In [1]:
import os
import sys
import requests
from pathlib import Path
import numpy as np
import plotly.express as px
from tqdm import tqdm
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, reduce, repeat
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from torch.utils.data import Dataset, DataLoader
from dataclasses import dataclass
from fancy_einsum import einsum
import math 

device = t.device("cuda" if t.cuda.is_available() else "cpu")

Create the toy dataset

There are two options for input sequences: noise and signal. The signal contains a 0 preceded by 2 numbers less in [1,4], and has a ground truth answer given by an arbitrary map. 

In [2]:
ans_map = {(1, 1):0, (1, 2):1, (1, 3):2, (1, 4):3, (2, 2):4, (2, 3):5, (2, 4):6, (3,3):7, (3, 4):8, (4, 4):9}
class SkipTrigramDataset(Dataset):

    def __init__(self, size: int, prob:float, prompt_length: int, seed:int = 42):
        self.size = size
        self.prob = prob
        self.prompt_length = prompt_length
        self.data = [self.generate_sample() for _ in range(size)]
        t.manual_seed(seed=seed)

    def generate_sample(self):
        random_tensor = t.randint(low = 5, high = 11, size = (self.prompt_length,))
        random_int = t.randint(low = 0, high = self.prob, size = (1,))
        if random_int.item() == self.prob - 1:
            rand_num_1 = t.randint(1, 5, size = (1,)).item()
            rand_num_2 = t.randint(1, 5, size = (1,)).item()
            random_tensor[0] = rand_num_1
            random_tensor[-2] = rand_num_2
            sample = random_tensor
            sample[-1] = 0
            label = ans_map[tuple(sorted((rand_num_1, rand_num_2)))]
            label = t.tensor([label])
        else:
            sample = random_tensor
            label = t.randint(low = 0, high = 11, size = (1,))

        return sample, label


    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.size

In [3]:
ans_map = {(1, 1):0, (1, 2):1, (1, 3):2, (1, 4):3, (2, 2):4, (2, 3):5, (2, 4):6, (3,3):7, (3, 4):8, (4, 4):9}

example_dataset = SkipTrigramDataset(size = 10_000, prob = 2, prompt_length = 4)
for i in range(10):
    print(f'Sample: {example_dataset[i][0]}, label: {example_dataset[i][1][-1]}')



Sample: tensor([2, 9, 2, 0]), label: 4
Sample: tensor([ 2, 10,  1,  0]), label: 1
Sample: tensor([7, 5, 5, 5]), label: 3
Sample: tensor([2, 9, 4, 0]), label: 6
Sample: tensor([7, 6, 6, 5]), label: 0
Sample: tensor([8, 8, 9, 9]), label: 8
Sample: tensor([ 2, 10,  1,  0]), label: 1
Sample: tensor([ 9,  5, 10,  6]), label: 7
Sample: tensor([6, 8, 5, 7]), label: 2
Sample: tensor([2, 6, 2, 0]), label: 4


In [4]:
signal = [example[1].item() for example in example_dataset if example[0][-1] == 0]
noise = [example[1].item() for example in example_dataset if example[0][-1] != 0]

print(len(signal), len(noise))

px.histogram(signal)
#px.histogram(noise)

## repeated ints < 5 are half as likely in the signal dataset. This shouldn't matter, but could check to see if those outputs (0, 4, 7, 9) are treated systematically differently by the model


4935 5065


In [5]:
@dataclass
class model_cfg:
    d_model: int = 12
    d_vocab: int = 12
    init_range: float = 0.02
    n_ctx: int = 5
    d_head: int = 3
    n_heads: int = 3
    n_layers: int = 1

model_config = model_cfg()

print(model_config)

model_cfg(d_model=12, d_vocab=12, init_range=0.02, n_ctx=5, d_head=3, n_heads=3, n_layers=1)


In [6]:
class OneHotEmbed(nn.Module):
    def __init__(self, model_cfg):
        super().__init__()
        self.model_cfg = model_cfg
        self.embedding = nn.Embedding(self.model_cfg.d_vocab, self.model_cfg.d_vocab)
        self.embedding.weight = nn.Parameter(t.eye(self.model_cfg.d_vocab))
        self.embedding.weight.requires_grad = False

    def forward(self, input):
        # tokens: [batch, position]
        # output: [batch, position, d_model]
        input = input.to(device)
        embed_out = self.embedding(input)
        return embed_out
    
class OneHotUnembed(nn.Module):
    def __init__(self, model_cfg):
        super().__init__()
        self.model_cfg = model_cfg
        self.unembed = t.eye(self.model_cfg.d_vocab)

    def forward(self, input):
        # tokens: [batch, position, d_model]
        # output: [batch, position]
        input = input.to(device)
        return t.einsum("vm,bpm->bpv", self.unembed.to(device), input)  # [batch, pos, d_vocab]


In [7]:
class Attention(nn.Module):
    def __init__(self, model_cfg):
        super().__init__()
        self.model_cfg = model_cfg

        # set weights, biases as parameters 
        self.W_Q = nn.Parameter(t.empty((model_cfg.n_heads, model_cfg.d_model, model_cfg.d_head)))
        self.W_K = nn.Parameter(t.empty((model_cfg.n_heads, model_cfg.d_model, model_cfg.d_head)))
        self.W_V = nn.Parameter(t.empty((model_cfg.n_heads, model_cfg.d_model, model_cfg.d_head)))
        self.W_O = nn.Parameter(t.empty((model_cfg.n_heads, model_cfg.d_head, model_cfg.d_model)))
        self.b_Q = nn.Parameter(t.zeros((model_cfg.n_heads, model_cfg.d_head)))
        self.b_K = nn.Parameter(t.zeros((model_cfg.n_heads, model_cfg.d_head)))
        self.b_V = nn.Parameter(t.zeros((model_cfg.n_heads, model_cfg.d_head)))
        self.b_O = nn.Parameter(t.zeros((model_cfg.d_model)))

        #initialize randomly
        nn.init.normal_(self.W_Q, std=self.model_cfg.init_range)
        nn.init.normal_(self.W_K, std=self.model_cfg.init_range)
        nn.init.normal_(self.W_V, std=self.model_cfg.init_range)
        nn.init.normal_(self.W_O, std=self.model_cfg.init_range)
        self.register_buffer("IGNORE", t.tensor(-1e5, dtype=t.float32, device=device))

        self.attn_activation_cache = {}

    def forward(self, normalized_resid_pre: t.Tensor):
        # normalized_resid_pre: [batch, position, d_model] (we're not using LN, so just input)
        # output: [batch, position, d_model]
        normalized_resid_pre = normalized_resid_pre.to(device)

        queries = einsum('n_heads d_model d_head, batch pos_q d_model -> batch pos_q n_heads d_head', self.W_Q, normalized_resid_pre) + self.b_Q
        keys = einsum('n_heads d_model d_head, batch pos_k d_model -> batch pos_k n_heads d_head', self.W_K, normalized_resid_pre) + self.b_K
        values = einsum('n_heads d_model d_head, batch pos_k d_model -> batch pos_k n_heads d_head', self.W_V, normalized_resid_pre) + self.b_V

        attn_scores = einsum('batch pos_q n_heads d_head, batch pos_k n_heads d_head -> batch n_heads pos_q pos_k', queries, keys)

        scaled_attn = attn_scores / math.sqrt(self.model_cfg.d_head)
        masked_attn = self.apply_causal_mask(scaled_attn)
        # apply softmax over keys to get attention pattern
        softmax_attn = masked_attn.softmax(dim = -1)

        # get weighted attention values
        z = einsum('batch pos_k n_heads d_head, batch n_heads pos_q pos_k -> batch pos_q n_heads d_head', values, softmax_attn)

        # layer output
        output = einsum('n_heads d_head d_model, batch pos_q n_heads d_head -> batch pos_q d_model', self.W_O, z) + self.b_O

        ## save to the cache 
        self.attn_activation_cache['hook_q'] = queries
        self.attn_activation_cache['hook_k'] = keys
        self.attn_activation_cache['hook_attn_scores'] = attn_scores
        self.attn_activation_cache['hook_pattern'] = softmax_attn
        self.attn_activation_cache['hook_v'] = values
        self.attn_activation_cache['hook_z'] = z
        self.attn_activation_cache['hook_attn_out'] = output

        return output

    def apply_causal_mask(self, attn_scores: t.Tensor):
        # attn_scores: [batch, n_heads, query_pos, key_pos]
        # output: [batch, n_heads, query_pos, key_pos]
        mask = t.triu(attn_scores, diagonal=1).bool()
        return attn_scores.masked_fill_(mask, self.IGNORE)

define a one layer attn only class      

In [8]:
class OneLayerAttnOnly(nn.Module):
    def __init__(self, model_cfg):
        super().__init__()
        self.model_cfg = model_cfg
        self.embed = OneHotEmbed(model_cfg)
        self.attn = Attention(model_cfg)
        self.unembed = OneHotUnembed(model_cfg)

        self.activation_cache = {}


    def forward(self, tokens):
        # tokens [batch, pos]
        # output [batch, pos, d_vocab]

        input = tokens.to(device)
        resid_pre = self.embed(input)
        resid_post = resid_pre + self.attn(resid_pre)
        out = self.unembed(resid_post)

        # add activations to the same cache as before
        self.activation_cache = self.attn.attn_activation_cache
        self.activation_cache['hook_resid_pre'] = resid_pre
        self.activation_cache['hook_resid_post'] = resid_post
        self.activation_cache['unembed'] = out

        return out
    

attn_only_1l = OneLayerAttnOnly(model_cfg).to(device)
attn_only_1l

        

OneLayerAttnOnly(
  (embed): OneHotEmbed(
    (embedding): Embedding(12, 12)
  )
  (attn): Attention()
  (unembed): OneHotUnembed()
)

In [9]:
def plot_metrics(train_loss,train_acc, test_loss, test_acc):
    fig = make_subplots(rows=2,cols=2, subplot_titles = ['Train Loss', 'Train Accuracy', 'Test Loss', 'Test Accuracy'])
    fig.add_trace(go.Scatter(y=train_loss), row=1, col=1)
    fig.add_trace(go.Scatter(y=train_acc), row=1, col=2)
    fig.add_trace(go.Scatter(y=test_loss), row=2, col=1)
    fig.add_trace(go.Scatter(y=test_acc), row=2, col=2)
    # get rid of legend
    fig.update_layout(showlegend=False)
    ## add x and y axis labels and change plot style 
    fig.update_xaxes(
        title = 'Epoch',
        zeroline=True,
        zerolinewidth=1,
        zerolinecolor='black',
        showgrid=False
    )
    fig.update_yaxes(
        title = 'Loss',
        showgrid=False,
        col=1
    )
    fig.update_yaxes(
        title = 'Accuracy',
        showgrid=False,
        col=2
    )
    fig.update_layout(
        font = dict(size =12),
        width=800,
        height =500,
        margin=dict(l=10, 
                    r = 20,
                    b=10,
                    t=20,
                    pad=0)
    )
    fig.show()

In [10]:
@dataclass
class train_cfg:
    ctx_length: int = 4
    trainset_size: int = 100000
    valset_size: int = 1000
    prob: int = 2
    epochs: int = 15
    batch_size: int = 512
    weight_decay: float = 0.0
    seed: int = 42
    lr: float = 1e-3

class Trainer:

    def __init__(self, train_cfg, model_cfg):
        self.train_cfg = train_cfg
        self.model = OneLayerAttnOnly(model_cfg).to(device)


    def train_dataloader(self, seed: int):
        trainset = SkipTrigramDataset(size=self.train_cfg.trainset_size,
                                      prob=self.train_cfg.prob,
                                      prompt_length=self.train_cfg.ctx_length)

        return DataLoader(trainset, batch_size=self.train_cfg.batch_size, shuffle=True)

    def val_dataloader(self, seed: int):
        valset = SkipTrigramDataset(size=self.train_cfg.valset_size,
                                    prob = self.train_cfg.prob,
                                    prompt_length=self.train_cfg.ctx_length)
        return DataLoader(valset, batch_size=self.train_cfg.batch_size, shuffle=False)

    def configure_optimizers(self):
        optimizer = t.optim.Adam(self.model.parameters(), lr=self.train_cfg.lr, weight_decay=self.train_cfg.weight_decay)
        return optimizer

In [11]:



