In [2]:
import os
import sys
import requests
from pathlib import Path
import numpy as np
import plotly.express as px
from tqdm import tqdm
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, reduce, repeat
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from torch.utils.data import Dataset, DataLoader
from dataclasses import dataclass
from fancy_einsum import einsum
import math 

device = t.device("cuda" if t.cuda.is_available() else "cpu")

Create the toy dataset

There are two options for input sequences: noise and signal. The signal contains a 0 preceded by 2 numbers less in [1,4], and has a ground truth answer given by an arbitrary map. 

In [3]:
ans_map = {(1, 1):0, (1, 2):1, (1, 3):2, (1, 4):3, (2, 2):4, (2, 3):5, (2, 4):6, (3,3):7, (3, 4):8, (4, 4):9}
class SkipTrigramDataset(Dataset):

    def __init__(self, size: int, prob:float, prompt_length: int, seed:int = 42):
        self.size = size
        self.prob = prob
        self.prompt_length = prompt_length
        self.data = [self.generate_sample() for _ in range(size)]
        t.manual_seed(seed=seed)

    def generate_sample(self):
        random_tensor = t.randint(low = 5, high = 11, size = (self.prompt_length,))
        random_int = t.randint(low = 0, high = self.prob, size = (1,))
        if random_int.item() == self.prob - 1:
            rand_num_1 = t.randint(1, 5, size = (1,)).item()
            rand_num_2 = t.randint(1, 5, size = (1,)).item()
            random_tensor[0] = rand_num_1
            random_tensor[-2] = rand_num_2
            sample = random_tensor
            sample[-1] = 0
            label = ans_map[tuple(sorted((rand_num_1, rand_num_2)))]
            label = t.tensor([label])
        else:
            sample = random_tensor
            label = t.randint(low = 0, high = 11, size = (1,))

        return sample, label


    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.size

In [4]:
ans_map = {(1, 1):0, (1, 2):1, (1, 3):2, (1, 4):3, (2, 2):4, (2, 3):5, (2, 4):6, (3,3):7, (3, 4):8, (4, 4):9}

example_dataset = SkipTrigramDataset(size = 10_000, prob = 2, prompt_length = 4)
for i in range(10):
    print(f'Sample: {example_dataset[i][0]}, label: {example_dataset[i][1][-1]}')



Sample: tensor([ 7,  9, 10, 10]), label: 3
Sample: tensor([5, 8, 9, 8]), label: 0
Sample: tensor([2, 7, 4, 0]), label: 6
Sample: tensor([3, 6, 4, 0]), label: 8
Sample: tensor([1, 9, 1, 0]), label: 0
Sample: tensor([8, 6, 7, 7]), label: 4
Sample: tensor([2, 7, 2, 0]), label: 4
Sample: tensor([10,  6,  5,  7]), label: 4
Sample: tensor([ 8,  6, 10,  9]), label: 3
Sample: tensor([1, 5, 2, 0]), label: 1


In [5]:
signal = [example[1].item() for example in example_dataset if example[0][-1] == 0]
noise = [example[1].item() for example in example_dataset if example[0][-1] != 0]

print(len(signal), len(noise))

px.histogram(signal)
#px.histogram(noise)

## repeated ints < 5 are half as likely in the signal dataset. This shouldn't matter, but could check to see if those outputs (0, 4, 7, 9) are treated systematically differently by the model


4957 5043


In [6]:
@dataclass
class model_cfg:
    d_model: int = 12
    d_vocab: int = 12
    init_range: float = 0.02
    n_ctx: int = 5
    d_head: int = 3
    n_heads: int = 3
    n_layers: int = 1

model_config = model_cfg()

print(model_config)

model_cfg(d_model=12, d_vocab=12, init_range=0.02, n_ctx=5, d_head=3, n_heads=3, n_layers=1)


In [7]:
class OneHotEmbed(nn.Module):
    def __init__(self, model_cfg):
        super().__init__()
        self.model_cfg = model_cfg
        self.embedding = nn.Embedding(self.model_cfg.d_vocab, self.model_cfg.d_vocab)
        self.embedding.weight = nn.Parameter(t.eye(self.model_cfg.d_vocab))
        self.embedding.weight.requires_grad = False

    def forward(self, input):
        # tokens: [batch, position]
        # output: [batch, position, d_model]
        input = input.to(device)
        embed_out = self.embedding(input)
        return embed_out
    
class OneHotUnembed(nn.Module):
    def __init__(self, model_cfg):
        super().__init__()
        self.model_cfg = model_cfg
        self.unembed = t.eye(self.model_cfg.d_vocab)

    def forward(self, input):
        # tokens: [batch, position, d_model]
        # output: [batch, position]
        input = input.to(device)
        return t.einsum("vm,bpm->bpv", self.unembed.to(device), input)  # [batch, pos, d_vocab]


In [8]:
class Attention(nn.Module):
    def __init__(self, model_cfg):
        super().__init__()
        self.model_cfg = model_cfg

        # set weights, biases as parameters 
        self.W_Q = nn.Parameter(t.empty((model_cfg.n_heads, model_cfg.d_model, model_cfg.d_head)))
        self.W_K = nn.Parameter(t.empty((model_cfg.n_heads, model_cfg.d_model, model_cfg.d_head)))
        self.W_V = nn.Parameter(t.empty((model_cfg.n_heads, model_cfg.d_model, model_cfg.d_head)))
        self.W_O = nn.Parameter(t.empty((model_cfg.n_heads, model_cfg.d_head, model_cfg.d_model)))
        self.b_Q = nn.Parameter(t.zeros((model_cfg.n_heads, model_cfg.d_head)))
        self.b_K = nn.Parameter(t.zeros((model_cfg.n_heads, model_cfg.d_head)))
        self.b_V = nn.Parameter(t.zeros((model_cfg.n_heads, model_cfg.d_head)))
        self.b_O = nn.Parameter(t.zeros((model_cfg.d_model)))

        #initialize randomly
        nn.init.normal_(self.W_Q, std=self.model_cfg.init_range)
        nn.init.normal_(self.W_K, std=self.model_cfg.init_range)
        nn.init.normal_(self.W_V, std=self.model_cfg.init_range)
        nn.init.normal_(self.W_O, std=self.model_cfg.init_range)
        self.register_buffer("IGNORE", t.tensor(-1e5, dtype=t.float32, device=device))

        self.attn_activation_cache = {}

    def forward(self, normalized_resid_pre: t.Tensor):
        # normalized_resid_pre: [batch, position, d_model] (we're not using LN, so just input)
        # output: [batch, position, d_model]
        normalized_resid_pre = normalized_resid_pre.to(device)

        queries = einsum('n_heads d_model d_head, batch pos_q d_model -> batch pos_q n_heads d_head', self.W_Q, normalized_resid_pre) + self.b_Q
        keys = einsum('n_heads d_model d_head, batch pos_k d_model -> batch pos_k n_heads d_head', self.W_K, normalized_resid_pre) + self.b_K
        values = einsum('n_heads d_model d_head, batch pos_k d_model -> batch pos_k n_heads d_head', self.W_V, normalized_resid_pre) + self.b_V

        attn_scores = einsum('batch pos_q n_heads d_head, batch pos_k n_heads d_head -> batch n_heads pos_q pos_k', queries, keys)

        scaled_attn = attn_scores / math.sqrt(self.model_cfg.d_head)
        masked_attn = self.apply_causal_mask(scaled_attn)
        # apply softmax over keys to get attention pattern
        softmax_attn = masked_attn.softmax(dim = -1)

        # get weighted attention values
        z = einsum('batch pos_k n_heads d_head, batch n_heads pos_q pos_k -> batch pos_q n_heads d_head', values, softmax_attn)

        # layer output
        output = einsum('n_heads d_head d_model, batch pos_q n_heads d_head -> batch pos_q d_model', self.W_O, z) + self.b_O

        ## save to the cache 
        self.attn_activation_cache['hook_q'] = queries
        self.attn_activation_cache['hook_k'] = keys
        self.attn_activation_cache['hook_attn_scores'] = attn_scores
        self.attn_activation_cache['hook_pattern'] = softmax_attn
        self.attn_activation_cache['hook_v'] = values
        self.attn_activation_cache['hook_z'] = z
        self.attn_activation_cache['hook_attn_out'] = output

        return output

    def apply_causal_mask(self, attn_scores: t.Tensor):
        # attn_scores: [batch, n_heads, query_pos, key_pos]
        # output: [batch, n_heads, query_pos, key_pos]
        mask = t.triu(attn_scores, diagonal=1).bool()
        return attn_scores.masked_fill_(mask, self.IGNORE)

define a one layer attn only class      

In [9]:
class OneLayerAttnOnly(nn.Module):
    def __init__(self, model_cfg):
        super().__init__()
        self.model_cfg = model_cfg
        self.embed = OneHotEmbed(model_cfg)
        self.attn = Attention(model_cfg)
        self.unembed = OneHotUnembed(model_cfg)

        self.activation_cache = {}


    def forward(self, tokens):
        # tokens [batch, pos]
        # output [batch, pos, d_vocab]

        input = tokens.to(device)
        resid_pre = self.embed(input)
        resid_post = resid_pre + self.attn(resid_pre)
        out = self.unembed(resid_post)

        # add activations to the same cache as before
        self.activation_cache = self.attn.attn_activation_cache
        self.activation_cache['hook_resid_pre'] = resid_pre
        self.activation_cache['hook_resid_post'] = resid_post
        self.activation_cache['unembed'] = out

        return out
    

attn_only_1l = OneLayerAttnOnly(model_cfg).to(device)
attn_only_1l

        

OneLayerAttnOnly(
  (embed): OneHotEmbed(
    (embedding): Embedding(12, 12)
  )
  (attn): Attention()
  (unembed): OneHotUnembed()
)

In [10]:
def plot_metrics(train_loss,train_acc, test_loss, test_acc):
    fig = make_subplots(rows=2,cols=2, subplot_titles = ['Train Loss', 'Train Accuracy', 'Test Loss', 'Test Accuracy'])
    fig.add_trace(go.Scatter(y=train_loss), row=1, col=1)
    fig.add_trace(go.Scatter(y=train_acc), row=1, col=2)
    fig.add_trace(go.Scatter(y=test_loss), row=2, col=1)
    fig.add_trace(go.Scatter(y=test_acc), row=2, col=2)
    # get rid of legend
    fig.update_layout(showlegend=False)
    ## add x and y axis labels and change plot style 
    fig.update_xaxes(
        title = 'Epoch',
        zeroline=True,
        zerolinewidth=1,
        zerolinecolor='black',
        showgrid=False
    )
    fig.update_yaxes(
        title = 'Loss',
        showgrid=False,
        col=1
    )
    fig.update_yaxes(
        title = 'Accuracy',
        showgrid=False,
        col=2
    )
    fig.update_layout(
        font = dict(size =12),
        width=800,
        height =500,
        margin=dict(l=10, 
                    r = 20,
                    b=10,
                    t=20,
                    pad=0)
    )
    fig.show()

