In [1]:
import torch
from torch import Tensor
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
import numpy as np
import torch as t
import tqdm
#functional
import torch.nn.functional as F
import matplotlib.pyplot as plt
import plotly.io as pio
import plotly.express as px
import plotly.graph_objects as go
from functools import partial
import einops
import circuitsvis as cv

# Graphing stuff

In [2]:
def to_numpy(tensor, flat=False):
    if type(tensor)!=torch.Tensor:
        return tensor
    if flat:
        return tensor.flatten().detach().cpu().numpy()
    else:
        return tensor.detach().cpu().numpy()

def imshow(tensor, xaxis=None, yaxis=None, animation_name='Snapshot', **kwargs):
    tensor = torch.squeeze(tensor)
    px.imshow(to_numpy(tensor, flat=False),aspect='auto', 
              labels={'x':xaxis, 'y':yaxis, 'animation_name':animation_name}, 
              **kwargs).show()
# Set default colour scheme
imshow = partial(imshow, color_continuous_scale='Blues')
# Creates good defaults for showing divergent colour scales (ie with both 
# positive and negative values, where 0 is white)
imshow_div = partial(imshow, color_continuous_scale='RdBu', color_continuous_midpoint=0.0)
# Presets a bunch of defaults to imshow to make it suitable for showing heatmaps 
# of activations with x axis being input 1 and y axis being input 2.
inputs_heatmap = partial(imshow, xaxis='Input 1', yaxis='Input 2', color_continuous_scale='RdBu', color_continuous_midpoint=0.0)

def line(x, y=None, hover=None, xaxis='', yaxis='', **kwargs):
    if type(y)==torch.Tensor:
        y = to_numpy(y, flat=True)
    if type(x)==torch.Tensor:
        x = to_numpy(x, flat=True)
    fig = px.line(x, y=y, hover_name=hover, **kwargs)
    fig.update_layout(xaxis_title=xaxis, yaxis_title=yaxis)
    fig.show()

def lines(lines_list, x=None, mode='lines', labels=None, xaxis='', yaxis='', title = '', log_y=False, hover=None, **kwargs):
    if type(lines_list)==torch.Tensor:
        lines_list = [lines_list[i] for i in range(lines_list.shape[0])]
    if x is None:
        x=np.arange(len(lines_list[0]))
    fig = go.Figure(layout={'title':title})
    fig.update_xaxes(title=xaxis)
    fig.update_yaxes(title=yaxis)
    for c, line in enumerate(lines_list):
        if type(line)==torch.Tensor:
            line = to_numpy(line)
        if labels is not None:
            label = labels[c]
        else:
            label = c
        fig.add_trace(go.Scatter(x=x, y=line, mode=mode, name=label, hovertext=hover, **kwargs))
    if log_y:
        fig.update_layout(yaxis_type="log")
    fig.show()

# Training Config Stuff

In [3]:
cfg = HookedTransformerConfig(
    n_layers = 1,
    n_heads = 1,
    d_model = 128,
    d_head = 128,
    d_mlp = 512,
    act_fn = "relu",
    normalization_type=None,
    d_vocab=11,
    d_vocab_out=10,
    n_ctx=10,
    init_weights=True,
    device="cuda",
    seed = 1337,
)

lr = 1e-3
weight_decay = 1e-4
test_train_split = 0.8
epochs = 100
batch_size = 4096



In [4]:

np_data = np.load('data/moves.npy')

In [5]:
np_data.shape

(46080, 11)

In [6]:
#load npy file
np_data = np.load('data/moves.npy')
data = np_data[:, :-1]
labels = np_data[:, 1:]

print(len(data))
print(len(labels))
print(data[0])
print(labels[0])

46080
46080
[10  0  1  2  3  4  6  5  8  7]
[0 1 2 3 4 6 5 8 7 9]


In [7]:
encoded_labels = F.one_hot(t.tensor(labels))
print(encoded_labels)
print(t.sum(encoded_labels, axis=1))

tensor([[[1, 0, 0,  ..., 0, 0, 0],
         [0, 1, 0,  ..., 0, 0, 0],
         [0, 0, 1,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 1, 0],
         [0, 0, 0,  ..., 1, 0, 0],
         [0, 0, 0,  ..., 0, 0, 1]],

        [[1, 0, 0,  ..., 0, 0, 0],
         [0, 1, 0,  ..., 0, 0, 0],
         [0, 0, 1,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 1, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 1]],

        [[1, 0, 0,  ..., 0, 0, 0],
         [0, 1, 0,  ..., 0, 0, 0],
         [0, 0, 1,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 1, 0, 0],
         [0, 0, 0,  ..., 0, 0, 1]],

        ...,

        [[0, 0, 0,  ..., 0, 1, 0],
         [0, 0, 0,  ..., 1, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 1,  ..., 0, 0, 0],
         [0, 1, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 1]],

        [[0, 0, 0,  ..., 0, 1, 0],
         [0, 0, 0,  ..., 1, 0, 0],
         [0,

In [8]:
encoded_data = F.one_hot(t.tensor(data))
print(encoded_data[1238])

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
        [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
        [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]])


In [9]:
encoded_labels[0,:,:]

tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]])

In [10]:
#data and labels as numpy arrays
data = np.array(data)
encoded_labels = np.array(encoded_labels)
#data and encoded_labels as tensors
data = t.from_numpy(data)
encoded_labels = t.from_numpy(encoded_labels).to(t.float)
total_data = list(zip(data, encoded_labels))
split_data = list(t.utils.data.random_split(total_data, [.8, .2]))
train_pairs = split_data[0]
test_pairs= split_data[1]
train_data, train_labels = zip(*train_pairs)
test_data, test_labels = zip(*test_pairs)

train_data = t.stack(train_data).to(cfg.device)
train_labels = t.stack(train_labels).to(cfg.device)
test_data = t.stack(test_data).to(cfg.device)
test_labels = t.stack(test_labels).to(cfg.device)


#test train split
train_data = data[:int(len(data)*test_train_split)]
train_labels = encoded_labels[:int(len(data)*test_train_split)]
test_data = data[int(len(data)*test_train_split):]
test_labels = encoded_labels[int(len(data)*test_train_split):]

In [11]:
print(len(test_data))
print(len(test_labels))

9216
9216


In [12]:
def loss_fn(logits, labels):
    return t.nn.functional.cross_entropy(logits, labels)

In [13]:
ten = t.tensor([0,1]).to(t.float)
loss_fn(ten, ten)

tensor(0.3133)

In [14]:
train_losses = []
test_losses = []
model = HookedTransformer(cfg).to(cfg.device)
optimizer = t.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

# for epoch in tqdm.tqdm(range(epochs)):
for epoch in range(epochs):
    for batch in range(0, len(train_data), batch_size):
        train_logits = model(train_data[batch:batch+batch_size])
        print(train_logits.dtype)
        print(train_labels.dtype)
        train_loss = loss_fn(train_logits, train_labels[batch:batch+batch_size])

        train_loss.backward()

        train_losses.append(train_loss.item())
        optimizer.step()
        optimizer.zero_grad()

        with t.inference_mode():
            test_logits = model(test_data)
            test_loss = loss_fn(test_logits, test_labels)
            test_losses.append(test_loss.item())

        print(f"Epoch {epoch} | Train Loss: {train_loss.item()} | Test Loss: {test_loss.item()}")

Moving model to device:  cuda


  1%|          | 1/100 [00:00<00:43,  2.28it/s]

torch.float32
torch.float32
Epoch 0 | Train Loss: 2.301990032196045 | Test Loss: 2.2703514099121094
torch.float32
torch.float32
Epoch 0 | Train Loss: 2.271068572998047 | Test Loss: 2.2395405769348145
torch.float32
torch.float32
Epoch 0 | Train Loss: 2.23994779586792 | Test Loss: 2.2085189819335938
torch.float32
torch.float32
Epoch 0 | Train Loss: 2.2092206478118896 | Test Loss: 2.1764636039733887
torch.float32
torch.float32
Epoch 0 | Train Loss: 2.176811695098877 | Test Loss: 2.142791748046875
torch.float32
torch.float32
Epoch 0 | Train Loss: 2.1434266567230225 | Test Loss: 2.10731840133667
torch.float32
torch.float32
Epoch 0 | Train Loss: 2.1076362133026123 | Test Loss: 2.0705783367156982
torch.float32
torch.float32
Epoch 0 | Train Loss: 2.070155382156372 | Test Loss: 2.0336356163024902
torch.float32
torch.float32
Epoch 0 | Train Loss: 2.0340754985809326 | Test Loss: 1.9981554746627808
torch.float32
torch.float32
Epoch 1 | Train Loss: 1.9980801343917847 | Test Loss: 1.965922474861145


  3%|▎         | 3/100 [00:00<00:17,  5.57it/s]