def train(model, train_data, criterion, optimizer):

    avg_loss = 0
    acc = 0

    for _, batch in list(enumerate(train_data)):

        X = batch[0].to(device)
        y = batch[1].squeeze().to(device)
        assert X.shape[0] == y.shape[0]

        output = model(X)[:,-1]

        with t.no_grad():
            corr = (output.argmax(dim=-1) == y).count_nonzero()
        loss = criterion(output,y)
        loss.backward()

        optimizer.step()
        avg_loss += loss.item() * len(batch)
        acc += corr.item()
        optimizer.zero_grad()

    cache = model.activation_cache

    return avg_loss/train_cfg.trainset_size, acc/train_cfg.trainset_size, cache
    
def test(model, test_data, criterion):

    avg_loss = 0
    acc = 0

    for _, batch in list(enumerate(test_data)):
        X = batch[0].to(device)
        y = batch[1].squeeze().to(device)
        assert X.shape[0] == y.shape[0]

        # don't log the gradients for testing
        with t.no_grad():
            output = model(X)[:,-1]
        loss = criterion(output,y)
        corr = (output.argmax(dim=-1) == y).count_nonzero()
        avg_loss += loss.item() * len(batch)
        acc += corr.item()


    return avg_loss/train_cfg.valset_size, acc/train_cfg.valset_size

In [12]:
from tqdm import tqdm 

def train_and_eval(train_cfg, model_cfg, save_caches = []):

    trainer = Trainer(train_cfg, model_cfg)
    model = trainer.model
    train_data = trainer.train_dataloader(seed=train_cfg.seed)

    test_data = trainer.val_dataloader(seed=train_cfg.seed)
    criterion = nn.CrossEntropyLoss()
    optimizer = trainer.configure_optimizers()
    train_loss = []
    train_acc = []
    test_loss = []
    test_acc = []

    caches = []

    for n in tqdm(range(train_cfg.epochs)):
        loss, acc, cache = train(model, train_data, criterion, optimizer)

        train_loss.append(loss)
        train_acc.append(acc)
        t_loss, t_acc = test(model, test_data, criterion)
        test_loss.append(t_loss)
        test_acc.append(t_acc)
        print(f"train loss = {loss}, train acc = {acc}, test loss = {t_loss}, test acc = {t_acc}")

        if n in save_caches:
            caches.append(cache)

    return train_loss, train_acc, test_loss, test_acc, caches, trainer.model



Helper functions for plotting

In [13]:
from transformer_lens import utils, HookedTransformer, ActivationCache

def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [14]:
def get_head_to_resid(model, cache, head, seq):
  model(seq)
  z = cache["hook_z"][:, :, head, :]
  wo = model.attn.W_O[head, :, :]
  bo = model.attn.b_O
  out = einsum("batch pos d_head, d_head d_model -> batch pos d_model", z, wo) + bo
  return out

def get_head_to_logits(model, head_to_resid):
  return model.unembed(head_to_resid)

def show_pattern(cache, seq, model):
    X = t.Tensor(seq).unsqueeze(0).int()
    model(X)
    ax_label = [f"{i}: {z.item()}" for i, z in enumerate(X[0])]
    imshow(cache['hook_pattern'][0], facet_col = 0, x = ax_label, y = ax_label)

def get_head_to_resid(model, cache, head, seq):
  model(seq)
  z = cache["hook_z"][:, :, head, :]
  wo = model.attn.W_O[head, :, :]
  bo = model.attn.b_O
  out = einsum("batch pos d_head, d_head d_model -> batch pos d_model", z, wo) + bo
  return out

def get_head_to_logits(model, head_to_resid):
  return model.unembed(head_to_resid)





Running for many iterations of d_head

In [17]:
train_cfg.seed

42

In [15]:
train_losses = {}
train_accs = {}
test_losses = {}
test_accs = {}
models = {}
for dh in [1,2,3,4,5]:

    train_config = train_cfg(
        ctx_length=4,
        weight_decay=0,
        trainset_size = 100000,
        prob = 2,
        epochs = 80,
        batch_size = 512
    )

    model_config = model_cfg(
        d_model = 11,
        d_vocab = 11,
        d_head = dh,
        init_range = 0.02,
        n_ctx = 4,
        n_heads = 2,
        n_layers = 1)

    train_losses[dh], train_accs[dh], test_losses[dh], test_accs[dh], _, models[dh] = train_and_eval(train_config, model_config)

  1%|▏         | 1/80 [00:00<00:36,  2.15it/s]

train loss = 0.009550961723327638, train acc = 0.07676, test loss = 0.009638270854949952, test acc = 0.084


  2%|▎         | 2/80 [00:00<00:33,  2.30it/s]

train loss = 0.009302466802597047, train acc = 0.10172, test loss = 0.009314502716064453, test acc = 0.168


  4%|▍         | 3/80 [00:01<00:33,  2.30it/s]

train loss = 0.008834695105552673, train acc = 0.19934, test loss = 0.008712907791137696, test acc = 0.298


  5%|▌         | 4/80 [00:01<00:32,  2.34it/s]

train loss = 0.008206376249790192, train acc = 0.2847, test loss = 0.008053151607513428, test acc = 0.364


  6%|▋         | 5/80 [00:02<00:31,  2.37it/s]

train loss = 0.007618644745349884, train acc = 0.30395, test loss = 0.0075563061237335204, test acc = 0.33


  8%|▊         | 6/80 [00:02<00:31,  2.37it/s]

train loss = 0.007186954641342163, train acc = 0.37076, test loss = 0.0071667275428771975, test acc = 0.38


  9%|▉         | 7/80 [00:02<00:30,  2.37it/s]

train loss = 0.006827733063697815, train acc = 0.38778, test loss = 0.006840481042861939, test acc = 0.378


 10%|█         | 8/80 [00:03<00:31,  2.30it/s]

train loss = 0.006542300763130188, train acc = 0.39606, test loss = 0.006587871074676514, test acc = 0.445


 11%|█▏        | 9/80 [00:03<00:30,  2.33it/s]

train loss = 0.006322209053039551, train acc = 0.45134, test loss = 0.006390723705291748, test acc = 0.447


 12%|█▎        | 10/80 [00:04<00:29,  2.35it/s]

train loss = 0.006150352396965027, train acc = 0.45115, test loss = 0.006239802122116089, test acc = 0.443


 14%|█▍        | 11/80 [00:04<00:29,  2.36it/s]

train loss = 0.0060201964521408085, train acc = 0.45116, test loss = 0.006117268323898316, test acc = 0.447


 15%|█▌        | 12/80 [00:05<00:28,  2.37it/s]

train loss = 0.0059129298329353335, train acc = 0.45137, test loss = 0.006025593757629394, test acc = 0.447


 16%|█▋        | 13/80 [00:05<00:28,  2.38it/s]

train loss = 0.005826706764698029, train acc = 0.45141, test loss = 0.0059453389644622805, test acc = 0.447


 18%|█▊        | 14/80 [00:05<00:27,  2.36it/s]

train loss = 0.005760289332866668, train acc = 0.45475, test loss = 0.005887062788009644, test acc = 0.474


 19%|█▉        | 15/80 [00:06<00:27,  2.37it/s]

train loss = 0.005702745277881623, train acc = 0.48247, test loss = 0.00583536434173584, test acc = 0.474


 20%|██        | 16/80 [00:06<00:26,  2.38it/s]

train loss = 0.005649068911075592, train acc = 0.48247, test loss = 0.005785156965255737, test acc = 0.474


 21%|██▏       | 17/80 [00:07<00:26,  2.39it/s]

train loss = 0.005605617749691009, train acc = 0.48247, test loss = 0.005740759134292603, test acc = 0.474


 22%|██▎       | 18/80 [00:07<00:25,  2.39it/s]

train loss = 0.005560840709209442, train acc = 0.48247, test loss = 0.005701068162918091, test acc = 0.474


 24%|██▍       | 19/80 [00:08<00:25,  2.36it/s]

train loss = 0.005522268440723419, train acc = 0.48247, test loss = 0.005662858486175537, test acc = 0.474


 25%|██▌       | 20/80 [00:08<00:25,  2.35it/s]

train loss = 0.00548581460237503, train acc = 0.48247, test loss = 0.005633593559265137, test acc = 0.474


 26%|██▋       | 21/80 [00:08<00:24,  2.37it/s]

train loss = 0.005452868139743805, train acc = 0.48247, test loss = 0.0056019511222839356, test acc = 0.474


 28%|██▊       | 22/80 [00:09<00:24,  2.38it/s]

train loss = 0.005424977436065674, train acc = 0.48247, test loss = 0.005579197883605957, test acc = 0.474


 29%|██▉       | 23/80 [00:09<00:23,  2.39it/s]

train loss = 0.005403270094394684, train acc = 0.48247, test loss = 0.005556965827941894, test acc = 0.474


 30%|███       | 24/80 [00:10<00:23,  2.37it/s]

train loss = 0.0053815675234794615, train acc = 0.48247, test loss = 0.005538103103637696, test acc = 0.474


 31%|███▏      | 25/80 [00:10<00:23,  2.35it/s]

train loss = 0.005365478851795197, train acc = 0.48247, test loss = 0.005522105455398559, test acc = 0.474


 32%|███▎      | 26/80 [00:11<00:22,  2.37it/s]

train loss = 0.005348674623966217, train acc = 0.48247, test loss = 0.005511218070983886, test acc = 0.474


 34%|███▍      | 27/80 [00:11<00:22,  2.37it/s]

train loss = 0.005335283501148224, train acc = 0.48247, test loss = 0.0054978299140930175, test acc = 0.474


 35%|███▌      | 28/80 [00:11<00:22,  2.33it/s]

train loss = 0.005323798019886017, train acc = 0.48247, test loss = 0.005489046335220337, test acc = 0.474


 36%|███▋      | 29/80 [00:12<00:21,  2.35it/s]

train loss = 0.005316620354652405, train acc = 0.48247, test loss = 0.005480691432952881, test acc = 0.474


 38%|███▊      | 30/80 [00:12<00:21,  2.37it/s]

train loss = 0.005307605211734772, train acc = 0.48247, test loss = 0.005475027322769165, test acc = 0.474


 39%|███▉      | 31/80 [00:13<00:20,  2.38it/s]

train loss = 0.005301489360332489, train acc = 0.48247, test loss = 0.00546956467628479, test acc = 0.474


 40%|████      | 32/80 [00:13<00:20,  2.36it/s]

train loss = 0.005295080690383911, train acc = 0.48247, test loss = 0.0054582076072692874, test acc = 0.474


 41%|████▏     | 33/80 [00:13<00:19,  2.35it/s]

train loss = 0.0052919957327842715, train acc = 0.48246, test loss = 0.005454441785812378, test acc = 0.474


 42%|████▎     | 34/80 [00:14<00:19,  2.35it/s]

train loss = 0.005284664416313171, train acc = 0.48249, test loss = 0.005449446439743042, test acc = 0.474


 44%|████▍     | 35/80 [00:14<00:19,  2.35it/s]

train loss = 0.005281779930591583, train acc = 0.48247, test loss = 0.00544964337348938, test acc = 0.474


 45%|████▌     | 36/80 [00:15<00:18,  2.37it/s]

train loss = 0.00527630816936493, train acc = 0.48247, test loss = 0.005449980974197388, test acc = 0.475


 46%|████▋     | 37/80 [00:15<00:18,  2.38it/s]

train loss = 0.005273200304508209, train acc = 0.48244, test loss = 0.005441493988037109, test acc = 0.475


 48%|████▊     | 38/80 [00:16<00:17,  2.39it/s]

train loss = 0.005270545558929443, train acc = 0.4824, test loss = 0.0054392340183258055, test acc = 0.474


 49%|████▉     | 39/80 [00:16<00:17,  2.40it/s]

train loss = 0.0052686593461036685, train acc = 0.48246, test loss = 0.0054404456615448, test acc = 0.474


 50%|█████     | 40/80 [00:16<00:16,  2.39it/s]

train loss = 0.005266204731464386, train acc = 0.48247, test loss = 0.00543584942817688, test acc = 0.475


 51%|█████▏    | 41/80 [00:17<00:16,  2.40it/s]

train loss = 0.005264739603996277, train acc = 0.48245, test loss = 0.005435604333877564, test acc = 0.475


 52%|█████▎    | 42/80 [00:17<00:15,  2.40it/s]