In [11]:
@dataclass
class train_cfg:
    ctx_length: int = 4
    trainset_size: int = 100000
    valset_size: int = 1000
    prob: int = 2
    epochs: int = 15
    batch_size: int = 512
    weight_decay: float = 0.0
    seed: int = 42
    lr: float = 1e-3

class Trainer:

    def __init__(self, train_cfg, model_cfg):
        self.train_cfg = train_cfg
        self.model = OneLayerAttnOnly(model_cfg).to(device)


    def train_dataloader(self, seed: int):
        trainset = SkipTrigramDataset(size=self.train_cfg.trainset_size,
                                      prob=self.train_cfg.prob,
                                      prompt_length=self.train_cfg.ctx_length)

        return DataLoader(trainset, batch_size=self.train_cfg.batch_size, shuffle=True)

    def val_dataloader(self, seed: int):
        valset = SkipTrigramDataset(size=self.train_cfg.valset_size,
                                    prob = self.train_cfg.prob,
                                    prompt_length=self.train_cfg.ctx_length)
        return DataLoader(valset, batch_size=self.train_cfg.batch_size, shuffle=False)

    def configure_optimizers(self):
        optimizer = t.optim.Adam(self.model.parameters(), lr=self.train_cfg.lr, weight_decay=self.train_cfg.weight_decay)
        return optimizer

In [67]:



def train(model, train_data, criterion, optimizer):

    avg_loss = 0
    acc = 0

    for _, batch in list(enumerate(train_data)):

        X = batch[0].to(device)
        y = batch[1].squeeze().to(device)
        assert X.shape[0] == y.shape[0]

        output = model(X)[:,-1]

        with t.no_grad():
            corr = (output.argmax(dim=-1) == y).count_nonzero()
        loss = criterion(output,y)
        loss.backward()

        optimizer.step()
        avg_loss += loss.item() * len(batch)
        acc += corr.item()
        optimizer.zero_grad()

    cache = model.activation_cache

    return avg_loss/train_cfg.trainset_size, acc/train_cfg.trainset_size, cache
    
def test(model, test_data, criterion):

    avg_loss = 0
    acc = 0

    for _, batch in list(enumerate(test_data)):
        X = batch[0].to(device)
        y = batch[1].squeeze().to(device)
        assert X.shape[0] == y.shape[0]

        # don't log the gradients for testing
        with t.no_grad():
            output = model(X)[:,-1]
        loss = criterion(output,y)
        corr = (output.argmax(dim=-1) == y).count_nonzero()
        avg_loss += loss.item() * len(batch)
        acc += corr.item()


    return avg_loss/train_cfg.valset_size, acc/train_cfg.valset_size

In [68]:
from tqdm import tqdm 

def train_and_eval(train_cfg, model_cfg, save_caches = []):

    trainer = Trainer(train_cfg, model_cfg)
    model = trainer.model
    train_data = trainer.train_dataloader(seed=train_cfg.seed)

    test_data = trainer.val_dataloader(seed=train_cfg.seed)
    criterion = nn.CrossEntropyLoss()
    optimizer = trainer.configure_optimizers()
    train_loss = []
    train_acc = []
    test_loss = []
    test_acc = []

    caches = []

    for n in tqdm(range(train_cfg.epochs)):
        loss, acc, cache = train(model, train_data, criterion, optimizer)

        train_loss.append(loss)
        train_acc.append(acc)
        t_loss, t_acc = test(model, test_data, criterion)
        test_loss.append(t_loss)
        test_acc.append(t_acc)
        print(f"train loss = {loss}, train acc = {acc}, test loss = {t_loss}, test acc = {t_acc}")

        if n in save_caches:
            caches.append(cache)

    return train_loss, train_acc, test_loss, test_acc, caches, trainer.model




Helper functions for plotting

In [69]:
from transformer_lens import utils, HookedTransformer, ActivationCache

def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [70]:
def get_head_to_resid(model, cache, head, seq):
  model(seq)
  z = cache["hook_z"][:, :, head, :]
  wo = model.attn.W_O[head, :, :]
  bo = model.attn.b_O
  out = einsum("batch pos d_head, d_head d_model -> batch pos d_model", z, wo) + bo
  return out

def get_head_to_logits(model, head_to_resid):
  return model.unembed(head_to_resid)

def show_pattern(cache, seq, model):
    X = t.Tensor(seq).unsqueeze(0).int()
    model(X)
    ax_label = [f"{i}: {z.item()}" for i, z in enumerate(X[0])]
    imshow(cache['hook_pattern'][0], facet_col = 0, x = ax_label, y = ax_label)

def get_head_to_resid(model, cache, head, seq):
  model(seq)
  z = cache["hook_z"][:, :, head, :]
  wo = model.attn.W_O[head, :, :]
  bo = model.attn.b_O
  out = einsum("batch pos d_head, d_head d_model -> batch pos d_model", z, wo) + bo
  return out

def get_head_to_logits(model, head_to_resid):
  return model.unembed(head_to_resid)





Running for many iterations of d_head

In [94]:
train_cfg.lr

0.001

In [81]:
train_losses = {}
train_accs = {}
test_losses = {}
test_accs = {}
models = {}
#for dh in [1,2,3,4,5]:
for dh in range(1,6):

    train_config = train_cfg(
        ctx_length=4,
        weight_decay=0,
        trainset_size = 100000,
        prob = 2,
        epochs = 60,
        batch_size = 512
    )

    model_config = model_cfg(
        d_model = 11,
        d_vocab = 11,
        d_head = dh,
        init_range = 0.02,
        n_ctx = 4,
        n_heads = 2,
        n_layers = 1)

    train_losses[dh], train_accs[dh], test_losses[dh], test_accs[dh], _, models[dh] = train_and_eval(train_config, model_config)

  2%|▏         | 1/60 [00:00<00:25,  2.30it/s]

train loss = 0.009543825511932374, train acc = 0.07678, test loss = 0.009621893882751465, test acc = 0.084


  3%|▎         | 2/60 [00:00<00:25,  2.30it/s]

train loss = 0.009316181807518005, train acc = 0.12629, test loss = 0.009398043632507324, test acc = 0.144


  5%|▌         | 3/60 [00:01<00:24,  2.28it/s]

train loss = 0.009040278673171998, train acc = 0.17765, test loss = 0.008973809242248535, test acc = 0.21


  7%|▋         | 4/60 [00:01<00:24,  2.32it/s]

train loss = 0.008398842458724975, train acc = 0.2397, test loss = 0.00812405776977539, test acc = 0.335


  8%|▊         | 5/60 [00:02<00:23,  2.33it/s]

train loss = 0.007657382144927978, train acc = 0.28222, test loss = 0.007563254833221436, test acc = 0.275


 10%|█         | 6/60 [00:02<00:22,  2.37it/s]

train loss = 0.0072420110893249515, train acc = 0.3167, test loss = 0.007254176378250122, test acc = 0.335


 12%|█▏        | 7/60 [00:02<00:22,  2.39it/s]

train loss = 0.00699178507566452, train acc = 0.33279, test loss = 0.007032164573669434, test acc = 0.385


 13%|█▎        | 8/60 [00:03<00:21,  2.40it/s]

train loss = 0.006779979498386383, train acc = 0.34332, test loss = 0.006823414325714111, test acc = 0.319


 15%|█▌        | 9/60 [00:03<00:21,  2.37it/s]

train loss = 0.006552622761726379, train acc = 0.381, test loss = 0.006583744764328003, test acc = 0.45


 17%|█▋        | 10/60 [00:04<00:21,  2.36it/s]

train loss = 0.0063114313197135926, train acc = 0.47004, test loss = 0.006358935117721558, test acc = 0.477


 18%|█▊        | 11/60 [00:04<00:20,  2.33it/s]

train loss = 0.006108821287155152, train acc = 0.48292, test loss = 0.006171817541122436, test acc = 0.477


 20%|██        | 12/60 [00:05<00:20,  2.35it/s]

train loss = 0.005944058604240418, train acc = 0.4829, test loss = 0.006020768165588379, test acc = 0.477


 22%|██▏       | 13/60 [00:05<00:19,  2.38it/s]

train loss = 0.005809584743976593, train acc = 0.4829, test loss = 0.005890937089920044, test acc = 0.477


 23%|██▎       | 14/60 [00:05<00:19,  2.40it/s]

train loss = 0.005699225189685821, train acc = 0.4829, test loss = 0.005784554719924927, test acc = 0.477


 25%|██▌       | 15/60 [00:06<00:18,  2.42it/s]

train loss = 0.005602450997829437, train acc = 0.49077, test loss = 0.00569196891784668, test acc = 0.514


 27%|██▋       | 16/60 [00:06<00:18,  2.42it/s]

train loss = 0.005515144824981689, train acc = 0.51402, test loss = 0.005610179901123047, test acc = 0.514


 28%|██▊       | 17/60 [00:07<00:18,  2.38it/s]

train loss = 0.005441455819606781, train acc = 0.51411, test loss = 0.005537575721740723, test acc = 0.514


 30%|███       | 18/60 [00:07<00:17,  2.38it/s]

train loss = 0.005375908002853393, train acc = 0.51415, test loss = 0.005475484371185303, test acc = 0.514


 32%|███▏      | 19/60 [00:08<00:17,  2.38it/s]

train loss = 0.005317165205478668, train acc = 0.51411, test loss = 0.0054183506965637205, test acc = 0.513


 33%|███▎      | 20/60 [00:08<00:17,  2.29it/s]

train loss = 0.005267223613262177, train acc = 0.51417, test loss = 0.005370695352554322, test acc = 0.513


 35%|███▌      | 21/60 [00:08<00:17,  2.24it/s]

train loss = 0.005219370850324631, train acc = 0.51412, test loss = 0.005326268672943116, test acc = 0.513


 37%|███▋      | 22/60 [00:09<00:16,  2.26it/s]

train loss = 0.00518653840303421, train acc = 0.51428, test loss = 0.005290487289428711, test acc = 0.513


 38%|███▊      | 23/60 [00:09<00:16,  2.30it/s]

train loss = 0.005152434558868408, train acc = 0.5142, test loss = 0.005255842685699463, test acc = 0.514


 40%|████      | 24/60 [00:10<00:15,  2.33it/s]

train loss = 0.005120480504035949, train acc = 0.51799, test loss = 0.005226316213607788, test acc = 0.514


 42%|████▏     | 25/60 [00:10<00:15,  2.30it/s]

train loss = 0.005095786447525025, train acc = 0.53723, test loss = 0.005193859815597534, test acc = 0.547


 43%|████▎     | 26/60 [00:11<00:14,  2.33it/s]

train loss = 0.005065581171512604, train acc = 0.54476, test loss = 0.00516396164894104, test acc = 0.547


 45%|████▌     | 27/60 [00:11<00:14,  2.35it/s]

train loss = 0.005040324308872223, train acc = 0.54464, test loss = 0.005133144617080689, test acc = 0.547


 47%|████▋     | 28/60 [00:11<00:13,  2.36it/s]

