In [None]:
import torch
n_tokens = 16

memory = torch.randint(0, n_tokens - 1, (n_tokens - 1, n_tokens - 1))

memory

tensor([[ 3, 12,  7,  2,  1,  9,  9, 12,  3,  9, 14,  1, 12, 13,  1],
        [ 3,  1,  2,  9, 11,  8,  4,  5, 13,  6, 14,  1, 12,  2, 14],
        [ 7, 14,  1,  4,  5,  0,  0,  9,  0, 10, 14,  3,  2,  4,  7],
        [ 5, 12,  9,  6,  3,  0,  7,  7,  3,  3,  4,  5,  1,  1,  0],
        [ 6,  8,  8, 13,  4,  3, 12,  2, 10,  8,  8,  5,  7,  9,  8],
        [ 3, 13, 12,  7,  0,  4, 11,  7,  3, 11, 10,  6,  0,  6, 11],
        [10, 10, 13,  4,  5, 12,  4, 10,  8, 13,  8, 11,  1,  3, 13],
        [12,  5, 10,  5,  3,  3,  8, 12,  3,  1, 14,  2,  9, 10,  2],
        [ 9,  7, 13,  1,  8, 13,  1,  5, 12, 10, 10, 11,  2, 13,  0],
        [ 0,  1, 11,  2,  2, 11, 12,  1,  4, 10,  0, 11,  0,  6, 10],
        [ 6,  5,  0, 12, 10, 14, 10, 10,  5, 13, 12, 10, 13, 14, 11],
        [ 0,  4,  8,  5, 11, 12,  5, 12,  4,  7, 12,  9,  6,  2,  8],
        [ 4, 11, 12, 12,  2,  2,  8,  8, 12,  3,  4,  2,  4, 11, 10],
        [12,  2, 10,  7, 11,  2,  0,  2,  3,  0,  8,  3,  1,  1,  0],
        [ 7,  2,  3,

In [None]:
import torch
import random
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm


class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_head, dim_feedforward):
        super().__init__()
        self.d_model = d_model
        self.n_head = n_head
        self.dim_feedforward = dim_feedforward

        self.self_attn = nn.MultiheadAttention(d_model, n_head, batch_first=True)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, tgt, skip_feedforward=False, skip_self_attn=False):
        mask = torch.triu(torch.ones(tgt.shape[1], tgt.shape[1]), diagonal=1).bool().cuda()
        if not skip_self_attn:
            tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=mask)[0]
            tgt = tgt + tgt2
        tgt = self.norm1(tgt)
        if self.dim_feedforward > 0 and not skip_feedforward:
            tgt2 = self.linear2(nn.functional.relu(self.linear1(tgt)))
            tgt = tgt + tgt2
        tgt = self.norm2(tgt)
        return tgt