train loss = 0.005264997351169586, train acc = 0.48248, test loss = 0.005428198099136352, test acc = 0.474


 54%|█████▍    | 43/80 [00:18<00:15,  2.38it/s]

train loss = 0.005261015613079071, train acc = 0.48247, test loss = 0.005430811166763306, test acc = 0.474


 55%|█████▌    | 44/80 [00:18<00:15,  2.38it/s]

train loss = 0.0052616004276275635, train acc = 0.48246, test loss = 0.005427539348602295, test acc = 0.474


 56%|█████▋    | 45/80 [00:19<00:14,  2.36it/s]

train loss = 0.005261186368465423, train acc = 0.48246, test loss = 0.005425002098083496, test acc = 0.474


 57%|█████▊    | 46/80 [00:19<00:14,  2.38it/s]

train loss = 0.00525757372379303, train acc = 0.48247, test loss = 0.0054232368469238285, test acc = 0.474


 59%|█████▉    | 47/80 [00:19<00:13,  2.37it/s]

train loss = 0.005254922339916229, train acc = 0.48247, test loss = 0.005424931287765503, test acc = 0.474


 60%|██████    | 48/80 [00:20<00:13,  2.38it/s]

train loss = 0.005254559664726257, train acc = 0.48247, test loss = 0.005425036191940308, test acc = 0.474


 61%|██████▏   | 49/80 [00:20<00:13,  2.37it/s]

train loss = 0.005253592758178711, train acc = 0.48247, test loss = 0.005426823139190674, test acc = 0.474


 62%|██████▎   | 50/80 [00:21<00:12,  2.38it/s]

train loss = 0.005252173137664795, train acc = 0.48247, test loss = 0.005421801090240479, test acc = 0.474


 64%|██████▍   | 51/80 [00:21<00:12,  2.38it/s]

train loss = 0.005250022985935211, train acc = 0.48247, test loss = 0.005419909715652466, test acc = 0.474


 65%|██████▌   | 52/80 [00:21<00:11,  2.38it/s]

train loss = 0.005250340201854706, train acc = 0.48247, test loss = 0.005411871910095215, test acc = 0.474


 66%|██████▋   | 53/80 [00:22<00:11,  2.38it/s]

train loss = 0.005227652952671051, train acc = 0.48247, test loss = 0.0053573715686798095, test acc = 0.474


 68%|██████▊   | 54/80 [00:22<00:11,  2.34it/s]

train loss = 0.005128857917785645, train acc = 0.50444, test loss = 0.0052043933868408205, test acc = 0.51


 69%|██████▉   | 55/80 [00:23<00:10,  2.34it/s]

train loss = 0.005029714484214783, train acc = 0.51345, test loss = 0.005130629062652588, test acc = 0.51


 70%|███████   | 56/80 [00:23<00:10,  2.34it/s]

train loss = 0.004970382885932922, train acc = 0.52267, test loss = 0.005065196752548217, test acc = 0.547


 71%|███████▏  | 57/80 [00:24<00:09,  2.36it/s]

train loss = 0.004917886900901795, train acc = 0.54459, test loss = 0.005016106367111206, test acc = 0.547


 72%|███████▎  | 58/80 [00:24<00:09,  2.36it/s]

train loss = 0.004875438003540039, train acc = 0.54459, test loss = 0.004976727485656738, test acc = 0.547


 74%|███████▍  | 59/80 [00:24<00:08,  2.37it/s]

train loss = 0.004852990698814392, train acc = 0.54459, test loss = 0.004952244281768799, test acc = 0.547


 75%|███████▌  | 60/80 [00:25<00:08,  2.38it/s]

train loss = 0.00483303474187851, train acc = 0.54459, test loss = 0.004936633348464966, test acc = 0.547


 76%|███████▋  | 61/80 [00:25<00:07,  2.39it/s]

train loss = 0.004820993571281433, train acc = 0.54459, test loss = 0.004925325393676758, test acc = 0.547


 78%|███████▊  | 62/80 [00:26<00:07,  2.39it/s]

train loss = 0.004810764572620392, train acc = 0.54459, test loss = 0.0049172892570495605, test acc = 0.547


 79%|███████▉  | 63/80 [00:26<00:07,  2.39it/s]

train loss = 0.004805529315471649, train acc = 0.54459, test loss = 0.004914032697677612, test acc = 0.547


 80%|████████  | 64/80 [00:27<00:06,  2.39it/s]

train loss = 0.004801647012233734, train acc = 0.54459, test loss = 0.0049068424701690675, test acc = 0.547


 81%|████████▏ | 65/80 [00:27<00:06,  2.37it/s]

train loss = 0.004798023929595947, train acc = 0.54459, test loss = 0.004903839111328125, test acc = 0.547


 82%|████████▎ | 66/80 [00:27<00:05,  2.37it/s]

train loss = 0.004798273224830628, train acc = 0.54459, test loss = 0.004905275821685791, test acc = 0.547


 84%|████████▍ | 67/80 [00:28<00:05,  2.37it/s]

train loss = 0.0047926186633110044, train acc = 0.54459, test loss = 0.004899334192276001, test acc = 0.547


 85%|████████▌ | 68/80 [00:28<00:05,  2.34it/s]

train loss = 0.004791494407653809, train acc = 0.54459, test loss = 0.004899673700332642, test acc = 0.547


 86%|████████▋ | 69/80 [00:29<00:04,  2.35it/s]

train loss = 0.004792289962768555, train acc = 0.54459, test loss = 0.004898046970367432, test acc = 0.547


 88%|████████▊ | 70/80 [00:29<00:04,  2.36it/s]

train loss = 0.004789402704238892, train acc = 0.54459, test loss = 0.00489532470703125, test acc = 0.547


 89%|████████▉ | 71/80 [00:30<00:03,  2.36it/s]

train loss = 0.004789553849697113, train acc = 0.54459, test loss = 0.0048947098255157475, test acc = 0.547


 90%|█████████ | 72/80 [00:30<00:03,  2.36it/s]

train loss = 0.004785815575122834, train acc = 0.54459, test loss = 0.004894742012023926, test acc = 0.547


 91%|█████████▏| 73/80 [00:30<00:02,  2.37it/s]

train loss = 0.004783466620445252, train acc = 0.54459, test loss = 0.004895516872406006, test acc = 0.547


 92%|█████████▎| 74/80 [00:31<00:02,  2.35it/s]

train loss = 0.004783501770496368, train acc = 0.54459, test loss = 0.004889191150665283, test acc = 0.547


 94%|█████████▍| 75/80 [00:31<00:02,  2.36it/s]

train loss = 0.0047832177352905274, train acc = 0.54459, test loss = 0.0048890793323516845, test acc = 0.547


 95%|█████████▌| 76/80 [00:32<00:01,  2.37it/s]

train loss = 0.00478316796541214, train acc = 0.54459, test loss = 0.004892652750015259, test acc = 0.547


 96%|█████████▋| 77/80 [00:32<00:01,  2.38it/s]

train loss = 0.004780767335891723, train acc = 0.54459, test loss = 0.004887993574142456, test acc = 0.547


 98%|█████████▊| 78/80 [00:32<00:00,  2.38it/s]

train loss = 0.004784581906795502, train acc = 0.54459, test loss = 0.004888828754425049, test acc = 0.547


 99%|█████████▉| 79/80 [00:33<00:00,  2.38it/s]

train loss = 0.004779645187854767, train acc = 0.54459, test loss = 0.00488926100730896, test acc = 0.547


100%|██████████| 80/80 [00:33<00:00,  2.37it/s]

train loss = 0.004780913963317871, train acc = 0.54459, test loss = 0.004889308929443359, test acc = 0.547



  1%|▏         | 1/80 [00:00<00:36,  2.19it/s]

train loss = 0.00951553041934967, train acc = 0.07697, test loss = 0.009564930438995362, test acc = 0.114


  2%|▎         | 2/80 [00:00<00:35,  2.20it/s]

train loss = 0.00905029794216156, train acc = 0.1851, test loss = 0.008800426006317138, test acc = 0.239


  4%|▍         | 3/80 [00:01<00:36,  2.13it/s]

train loss = 0.008057956492900848, train acc = 0.27651, test loss = 0.007706060409545899, test acc = 0.354


  5%|▌         | 4/80 [00:01<00:35,  2.15it/s]

train loss = 0.007050858821868897, train acc = 0.40803, test loss = 0.006765671730041504, test acc = 0.446


  6%|▋         | 5/80 [00:02<00:34,  2.16it/s]

train loss = 0.006281016607284546, train acc = 0.45798, test loss = 0.006191268920898437, test acc = 0.482


  8%|▊         | 6/80 [00:02<00:34,  2.17it/s]

train loss = 0.005858484902381897, train acc = 0.49778, test loss = 0.005864950656890869, test acc = 0.519


  9%|▉         | 7/80 [00:03<00:33,  2.17it/s]

train loss = 0.005599147119522095, train acc = 0.51331, test loss = 0.005627917289733887, test acc = 0.521


 10%|█         | 8/80 [00:03<00:33,  2.18it/s]

train loss = 0.005372270050048828, train acc = 0.51896, test loss = 0.0053976716995239255, test acc = 0.548


 11%|█▏        | 9/80 [00:04<00:32,  2.18it/s]

train loss = 0.005162566630840301, train acc = 0.54445, test loss = 0.005217376708984375, test acc = 0.548


 12%|█▎        | 10/80 [00:04<00:32,  2.18it/s]

train loss = 0.005015750827789307, train acc = 0.54454, test loss = 0.00510426664352417, test acc = 0.548


 14%|█▍        | 11/80 [00:05<00:31,  2.17it/s]

train loss = 0.004932906112670899, train acc = 0.54476, test loss = 0.005039632081985474, test acc = 0.548


 15%|█▌        | 12/80 [00:05<00:31,  2.17it/s]

train loss = 0.0048839395928382875, train acc = 0.54476, test loss = 0.004998907089233399, test acc = 0.546


 16%|█▋        | 13/80 [00:06<00:31,  2.14it/s]

train loss = 0.004852416191101074, train acc = 0.5448, test loss = 0.004973002195358276, test acc = 0.546


 18%|█▊        | 14/80 [00:06<00:30,  2.15it/s]

train loss = 0.004831020817756652, train acc = 0.54491, test loss = 0.004952987194061279, test acc = 0.548


 19%|█▉        | 15/80 [00:06<00:30,  2.16it/s]

train loss = 0.004815866239070893, train acc = 0.54484, test loss = 0.004937817335128784, test acc = 0.547


 20%|██        | 16/80 [00:07<00:29,  2.17it/s]

train loss = 0.004806578943729401, train acc = 0.54486, test loss = 0.004930262327194214, test acc = 0.548


 21%|██▏       | 17/80 [00:07<00:29,  2.15it/s]

train loss = 0.004798087985515595, train acc = 0.54486, test loss = 0.004920792579650879, test acc = 0.546


 22%|██▎       | 18/80 [00:08<00:28,  2.15it/s]

train loss = 0.0047951497626304625, train acc = 0.54505, test loss = 0.004915032148361206, test acc = 0.547


 24%|██▍       | 19/80 [00:08<00:28,  2.16it/s]

train loss = 0.004784793281555176, train acc = 0.54468, test loss = 0.004911665916442871, test acc = 0.547


 25%|██▌       | 20/80 [00:09<00:27,  2.15it/s]

train loss = 0.0047803000354766844, train acc = 0.54482, test loss = 0.004902473211288452, test acc = 0.547


 26%|██▋       | 21/80 [00:09<00:27,  2.16it/s]

train loss = 0.004778108747005462, train acc = 0.54491, test loss = 0.004896500825881958, test acc = 0.55


 28%|██▊       | 22/80 [00:10<00:26,  2.17it/s]

train loss = 0.004777222216129303, train acc = 0.54437, test loss = 0.004896040439605713, test acc = 0.547


 29%|██▉       | 23/80 [00:10<00:26,  2.13it/s]

train loss = 0.004769908684492111, train acc = 0.54496, test loss = 0.004891326427459717, test acc = 0.548


 30%|███       | 24/80 [00:11<00:26,  2.14it/s]

train loss = 0.004769435937404633, train acc = 0.54474, test loss = 0.004888580322265625, test acc = 0.548


 31%|███▏      | 25/80 [00:11<00:25,  2.16it/s]