train loss = 0.005009198000431061, train acc = 0.54463, test loss = 0.0051023709774017335, test acc = 0.547


 48%|████▊     | 29/60 [00:12<00:13,  2.38it/s]

train loss = 0.00497896723985672, train acc = 0.54463, test loss = 0.00506860637664795, test acc = 0.547


 50%|█████     | 30/60 [00:12<00:12,  2.38it/s]

train loss = 0.00495129681110382, train acc = 0.54463, test loss = 0.005041004657745362, test acc = 0.547


 52%|█████▏    | 31/60 [00:13<00:12,  2.39it/s]

train loss = 0.004924285914897918, train acc = 0.54463, test loss = 0.005016798973083496, test acc = 0.547


 53%|█████▎    | 32/60 [00:13<00:11,  2.40it/s]

train loss = 0.004905629906654358, train acc = 0.54463, test loss = 0.004997851371765137, test acc = 0.547


 55%|█████▌    | 33/60 [00:14<00:11,  2.39it/s]

train loss = 0.004890201222896576, train acc = 0.54463, test loss = 0.004981210470199585, test acc = 0.547


 57%|█████▋    | 34/60 [00:14<00:10,  2.39it/s]

train loss = 0.004875698120594025, train acc = 0.54463, test loss = 0.004970287084579468, test acc = 0.547


 58%|█████▊    | 35/60 [00:14<00:10,  2.40it/s]

train loss = 0.004865095903873444, train acc = 0.54463, test loss = 0.004954633474349976, test acc = 0.547


 60%|██████    | 36/60 [00:15<00:10,  2.40it/s]

train loss = 0.004853781707286835, train acc = 0.54463, test loss = 0.004945528507232666, test acc = 0.547


 62%|██████▏   | 37/60 [00:15<00:09,  2.38it/s]

train loss = 0.004844480187892914, train acc = 0.54463, test loss = 0.004935569047927856, test acc = 0.547


 63%|██████▎   | 38/60 [00:16<00:09,  2.38it/s]

train loss = 0.004836686832904816, train acc = 0.54463, test loss = 0.004927226781845093, test acc = 0.547


 65%|██████▌   | 39/60 [00:16<00:08,  2.36it/s]

train loss = 0.0048289052081108095, train acc = 0.54463, test loss = 0.004923397541046142, test acc = 0.547


 67%|██████▋   | 40/60 [00:16<00:08,  2.36it/s]

train loss = 0.004823688561916352, train acc = 0.54463, test loss = 0.00491566014289856, test acc = 0.547


 68%|██████▊   | 41/60 [00:17<00:08,  2.37it/s]

train loss = 0.004821750123500824, train acc = 0.54463, test loss = 0.004911091804504395, test acc = 0.547


 70%|███████   | 42/60 [00:17<00:07,  2.38it/s]

train loss = 0.004814723999500275, train acc = 0.54463, test loss = 0.004902906417846679, test acc = 0.547


 72%|███████▏  | 43/60 [00:18<00:07,  2.38it/s]

train loss = 0.004813664019107818, train acc = 0.54463, test loss = 0.004898611068725586, test acc = 0.547


 73%|███████▎  | 44/60 [00:18<00:06,  2.38it/s]

train loss = 0.004810398111343384, train acc = 0.54463, test loss = 0.004895089626312256, test acc = 0.547


 75%|███████▌  | 45/60 [00:19<00:06,  2.39it/s]

train loss = 0.004807173404693604, train acc = 0.54463, test loss = 0.004893956184387207, test acc = 0.547


 77%|███████▋  | 46/60 [00:19<00:05,  2.39it/s]

train loss = 0.004806364905834198, train acc = 0.54463, test loss = 0.004894092321395874, test acc = 0.547


 78%|███████▊  | 47/60 [00:19<00:05,  2.39it/s]

train loss = 0.004801880009174347, train acc = 0.54463, test loss = 0.004889023542404175, test acc = 0.547


 80%|████████  | 48/60 [00:20<00:05,  2.36it/s]

train loss = 0.004799890124797821, train acc = 0.54463, test loss = 0.004889080047607422, test acc = 0.547


 82%|████████▏ | 49/60 [00:20<00:04,  2.37it/s]

train loss = 0.004796942739486694, train acc = 0.54463, test loss = 0.004886142730712891, test acc = 0.547


 83%|████████▎ | 50/60 [00:21<00:04,  2.38it/s]

train loss = 0.004792945189476013, train acc = 0.54463, test loss = 0.004883267879486084, test acc = 0.547


 85%|████████▌ | 51/60 [00:21<00:03,  2.38it/s]

train loss = 0.004793940954208374, train acc = 0.54463, test loss = 0.004880615711212158, test acc = 0.547


 87%|████████▋ | 52/60 [00:21<00:03,  2.39it/s]

train loss = 0.004795179581642151, train acc = 0.54463, test loss = 0.004881762981414795, test acc = 0.547


 88%|████████▊ | 53/60 [00:22<00:02,  2.39it/s]

train loss = 0.0047883244562149044, train acc = 0.54463, test loss = 0.004880459547042847, test acc = 0.547


 90%|█████████ | 54/60 [00:22<00:02,  2.39it/s]

train loss = 0.004789302473068237, train acc = 0.54463, test loss = 0.004879246234893799, test acc = 0.547


 92%|█████████▏| 55/60 [00:23<00:02,  2.39it/s]

train loss = 0.004788370492458343, train acc = 0.54463, test loss = 0.004879751920700073, test acc = 0.547


 93%|█████████▎| 56/60 [00:23<00:01,  2.39it/s]

train loss = 0.0047885100698471066, train acc = 0.54463, test loss = 0.0048764655590057375, test acc = 0.547


 95%|█████████▌| 57/60 [00:24<00:01,  2.39it/s]

train loss = 0.0047863589501380924, train acc = 0.54463, test loss = 0.004874895811080932, test acc = 0.547


 97%|█████████▋| 58/60 [00:24<00:00,  2.39it/s]

train loss = 0.004787178256511688, train acc = 0.54463, test loss = 0.004876900434494019, test acc = 0.547


 98%|█████████▊| 59/60 [00:24<00:00,  2.35it/s]

train loss = 0.004785536644458771, train acc = 0.54463, test loss = 0.004875696897506714, test acc = 0.547


100%|██████████| 60/60 [00:25<00:00,  2.37it/s]

train loss = 0.0047860232496261595, train acc = 0.54463, test loss = 0.0048738689422607425, test acc = 0.547



  2%|▏         | 1/60 [00:00<00:26,  2.23it/s]

train loss = 0.009513147296905517, train acc = 0.07676, test loss = 0.009548699855804444, test acc = 0.084


  3%|▎         | 2/60 [00:00<00:26,  2.22it/s]

train loss = 0.008922900891304016, train acc = 0.25031, test loss = 0.008479077816009522, test acc = 0.382


  5%|▌         | 3/60 [00:01<00:25,  2.21it/s]

train loss = 0.007547362496852875, train acc = 0.41995, test loss = 0.007077910900115967, test acc = 0.423


  7%|▋         | 4/60 [00:01<00:25,  2.20it/s]

train loss = 0.006547872424125672, train acc = 0.45096, test loss = 0.006348059177398681, test acc = 0.451


  8%|▊         | 5/60 [00:02<00:24,  2.20it/s]

train loss = 0.006001511986255646, train acc = 0.45197, test loss = 0.005928042650222778, test acc = 0.451


 10%|█         | 6/60 [00:02<00:24,  2.20it/s]

train loss = 0.005670119967460632, train acc = 0.47338, test loss = 0.0056159842014312745, test acc = 0.521


 12%|█▏        | 7/60 [00:03<00:24,  2.19it/s]

train loss = 0.0053996628260612484, train acc = 0.51354, test loss = 0.00539586091041565, test acc = 0.521


 13%|█▎        | 8/60 [00:03<00:23,  2.19it/s]

train loss = 0.0052512659215927125, train acc = 0.51377, test loss = 0.005281142950057984, test acc = 0.523


 15%|█▌        | 9/60 [00:04<00:23,  2.16it/s]

train loss = 0.005138015999794006, train acc = 0.52566, test loss = 0.005161264419555664, test acc = 0.549


 17%|█▋        | 10/60 [00:04<00:23,  2.17it/s]

train loss = 0.005020086317062378, train acc = 0.54488, test loss = 0.00506502103805542, test acc = 0.549


 18%|█▊        | 11/60 [00:05<00:22,  2.18it/s]

train loss = 0.004941702845096588, train acc = 0.54463, test loss = 0.0050069899559020995, test acc = 0.549


 20%|██        | 12/60 [00:05<00:22,  2.18it/s]

train loss = 0.004894791188240051, train acc = 0.54487, test loss = 0.004969756603240966, test acc = 0.55


 22%|██▏       | 13/60 [00:05<00:21,  2.18it/s]

train loss = 0.00486255740404129, train acc = 0.54499, test loss = 0.004940414190292358, test acc = 0.55


 23%|██▎       | 14/60 [00:06<00:21,  2.14it/s]

train loss = 0.004840960669517517, train acc = 0.54472, test loss = 0.004925314664840698, test acc = 0.55


 25%|██▌       | 15/60 [00:06<00:20,  2.16it/s]

train loss = 0.0048249282908439635, train acc = 0.54489, test loss = 0.0049104845523834225, test acc = 0.548


 27%|██▋       | 16/60 [00:07<00:20,  2.17it/s]

train loss = 0.004813766024112701, train acc = 0.54473, test loss = 0.004900208473205566, test acc = 0.547


 28%|██▊       | 17/60 [00:07<00:19,  2.17it/s]

train loss = 0.004804046318531037, train acc = 0.54487, test loss = 0.00489087176322937, test acc = 0.549


 30%|███       | 18/60 [00:08<00:19,  2.18it/s]

train loss = 0.004799556682109833, train acc = 0.54471, test loss = 0.004883583307266235, test acc = 0.549


 32%|███▏      | 19/60 [00:08<00:18,  2.16it/s]

train loss = 0.004791651785373688, train acc = 0.54478, test loss = 0.004878183603286743, test acc = 0.55


 33%|███▎      | 20/60 [00:09<00:18,  2.17it/s]

train loss = 0.004786559956073761, train acc = 0.54469, test loss = 0.004872411489486694, test acc = 0.55


 35%|███▌      | 21/60 [00:09<00:17,  2.17it/s]

train loss = 0.0047818915963172915, train acc = 0.5446, test loss = 0.004869608163833618, test acc = 0.549


 37%|███▋      | 22/60 [00:10<00:17,  2.18it/s]

train loss = 0.0047796196842193606, train acc = 0.54484, test loss = 0.004868353843688965, test acc = 0.548


 38%|███▊      | 23/60 [00:10<00:16,  2.18it/s]

train loss = 0.004775916802883148, train acc = 0.54468, test loss = 0.004867995500564575, test acc = 0.546


 40%|████      | 24/60 [00:11<00:16,  2.19it/s]

