In [2]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from models import BasicTransformer

## Data Preparation

In [3]:
x = torch.load('X.pt')
y = torch.load('y.pt')

# Create a set of all possible play descriptors
play_descriptors = set()
for s in x.flatten():
    play_descriptors.add(s)
for s in y.flatten():
    play_descriptors.add(s)

# Assign each play descriptor a unique token
play_to_tok = {s: i for i, s in enumerate(play_descriptors)}
num_token_types = len(play_descriptors)

# Convert x and y to their token representations
x_tok = torch.zeros(x.shape, dtype=torch.long)
y_tok = torch.zeros(y.shape, dtype=torch.long)

for i in range(x.shape[0]):
    for j in range(x.shape[1]):
        x_tok[i, j] = play_to_tok[x[i, j]]

for i in range(y.shape[0]):
    for j in range(y.shape[1]):
        y_tok[i, j] = play_to_tok[y[i, j]]



In [5]:
# Split into train and test
train_examples = math.floor(len(x) * 0.8)

# Train on only one example for now
x_train = x_tok[:train_examples][:8000]
y_train = y_tok[:train_examples][:8000]

x_test = x_tok[train_examples:][:2000]
y_test = y_tok[train_examples:][:2000]

train_loader = torch.utils.data.DataLoader(
    list(zip(x_train, y_train)),
    batch_size=32,
    shuffle=True
)

val_loader = torch.utils.data.DataLoader(
    list(zip(x_test, y_test)),
    batch_size=32,
    shuffle=True
)


In [6]:
def accuracy(preds, labels):
    return sum(preds[:, -1].argmax(dim=1) == labels[:, -1].argmax(dim=1)) / len(preds)

## Training Loop

In [10]:
def train(model, n_epochs=10):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    for epoch in range(n_epochs):
        model.train()
        i = 0
        train_loss = 0
        train_acc = 0
        for x_bat, y_bat in iter(train_loader):
            x_bat.permute(1, 0)
            y_bat = F.one_hot(y_bat, num_classes=num_token_types).float()
            optimizer.zero_grad()
            y_pred = model(x_bat)
            loss = criterion(y_pred, y_bat)
            loss.backward()
            train_loss += loss.item()
            optimizer.step()
            i += 1
            train_acc += accuracy(y_pred, y_bat)
        
            # print(f'Epoch {epoch}, iter {i}, loss: {loss.item()}')
        train_acc = train_acc / len(train_loader)

        val_loss = 0
        val_acc = 0
        for x_val, y_val in iter(val_loader):
            y_pred = model(x_val)
            y_val = F.one_hot(y_val, num_classes=num_token_types).float()
            loss = criterion(y_pred, y_val)
            val_loss += loss.item()
            val_acc += accuracy(y_pred, y_val)
        
        val_acc = val_acc / len(val_loader)
        print(f'Epoch {epoch}, iter {i}, train_loss: {train_loss}, train_acc: {train_acc}, val_loss: {val_loss}, val_acc: {val_acc}')

In [8]:
model = BasicTransformer(num_token_types, 128, 128, 2, 2, 0.1)



In [11]:
train(model)