train loss = 0.004768450644016266, train acc = 0.54501, test loss = 0.004885380268096924, test acc = 0.549


 32%|███▎      | 26/80 [00:12<00:24,  2.17it/s]

train loss = 0.004768938627243042, train acc = 0.54478, test loss = 0.004882905006408691, test acc = 0.548


 34%|███▍      | 27/80 [00:12<00:24,  2.17it/s]

train loss = 0.004767064850330353, train acc = 0.54474, test loss = 0.00488211989402771, test acc = 0.548


 35%|███▌      | 28/80 [00:12<00:23,  2.17it/s]

train loss = 0.004764732329845429, train acc = 0.54476, test loss = 0.004879072904586792, test acc = 0.548


 36%|███▋      | 29/80 [00:13<00:23,  2.14it/s]

train loss = 0.004764417216777801, train acc = 0.54471, test loss = 0.004880631685256958, test acc = 0.548


 38%|███▊      | 30/80 [00:13<00:23,  2.15it/s]

train loss = 0.004762759754657745, train acc = 0.5446, test loss = 0.00487605357170105, test acc = 0.548


 39%|███▉      | 31/80 [00:14<00:22,  2.16it/s]

train loss = 0.004761562058925629, train acc = 0.54464, test loss = 0.004876136541366577, test acc = 0.549


 40%|████      | 32/80 [00:14<00:22,  2.17it/s]

train loss = 0.004760935513973236, train acc = 0.54488, test loss = 0.004875807762145996, test acc = 0.548


 41%|████▏     | 33/80 [00:15<00:21,  2.15it/s]

train loss = 0.004759933841228485, train acc = 0.54485, test loss = 0.004871336936950683, test acc = 0.549


 42%|████▎     | 34/80 [00:15<00:21,  2.16it/s]

train loss = 0.004757483761310578, train acc = 0.54472, test loss = 0.00486861515045166, test acc = 0.55


 44%|████▍     | 35/80 [00:16<00:20,  2.17it/s]

train loss = 0.004759180316925049, train acc = 0.54461, test loss = 0.0048701329231262205, test acc = 0.549


 45%|████▌     | 36/80 [00:16<00:20,  2.17it/s]

train loss = 0.004759334082603455, train acc = 0.54445, test loss = 0.004870748043060303, test acc = 0.55


 46%|████▋     | 37/80 [00:17<00:19,  2.18it/s]

train loss = 0.0047581147050857545, train acc = 0.54484, test loss = 0.004866942644119263, test acc = 0.552


 48%|████▊     | 38/80 [00:17<00:19,  2.18it/s]

train loss = 0.004756369051933288, train acc = 0.54456, test loss = 0.004869256496429443, test acc = 0.55


 49%|████▉     | 39/80 [00:18<00:18,  2.18it/s]

train loss = 0.004756955630779267, train acc = 0.54475, test loss = 0.004866597890853882, test acc = 0.552


 50%|█████     | 40/80 [00:18<00:18,  2.18it/s]

train loss = 0.004753861875534057, train acc = 0.54486, test loss = 0.004865450143814087, test acc = 0.549


 51%|█████▏    | 41/80 [00:18<00:17,  2.18it/s]

train loss = 0.004756295282840729, train acc = 0.54465, test loss = 0.004862723350524903, test acc = 0.551


 52%|█████▎    | 42/80 [00:19<00:19,  1.98it/s]

train loss = 0.0047579584693908696, train acc = 0.54454, test loss = 0.004861756563186646, test acc = 0.551


 54%|█████▍    | 43/80 [00:20<00:19,  1.93it/s]

train loss = 0.004754674196243286, train acc = 0.54494, test loss = 0.004863103628158569, test acc = 0.548


 55%|█████▌    | 44/80 [00:20<00:17,  2.01it/s]

train loss = 0.004755365986824036, train acc = 0.54461, test loss = 0.004860641956329345, test acc = 0.55


 56%|█████▋    | 45/80 [00:21<00:16,  2.06it/s]

train loss = 0.004751631166934967, train acc = 0.54445, test loss = 0.004862504720687866, test acc = 0.549


 57%|█████▊    | 46/80 [00:21<00:16,  2.01it/s]

train loss = 0.004751785144805908, train acc = 0.54468, test loss = 0.00486148190498352, test acc = 0.551


 59%|█████▉    | 47/80 [00:21<00:15,  2.06it/s]

train loss = 0.004752078409194946, train acc = 0.5446, test loss = 0.004857778549194336, test acc = 0.553


 60%|██████    | 48/80 [00:22<00:15,  2.11it/s]

train loss = 0.004753673286437989, train acc = 0.5445, test loss = 0.004857347249984741, test acc = 0.549


 61%|██████▏   | 49/80 [00:22<00:14,  2.13it/s]

train loss = 0.004751134865283966, train acc = 0.54437, test loss = 0.0048565993309021, test acc = 0.552


 62%|██████▎   | 50/80 [00:23<00:13,  2.15it/s]

train loss = 0.004750727462768555, train acc = 0.54485, test loss = 0.004855893135070801, test acc = 0.549


 64%|██████▍   | 51/80 [00:23<00:13,  2.16it/s]

train loss = 0.004751954245567322, train acc = 0.54473, test loss = 0.004857486486434937, test acc = 0.549


 65%|██████▌   | 52/80 [00:24<00:13,  2.13it/s]

train loss = 0.004751352500915527, train acc = 0.54454, test loss = 0.004856463670730591, test acc = 0.549


 66%|██████▋   | 53/80 [00:24<00:12,  2.15it/s]

train loss = 0.004747534611225128, train acc = 0.54461, test loss = 0.004855319261550904, test acc = 0.548


 68%|██████▊   | 54/80 [00:25<00:12,  2.17it/s]

train loss = 0.004748717679977417, train acc = 0.54449, test loss = 0.004853755235671997, test acc = 0.553


 69%|██████▉   | 55/80 [00:25<00:11,  2.18it/s]

train loss = 0.00474956778049469, train acc = 0.54451, test loss = 0.004853756666183472, test acc = 0.552


 70%|███████   | 56/80 [00:26<00:11,  2.15it/s]

train loss = 0.004747631130218506, train acc = 0.54431, test loss = 0.0048520355224609376, test acc = 0.554


 71%|███████▏  | 57/80 [00:26<00:10,  2.14it/s]

train loss = 0.004748239026069641, train acc = 0.54468, test loss = 0.004853277444839477, test acc = 0.548


 72%|███████▎  | 58/80 [00:27<00:10,  2.15it/s]

train loss = 0.004748916640281677, train acc = 0.54473, test loss = 0.004850903034210205, test acc = 0.55


 74%|███████▍  | 59/80 [00:27<00:09,  2.16it/s]

train loss = 0.004746802129745483, train acc = 0.54465, test loss = 0.0048525044918060305, test acc = 0.55


 75%|███████▌  | 60/80 [00:27<00:09,  2.17it/s]

train loss = 0.004746077275276184, train acc = 0.54472, test loss = 0.004851552486419677, test acc = 0.555


 76%|███████▋  | 61/80 [00:28<00:08,  2.15it/s]

train loss = 0.004747584388256073, train acc = 0.54458, test loss = 0.004847626447677613, test acc = 0.554


 78%|███████▊  | 62/80 [00:28<00:08,  2.16it/s]

train loss = 0.0047455978322029115, train acc = 0.54439, test loss = 0.004848750591278076, test acc = 0.553


 79%|███████▉  | 63/80 [00:29<00:07,  2.16it/s]

train loss = 0.004746938858032227, train acc = 0.54521, test loss = 0.0048506691455841066, test acc = 0.548


 80%|████████  | 64/80 [00:29<00:07,  2.16it/s]

train loss = 0.004742992570400238, train acc = 0.54455, test loss = 0.004845983982086181, test acc = 0.553


 81%|████████▏ | 65/80 [00:30<00:06,  2.17it/s]

train loss = 0.0047439209127426146, train acc = 0.54494, test loss = 0.004847533226013183, test acc = 0.551


 82%|████████▎ | 66/80 [00:30<00:06,  2.18it/s]

train loss = 0.00474570198059082, train acc = 0.54453, test loss = 0.004848968505859375, test acc = 0.55


 84%|████████▍ | 67/80 [00:31<00:05,  2.18it/s]

train loss = 0.004744319500923157, train acc = 0.54445, test loss = 0.00484574294090271, test acc = 0.557


 85%|████████▌ | 68/80 [00:31<00:05,  2.18it/s]

train loss = 0.004743385617733002, train acc = 0.54504, test loss = 0.004846273899078369, test acc = 0.553


 86%|████████▋ | 69/80 [00:32<00:05,  2.18it/s]

train loss = 0.004741148841381073, train acc = 0.54481, test loss = 0.004843167543411255, test acc = 0.554


 88%|████████▊ | 70/80 [00:32<00:04,  2.15it/s]

train loss = 0.004744489691257477, train acc = 0.54443, test loss = 0.004844628810882569, test acc = 0.553


 89%|████████▉ | 71/80 [00:33<00:04,  2.16it/s]

train loss = 0.004743471903800965, train acc = 0.54442, test loss = 0.00484515118598938, test acc = 0.552


 90%|█████████ | 72/80 [00:33<00:03,  2.17it/s]

train loss = 0.004743412067890167, train acc = 0.54455, test loss = 0.004843022346496582, test acc = 0.553


 91%|█████████▏| 73/80 [00:33<00:03,  2.18it/s]

train loss = 0.00474413676738739, train acc = 0.54462, test loss = 0.004845801830291748, test acc = 0.551


 92%|█████████▎| 74/80 [00:34<00:02,  2.19it/s]

train loss = 0.004743422408103943, train acc = 0.54472, test loss = 0.00484452486038208, test acc = 0.55


 94%|█████████▍| 75/80 [00:34<00:02,  2.20it/s]

train loss = 0.004742223863601684, train acc = 0.54518, test loss = 0.004844552993774414, test acc = 0.548


 95%|█████████▌| 76/80 [00:35<00:01,  2.22it/s]

train loss = 0.004742143251895904, train acc = 0.54452, test loss = 0.0048456952571868895, test acc = 0.555


 96%|█████████▋| 77/80 [00:35<00:01,  2.23it/s]

train loss = 0.004741428470611572, train acc = 0.54487, test loss = 0.004843711137771606, test acc = 0.549


 98%|█████████▊| 78/80 [00:36<00:00,  2.22it/s]

train loss = 0.004740850291252137, train acc = 0.54493, test loss = 0.004842173099517822, test acc = 0.552


 99%|█████████▉| 79/80 [00:36<00:00,  2.17it/s]

train loss = 0.004739903838634491, train acc = 0.54479, test loss = 0.004842676639556885, test acc = 0.55


100%|██████████| 80/80 [00:37<00:00,  2.15it/s]

train loss = 0.0047414569902420045, train acc = 0.54525, test loss = 0.004845996618270874, test acc = 0.554



  1%|▏         | 1/80 [00:00<00:37,  2.11it/s]

train loss = 0.009478747129440308, train acc = 0.08168, test loss = 0.009433747291564942, test acc = 0.152


  2%|▎         | 2/80 [00:00<00:36,  2.12it/s]

train loss = 0.008621103112697601, train acc = 0.35204, test loss = 0.008008205890655517, test acc = 0.382


  4%|▍         | 3/80 [00:01<00:36,  2.11it/s]

train loss = 0.0071289038968086246, train acc = 0.4424, test loss = 0.006623030662536621, test acc = 0.447


  5%|▌         | 4/80 [00:01<00:35,  2.11it/s]

train loss = 0.0061592613792419435, train acc = 0.4515, test loss = 0.006042149066925049, test acc = 0.447


  6%|▋         | 5/80 [00:02<00:35,  2.11it/s]

train loss = 0.005759308848381043, train acc = 0.45146, test loss = 0.005752755403518677, test acc = 0.447


  8%|▊         | 6/80 [00:02<00:35,  2.11it/s]

train loss = 0.005520240316390992, train acc = 0.45656, test loss = 0.005535243034362793, test acc = 0.47


  9%|▉         | 7/80 [00:03<00:34,  2.10it/s]