Epoch 2 | Train Loss: 1.7850780487060547 | Test Loss: 1.7634681463241577
torch.float32
torch.float32
Epoch 2 | Train Loss: 1.7640475034713745 | Test Loss: 1.7454668283462524
torch.float32
torch.float32
Epoch 2 | Train Loss: 1.7470600605010986 | Test Loss: 1.7260545492172241
torch.float32
torch.float32
Epoch 2 | Train Loss: 1.7282990217208862 | Test Loss: 1.7056715488433838
torch.float32
torch.float32
Epoch 2 | Train Loss: 1.707759141921997 | Test Loss: 1.6847436428070068
torch.float32
torch.float32
Epoch 2 | Train Loss: 1.684143304824829 | Test Loss: 1.6626578569412231
torch.float32
torch.float32
Epoch 2 | Train Loss: 1.6615384817123413 | Test Loss: 1.63853919506073
torch.float32
torch.float32
Epoch 2 | Train Loss: 1.6412895917892456 | Test Loss: 1.61211359500885
torch.float32
torch.float32
Epoch 3 | Train Loss: 1.615416407585144 | Test Loss: 1.583520531654358
torch.float32
torch.float32
Epoch 3 | Train Loss: 1.589840292930603 | Test Loss: 1.5527479648590088
torch.float32
torch.float32

  5%|▌         | 5/100 [00:00<00:12,  7.56it/s]

Epoch 4 | Train Loss: 1.055804967880249 | Test Loss: 0.9966414570808411
torch.float32
torch.float32
Epoch 4 | Train Loss: 0.998585045337677 | Test Loss: 0.9336987137794495
torch.float32
torch.float32
Epoch 4 | Train Loss: 0.9344625473022461 | Test Loss: 0.8701422810554504
torch.float32
torch.float32
Epoch 4 | Train Loss: 0.8709927797317505 | Test Loss: 0.8066260814666748
torch.float32
torch.float32
Epoch 4 | Train Loss: 0.8093199729919434 | Test Loss: 0.7441383600234985
torch.float32
torch.float32
Epoch 5 | Train Loss: 0.745267391204834 | Test Loss: 0.6831145286560059
torch.float32
torch.float32
Epoch 5 | Train Loss: 0.6857129335403442 | Test Loss: 0.6247608065605164
torch.float32
torch.float32
Epoch 5 | Train Loss: 0.6265265345573425 | Test Loss: 0.5698306560516357
torch.float32
torch.float32
Epoch 5 | Train Loss: 0.5715181231498718 | Test Loss: 0.5179246664047241
torch.float32
torch.float32
Epoch 5 | Train Loss: 0.517713725566864 | Test Loss: 0.46841806173324585
torch.float32
torch.f

  9%|▉         | 9/100 [00:01<00:09,  9.66it/s]

Epoch 6 | Train Loss: 0.0959727093577385 | Test Loss: 0.08052349090576172
torch.float32
torch.float32
Epoch 6 | Train Loss: 0.08155813813209534 | Test Loss: 0.06809712946414948
torch.float32
torch.float32
Epoch 7 | Train Loss: 0.06833237409591675 | Test Loss: 0.05818776786327362
torch.float32
torch.float32
Epoch 7 | Train Loss: 0.05873221904039383 | Test Loss: 0.04973116144537926
torch.float32
torch.float32
Epoch 7 | Train Loss: 0.04990093410015106 | Test Loss: 0.04219410941004753
torch.float32
torch.float32
Epoch 7 | Train Loss: 0.04259960353374481 | Test Loss: 0.035934869199991226
torch.float32
torch.float32
Epoch 7 | Train Loss: 0.036075372248888016 | Test Loss: 0.031008217483758926
torch.float32
torch.float32
Epoch 7 | Train Loss: 0.031222213059663773 | Test Loss: 0.027041928842663765
torch.float32
torch.float32
Epoch 7 | Train Loss: 0.02733251266181469 | Test Loss: 0.023780616000294685
torch.float32
torch.float32
Epoch 7 | Train Loss: 0.023842643946409225 | Test Loss: 0.0210371930

 11%|█         | 11/100 [00:01<00:08, 10.21it/s]

Epoch 9 | Train Loss: 0.007358204107731581 | Test Loss: 0.00669760862365365
torch.float32
torch.float32
Epoch 9 | Train Loss: 0.006756930146366358 | Test Loss: 0.006213771644979715
torch.float32
torch.float32
Epoch 9 | Train Loss: 0.006286424584686756 | Test Loss: 0.005779371131211519
torch.float32
torch.float32
Epoch 9 | Train Loss: 0.005878004245460033 | Test Loss: 0.0053921290673315525
torch.float32
torch.float32
Epoch 9 | Train Loss: 0.005444241221994162 | Test Loss: 0.00504921842366457
torch.float32
torch.float32
Epoch 9 | Train Loss: 0.005101026501506567 | Test Loss: 0.004745654296129942
torch.float32
torch.float32
Epoch 9 | Train Loss: 0.004754047840833664 | Test Loss: 0.004474232438951731
torch.float32
torch.float32
Epoch 9 | Train Loss: 0.004473147448152304 | Test Loss: 0.004229612648487091
torch.float32
torch.float32
Epoch 10 | Train Loss: 0.0042833187617361546 | Test Loss: 0.004006821196526289
torch.float32
torch.float32
Epoch 10 | Train Loss: 0.004071654286235571 | Test Los

 13%|█▎        | 13/100 [00:01<00:08, 10.60it/s]

Epoch 11 | Train Loss: 0.0024464253801852465 | Test Loss: 0.00234392611309886
torch.float32
torch.float32
Epoch 11 | Train Loss: 0.002365282503888011 | Test Loss: 0.0022742105647921562
torch.float32
torch.float32
Epoch 11 | Train Loss: 0.0023117659147828817 | Test Loss: 0.0022089567501097918
torch.float32
torch.float32
Epoch 11 | Train Loss: 0.0022096955217421055 | Test Loss: 0.0021481411531567574
torch.float32
torch.float32
Epoch 11 | Train Loss: 0.0021499854046851397 | Test Loss: 0.002091736299917102
torch.float32
torch.float32
Epoch 12 | Train Loss: 0.0021094840485602617 | Test Loss: 0.002039363607764244
torch.float32
torch.float32
Epoch 12 | Train Loss: 0.002068719593808055 | Test Loss: 0.0019905806984752417
torch.float32
torch.float32
Epoch 12 | Train Loss: 0.001994180493056774 | Test Loss: 0.0019450061954557896
torch.float32
torch.float32
Epoch 12 | Train Loss: 0.001963801681995392 | Test Loss: 0.0019021208863705397
torch.float32
torch.float32
Epoch 12 | Train Loss: 0.00192346668

 15%|█▌        | 15/100 [00:01<00:07, 10.81it/s]

Epoch 13 | Train Loss: 0.0015281018568202853 | Test Loss: 0.0015023474115878344
torch.float32
torch.float32
Epoch 13 | Train Loss: 0.0015064658364281058 | Test Loss: 0.0014798303600400686
torch.float32
torch.float32
Epoch 14 | Train Loss: 0.0014901889953762293 | Test Loss: 0.001458054524846375
torch.float32
torch.float32
Epoch 14 | Train Loss: 0.001474240212701261 | Test Loss: 0.0014369753189384937
torch.float32
torch.float32
Epoch 14 | Train Loss: 0.0014386108377948403 | Test Loss: 0.001416612882167101
torch.float32
torch.float32
Epoch 14 | Train Loss: 0.0014276693109422922 | Test Loss: 0.0013969428837299347
torch.float32
torch.float32
Epoch 14 | Train Loss: 0.0014114391524344683 | Test Loss: 0.0013779043219983578
torch.float32
torch.float32
Epoch 14 | Train Loss: 0.0013845361536368728 | Test Loss: 0.0013594356132671237
torch.float32
torch.float32
Epoch 14 | Train Loss: 0.0013822632608935237 | Test Loss: 0.0013415053253993392
torch.float32
torch.float32
Epoch 14 | Train Loss: 0.001343

 17%|█▋        | 17/100 [00:01<00:07, 10.98it/s]

