In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch

from dyck_k_generator import constants

In [None]:
device = (
    "cuda:0"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)
device

In [None]:
if device == "mps":
    torch.mps.empty_cache()
elif device == "cuda:0":
    torch.cuda.empty_cache()

In [None]:
torch.manual_seed(42)

In [None]:
k = 1

In [None]:
from dyck_k_generator.generator import generate_dataset

path = generate_dataset(
    n=1000,
    k=k,
    max_length=10,
    balanced=0.55,
)

In [None]:
VOCAB = "".join(
    ["".join((key, value)) for key, value in list(constants.BRACKETS.items())[:k]]
)
VOCAB

In [None]:
from dataset.dataset import DyckLanguageDataset

In [None]:
dataset = DyckLanguageDataset(path, VOCAB).to(device)

In [None]:
from torch.utils.data import random_split

train_size = int(0.8 * len(dataset))
val_size = int(0.15 * train_size)
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    dataset, [train_size, val_size, test_size]
)

In [None]:
from torch.utils.data import DataLoader

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=True)

# Hooked Transformer (Bidirectional mask) - Dyck-1 dataset


In [None]:
from transformer.hooked_transformer import (
    TransformerClassifier,
    TransformerClassifierConfig,
    causal_mask,
    pad_token_mask,
)

In [None]:
model_config = TransformerClassifierConfig(
    vocab_size=len(VOCAB),
    d_model=256,
    n_heads=1,
    dim_ff=384,
    n_layers=1,
    n_classes=2,
    max_seq_len=10,
)

In [None]:
model_bidirectional = TransformerClassifier(model_config)

In [None]:
model_bidirectional.to(device)

In [None]:
import torch.optim as optim

crit = torch.nn.CrossEntropyLoss()

optimizer = optim.Adam(model_bidirectional.parameters(), lr=1e-4)

In [None]:
train_loss, train_acc, val_loss, val_acc = model_bidirectional.train_model(
    device=device,
    epochs=20,
    optimizer=optimizer,
    criterion=crit,
    train_dataloader=train_dataloader,
    eval_dataloader=val_dataloader,
    use_mask="bidirectional",
)

In [None]:
test_loss, test_acc = model_bidirectional.eval_model(
    device=device,
    test_dataloader=test_dataloader,
    criterion=crit,
    use_mask="bidirectional",
)

In [None]:
import matplotlib.pyplot as plt

plt.plot(train_loss, label="train loss")
plt.plot(val_loss, label="val loss")
plt.plot(train_acc, label="train acc")
plt.plot(val_acc, label="val acc")
plt.legend()
plt.show()

# Hooked Transformer (causal mask) - Dyck-1 dataset

In [None]:
model_causal = TransformerClassifier(model_config).to(device)

In [None]:
crit = torch.nn.CrossEntropyLoss()

optimizer = optim.Adam(model_causal.parameters(), lr=1e-4)

In [None]:
train_loss, train_acc, val_loss, val_acc = model_causal.train_model(
    device=device,
    epochs=50,
    optimizer=optimizer,
    criterion=crit,
    train_dataloader=train_dataloader,
    eval_dataloader=val_dataloader,
    use_mask="causal",
)

In [None]:
test_loss, test_acc = model_causal.eval_model(
    device=device,
    test_dataloader=test_dataloader,
    criterion=crit,
    use_mask="causal",
)

In [None]:
import matplotlib.pyplot as plt

plt.plot(train_loss, label="train loss")
plt.plot(val_loss, label="val loss")
plt.plot(train_acc, label="train acc")
plt.plot(val_acc, label="val acc")
plt.legend()
plt.show()

# Attention plots:


In [None]:
batch = next(iter(test_dataloader))
strings, labels, tokens = batch

mask = pad_token_mask(tokens)
attn_matrices = model_bidirectional.get_attn_matrices(tokens, mask)

In [None]:
from transformer_viz.visualizer import min_max_normalize, plot_attn_matrices

In [None]:
from dataset.dataset import DyckLanguageTokenizer
from dyck_k_generator.checker import is_dyck_word

In [None]:
batch = (
    ")))))))(((((((",
    is_dyck_word(")))))))(((((((", k=1),
    DyckLanguageTokenizer(VOCAB).tokenize(")))))))(((((((").to(device),
)
batch

In [None]:
plot_attn_matrices(VOCAB, batch, model_bidirectional, min_max_normalize, pad_token_mask)