train loss = 0.005299038553237915, train acc = 0.51177, test loss = 0.005335330247879028, test acc = 0.51


 10%|█         | 8/80 [00:03<00:34,  2.09it/s]

train loss = 0.005118202230930328, train acc = 0.52402, test loss = 0.0051079046726226806, test acc = 0.547


 11%|█▏        | 9/80 [00:04<00:34,  2.09it/s]

train loss = 0.004938303470611572, train acc = 0.5447, test loss = 0.0049921321868896485, test acc = 0.548


 12%|█▎        | 10/80 [00:04<00:33,  2.09it/s]

train loss = 0.004863742311000824, train acc = 0.54471, test loss = 0.004939999580383301, test acc = 0.547


 14%|█▍        | 11/80 [00:05<00:32,  2.09it/s]

train loss = 0.004830195281505585, train acc = 0.54467, test loss = 0.004910322427749633, test acc = 0.549


 15%|█▌        | 12/80 [00:05<00:32,  2.09it/s]

train loss = 0.0048041027092933656, train acc = 0.5445, test loss = 0.004890887022018432, test acc = 0.553


 16%|█▋        | 13/80 [00:06<00:32,  2.09it/s]

train loss = 0.004791292169094086, train acc = 0.54453, test loss = 0.004873574495315552, test acc = 0.553


 18%|█▊        | 14/80 [00:06<00:31,  2.09it/s]

train loss = 0.004780538864135742, train acc = 0.54464, test loss = 0.004862775325775146, test acc = 0.555


 19%|█▉        | 15/80 [00:07<00:31,  2.09it/s]

train loss = 0.0047743708848953246, train acc = 0.54454, test loss = 0.004854764938354492, test acc = 0.55


 20%|██        | 16/80 [00:07<00:31,  2.06it/s]

train loss = 0.004767156443595886, train acc = 0.54471, test loss = 0.004845614194869995, test acc = 0.552


 21%|██▏       | 17/80 [00:08<00:30,  2.07it/s]

train loss = 0.004762356200218201, train acc = 0.54444, test loss = 0.0048410708904266355, test acc = 0.55


 22%|██▎       | 18/80 [00:08<00:29,  2.08it/s]

train loss = 0.004757891664505005, train acc = 0.54475, test loss = 0.0048382430076599125, test acc = 0.544


 24%|██▍       | 19/80 [00:09<00:29,  2.07it/s]

train loss = 0.004756423578262329, train acc = 0.5444, test loss = 0.004830882310867309, test acc = 0.547


 25%|██▌       | 20/80 [00:09<00:28,  2.08it/s]

train loss = 0.004754240555763244, train acc = 0.54487, test loss = 0.004829394578933716, test acc = 0.547


 26%|██▋       | 21/80 [00:10<00:28,  2.08it/s]

train loss = 0.004752367534637451, train acc = 0.54524, test loss = 0.00482642674446106, test acc = 0.547


 28%|██▊       | 22/80 [00:10<00:27,  2.09it/s]

train loss = 0.004750029702186584, train acc = 0.54529, test loss = 0.00482066535949707, test acc = 0.552


 29%|██▉       | 23/80 [00:10<00:27,  2.09it/s]

train loss = 0.004747046921253204, train acc = 0.54544, test loss = 0.004819394111633301, test acc = 0.55


 30%|███       | 24/80 [00:11<00:26,  2.10it/s]

train loss = 0.0047467385101318356, train acc = 0.54552, test loss = 0.00481932544708252, test acc = 0.547


 31%|███▏      | 25/80 [00:11<00:26,  2.08it/s]

train loss = 0.004741177151203155, train acc = 0.54551, test loss = 0.004813868999481202, test acc = 0.551


 32%|███▎      | 26/80 [00:12<00:25,  2.09it/s]

train loss = 0.004740764863491059, train acc = 0.54525, test loss = 0.0048138067722320555, test acc = 0.549


 34%|███▍      | 27/80 [00:12<00:25,  2.09it/s]

train loss = 0.004740860657691956, train acc = 0.54558, test loss = 0.0048118956089019775, test acc = 0.554


 35%|███▌      | 28/80 [00:13<00:24,  2.09it/s]

train loss = 0.004739051728248596, train acc = 0.54594, test loss = 0.004813713073730469, test acc = 0.552


 36%|███▋      | 29/80 [00:13<00:24,  2.07it/s]

train loss = 0.0047385767889022825, train acc = 0.54554, test loss = 0.004809595108032226, test acc = 0.555


 38%|███▊      | 30/80 [00:14<00:24,  2.08it/s]

train loss = 0.0047370147013664245, train acc = 0.54598, test loss = 0.004807865619659424, test acc = 0.554


 39%|███▉      | 31/80 [00:14<00:23,  2.08it/s]

train loss = 0.00474081916809082, train acc = 0.54557, test loss = 0.004808748483657837, test acc = 0.549


 40%|████      | 32/80 [00:15<00:22,  2.10it/s]

train loss = 0.004736781141757965, train acc = 0.5458, test loss = 0.00480804443359375, test acc = 0.548


 41%|████▏     | 33/80 [00:15<00:22,  2.06it/s]

train loss = 0.004737604765892029, train acc = 0.54596, test loss = 0.004806783199310303, test acc = 0.55


 42%|████▎     | 34/80 [00:16<00:22,  2.07it/s]

train loss = 0.004736194236278534, train acc = 0.54577, test loss = 0.004804455757141114, test acc = 0.553


 44%|████▍     | 35/80 [00:16<00:21,  2.08it/s]

train loss = 0.004736238088607788, train acc = 0.5454, test loss = 0.004806426286697388, test acc = 0.552


 45%|████▌     | 36/80 [00:17<00:21,  2.08it/s]

train loss = 0.004733779294490814, train acc = 0.54587, test loss = 0.004805728435516358, test acc = 0.558


 46%|████▋     | 37/80 [00:17<00:20,  2.08it/s]

train loss = 0.004733864171504974, train acc = 0.54545, test loss = 0.004804339647293091, test acc = 0.548


 48%|████▊     | 38/80 [00:18<00:20,  2.07it/s]

train loss = 0.004734448983669281, train acc = 0.54545, test loss = 0.004802844762802124, test acc = 0.555


 49%|████▉     | 39/80 [00:18<00:19,  2.08it/s]

train loss = 0.00473315598487854, train acc = 0.54589, test loss = 0.004804204940795898, test acc = 0.561


 50%|█████     | 40/80 [00:19<00:19,  2.09it/s]

train loss = 0.004732128689289093, train acc = 0.54579, test loss = 0.004802020311355591, test acc = 0.553


 51%|█████▏    | 41/80 [00:19<00:18,  2.07it/s]

train loss = 0.004730619556903839, train acc = 0.54522, test loss = 0.004798758745193482, test acc = 0.555


 52%|█████▎    | 42/80 [00:20<00:18,  2.07it/s]

train loss = 0.004732116785049438, train acc = 0.54578, test loss = 0.004798100709915161, test acc = 0.555


 54%|█████▍    | 43/80 [00:20<00:17,  2.07it/s]

train loss = 0.004732822020053863, train acc = 0.54558, test loss = 0.004798774003982544, test acc = 0.553


 55%|█████▌    | 44/80 [00:21<00:17,  2.07it/s]

train loss = 0.004730937843322754, train acc = 0.54539, test loss = 0.004800644397735596, test acc = 0.557


 56%|█████▋    | 45/80 [00:21<00:16,  2.07it/s]

train loss = 0.004730690677165985, train acc = 0.54549, test loss = 0.004796837329864502, test acc = 0.554


 57%|█████▊    | 46/80 [00:22<00:16,  2.09it/s]

train loss = 0.004731722083091736, train acc = 0.54532, test loss = 0.004800463438034057, test acc = 0.554


 59%|█████▉    | 47/80 [00:22<00:15,  2.08it/s]

train loss = 0.004731436553001404, train acc = 0.54543, test loss = 0.004799001693725586, test acc = 0.549


 60%|██████    | 48/80 [00:23<00:15,  2.09it/s]

train loss = 0.00473037015914917, train acc = 0.54552, test loss = 0.004802028894424439, test acc = 0.554


 61%|██████▏   | 49/80 [00:23<00:15,  2.06it/s]

train loss = 0.004732270219326019, train acc = 0.54523, test loss = 0.004800427675247193, test acc = 0.557


 62%|██████▎   | 50/80 [00:23<00:14,  2.07it/s]

train loss = 0.004729691979885101, train acc = 0.54572, test loss = 0.004799762725830078, test acc = 0.553


 64%|██████▍   | 51/80 [00:24<00:13,  2.07it/s]

train loss = 0.004729812126159668, train acc = 0.54555, test loss = 0.0048002443313598635, test acc = 0.555


 65%|██████▌   | 52/80 [00:24<00:13,  2.08it/s]

train loss = 0.004729051280021667, train acc = 0.54541, test loss = 0.004798478841781616, test acc = 0.552


 66%|██████▋   | 53/80 [00:25<00:12,  2.09it/s]

train loss = 0.004727449173927307, train acc = 0.54529, test loss = 0.004799980401992798, test acc = 0.552


 68%|██████▊   | 54/80 [00:25<00:12,  2.09it/s]

train loss = 0.004728794054985047, train acc = 0.54502, test loss = 0.004796858072280884, test acc = 0.547


 69%|██████▉   | 55/80 [00:26<00:12,  2.08it/s]

train loss = 0.004729200935363769, train acc = 0.54572, test loss = 0.00479919695854187, test acc = 0.555


 70%|███████   | 56/80 [00:26<00:11,  2.08it/s]

train loss = 0.004728716537952423, train acc = 0.54522, test loss = 0.0047977914810180666, test acc = 0.554


 71%|███████▏  | 57/80 [00:27<00:11,  2.05it/s]

train loss = 0.004726150312423706, train acc = 0.5455, test loss = 0.004796363115310669, test acc = 0.547


 72%|███████▎  | 58/80 [00:27<00:10,  2.06it/s]

train loss = 0.004727552092075348, train acc = 0.54584, test loss = 0.004796884298324585, test acc = 0.559


 74%|███████▍  | 59/80 [00:28<00:10,  2.07it/s]

train loss = 0.004727743263244629, train acc = 0.54572, test loss = 0.004797378063201904, test acc = 0.553


 75%|███████▌  | 60/80 [00:28<00:09,  2.07it/s]

train loss = 0.0047277786135673525, train acc = 0.54562, test loss = 0.004797377586364746, test acc = 0.551


 76%|███████▋  | 61/80 [00:29<00:09,  2.07it/s]

train loss = 0.004726678001880646, train acc = 0.54554, test loss = 0.004800804615020752, test acc = 0.557


 78%|███████▊  | 62/80 [00:29<00:08,  2.07it/s]

train loss = 0.004726732916831971, train acc = 0.5463, test loss = 0.004799926042556762, test acc = 0.555


 79%|███████▉  | 63/80 [00:30<00:08,  2.08it/s]

train loss = 0.004724910144805908, train acc = 0.5462, test loss = 0.004795466899871826, test acc = 0.555


 80%|████████  | 64/80 [00:30<00:07,  2.07it/s]

train loss = 0.0047252709341049194, train acc = 0.54598, test loss = 0.004800972223281861, test acc = 0.558


 81%|████████▏ | 65/80 [00:31<00:07,  2.04it/s]

train loss = 0.004724679238796234, train acc = 0.54538, test loss = 0.0047959151268005375, test acc = 0.561


 82%|████████▎ | 66/80 [00:31<00:06,  2.06it/s]

train loss = 0.004726174359321594, train acc = 0.54601, test loss = 0.00479865026473999, test acc = 0.558


 84%|████████▍ | 67/80 [00:32<00:06,  2.06it/s]

train loss = 0.004724848954677582, train acc = 0.54627, test loss = 0.004798196792602539, test acc = 0.559


 85%|████████▌ | 68/80 [00:32<00:05,  2.08it/s]

train loss = 0.0047248947930336, train acc = 0.54564, test loss = 0.0047970876693725585, test acc = 0.561


 86%|████████▋ | 69/80 [00:33<00:05,  2.09it/s]

train loss = 0.004721114115715027, train acc = 0.54584, test loss = 0.004798722267150879, test acc = 0.553


 88%|████████▊ | 70/80 [00:33<00:04,  2.11it/s]

