[Link to experiment 2]( )

In [None]:
import torch
n_tokens = 21

memory = torch.randint(0, n_tokens - 1, (n_tokens - 1, n_tokens - 1))

memory

tensor([[ 4,  4, 10, 13, 11,  3,  9,  0, 17, 13,  6, 10,  3,  8,  4, 16,  9, 11,
         17, 10],
        [ 0,  7, 18, 18,  4,  1,  8, 13, 19,  6, 10, 12, 13,  2,  8, 10,  6, 18,
          2,  6],
        [ 4, 10, 10, 19, 17,  2, 16,  6, 17, 12, 16, 18,  4, 10,  1, 15,  5, 17,
          7, 12],
        [10,  3, 12,  6,  3, 18, 10,  9,  5,  6, 18, 10, 12, 16, 11,  6,  4, 15,
          2,  7],
        [16,  7,  2,  0,  7,  0,  9,  6, 12, 11, 18, 11, 19, 14, 11,  1, 15,  6,
          8, 18],
        [18,  9, 18,  4, 18, 10,  4, 12, 17,  8,  3, 12, 19,  3, 19, 17, 12, 13,
         18, 14],
        [ 6, 17, 19, 11,  9,  0, 17, 10,  8,  7, 19,  7,  2,  1, 12, 18, 15, 14,
          3,  1],
        [ 0, 10,  7, 13, 19, 16, 19, 19,  6, 13,  1,  7,  8, 16, 10, 15, 11, 14,
         11, 11],
        [ 2,  8, 15,  4, 19,  3,  9,  6, 14,  5,  1, 19, 15, 19,  6, 17,  2,  6,
          6, 15],
        [17, 10, 10, 18,  2,  9,  7,  5,  7,  9, 18,  2,  7, 17, 17, 18, 16, 14,
         17,  6],
        [ 

In [None]:
import torch
import random
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm


class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_head, dim_feedforward):
        super().__init__()
        self.d_model = d_model
        self.n_head = n_head
        self.dim_feedforward = dim_feedforward

        self.self_attn = nn.MultiheadAttention(d_model, n_head, batch_first=True)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, tgt):
        mask = torch.triu(torch.ones(tgt.shape[1], tgt.shape[1]), diagonal=1).bool().cuda()
        tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=mask)[0]
        tgt = tgt + tgt2
        tgt = self.norm1(tgt)
        if self.dim_feedforward > 0:
            tgt2 = self.linear2(nn.functional.relu(self.linear1(tgt)))
            tgt = tgt + tgt2
        tgt = self.norm2(tgt)
        return tgt

class ToyTransformer(nn.Module):
    def __init__(self, n_layers, d_model, n_head, hidden_size, n_tokens, max_len):
        super().__init__()
        self.n_layers = n_layers
        self.d_model = d_model
        self.n_head = n_head
        self.hidden_size = hidden_size
        self.tokens = list(range(n_tokens))
        self.max_len = max_len

        self.embed = nn.Embedding(n_tokens, embedding_dim=d_model)

        self.layers = nn.ModuleList([
            DecoderLayer(d_model=d_model, n_head=n_head, dim_feedforward=hidden_size)
            for _ in range(n_layers)
        ])
        self.unembed = nn.Linear(d_model, n_tokens)

    def forward(self, x):
        tgt = self.embed(x)
        for layer in self.layers:
            tgt = tgt + layer(tgt)
        x = self.unembed(tgt)
        return x

    def train(self, lr=1e-3, batch_size=128, n_epochs=1000):
        optimizer = optim.Adam(self.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()

        for _ in tqdm(range(n_epochs)):
            batch = self.generate_data(batch_size)
            optimizer.zero_grad()
            output = self(batch)
            loss = criterion(output[:, :-1].reshape(-1, len(self.tokens)), batch[:, 1:].reshape(-1))

            loss.backward()
            optimizer.step()

        print('loss: ', loss.item())

    def generate_data(self, batch_size):
        bos = torch.tensor([self.tokens[0]] * batch_size).reshape(-1, 1)  # [batch_size, 1]
        random_indices = torch.randint(0, n_tokens - 1, (batch_size, 2))  # [batch_size, 2]
        next_tokens = memory[random_indices[:, 0], random_indices[:, 1]].unsqueeze(1)  # [batch_size, 1]
        tensor = torch.cat([bos, random_indices, next_tokens], dim=1)
        return tensor.cuda()



In [None]:
from time import sleep
for hidden_size in [0, 16, 32]:
    for nheads in [2, 4, 8]:
        print('hidden_size: ', hidden_size, 'nheads: ', nheads)
        sleep(1)
        model = ToyTransformer(n_layers=1, d_model=8, n_head=nheads, hidden_size=hidden_size, n_tokens=n_tokens, max_len=4).cuda()
        model.train(lr=1e-3, n_epochs=50000, batch_size=512)
        samples = 50000
        data = model.generate_data(samples)
        # print(data)
        output = (model(data[:,:-1])[:,-1,:].argmax(dim=-1))
        # print(output)
        print(output.eq(data[:,-1]).sum().item() / samples)
        sleep(1)


hidden_size:  0 nheads:  2


100%|██████████| 50000/50000 [02:25<00:00, 343.75it/s]


loss:  2.5544683933258057
0.45614
hidden_size:  0 nheads:  4


100%|██████████| 50000/50000 [02:24<00:00, 346.65it/s]


loss:  2.4965195655822754
0.53126
hidden_size:  0 nheads:  8


100%|██████████| 50000/50000 [02:21<00:00, 352.82it/s]


loss:  2.4809160232543945
0.5849
hidden_size:  16 nheads:  2


100%|██████████| 50000/50000 [02:43<00:00, 305.02it/s]


loss:  2.399559259414673
0.68864
hidden_size:  16 nheads:  4


100%|██████████| 50000/50000 [02:45<00:00, 302.66it/s]


loss:  2.2880663871765137
0.77488
hidden_size:  16 nheads:  8


100%|██████████| 50000/50000 [02:44<00:00, 304.10it/s]


loss:  2.227630853652954
0.80994
hidden_size:  32 nheads:  2


100%|██████████| 50000/50000 [02:44<00:00, 303.69it/s]


loss:  2.1739981174468994
0.89682
hidden_size:  32 nheads:  4


100%|██████████| 50000/50000 [02:42<00:00, 308.32it/s]


loss:  2.169658899307251
0.87494
hidden_size:  32 nheads:  8


100%|██████████| 50000/50000 [02:44<00:00, 303.67it/s]


loss:  2.1942458152770996
0.8586