train loss = 0.004774313304424286, train acc = 0.54453, test loss = 0.004863669633865356, test acc = 0.548


 42%|████▏     | 25/60 [00:11<00:15,  2.19it/s]

train loss = 0.0047698691940307614, train acc = 0.5447, test loss = 0.0048605425357818605, test acc = 0.547


 43%|████▎     | 26/60 [00:11<00:15,  2.19it/s]

train loss = 0.0047702870488166805, train acc = 0.54467, test loss = 0.004860096931457519, test acc = 0.545


 45%|████▌     | 27/60 [00:12<00:15,  2.19it/s]

train loss = 0.00476869790315628, train acc = 0.54451, test loss = 0.004855558395385742, test acc = 0.547


 47%|████▋     | 28/60 [00:12<00:14,  2.19it/s]

train loss = 0.004768895254135132, train acc = 0.54467, test loss = 0.004855591535568238, test acc = 0.546


 48%|████▊     | 29/60 [00:13<00:14,  2.15it/s]

train loss = 0.004765815804004669, train acc = 0.54447, test loss = 0.0048547577857971196, test acc = 0.547


 50%|█████     | 30/60 [00:13<00:14,  2.10it/s]

train loss = 0.004764660234451294, train acc = 0.54455, test loss = 0.004854392290115356, test acc = 0.545


 52%|█████▏    | 31/60 [00:14<00:13,  2.12it/s]

train loss = 0.004761838748455048, train acc = 0.54469, test loss = 0.0048531568050384525, test acc = 0.545


 53%|█████▎    | 32/60 [00:14<00:13,  2.14it/s]

train loss = 0.004761167349815369, train acc = 0.54445, test loss = 0.004850892782211303, test acc = 0.546


 55%|█████▌    | 33/60 [00:15<00:12,  2.15it/s]

train loss = 0.004757262309789657, train acc = 0.54472, test loss = 0.0048488245010375975, test acc = 0.547


 57%|█████▋    | 34/60 [00:15<00:12,  2.16it/s]

train loss = 0.004759743287563324, train acc = 0.54453, test loss = 0.004851956129074097, test acc = 0.547


 58%|█████▊    | 35/60 [00:16<00:11,  2.17it/s]

train loss = 0.004759508183002472, train acc = 0.54444, test loss = 0.0048470458984375, test acc = 0.546


 60%|██████    | 36/60 [00:16<00:11,  2.17it/s]

train loss = 0.0047595290613174435, train acc = 0.54463, test loss = 0.004847267627716065, test acc = 0.544


 62%|██████▏   | 37/60 [00:17<00:10,  2.18it/s]

train loss = 0.00475732696056366, train acc = 0.54452, test loss = 0.004848692655563354, test acc = 0.545


 63%|██████▎   | 38/60 [00:17<00:10,  2.18it/s]

train loss = 0.004756018319129944, train acc = 0.54462, test loss = 0.004845509767532349, test acc = 0.546


 65%|██████▌   | 39/60 [00:17<00:09,  2.17it/s]

train loss = 0.004754829728603363, train acc = 0.54462, test loss = 0.004842890501022339, test acc = 0.547


 67%|██████▋   | 40/60 [00:18<00:09,  2.18it/s]

train loss = 0.004757260501384735, train acc = 0.54447, test loss = 0.0048443727493286135, test acc = 0.546


 68%|██████▊   | 41/60 [00:18<00:08,  2.18it/s]

train loss = 0.004753548004627228, train acc = 0.54455, test loss = 0.004843430757522583, test acc = 0.546


 70%|███████   | 42/60 [00:19<00:08,  2.18it/s]

train loss = 0.004754845433235169, train acc = 0.54474, test loss = 0.004841988801956177, test acc = 0.547


 72%|███████▏  | 43/60 [00:19<00:07,  2.19it/s]

train loss = 0.004754567515850067, train acc = 0.54452, test loss = 0.0048408894538879395, test acc = 0.547


 73%|███████▎  | 44/60 [00:20<00:07,  2.18it/s]

train loss = 0.004751801271438599, train acc = 0.54451, test loss = 0.004839231252670288, test acc = 0.549


 75%|███████▌  | 45/60 [00:20<00:06,  2.19it/s]

train loss = 0.004754987626075745, train acc = 0.54456, test loss = 0.004843623638153076, test acc = 0.546


 77%|███████▋  | 46/60 [00:21<00:06,  2.19it/s]

train loss = 0.0047526243042945865, train acc = 0.54467, test loss = 0.00483496356010437, test acc = 0.549


 78%|███████▊  | 47/60 [00:21<00:05,  2.18it/s]

train loss = 0.004754479947090149, train acc = 0.54443, test loss = 0.004837481260299682, test acc = 0.547


 80%|████████  | 48/60 [00:22<00:05,  2.15it/s]

train loss = 0.0047502489161491395, train acc = 0.54432, test loss = 0.004835485219955444, test acc = 0.55


 82%|████████▏ | 49/60 [00:22<00:05,  2.16it/s]

train loss = 0.004750665853023529, train acc = 0.54427, test loss = 0.0048368899822235105, test acc = 0.547


 83%|████████▎ | 50/60 [00:23<00:04,  2.17it/s]

train loss = 0.004749338452816009, train acc = 0.54463, test loss = 0.0048404946327209476, test acc = 0.547


 85%|████████▌ | 51/60 [00:23<00:04,  2.17it/s]

train loss = 0.00474784658908844, train acc = 0.54449, test loss = 0.004835193157196045, test acc = 0.551


 87%|████████▋ | 52/60 [00:23<00:03,  2.18it/s]

train loss = 0.004748026795387268, train acc = 0.54461, test loss = 0.004836906433105468, test acc = 0.549


 88%|████████▊ | 53/60 [00:24<00:03,  2.18it/s]

train loss = 0.004748799881935119, train acc = 0.54446, test loss = 0.004835469961166382, test acc = 0.548


 90%|█████████ | 54/60 [00:24<00:02,  2.18it/s]

train loss = 0.004747399959564209, train acc = 0.54433, test loss = 0.004836886405944824, test acc = 0.548


 92%|█████████▏| 55/60 [00:25<00:02,  2.18it/s]

train loss = 0.00474582350730896, train acc = 0.54419, test loss = 0.004832076549530029, test acc = 0.55


 93%|█████████▎| 56/60 [00:25<00:01,  2.03it/s]

train loss = 0.004745541086196899, train acc = 0.54482, test loss = 0.004832269191741944, test acc = 0.549


 95%|█████████▌| 57/60 [00:26<00:01,  2.08it/s]

train loss = 0.004744028880596161, train acc = 0.54427, test loss = 0.00483471393585205, test acc = 0.547


 97%|█████████▋| 58/60 [00:26<00:00,  2.09it/s]

train loss = 0.004743015837669373, train acc = 0.54478, test loss = 0.004833490371704102, test acc = 0.547


 98%|█████████▊| 59/60 [00:27<00:00,  2.13it/s]

train loss = 0.00474383040189743, train acc = 0.54451, test loss = 0.0048368659019470215, test acc = 0.547


100%|██████████| 60/60 [00:27<00:00,  2.17it/s]

train loss = 0.004743586084842682, train acc = 0.54461, test loss = 0.004829589366912842, test acc = 0.549



  2%|▏         | 1/60 [00:00<00:27,  2.14it/s]

train loss = 0.0094864293384552, train acc = 0.0838, test loss = 0.009488294124603272, test acc = 0.114


  3%|▎         | 2/60 [00:01<00:33,  1.71it/s]

train loss = 0.008777773604393005, train acc = 0.2555, test loss = 0.008184651374816894, test acc = 0.349


  5%|▌         | 3/60 [00:01<00:30,  1.86it/s]

train loss = 0.0071165771317482, train acc = 0.3969, test loss = 0.00660037088394165, test acc = 0.419


  7%|▋         | 4/60 [00:02<00:28,  1.96it/s]

train loss = 0.006025668525695801, train acc = 0.47194, test loss = 0.005825858354568481, test acc = 0.512


  8%|▊         | 5/60 [00:02<00:27,  2.02it/s]

train loss = 0.00544201416015625, train acc = 0.54321, test loss = 0.005374245405197144, test acc = 0.546


 10%|█         | 6/60 [00:03<00:26,  2.06it/s]

train loss = 0.0051296464920043945, train acc = 0.54462, test loss = 0.005140653133392334, test acc = 0.547


 12%|█▏        | 7/60 [00:03<00:26,  2.02it/s]

train loss = 0.004968142447471619, train acc = 0.54471, test loss = 0.005019981861114502, test acc = 0.547


 13%|█▎        | 8/60 [00:04<00:25,  2.04it/s]

train loss = 0.00488368273973465, train acc = 0.54452, test loss = 0.004953644990921021, test acc = 0.546


 15%|█▌        | 9/60 [00:04<00:24,  2.06it/s]

train loss = 0.004840674736499786, train acc = 0.54446, test loss = 0.004918367385864258, test acc = 0.545


 17%|█▋        | 10/60 [00:04<00:24,  2.08it/s]

train loss = 0.0048140737628936765, train acc = 0.54471, test loss = 0.00489367413520813, test acc = 0.547


 18%|█▊        | 11/60 [00:05<00:23,  2.08it/s]

train loss = 0.004801923253536225, train acc = 0.54457, test loss = 0.004874054670333862, test acc = 0.549


 20%|██        | 12/60 [00:05<00:22,  2.09it/s]

train loss = 0.004786440236568451, train acc = 0.54466, test loss = 0.004868536710739136, test acc = 0.548


 22%|██▏       | 13/60 [00:06<00:22,  2.10it/s]

train loss = 0.004781545906066894, train acc = 0.54441, test loss = 0.004857131242752075, test acc = 0.548


 23%|██▎       | 14/60 [00:06<00:21,  2.10it/s]

train loss = 0.004770910346508026, train acc = 0.54451, test loss = 0.004848475217819214, test acc = 0.549


 25%|██▌       | 15/60 [00:07<00:21,  2.11it/s]

train loss = 0.004767722041606903, train acc = 0.54467, test loss = 0.004845495939254761, test acc = 0.548


 27%|██▋       | 16/60 [00:07<00:21,  2.07it/s]

train loss = 0.0047618457913398745, train acc = 0.54394, test loss = 0.004840256452560425, test acc = 0.546


 28%|██▊       | 17/60 [00:08<00:20,  2.08it/s]

train loss = 0.004757839037179947, train acc = 0.54461, test loss = 0.004835774421691894, test acc = 0.542


 30%|███       | 18/60 [00:08<00:20,  2.10it/s]

train loss = 0.00475812126159668, train acc = 0.54451, test loss = 0.004835150241851807, test acc = 0.543


 32%|███▏      | 19/60 [00:09<00:19,  2.11it/s]

train loss = 0.004755456509590149, train acc = 0.54434, test loss = 0.004830134868621826, test acc = 0.548


 33%|███▎      | 20/60 [00:09<00:18,  2.11it/s]