train loss = 0.004725247552394867, train acc = 0.54623, test loss = 0.004796544313430786, test acc = 0.555


 89%|████████▉ | 71/80 [00:34<00:04,  2.08it/s]

train loss = 0.004724398927688599, train acc = 0.54629, test loss = 0.004801439762115479, test acc = 0.559


 90%|█████████ | 72/80 [00:34<00:03,  2.09it/s]

train loss = 0.004722430593967437, train acc = 0.5466, test loss = 0.00479723334312439, test acc = 0.556


 91%|█████████▏| 73/80 [00:35<00:03,  2.06it/s]

train loss = 0.0047238845610618595, train acc = 0.54617, test loss = 0.004798467636108398, test acc = 0.558


 92%|█████████▎| 74/80 [00:35<00:02,  2.07it/s]

train loss = 0.004722468922138214, train acc = 0.54606, test loss = 0.00480118465423584, test acc = 0.558


 94%|█████████▍| 75/80 [00:36<00:02,  2.07it/s]

train loss = 0.004722620944976806, train acc = 0.54643, test loss = 0.004796331405639649, test acc = 0.556


 95%|█████████▌| 76/80 [00:36<00:01,  2.07it/s]

train loss = 0.004721331670284271, train acc = 0.54537, test loss = 0.004797630548477173, test acc = 0.555


 96%|█████████▋| 77/80 [00:37<00:01,  2.07it/s]

train loss = 0.0047214480757713315, train acc = 0.54602, test loss = 0.004800450325012207, test acc = 0.556


 98%|█████████▊| 78/80 [00:37<00:00,  2.07it/s]

train loss = 0.004720251457691192, train acc = 0.54659, test loss = 0.004798675775527954, test acc = 0.556


 99%|█████████▉| 79/80 [00:37<00:00,  2.06it/s]

train loss = 0.00472139196395874, train acc = 0.54638, test loss = 0.004798148155212402, test acc = 0.555


100%|██████████| 80/80 [00:38<00:00,  2.08it/s]

train loss = 0.0047226300883293154, train acc = 0.54653, test loss = 0.004797106504440308, test acc = 0.555



  1%|▏         | 1/80 [00:00<00:38,  2.04it/s]

train loss = 0.009462050557136536, train acc = 0.09792, test loss = 0.009419225215911865, test acc = 0.192


  2%|▎         | 2/80 [00:00<00:38,  2.04it/s]

train loss = 0.008465453600883483, train acc = 0.38618, test loss = 0.007564168214797973, test acc = 0.45


  4%|▍         | 3/80 [00:01<00:37,  2.04it/s]

train loss = 0.006559716749191284, train acc = 0.47092, test loss = 0.00602867078781128, test acc = 0.514


  5%|▌         | 4/80 [00:02<00:41,  1.82it/s]

train loss = 0.005564754421710968, train acc = 0.52274, test loss = 0.005425474405288696, test acc = 0.547


  6%|▋         | 5/80 [00:02<00:39,  1.88it/s]

train loss = 0.0051761319732666015, train acc = 0.54478, test loss = 0.005167640209197998, test acc = 0.547


  8%|▊         | 6/80 [00:03<00:38,  1.93it/s]

train loss = 0.004995270004272461, train acc = 0.54488, test loss = 0.005033241987228394, test acc = 0.548


  9%|▉         | 7/80 [00:03<00:37,  1.94it/s]

train loss = 0.004893611392974853, train acc = 0.54508, test loss = 0.004954443216323852, test acc = 0.546


 10%|█         | 8/80 [00:04<00:39,  1.82it/s]

train loss = 0.004840755541324616, train acc = 0.54498, test loss = 0.004908499956130982, test acc = 0.544


 11%|█▏        | 9/80 [00:04<00:38,  1.84it/s]

train loss = 0.004806696555614471, train acc = 0.54513, test loss = 0.004879437208175659, test acc = 0.546


 12%|█▎        | 10/80 [00:05<00:37,  1.89it/s]

train loss = 0.004784473333358764, train acc = 0.54503, test loss = 0.0048617286682128905, test acc = 0.541


 14%|█▍        | 11/80 [00:05<00:35,  1.92it/s]

train loss = 0.004770792925357818, train acc = 0.54449, test loss = 0.004847801685333252, test acc = 0.547


 15%|█▌        | 12/80 [00:06<00:35,  1.91it/s]

train loss = 0.004761341483592987, train acc = 0.54463, test loss = 0.004839116096496582, test acc = 0.542


 16%|█▋        | 13/80 [00:06<00:34,  1.94it/s]

train loss = 0.0047566616344451905, train acc = 0.54436, test loss = 0.004831864356994629, test acc = 0.545


 18%|█▊        | 14/80 [00:07<00:33,  1.94it/s]

train loss = 0.004751752238273621, train acc = 0.54416, test loss = 0.004827778100967408, test acc = 0.543


 19%|█▉        | 15/80 [00:07<00:32,  1.98it/s]

train loss = 0.004747306549549103, train acc = 0.54477, test loss = 0.004824394464492798, test acc = 0.544


 20%|██        | 16/80 [00:08<00:32,  1.99it/s]

train loss = 0.004743329722881317, train acc = 0.54433, test loss = 0.00481986141204834, test acc = 0.541


 21%|██▏       | 17/80 [00:08<00:31,  2.00it/s]

train loss = 0.004745013635158539, train acc = 0.54464, test loss = 0.004815592527389527, test acc = 0.544


 22%|██▎       | 18/80 [00:09<00:30,  2.02it/s]

train loss = 0.004741548407077789, train acc = 0.5445, test loss = 0.004816439151763916, test acc = 0.54


 24%|██▍       | 19/80 [00:09<00:30,  2.03it/s]

train loss = 0.004742329604625702, train acc = 0.54502, test loss = 0.004813015460968018, test acc = 0.544


 25%|██▌       | 20/80 [00:10<00:29,  2.02it/s]

train loss = 0.004737864224910736, train acc = 0.54475, test loss = 0.004809869289398193, test acc = 0.544


 26%|██▋       | 21/80 [00:10<00:29,  2.02it/s]

train loss = 0.004737449657917022, train acc = 0.54482, test loss = 0.0048069143295288086, test acc = 0.544


 28%|██▊       | 22/80 [00:11<00:28,  2.02it/s]

train loss = 0.00473456992149353, train acc = 0.5449, test loss = 0.0048106563091278075, test acc = 0.544


 29%|██▉       | 23/80 [00:11<00:28,  2.04it/s]

train loss = 0.004733904929161072, train acc = 0.54478, test loss = 0.00480668044090271, test acc = 0.543


 30%|███       | 24/80 [00:12<00:28,  1.99it/s]

train loss = 0.004731795284748077, train acc = 0.54501, test loss = 0.004806365966796875, test acc = 0.544


 31%|███▏      | 25/80 [00:12<00:27,  2.00it/s]

train loss = 0.004731407723426818, train acc = 0.54466, test loss = 0.004805030584335327, test acc = 0.544


 32%|███▎      | 26/80 [00:13<00:26,  2.01it/s]

train loss = 0.004732590112686157, train acc = 0.54522, test loss = 0.00480374026298523, test acc = 0.54


 34%|███▍      | 27/80 [00:13<00:26,  1.99it/s]

train loss = 0.004729926745891571, train acc = 0.54451, test loss = 0.004805571317672729, test acc = 0.542


 35%|███▌      | 28/80 [00:14<00:26,  1.99it/s]

train loss = 0.004726933455467224, train acc = 0.5451, test loss = 0.004804162740707397, test acc = 0.545


 36%|███▋      | 29/80 [00:14<00:25,  2.01it/s]

train loss = 0.0047228612756729125, train acc = 0.54524, test loss = 0.00480854868888855, test acc = 0.549


 38%|███▊      | 30/80 [00:15<00:25,  2.00it/s]

train loss = 0.004724364404678345, train acc = 0.54594, test loss = 0.004806888341903686, test acc = 0.541


 39%|███▉      | 31/80 [00:15<00:24,  2.01it/s]

train loss = 0.004722283825874329, train acc = 0.54556, test loss = 0.004806943893432617, test acc = 0.545


 40%|████      | 32/80 [00:16<00:24,  1.96it/s]

train loss = 0.00472177759885788, train acc = 0.54544, test loss = 0.004807654857635498, test acc = 0.542


 41%|████▏     | 33/80 [00:16<00:23,  1.97it/s]

train loss = 0.0047214214015007016, train acc = 0.54554, test loss = 0.004805919885635376, test acc = 0.55


 42%|████▎     | 34/80 [00:17<00:23,  1.98it/s]

train loss = 0.00471809757232666, train acc = 0.54503, test loss = 0.004810960769653321, test acc = 0.545


 44%|████▍     | 35/80 [00:17<00:22,  1.98it/s]

train loss = 0.004717262206077576, train acc = 0.545, test loss = 0.004811004638671875, test acc = 0.534


 45%|████▌     | 36/80 [00:18<00:22,  1.99it/s]

train loss = 0.004717282266616821, train acc = 0.54537, test loss = 0.004809486865997315, test acc = 0.54


 46%|████▋     | 37/80 [00:18<00:21,  1.99it/s]

train loss = 0.004717717030048371, train acc = 0.54575, test loss = 0.00480827522277832, test acc = 0.548


 48%|████▊     | 38/80 [00:19<00:21,  1.98it/s]

train loss = 0.004716897597312927, train acc = 0.54529, test loss = 0.0048087496757507325, test acc = 0.538


 49%|████▉     | 39/80 [00:19<00:20,  1.98it/s]

train loss = 0.0047178199863433835, train acc = 0.5453, test loss = 0.004806087493896484, test acc = 0.55


 50%|█████     | 40/80 [00:20<00:20,  1.94it/s]

train loss = 0.004717386701107025, train acc = 0.54495, test loss = 0.004807471513748169, test acc = 0.535


 51%|█████▏    | 41/80 [00:20<00:20,  1.93it/s]

train loss = 0.0047144813990592956, train acc = 0.54577, test loss = 0.004809731960296631, test acc = 0.55


 52%|█████▎    | 42/80 [00:21<00:19,  1.95it/s]

train loss = 0.00471535281419754, train acc = 0.54563, test loss = 0.004805679798126221, test acc = 0.541


 54%|█████▍    | 43/80 [00:21<00:18,  1.98it/s]

train loss = 0.004715586884021759, train acc = 0.54588, test loss = 0.004806746006011963, test acc = 0.544


 55%|█████▌    | 44/80 [00:22<00:18,  1.99it/s]

train loss = 0.004714569578170776, train acc = 0.54524, test loss = 0.004804548740386963, test acc = 0.556


 56%|█████▋    | 45/80 [00:22<00:17,  2.00it/s]

train loss = 0.004714397854804993, train acc = 0.54537, test loss = 0.004806409597396851, test acc = 0.548


 57%|█████▊    | 46/80 [00:23<00:16,  2.00it/s]

train loss = 0.004714294261932373, train acc = 0.54583, test loss = 0.004805712223052978, test acc = 0.547


 59%|█████▉    | 47/80 [00:23<00:16,  2.02it/s]

train loss = 0.00471419399023056, train acc = 0.54446, test loss = 0.004806097745895386, test acc = 0.548


 60%|██████    | 48/80 [00:24<00:15,  2.01it/s]

train loss = 0.004712075769901276, train acc = 0.54473, test loss = 0.0048057172298431395, test acc = 0.536


 61%|██████▏   | 49/80 [00:24<00:15,  2.01it/s]

train loss = 0.00471456645488739, train acc = 0.54503, test loss = 0.0048040149211883544, test acc = 0.542


 62%|██████▎   | 50/80 [00:25<00:14,  2.01it/s]

train loss = 0.004713075602054596, train acc = 0.54541, test loss = 0.004802241086959839, test acc = 0.537


 64%|██████▍   | 51/80 [00:25<00:14,  2.00it/s]

train loss = 0.004712031681537628, train acc = 0.54643, test loss = 0.004805008888244629, test acc = 0.542


 65%|██████▌   | 52/80 [00:26<00:14,  1.98it/s]

train loss = 0.004712352545261383, train acc = 0.54503, test loss = 0.004800794363021851, test acc = 0.55


 66%|██████▋   | 53/80 [00:26<00:13,  1.98it/s]