Epoch 16 | Train Loss: 0.0011705746874213219 | Test Loss: 0.0011457238579168916
torch.float32
torch.float32
Epoch 16 | Train Loss: 0.0011463674018159509 | Test Loss: 0.0011329231783747673
torch.float32
torch.float32
Epoch 16 | Train Loss: 0.0011415740009397268 | Test Loss: 0.0011203923495486379
torch.float32
torch.float32
Epoch 16 | Train Loss: 0.001132023986428976 | Test Loss: 0.001108107273466885
torch.float32
torch.float32
Epoch 16 | Train Loss: 0.001112615573219955 | Test Loss: 0.0010960428044199944
torch.float32
torch.float32
Epoch 16 | Train Loss: 0.0011146877659484744 | Test Loss: 0.0010841808980330825
torch.float32
torch.float32
Epoch 16 | Train Loss: 0.001085318741388619 | Test Loss: 0.0010725517058745027
torch.float32
torch.float32
Epoch 16 | Train Loss: 0.0010770575609058142 | Test Loss: 0.0010611464967951179
torch.float32
torch.float32
Epoch 17 | Train Loss: 0.0010671202326193452 | Test Loss: 0.0010499432682991028
torch.float32
torch.float32
Epoch 17 | Train Loss: 0.0010604

 19%|█▉        | 19/100 [00:02<00:07, 11.10it/s]

Epoch 18 | Train Loss: 0.0009394578519277275 | Test Loss: 0.0009205344831570983
torch.float32
torch.float32
Epoch 18 | Train Loss: 0.0009238924249075353 | Test Loss: 0.000911643379367888
torch.float32
torch.float32
Epoch 18 | Train Loss: 0.0009269340080209076 | Test Loss: 0.0009028696804307401
torch.float32
torch.float32
Epoch 18 | Train Loss: 0.0009037085110321641 | Test Loss: 0.0008942388813011348
torch.float32
torch.float32
Epoch 18 | Train Loss: 0.0008980572456493974 | Test Loss: 0.0008857441716827452
torch.float32
torch.float32
Epoch 19 | Train Loss: 0.0008900508400984108 | Test Loss: 0.0008773694862611592
torch.float32
torch.float32
Epoch 19 | Train Loss: 0.0008857574430294335 | Test Loss: 0.0008691144757904112
torch.float32
torch.float32
Epoch 19 | Train Loss: 0.0008689725073054433 | Test Loss: 0.000860988802742213
torch.float32
torch.float32
Epoch 19 | Train Loss: 0.0008675504359416664 | Test Loss: 0.0008529868791811168
torch.float32
torch.float32
Epoch 19 | Train Loss: 0.00086

 23%|██▎       | 23/100 [00:02<00:06, 11.28it/s]

Epoch 20 | Train Loss: 0.0007658340036869049 | Test Loss: 0.000758373353164643
torch.float32
torch.float32
Epoch 20 | Train Loss: 0.0007617971277795732 | Test Loss: 0.0007517627673223615
torch.float32
torch.float32
Epoch 21 | Train Loss: 0.000755111628677696 | Test Loss: 0.0007452411227859557
torch.float32
torch.float32
Epoch 21 | Train Loss: 0.0007521556690335274 | Test Loss: 0.0007388003286905587
torch.float32
torch.float32
Epoch 21 | Train Loss: 0.000738547823857516 | Test Loss: 0.0007324504549615085
torch.float32
torch.float32
Epoch 21 | Train Loss: 0.0007381337927654386 | Test Loss: 0.000726187601685524
torch.float32
torch.float32
Epoch 21 | Train Loss: 0.0007342075114138424 | Test Loss: 0.0007200014661066234
torch.float32
torch.float32
Epoch 21 | Train Loss: 0.0007223595748655498 | Test Loss: 0.0007138896617107093
torch.float32
torch.float32
Epoch 21 | Train Loss: 0.0007257675752043724 | Test Loss: 0.0007078427588567138
torch.float32
torch.float32
Epoch 21 | Train Loss: 0.0007085

 25%|██▌       | 25/100 [00:02<00:06, 11.30it/s]

Epoch 23 | Train Loss: 0.0006468213396146894 | Test Loss: 0.0006358455284498632
torch.float32
torch.float32
Epoch 23 | Train Loss: 0.000635601463727653 | Test Loss: 0.000630769704002887
torch.float32
torch.float32
Epoch 23 | Train Loss: 0.0006357257370837033 | Test Loss: 0.0006257587228901684
torch.float32
torch.float32
Epoch 23 | Train Loss: 0.0006328379386104643 | Test Loss: 0.000620804843492806
torch.float32
torch.float32
Epoch 23 | Train Loss: 0.0006228075362741947 | Test Loss: 0.000615903118159622
torch.float32
torch.float32
Epoch 23 | Train Loss: 0.0006261146045289934 | Test Loss: 0.0006110514514148235
torch.float32
torch.float32
Epoch 23 | Train Loss: 0.0006117022130638361 | Test Loss: 0.0006062620086595416
torch.float32
torch.float32
Epoch 23 | Train Loss: 0.000609199982136488 | Test Loss: 0.0006015316466800869
torch.float32
torch.float32
Epoch 24 | Train Loss: 0.0006040645530447364 | Test Loss: 0.0005968562327325344
torch.float32
torch.float32
Epoch 24 | Train Loss: 0.00060225

 27%|██▋       | 27/100 [00:02<00:06, 11.30it/s]

Epoch 25 | Train Loss: 0.0005510695627890527 | Test Loss: 0.0005407357239164412
torch.float32
torch.float32
Epoch 25 | Train Loss: 0.0005424552364274859 | Test Loss: 0.0005367479170672596
torch.float32
torch.float32
Epoch 25 | Train Loss: 0.0005456233047880232 | Test Loss: 0.0005327966646291316
torch.float32
torch.float32
Epoch 25 | Train Loss: 0.0005333537119440734 | Test Loss: 0.0005288926186040044
torch.float32
torch.float32
Epoch 25 | Train Loss: 0.0005315309972502291 | Test Loss: 0.0005250347894616425
torch.float32
torch.float32
Epoch 26 | Train Loss: 0.000527199765201658 | Test Loss: 0.0005212196265347302
torch.float32
torch.float32
Epoch 26 | Train Loss: 0.0005258712917566299 | Test Loss: 0.0005174422403797507
torch.float32
torch.float32
Epoch 26 | Train Loss: 0.0005172318196855485 | Test Loss: 0.0005137121770530939
torch.float32
torch.float32
Epoch 26 | Train Loss: 0.0005177754792384803 | Test Loss: 0.0005100247217342257
torch.float32
torch.float32
Epoch 26 | Train Loss: 0.0005

 29%|██▉       | 29/100 [00:02<00:06, 11.29it/s]

Epoch 27 | Train Loss: 0.00046912222751416266 | Test Loss: 0.00046541940537281334
torch.float32
torch.float32
Epoch 27 | Train Loss: 0.0004677890974562615 | Test Loss: 0.00046223122626543045
torch.float32
torch.float32
Epoch 28 | Train Loss: 0.00046411162475124 | Test Loss: 0.0004590777098201215
torch.float32
torch.float32
Epoch 28 | Train Loss: 0.00046313609345816076 | Test Loss: 0.0004559553926810622
torch.float32
torch.float32
Epoch 28 | Train Loss: 0.00045578545541502535 | Test Loss: 0.00045286849490366876
torch.float32
torch.float32
Epoch 28 | Train Loss: 0.0004564494302030653 | Test Loss: 0.0004498138732742518
torch.float32
torch.float32
Epoch 28 | Train Loss: 0.00045515852980315685 | Test Loss: 0.0004467903927434236
torch.float32
torch.float32
Epoch 28 | Train Loss: 0.00044817267917096615 | Test Loss: 0.0004437931929714978
torch.float32
torch.float32
Epoch 28 | Train Loss: 0.00045109959319233894 | Test Loss: 0.0004408188397064805
torch.float32
torch.float32
Epoch 28 | Train Loss

 31%|███       | 31/100 [00:03<00:06, 11.21it/s]

Epoch 30 | Train Loss: 0.000410973938414827 | Test Loss: 0.0004047964175697416
torch.float32
torch.float32
Epoch 30 | Train Loss: 0.00040465276106260717 | Test Loss: 0.00040221205563284457
torch.float32
torch.float32
Epoch 30 | Train Loss: 0.00040537910535931587 | Test Loss: 0.00039965510950423777
torch.float32
torch.float32
Epoch 30 | Train Loss: 0.00040447135688737035 | Test Loss: 0.0003971206024289131
torch.float32
torch.float32
Epoch 30 | Train Loss: 0.00039832876063883305 | Test Loss: 0.0003946075157728046
torch.float32
torch.float32
Epoch 30 | Train Loss: 0.0004010929842479527 | Test Loss: 0.0003921138704754412
torch.float32
torch.float32
Epoch 30 | Train Loss: 0.0003924966440536082 | Test Loss: 0.000389645661925897
torch.float32
torch.float32
Epoch 30 | Train Loss: 0.00039166491478681564 | Test Loss: 0.0003872033266816288
torch.float32
torch.float32
Epoch 31 | Train Loss: 0.0003887642815243453 | Test Loss: 0.00038478177157230675
torch.float32
torch.float32
Epoch 31 | Train Loss:

 33%|███▎      | 33/100 [00:03<00:05, 11.20it/s]