train loss = 0.004753807802200317, train acc = 0.54456, test loss = 0.004826694011688232, test acc = 0.548


 35%|███▌      | 21/60 [00:10<00:18,  2.11it/s]

train loss = 0.0047512233543396, train acc = 0.54414, test loss = 0.0048275871276855465, test acc = 0.546


 37%|███▋      | 22/60 [00:10<00:17,  2.11it/s]

train loss = 0.004749708545207977, train acc = 0.54475, test loss = 0.004823438882827759, test acc = 0.544


 38%|███▊      | 23/60 [00:11<00:17,  2.11it/s]

train loss = 0.004746721596717835, train acc = 0.54484, test loss = 0.0048216753005981446, test acc = 0.545


 40%|████      | 24/60 [00:11<00:17,  2.10it/s]

train loss = 0.004745888686180115, train acc = 0.54493, test loss = 0.004816394567489624, test acc = 0.547


 42%|████▏     | 25/60 [00:12<00:17,  2.05it/s]

train loss = 0.004743911325931549, train acc = 0.5449, test loss = 0.0048170387744903564, test acc = 0.552


 43%|████▎     | 26/60 [00:12<00:16,  2.05it/s]

train loss = 0.004742294292449952, train acc = 0.54491, test loss = 0.004816076040267944, test acc = 0.546


 45%|████▌     | 27/60 [00:13<00:16,  2.04it/s]

train loss = 0.004743288812637329, train acc = 0.54531, test loss = 0.004816760540008545, test acc = 0.555


 47%|████▋     | 28/60 [00:13<00:15,  2.05it/s]

train loss = 0.0047411097407341, train acc = 0.54475, test loss = 0.004810771465301514, test acc = 0.557


 48%|████▊     | 29/60 [00:14<00:15,  2.05it/s]

train loss = 0.004740234520435333, train acc = 0.54542, test loss = 0.0048124630451202395, test acc = 0.553


 50%|█████     | 30/60 [00:14<00:14,  2.06it/s]

train loss = 0.0047404482531547545, train acc = 0.54505, test loss = 0.00480885124206543, test acc = 0.555


 52%|█████▏    | 31/60 [00:15<00:15,  1.89it/s]

train loss = 0.004735819904804229, train acc = 0.5453, test loss = 0.004810020923614502, test acc = 0.55


 53%|█████▎    | 32/60 [00:15<00:16,  1.73it/s]

train loss = 0.004737437446117401, train acc = 0.54525, test loss = 0.00481025242805481, test acc = 0.554


 55%|█████▌    | 33/60 [00:16<00:15,  1.72it/s]

train loss = 0.004735367922782898, train acc = 0.54557, test loss = 0.004814444303512573, test acc = 0.551


 57%|█████▋    | 34/60 [00:17<00:15,  1.72it/s]

train loss = 0.004735669791698456, train acc = 0.5454, test loss = 0.004811274766921997, test acc = 0.554


 58%|█████▊    | 35/60 [00:17<00:14,  1.78it/s]

train loss = 0.004736109981536865, train acc = 0.54555, test loss = 0.004808882474899292, test acc = 0.55


 60%|██████    | 36/60 [00:18<00:12,  1.85it/s]

train loss = 0.004736190204620361, train acc = 0.54558, test loss = 0.004808110237121582, test acc = 0.551


 62%|██████▏   | 37/60 [00:18<00:12,  1.91it/s]

train loss = 0.0047356397604942325, train acc = 0.54486, test loss = 0.004808020830154419, test acc = 0.555


 63%|██████▎   | 38/60 [00:19<00:11,  1.97it/s]

train loss = 0.004735248119831085, train acc = 0.54565, test loss = 0.004808557987213135, test acc = 0.556


 65%|██████▌   | 39/60 [00:19<00:11,  1.87it/s]

train loss = 0.004734910571575165, train acc = 0.54551, test loss = 0.004809271574020385, test acc = 0.559


 67%|██████▋   | 40/60 [00:20<00:11,  1.80it/s]

train loss = 0.004735453782081604, train acc = 0.54569, test loss = 0.004809335708618164, test acc = 0.549


 68%|██████▊   | 41/60 [00:20<00:10,  1.84it/s]

train loss = 0.004734144394397736, train acc = 0.54613, test loss = 0.004809926509857178, test acc = 0.549


 70%|███████   | 42/60 [00:21<00:09,  1.83it/s]

train loss = 0.0047321539139747616, train acc = 0.5451, test loss = 0.004807326555252075, test acc = 0.557


 72%|███████▏  | 43/60 [00:21<00:09,  1.80it/s]

train loss = 0.004733203752040863, train acc = 0.54527, test loss = 0.004805993318557739, test acc = 0.557


 73%|███████▎  | 44/60 [00:22<00:08,  1.80it/s]

train loss = 0.004730983395576477, train acc = 0.54581, test loss = 0.0048080894947052, test acc = 0.551


 75%|███████▌  | 45/60 [00:22<00:08,  1.83it/s]

train loss = 0.004731972756385803, train acc = 0.54537, test loss = 0.004809393405914307, test acc = 0.555


 77%|███████▋  | 46/60 [00:23<00:07,  1.89it/s]

train loss = 0.00472960645198822, train acc = 0.54555, test loss = 0.0048076162338256834, test acc = 0.55


 78%|███████▊  | 47/60 [00:23<00:06,  1.91it/s]

train loss = 0.00472956734418869, train acc = 0.54588, test loss = 0.004807605028152466, test acc = 0.551


 80%|████████  | 48/60 [00:24<00:06,  1.95it/s]

train loss = 0.004729295814037323, train acc = 0.5459, test loss = 0.004809775829315186, test acc = 0.545


 82%|████████▏ | 49/60 [00:24<00:05,  1.99it/s]

train loss = 0.004730800180435181, train acc = 0.54577, test loss = 0.004812338590621948, test acc = 0.538


 83%|████████▎ | 50/60 [00:25<00:04,  2.02it/s]

train loss = 0.004729209887981415, train acc = 0.54512, test loss = 0.00481039834022522, test acc = 0.549


 85%|████████▌ | 51/60 [00:25<00:04,  2.04it/s]

train loss = 0.004729559025764466, train acc = 0.54594, test loss = 0.004811979055404663, test acc = 0.535


 87%|████████▋ | 52/60 [00:26<00:03,  2.05it/s]

train loss = 0.004729976341724396, train acc = 0.5456, test loss = 0.004807034969329834, test acc = 0.556


 88%|████████▊ | 53/60 [00:26<00:03,  2.04it/s]

train loss = 0.004730199389457702, train acc = 0.54535, test loss = 0.004806925773620605, test acc = 0.55


 90%|█████████ | 54/60 [00:27<00:02,  2.04it/s]

train loss = 0.00472721399307251, train acc = 0.54506, test loss = 0.004808887004852295, test acc = 0.547


 92%|█████████▏| 55/60 [00:27<00:02,  2.04it/s]

train loss = 0.004726460261344909, train acc = 0.54508, test loss = 0.004806224107742309, test acc = 0.554


 93%|█████████▎| 56/60 [00:28<00:01,  2.04it/s]

train loss = 0.004730852706432342, train acc = 0.54506, test loss = 0.004807130098342895, test acc = 0.549


 95%|█████████▌| 57/60 [00:28<00:01,  2.05it/s]

train loss = 0.004731340248584747, train acc = 0.54449, test loss = 0.004812620639801026, test acc = 0.543


 97%|█████████▋| 58/60 [00:29<00:00,  2.02it/s]

train loss = 0.004728030403852463, train acc = 0.54524, test loss = 0.004805869102478028, test acc = 0.551


 98%|█████████▊| 59/60 [00:29<00:00,  2.03it/s]

train loss = 0.004726052050590515, train acc = 0.54506, test loss = 0.00481128978729248, test acc = 0.536


100%|██████████| 60/60 [00:30<00:00,  1.98it/s]

train loss = 0.0047245363879203795, train acc = 0.54547, test loss = 0.004810011625289917, test acc = 0.551



  2%|▏         | 1/60 [00:00<00:28,  2.07it/s]

train loss = 0.009448097672462463, train acc = 0.10324, test loss = 0.009377979278564453, test acc = 0.284


  3%|▎         | 2/60 [00:00<00:28,  2.05it/s]

train loss = 0.008430820088386536, train acc = 0.35521, test loss = 0.007735038995742798, test acc = 0.337


  5%|▌         | 3/60 [00:01<00:28,  2.03it/s]

train loss = 0.006899761853218079, train acc = 0.44068, test loss = 0.006456543445587158, test acc = 0.447


  7%|▋         | 4/60 [00:01<00:27,  2.00it/s]

train loss = 0.005989473009109497, train acc = 0.45141, test loss = 0.00592834734916687, test acc = 0.448


  8%|▊         | 5/60 [00:02<00:27,  2.00it/s]

train loss = 0.005641938376426697, train acc = 0.46213, test loss = 0.005654768943786621, test acc = 0.475


 10%|█         | 6/60 [00:02<00:27,  1.99it/s]

train loss = 0.005387932059764862, train acc = 0.48713, test loss = 0.005392860412597656, test acc = 0.511


 12%|█▏        | 7/60 [00:03<00:26,  2.00it/s]

train loss = 0.005157398014068604, train acc = 0.51368, test loss = 0.005206426620483398, test acc = 0.512


 13%|█▎        | 8/60 [00:03<00:26,  2.00it/s]

train loss = 0.004955706021785736, train acc = 0.53575, test loss = 0.004979045867919922, test acc = 0.546


 15%|█▌        | 9/60 [00:04<00:25,  2.00it/s]

train loss = 0.004839563508033752, train acc = 0.54453, test loss = 0.004919290542602539, test acc = 0.548


 17%|█▋        | 10/60 [00:04<00:25,  1.99it/s]

train loss = 0.004805424537658691, train acc = 0.54442, test loss = 0.004895976781845093, test acc = 0.545


 18%|█▊        | 11/60 [00:05<00:24,  1.99it/s]

train loss = 0.004789747021198273, train acc = 0.54432, test loss = 0.004878388166427612, test acc = 0.546


 20%|██        | 12/60 [00:05<00:24,  1.99it/s]

train loss = 0.004779083197116851, train acc = 0.54419, test loss = 0.004871552705764771, test acc = 0.543


 22%|██▏       | 13/60 [00:06<00:23,  1.98it/s]

train loss = 0.004772185060977936, train acc = 0.54452, test loss = 0.004863869428634643, test acc = 0.543


 23%|██▎       | 14/60 [00:07<00:23,  1.98it/s]

train loss = 0.004767198491096496, train acc = 0.5438, test loss = 0.004852551221847534, test acc = 0.543


 25%|██▌       | 15/60 [00:07<00:22,  1.98it/s]

train loss = 0.004762380270957947, train acc = 0.54416, test loss = 0.0048464457988739015, test acc = 0.547


 27%|██▋       | 16/60 [00:08<00:22,  1.99it/s]

train loss = 0.004756265552043915, train acc = 0.54445, test loss = 0.004837013006210327, test acc = 0.546


 28%|██▊       | 17/60 [00:08<00:22,  1.91it/s]