train loss = 0.004712093658447266, train acc = 0.54633, test loss = 0.004806929349899292, test acc = 0.539


 68%|██████▊   | 54/80 [00:27<00:13,  1.98it/s]

train loss = 0.004711039426326752, train acc = 0.54611, test loss = 0.004803006172180176, test acc = 0.557


 69%|██████▉   | 55/80 [00:27<00:12,  1.99it/s]

train loss = 0.004713614225387573, train acc = 0.54668, test loss = 0.004803350210189819, test acc = 0.536


 70%|███████   | 56/80 [00:28<00:11,  2.00it/s]

train loss = 0.004711748530864716, train acc = 0.54579, test loss = 0.004805304527282715, test acc = 0.541


 71%|███████▏  | 57/80 [00:28<00:11,  1.98it/s]

train loss = 0.00471117892742157, train acc = 0.54614, test loss = 0.004803779125213623, test acc = 0.553


 72%|███████▎  | 58/80 [00:29<00:11,  1.99it/s]

train loss = 0.004713870346546173, train acc = 0.54621, test loss = 0.004801984071731567, test acc = 0.542


 74%|███████▍  | 59/80 [00:29<00:10,  2.00it/s]

train loss = 0.004714366414546966, train acc = 0.54551, test loss = 0.004802431106567383, test acc = 0.548


 75%|███████▌  | 60/80 [00:30<00:09,  2.00it/s]

train loss = 0.004712811462879181, train acc = 0.54614, test loss = 0.004804147243499756, test acc = 0.538


 76%|███████▋  | 61/80 [00:30<00:09,  2.01it/s]

train loss = 0.0047126702618598935, train acc = 0.54659, test loss = 0.004801020622253418, test acc = 0.548


 78%|███████▊  | 62/80 [00:31<00:08,  2.01it/s]

train loss = 0.00471289252281189, train acc = 0.54643, test loss = 0.004800879955291748, test acc = 0.545


 79%|███████▉  | 63/80 [00:31<00:08,  2.00it/s]

train loss = 0.004712242743968964, train acc = 0.54701, test loss = 0.004803021669387817, test acc = 0.545


 80%|████████  | 64/80 [00:32<00:07,  2.01it/s]

train loss = 0.00471305370092392, train acc = 0.54607, test loss = 0.004801756381988526, test acc = 0.545


 81%|████████▏ | 65/80 [00:32<00:07,  2.03it/s]

train loss = 0.004712297804355621, train acc = 0.5453, test loss = 0.00480318021774292, test acc = 0.537


 82%|████████▎ | 66/80 [00:33<00:07,  1.95it/s]

train loss = 0.004711859993934631, train acc = 0.54642, test loss = 0.004802588701248169, test acc = 0.54


 84%|████████▍ | 67/80 [00:33<00:06,  1.95it/s]

train loss = 0.004712930552959442, train acc = 0.54632, test loss = 0.004801120519638062, test acc = 0.547


 85%|████████▌ | 68/80 [00:34<00:06,  1.97it/s]

train loss = 0.00471168779373169, train acc = 0.54698, test loss = 0.0048017089366912845, test acc = 0.539


 86%|████████▋ | 69/80 [00:34<00:05,  1.99it/s]

train loss = 0.004712226989269257, train acc = 0.54534, test loss = 0.004801218032836914, test acc = 0.547


 88%|████████▊ | 70/80 [00:35<00:04,  2.00it/s]

train loss = 0.004709899108409882, train acc = 0.54727, test loss = 0.004804938793182373, test acc = 0.537


 89%|████████▉ | 71/80 [00:35<00:04,  1.98it/s]

train loss = 0.004712778327465058, train acc = 0.54626, test loss = 0.00480030083656311, test acc = 0.544


 90%|█████████ | 72/80 [00:36<00:04,  1.99it/s]

train loss = 0.004711336212158203, train acc = 0.54648, test loss = 0.0048036806583404545, test acc = 0.549


 91%|█████████▏| 73/80 [00:36<00:03,  2.01it/s]

train loss = 0.004711285102367401, train acc = 0.54524, test loss = 0.004804506063461304, test acc = 0.536


 92%|█████████▎| 74/80 [00:37<00:02,  2.02it/s]

train loss = 0.0047123293375968935, train acc = 0.54666, test loss = 0.004800443410873413, test acc = 0.541


 94%|█████████▍| 75/80 [00:37<00:02,  2.03it/s]

train loss = 0.004711050333976745, train acc = 0.54614, test loss = 0.004799888372421264, test acc = 0.552


 95%|█████████▌| 76/80 [00:38<00:01,  2.01it/s]

train loss = 0.004712929112911224, train acc = 0.54624, test loss = 0.0048005921840667725, test acc = 0.553


 96%|█████████▋| 77/80 [00:38<00:01,  2.02it/s]

train loss = 0.004713024170398712, train acc = 0.54652, test loss = 0.004798702478408814, test acc = 0.548


 98%|█████████▊| 78/80 [00:39<00:00,  2.03it/s]

train loss = 0.004712210078239441, train acc = 0.54632, test loss = 0.004798026561737061, test acc = 0.542


 99%|█████████▉| 79/80 [00:39<00:00,  2.03it/s]

train loss = 0.004714231522083282, train acc = 0.54689, test loss = 0.004801828145980835, test acc = 0.536


100%|██████████| 80/80 [00:40<00:00,  1.98it/s]

train loss = 0.004713001444339752, train acc = 0.54633, test loss = 0.004800518989562988, test acc = 0.54



  1%|▏         | 1/80 [00:00<00:39,  1.98it/s]

train loss = 0.0094303129863739, train acc = 0.0973, test loss = 0.009311222076416015, test acc = 0.218


  2%|▎         | 2/80 [00:01<00:39,  1.98it/s]

train loss = 0.008134816386699677, train acc = 0.39941, test loss = 0.00708597469329834, test acc = 0.414


  4%|▍         | 3/80 [00:01<00:38,  1.98it/s]

train loss = 0.006096689882278443, train acc = 0.49862, test loss = 0.0055714952945709224, test acc = 0.547


  5%|▌         | 4/80 [00:02<00:38,  1.98it/s]

train loss = 0.005194367735385895, train acc = 0.54467, test loss = 0.005111918210983276, test acc = 0.547


  6%|▋         | 5/80 [00:02<00:37,  1.98it/s]

train loss = 0.00493424030303955, train acc = 0.54457, test loss = 0.004973664045333863, test acc = 0.547


  8%|▊         | 6/80 [00:03<00:37,  1.99it/s]

train loss = 0.004844783589839936, train acc = 0.54473, test loss = 0.004914512395858764, test acc = 0.547


  9%|▉         | 7/80 [00:03<00:36,  1.99it/s]

train loss = 0.004807057466506958, train acc = 0.54497, test loss = 0.004884533166885376, test acc = 0.549


 10%|█         | 8/80 [00:04<00:36,  1.97it/s]

train loss = 0.004783241069316864, train acc = 0.54502, test loss = 0.004864580392837524, test acc = 0.544


 11%|█▏        | 9/80 [00:04<00:35,  1.97it/s]

train loss = 0.004769763226509094, train acc = 0.54418, test loss = 0.004854818344116211, test acc = 0.547


 12%|█▎        | 10/80 [00:05<00:35,  1.98it/s]

train loss = 0.0047595687246322635, train acc = 0.54508, test loss = 0.004843274354934693, test acc = 0.548


 14%|█▍        | 11/80 [00:05<00:34,  1.99it/s]

train loss = 0.00475396889925003, train acc = 0.54431, test loss = 0.004837527513504029, test acc = 0.546


 15%|█▌        | 12/80 [00:06<00:34,  2.00it/s]

train loss = 0.0047467751169204715, train acc = 0.54496, test loss = 0.004829741477966308, test acc = 0.543


 16%|█▋        | 13/80 [00:06<00:33,  2.00it/s]

train loss = 0.004743961067199707, train acc = 0.54472, test loss = 0.004826525926589966, test acc = 0.548


 18%|█▊        | 14/80 [00:07<00:33,  2.00it/s]

train loss = 0.00473652509689331, train acc = 0.54479, test loss = 0.004821557521820069, test acc = 0.542


 19%|█▉        | 15/80 [00:07<00:32,  2.00it/s]

train loss = 0.004734155147075653, train acc = 0.54466, test loss = 0.0048184733390808105, test acc = 0.552


 20%|██        | 16/80 [00:08<00:32,  1.99it/s]

train loss = 0.0047331268835067745, train acc = 0.54481, test loss = 0.004814785003662109, test acc = 0.542


 21%|██▏       | 17/80 [00:08<00:31,  1.98it/s]

train loss = 0.004730062246322632, train acc = 0.54537, test loss = 0.004813156366348267, test acc = 0.541


 22%|██▎       | 18/80 [00:09<00:31,  1.99it/s]

train loss = 0.004727235403060913, train acc = 0.54536, test loss = 0.004814130067825317, test acc = 0.549


 24%|██▍       | 19/80 [00:09<00:31,  1.92it/s]

train loss = 0.0047252332210540775, train acc = 0.54565, test loss = 0.004811561107635498, test acc = 0.546


 25%|██▌       | 20/80 [00:10<00:31,  1.92it/s]

train loss = 0.004726328959465027, train acc = 0.54514, test loss = 0.004809439182281494, test acc = 0.545


 26%|██▋       | 21/80 [00:10<00:30,  1.92it/s]

train loss = 0.004722263314723968, train acc = 0.54551, test loss = 0.004808932304382324, test acc = 0.545


 28%|██▊       | 22/80 [00:11<00:30,  1.92it/s]

train loss = 0.004718759469985962, train acc = 0.54536, test loss = 0.0048100767135620115, test acc = 0.539


 29%|██▉       | 23/80 [00:11<00:29,  1.93it/s]

train loss = 0.004718432822227478, train acc = 0.54537, test loss = 0.004805418729782104, test acc = 0.542


 30%|███       | 24/80 [00:12<00:29,  1.93it/s]

train loss = 0.004718751995563507, train acc = 0.54525, test loss = 0.004804368019104004, test acc = 0.544


 31%|███▏      | 25/80 [00:12<00:29,  1.89it/s]

train loss = 0.0047163254022598265, train acc = 0.54528, test loss = 0.004806602001190185, test acc = 0.546


 32%|███▎      | 26/80 [00:13<00:28,  1.88it/s]

train loss = 0.004713949103355408, train acc = 0.54528, test loss = 0.00480536413192749, test acc = 0.545


 34%|███▍      | 27/80 [00:13<00:27,  1.91it/s]

train loss = 0.004715083923339844, train acc = 0.54497, test loss = 0.004804364442825317, test acc = 0.546


 35%|███▌      | 28/80 [00:14<00:27,  1.89it/s]

train loss = 0.00471404483795166, train acc = 0.54561, test loss = 0.004806877851486206, test acc = 0.535


 36%|███▋      | 29/80 [00:14<00:26,  1.92it/s]

train loss = 0.004714613139629364, train acc = 0.54514, test loss = 0.0048082141876220705, test acc = 0.545


 38%|███▊      | 30/80 [00:15<00:26,  1.92it/s]

train loss = 0.004713387308120727, train acc = 0.54569, test loss = 0.004804826736450195, test acc = 0.549


 39%|███▉      | 31/80 [00:15<00:25,  1.90it/s]

train loss = 0.0047115318274497985, train acc = 0.54543, test loss = 0.004804322481155396, test acc = 0.532


 40%|████      | 32/80 [00:16<00:25,  1.90it/s]

train loss = 0.004713640513420105, train acc = 0.54509, test loss = 0.004802637100219727, test acc = 0.543


 41%|████▏     | 33/80 [00:16<00:24,  1.92it/s]

train loss = 0.0047135748624801636, train acc = 0.54633, test loss = 0.004804875373840332, test acc = 0.534


 42%|████▎     | 34/80 [00:17<00:24,  1.90it/s]

train loss = 0.004712309989929199, train acc = 0.54586, test loss = 0.004802209377288818, test acc = 0.548


 44%|████▍     | 35/80 [00:18<00:23,  1.89it/s]

train loss = 0.004714313032627105, train acc = 0.54613, test loss = 0.004804498195648194, test acc = 0.544


 45%|████▌     | 36/80 [00:18<00:23,  1.88it/s]