Epoch 32 | Train Loss: 0.00036240750341676176 | Test Loss: 0.00035743621992878616
torch.float32
torch.float32
Epoch 32 | Train Loss: 0.00036179713788442314 | Test Loss: 0.0003552911221049726
torch.float32
torch.float32
Epoch 32 | Train Loss: 0.0003563494828995317 | Test Loss: 0.0003531624097377062
torch.float32
torch.float32
Epoch 32 | Train Loss: 0.0003589499683585018 | Test Loss: 0.0003510491515044123
torch.float32
torch.float32
Epoch 32 | Train Loss: 0.00035137261147610843 | Test Loss: 0.00034895632416009903
torch.float32
torch.float32
Epoch 32 | Train Loss: 0.0003507730725686997 | Test Loss: 0.00034688482992351055
torch.float32
torch.float32
Epoch 33 | Train Loss: 0.0003482690663076937 | Test Loss: 0.00034483091440051794
torch.float32
torch.float32
Epoch 33 | Train Loss: 0.0003477955178823322 | Test Loss: 0.00034279562532901764
torch.float32
torch.float32
Epoch 33 | Train Loss: 0.00034268127637915313 | Test Loss: 0.000340779748512432
torch.float32
torch.float32
Epoch 33 | Train Los

 37%|███▋      | 37/100 [00:03<00:05, 11.17it/s]

Epoch 34 | Train Loss: 0.0003230949805583805 | Test Loss: 0.00031609315192326903
torch.float32
torch.float32
Epoch 34 | Train Loss: 0.00031636565108783543 | Test Loss: 0.00031430349918082356
torch.float32
torch.float32
Epoch 34 | Train Loss: 0.00031594614847563207 | Test Loss: 0.0003125302609987557
torch.float32
torch.float32
Epoch 35 | Train Loss: 0.00031375783146359026 | Test Loss: 0.0003107719530817121
torch.float32
torch.float32
Epoch 35 | Train Loss: 0.00031342310830950737 | Test Loss: 0.00030902816797606647
torch.float32
torch.float32
Epoch 35 | Train Loss: 0.00030893218354322016 | Test Loss: 0.0003073005354963243
torch.float32
torch.float32
Epoch 35 | Train Loss: 0.0003096718282904476 | Test Loss: 0.0003055886772926897
torch.float32
torch.float32
Epoch 35 | Train Loss: 0.00030937898554839194 | Test Loss: 0.0003038910508621484
torch.float32
torch.float32
Epoch 35 | Train Loss: 0.0003047810459975153 | Test Loss: 0.000302202592138201
torch.float32
torch.float32
Epoch 35 | Train Los

 39%|███▉      | 39/100 [00:03<00:05, 11.27it/s]

torch.float32
torch.float32
Epoch 37 | Train Loss: 0.00028410652885213494 | Test Loss: 0.0002814882027450949
torch.float32
torch.float32
Epoch 37 | Train Loss: 0.0002838734944816679 | Test Loss: 0.00027998225414194167
torch.float32
torch.float32
Epoch 37 | Train Loss: 0.0002798969217110425 | Test Loss: 0.0002784915850497782
torch.float32
torch.float32
Epoch 37 | Train Loss: 0.0002806284464895725 | Test Loss: 0.0002770129940472543
torch.float32
torch.float32
Epoch 37 | Train Loss: 0.0002804772520903498 | Test Loss: 0.0002755463938228786
torch.float32
torch.float32
Epoch 37 | Train Loss: 0.0002763477386906743 | Test Loss: 0.00027408829191699624
torch.float32
torch.float32
Epoch 37 | Train Loss: 0.00027854222571477294 | Test Loss: 0.00027263900847174227
torch.float32
torch.float32
Epoch 37 | Train Loss: 0.00027284384123049676 | Test Loss: 0.0002712029963731766
torch.float32
torch.float32
Epoch 37 | Train Loss: 0.0002726298407651484 | Test Loss: 0.00026977804373018444
torch.float32
torch.f

 41%|████      | 41/100 [00:03<00:05, 11.28it/s]

Epoch 39 | Train Loss: 0.00025474984431639314 | Test Loss: 0.0002535332168918103
torch.float32
torch.float32
Epoch 39 | Train Loss: 0.0002554673992563039 | Test Loss: 0.0002522474096622318
torch.float32
torch.float32
Epoch 39 | Train Loss: 0.00025542356888763607 | Test Loss: 0.0002509711775928736
torch.float32
torch.float32
Epoch 39 | Train Loss: 0.0002516921085771173 | Test Loss: 0.0002497025125194341
torch.float32
torch.float32
Epoch 39 | Train Loss: 0.0002537420659791678 | Test Loss: 0.00024844237486831844
torch.float32
torch.float32
Epoch 39 | Train Loss: 0.00024860582198016346 | Test Loss: 0.00024719134671613574
torch.float32
torch.float32
Epoch 39 | Train Loss: 0.00024849988403730094 | Test Loss: 0.00024595099966973066
torch.float32
torch.float32
Epoch 40 | Train Loss: 0.00024689818383194506 | Test Loss: 0.00024472008226439357
torch.float32
torch.float32
Epoch 40 | Train Loss: 0.0002467793528921902 | Test Loss: 0.00024349801242351532
torch.float32
torch.float32
Epoch 40 | Train L

 43%|████▎     | 43/100 [00:04<00:05, 11.25it/s]

Epoch 41 | Train Loss: 0.00023018075444269925 | Test Loss: 0.0002284167130710557
torch.float32
torch.float32
Epoch 41 | Train Loss: 0.00023209214850794524 | Test Loss: 0.00022731312492396683
torch.float32
torch.float32
Epoch 41 | Train Loss: 0.00022744557645637542 | Test Loss: 0.00022621717653237283
torch.float32
torch.float32
Epoch 41 | Train Loss: 0.00022742045985069126 | Test Loss: 0.0002251297264592722
torch.float32
torch.float32
Epoch 42 | Train Loss: 0.0002259925240650773 | Test Loss: 0.0002240507019450888
torch.float32
torch.float32
Epoch 42 | Train Loss: 0.00022593094035983086 | Test Loss: 0.000222981019760482
torch.float32
torch.float32
Epoch 42 | Train Loss: 0.000222908376599662 | Test Loss: 0.00022191782773006707
torch.float32
torch.float32
Epoch 42 | Train Loss: 0.00022359447029884905 | Test Loss: 0.00022086358512751758
torch.float32
torch.float32
Epoch 42 | Train Loss: 0.00022366979101207107 | Test Loss: 0.00021981667669024318
torch.float32
torch.float32
Epoch 42 | Train L

 45%|████▌     | 45/100 [00:04<00:04, 11.29it/s]

Epoch 43 | Train Loss: 0.00020889700681436807 | Test Loss: 0.0002068315225187689
torch.float32
torch.float32
Epoch 44 | Train Loss: 0.00020762217172887176 | Test Loss: 0.00020588030747603625
torch.float32
torch.float32
Epoch 44 | Train Loss: 0.00020760312327183783 | Test Loss: 0.00020493488409556448
torch.float32
torch.float32
Epoch 44 | Train Loss: 0.00020486643188633025 | Test Loss: 0.00020399958884809166
torch.float32
torch.float32
Epoch 44 | Train Loss: 0.00020553162903524935 | Test Loss: 0.00020306839724071324
torch.float32
torch.float32
Epoch 44 | Train Loss: 0.00020566035527735949 | Test Loss: 0.00020214510732330382
torch.float32
torch.float32
Epoch 44 | Train Loss: 0.00020270806271582842 | Test Loss: 0.00020122523710597306
torch.float32
torch.float32
Epoch 44 | Train Loss: 0.00020443453104235232 | Test Loss: 0.00020031226449646056
torch.float32
torch.float32
Epoch 44 | Train Loss: 0.00020040613890159875 | Test Loss: 0.0001994050689972937
torch.float32
torch.float32
Epoch 44 | T

 47%|████▋     | 47/100 [00:04<00:04, 11.28it/s]

