In [17]:
import torch, math
torch.set_printoptions(precision=2, sci_mode=False, linewidth=200)
def generate_memory(n_tokens):
    # memory = torch.randint(0, n_tokens, (n_tokens, n_tokens))
    memory = torch.triu(torch.ones(n_tokens, n_tokens), diagonal=0)
    memory = memory * torch.randint_like(memory, 0, n_tokens)
    memory = memory + memory.T - torch.diag(memory.diagonal())
    memory = memory.long()
    return memory

generate_memory(10)

tensor([[7, 1, 4, 4, 5, 6, 0, 7, 6, 0],
        [1, 2, 8, 4, 0, 3, 8, 1, 9, 7],
        [4, 8, 3, 6, 2, 1, 6, 2, 3, 9],
        [4, 4, 6, 6, 1, 7, 4, 1, 3, 5],
        [5, 0, 2, 1, 3, 8, 4, 0, 0, 1],
        [6, 3, 1, 7, 8, 6, 7, 1, 3, 4],
        [0, 8, 6, 4, 4, 7, 2, 1, 4, 6],
        [7, 1, 2, 1, 0, 1, 1, 1, 7, 9],
        [6, 9, 3, 3, 0, 3, 4, 7, 8, 7],
        [0, 7, 9, 5, 1, 4, 6, 9, 7, 6]])

In [18]:
import torch
import random
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm


