In [7]:
import torch
from model import Base, BaselineModel
from torch import nn
from typing import Union, Iterable
from data import Data, train_test_split_exact
from collections import OrderedDict
import tqdm

In [9]:
class BaselineModel(Base):
    def __init__(
        self, vocab_size: Union[int, Iterable], hidden_dim: int, output_dim: int, dropout: float = 0.
    ):
        """
        :param vocab_size: number of tokens in the vocabulary,
          or an Iterable of vocab sizes for each input. One embedding layer will be created for each input.
        :param hidden_dim: dimension of the hidden layer
        :param output_dim: dimension of the output layer
        """

        super().__init__(vocab_size, hidden_dim)

        self.nonlinear = nn.Sequential(
            nn.Linear(3 * hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            # ResidualBlock(hidden_dim, dropout=dropout),
            # nn.SiLU(),
            # ResidualBlock(hidden_dim, dropout=dropout),
            # nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
        )
        self.readout = nn.Linear(hidden_dim, output_dim)

    def forward_with_embeddings(self, x, embs):  # embs: [ batch_size, 2 * hidden_dim ]
        x = self.embed_input(x, embs)
        x = self.nonlinear(x)  # [ batch_size, hidden_dim ]
        return self.readout(x)  # [ batch_size, output_dim ]

In [30]:
torch.manual_seed(0)

operations = ["add", "subtract", "multiply"]
P = 53
X = torch.cartesian_prod(torch.arange(P), torch.arange(P), torch.arange(len(operations)))
y = torch.zeros(len(X))
# modular arithemtic
for i, (a, b, op) in enumerate(X):
    if operations[op] == "add":
        y[i] = (a + b) % P
    elif operations[op] == "subtract":
        y[i] = (a - b) % P
    elif operations[op] == "multiply":
        y[i] = (a * b) % P

def train_test_split_exact(X, y=None, train_frac=0.8, seed=0):
    torch.manual_seed(seed)
    train_mask = torch.ones(X.shape[0], dtype=torch.bool)
    train_mask[int(train_frac * X.shape[0]) :] = False
    train_mask = train_mask[torch.randperm(X.shape[0])]
    test_mask = ~train_mask
    return train_mask, test_mask

def accuracy(out, y, task_mask=None):
    with torch.no_grad():
        if task_mask is None:
            return (out.argmax(dim=1) == y).float().mean() * 100
        else:
            return (out.argmax(dim=1)[task_mask] == y[task_mask]).float().mean() * 100

def get_task_accs(out, y):
    task_accs = {}
    for op in operations:
        task_mask = X[:, 2] == operations.index(op)
        task_accs[op] = accuracy(out, y.long(), task_mask=task_mask)
    return task_accs



model = BaselineModel(vocab_size=(P + len(operations)), hidden_dim=64, output_dim=P)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()
mask_train, mask_test = train_test_split_exact(X, y, train_frac=0.8, seed=0)

print("T:  loss |  acc  || V:  loss |  acc")
bar = tqdm.trange(20000)
for epoch in bar:
    model.train()
    optimizer.zero_grad()
    out = model(X[mask_train])
    loss_train = criterion(out, y[mask_train].long())
    loss_train.backward()
    optimizer.step()
    acc_train = accuracy(out, y[mask_train].long())
    task_accs_train = get_task_accs(out, y[mask_train].long())
    
    model.eval()
    out = model(X[mask_test])
    loss_test = criterion(out, y[mask_test].long())
    acc_test = accuracy(out, y[mask_test].long())
    msg = f"T: {loss_train:.3f} | {acc_train:.3f} || V: {loss_test:.3f} | {acc_test:.3f}"
    bar.set_description_str(msg)
    task_accs_test = get_task_accs(out, y[mask_test].long())
    postfix = " | ".join([f"{op}: {acc:.2f}" for op, acc in task_accs_test.items()])
    bar.set_postfix_str(postfix)


T:  loss |  acc  || V:  loss |  acc


  0%|          | 0/20000 [00:00<?, ?it/s]


IndexError: The shape of the mask [8427] at index 0 does not match the shape of the indexed tensor [6741] at index 0