train loss = 0.004751960577964782, train acc = 0.54419, test loss = 0.004828253746032715, test acc = 0.543


 30%|███       | 18/60 [00:09<00:22,  1.89it/s]

train loss = 0.004748843998908996, train acc = 0.54447, test loss = 0.004822947978973389, test acc = 0.54


 32%|███▏      | 19/60 [00:09<00:21,  1.90it/s]

train loss = 0.00474557190656662, train acc = 0.54468, test loss = 0.00481824254989624, test acc = 0.545


 33%|███▎      | 20/60 [00:10<00:20,  1.91it/s]

train loss = 0.004741587262153625, train acc = 0.54467, test loss = 0.004816799640655517, test acc = 0.542


 35%|███▌      | 21/60 [00:10<00:20,  1.92it/s]

train loss = 0.004739896562099457, train acc = 0.54458, test loss = 0.004810513734817505, test acc = 0.547


 37%|███▋      | 22/60 [00:11<00:20,  1.87it/s]

train loss = 0.0047380733180046085, train acc = 0.54505, test loss = 0.004810844898223877, test acc = 0.547


 38%|███▊      | 23/60 [00:11<00:19,  1.89it/s]

train loss = 0.004736558825969696, train acc = 0.54501, test loss = 0.004808025360107422, test acc = 0.546


 40%|████      | 24/60 [00:12<00:18,  1.91it/s]

train loss = 0.004736182332038879, train acc = 0.54472, test loss = 0.004804821729660035, test acc = 0.549


 42%|████▏     | 25/60 [00:12<00:18,  1.91it/s]

train loss = 0.004731022052764893, train acc = 0.545, test loss = 0.004803390979766846, test acc = 0.551


 43%|████▎     | 26/60 [00:13<00:17,  1.93it/s]

train loss = 0.004733965022563934, train acc = 0.54525, test loss = 0.004800679445266723, test acc = 0.554


 45%|████▌     | 27/60 [00:13<00:16,  1.95it/s]

train loss = 0.004731324770450592, train acc = 0.54524, test loss = 0.004803531885147095, test acc = 0.549


 47%|████▋     | 28/60 [00:14<00:16,  1.94it/s]

train loss = 0.004728152801990509, train acc = 0.54538, test loss = 0.004806639909744262, test acc = 0.545


 48%|████▊     | 29/60 [00:14<00:16,  1.93it/s]

train loss = 0.004729087624549866, train acc = 0.5452, test loss = 0.004803215742111206, test acc = 0.545


 50%|█████     | 30/60 [00:15<00:15,  1.95it/s]

train loss = 0.004731130955219269, train acc = 0.54455, test loss = 0.004801446676254272, test acc = 0.544


 52%|█████▏    | 31/60 [00:15<00:14,  1.96it/s]

train loss = 0.004729909179210663, train acc = 0.54483, test loss = 0.004800660133361816, test acc = 0.543


 53%|█████▎    | 32/60 [00:16<00:14,  1.96it/s]

train loss = 0.0047276060748100285, train acc = 0.54515, test loss = 0.004801815509796143, test acc = 0.545


 55%|█████▌    | 33/60 [00:16<00:13,  1.94it/s]

train loss = 0.004729731106758118, train acc = 0.54417, test loss = 0.004803106307983398, test acc = 0.548


 57%|█████▋    | 34/60 [00:17<00:13,  1.95it/s]

train loss = 0.004729436118602753, train acc = 0.54485, test loss = 0.004800138235092163, test acc = 0.545


 58%|█████▊    | 35/60 [00:17<00:12,  1.94it/s]

train loss = 0.004726193840503693, train acc = 0.54444, test loss = 0.004801130294799805, test acc = 0.546


 60%|██████    | 36/60 [00:18<00:12,  1.96it/s]

train loss = 0.004727589058876038, train acc = 0.5443, test loss = 0.004800123453140259, test acc = 0.548


 62%|██████▏   | 37/60 [00:18<00:11,  1.96it/s]

train loss = 0.004724250409603119, train acc = 0.54469, test loss = 0.004801476955413818, test acc = 0.545


 63%|██████▎   | 38/60 [00:19<00:11,  1.97it/s]

train loss = 0.004726832294464111, train acc = 0.54513, test loss = 0.00480097246170044, test acc = 0.546


 65%|██████▌   | 39/60 [00:19<00:10,  1.97it/s]

train loss = 0.004725964632034302, train acc = 0.54425, test loss = 0.004803264617919922, test acc = 0.539


 67%|██████▋   | 40/60 [00:20<00:10,  1.96it/s]

train loss = 0.004725330312252044, train acc = 0.54449, test loss = 0.004803256034851074, test acc = 0.539


 68%|██████▊   | 41/60 [00:20<00:09,  1.97it/s]

train loss = 0.004721378409862518, train acc = 0.54399, test loss = 0.004802314043045044, test acc = 0.544


 70%|███████   | 42/60 [00:21<00:09,  1.98it/s]

train loss = 0.004723547480106354, train acc = 0.54468, test loss = 0.004807394981384277, test acc = 0.54


 72%|███████▏  | 43/60 [00:21<00:08,  1.98it/s]

train loss = 0.004723051340579987, train acc = 0.54507, test loss = 0.004801590681076049, test acc = 0.555


 73%|███████▎  | 44/60 [00:22<00:08,  1.98it/s]

train loss = 0.004722842876911163, train acc = 0.54431, test loss = 0.004803121566772461, test acc = 0.55


 75%|███████▌  | 45/60 [00:22<00:07,  1.96it/s]

train loss = 0.0047224345064163205, train acc = 0.54465, test loss = 0.004803490877151489, test acc = 0.553


 77%|███████▋  | 46/60 [00:23<00:07,  1.97it/s]

train loss = 0.0047222945070266725, train acc = 0.5448, test loss = 0.004805148601531983, test acc = 0.547


 78%|███████▊  | 47/60 [00:23<00:06,  1.98it/s]

train loss = 0.00472210230588913, train acc = 0.54464, test loss = 0.0048025517463684084, test acc = 0.544


 80%|████████  | 48/60 [00:24<00:06,  1.97it/s]

train loss = 0.0047208940148353576, train acc = 0.54527, test loss = 0.0048052148818969725, test acc = 0.552


 82%|████████▏ | 49/60 [00:25<00:05,  1.95it/s]

train loss = 0.004719358348846435, train acc = 0.54441, test loss = 0.004804939985275269, test acc = 0.547


 83%|████████▎ | 50/60 [00:25<00:05,  1.96it/s]

train loss = 0.004717274050712586, train acc = 0.54558, test loss = 0.004807147741317749, test acc = 0.54


 85%|████████▌ | 51/60 [00:26<00:04,  1.98it/s]

train loss = 0.00471675062417984, train acc = 0.5457, test loss = 0.004803912401199341, test acc = 0.554


 87%|████████▋ | 52/60 [00:26<00:04,  1.98it/s]

train loss = 0.004719057900905609, train acc = 0.54443, test loss = 0.0048057839870452885, test acc = 0.556


 88%|████████▊ | 53/60 [00:27<00:03,  1.98it/s]

train loss = 0.004718918030261993, train acc = 0.54431, test loss = 0.004800783634185791, test acc = 0.548


 90%|█████████ | 54/60 [00:27<00:03,  1.99it/s]

train loss = 0.0047166458344459535, train acc = 0.54469, test loss = 0.004804550170898437, test acc = 0.55


 92%|█████████▏| 55/60 [00:28<00:02,  1.97it/s]

train loss = 0.00471531622171402, train acc = 0.54506, test loss = 0.004800571918487549, test acc = 0.554


 93%|█████████▎| 56/60 [00:28<00:02,  1.97it/s]

train loss = 0.004714936470985412, train acc = 0.54476, test loss = 0.004799603462219238, test acc = 0.544


 95%|█████████▌| 57/60 [00:29<00:01,  1.97it/s]

train loss = 0.004715318689346313, train acc = 0.54452, test loss = 0.004798198223114013, test acc = 0.553


 97%|█████████▋| 58/60 [00:29<00:01,  1.73it/s]

train loss = 0.004716505017280578, train acc = 0.54516, test loss = 0.004796977996826172, test acc = 0.562


 98%|█████████▊| 59/60 [00:30<00:00,  1.80it/s]

train loss = 0.004711396481990814, train acc = 0.54532, test loss = 0.004798660516738892, test acc = 0.546


100%|██████████| 60/60 [00:30<00:00,  1.94it/s]

train loss = 0.004711436553001404, train acc = 0.54537, test loss = 0.004799042463302612, test acc = 0.56



  2%|▏         | 1/60 [00:00<00:30,  1.93it/s]

train loss = 0.009426709961891174, train acc = 0.09919, test loss = 0.00929432487487793, test acc = 0.174


  3%|▎         | 2/60 [00:01<00:29,  1.95it/s]

train loss = 0.008106940677165985, train acc = 0.38674, test loss = 0.007089451789855957, test acc = 0.487


  5%|▌         | 3/60 [00:01<00:29,  1.93it/s]

train loss = 0.0061897466826438905, train acc = 0.48288, test loss = 0.005816920518875122, test acc = 0.487


  7%|▋         | 4/60 [00:02<00:29,  1.89it/s]

train loss = 0.00542929182767868, train acc = 0.50318, test loss = 0.0053168065547943115, test acc = 0.547


  8%|▊         | 5/60 [00:02<00:28,  1.91it/s]

train loss = 0.005042817702293396, train acc = 0.54468, test loss = 0.00504521131515503, test acc = 0.548


 10%|█         | 6/60 [00:03<00:28,  1.91it/s]

train loss = 0.004881454248428345, train acc = 0.54474, test loss = 0.004948205471038818, test acc = 0.547


 12%|█▏        | 7/60 [00:03<00:28,  1.89it/s]

train loss = 0.004819621262550354, train acc = 0.54474, test loss = 0.004906187534332276, test acc = 0.547


 13%|█▎        | 8/60 [00:04<00:27,  1.90it/s]

train loss = 0.004790864133834839, train acc = 0.54485, test loss = 0.004885115146636963, test acc = 0.549


 15%|█▌        | 9/60 [00:04<00:27,  1.88it/s]

train loss = 0.0047744172978401184, train acc = 0.54485, test loss = 0.004867448806762696, test acc = 0.547


 17%|█▋        | 10/60 [00:05<00:26,  1.90it/s]

train loss = 0.004762561876773834, train acc = 0.54461, test loss = 0.004857710599899292, test acc = 0.55


 18%|█▊        | 11/60 [00:05<00:25,  1.90it/s]

train loss = 0.004754691979885101, train acc = 0.54509, test loss = 0.004846793174743652, test acc = 0.545


 20%|██        | 12/60 [00:06<00:25,  1.91it/s]

train loss = 0.004746743626594543, train acc = 0.54553, test loss = 0.004840334177017212, test acc = 0.551


 22%|██▏       | 13/60 [00:06<00:24,  1.93it/s]

