In [1]:
import torch
import torch.nn.functional as F
import numpy as np

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleTransformer(nn.Module):
    def __init__(self, vocab_size=64, dim_model=32):
        super(SimpleTransformer, self).__init__()
        self.embed = nn.Embedding(vocab_size, dim_model)
        self.query = nn.Linear(dim_model, dim_model)
        self.key = nn.Linear(dim_model, dim_model)
        self.value = nn.Linear(dim_model, dim_model)
        self.ffn = nn.Sequential(
            nn.Linear(dim_model, dim_model * 4),
            nn.ReLU(),
            nn.Linear(dim_model * 4, dim_model)
        )
        self.out = nn.Linear(dim_model, vocab_size)

    def forward(self, x):
        x = self.embed(x)

        q = self.query(x)
        k = self.key(x)
        v = self.value(x)
        attn_weights = F.softmax(q @ k.transpose(-2, -1) / (32 ** 0.5), dim=-1)
        attn_output = attn_weights @ v

        ffn_output = self.ffn(attn_output + x)

        logits = self.out(ffn_output)
        return logits

model = SimpleTransformer()

In [3]:
import random

FIXED_LENGTH = 3

def generateLists(n):
    output = []
    for _ in range(n):
        curr = []
        for _ in range(FIXED_LENGTH):
            curr.append(random.randint(0, 64 - 1))
            # curr.append(random.randint(0, 100))

        # maximum = max(curr)
        # output.append((curr, maximum))
        output.append(curr)

    output = torch.tensor(output)
    return output

In [4]:
def loss_function(logits, tokens, return_per_token=True, print_tokens=False):
    # we take the last element of the logits to make the next prediction
    logits = logits[:, -1, :]
    answer = torch.max(tokens, dim=1)[0]
    log_prob = logits.log_softmax(-1)
    if print_tokens:
        print("tokens", tokens)
        print("predicted", torch.argmax(logits, dim=-1))
    # shape is (batch_size, 1) which represents probabilities 
    # of the correct answer
    output_prob = log_prob.gather(-1, answer.unsqueeze(-1))
    if return_per_token:
        return -1 * output_prob.squeeze()
    return -1 * output_prob.mean()

In [5]:
def accuracy(logits, tokens, return_per_token=False):
    logits = logits[:, -1, :]
    predicted = torch.argmax(logits, dim=1)
    answer = torch.max(tokens, dim=1)[0]
    if return_per_token:
        return (predicted == answer).float()
    return (predicted == answer).float().mean().item()

In [6]:
def train_model(model, n_epochs, batch_size, batches_per, sequence_length=2):
    lr = 1e-3
    betas = (0.9, 0.999)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas)

    train_losses = []
    for epoch in range(n_epochs):
        epoch_losses = []
        for _ in range(batches_per):
            tokens = generateLists(batch_size)
            logits = model(tokens)
            # print(tokens.shape)
            # print(logits.shape)
            losses = loss_function(logits, tokens, print_tokens=True)
            losses.mean().backward()
            optimizer.step()
            optimizer.zero_grad()
            epoch_losses.extend(losses.detach())

        train_losses.append(np.mean(epoch_losses))
        if epoch % 10 == 0:
            print(f"Epoch {epoch}, train loss: {train_losses[-1]}")

    model.eval()
    test_data = generateLists(1280)
    logits = model(test_data)
    acc = accuracy(logits, test_data, return_per_token=False)

    print(f"Test accuracy: {acc}")

    return losses
                

In [7]:
losses = train_model(model, 1, 128, 10, 3)

tokens tensor([[ 6, 15, 43],
        [18, 15, 57],
        [15, 44, 63],
        [28, 18, 19],
        [40,  5, 15],
        [ 8, 30, 38],
        [37, 43, 43],
        [37, 49, 61],
        [55, 30, 46],
        [47, 38, 10],
        [45, 54, 39],
        [19,  9,  4],
        [53, 52, 53],
        [ 8, 25, 29],
        [53, 15, 46],
        [25, 15,  1],
        [13, 35, 19],
        [62, 47, 27],
        [40,  0, 59],
        [28, 48, 45],
        [ 8,  9,  1],
        [26, 47, 16],
        [37, 36, 42],
        [15, 20, 57],
        [59, 57, 58],
        [63, 45,  8],
        [ 1, 30, 49],
        [43, 23,  7],
        [ 6,  0, 44],
        [23, 35,  2],
        [20, 45, 55],
        [59, 53, 17],
        [12, 29, 11],
        [18, 48, 31],
        [ 2,  2,  4],
        [23, 29, 57],
        [40,  5, 12],
        [21,  0, 32],
        [52,  9, 50],
        [11, 57, 11],
        [55,  7, 58],
        [28,  5, 52],
        [42, 47, 12],
        [18, 14,  3],
        [58, 49, 63],
   