Epoch 46 | Train Loss: 0.00018892389198299497 | Test Loss: 0.00018816010560840368
torch.float32
torch.float32
Epoch 46 | Train Loss: 0.00018956430722028017 | Test Loss: 0.00018733479373622686
torch.float32
torch.float32
Epoch 46 | Train Loss: 0.00018973635451402515 | Test Loss: 0.0001865155209088698
torch.float32
torch.float32
Epoch 46 | Train Loss: 0.00018702638044487685 | Test Loss: 0.00018570059910416603
torch.float32
torch.float32
Epoch 46 | Train Loss: 0.00018864368030335754 | Test Loss: 0.00018488845671527088
torch.float32
torch.float32
Epoch 46 | Train Loss: 0.000184961871127598 | Test Loss: 0.0001840836921473965
torch.float32
torch.float32
Epoch 46 | Train Loss: 0.00018506291962694377 | Test Loss: 0.0001832839334383607
torch.float32
torch.float32
Epoch 47 | Train Loss: 0.0001839831384131685 | Test Loss: 0.0001824907521950081
torch.float32
torch.float32
Epoch 47 | Train Loss: 0.00018401276611257344 | Test Loss: 0.00018170078692492098
torch.float32
torch.float32
Epoch 47 | Train 

 49%|████▉     | 49/100 [00:04<00:04, 11.32it/s]

Epoch 48 | Train Loss: 0.00017308622773271054 | Test Loss: 0.00017189189384225756
torch.float32
torch.float32
Epoch 48 | Train Loss: 0.00017459677474107593 | Test Loss: 0.0001711696822894737
torch.float32
torch.float32
Epoch 48 | Train Loss: 0.00017122688586823642 | Test Loss: 0.00017045212734956294
torch.float32
torch.float32
Epoch 48 | Train Loss: 0.00017135788220912218 | Test Loss: 0.00016973901074379683
torch.float32
torch.float32
Epoch 49 | Train Loss: 0.00017038661462720484 | Test Loss: 0.0001690309145487845
torch.float32
torch.float32
Epoch 49 | Train Loss: 0.0001704374299151823 | Test Loss: 0.00016832680557854474
torch.float32
torch.float32
Epoch 49 | Train Loss: 0.0001682633883319795 | Test Loss: 0.0001676274259807542
torch.float32
torch.float32
Epoch 49 | Train Loss: 0.00016886909725144506 | Test Loss: 0.00016693337238393724
torch.float32
torch.float32
Epoch 49 | Train Loss: 0.00016908835095819086 | Test Loss: 0.0001662420982029289
torch.float32
torch.float32
Epoch 49 | Train

 53%|█████▎    | 53/100 [00:05<00:04, 11.36it/s]

Epoch 50 | Train Loss: 0.00015911283844616264 | Test Loss: 0.00015763497503940016
torch.float32
torch.float32
Epoch 51 | Train Loss: 0.0001582356489961967 | Test Loss: 0.00015700106450822204
torch.float32
torch.float32
Epoch 51 | Train Loss: 0.0001583019911777228 | Test Loss: 0.00015636993339285254
torch.float32
torch.float32
Epoch 51 | Train Loss: 0.0001563067053211853 | Test Loss: 0.00015574412827845663
torch.float32
torch.float32
Epoch 51 | Train Loss: 0.0001568914158269763 | Test Loss: 0.00015512316895183176
torch.float32
torch.float32
Epoch 51 | Train Loss: 0.0001571332395542413 | Test Loss: 0.00015450383943971246
torch.float32
torch.float32
Epoch 51 | Train Loss: 0.0001549104053992778 | Test Loss: 0.0001538883661851287
torch.float32
torch.float32
Epoch 51 | Train Loss: 0.00015628435357939452 | Test Loss: 0.00015327564324252307
torch.float32
torch.float32
Epoch 51 | Train Loss: 0.00015331286704167724 | Test Loss: 0.0001526666892459616
torch.float32
torch.float32
Epoch 51 | Train L

 55%|█████▌    | 55/100 [00:05<00:03, 11.36it/s]

Epoch 53 | Train Loss: 0.0001455760357202962 | Test Loss: 0.00014507726882584393
torch.float32
torch.float32
Epoch 53 | Train Loss: 0.00014613995153922588 | Test Loss: 0.00014451798051595688
torch.float32
torch.float32
Epoch 53 | Train Loss: 0.00014639760775025934 | Test Loss: 0.0001439620100427419
torch.float32
torch.float32
Epoch 53 | Train Loss: 0.000144334597280249 | Test Loss: 0.0001434069126844406
torch.float32
torch.float32
Epoch 53 | Train Loss: 0.00014562490105163306 | Test Loss: 0.0001428549294359982
torch.float32
torch.float32
Epoch 53 | Train Loss: 0.00014288185047917068 | Test Loss: 0.00014230776287149638
torch.float32
torch.float32
Epoch 53 | Train Loss: 0.0001430543779861182 | Test Loss: 0.00014176327385939658
torch.float32
torch.float32
Epoch 54 | Train Loss: 0.00014230083615984768 | Test Loss: 0.00014122112770564854
torch.float32
torch.float32
Epoch 54 | Train Loss: 0.00014238145377021283 | Test Loss: 0.00014068307064007968
torch.float32
torch.float32
Epoch 54 | Train 

 57%|█████▋    | 57/100 [00:05<00:03, 11.33it/s]

Epoch 55 | Train Loss: 0.00013479792687576264 | Test Loss: 0.00013395535643212497
torch.float32
torch.float32
Epoch 55 | Train Loss: 0.00013601405953522772 | Test Loss: 0.00013345840852707624
torch.float32
torch.float32
Epoch 55 | Train Loss: 0.00013347754429560155 | Test Loss: 0.00013296173710841686
torch.float32
torch.float32
Epoch 55 | Train Loss: 0.00013365519407670945 | Test Loss: 0.0001324689801549539
torch.float32
torch.float32
Epoch 56 | Train Loss: 0.00013297161785885692 | Test Loss: 0.0001319803559454158
torch.float32
torch.float32
Epoch 56 | Train Loss: 0.0001330589147983119 | Test Loss: 0.00013149381265975535
torch.float32
torch.float32
Epoch 56 | Train Loss: 0.0001314260734943673 | Test Loss: 0.00013101020886097103
torch.float32
torch.float32
Epoch 56 | Train Loss: 0.0001319626608164981 | Test Loss: 0.00013052980648353696
torch.float32
torch.float32
Epoch 56 | Train Loss: 0.00013223545101936907 | Test Loss: 0.0001300506992265582
torch.float32
torch.float32
Epoch 56 | Train

 59%|█████▉    | 59/100 [00:05<00:03, 11.33it/s]

Epoch 57 | Train Loss: 0.00012514674745034426 | Test Loss: 0.00012405644520185888
torch.float32
torch.float32
Epoch 58 | Train Loss: 0.00012452545342966914 | Test Loss: 0.00012361217522993684
torch.float32
torch.float32
Epoch 58 | Train Loss: 0.00012461686856113374 | Test Loss: 0.00012317078653723001
torch.float32
torch.float32
Epoch 58 | Train Loss: 0.00012309869634918869 | Test Loss: 0.0001227331958943978
torch.float32
torch.float32
Epoch 58 | Train Loss: 0.00012361937842797488 | Test Loss: 0.000122297162306495
torch.float32
torch.float32
Epoch 58 | Train Loss: 0.00012389985204208642 | Test Loss: 0.00012186289677629247
torch.float32
torch.float32
Epoch 58 | Train Loss: 0.00012216343020554632 | Test Loss: 0.00012143031199229881
torch.float32
torch.float32
Epoch 58 | Train Loss: 0.00012327880540397018 | Test Loss: 0.00012100039020879194
torch.float32
torch.float32
Epoch 58 | Train Loss: 0.0001210101690958254 | Test Loss: 0.00012057104322593659
torch.float32
torch.float32
Epoch 58 | Tra

 61%|██████    | 61/100 [00:05<00:03, 11.30it/s]

Epoch 60 | Train Loss: 0.0001155335339717567 | Test Loss: 0.00011520938278408721
torch.float32
torch.float32
Epoch 60 | Train Loss: 0.00011603676102822646 | Test Loss: 0.00011481294495752081
torch.float32
torch.float32
Epoch 60 | Train Loss: 0.00011632197856670246 | Test Loss: 0.00011441731476224959
torch.float32
torch.float32
Epoch 60 | Train Loss: 0.00011469283344922587 | Test Loss: 0.00011402335803722963
torch.float32
torch.float32
Epoch 60 | Train Loss: 0.0001157494043582119 | Test Loss: 0.00011363098019501194
torch.float32
torch.float32
Epoch 60 | Train Loss: 0.00011363637167960405 | Test Loss: 0.00011324074876029044
torch.float32
torch.float32
Epoch 60 | Train Loss: 0.00011382007505744696 | Test Loss: 0.00011285315849818289
torch.float32
torch.float32
Epoch 61 | Train Loss: 0.00011327788524795324 | Test Loss: 0.00011246812937315553
torch.float32
torch.float32
Epoch 61 | Train Loss: 0.0001133746700361371 | Test Loss: 0.00011208317300770432
torch.float32
torch.float32
Epoch 61 | Tr

 63%|██████▎   | 63/100 [00:05<00:03, 11.26it/s]