class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_head, dim_feedforward, dropout=0.0):
        super().__init__()
        self.d_model = d_model
        self.n_head = n_head
        self.dim_feedforward = dim_feedforward

        self.self_attn = nn.MultiheadAttention(d_model, n_head, batch_first=True, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    @property
    def device(self):
        return next(self.parameters()).device

    def forward(self, tgt, skip_feedforward=False, skip_self_attn=False, linear_mask=None, history=None, custom_attention=False, head_mask=None):
        hist = {}
        attn_heads = None
        if not skip_self_attn:
            # mask = torch.triu(torch.ones(tgt.shape[1], tgt.shape[1]), diagonal=1).bool().to(self.device)
            # tgt2, attn_heads = hist['self_attn_non_residual'] = self.self_attn(tgt, tgt, tgt, attn_mask=mask, average_attn_weights=False)
            tgt2 = hist['self_attn_non_residual'] = self.attn_forward(tgt, self.self_attn, custom_attention, hist, head_mask)
            tgt = hist['self_attn'] = tgt + tgt2
        tgt = hist['norm1'] = self.norm1(tgt)
        if self.dim_feedforward > 0 and not skip_feedforward:
            tgt2 = hist['linear1'] = nn.functional.relu(self.linear1(tgt))
            tgt2 = hist['linear1_dropout'] = self.dropout(tgt2)
            if linear_mask is not None:
                tgt2 = tgt2 * linear_mask
            tgt2 = hist['linear2_non_residual'] = self.linear2(tgt2)
            tgt = hist['linear2'] = tgt + tgt2
        tgt = hist['norm2'] = self.norm2(tgt)
        return tgt if history is None else hist[history], attn_heads
    
    def attn_forward(self, x, attn, custom_attention, history, head_mask=None):
        attn_mask = torch.tril(torch.ones(x.shape[1], x.shape[1]), diagonal=0).to(self.device)
        if not custom_attention:
            return attn(x, x, x, attn_mask=attn_mask)[0]

        batch_size = x.shape[0]
        x = x.transpose(0, 1)
        # this is just torch's attention but expanded so we can modify it
        proj = F.linear(x, attn.in_proj_weight, attn.in_proj_bias)
        proj = proj.unflatten(-1, (3, self.d_model)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
        q, k, v = proj[0], proj[1], proj[2]
        q = q.unflatten(-1, (self.n_head, self.d_model // self.n_head)).permute(1, 2, 0, 3)
        k = k.unflatten(-1, (self.n_head, self.d_model // self.n_head)).permute(1, 2, 0, 3)
        v = v.unflatten(-1, (self.n_head, self.d_model // self.n_head)).permute(1, 2, 0, 3)
        
        history.update({
            'q': q,
            'k': k,
            'v': v
        })

        attn_mask = attn_mask.masked_fill(attn_mask == False, float('-inf'))
        
        attn_output = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1)) + attn_mask
        attn_output = F.softmax(attn_output, dim=-1)
        attn_output = torch.matmul(attn_output, v)
        attn_output = attn_output.permute(2, 0, 1, 3).contiguous()  # [seq_len, batch_size, n_head, d_model // n_head]
        # apply head mask
        if head_mask is not None:
            attn_output = attn_output * head_mask[None, None, :, None]
        
        history['attn_output'] = attn_output

        attn_output = attn_output.flatten(-2, -1)
        attn_output = F.linear(attn_output, attn.out_proj.weight, attn.out_proj.bias)
        return attn_output.transpose(0, 1)

class ToyTransformer(nn.Module):
    def __init__(self, n_layers, d_model, n_head, hidden_size, n_tokens, max_len, dropout=0.0):
        super().__init__()
        self.n_layers = n_layers
        self.d_model = d_model
        self.n_head = n_head
        self.hidden_size = hidden_size
        self.tokens = list(range(n_tokens))
        self.max_len = max_len

        self.embed = nn.Embedding(n_tokens, embedding_dim=d_model)

        self.layers = nn.ModuleList([
            DecoderLayer(d_model=d_model, n_head=n_head, dim_feedforward=hidden_size)
            for _ in range(n_layers)
        ])
        self.unembed = nn.Linear(d_model, n_tokens)

    @property
    def device(self):
        return next(self.parameters()).device

    def forward(
        self,
        x,
        skip_feedforward=False,
        skip_self_attn=False,
        return_before_embedding=False,
        linear_mask=None,
        history=None,
        return_attn_weights=False,
        custom_attention=False,
        head_mask=None):
        if head_mask is not None:
            custom_attention = True
        tgt = self.embed(x)
        tgt = F.pad(tgt, (0, 0, 0, 1))  # [batch_size, seq_len + 1, d_model]
        for layer in self.layers:
            tgt, attn_heads = layer(
                tgt,
                skip_feedforward=skip_feedforward,
                skip_self_attn=skip_self_attn,
                linear_mask=linear_mask,
                history=history,
                custom_attention=custom_attention,
                head_mask=head_mask)
            if history is not None:
                return tgt
        if return_before_embedding:
            return tgt
        x = self.unembed(tgt)
        if return_attn_weights:
            return x, attn_heads
        return x

    def train(self, memory, lr=1e-3, batch_size=128, n_epochs=1000):
        optimizer = optim.Adam(self.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()

        for _ in tqdm(range(n_epochs)):
            batch = self.generate_data(batch_size, memory)
            optimizer.zero_grad()
            input = batch[:, :-1]
            output = self(input)
            loss = criterion(output.reshape(-1, len(self.tokens)), batch.reshape(-1))

            loss.backward()
            optimizer.step()

        print('loss: ', loss.item())

    def generate_data(self, batch_size, memory):
        random_indices = torch.randint(0, n_tokens, (batch_size, 2))  # [batch_size, 2]
        next_tokens = memory[random_indices[:, 0], random_indices[:, 1]].unsqueeze(1)  # [batch_size, 1]
        tensor = torch.cat([random_indices, next_tokens], dim=1)
        return tensor.to(self.device)



In [22]:
hidden_size = 64
n_tokens = 40
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
model = ToyTransformer(n_layers=3, d_model=16, n_head=4, hidden_size=hidden_size, n_tokens=n_tokens, max_len=3, dropout=0.1).to(device)
memory = generate_memory(n_tokens)
history = []
while True:
    for epoch in range(999):
        model.train(memory, lr=1e-2, n_epochs=100, batch_size=144 * 2)
        samples = 1000
        data = model.generate_data(samples, memory)
        output = (model(data[:,:-1])[:,-1,:].argmax(dim=-1))
        accuracy = output.eq(data[:,-1]).sum().item() / samples
        print('Accuracy: ', accuracy)
        if accuracy == 1.0:
            print(f'Finished training in {epoch} epochs')
            history.append(epoch)
            memory = generate_memory(n_tokens)
            break


  0%|          | 0/100 [00:00<?, ?it/s]

100%|██████████| 100/100 [00:01<00:00, 51.86it/s]


loss:  0.7550489902496338
Accuracy:  0.432


100%|██████████| 100/100 [00:01<00:00, 62.34it/s]


loss:  0.17218740284442902
Accuracy:  0.835


100%|██████████| 100/100 [00:01<00:00, 58.32it/s]


loss:  0.018484337255358696
Accuracy:  0.995


100%|██████████| 100/100 [00:01<00:00, 60.17it/s]


loss:  0.006250449456274509
Accuracy:  1.0
Finished training in 3 epochs


100%|██████████| 100/100 [00:01<00:00, 57.03it/s]


loss:  0.568324863910675
Accuracy:  0.557


100%|██████████| 100/100 [00:01<00:00, 53.86it/s]


loss:  0.19932609796524048
Accuracy:  0.847


100%|██████████| 100/100 [00:01<00:00, 58.89it/s]


loss:  0.03722058981657028
Accuracy:  0.97


100%|██████████| 100/100 [00:01<00:00, 57.88it/s]


loss:  0.007524285931140184
Accuracy:  0.998


100%|██████████| 100/100 [00:01<00:00, 59.43it/s]


loss:  0.006701682228595018
Accuracy:  1.0
Finished training in 4 epochs


100%|██████████| 100/100 [00:01<00:00, 55.46it/s]


loss:  0.5959761142730713
Accuracy:  0.538


100%|██████████| 100/100 [00:01<00:00, 63.22it/s]


loss:  0.27054154872894287
Accuracy:  0.823


100%|██████████| 100/100 [00:01<00:00, 62.48it/s]


loss:  0.17694872617721558
Accuracy:  0.85


100%|██████████| 100/100 [00:01<00:00, 62.60it/s]


loss:  0.015762673690915108
Accuracy:  0.991


100%|██████████| 100/100 [00:01<00:00, 62.55it/s]


loss:  0.011228998191654682
Accuracy:  1.0
Finished training in 4 epochs


100%|██████████| 100/100 [00:01<00:00, 57.82it/s]


loss:  0.6276482939720154
Accuracy:  0.487


100%|██████████| 100/100 [00:01<00:00, 59.05it/s]


loss:  0.3395613431930542
Accuracy:  0.688


100%|██████████| 100/100 [00:01<00:00, 58.93it/s]


loss:  0.20652669668197632
Accuracy:  0.815


100%|██████████| 100/100 [00:01<00:00, 59.87it/s]


loss:  0.17627695202827454
Accuracy:  0.877


100%|██████████| 100/100 [00:01<00:00, 59.03it/s]


loss:  0.1108958050608635
Accuracy:  0.94


100%|██████████| 100/100 [00:01<00:00, 50.71it/s]


loss:  0.006048696581274271
Accuracy:  0.996


100%|██████████| 100/100 [00:01<00:00, 60.27it/s]


loss:  0.0022060424089431763
Accuracy:  1.0
Finished training in 6 epochs


100%|██████████| 100/100 [00:01<00:00, 59.88it/s]


loss:  0.7497039437294006
Accuracy:  0.391


100%|██████████| 100/100 [00:01<00:00, 57.73it/s]


loss:  0.41760939359664917
Accuracy:  0.631


100%|██████████| 100/100 [00:01<00:00, 59.26it/s]


loss:  0.2117954045534134
Accuracy:  0.8


100%|██████████| 100/100 [00:01<00:00, 60.06it/s]


loss:  0.03847656771540642
Accuracy:  0.938


100%|██████████| 100/100 [00:01<00:00, 57.54it/s]


loss:  0.014851180836558342
Accuracy:  1.0
Finished training in 4 epochs


100%|██████████| 100/100 [00:01<00:00, 61.17it/s]


loss:  0.7556623816490173
Accuracy:  0.346


100%|██████████| 100/100 [00:01<00:00, 60.44it/s]


loss:  0.4236052930355072
Accuracy:  0.644


100%|██████████| 100/100 [00:01<00:00, 60.90it/s]


loss:  0.3614368736743927
Accuracy:  0.726


100%|██████████| 100/100 [00:01<00:00, 60.49it/s]


loss:  0.23052619397640228
Accuracy:  0.797


100%|██████████| 100/100 [00:01<00:00, 57.94it/s]


loss:  0.08825845271348953
Accuracy:  0.918


100%|██████████| 100/100 [00:01<00:00, 60.11it/s]


loss:  0.017529258504509926
Accuracy:  0.99


100%|██████████| 100/100 [00:01<00:00, 60.88it/s]


loss:  0.028869928792119026
Accuracy:  0.993


100%|██████████| 100/100 [00:01<00:00, 58.96it/s]


loss:  0.03825359418988228
Accuracy:  0.994


100%|██████████| 100/100 [00:01<00:00, 62.23it/s]


loss:  0.15833349525928497
Accuracy:  0.835


100%|██████████| 100/100 [00:01<00:00, 56.91it/s]


loss:  0.06636983901262283
Accuracy:  0.924


100%|██████████| 100/100 [00:01<00:00, 58.92it/s]


loss:  0.015254157595336437
Accuracy:  0.988


100%|██████████| 100/100 [00:01<00:00, 61.03it/s]


loss:  0.028926627710461617
Accuracy:  0.985


100%|██████████| 100/100 [00:01<00:00, 59.23it/s]


loss:  0.021795667707920074
Accuracy:  0.988


100%|██████████| 100/100 [00:01<00:00, 56.55it/s]


loss:  0.03588704392313957
Accuracy:  0.991


100%|██████████| 100/100 [00:01<00:00, 59.29it/s]


loss:  0.030179845169186592
Accuracy:  0.984


100%|██████████| 100/100 [00:01<00:00, 58.07it/s]


loss:  0.011445862241089344
Accuracy:  0.993


100%|██████████| 100/100 [00:01<00:00, 60.08it/s]


loss:  0.0028058465104550123
Accuracy:  1.0
Finished training in 16 epochs


100%|██████████| 100/100 [00:01<00:00, 59.30it/s]


loss:  0.8884651064872742
Accuracy:  0.3


100%|██████████| 100/100 [00:01<00:00, 61.47it/s]


loss:  0.4341450333595276
Accuracy:  0.56


100%|██████████| 100/100 [00:01<00:00, 55.20it/s]


loss:  0.3477916717529297
Accuracy:  0.718


100%|██████████| 100/100 [00:01<00:00, 59.99it/s]


loss:  0.26819872856140137
Accuracy:  0.778


100%|██████████| 100/100 [00:01<00:00, 59.33it/s]


loss:  0.10967680811882019
Accuracy:  0.938


100%|██████████| 100/100 [00:01<00:00, 58.99it/s]


loss:  0.08830738067626953
Accuracy:  0.931


100%|██████████| 100/100 [00:01<00:00, 59.66it/s]


loss:  0.010166054591536522
Accuracy:  0.993


100%|██████████| 100/100 [00:01<00:00, 55.91it/s]


loss:  0.020569054409861565
Accuracy:  0.996


100%|██████████| 100/100 [00:01<00:00, 60.19it/s]


loss:  0.0139619754627347
Accuracy:  0.992


100%|██████████| 100/100 [00:01<00:00, 60.54it/s]


loss:  0.09340084344148636
Accuracy:  0.984


100%|██████████| 100/100 [00:01<00:00, 59.20it/s]


loss:  0.004948828835040331
Accuracy:  1.0
Finished training in 10 epochs


100%|██████████| 100/100 [00:01<00:00, 58.92it/s]


loss:  0.8927263617515564
Accuracy:  0.248


100%|██████████| 100/100 [00:01<00:00, 57.21it/s]


loss:  0.5418843626976013
Accuracy:  0.555


100%|██████████| 100/100 [00:01<00:00, 60.88it/s]


loss:  0.3815271556377411
Accuracy:  0.713


100%|██████████| 100/100 [00:01<00:00, 57.91it/s]


loss:  0.27326664328575134
Accuracy:  0.753


100%|██████████| 100/100 [00:01<00:00, 60.45it/s]


loss:  0.19977468252182007
Accuracy:  0.837


100%|██████████| 100/100 [00:01<00:00, 58.98it/s]


loss:  0.09751518815755844
Accuracy:  0.938


100%|██████████| 100/100 [00:01<00:00, 58.90it/s]


loss:  0.03194683790206909
Accuracy:  0.983


100%|██████████| 100/100 [00:01<00:00, 56.78it/s]


loss:  0.019657084718346596
Accuracy:  0.99


100%|██████████| 100/100 [00:01<00:00, 59.77it/s]


loss:  0.0779702216386795
Accuracy:  0.952


100%|██████████| 100/100 [00:01<00:00, 59.05it/s]


loss:  0.011208717711269855
Accuracy:  0.994


100%|██████████| 100/100 [00:01<00:00, 60.47it/s]


loss:  0.007801429368555546
Accuracy:  0.994


100%|██████████| 100/100 [00:01<00:00, 58.76it/s]


loss:  0.003618016839027405
Accuracy:  1.0
Finished training in 11 epochs


100%|██████████| 100/100 [00:01<00:00, 54.81it/s]


loss:  0.9475893378257751
Accuracy:  0.211


100%|██████████| 100/100 [00:01<00:00, 59.24it/s]


loss:  0.6699199676513672
Accuracy:  0.418


100%|██████████| 100/100 [00:01<00:00, 58.28it/s]


loss:  0.408300518989563
Accuracy:  0.615


100%|██████████| 100/100 [00:01<00:00, 53.09it/s]


loss:  0.37515997886657715
Accuracy:  0.696


100%|██████████| 100/100 [00:01<00:00, 58.60it/s]


loss:  0.22867900133132935
Accuracy:  0.788


100%|██████████| 100/100 [00:01<00:00, 58.35it/s]


loss:  0.15188048779964447
Accuracy:  0.874


100%|██████████| 100/100 [00:01<00:00, 60.16it/s]


loss:  0.04624839127063751
Accuracy:  0.963


100%|██████████| 100/100 [00:01<00:00, 60.22it/s]


loss:  0.10596963763237
Accuracy:  0.882


100%|██████████| 100/100 [00:01<00:00, 56.28it/s]


loss:  0.23384657502174377
Accuracy:  0.812


100%|██████████| 100/100 [00:01<00:00, 60.50it/s]


loss:  0.005777914542704821
Accuracy:  0.997


100%|██████████| 100/100 [00:01<00:00, 59.14it/s]


loss:  0.11040297895669937
Accuracy:  0.903


100%|██████████| 100/100 [00:01<00:00, 59.96it/s]


loss:  0.049073949456214905
Accuracy:  0.975


100%|██████████| 100/100 [00:01<00:00, 58.89it/s]


loss:  0.011504173278808594
Accuracy:  1.0
Finished training in 12 epochs


100%|██████████| 100/100 [00:01<00:00, 58.23it/s]


loss:  0.9492690563201904
Accuracy:  0.204


100%|██████████| 100/100 [00:01<00:00, 58.23it/s]


loss:  0.6703513264656067
Accuracy:  0.412


100%|██████████| 100/100 [00:01<00:00, 60.73it/s]


loss:  0.43237805366516113
Accuracy:  0.627


100%|██████████| 100/100 [00:01<00:00, 60.67it/s]


loss:  0.35912609100341797
Accuracy:  0.693


100%|██████████| 100/100 [00:01<00:00, 60.44it/s]


loss:  0.3016383647918701
Accuracy:  0.709


100%|██████████| 100/100 [00:01<00:00, 59.32it/s]


loss:  0.15507207810878754
Accuracy:  0.876


100%|██████████| 100/100 [00:01<00:00, 56.38it/s]


loss:  0.09348248690366745
Accuracy:  0.885


100%|██████████| 100/100 [00:01<00:00, 59.54it/s]


loss:  0.0881350040435791
Accuracy:  0.925


100%|██████████| 100/100 [00:01<00:00, 57.91it/s]


loss:  0.09372253715991974
Accuracy:  0.929


100%|██████████| 100/100 [00:01<00:00, 59.58it/s]


loss:  0.1118120476603508
Accuracy:  0.888


100%|██████████| 100/100 [00:01<00:00, 60.66it/s]


loss:  0.018474487587809563
Accuracy:  0.992


100%|██████████| 100/100 [00:01<00:00, 55.14it/s]


loss:  0.06124885380268097
Accuracy:  0.967


100%|██████████| 100/100 [00:01<00:00, 58.75it/s]


loss:  0.051035597920417786
Accuracy:  0.952


100%|██████████| 100/100 [00:01<00:00, 60.93it/s]


loss:  0.038973063230514526
Accuracy:  0.979


100%|██████████| 100/100 [00:01<00:00, 60.21it/s]


loss:  0.06234314292669296
Accuracy:  0.948


100%|██████████| 100/100 [00:01<00:00, 60.49it/s]


loss:  0.024198006838560104
Accuracy:  0.979


100%|██████████| 100/100 [00:01<00:00, 57.46it/s]


loss:  0.06348630040884018
Accuracy:  0.957


100%|██████████| 100/100 [00:01<00:00, 60.67it/s]


loss:  0.010188550688326359
Accuracy:  0.992


100%|██████████| 100/100 [00:01<00:00, 58.17it/s]


loss:  0.07256989181041718
Accuracy:  0.942


100%|██████████| 100/100 [00:01<00:00, 59.99it/s]


loss:  0.016146676614880562
Accuracy:  0.997


100%|██████████| 100/100 [00:01<00:00, 59.19it/s]


loss:  0.06761685013771057
Accuracy:  0.935


100%|██████████| 100/100 [00:01<00:00, 60.32it/s]


loss:  0.011999647133052349
Accuracy:  0.986


100%|██████████| 100/100 [00:01<00:00, 54.16it/s]


loss:  0.06639953702688217
Accuracy:  0.948


100%|██████████| 100/100 [00:01<00:00, 58.59it/s]


loss:  0.04279300943017006
Accuracy:  0.974


100%|██████████| 100/100 [00:01<00:00, 58.90it/s]


loss:  0.15423119068145752
Accuracy:  0.867


100%|██████████| 100/100 [00:01<00:00, 57.98it/s]


loss:  0.18580657243728638
Accuracy:  0.879


100%|██████████| 100/100 [00:01<00:00, 59.61it/s]


loss:  0.021303752437233925
Accuracy:  0.977


100%|██████████| 100/100 [00:01<00:00, 55.38it/s]


loss:  0.010802561417222023
Accuracy:  0.995


100%|██████████| 100/100 [00:01<00:00, 59.01it/s]


loss:  0.10519900172948837
Accuracy:  0.941


100%|██████████| 100/100 [00:01<00:00, 59.61it/s]


loss:  0.004185507073998451
Accuracy:  1.0
Finished training in 29 epochs


100%|██████████| 100/100 [00:01<00:00, 57.55it/s]


loss:  0.9441618323326111
Accuracy:  0.258


100%|██████████| 100/100 [00:01<00:00, 60.31it/s]


loss:  0.6155077815055847
Accuracy:  0.441


100%|██████████| 100/100 [00:01<00:00, 55.30it/s]


loss:  0.4729582369327545
Accuracy:  0.63


100%|██████████| 100/100 [00:01<00:00, 58.92it/s]


loss:  0.3310343027114868
Accuracy:  0.717


100%|██████████| 100/100 [00:01<00:00, 58.55it/s]


loss:  0.2181536704301834
Accuracy:  0.809


100%|██████████| 100/100 [00:01<00:00, 59.52it/s]


loss:  0.24477319419384003
Accuracy:  0.795


100%|██████████| 100/100 [00:01<00:00, 58.71it/s]


loss:  0.12922273576259613
Accuracy:  0.906


100%|██████████| 100/100 [00:01<00:00, 54.59it/s]


loss:  0.03358851373195648
Accuracy:  0.98


100%|██████████| 100/100 [00:01<00:00, 59.49it/s]


loss:  0.0520222969353199
Accuracy:  0.976


100%|██████████| 100/100 [00:01<00:00, 58.42it/s]


loss:  0.17950445413589478
Accuracy:  0.854


100%|██████████| 100/100 [00:01<00:00, 57.53it/s]


loss:  0.110144704580307
Accuracy:  0.887


100%|██████████| 100/100 [00:01<00:00, 58.71it/s]


loss:  0.05640792101621628
Accuracy:  0.949


100%|██████████| 100/100 [00:01<00:00, 57.49it/s]


loss:  0.012376599945127964
Accuracy:  1.0
Finished training in 12 epochs


100%|██████████| 100/100 [00:01<00:00, 59.24it/s]


loss:  1.0298993587493896
Accuracy:  0.157


100%|██████████| 100/100 [00:01<00:00, 58.89it/s]


loss:  0.7470618486404419
Accuracy:  0.338


100%|██████████| 100/100 [00:01<00:00, 59.04it/s]


loss:  0.567891001701355
Accuracy:  0.488


100%|██████████| 100/100 [00:01<00:00, 59.81it/s]


loss:  0.38648611307144165
Accuracy:  0.588


100%|██████████| 100/100 [00:01<00:00, 58.95it/s]


loss:  0.37561485171318054
Accuracy:  0.643


100%|██████████| 100/100 [00:01<00:00, 54.77it/s]


loss:  0.3339044749736786
Accuracy:  0.761


100%|██████████| 100/100 [00:01<00:00, 57.93it/s]


loss:  0.18093182146549225
Accuracy:  0.787


100%|██████████| 100/100 [00:01<00:00, 56.35it/s]


loss:  0.2734350860118866
Accuracy:  0.752


100%|██████████| 100/100 [00:01<00:00, 58.55it/s]


loss:  0.18346373736858368
Accuracy:  0.863


100%|██████████| 100/100 [00:01<00:00, 58.20it/s]


loss:  0.24843864142894745
Accuracy:  0.79


100%|██████████| 100/100 [00:01<00:00, 55.62it/s]


loss:  0.11047355085611343
Accuracy:  0.951


100%|██████████| 100/100 [00:01<00:00, 57.87it/s]


loss:  0.03044791705906391
Accuracy:  0.962


100%|██████████| 100/100 [00:01<00:00, 58.47it/s]


loss:  0.018541505560278893
Accuracy:  0.998


100%|██████████| 100/100 [00:01<00:00, 59.08it/s]


loss:  0.11634504050016403
Accuracy:  0.895


100%|██████████| 100/100 [00:01<00:00, 57.78it/s]


loss:  0.10811571031808853
Accuracy:  0.914


100%|██████████| 100/100 [00:01<00:00, 54.13it/s]


loss:  0.03216610848903656
Accuracy:  0.977


100%|██████████| 100/100 [00:01<00:00, 59.53it/s]


loss:  0.014625008217990398
Accuracy:  0.996


100%|██████████| 100/100 [00:01<00:00, 55.68it/s]


loss:  0.14868122339248657
Accuracy:  0.913


100%|██████████| 100/100 [00:01<00:00, 58.37it/s]


loss:  0.005856305826455355
Accuracy:  1.0
Finished training in 18 epochs


100%|██████████| 100/100 [00:01<00:00, 58.16it/s]


loss:  1.0050207376480103
Accuracy:  0.199


100%|██████████| 100/100 [00:01<00:00, 58.35it/s]


loss:  0.7699189186096191
Accuracy:  0.359


100%|██████████| 100/100 [00:01<00:00, 55.20it/s]


loss:  0.6182875633239746
Accuracy:  0.472


100%|██████████| 100/100 [00:01<00:00, 58.21it/s]


loss:  0.5025714635848999
Accuracy:  0.594


100%|██████████| 100/100 [00:01<00:00, 57.75it/s]


loss:  0.4351215064525604
Accuracy:  0.609


100%|██████████| 100/100 [00:01<00:00, 58.58it/s]


loss:  0.38340169191360474
Accuracy:  0.696


100%|██████████| 100/100 [00:01<00:00, 58.92it/s]


loss:  0.2675374746322632
Accuracy:  0.76


100%|██████████| 100/100 [00:01<00:00, 55.75it/s]


loss:  0.19008533656597137
Accuracy:  0.86


100%|██████████| 100/100 [00:01<00:00, 57.89it/s]


loss:  0.16209784150123596
Accuracy:  0.856


100%|██████████| 100/100 [00:01<00:00, 57.09it/s]


loss:  0.16673551499843597
Accuracy:  0.808


100%|██████████| 100/100 [00:01<00:00, 58.42it/s]


loss:  0.15191620588302612
Accuracy:  0.879


100%|██████████| 100/100 [00:01<00:00, 57.17it/s]


loss:  0.08855611085891724
Accuracy:  0.938


100%|██████████| 100/100 [00:01<00:00, 54.44it/s]


loss:  0.15069933235645294
Accuracy:  0.855


100%|██████████| 100/100 [00:01<00:00, 57.85it/s]


loss:  0.04194381833076477
Accuracy:  0.982


100%|██████████| 100/100 [00:01<00:00, 56.01it/s]


loss:  0.05774220451712608
Accuracy:  0.959


100%|██████████| 100/100 [00:01<00:00, 58.04it/s]


loss:  0.0865524485707283
Accuracy:  0.943


100%|██████████| 100/100 [00:01<00:00, 55.68it/s]


loss:  0.08419442176818848
Accuracy:  0.913


100%|██████████| 100/100 [00:01<00:00, 53.98it/s]


loss:  0.05881360545754433
Accuracy:  0.957


100%|██████████| 100/100 [00:01<00:00, 56.35it/s]


loss:  0.07151926308870316
Accuracy:  0.96


100%|██████████| 100/100 [00:01<00:00, 56.96it/s]


loss:  0.015408607199788094
Accuracy:  0.998


100%|██████████| 100/100 [00:01<00:00, 56.86it/s]


loss:  0.05365525186061859
Accuracy:  0.973


100%|██████████| 100/100 [00:01<00:00, 57.85it/s]


loss:  0.05464835837483406
Accuracy:  0.962


100%|██████████| 100/100 [00:01<00:00, 52.79it/s]


loss:  0.03918624296784401
Accuracy:  0.978


100%|██████████| 100/100 [00:01<00:00, 57.46it/s]


loss:  0.016877420246601105
Accuracy:  0.995


100%|██████████| 100/100 [00:01<00:00, 55.92it/s]


loss:  0.027585327625274658
Accuracy:  0.977


100%|██████████| 100/100 [00:01<00:00, 56.54it/s]


loss:  0.09705983102321625
Accuracy:  0.913


100%|██████████| 100/100 [00:01<00:00, 56.76it/s]


loss:  0.011716087348759174
Accuracy:  0.997


100%|██████████| 100/100 [00:01<00:00, 50.76it/s]


loss:  0.06486054509878159
Accuracy:  0.967


100%|██████████| 100/100 [00:01<00:00, 57.72it/s]


loss:  0.050756555050611496
Accuracy:  0.963


100%|██████████| 100/100 [00:01<00:00, 53.81it/s]


loss:  0.042935553938150406
Accuracy:  0.977


100%|██████████| 100/100 [00:01<00:00, 56.94it/s]


loss:  0.01767190732061863
Accuracy:  0.994


100%|██████████| 100/100 [00:01<00:00, 58.86it/s]


loss:  0.027548765763640404
Accuracy:  0.984


100%|██████████| 100/100 [00:01<00:00, 55.05it/s]


loss:  0.2555031478404999
Accuracy:  0.871


100%|██████████| 100/100 [00:01<00:00, 57.44it/s]


loss:  0.059887032955884933
Accuracy:  0.948


100%|██████████| 100/100 [00:01<00:00, 53.42it/s]


loss:  0.17414140701293945
Accuracy:  0.919


100%|██████████| 100/100 [00:01<00:00, 57.06it/s]


loss:  0.06496204435825348
Accuracy:  0.935


100%|██████████| 100/100 [00:01<00:00, 56.88it/s]


loss:  0.03593578562140465
Accuracy:  0.977


100%|██████████| 100/100 [00:01<00:00, 57.79it/s]


loss:  0.05992095172405243
Accuracy:  0.977


100%|██████████| 100/100 [00:01<00:00, 57.13it/s]


loss:  0.005086713936179876
Accuracy:  1.0
Finished training in 38 epochs


100%|██████████| 100/100 [00:01<00:00, 52.01it/s]


loss:  1.0619113445281982
Accuracy:  0.15


100%|██████████| 100/100 [00:01<00:00, 55.60it/s]


loss:  0.8790237903594971
Accuracy:  0.272


100%|██████████| 100/100 [00:01<00:00, 56.04it/s]


loss:  0.7161864042282104
Accuracy:  0.433


100%|██████████| 100/100 [00:01<00:00, 57.25it/s]


loss:  0.42726433277130127
Accuracy:  0.564


100%|██████████| 100/100 [00:01<00:00, 57.51it/s]


loss:  0.5599825978279114
Accuracy:  0.536


100%|██████████| 100/100 [00:01<00:00, 52.67it/s]


loss:  0.43185317516326904
Accuracy:  0.663


100%|██████████| 100/100 [00:01<00:00, 57.36it/s]


loss:  0.28700995445251465
Accuracy:  0.768


100%|██████████| 100/100 [00:01<00:00, 56.45it/s]


loss:  0.19789698719978333
Accuracy:  0.828


100%|██████████| 100/100 [00:01<00:00, 57.71it/s]


loss:  0.28754517436027527
Accuracy:  0.741


100%|██████████| 100/100 [00:01<00:00, 55.07it/s]


loss:  0.21490763127803802
Accuracy:  0.842


100%|██████████| 100/100 [00:01<00:00, 52.76it/s]


loss:  0.283600777387619
Accuracy:  0.729


100%|██████████| 100/100 [00:01<00:00, 55.93it/s]


loss:  0.125719353556633
Accuracy:  0.889


100%|██████████| 100/100 [00:01<00:00, 55.60it/s]


loss:  0.0885428711771965
Accuracy:  0.929


100%|██████████| 100/100 [00:01<00:00, 56.61it/s]


loss:  0.09153762459754944
Accuracy:  0.933


100%|██████████| 100/100 [00:01<00:00, 56.66it/s]


loss:  0.10884925723075867
Accuracy:  0.935


100%|██████████| 100/100 [00:01<00:00, 51.57it/s]


loss:  0.07355565577745438
Accuracy:  0.95


100%|██████████| 100/100 [00:01<00:00, 54.82it/s]


loss:  0.08749043196439743
Accuracy:  0.929


100%|██████████| 100/100 [00:01<00:00, 56.85it/s]


loss:  0.12402376532554626
Accuracy:  0.878


100%|██████████| 100/100 [00:01<00:00, 57.28it/s]


loss:  0.16448701918125153
Accuracy:  0.878


100%|██████████| 100/100 [00:01<00:00, 56.66it/s]


loss:  0.09785189479589462
Accuracy:  0.966


100%|██████████| 100/100 [00:01<00:00, 51.86it/s]


loss:  0.12377499788999557
Accuracy:  0.837


100%|██████████| 100/100 [00:01<00:00, 55.67it/s]


loss:  0.13746052980422974
Accuracy:  0.9


100%|██████████| 100/100 [00:01<00:00, 57.69it/s]


loss:  0.07727889716625214
Accuracy:  0.951


100%|██████████| 100/100 [00:01<00:00, 54.43it/s]


loss:  0.09704781323671341
Accuracy:  0.921


100%|██████████| 100/100 [00:01<00:00, 56.33it/s]


loss:  0.07887343317270279
Accuracy:  0.95


100%|██████████| 100/100 [00:01<00:00, 51.81it/s]


loss:  0.12545157968997955
Accuracy:  0.882


100%|██████████| 100/100 [00:01<00:00, 54.96it/s]


loss:  0.24134041368961334
Accuracy:  0.851


100%|██████████| 100/100 [00:01<00:00, 56.03it/s]


loss:  0.04392506182193756
Accuracy:  0.979


100%|██████████| 100/100 [00:01<00:00, 54.86it/s]


loss:  0.0668271854519844
Accuracy:  0.966


100%|██████████| 100/100 [00:01<00:00, 56.99it/s]


loss:  0.023951910436153412
Accuracy:  0.988


100%|██████████| 100/100 [00:01<00:00, 50.27it/s]


loss:  0.0887843444943428
Accuracy:  0.939


100%|██████████| 100/100 [00:01<00:00, 56.82it/s]


loss:  0.0908372700214386
Accuracy:  0.908


100%|██████████| 100/100 [00:01<00:00, 57.65it/s]


loss:  0.019724490121006966
Accuracy:  0.991


100%|██████████| 100/100 [00:01<00:00, 53.76it/s]


loss:  0.07482963055372238
Accuracy:  0.94


100%|██████████| 100/100 [00:01<00:00, 56.86it/s]


loss:  0.09423121064901352
Accuracy:  0.931


100%|██████████| 100/100 [00:01<00:00, 50.06it/s]


loss:  0.01091679371893406
Accuracy:  0.998


100%|██████████| 100/100 [00:01<00:00, 56.21it/s]


loss:  0.0596083365380764
Accuracy:  0.973


100%|██████████| 100/100 [00:01<00:00, 53.90it/s]


loss:  0.11093328893184662
Accuracy:  0.933


100%|██████████| 100/100 [00:01<00:00, 55.45it/s]


loss:  0.09027199447154999
Accuracy:  0.924


100%|██████████| 100/100 [00:01<00:00, 56.78it/s]


loss:  0.03362278640270233
Accuracy:  0.975


100%|██████████| 100/100 [00:01<00:00, 51.72it/s]


loss:  0.06355036795139313
Accuracy:  0.952


100%|██████████| 100/100 [00:01<00:00, 55.79it/s]


loss:  0.117958202958107
Accuracy:  0.933


100%|██████████| 100/100 [00:01<00:00, 56.04it/s]


loss:  0.022203080356121063
Accuracy:  0.991


100%|██████████| 100/100 [00:01<00:00, 56.72it/s]


loss:  0.16548967361450195
Accuracy:  0.854


100%|██████████| 100/100 [00:01<00:00, 52.95it/s]


loss:  0.03698885068297386
Accuracy:  0.967


100%|██████████| 100/100 [00:01<00:00, 53.91it/s]


loss:  0.07994400709867477
Accuracy:  0.971


100%|██████████| 100/100 [00:01<00:00, 56.19it/s]


loss:  0.04382111504673958
Accuracy:  0.971


100%|██████████| 100/100 [00:01<00:00, 55.50it/s]


loss:  0.1458296924829483
Accuracy:  0.892


100%|██████████| 100/100 [00:01<00:00, 56.34it/s]


loss:  0.04375787451863289
Accuracy:  0.956


100%|██████████| 100/100 [00:01<00:00, 55.70it/s]


loss:  0.3665728271007538
Accuracy:  0.714


100%|██████████| 100/100 [00:01<00:00, 51.88it/s]


loss:  0.050591275095939636
Accuracy:  0.98


100%|██████████| 100/100 [00:01<00:00, 54.50it/s]


loss:  0.027746669948101044
Accuracy:  0.989


100%|██████████| 100/100 [00:01<00:00, 55.19it/s]


loss:  0.14015258848667145
Accuracy:  0.933


100%|██████████| 100/100 [00:01<00:00, 56.18it/s]


loss:  0.014453871175646782
Accuracy:  0.996


100%|██████████| 100/100 [00:01<00:00, 54.82it/s]


loss:  0.06906508654356003
Accuracy:  0.942


100%|██████████| 100/100 [00:01<00:00, 52.18it/s]


loss:  0.014698801562190056
Accuracy:  0.995


100%|██████████| 100/100 [00:01<00:00, 55.85it/s]


loss:  0.23676586151123047
Accuracy:  0.87


100%|██████████| 100/100 [00:01<00:00, 53.52it/s]


loss:  0.03496338427066803
Accuracy:  0.979


100%|██████████| 100/100 [00:01<00:00, 55.64it/s]


loss:  0.02677331306040287
Accuracy:  0.994


100%|██████████| 100/100 [00:01<00:00, 56.73it/s]


loss:  0.051974620670080185
Accuracy:  0.974


100%|██████████| 100/100 [00:01<00:00, 50.44it/s]


loss:  0.05899868160486221
Accuracy:  0.949


100%|██████████| 100/100 [00:01<00:00, 54.85it/s]


loss:  0.042401790618896484
Accuracy:  0.977


100%|██████████| 100/100 [00:01<00:00, 55.48it/s]


loss:  0.05502259358763695
Accuracy:  0.943


100%|██████████| 100/100 [00:01<00:00, 54.92it/s]


loss:  0.02477528713643551
Accuracy:  0.979


100%|██████████| 100/100 [00:01<00:00, 55.63it/s]


loss:  0.019096743315458298
Accuracy:  0.985


100%|██████████| 100/100 [00:01<00:00, 51.17it/s]


loss:  0.1200459823012352
Accuracy:  0.916


100%|██████████| 100/100 [00:01<00:00, 54.44it/s]


loss:  0.05788100138306618
Accuracy:  0.975


100%|██████████| 100/100 [00:01<00:00, 55.65it/s]


loss:  0.0191783644258976
Accuracy:  0.989


100%|██████████| 100/100 [00:01<00:00, 56.04it/s]


loss:  0.016003292053937912
Accuracy:  0.995


100%|██████████| 100/100 [00:01<00:00, 53.24it/s]


loss:  0.04699765145778656
Accuracy:  0.943


100%|██████████| 100/100 [00:01<00:00, 54.38it/s]


loss:  0.02288232557475567
Accuracy:  0.979


100%|██████████| 100/100 [00:01<00:00, 51.88it/s]


loss:  0.02526906318962574
Accuracy:  0.981


100%|██████████| 100/100 [00:01<00:00, 53.68it/s]


loss:  0.1042911559343338
Accuracy:  0.946


100%|██████████| 100/100 [00:01<00:00, 51.01it/s]


loss:  0.015517804771661758
Accuracy:  0.995


100%|██████████| 100/100 [00:01<00:00, 57.61it/s]


loss:  0.06965174525976181
Accuracy:  0.983


100%|██████████| 100/100 [00:01<00:00, 50.51it/s]


loss:  0.06634822487831116
Accuracy:  0.959


100%|██████████| 100/100 [00:01<00:00, 52.10it/s]


loss:  0.022469155490398407
Accuracy:  0.975


 33%|███▎      | 33/100 [00:00<00:01, 55.75it/s]


KeyboardInterrupt: 

In [23]:
history

[3, 4, 4, 6, 4, 16, 10, 11, 12, 29, 12, 18, 38]