train loss = 0.00471343878030777, train acc = 0.54517, test loss = 0.004801771402359009, test acc = 0.546


 46%|████▋     | 37/80 [00:19<00:22,  1.90it/s]

train loss = 0.0047126070833206175, train acc = 0.54599, test loss = 0.004803292274475098, test acc = 0.539


 48%|████▊     | 38/80 [00:19<00:22,  1.91it/s]

train loss = 0.004711179580688477, train acc = 0.54608, test loss = 0.004799712657928467, test acc = 0.548


 49%|████▉     | 39/80 [00:20<00:21,  1.92it/s]

train loss = 0.004710539424419403, train acc = 0.54573, test loss = 0.004802758932113647, test acc = 0.545


 50%|█████     | 40/80 [00:20<00:22,  1.78it/s]

train loss = 0.004711674180030823, train acc = 0.54495, test loss = 0.004801714658737182, test acc = 0.54


 51%|█████▏    | 41/80 [00:21<00:21,  1.81it/s]

train loss = 0.004712451531887055, train acc = 0.54594, test loss = 0.00479879641532898, test acc = 0.555


 52%|█████▎    | 42/80 [00:21<00:20,  1.86it/s]

train loss = 0.004710101890563965, train acc = 0.54631, test loss = 0.004802274703979492, test acc = 0.543


 54%|█████▍    | 43/80 [00:22<00:19,  1.87it/s]

train loss = 0.004711338384151459, train acc = 0.54598, test loss = 0.0048022673130035404, test acc = 0.546


 55%|█████▌    | 44/80 [00:22<00:18,  1.90it/s]

train loss = 0.004711864039897919, train acc = 0.54514, test loss = 0.004801601886749268, test acc = 0.546


 56%|█████▋    | 45/80 [00:23<00:18,  1.91it/s]

train loss = 0.004710354573726654, train acc = 0.54629, test loss = 0.004802903890609741, test acc = 0.549


 57%|█████▊    | 46/80 [00:23<00:18,  1.89it/s]

train loss = 0.004709758079051971, train acc = 0.54571, test loss = 0.004800906181335449, test acc = 0.547


 59%|█████▉    | 47/80 [00:24<00:17,  1.91it/s]

train loss = 0.004710544838905335, train acc = 0.54629, test loss = 0.004798802852630615, test acc = 0.55


 60%|██████    | 48/80 [00:24<00:16,  1.93it/s]

train loss = 0.004710602819919587, train acc = 0.5453, test loss = 0.004796726226806641, test acc = 0.557


 61%|██████▏   | 49/80 [00:25<00:16,  1.92it/s]

train loss = 0.00471005031824112, train acc = 0.54673, test loss = 0.004796028137207031, test acc = 0.556


 62%|██████▎   | 50/80 [00:25<00:15,  1.91it/s]

train loss = 0.004708787467479706, train acc = 0.54709, test loss = 0.004800733327865601, test acc = 0.54


 64%|██████▍   | 51/80 [00:26<00:15,  1.89it/s]

train loss = 0.004707887134552002, train acc = 0.54642, test loss = 0.004801681995391846, test acc = 0.553


 65%|██████▌   | 52/80 [00:27<00:14,  1.89it/s]

train loss = 0.004709899973869324, train acc = 0.54805, test loss = 0.004799724340438843, test acc = 0.549


 66%|██████▋   | 53/80 [00:27<00:14,  1.87it/s]

train loss = 0.00471101012468338, train acc = 0.54601, test loss = 0.004800884485244751, test acc = 0.545


 68%|██████▊   | 54/80 [00:28<00:14,  1.81it/s]

train loss = 0.004709787976741791, train acc = 0.54603, test loss = 0.0047986388206481935, test acc = 0.554


 69%|██████▉   | 55/80 [00:28<00:13,  1.82it/s]

train loss = 0.004709134802818298, train acc = 0.54631, test loss = 0.00479884934425354, test acc = 0.559


 70%|███████   | 56/80 [00:29<00:13,  1.83it/s]

train loss = 0.004712182238101959, train acc = 0.54635, test loss = 0.004801736354827881, test acc = 0.539


 71%|███████▏  | 57/80 [00:29<00:12,  1.84it/s]

train loss = 0.004711369254589081, train acc = 0.54688, test loss = 0.004799647808074951, test acc = 0.548


 72%|███████▎  | 58/80 [00:30<00:11,  1.87it/s]

train loss = 0.004710566287040711, train acc = 0.54634, test loss = 0.004798644542694091, test acc = 0.54


 74%|███████▍  | 59/80 [00:30<00:11,  1.88it/s]

train loss = 0.004711019304990768, train acc = 0.54705, test loss = 0.0047981483936309815, test acc = 0.547


 75%|███████▌  | 60/80 [00:31<00:10,  1.85it/s]

train loss = 0.004711514573097229, train acc = 0.54704, test loss = 0.004798704862594604, test acc = 0.549


 76%|███████▋  | 61/80 [00:31<00:10,  1.88it/s]

train loss = 0.004711718573570251, train acc = 0.54621, test loss = 0.004800042629241943, test acc = 0.552


 78%|███████▊  | 62/80 [00:32<00:09,  1.88it/s]

train loss = 0.004710539517402649, train acc = 0.54654, test loss = 0.004800405263900757, test acc = 0.544


 79%|███████▉  | 63/80 [00:32<00:08,  1.89it/s]

train loss = 0.004713464081287384, train acc = 0.54713, test loss = 0.004800037622451782, test acc = 0.539


 80%|████████  | 64/80 [00:33<00:08,  1.90it/s]

train loss = 0.004709660415649414, train acc = 0.54689, test loss = 0.004798319816589356, test acc = 0.549


 81%|████████▏ | 65/80 [00:34<00:07,  1.90it/s]

train loss = 0.004711520998477936, train acc = 0.54633, test loss = 0.004799001216888428, test acc = 0.549


 82%|████████▎ | 66/80 [00:34<00:07,  1.88it/s]

train loss = 0.0047108497309684754, train acc = 0.54595, test loss = 0.004798398494720459, test acc = 0.538


 84%|████████▍ | 67/80 [00:35<00:06,  1.87it/s]

train loss = 0.004709542245864868, train acc = 0.54692, test loss = 0.004800363540649414, test acc = 0.543


 85%|████████▌ | 68/80 [00:35<00:06,  1.84it/s]

train loss = 0.00471198972940445, train acc = 0.54639, test loss = 0.004797365188598633, test acc = 0.543


 86%|████████▋ | 69/80 [00:36<00:05,  1.86it/s]

train loss = 0.004709606640338898, train acc = 0.54638, test loss = 0.004798911809921265, test acc = 0.542


 88%|████████▊ | 70/80 [00:36<00:05,  1.87it/s]

train loss = 0.004709746415615082, train acc = 0.54631, test loss = 0.0047977535724639895, test acc = 0.544


 89%|████████▉ | 71/80 [00:37<00:04,  1.87it/s]

train loss = 0.004711196177005768, train acc = 0.54621, test loss = 0.004799327611923218, test acc = 0.552


 90%|█████████ | 72/80 [00:37<00:04,  1.88it/s]

train loss = 0.004709707458019257, train acc = 0.54641, test loss = 0.004796696424484253, test acc = 0.55


 91%|█████████▏| 73/80 [00:38<00:03,  1.90it/s]

train loss = 0.004710192782878876, train acc = 0.54635, test loss = 0.004798258543014526, test acc = 0.549


 92%|█████████▎| 74/80 [00:38<00:03,  1.89it/s]

train loss = 0.00470910561800003, train acc = 0.54664, test loss = 0.004798536062240601, test acc = 0.555


 94%|█████████▍| 75/80 [00:39<00:02,  1.90it/s]

train loss = 0.004711121730804443, train acc = 0.54661, test loss = 0.0048015379905700685, test acc = 0.542


 95%|█████████▌| 76/80 [00:39<00:02,  1.91it/s]

train loss = 0.004712517921924591, train acc = 0.54682, test loss = 0.0048001534938812256, test acc = 0.54


 96%|█████████▋| 77/80 [00:40<00:01,  1.89it/s]

train loss = 0.004708718173503876, train acc = 0.54666, test loss = 0.004799456596374512, test acc = 0.549


 98%|█████████▊| 78/80 [00:40<00:01,  1.90it/s]

train loss = 0.004709537277221679, train acc = 0.54702, test loss = 0.004798601388931275, test acc = 0.542


 99%|█████████▉| 79/80 [00:41<00:00,  1.91it/s]

train loss = 0.004712462744712829, train acc = 0.54706, test loss = 0.004800898313522339, test acc = 0.541


100%|██████████| 80/80 [00:41<00:00,  1.91it/s]

train loss = 0.004708308176994323, train acc = 0.54606, test loss = 0.004797579765319825, test acc = 0.541





In [19]:
for dh in [1,2,3,4,5]:
    pre_trained = models[dh]
    head_outputs = t.zeros((model_config.n_heads, model_config.n_ctx, model_config.d_vocab))

    ## getting DLA from each token to each possible continuation
    for i in range(1,5):
        seq = [i,0]
        X = t.Tensor(seq).unsqueeze(0).int()
        for head in range(model_config.n_heads):
            temp = get_head_to_resid(pre_trained, pre_trained.activation_cache, head, X)
            temp = get_head_to_logits(pre_trained, temp)
            #print(head_outputs[head, i-1].shape, temp[0,-1].shape)
            head_outputs[head, i-1,:] = temp[0,-1,:]
    
    seq = [1,2,3,4,0]
    keys = seq[:-1]
    X = t.Tensor(seq).unsqueeze(0).int()
    logits = pre_trained(X)
    scores = pre_trained.activation_cache['hook_attn_scores'][0]
    rankings = {}

    for head in range(model_config.n_heads):
        ## get scores of attention from last query pos to all other (key) positions
        rankings[head] = sorted(list(zip(range(len(keys)),keys, scores[head,-1,:-1].detach())), key = lambda x: x[-1], reverse = True)


    ## getting attention of each token per head


    print(f"Attention ranking for d_head = {dh} is {rankings}")
    imshow(head_outputs, facet_col = 0, y = [1,2,3,4], title = f"d_head = {dh}")

Attention ranking for d_head = 1 is {0: [(2, 3, tensor(28.1765)), (0, 1, tensor(18.2762)), (1, 2, tensor(10.2014)), (3, 4, tensor(4.1399))], 1: [(3, 4, tensor(12.3929)), (1, 2, tensor(4.8676)), (0, 1, tensor(3.6358)), (2, 3, tensor(-2.5423))]}


Attention ranking for d_head = 2 is {0: [(2, 3, tensor(13.2931)), (3, 4, tensor(11.7319)), (1, 2, tensor(11.1438)), (0, 1, tensor(-7.8753))], 1: [(1, 2, tensor(20.3196)), (0, 1, tensor(19.8794)), (3, 4, tensor(8.4864)), (2, 3, tensor(-18.3274))]}


Attention ranking for d_head = 3 is {0: [(3, 4, tensor(22.2124)), (0, 1, tensor(13.2219)), (2, 3, tensor(4.4880)), (1, 2, tensor(-2.8268))], 1: [(1, 2, tensor(17.8666)), (2, 3, tensor(16.2000)), (0, 1, tensor(7.6122)), (3, 4, tensor(-19.8617))]}


Attention ranking for d_head = 4 is {0: [(1, 2, tensor(26.4344)), (0, 1, tensor(16.0252)), (2, 3, tensor(4.7179)), (3, 4, tensor(-3.8639))], 1: [(3, 4, tensor(21.4538)), (2, 3, tensor(14.8827)), (0, 1, tensor(3.4966)), (1, 2, tensor(-3.7463))]}


Attention ranking for d_head = 5 is {0: [(0, 1, tensor(22.6485)), (2, 3, tensor(14.0099)), (1, 2, tensor(3.1352)), (3, 4, tensor(-6.6797))], 1: [(3, 4, tensor(23.3817)), (1, 2, tensor(14.6686)), (2, 3, tensor(6.0567)), (0, 1, tensor(-4.8255))]}


In [18]:
ind = 5
print(train_losses[ind][-1],train_accs[ind][-1], test_losses[ind][-1], test_accs[ind][-1])
plot_metrics(train_losses[ind],train_accs[ind], test_losses[ind], test_accs[ind]) 

0.004708308176994323 0.54606 0.004797579765319825 0.541
