## A. Train the transformer model

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import pickle
import sys
from pathlib import Path

sys.path.append(str(Path("..").resolve()))
save_dir = Path(".").resolve() / "results"

import numpy as np

import torch
import torch.nn.functional as F

from data import Dataset, DataArgs, generate_data
from evals import attn1_score, attn2_score
from models.transformer import Transformer, TransformerConfig

DEVICE = torch.device("cpu")  # torch.device('cuda:0')
SEED = 42
RNG = np.random.default_rng(SEED)
np.random.seed(seed=SEED)
torch.manual_seed(seed=SEED)

<torch._C.Generator at 0x7f1da7b32d70>

### 1. Synthetic data 

In [3]:
# Data
vocab_size = 10
bsz = 1024
length = 128
data = generate_data(vocab_size, bsz, length + 1)

# Model
config = TransformerConfig(
    vocab_size=vocab_size,
    emb_dim=32,
    seq_len=length,
    n_head=1,
    n_layer=2,
)
model = Transformer(config)

In [4]:
# Training
niter = 1  # 100
all_losses = {}
all_accs = {}
all_attn1 = {}
all_attn2 = {}
names = ["Adam"]  # ['SGD', 'Adam', 'AdamLN', 'SGDLN']

X = data[:, :-1].to(dtype=torch.long, device=DEVICE)
Y = data[:, 1:].to(dtype=torch.long, device=DEVICE)

for name in names:
    print(name, flush=True)

    if name[-2:] == "LN":
        config.norm_layer = True
    else:
        config.norm_layer = False
    print(config, flush=True)

    model = Transformer(config)
    model.to(device=DEVICE)
    if name[:3] == "SGD":
        lr = 1e-2
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0)
    else:
        lr = 1e-3
        optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0, 0))

    losses = torch.zeros(niter)
    accs = torch.zeros(niter)
    attn1 = torch.zeros(niter)
    attn2 = torch.zeros(niter)

    for i in range(niter):
        optimizer.zero_grad()

        # compute loss
        score, attns = model(X, verbose=True)
        loss = F.cross_entropy(score.reshape((-1, vocab_size)), Y.reshape(-1))

        loss.backward()
        optimizer.step()

        # record statistics
        with torch.no_grad():
            losses[i] = loss.item()
            accs[i] = (score.argmax(-1) == Y).float().mean()
            attn1[i] = attn1_score(attns[0, :, 0])
            attn2[i] = attn2_score(attns[1, :, 0], X)

    all_losses[name] = losses
    all_accs[name] = accs
    all_attn1[name] = attn1
    all_attn2[name] = attn2

Adam
TransformerConfig(vocab_size=10, emb_dim=32, pos_emb=True, pos_dim=32, freeze_pos=False, seq_len=128, emb_dropout=0.0, n_head=1, attn_bias=False, attn_dropout=0.0, rope=False, rope_theta=10000, activation='gelu', ffn_dim=128, ffn_bias=False, ffn_dropout=0.0, norm_layer=False, norm_eps=1e-05, pre_norm=True, n_layer=2, flash=True, weight_tying=False, output_dropout=0.0, dropout=0.0)


In [5]:
path = os.path.join(save_dir, "toy_data")
path = Path(path)
path.mkdir(parents=True, exist_ok=True)

# Save model
torch.save(model.state_dict(), path / f"{name}.pt")

# Save metrics
with open(path / "losses.pkl", "wb") as f:
    pickle.dump(all_losses, f)
with open(path / "accs.pkl", "wb") as f:
    pickle.dump(all_accs, f)
with open(path / "attn1.pkl", "wb") as f:
    pickle.dump(all_attn1, f)
with open(path / "attn2.pkl", "wb") as f:
    pickle.dump(all_attn2, f)

### Shakespeare Data

In [6]:
# Data
bsz = 1024
length = 128
args = DataArgs(seq_length=length)
ds = Dataset(args=args, train_test=None, bigram_outs=False)
vocab_size = ds.num_tokens
x, outs = ds.gen_batch(rng=RNG, batch_size=bsz)
data = torch.from_numpy(x)

# Model
config = TransformerConfig(
    vocab_size=vocab_size,
    emb_dim=32,
    seq_len=length,
    n_head=1,
    n_layer=2,
)
model = Transformer(config)

In [7]:
# Training
niter = 1  # 100
all_losses = {}
all_accs = {}
all_attn1 = {}
all_attn2 = {}
names = ["Adam"]  # ['SGD', 'Adam', 'AdamLN', 'SGDLN']

X = data[:, :-1].to(dtype=torch.long, device=DEVICE)
Y = data[:, 1:].to(dtype=torch.long, device=DEVICE)

for name in names:
    print(name, flush=True)

    if name[-2:] == "LN":
        config.norm_layer = True
    else:
        config.norm_layer = False
    print(config, flush=True)

    model = Transformer(config)
    model.to(device=DEVICE)
    if name[:3] == "SGD":
        lr = 1e-2
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0)
    else:
        lr = 1e-3
        optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0, 0))

    losses = torch.zeros(niter)
    accs = torch.zeros(niter)
    attn1 = torch.zeros(niter)
    attn2 = torch.zeros(niter)

    for i in range(niter):
        optimizer.zero_grad()

        # compute loss
        score, attns = model(X, verbose=True)
        loss = F.cross_entropy(score.reshape((-1, vocab_size)), Y.reshape(-1))

        loss.backward()
        optimizer.step()

        # record statistics
        with torch.no_grad():
            losses[i] = loss.item()
            accs[i] = (score.argmax(-1) == Y).float().mean()
            attn1[i] = attn1_score(attns[0, :, 0])
            attn2[i] = attn2_score(attns[1, :, 0], X)

    all_losses[name] = losses
    all_accs[name] = accs
    all_attn1[name] = attn1
    all_attn2[name] = attn2

Adam
TransformerConfig(vocab_size=65, emb_dim=32, pos_emb=True, pos_dim=32, freeze_pos=False, seq_len=128, emb_dropout=0.0, n_head=1, attn_bias=False, attn_dropout=0.0, rope=False, rope_theta=10000, activation='gelu', ffn_dim=128, ffn_bias=False, ffn_dropout=0.0, norm_layer=False, norm_eps=1e-05, pre_norm=True, n_layer=2, flash=True, weight_tying=False, output_dropout=0.0, dropout=0.0)


Save Metrics

In [8]:
path = os.path.join(save_dir, "shakespeare_data")
path = Path(path)
path.mkdir(parents=True, exist_ok=True)

# Save model
torch.save(model.state_dict(), path / f"{name}.pt")

# Save metrics
with open(path / "losses.pkl", "wb") as f:
    pickle.dump(all_losses, f)
with open(path / "accs.pkl", "wb") as f:
    pickle.dump(all_accs, f)
with open(path / "attn1.pkl", "wb") as f:
    pickle.dump(all_attn1, f)
with open(path / "attn2.pkl", "wb") as f:
    pickle.dump(all_attn2, f)