Contains code to train the transformer on toy dataset

References:
1. https://github.com/danielmamay/grokking/blob/main/grokking/training.py

In [2]:
import torch
import tqdm
import math

from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

from model import Transformer, Config
from gen_data import get_dataset

# to ensure that the results are reproducible
torch.manual_seed(seed=42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# Load the Config
cfg = Config()

train_dataset, valid_dataset = get_dataset(cfg)
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=512, shuffle=False)

In [4]:
len(valid_loader.dataset)

5646

#### Define the training and validation loops

In [5]:
def trainval_loop(model, optimizer, lr_scheduler, dataloader, is_train=False):

  if is_train:
    model.train()
  else:
    model.eval()

  # cross entropy loss
  criterion = torch.nn.CrossEntropyLoss()

  loss_total = 0
  accuracy = 0

  for X,y in dataloader:

    optimizer.zero_grad()

    X = X.to(device) # (512, 4)
    y = y.to(device) # 512

    # validation
    if not is_train:
      with torch.no_grad():
        yhat = model(X)
        yhat = yhat[-1]

        loss = criterion(yhat, y)

        # Calculate the Accuracy
        preds = torch.argmax(yhat, dim=1)
        acc = (preds == y).sum()

        loss_total += loss.item() * len(y)
        accuracy += acc.item()

    else:
      yhat = model(X) # torch.Size([4, 512, 99])
      yhat = yhat[-1]

      loss = criterion(yhat, y)

      # update the model weights
      loss.backward()
      optimizer.step()
      lr_scheduler.step()

      # Calculate the Accuracy
      preds = torch.argmax(yhat, dim=1)
      acc = (preds == y).sum()

      loss_total += loss.item() * len(y)
      accuracy += acc.item()

  return loss_total/len(dataloader.dataset), accuracy/len(dataloader.dataset)

In [6]:
# Load the model and the optimizers
model = Transformer(cfg).to(device)

# NOTE: Default setting used in the paper
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1, betas=[0.9, 0.98])

# Adam with LR = 3e-4
# optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

# linear learning rate warm over the first 10 updates
lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1, total_iters=10)

In [7]:
num_epochs = math.ceil( cfg.num_updates / len(train_loader) )
steps_per_epoch = math.ceil( len(train_loader.dataset)  / 512 )

In [8]:
train_acc, val_acc, train_loss, val_loss = [], [], [], []

for epoch in tqdm.tqdm(range(num_epochs)):

  # Training Loops
  loss, acc = trainval_loop(model, optimizer, lr_scheduler, train_loader, is_train = True)
  train_loss.append(loss)
  train_acc.append(acc)

  # Calculating the validation loss
  loss, acc = trainval_loop(model, optimizer, lr_scheduler, valid_loader, is_train = False)
  val_loss.append(loss)
  val_acc.append(acc)

100%|██████████| 12500/12500 [30:29<00:00,  6.83it/s]


In [10]:
plot_title = "Modular Addition (training on {} of data) with {}".format(str(cfg.split_size*100), "AdamW With Weight Decay")

steps = torch.arange(len(train_acc)).numpy() * steps_per_epoch
plt.plot(steps, train_acc, label="train")
plt.plot(steps, val_acc, label="val")
plt.legend()
plt.title( plot_title )
plt.xlabel("Steps")
plt.ylabel("Accuracy")
plt.xscale("log", base=10)
plt.savefig("results/acc_{}_adamW_wdecay.png".format(str(cfg.split_size*100)), dpi=150)
plt.close()

plt.plot(steps, train_loss, label="train")
plt.plot(steps, val_loss, label="val")
plt.legend()
plt.title( plot_title )
plt.xlabel("Steps")
plt.ylabel("Loss")
plt.xscale("log", base=10)
plt.savefig("results/loss_{}_adamW_wdecay.png".format(str(cfg.split_size*100)), dpi=150)
plt.close()