In [44]:
import torch
import torch.nn.functional as F
import numpy as np

In [45]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleTransformer(nn.Module):
    def __init__(self, vocab_size=64, dim_model=32):
        super(SimpleTransformer, self).__init__()
        self.embed = nn.Embedding(vocab_size, dim_model)
        self.query = nn.Linear(dim_model, dim_model)
        self.key = nn.Linear(dim_model, dim_model)
        self.value = nn.Linear(dim_model, dim_model)
        self.ffn = nn.Sequential(
            nn.Linear(dim_model, dim_model * 4),
            nn.ReLU(),
            nn.Linear(dim_model * 4, dim_model)
        )
        self.out = nn.Linear(dim_model, vocab_size)

    def forward(self, x):
        x = self.embed(x)
        # Simple single-head attention
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)
        attn_weights = F.softmax(q @ k.transpose(-2, -1) / (32 ** 0.5), dim=-1)
        attn_output = attn_weights @ v

        # Feed-forward network
        ffn_output = self.ffn(attn_output + x)  # Add & Norm not implemented for simplicity

        logits = self.out(ffn_output)
        return logits

# Example usage
model = SimpleTransformer()

In [46]:
# data generation
import random

FIXED_LENGTH = 3

def generateLists(n):
    output = []
    for _ in range(n):
        curr = []
        for _ in range(FIXED_LENGTH):
            curr.append(random.randint(0, 64 - 1))
            # curr.append(random.randint(0, 100))

        # maximum = max(curr)
        # output.append((curr, maximum))
        output.append(curr)

    output = torch.tensor(output)
    return output

In [47]:
def loss_function(logits, tokens, return_per_token=True, print_tokens=False):
    # we take the last element of the logits to make the next prediction
    logits = logits[:, -1, :]
    answer = torch.max(tokens, dim=1)[0]
    log_prob = logits.log_softmax(-1)
    if print_tokens:
        print("tokens", tokens)
        print("predicted", torch.argmax(logits, dim=-1))
    # shape is (batch_size, 1) which represents probabilities 
    # of the correct answer
    output_prob = log_prob.gather(-1, answer.unsqueeze(-1))
    if return_per_token:
        return -1 * output_prob.squeeze()
    return -1 * output_prob.mean()

In [48]:
def accuracy(logits, tokens, return_per_token=False):
    logits = logits[:, -1, :]
    predicted = torch.argmax(logits, dim=1)
    answer = torch.max(tokens, dim=1)[0]
    if return_per_token:
        return (predicted == answer).float()
    return (predicted == answer).float().mean().item()

In [52]:
def train_model(model, n_epochs, batch_size, batches_per, sequence_length=2):
    lr = 1e-3
    betas = (0.9, 0.999)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas)

    train_losses = []
    for epoch in range(n_epochs):
        epoch_losses = []
        for _ in range(batches_per):
            tokens = generateLists(batch_size)
            logits = model(tokens)
            # print(tokens.shape)
            # print(logits.shape)
            losses = loss_function(logits, tokens, print_tokens=True)
            losses.mean().backward()
            optimizer.step()
            optimizer.zero_grad()
            epoch_losses.extend(losses.detach())

        train_losses.append(np.mean(epoch_losses))
        if epoch % 10 == 0:
            print(f"Epoch {epoch}, train loss: {train_losses[-1]}")

    model.eval()
    test_data = generateLists(1280)
    logits = model(test_data)
    acc = accuracy(logits, test_data, return_per_token=False)

    print(f"Test accuracy: {acc}")

    return losses
                

In [54]:
losses = train_model(model, 1, 128, 10, 3)

tokens tensor([[26,  6, 34],
        [25, 47, 18],
        [53, 11,  2],
        [50, 63, 43],
        [50, 40,  9],
        [23, 22, 34],
        [38, 46, 60],
        [40, 60, 32],
        [61,  8, 26],
        [49, 18, 42],
        [ 1, 57, 40],
        [51, 63, 19],
        [52, 37, 12],
        [20, 27, 46],
        [29, 36, 24],
        [ 7, 44, 22],
        [40, 29, 36],
        [17, 48, 24],
        [31, 29,  6],
        [23, 36, 48],
        [17, 52, 20],
        [59, 48, 20],
        [32, 17, 58],
        [55, 56,  6],
        [52, 25, 22],
        [16, 56, 19],
        [52, 46, 47],
        [61, 13, 46],
        [50, 63, 59],
        [ 4, 26, 20],
        [ 5, 27, 38],
        [43,  7, 48],
        [34, 29,  5],
        [33, 22, 33],
        [18, 30, 36],
        [62, 14,  7],
        [29, 47, 41],
        [14, 16, 36],
        [10, 59, 63],
        [19,  7,  7],
        [50,  6, 55],
        [60, 24, 26],
        [24, 13, 22],
        [ 7, 56, 54],
        [38,  5, 19],
   