In [1]:
import torch
import torch.nn.functional as F
import numpy as np

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleTransformer(nn.Module):
    def __init__(self, vocab_size=64, dim_model=32):
        super(SimpleTransformer, self).__init__()
        self.embed = nn.Embedding(vocab_size, dim_model)
        self.query = nn.Linear(dim_model, dim_model)
        self.key = nn.Linear(dim_model, dim_model)
        self.value = nn.Linear(dim_model, dim_model)
        self.ffn = nn.Sequential(
            nn.Linear(dim_model, dim_model * 4),
            nn.ReLU(),
            nn.Linear(dim_model * 4, dim_model)
        )
        self.out = nn.Linear(dim_model, vocab_size)

    def forward(self, x):
        x = self.embed(x)

        q = self.query(x)
        k = self.key(x)
        v = self.value(x)
        attn_weights = F.softmax(q @ k.transpose(-2, -1) / (32 ** 0.5), dim=-1)
        attn_output = attn_weights @ v

        ffn_output = self.ffn(attn_output + x)

        logits = self.out(ffn_output)
        return logits

model = SimpleTransformer()

In [3]:
import random

FIXED_LENGTH = 3

def generateLists(n):
    output = []
    for _ in range(n):
        curr = []
        for _ in range(FIXED_LENGTH):
            curr.append(random.randint(0, 64 - 1))
            # curr.append(random.randint(0, 100))

        # maximum = max(curr)
        # output.append((curr, maximum))
        output.append(curr)

    output = torch.tensor(output)
    return output

Separate training and testing

In [4]:
def separate_data(n, split=0.7):
    output = []
    for i in range(64):
        for j in range(64):
            curr = [i, j]
            output.append(curr)

    random.shuffle(output)

    split_index = int(len(output) * split)
    # return training, testing
    return torch.tensor(output[:split_index]), torch.tensor(output[split_index:])

In [5]:
def output_data(data, batch_size=128):
    for i in range(0, len(data), batch_size):
        yield data[i:i + batch_size]

Loss functions

In [6]:
def loss_function(logits, tokens, return_per_token=True, print_tokens=False):
    # we take the last element of the logits to make the next prediction
    logits = logits[:, -1, :]
    answer = torch.max(tokens, dim=1)[0]
    log_prob = logits.log_softmax(-1)
    if print_tokens:
        print("tokens", tokens)
        print("predicted", torch.argmax(logits, dim=-1))
    # shape is (batch_size, 1) which represents probabilities 
    # of the correct answer
    output_prob = log_prob.gather(-1, answer.unsqueeze(-1))
    if return_per_token:
        return -1 * output_prob.squeeze()
    return -1 * output_prob.mean()

In [7]:
def accuracy(logits, tokens, return_per_token=False):
    logits = logits[:, -1, :]
    predicted = torch.argmax(logits, dim=-1)
    answer = torch.max(tokens, dim=1)[0]
    print(predicted, answer)
    if return_per_token:
        return (predicted == answer).float()
    return (predicted == answer).float().mean().item()

In [8]:
def train_model(model, n_epochs, batch_size, batches_per, sequence_length=2):
    lr = 1e-3
    betas = (0.9, 0.999)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas)

    train_losses = []
    training, testing = separate_data(64)
    print(training.shape, testing.shape)
    for epoch in range(n_epochs):
        model.train()
        epoch_losses = []
        generator = output_data(training, batch_size)
        for _ in range(batches_per):
            tokens = next(generator)
            logits = model(tokens)
            # print(tokens.shape)
            # print(logits.shape)
            losses = loss_function(logits, tokens, print_tokens=False)
            losses.mean().backward()
            optimizer.step()
            optimizer.zero_grad()
            epoch_losses.extend(losses.detach())

        train_losses.append(np.mean(epoch_losses))

        if epoch % 10 == 0:
            print(f"Epoch {epoch}, train loss: {train_losses[-1]}")

    model.eval()
    logits = model(testing)
    acc = accuracy(logits, testing, return_per_token=False)
    print(f"Test accuracy: {acc}")

    return losses
                

In [9]:
torch.set_printoptions(threshold=10000)
losses = train_model(model, 1000, 128, 10, 3)

torch.Size([2867, 2]) torch.Size([1229, 2])
Epoch 0, train loss: 4.115307807922363
Epoch 10, train loss: 1.2805118560791016
Epoch 20, train loss: 0.2150605022907257
Epoch 30, train loss: 0.05243371054530144
Epoch 40, train loss: 0.019432011991739273
Epoch 50, train loss: 0.009694280102849007
Epoch 60, train loss: 0.005810381378978491
Epoch 70, train loss: 0.003874266054481268
Epoch 80, train loss: 0.00276588904671371
Epoch 90, train loss: 0.002070572692900896
Epoch 100, train loss: 0.0016054243315011263
Epoch 110, train loss: 0.0012786707375198603
Epoch 120, train loss: 0.001040170551277697
Epoch 130, train loss: 0.0008607251802459359
Epoch 140, train loss: 0.0007224645232781768
Epoch 150, train loss: 0.000613644253462553
Epoch 160, train loss: 0.0005265054060146213
Epoch 170, train loss: 0.00045569968642666936
Epoch 180, train loss: 0.0003973893471993506
Epoch 190, train loss: 0.00034885111381299794
Epoch 200, train loss: 0.0003080289752688259
Epoch 210, train loss: 0.0002734263835009