Epoch 62 | Train Loss: 0.00010788103099912405 | Test Loss: 0.00010726747859735042
torch.float32
torch.float32
Epoch 62 | Train Loss: 0.00010888336692005396 | Test Loss: 0.00010690993804018945
torch.float32
torch.float32
Epoch 62 | Train Loss: 0.00010691144416341558 | Test Loss: 0.0001065531323547475
torch.float32
torch.float32
Epoch 62 | Train Loss: 0.00010709406342357397 | Test Loss: 0.00010619939712341875
torch.float32
torch.float32
Epoch 63 | Train Loss: 0.00010659772669896483 | Test Loss: 0.00010584726260276511
torch.float32
torch.float32
Epoch 63 | Train Loss: 0.0001066961485776119 | Test Loss: 0.00010549685976002365
torch.float32
torch.float32
Epoch 63 | Train Loss: 0.00010542001109570265 | Test Loss: 0.000105148101283703
torch.float32
torch.float32
Epoch 63 | Train Loss: 0.00010589690646156669 | Test Loss: 0.00010480011405888945
torch.float32
torch.float32
Epoch 63 | Train Loss: 0.00010618211672408506 | Test Loss: 0.00010445644147694111
torch.float32
torch.float32
Epoch 63 | Tra

 67%|██████▋   | 67/100 [00:06<00:02, 11.19it/s]

Epoch 64 | Train Loss: 0.00010094331082655117 | Test Loss: 0.00010011216363636777
torch.float32
torch.float32
Epoch 65 | Train Loss: 0.00010048670083051547 | Test Loss: 9.978978778235614e-05
torch.float32
torch.float32
Epoch 65 | Train Loss: 0.0001005862868623808 | Test Loss: 9.94682777673006e-05
torch.float32
torch.float32
Epoch 65 | Train Loss: 9.938951552612707e-05 | Test Loss: 9.914778638631105e-05
torch.float32
torch.float32
Epoch 65 | Train Loss: 9.985054202843457e-05 | Test Loss: 9.882986341835931e-05
torch.float32
torch.float32
Epoch 65 | Train Loss: 0.00010013727296609432 | Test Loss: 9.851455979514867e-05
torch.float32
torch.float32
Epoch 65 | Train Loss: 9.873984527075663e-05 | Test Loss: 9.819842671277002e-05
torch.float32
torch.float32
Epoch 65 | Train Loss: 9.966702054953203e-05 | Test Loss: 9.788454917725176e-05
torch.float32
torch.float32
Epoch 65 | Train Loss: 9.787989984033629e-05 | Test Loss: 9.757113002706319e-05
torch.float32
torch.float32
Epoch 65 | Train Loss: 9.

 69%|██████▉   | 69/100 [00:06<00:02, 11.21it/s]

Epoch 67 | Train Loss: 9.385561861563474e-05 | Test Loss: 9.364181460114196e-05
torch.float32
torch.float32
Epoch 67 | Train Loss: 9.430235513718799e-05 | Test Loss: 9.334916830994189e-05
torch.float32
torch.float32
Epoch 67 | Train Loss: 9.45873252931051e-05 | Test Loss: 9.305899584433064e-05
torch.float32
torch.float32
Epoch 67 | Train Loss: 9.32676630327478e-05 | Test Loss: 9.27686269278638e-05
torch.float32
torch.float32
Epoch 67 | Train Loss: 9.41497492021881e-05 | Test Loss: 9.247891284758225e-05
torch.float32
torch.float32
Epoch 67 | Train Loss: 9.247234993381426e-05 | Test Loss: 9.219249477609992e-05
torch.float32
torch.float32
Epoch 67 | Train Loss: 9.265122207580134e-05 | Test Loss: 9.190665878122672e-05
torch.float32
torch.float32
Epoch 68 | Train Loss: 9.224807581631467e-05 | Test Loss: 9.162179776467383e-05
torch.float32
torch.float32
Epoch 68 | Train Loss: 9.234821482095867e-05 | Test Loss: 9.133891580859199e-05
torch.float32
torch.float32
Epoch 68 | Train Loss: 9.1258902

 71%|███████   | 71/100 [00:06<00:02, 11.24it/s]

Epoch 69 | Train Loss: 8.823328244034201e-05 | Test Loss: 8.777180482866243e-05
torch.float32
torch.float32
Epoch 69 | Train Loss: 8.907149458536878e-05 | Test Loss: 8.750568667892367e-05
torch.float32
torch.float32
Epoch 69 | Train Loss: 8.749649714445695e-05 | Test Loss: 8.724067447474226e-05
torch.float32
torch.float32
Epoch 69 | Train Loss: 8.767056715441868e-05 | Test Loss: 8.697700104676187e-05
torch.float32
torch.float32
Epoch 70 | Train Loss: 8.729867840884253e-05 | Test Loss: 8.67146736709401e-05
torch.float32
torch.float32
Epoch 70 | Train Loss: 8.739794429857284e-05 | Test Loss: 8.645302295917645e-05
torch.float32
torch.float32
Epoch 70 | Train Loss: 8.637202699901536e-05 | Test Loss: 8.6194348114077e-05
torch.float32
torch.float32
Epoch 70 | Train Loss: 8.679823076818138e-05 | Test Loss: 8.593669190304354e-05
torch.float32
torch.float32
Epoch 70 | Train Loss: 8.708085806574672e-05 | Test Loss: 8.567737677367404e-05
torch.float32
torch.float32
Epoch 70 | Train Loss: 8.586419

 73%|███████▎  | 73/100 [00:06<00:02, 11.10it/s]

Epoch 71 | Train Loss: 8.307619282277301e-05 | Test Loss: 8.242733747465536e-05
torch.float32
torch.float32
Epoch 72 | Train Loss: 8.27311523607932e-05 | Test Loss: 8.218600851250812e-05
torch.float32
torch.float32
Epoch 72 | Train Loss: 8.282998169306666e-05 | Test Loss: 8.194510155590251e-05
torch.float32
torch.float32
Epoch 72 | Train Loss: 8.186345075955614e-05 | Test Loss: 8.170521323336288e-05
torch.float32
torch.float32
Epoch 72 | Train Loss: 8.227559010265395e-05 | Test Loss: 8.14660670584999e-05
torch.float32
torch.float32
Epoch 72 | Train Loss: 8.255357533926144e-05 | Test Loss: 8.12283469713293e-05
torch.float32
torch.float32
Epoch 72 | Train Loss: 8.140203863149509e-05 | Test Loss: 8.099075785139576e-05
torch.float32
torch.float32
Epoch 72 | Train Loss: 8.218085713451728e-05 | Test Loss: 8.075494406512007e-05
torch.float32
torch.float32
Epoch 72 | Train Loss: 8.074261131696403e-05 | Test Loss: 8.051852637436241e-05
torch.float32
torch.float32
Epoch 72 | Train Loss: 8.090945

 75%|███████▌  | 75/100 [00:07<00:02, 10.89it/s]

Epoch 74 | Train Loss: 7.860839104978368e-05 | Test Loss: 7.77757159085013e-05
torch.float32
torch.float32
Epoch 74 | Train Loss: 7.769309013383463e-05 | Test Loss: 7.755433034617454e-05
torch.float32
torch.float32
Epoch 74 | Train Loss: 7.809328963048756e-05 | Test Loss: 7.733221718808636e-05
torch.float32
torch.float32
Epoch 74 | Train Loss: 7.836746226530522e-05 | Test Loss: 7.711263606324792e-05
torch.float32
torch.float32
Epoch 74 | Train Loss: 7.727299089310691e-05 | Test Loss: 7.689172343816608e-05
torch.float32
torch.float32
Epoch 74 | Train Loss: 7.801698666298762e-05 | Test Loss: 7.667347381357104e-05
torch.float32
torch.float32
Epoch 74 | Train Loss: 7.665866723982617e-05 | Test Loss: 7.645547884749249e-05
torch.float32
torch.float32
Epoch 74 | Train Loss: 7.682364230277017e-05 | Test Loss: 7.623796409461647e-05
torch.float32
torch.float32
Epoch 75 | Train Loss: 7.651581836398691e-05 | Test Loss: 7.602317782584578e-05
torch.float32
torch.float32
Epoch 75 | Train Loss: 7.6614

 77%|███████▋  | 77/100 [00:07<00:02, 10.99it/s]