In [None]:
minimax_norm = min_max_normalize(attn_matrices[0][0][0].cpu().detach().numpy())

In [None]:
minimax_norm[0][0]

# Hooked Transformer (Bidirectional mask) - Dyck-3 dataset


In [None]:
k = 3

In [None]:
path = generate_dataset(
    n=5_000,
    k=3,
    min_length=8,
    max_length=8,
    balanced=0.6,
)

In [None]:
VOCAB = "".join(
    ["".join((key, value)) for key, value in list(constants.BRACKETS.items())[:k]]
)
VOCAB

In [None]:
dataset_dyck_2 = DyckLanguageDataset(path, VOCAB).to(device)

In [None]:
train_size = int(0.8 * len(dataset_dyck_2))
val_size = int(0.15 * train_size)
test_size = len(dataset_dyck_2) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    dataset_dyck_2, [train_size, val_size, test_size]
)

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=True)

In [None]:
model_config = TransformerClassifierConfig(
    vocab_size=len(VOCAB),
    d_model=512,
    n_heads=1,
    dim_ff=1024,
    n_layers=1,
    n_classes=2,
    max_seq_len=10,
)

In [None]:
model = TransformerClassifier(model_config).to(device)

In [None]:
crit = torch.nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=1e-5)

In [None]:
train_loss, train_acc, val_loss, val_acc = model.train_model(
    device=device,
    epochs=15,
    optimizer=optimizer,
    criterion=crit,
    train_dataloader=train_dataloader,
    eval_dataloader=val_dataloader,
    use_mask="bidirectional",
)

In [None]:
test_loss, test_acc = model.eval_model(
    device=device,
    test_dataloader=test_dataloader,
    criterion=crit,
    use_mask="bidirectional",
)

In [None]:
import matplotlib.pyplot as plt

plt.plot(train_loss, label="train loss")
plt.plot(val_loss, label="val loss")
plt.plot(train_acc, label="train acc")
plt.plot(val_acc, label="val acc")
plt.legend()
plt.show()

In [None]:
batch = next(iter(test_dataloader))
batch

In [None]:
plot_attn_matrices(VOCAB, batch, model, min_max_normalize, pad_token_mask)

# Out of Distribution Dyck-3

In [None]:
k = 3

In [None]:
dyck_3_train_dataset = generate_dataset(
    n=50_000,
    k=3,
    min_length=96,
    max_length=96,
    balanced=0.5,
)

In [None]:
dyck_3_train_dataset

In [None]:
VOCAB = "".join(
    ["".join((key, value)) for key, value in list(constants.BRACKETS.items())[:k]]
)
VOCAB

In [None]:
dyck_3_train = DyckLanguageDataset(dyck_3_train_dataset, VOCAB).to(device)

train_dataset, val_dataset = random_split(
    dyck_3_train, [0.8, 0.2]
)

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=True)

In [None]:
ood_model_config = TransformerClassifierConfig(
    vocab_size=len(VOCAB),
    d_model=256,
    n_heads=1,
    dim_ff=320,
    n_layers=2,
    n_classes=2,
    max_seq_len=128,
)

ood_model = TransformerClassifier(ood_model_config).to(device)

In [None]:
crit = torch.nn.CrossEntropyLoss()

optimizer = optim.Adam(ood_model.parameters(), lr=1e-4)

In [None]:
train_loss, train_acc, val_loss, val_acc = ood_model.train_model(
    device=device,
    epochs=10,
    optimizer=optimizer,
    criterion=crit,
    train_dataloader=train_dataloader,
    eval_dataloader=val_dataloader,
    use_mask="bidirectional",
)

In [None]:
import matplotlib.pyplot as plt

plt.plot(train_loss, label="train loss")
plt.plot(val_loss, label="val loss")
plt.plot(train_acc, label="train acc")
plt.plot(val_acc, label="val acc")
plt.legend()
plt.show()

In [None]:
dyck_3_test_dataset = generate_dataset(
    n=10_000,
    k=3,
    min_length=32,
    max_length=128,
    balanced=0.5,
)

In [None]:
dyck_3_test = DyckLanguageDataset(dyck_3_test_dataset, VOCAB).to(device)
test_dataloader = DataLoader(dyck_3_test, batch_size=8, shuffle=True)


In [None]:
test_loss, test_acc = ood_model.eval_model(
    device=device,
    test_dataloader=test_dataloader,
    criterion=crit,
    use_mask="bidirectional",
)