In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch

from dyck_k_generator import constants

In [4]:
device = (
    "cuda:0"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)
device

  return torch._C._cuda_getDeviceCount() > 0


'cpu'

In [None]:
if device == "mps":
    torch.mps.empty_cache()
elif device == "cuda:0":
    torch.cuda.empty_cache()

In [None]:
k = 1

In [None]:
from dyck_k_generator.generator import generate_dataset

generate_dataset(
    n=10_000,
    k=k,
    max_length=10,
    balanced=0.6,
)

In [None]:
VOCAB = "".join(
    ["".join((key, value)) for key, value in list(constants.BRACKETS.items())[:k]]
)
VOCAB

In [None]:
from dataset.dataset import DyckLanguageDataset

In [None]:
dataset = DyckLanguageDataset("data/dyck-1_500000-samples_8-len_p05.jsonl", VOCAB).to(
    device
)

In [None]:
from torch.utils.data import random_split

train_size = int(0.8 * len(dataset))
val_size = int(0.15 * train_size)
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    dataset, [train_size, val_size, test_size]
)

In [None]:
from torch.utils.data import DataLoader

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False)

# Manual Transformer + BERTViz


In [None]:
from transformer.hooked_transformer import (
    TransformerClassifier,
    TransformerClassifierConfig,
    pad_token_mask,
)

In [None]:
model_config = TransformerClassifierConfig(
    vocab_size=len(VOCAB),
    d_model=128,
    n_heads=1,
    dim_ff=256,
    n_layers=1,
    n_classes=2,
    max_seq_len=10,
)

In [None]:
model = TransformerClassifier(model_config)

In [None]:
model.to(device)

In [None]:
import torch.optim as optim

crit = torch.nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [None]:
train_loss, train_acc, val_loss, val_acc = model.train_model(
    device=device,
    epochs=10,
    optimizer=optimizer,
    criterion=crit,
    train_dataloader=train_dataloader,
    eval_dataloader=val_dataloader,
    use_mask=True,
)

In [None]:
test_loss, test_acc = model.eval_model(
    device=device,
    test_dataloader=test_dataloader,
    criterion=crit,
    use_mask=True,
)

In [None]:
import matplotlib.pyplot as plt

plt.plot(train_loss, label="train loss")
plt.plot(val_loss, label="val loss")
plt.plot(train_acc, label="train acc")
plt.plot(val_acc, label="val acc")
plt.legend()
plt.show()

# Attention plots:


In [None]:
batch = next(iter(test_dataloader))
strings, labels, tokens = batch

mask = pad_token_mask(tokens)
attn_matrices = model.get_attn_matrices(tokens, mask)

In [None]:
print(strings[0], tokens[0], labels[0])

In [None]:
attn_matrices[0].shape

In [None]:
attn_matrix = attn_matrices[0][0]
attn_matrix[0]

In [None]:
from dataset.dataset import DyckLanguageTokenizer

In [None]:
tokenizer = DyckLanguageTokenizer(VOCAB)

In [None]:
labels = tokenizer.decode_single(tokens[0], remove_special_tokens=False).split(" ")
x_ticks = list(i for i in range(attn_matrices[0][0][0].shape[0]))
y_ticks = list(i for i in range(attn_matrices[0][0][0].shape[1]))

In [None]:
len(strings[0])

In [None]:
attn_matrices[0][0][0].shape

In [None]:
import numpy as np


def z_score_normalize(matrix):
    mean = np.mean(matrix)
    std = np.std(matrix)
    normalized_matrix = (matrix - mean) / std
    return normalized_matrix


def min_max_normalize(matrix):
    min_val = np.min(matrix)
    max_val = np.max(matrix)
    normalized_matrix = (matrix - min_val) / (max_val - min_val)
    return normalized_matrix

In [None]:
import matplotlib.pyplot as plt

# plot heatmap of first attention matrix
plt.figure(figsize=(10, 10))
heatmap = plt.imshow(
    min_max_normalize(attn_matrices[0][0][0].cpu().detach().numpy()),
    cmap="coolwarm",
    interpolation="nearest",
)

plt.xticks(ticks=x_ticks, labels=labels)
plt.yticks(ticks=y_ticks, labels=labels)
cbar = plt.colorbar(heatmap)
cbar.set_label("Attention weights")
plt.show()

In [None]:
minimax_norm = min_max_normalize(attn_matrices[0][0][0].cpu().detach().numpy())

In [None]:
minimax_norm[11][11]

# Experiment 2


In [None]:
k = 3

In [None]:
path = generate_dataset(
    n=10_000,
    k=3,
    max_length=512,
    balanced=0.6,
)

In [None]:
VOCAB = "".join(
    ["".join((key, value)) for key, value in list(constants.BRACKETS.items())[:k]]
)
VOCAB

In [None]:
dataset_dyck_2 = DyckLanguageDataset(path, VOCAB).to(device)

In [None]:
train_size = int(0.8 * len(dataset_dyck_2))
val_size = int(0.15 * train_size)
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    dataset, [train_size, val_size, test_size]
)

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False)

In [None]:
model_config = TransformerClassifierConfig(
    vocab_size=len(VOCAB),
    d_model=512,
    n_heads=4,
    dim_ff=1024,
    n_layers=2,
    n_classes=2,
    max_seq_len=512,
)

In [None]:
model = TransformerClassifier(model_config).to(device)

In [None]:
crit = torch.nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [None]:
train_loss, train_acc, val_loss, val_acc = model.train_model(
    device=device,
    epochs=10,
    optimizer=optimizer,
    criterion=crit,
    train_dataloader=train_dataloader,
    eval_dataloader=val_dataloader,
    use_mask=True,
)

In [None]:
test_loss, test_acc = model.eval_model(
    device=device,
    test_dataloader=test_dataloader,
    criterion=crit,
    use_mask=True,
)