train loss = 0.004744234483242035, train acc = 0.54492, test loss = 0.004837571859359741, test acc = 0.546


 23%|██▎       | 14/60 [00:07<00:23,  1.95it/s]

train loss = 0.004737803800106049, train acc = 0.54502, test loss = 0.004832149028778076, test acc = 0.544


 25%|██▌       | 15/60 [00:07<00:22,  1.96it/s]

train loss = 0.0047362820219993594, train acc = 0.54444, test loss = 0.004826584577560425, test acc = 0.545


 27%|██▋       | 16/60 [00:08<00:22,  1.97it/s]

train loss = 0.0047310435962677, train acc = 0.54459, test loss = 0.004823717355728149, test acc = 0.543


 28%|██▊       | 17/60 [00:08<00:21,  1.97it/s]

train loss = 0.00473135694026947, train acc = 0.54432, test loss = 0.004816745519638062, test acc = 0.54


 30%|███       | 18/60 [00:09<00:21,  1.97it/s]

train loss = 0.0047277902436256404, train acc = 0.54428, test loss = 0.004815586805343628, test acc = 0.541


 32%|███▏      | 19/60 [00:09<00:20,  1.97it/s]

train loss = 0.00472312068939209, train acc = 0.54469, test loss = 0.004815646886825562, test acc = 0.541


 33%|███▎      | 20/60 [00:10<00:20,  1.97it/s]

train loss = 0.004722788207530975, train acc = 0.5451, test loss = 0.004809986352920532, test acc = 0.546


 35%|███▌      | 21/60 [00:10<00:19,  1.97it/s]

train loss = 0.004723888597488403, train acc = 0.5455, test loss = 0.004810807228088379, test acc = 0.539


 37%|███▋      | 22/60 [00:11<00:19,  1.95it/s]

train loss = 0.004723118035793304, train acc = 0.54426, test loss = 0.004809895038604736, test acc = 0.543


 38%|███▊      | 23/60 [00:11<00:19,  1.90it/s]

train loss = 0.004721186146736145, train acc = 0.54595, test loss = 0.004806874513626098, test acc = 0.554


 40%|████      | 24/60 [00:12<00:19,  1.89it/s]

train loss = 0.004718623450994492, train acc = 0.54577, test loss = 0.004806345224380493, test acc = 0.55


 42%|████▏     | 25/60 [00:13<00:18,  1.89it/s]

train loss = 0.004718859982490539, train acc = 0.54593, test loss = 0.004804351806640625, test acc = 0.546


 43%|████▎     | 26/60 [00:13<00:18,  1.89it/s]

train loss = 0.004717520599365235, train acc = 0.54618, test loss = 0.004805635213851928, test acc = 0.551


 45%|████▌     | 27/60 [00:14<00:17,  1.84it/s]

train loss = 0.004717654333114624, train acc = 0.54504, test loss = 0.004808368921279907, test acc = 0.55


 47%|████▋     | 28/60 [00:14<00:17,  1.82it/s]

train loss = 0.004717586119174957, train acc = 0.54528, test loss = 0.004806197643280029, test acc = 0.553


 48%|████▊     | 29/60 [00:15<00:17,  1.80it/s]

train loss = 0.004714353103637695, train acc = 0.54531, test loss = 0.004804039478302002, test acc = 0.546


 50%|█████     | 30/60 [00:15<00:16,  1.79it/s]

train loss = 0.0047175211501121524, train acc = 0.54477, test loss = 0.004802613973617554, test acc = 0.55


 52%|█████▏    | 31/60 [00:16<00:16,  1.75it/s]

train loss = 0.00471441335439682, train acc = 0.54553, test loss = 0.004805410623550415, test acc = 0.546


 53%|█████▎    | 32/60 [00:16<00:16,  1.74it/s]

train loss = 0.004712815001010895, train acc = 0.54544, test loss = 0.004806144952774048, test acc = 0.54


 55%|█████▌    | 33/60 [00:17<00:15,  1.74it/s]

train loss = 0.004713303730487823, train acc = 0.54557, test loss = 0.00480569863319397, test acc = 0.547


 57%|█████▋    | 34/60 [00:18<00:15,  1.68it/s]

train loss = 0.004712074890136719, train acc = 0.54645, test loss = 0.004805330276489258, test acc = 0.543


 58%|█████▊    | 35/60 [00:18<00:14,  1.67it/s]

train loss = 0.004714241592884064, train acc = 0.54566, test loss = 0.004802357196807861, test acc = 0.548


 60%|██████    | 36/60 [00:19<00:14,  1.69it/s]

train loss = 0.004714335961341858, train acc = 0.54551, test loss = 0.004803643465042114, test acc = 0.54


 62%|██████▏   | 37/60 [00:19<00:13,  1.71it/s]

train loss = 0.00471155179977417, train acc = 0.54603, test loss = 0.0048005242347717286, test acc = 0.543


 63%|██████▎   | 38/60 [00:20<00:12,  1.73it/s]

train loss = 0.004712367556095123, train acc = 0.54613, test loss = 0.0048016650676727295, test acc = 0.548


 65%|██████▌   | 39/60 [00:21<00:12,  1.74it/s]

train loss = 0.004712925031185151, train acc = 0.54615, test loss = 0.004803219795227051, test acc = 0.551


 67%|██████▋   | 40/60 [00:21<00:11,  1.73it/s]

train loss = 0.004712832016944885, train acc = 0.54567, test loss = 0.004800068855285645, test acc = 0.547


 68%|██████▊   | 41/60 [00:22<00:10,  1.74it/s]

train loss = 0.004709949038028717, train acc = 0.54598, test loss = 0.004799007177352905, test acc = 0.552


 70%|███████   | 42/60 [00:22<00:10,  1.73it/s]

train loss = 0.004710396590232849, train acc = 0.54621, test loss = 0.004801973581314087, test acc = 0.551


 72%|███████▏  | 43/60 [00:23<00:09,  1.74it/s]

train loss = 0.004713143482208252, train acc = 0.54532, test loss = 0.004801630735397339, test acc = 0.548


 73%|███████▎  | 44/60 [00:23<00:09,  1.75it/s]

train loss = 0.004711408755779267, train acc = 0.54593, test loss = 0.004799297571182251, test acc = 0.55


 75%|███████▌  | 45/60 [00:24<00:08,  1.73it/s]

train loss = 0.00471184850692749, train acc = 0.54661, test loss = 0.00480029821395874, test acc = 0.557


 77%|███████▋  | 46/60 [00:25<00:08,  1.73it/s]

train loss = 0.004711669647693634, train acc = 0.54535, test loss = 0.004801958799362183, test acc = 0.549


 78%|███████▊  | 47/60 [00:25<00:07,  1.74it/s]

train loss = 0.004712610893249512, train acc = 0.54665, test loss = 0.004803060531616211, test acc = 0.55


 80%|████████  | 48/60 [00:26<00:06,  1.74it/s]

train loss = 0.004710597724914551, train acc = 0.54634, test loss = 0.00480098032951355, test acc = 0.554


 82%|████████▏ | 49/60 [00:26<00:06,  1.75it/s]

train loss = 0.004712556512355804, train acc = 0.54603, test loss = 0.004801634550094604, test acc = 0.548


 83%|████████▎ | 50/60 [00:27<00:05,  1.73it/s]

train loss = 0.00471069073677063, train acc = 0.54538, test loss = 0.004802810907363892, test acc = 0.554


 85%|████████▌ | 51/60 [00:28<00:05,  1.70it/s]

train loss = 0.004710501909255981, train acc = 0.54682, test loss = 0.004803919792175293, test acc = 0.537


 87%|████████▋ | 52/60 [00:28<00:04,  1.71it/s]

train loss = 0.0047120732140541076, train acc = 0.54637, test loss = 0.004799614191055298, test acc = 0.549


 88%|████████▊ | 53/60 [00:29<00:04,  1.73it/s]

train loss = 0.0047127160143852235, train acc = 0.54599, test loss = 0.004801172971725464, test acc = 0.543


 90%|█████████ | 54/60 [00:29<00:03,  1.74it/s]

train loss = 0.004711239960193634, train acc = 0.54667, test loss = 0.004801924467086792, test acc = 0.547


 92%|█████████▏| 55/60 [00:30<00:03,  1.60it/s]

train loss = 0.004712237675189972, train acc = 0.54664, test loss = 0.004800886392593384, test acc = 0.556


 93%|█████████▎| 56/60 [00:31<00:02,  1.62it/s]

train loss = 0.00470966415643692, train acc = 0.54617, test loss = 0.004800062656402588, test acc = 0.545


 95%|█████████▌| 57/60 [00:31<00:01,  1.65it/s]

train loss = 0.00471140691280365, train acc = 0.54622, test loss = 0.0048009040355682375, test acc = 0.551


 97%|█████████▋| 58/60 [00:32<00:01,  1.68it/s]

train loss = 0.0047106051921844486, train acc = 0.54573, test loss = 0.004798409700393677, test acc = 0.557


 98%|█████████▊| 59/60 [00:32<00:00,  1.71it/s]

train loss = 0.004712594873905182, train acc = 0.54629, test loss = 0.004800342798233033, test acc = 0.552


100%|██████████| 60/60 [00:33<00:00,  1.80it/s]

train loss = 0.004710559675693512, train acc = 0.54626, test loss = 0.004805691242218017, test acc = 0.538





In [74]:
ans_map

{(1, 1): 0,
 (1, 2): 1,
 (1, 3): 2,
 (1, 4): 3,
 (2, 2): 4,
 (2, 3): 5,
 (2, 4): 6,
 (3, 3): 7,
 (3, 4): 8,
 (4, 4): 9}

In [82]:
def get_rankings(seq, model):
    keys = seq[:-1]
    X = t.Tensor(seq).unsqueeze(0).int()
    logits = model(X)
    scores = model.activation_cache['hook_attn_scores'][0]
    rankings = {}

    for head in range(model_config.n_heads):
        ## get scores of attention from last query pos to all other (key) positions
        rankings[head] = sorted(list(zip(range(len(keys)),keys, scores[head,-1,:-1].detach())), key = lambda x: x[-1], reverse = True)

    return rankings
    

In [85]:
indx = list(ans_map.keys())
indx.append("nan")
x_labels = [str(x) for x in indx]

for dh in range(1,6):
    pre_trained = models[dh]
    head_outputs = t.zeros((model_config.n_heads, model_config.n_ctx, model_config.d_vocab))

    ## getting DLA from each token to each possible continuation
    for i in range(1,5):
        seq = [i,0]
        X = t.Tensor(seq).unsqueeze(0).int()
        for head in range(model_config.n_heads):
            temp = get_head_to_resid(pre_trained, pre_trained.activation_cache, head, X)
            temp = get_head_to_logits(pre_trained, temp)
            #print(head_outputs[head, i-1].shape, temp[0,-1].shape)
            head_outputs[head, i-1,:] = temp[0,-1,:]
    
    seq = [1,2,3,4,0]
    rankings = get_rankings(seq, pre_trained)


    ## getting attention of each token per head


    print(f"Attention ranking for d_head = {dh} is {rankings}")
    imshow(head_outputs, facet_col = 0, y = [1,2,3,4], x = x_labels, title = f"Head Outputs Conditional on Attn for d_head = {dh}",xaxis = "Logit contribution to completions", yaxis = "Source token")