Epoch 76 | Train Loss: 7.421689224429429e-05 | Test Loss: 7.350283703999594e-05
torch.float32
torch.float32
Epoch 76 | Train Loss: 7.448897667927667e-05 | Test Loss: 7.329812069656327e-05
torch.float32
torch.float32
Epoch 76 | Train Loss: 7.344682671828195e-05 | Test Loss: 7.309435022762045e-05
torch.float32
torch.float32
Epoch 76 | Train Loss: 7.415891013806686e-05 | Test Loss: 7.289140921784565e-05
torch.float32
torch.float32
Epoch 76 | Train Loss: 7.287473272299394e-05 | Test Loss: 7.268793706316501e-05
torch.float32
torch.float32
Epoch 76 | Train Loss: 7.303550228243694e-05 | Test Loss: 7.248748443089426e-05
torch.float32
torch.float32
Epoch 77 | Train Loss: 7.274961535586044e-05 | Test Loss: 7.228793401736766e-05
torch.float32
torch.float32
Epoch 77 | Train Loss: 7.284804451046512e-05 | Test Loss: 7.208852184703574e-05
torch.float32
torch.float32
Epoch 77 | Train Loss: 7.200597610790282e-05 | Test Loss: 7.188934250734746e-05
torch.float32
torch.float32
Epoch 77 | Train Loss: 7.238

 79%|███████▉  | 79/100 [00:07<00:01, 10.96it/s]

Epoch 78 | Train Loss: 6.98945441399701e-05 | Test Loss: 6.956636207178235e-05
torch.float32
torch.float32
Epoch 78 | Train Loss: 7.057568291202188e-05 | Test Loss: 6.93779657012783e-05
torch.float32
torch.float32
Epoch 78 | Train Loss: 6.935915007488802e-05 | Test Loss: 6.91903114784509e-05
torch.float32
torch.float32
Epoch 78 | Train Loss: 6.951892282813787e-05 | Test Loss: 6.900267180753872e-05
torch.float32
torch.float32
Epoch 79 | Train Loss: 6.925214984221384e-05 | Test Loss: 6.881551234982908e-05
torch.float32
torch.float32
Epoch 79 | Train Loss: 6.934661359991878e-05 | Test Loss: 6.863150338176638e-05
torch.float32
torch.float32
Epoch 79 | Train Loss: 6.854895036667585e-05 | Test Loss: 6.844630115665495e-05
torch.float32
torch.float32
Epoch 79 | Train Loss: 6.891774683026597e-05 | Test Loss: 6.826176831964403e-05
torch.float32
torch.float32
Epoch 79 | Train Loss: 6.91800523782149e-05 | Test Loss: 6.807931640651077e-05
torch.float32
torch.float32
Epoch 79 | Train Loss: 6.8213812

 83%|████████▎ | 83/100 [00:07<00:01, 11.07it/s]

Epoch 80 | Train Loss: 6.624406523769721e-05 | Test Loss: 6.575923907803372e-05
torch.float32
torch.float32
Epoch 81 | Train Loss: 6.599527841899544e-05 | Test Loss: 6.558666791534051e-05
torch.float32
torch.float32
Epoch 81 | Train Loss: 6.608924741158262e-05 | Test Loss: 6.541279435623437e-05
torch.float32
torch.float32
Epoch 81 | Train Loss: 6.533063424285501e-05 | Test Loss: 6.524119817186147e-05
torch.float32
torch.float32
Epoch 81 | Train Loss: 6.568824028363451e-05 | Test Loss: 6.506993668153882e-05
torch.float32
torch.float32
Epoch 81 | Train Loss: 6.594687147298828e-05 | Test Loss: 6.490003579529002e-05
torch.float32
torch.float32
Epoch 81 | Train Loss: 6.502569158328697e-05 | Test Loss: 6.472897075582296e-05
torch.float32
torch.float32
Epoch 81 | Train Loss: 6.566262163687497e-05 | Test Loss: 6.455900438595563e-05
torch.float32
torch.float32
Epoch 81 | Train Loss: 6.453745299950242e-05 | Test Loss: 6.438972195610404e-05
torch.float32
torch.float32
Epoch 81 | Train Loss: 6.469

 85%|████████▌ | 85/100 [00:07<00:01, 11.12it/s]

Epoch 83 | Train Loss: 6.232900341274217e-05 | Test Loss: 6.225262768566608e-05
torch.float32
torch.float32
Epoch 83 | Train Loss: 6.267743447097018e-05 | Test Loss: 6.209143612068146e-05
torch.float32
torch.float32
Epoch 83 | Train Loss: 6.293020851444453e-05 | Test Loss: 6.193338049342856e-05
torch.float32
torch.float32
Epoch 83 | Train Loss: 6.205063255038112e-05 | Test Loss: 6.17737096035853e-05
torch.float32
torch.float32
Epoch 83 | Train Loss: 6.266161653911695e-05 | Test Loss: 6.1615755839739e-05
torch.float32
torch.float32
Epoch 83 | Train Loss: 6.159243639558554e-05 | Test Loss: 6.145717634353787e-05
torch.float32
torch.float32
Epoch 83 | Train Loss: 6.174383452162147e-05 | Test Loss: 6.13012962276116e-05
torch.float32
torch.float32
Epoch 84 | Train Loss: 6.152060814201832e-05 | Test Loss: 6.114276038715616e-05
torch.float32
torch.float32
Epoch 84 | Train Loss: 6.160731572890654e-05 | Test Loss: 6.09880116826389e-05
torch.float32
torch.float32
Epoch 84 | Train Loss: 6.09044800

 87%|████████▋ | 87/100 [00:08<00:01, 10.97it/s]

Epoch 85 | Train Loss: 5.9271755162626505e-05 | Test Loss: 5.901376425754279e-05
torch.float32
torch.float32
Epoch 85 | Train Loss: 5.9858051827177405e-05 | Test Loss: 5.8864236052613705e-05
torch.float32
torch.float32
Epoch 85 | Train Loss: 5.8838904806179926e-05 | Test Loss: 5.871767643839121e-05
torch.float32
torch.float32
Epoch 85 | Train Loss: 5.898890594835393e-05 | Test Loss: 5.857040741830133e-05
torch.float32
torch.float32
Epoch 86 | Train Loss: 5.877863077330403e-05 | Test Loss: 5.842433893121779e-05
torch.float32
torch.float32
Epoch 86 | Train Loss: 5.886489088879898e-05 | Test Loss: 5.827758286613971e-05
torch.float32
torch.float32
Epoch 86 | Train Loss: 5.8194105804432184e-05 | Test Loss: 5.813304233015515e-05
torch.float32
torch.float32
Epoch 86 | Train Loss: 5.8528130466584116e-05 | Test Loss: 5.798911661258899e-05
torch.float32
torch.float32
Epoch 86 | Train Loss: 5.877532021258958e-05 | Test Loss: 5.784438326372765e-05
torch.float32
torch.float32
Epoch 86 | Train Loss:

 89%|████████▉ | 89/100 [00:08<00:00, 11.06it/s]

Epoch 87 | Train Loss: 5.626518031931482e-05 | Test Loss: 5.61540546186734e-05
torch.float32
torch.float32
Epoch 87 | Train Loss: 5.6412332924082875e-05 | Test Loss: 5.6016349844867364e-05
torch.float32
torch.float32
Epoch 88 | Train Loss: 5.6214277719845995e-05 | Test Loss: 5.587952182395384e-05
torch.float32
torch.float32
Epoch 88 | Train Loss: 5.629828592645936e-05 | Test Loss: 5.5742275435477495e-05
torch.float32
torch.float32
Epoch 88 | Train Loss: 5.565887113334611e-05 | Test Loss: 5.56067461729981e-05
torch.float32
torch.float32
Epoch 88 | Train Loss: 5.598340794676915e-05 | Test Loss: 5.5471416999353096e-05
torch.float32
torch.float32
Epoch 88 | Train Loss: 5.622449316433631e-05 | Test Loss: 5.533589865081012e-05
torch.float32
torch.float32
Epoch 88 | Train Loss: 5.543442603084259e-05 | Test Loss: 5.520130071090534e-05
torch.float32
torch.float32
Epoch 88 | Train Loss: 5.598719508270733e-05 | Test Loss: 5.506748129846528e-05
torch.float32
torch.float32
Epoch 88 | Train Loss: 5.

 91%|█████████ | 91/100 [00:08<00:00, 10.97it/s]

