In [44]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [45]:
import torch
from miexp.transformer import Transformer

In [46]:
device = torch.device("mps")

In [47]:
transformer = Transformer(embed_dim=128, num_heads=4, P=113, device=device, mlp_neurons=512, num_summands=2)

In [48]:
P = 113

In [49]:
import itertools
import numpy as np


X = np.array(list(itertools.product(range(113), range(113))))
y = (X[:, 0] + X[:, 1]) % P

In [50]:
y

array([  0,   1,   2, ..., 109, 110, 111])

In [51]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.6, random_state=0)

In [52]:
def prep_batch(batch: np.ndarray, device: torch.device) -> torch.Tensor:
    t = torch.zeros((batch.shape[0], batch.shape[1] + 1), device=device, dtype=torch.int)
    t[:, 1:] = torch.tensor(batch, device=device)
    t[:, 0] = P  
    return t

In [53]:
inputs = prep_batch(X_train, device)

In [54]:
val_inputs = prep_batch(X_test, device)

In [55]:
transformer.device == inputs.device, transformer.device, inputs.device

(False, device(type='mps'), device(type='mps', index=0))

In [56]:
transformer = transformer.to(device)

In [57]:
optimizer = torch.optim.AdamW(transformer.parameters(), lr=0.01)

In [58]:
y_train = torch.tensor(y_train, dtype=torch.long, device=device)
y_test = torch.tensor(y_test, dtype=torch.long, device=device)

In [None]:
from tqdm.notebook import tqdm

stats = pd.DataFrame(columns=[
    "train_loss",
    "train_acc",
    "val_loss",
    "val_acc"
])

In [None]:
import pandas as pd

for epoch in tqdm(range(10000, 50000)):
    logits = transformer(inputs)
    loss = torch.nn.functional.cross_entropy(logits, y_train)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    preds = torch.argmax(logits, dim=-1)
    with torch.no_grad():
        val_logits = transformer(val_inputs)
        val_loss = torch.nn.functional.cross_entropy(val_logits, y_test)
        val_preds = torch.argmax(val_logits, dim=-1)
    stats.loc[len(stats)] = {
        "train_loss": loss.item(),
        "train_acc": torch.sum(preds == y_train).item() / len(preds),
        "val_loss": val_loss.item(),
        "val_acc": torch.sum(val_preds == y_test).item() / len(val_preds)
    }

  0%|          | 0/40000 [00:00<?, ?it/s]

In [61]:
import plotly.express as px
import plotly.graph_objects as go

In [67]:
fig = go.Figure()

# Add the first trace (y1)
fig.add_trace(go.Scatter(x=stats.index, y=stats['train_loss'], name='train_loss', mode='lines'))

# Add the second trace (y2) with secondary_y=True
fig.add_trace(go.Scatter(x=stats.index, y=stats['train_acc'], name='train_acc', mode='lines', yaxis="y2"))

# Update the layout to include the second y-axis
fig.update_layout(
    title='Transformer Training',
    xaxis_title='Epochs',
    yaxis_title='Loss',
    yaxis2=dict(
        title='Accuracy',
        overlaying='y',  # Overlay the second y-axis on top of the first
        side='right'     # Place the second y-axis on the right
    )
)

# Show the plot
fig.show()

In [63]:
fig = go.Figure()

# Add the first trace (y1)
fig.add_trace(go.Scatter(x=stats.index, y=stats['val_loss'], name='val_loss', mode='lines'))

# Add the second trace (y2) with secondary_y=True
fig.add_trace(go.Scatter(x=stats.index, y=stats['val_acc'], name='val_acc', mode='lines', yaxis="y2"))

# Update the layout to include the second y-axis
fig.update_layout(
    title='Transformer Training',
    xaxis_title='Epochs',
    yaxis_title='Loss',
    yaxis2=dict(
        title='Accuracy',
        overlaying='y',  # Overlay the second y-axis on top of the first
        side='right'     # Place the second y-axis on the right
    )
)

# Show the plot
fig.show()

In [64]:
inputs[:10, :]

tensor([[113,  27,  66],
        [113,  12,  72],
        [113, 102,  63],
        [113,  57,  88],
        [113,   4,  31],
        [113,  48,  89],
        [113,  40,  96],
        [113,  36,  64],
        [113,  83,  45],
        [113,  58,  82]], device='mps:0', dtype=torch.int32)

In [None]:
logits

tensor([[ -2.4490, -29.2748,  40.3794,  ...,  27.4652, -29.6827,   0.1628],
        [  2.0166,   0.9747,   8.4324,  ...,   2.8632,  -0.6887, -10.0233],
        [  7.1558, -10.8454, -17.2419,  ..., -25.3343, -13.0056,  -0.7803],
        ...,
        [  3.1185,  10.0894, -13.3770,  ...,  15.8868,  -7.5139,  -8.2033],
        [ -1.4302,  -8.7777,  -6.3723,  ...,  -1.2857,  -8.2004, -23.8650],
        [  0.4781,  12.7040, -38.5620,  ...,   0.1288,  16.4200,   1.0142]],
       grad_fn=<MmBackward0>)

In [None]:
preds.type(torch.float).mean(), preds.type(torch.float).std(), preds.min(), preds.max()

(tensor(56.5240), tensor(32.5610), tensor(0), tensor(112))

In [None]:
preds[:10]

tensor([ 93,  84,  52,  32,  35,  24,  23, 100,  15,  27])

In [None]:
y_train[:10]

tensor([ 93,  84,  52,  32,  35,  24,  23, 100,  15,  27])