Attention ranking for d_head = 1 is {0: [(3, 4, tensor(10.7204)), (0, 1, tensor(10.5642)), (2, 3, tensor(2.6849)), (1, 2, tensor(2.4218))], 1: [(2, 3, tensor(11.4649)), (0, 1, tensor(3.0419)), (1, 2, tensor(3.0279)), (3, 4, tensor(2.1791))]}


Attention ranking for d_head = 2 is {0: [(0, 1, tensor(10.3109)), (3, 4, tensor(9.4504)), (1, 2, tensor(5.5252)), (2, 3, tensor(-1.1040))], 1: [(2, 3, tensor(11.0803)), (1, 2, tensor(10.9012)), (3, 4, tensor(5.5689)), (0, 1, tensor(-2.3904))]}


Attention ranking for d_head = 3 is {0: [(0, 1, tensor(19.8612)), (1, 2, tensor(14.4581)), (3, 4, tensor(5.2263)), (2, 3, tensor(-2.5995))], 1: [(2, 3, tensor(11.7716)), (3, 4, tensor(9.9533)), (1, 2, tensor(8.4585)), (0, 1, tensor(-1.5127))]}


Attention ranking for d_head = 4 is {0: [(1, 2, tensor(22.8451)), (0, 1, tensor(11.3628)), (2, 3, tensor(4.5868)), (3, 4, tensor(-4.1208))], 1: [(3, 4, tensor(20.1046)), (2, 3, tensor(16.7878)), (0, 1, tensor(7.9085)), (1, 2, tensor(-2.9431))]}


Attention ranking for d_head = 5 is {0: [(1, 2, tensor(20.5462)), (0, 1, tensor(11.0641)), (3, 4, tensor(3.4046)), (2, 3, tensor(-3.5068))], 1: [(2, 3, tensor(22.3207)), (3, 4, tensor(12.2126)), (0, 1, tensor(3.2533)), (1, 2, tensor(-7.4202))]}


In [88]:
## save models where heads share attention to the same source token as a different way of acheiving the task

#t.save(models[1], 'dhead1_new.pth')


In [96]:
ans_map

{(1, 1): 0,
 (1, 2): 1,
 (1, 3): 2,
 (1, 4): 3,
 (2, 2): 4,
 (2, 3): 5,
 (2, 4): 6,
 (3, 3): 7,
 (3, 4): 8,
 (4, 4): 9}

In [164]:
def get_scores(model, seq):
    X = t.Tensor(seq).unsqueeze(0).int()
    logits = model(X)
    scores = model.activation_cache['hook_attn_scores'][0]
    return scores

seq = [1,2,3,4,0]

test_2 = t.load('saved_models/dhead2_new.pth')
test_1 = t.load('saved_models/dhead1_new.pth')


print(t.stack([get_scores(test_1, seq)[head, -1, :-1].detach() for head in range(model_config.n_heads)]))

print(t.stack([get_scores(test_2, seq)[head, -1, :-1].detach() for head in range(model_config.n_heads)]))
#imshow(attn_list, facet_col = 0,  title = f"Attention Scores per Source Token for d_head = {1}")

tensor([[10.5642,  2.4218,  2.6849, 10.7204],
        [ 3.0419,  3.0279, 11.4649,  2.1791]])
tensor([[ -7.8753,  11.1438,  13.2931,  11.7319],
        [ 19.8794,  20.3196, -18.3274,   8.4864]])


In [160]:
seq = [2,5,3,0]


show_pattern(test_2.activation_cache, seq, test_2)
show_pattern(test_1.activation_cache, seq, test_1)

# for i in range(2,6):
#     show_pattern(models[i].activation_cache, seq, models[i])

In [158]:
seq = [1,2,3,4,0]
print(get_rankings(seq, test_1))
get_rankings(seq, test_2)

{0: [(3, 4, tensor(10.7204)), (0, 1, tensor(10.5642)), (2, 3, tensor(2.6849)), (1, 2, tensor(2.4218))], 1: [(2, 3, tensor(11.4649)), (0, 1, tensor(3.0419)), (1, 2, tensor(3.0279)), (3, 4, tensor(2.1791))]}


{0: [(2, 3, tensor(13.2931)),
  (3, 4, tensor(11.7319)),
  (1, 2, tensor(11.1438)),
  (0, 1, tensor(-7.8753))],
 1: [(1, 2, tensor(20.3196)),
  (0, 1, tensor(19.8794)),
  (3, 4, tensor(8.4864)),
  (2, 3, tensor(-18.3274))]}

For binomial distribution (k=1 Bernoulli trials), standard error for a sample stat is \sqrt(pq/n) ~ .015 for the val set, so possibly overfitting is pulling us out of 

In [90]:
test_err = np.sqrt(1/(4*train_config.valset_size))

train_err = np.sqrt(1/(4*train_config.trainset_size))
for ind in range(1,6):
    #print(train_losses[ind][-1],train_accs[ind][-1], test_losses[ind][-1], test_accs[ind][-1])
    print(f"For d_head = {ind}, the final test accuracy is {test_accs[ind][-1]}, the max test accuracy is {np.max(test_accs[ind])} at epoch {np.argmax(test_accs[ind])}/{train_config.epochs}")
    print(f"For d_head = {ind}, the final training accuracy is {train_accs[ind][-1]}, the max test accuracy is {np.max(train_accs[ind])} at epoch {np.argmax(train_accs[ind])}/{train_config.epochs}")

    print(f"test error within range? {np.abs(test_accs[ind][-1] - np.max(test_accs[ind])) <= test_err}")

    print(f"train error within range? {np.abs(train_accs[ind][-1] - np.max(train_accs[ind])) <= test_err}")
    plot_metrics(train_losses[ind],train_accs[ind], test_losses[ind], test_accs[ind]) 

For d_head = 1, the final test accuracy is 0.547, the max test accuracy is 0.547 at epoch 24/60
For d_head = 1, the final training accuracy is 0.54463, the max test accuracy is 0.54476 at epoch 25/60
test error within range? True
train error within range? True


For d_head = 2, the final test accuracy is 0.549, the max test accuracy is 0.551 at epoch 50/60
For d_head = 2, the final training accuracy is 0.54461, the max test accuracy is 0.54499 at epoch 12/60
test error within range? True
train error within range? True


For d_head = 3, the final test accuracy is 0.551, the max test accuracy is 0.559 at epoch 38/60
For d_head = 3, the final training accuracy is 0.54547, the max test accuracy is 0.54613 at epoch 40/60
test error within range? True
train error within range? True


For d_head = 4, the final test accuracy is 0.56, the max test accuracy is 0.562 at epoch 57/60
For d_head = 4, the final training accuracy is 0.54537, the max test accuracy is 0.5457 at epoch 50/60
test error within range? True
train error within range? True


For d_head = 5, the final test accuracy is 0.538, the max test accuracy is 0.557 at epoch 44/60
For d_head = 5, the final training accuracy is 0.54626, the max test accuracy is 0.54682 at epoch 50/60
test error within range? False
train error within range? True


In [91]:
test_signal = SkipTrigramDataset(size = 10000,
                                     prob = 1,
                                     prompt_length = train_config.ctx_length)

for dh in [1,2,3,4,5]:
    model = models[dh]
    acc = 0

    for i in range(len(test_signal)):
        X = test_signal[i][0].unsqueeze(0) # add batch dim to feed into model
        y = test_signal[i][1].item()

        output = model(X)[:,-1]
        corr = (output.argmax(dim=-1) == y).count_nonzero()
        acc += corr.item()

    print(acc/len(test_signal))




1.0
1.0
1.0
1.0
1.0


In [92]:
seq = [3,5,2,0]
for i in range(1,6):
    print(get_rankings(seq, models[i]))
    show_pattern(models[i].activation_cache, seq, models[i])

{0: [(0, 3, tensor(2.6849)), (2, 2, tensor(2.4218)), (1, 5, tensor(-7.2294))], 1: [(0, 3, tensor(11.4649)), (2, 2, tensor(3.0279)), (1, 5, tensor(-3.0561))]}


{0: [(2, 2, tensor(5.5252)), (0, 3, tensor(-1.1040)), (1, 5, tensor(-6.0368))], 1: [(0, 3, tensor(11.0803)), (2, 2, tensor(10.9012)), (1, 5, tensor(-7.2460))]}


{0: [(2, 2, tensor(14.4581)), (0, 3, tensor(-2.5995)), (1, 5, tensor(-10.5149))], 1: [(0, 3, tensor(11.7716)), (2, 2, tensor(8.4585)), (1, 5, tensor(-12.0351))]}


{0: [(2, 2, tensor(22.8451)), (0, 3, tensor(4.5868)), (1, 5, tensor(-14.3393))], 1: [(0, 3, tensor(16.7878)), (2, 2, tensor(-2.9431)), (1, 5, tensor(-20.8075))]}


{0: [(2, 2, tensor(20.5462)), (0, 3, tensor(-3.5068)), (1, 5, tensor(-15.0189))], 1: [(0, 3, tensor(22.3207)), (2, 2, tensor(-7.4202)), (1, 5, tensor(-15.3385))]}


Does the QK matrix store things orthogonally? (It's not sparse?)

In [127]:
def cos_sim(v1, v2):
    with t.no_grad():
        out = einsum("d, d -> ", v1, v2)/(t.norm(v1) * t.norm(v2))
    return out

assert cos_sim(t.tensor([1.,0.,0.]), t.tensor([0.,4.,0.])) == t.tensor(0.)
assert cos_sim(t.tensor([1.,0.,0.]), t.tensor([2.,0.,0.])) == t.tensor(1.)

In [152]:
## Queries have shape [batch, pos_q, n_heads, d_head]
## Keys have shape [batch, pos_k, n_heads, d_head]
model = models[5]

queries = rearrange(model.activation_cache['hook_q'][0], "q nh dh -> nh q dh") #remove batch dim and rearrange so that d_head is first
keys = rearrange(model.activation_cache['hook_k'][0], "q nh dh -> nh q dh") #remove batch dim and rearrange so that d_head is first

## get the cosim for all pairs of keys and queries

head_idx = 0
[[cos_sim(keys[head_idx,i], queries[head_idx, j]) for j in range(queries.shape[1])] for i in range(keys.shape[1])] ## [pos_k, pos_q]


[[tensor(0.9921), tensor(0.8856), tensor(0.9923), tensor(0.9983)],
 [tensor(-0.8149), tensor(-0.4114), tensor(-0.8022), tensor(-0.7912)],
 [tensor(0.9962), tensor(0.8557), tensor(0.9952), tensor(0.9989)],
 [tensor(-0.9923), tensor(-0.8798), tensor(-0.9922), tensor(-0.9984)]]