torch.float32
torch.float32
Epoch 90 | Train Loss: 5.3891355491941795e-05 | Test Loss: 5.336624235496856e-05
torch.float32
torch.float32
Epoch 90 | Train Loss: 5.3282052249414846e-05 | Test Loss: 5.3238611144479364e-05
torch.float32
torch.float32
Epoch 90 | Train Loss: 5.3598032536683604e-05 | Test Loss: 5.3111041779629886e-05
torch.float32
torch.float32
Epoch 90 | Train Loss: 5.383332245401107e-05 | Test Loss: 5.298515679896809e-05
torch.float32
torch.float32
Epoch 90 | Train Loss: 5.30778088432271e-05 | Test Loss: 5.285854422254488e-05
torch.float32
torch.float32
Epoch 90 | Train Loss: 5.360872091841884e-05 | Test Loss: 5.2732459153048694e-05
torch.float32
torch.float32
Epoch 90 | Train Loss: 5.2702973334817216e-05 | Test Loss: 5.260774923954159e-05
torch.float32
torch.float32
Epoch 90 | Train Loss: 5.2846404287265614e-05 | Test Loss: 5.248177330940962e-05
torch.float32
torch.float32
Epoch 91 | Train Loss: 5.2666349802166224e-05 | Test Loss: 5.2356688684085384e-05
torch.float32
torch

 93%|█████████▎| 93/100 [00:08<00:00, 11.00it/s]

Epoch 92 | Train Loss: 5.105164018459618e-05 | Test Loss: 5.101603164803237e-05
torch.float32
torch.float32
Epoch 92 | Train Loss: 5.135930405231193e-05 | Test Loss: 5.0896447646664456e-05
torch.float32
torch.float32
Epoch 92 | Train Loss: 5.158887506695464e-05 | Test Loss: 5.0776889111148193e-05
torch.float32
torch.float32
Epoch 92 | Train Loss: 5.0863975047832355e-05 | Test Loss: 5.065797449788079e-05
torch.float32
torch.float32
Epoch 92 | Train Loss: 5.1374907343415543e-05 | Test Loss: 5.0540165830170736e-05
torch.float32
torch.float32
Epoch 92 | Train Loss: 5.050955951446667e-05 | Test Loss: 5.042129851062782e-05
torch.float32
torch.float32
Epoch 92 | Train Loss: 5.064783545094542e-05 | Test Loss: 5.030533793615177e-05
torch.float32
torch.float32
Epoch 93 | Train Loss: 5.0481397920520976e-05 | Test Loss: 5.018713272875175e-05
torch.float32
torch.float32
Epoch 93 | Train Loss: 5.055640576756559e-05 | Test Loss: 5.006990613765083e-05
torch.float32
torch.float32
Epoch 93 | Train Loss:

 95%|█████████▌| 95/100 [00:08<00:00, 10.99it/s]

Epoch 94 | Train Loss: 4.9478327127872035e-05 | Test Loss: 4.8700003389967605e-05
torch.float32
torch.float32
Epoch 94 | Train Loss: 4.87815668748226e-05 | Test Loss: 4.859014370595105e-05
torch.float32
torch.float32
Epoch 94 | Train Loss: 4.9275255150860175e-05 | Test Loss: 4.84778756799642e-05
torch.float32
torch.float32
Epoch 94 | Train Loss: 4.844596332986839e-05 | Test Loss: 4.8365298425778747e-05
torch.float32
torch.float32
Epoch 94 | Train Loss: 4.858102693106048e-05 | Test Loss: 4.825624273507856e-05
torch.float32
torch.float32
Epoch 95 | Train Loss: 4.8424219130538404e-05 | Test Loss: 4.814544445252977e-05
torch.float32
torch.float32
Epoch 95 | Train Loss: 4.8499197873752564e-05 | Test Loss: 4.803640331374481e-05
torch.float32
torch.float32
Epoch 95 | Train Loss: 4.7953602916095406e-05 | Test Loss: 4.792683830601163e-05
torch.float32
torch.float32
Epoch 95 | Train Loss: 4.8247802624246106e-05 | Test Loss: 4.781710231327452e-05
torch.float32
torch.float32
Epoch 95 | Train Loss:

 99%|█████████▉| 99/100 [00:09<00:00, 11.12it/s]

Epoch 96 | Train Loss: 4.6502846089424565e-05 | Test Loss: 4.643200009013526e-05
torch.float32
torch.float32
Epoch 96 | Train Loss: 4.663770232582465e-05 | Test Loss: 4.632792479242198e-05
torch.float32
torch.float32
Epoch 97 | Train Loss: 4.648849062505178e-05 | Test Loss: 4.622337291948497e-05
torch.float32
torch.float32
Epoch 97 | Train Loss: 4.656014425563626e-05 | Test Loss: 4.612081102095544e-05
torch.float32
torch.float32
Epoch 97 | Train Loss: 4.603829074767418e-05 | Test Loss: 4.601804539561272e-05
torch.float32
torch.float32
Epoch 97 | Train Loss: 4.6326244046213105e-05 | Test Loss: 4.591475226334296e-05
torch.float32
torch.float32
Epoch 97 | Train Loss: 4.6541401388822123e-05 | Test Loss: 4.5811910240445286e-05
torch.float32
torch.float32
Epoch 97 | Train Loss: 4.588545925798826e-05 | Test Loss: 4.5709151891060174e-05
torch.float32
torch.float32
Epoch 97 | Train Loss: 4.635014192899689e-05 | Test Loss: 4.560881279758178e-05
torch.float32
torch.float32
Epoch 97 | Train Loss: 

100%|██████████| 100/100 [00:09<00:00, 10.79it/s]

Epoch 99 | Train Loss: 4.473371882340871e-05 | Test Loss: 4.431409979588352e-05
torch.float32
torch.float32
Epoch 99 | Train Loss: 4.423230348038487e-05 | Test Loss: 4.421750418259762e-05
torch.float32
torch.float32
Epoch 99 | Train Loss: 4.451301720109768e-05 | Test Loss: 4.411974805407226e-05
torch.float32
torch.float32
Epoch 99 | Train Loss: 4.4723052269546315e-05 | Test Loss: 4.4024262024322525e-05
torch.float32
torch.float32
Epoch 99 | Train Loss: 4.409309622133151e-05 | Test Loss: 4.392766277305782e-05
torch.float32
torch.float32
Epoch 99 | Train Loss: 4.45414443674963e-05 | Test Loss: 4.38314164057374e-05
torch.float32
torch.float32
Epoch 99 | Train Loss: 4.379630263429135e-05 | Test Loss: 4.3735602957895026e-05
torch.float32
torch.float32
Epoch 99 | Train Loss: 4.39273972006049e-05 | Test Loss: 4.364071719464846e-05





In [15]:
seq = t.tensor([10] * 10).to(cfg.device)
out = model(seq)

In [16]:
out[:,-1,:]

tensor([[-68.5691, -53.3248, -33.7109,  28.4673, -47.1521,  49.8403,  20.7787,
          18.6870, -19.6562,  14.9037]], device='cuda:0',
       grad_fn=<SliceBackward0>)

In [17]:
line(test_losses, log_y=True)
# plt.plot(train_losses)

In [18]:
line(train_losses, log_y=True)

In [19]:
model

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0): TransformerBlock(
      (ln1): Identity()
      (ln2): Identity()
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_resid_pre): HookPoint()
      (hook_resid_mid): HookPoint()
      (hook_resid_post): HookPoint()
    )
  )
  (unembed): Unembed()
)

In [20]:
imshow(model.embed.W_E)

In [21]:
# Take the dot product of all the embedding vectors
emb = model.embed.W_E
vec_count = emb.shape[0]
vec_dim = emb.shape[1]
print(f"The embedding shape is {emb.shape}, so our vectors of length {emb.shape[1]}")

dot_products = einops.einsum(emb, emb, "v2 embs, v1 emb -> v1 v2")

The embedding shape is torch.Size([11, 128]), so our vectors of length 128


In [22]:
print(dot_products.shape)
imshow_div(dot_products)

torch.Size([11, 11])


## What would your hypothesis around the attention head activations be based on seeing this?
+ Jack - My poorly informed guess is that tokens with low dot products and/or low norms won't have any strong attentional interaction
+ Omar - I think that corner moves [0, 2, 6, 8] will have similar attention patterns
+ Ari - I think same as Omar, plus center attends to everything, middle edges have attention symmetry too

In [23]:
tokens = [0,1,2,3,4,6,5,8,7]
# tokens = ([10] * 5) + [1,2,5,8,7]
str_tokens = [str(token) for token in tokens]
logits, cache = model.run_with_cache(torch.tensor(tokens).to('cuda'), remove_batch_dim=True)


In [24]:
print(type(cache))
attention_pattern = cache["pattern", 0, "attn"]
print(attention_pattern.shape)

<class 'transformer_lens.ActivationCache.ActivationCache'>
torch.Size([1, 10, 10])


In [25]:
cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern)