In [7]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from models import BasicTransformer

## Data Preparation

In [8]:
x = torch.load('X.pt')
y = torch.load('y.pt')

# Create a set of all possible play descriptors
play_descriptors = set()
for s in x.flatten():
    play_descriptors.add(s)
for s in y.flatten():
    play_descriptors.add(s)

# Assign each play descriptor a unique token
play_to_tok = {s: i for i, s in enumerate(play_descriptors)}
num_token_types = len(play_descriptors)

# Convert x and y to their token representations
x_tok = torch.zeros(x.shape, dtype=torch.long)
y_tok = torch.zeros(y.shape, dtype=torch.long)

for i in range(x.shape[0]):
    for j in range(x.shape[1]):
        x_tok[i, j] = play_to_tok[x[i, j]]

for i in range(len(y)):
    y_tok[i] = play_to_tok[y[i]]



In [14]:
# Split into train and test
train_examples = math.floor(len(x) * 0.8)

# Train on only one example for now
x_train = x_tok[:train_examples][:1]
y_train = y_tok[:train_examples][:1]

x_test = x_tok[train_examples:]
y_test = y_tok[train_examples:]

train_loader = torch.utils.data.DataLoader(
    list(zip(x_train, y_train)),
    batch_size=32,
    shuffle=True
)

test_loader = torch.utils.data.DataLoader(
    list(zip(x_test, y_test)),
    batch_size=32,
    shuffle=True
)


## Training Loop

In [24]:
def train(model, n_epochs=50):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    for epoch in range(n_epochs):
        model.train()
        i = 0
        for x_bat, y_bat in iter(train_loader):
            x_bat.permute(1, 0)
            # x_bat = F.one_hot(x_bat, num_classes=num_token_types).res
            y_bat = F.one_hot(y_bat, num_classes=num_token_types)
            optimizer.zero_grad()
            y_pred = model(x_bat)
            loss = criterion(y_pred, y_bat)
            loss.backward()
            optimizer.step()
            i += 1
            print(f'Epoch {epoch}, iter {i}, loss: {loss.item()}')

In [25]:
model = BasicTransformer(num_token_types, 128, 128, 2, 2, 0.1)



In [26]:
train(model)

Epoch 0, iter 1, loss: 1.6774523258209229
Epoch 1, iter 2, loss: 1.6876333951950073
Epoch 2, iter 3, loss: 1.699212908744812
Epoch 3, iter 4, loss: 1.6746820211410522
Epoch 4, iter 5, loss: 1.681522011756897
Epoch 5, iter 6, loss: 1.6916401386260986
Epoch 6, iter 7, loss: 1.6921319961547852
Epoch 7, iter 8, loss: 1.6903804540634155
Epoch 8, iter 9, loss: 1.6930173635482788
Epoch 9, iter 10, loss: 1.6939946413040161
Epoch 10, iter 11, loss: 1.681666374206543
Epoch 11, iter 12, loss: 1.6895463466644287
Epoch 12, iter 13, loss: 1.685921311378479
Epoch 13, iter 14, loss: 1.699947476387024
Epoch 14, iter 15, loss: 1.6905410289764404
Epoch 15, iter 16, loss: 1.6917647123336792
Epoch 16, iter 17, loss: 1.6950290203094482
Epoch 17, iter 18, loss: 1.690560221672058
Epoch 18, iter 19, loss: 1.687844157218933
Epoch 19, iter 20, loss: 1.691892147064209
Epoch 20, iter 21, loss: 1.6854383945465088
Epoch 21, iter 22, loss: 1.6818132400512695
Epoch 22, iter 23, loss: 1.6999239921569824
Epoch 23, iter 