class ToyTransformer(nn.Module):
    def __init__(self, n_layers, d_model, n_head, hidden_size, n_tokens, max_len):
        super().__init__()
        self.n_layers = n_layers
        self.d_model = d_model
        self.n_head = n_head
        self.hidden_size = hidden_size
        self.tokens = list(range(n_tokens))
        self.max_len = max_len

        self.embed = nn.Embedding(n_tokens, embedding_dim=d_model)

        self.layers = nn.ModuleList([
            DecoderLayer(d_model=d_model, n_head=n_head, dim_feedforward=hidden_size)
            for _ in range(n_layers)
        ])
        self.unembed = nn.Linear(d_model, n_tokens)

    def forward(self, x, skip_feedforward=False, skip_self_attn=False, return_before_embedding=False):
        tgt = self.embed(x)
        for layer in self.layers:
            tgt = tgt + layer(tgt, skip_feedforward=skip_feedforward, skip_self_attn=skip_self_attn)
        if return_before_embedding:
            return tgt
        x = self.unembed(tgt)
        return x

    def train(self, lr=1e-3, batch_size=128, n_epochs=1000):
        optimizer = optim.Adam(self.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()

        for _ in tqdm(range(n_epochs)):
            batch = self.generate_data(batch_size)
            optimizer.zero_grad()
            output = self(batch)
            loss = criterion(output[:, :-1].reshape(-1, len(self.tokens)), batch[:, 1:].reshape(-1))

            loss.backward()
            optimizer.step()

        print('loss: ', loss.item())

    def generate_data(self, batch_size):
        bos = torch.tensor([self.tokens[0]] * batch_size).reshape(-1, 1)  # [batch_size, 1]
        random_indices = torch.randint(0, n_tokens - 1, (batch_size, 2))  # [batch_size, 2]
        next_tokens = memory[random_indices[:, 0], random_indices[:, 1]].unsqueeze(1)  # [batch_size, 1]
        tensor = torch.cat([bos, random_indices, next_tokens], dim=1)
        return tensor.cuda()



In [None]:
from time import sleep
for hidden_size in [0, 16, 32]:
    print('hidden_size: ', hidden_size)
    sleep(1)
    model = ToyTransformer(n_layers=1, d_model=8, n_head=4, hidden_size=hidden_size, n_tokens=n_tokens, max_len=4).cuda()
    model.train(lr=1e-3, n_epochs=30000, batch_size=128)
    samples = 1000
    data = model.generate_data(samples)
    output = (model(data[:,:-1])[:,-1,:].argmax(dim=-1))
    print('Accuracy: ', output.eq(data[:,-1]).sum().item() / samples)
    output = (model(data[:,:-1], skip_feedforward=True)[:,-1,:].argmax(dim=-1))
    print('Accuracy without feedforward: ', output.eq(data[:,-1]).sum().item() / samples)
    output = (model(data[:,:-1], skip_feedforward=True, skip_self_attn=True)[:,-1,:].argmax(dim=-1))
    print('Accuracy without both: ', output.eq(data[:,-1]).sum().item() / samples)
    sleep(1)


hidden_size:  0


100%|██████████| 30000/30000 [01:25<00:00, 351.41it/s]


loss:  2.1032800674438477
Accuracy:  0.757
Accuracy without feedforward:  0.757
Accuracy without both:  0.134
hidden_size:  16


100%|██████████| 30000/30000 [01:37<00:00, 306.49it/s]


loss:  1.9600681066513062
Accuracy:  0.884
Accuracy without feedforward:  0.116
Accuracy without both:  0.114
hidden_size:  32


100%|██████████| 30000/30000 [01:38<00:00, 304.30it/s]


loss:  1.8287297487258911
Accuracy:  0.985
Accuracy without feedforward:  0.136
Accuracy without both:  0.089


In [None]:
class ModelWrapper(nn.Module):
    def __init__(self, model, hidden_dim):
        super().__init__()
        self.model = model
        self.hidden_layer = nn.Linear(model.d_model, hidden_dim)
        self.output_layer = nn.Linear(hidden_dim, model.d_model)

    def forward(self, x):
        tgt = self.model(x, skip_feedforward=True, return_before_embedding=True)[:,-1,:]
        x = self.hidden_layer(tgt)
        x = torch.nn.functional.relu(x)
        x = self.output_layer(x)
        tgt = tgt + x
        x = model.unembed(x)
        return x

wrapper = ModelWrapper(model, hidden_dim=2).cuda()
optimizer = optim.Adam(wrapper.parameters(), lr=1e-2)
model.requires_grad_(False)
criterion = nn.CrossEntropyLoss()

for batch in tqdm(range(10000)):
    optimizer.zero_grad()
    data = model.generate_data(128)
    output = wrapper(data[:,:-1])
    loss = criterion(output, data[:,-1])
    loss.backward()
    optimizer.step()
print(loss.item())

100%|██████████| 10000/10000 [00:19<00:00, 501.99it/s]

2.3522305488586426





In [None]:
samples = 1000
data = model.generate_data(samples)
output = wrapper(data[:,:-1]).argmax(dim=-1)
print('Accuracy: ', output.eq(data[:,-1]).sum().item() / samples)

Accuracy:  0.166
