In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch

from dyck_k_generator import constants

In [3]:
device = (
    "cuda:0"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)
device

'cuda:0'

In [4]:
if device == "mps":
    torch.mps.empty_cache()
elif device == "cuda:0":
    torch.cuda.empty_cache()

In [5]:
k = 1

In [6]:
VOCAB = "".join(
    ["".join((key, value)) for key, value in list(constants.BRACKETS.items())[:k]]
)
VOCAB

'()'

In [7]:
from transformer.dataset import DyckLanguageDataset

In [8]:
dataset = DyckLanguageDataset("data/dyck-1_50000-samples_80-len_p05.jsonl", VOCAB).to(
    device
)

Loaded 50000 samples from data/dyck-1_50000-samples_80-len_p05.jsonl


In [9]:
from torch.utils.data import random_split

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size

train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

In [10]:
from torch.utils.data import DataLoader

In [11]:
dl = DataLoader(train_dataset, batch_size=16, shuffle=True)

In [12]:
test_dl = DataLoader(test_dataset, batch_size=8, shuffle=True)

# Manual Transformer + BERTViz


In [13]:
from transformer.hooked_transformer import (
    TransformerClassifier,
    TransformerClassifierConfig,
    pad_token_mask,
)

In [14]:
model_config = TransformerClassifierConfig(
    vocab_size=len(VOCAB), d_model=128, n_heads=1, dim_ff=256, n_layers=1, n_classes=2
)

In [15]:
model = TransformerClassifier(model_config)

In [16]:
model.train()

TransformerClassifier(
  (embedding): Embedding(5, 128)
  (encoder_layers): ModuleList(
    (0): TransformerEncoderLayer(
      (attn): MultiHeadAttention(
        (wq): Linear(in_features=128, out_features=128, bias=True)
        (wk): Linear(in_features=128, out_features=128, bias=True)
        (wv): Linear(in_features=128, out_features=128, bias=True)
        (wo): Linear(in_features=128, out_features=128, bias=True)
        (attn): ScaledDotProductAttention()
        (dropout): Dropout(p=0.1, inplace=False)
        (ln): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      )
      (ff): Sequential(
        (0): Linear(in_features=128, out_features=256, bias=True)
        (1): ReLU()
        (2): Linear(in_features=256, out_features=128, bias=True)
        (3): ReLU()
      )
      (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (fc): Line

In [17]:
model.to(device)

TransformerClassifier(
  (embedding): Embedding(5, 128)
  (encoder_layers): ModuleList(
    (0): TransformerEncoderLayer(
      (attn): MultiHeadAttention(
        (wq): Linear(in_features=128, out_features=128, bias=True)
        (wk): Linear(in_features=128, out_features=128, bias=True)
        (wv): Linear(in_features=128, out_features=128, bias=True)
        (wo): Linear(in_features=128, out_features=128, bias=True)
        (attn): ScaledDotProductAttention()
        (dropout): Dropout(p=0.1, inplace=False)
        (ln): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      )
      (ff): Sequential(
        (0): Linear(in_features=128, out_features=256, bias=True)
        (1): ReLU()
        (2): Linear(in_features=256, out_features=128, bias=True)
        (3): ReLU()
      )
      (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (fc): Line

In [18]:
import torch.optim as optim

crit = torch.nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [19]:
from tqdm.auto import tqdm

In [23]:
epochs = 5

for epoch in range(epochs):
    running_loss = 0.0

    total_correct = 0
    total_samples = 0

    for i, data in enumerate(tqdm(dl)):
        _, labels, tokens = data
        labels = labels.type(torch.LongTensor)
        labels = labels.to(device)
        tokens = tokens.to(device)

        optimizer.zero_grad()

        mask = pad_token_mask(tokens)
        outputs = model(tokens, mask=None)
        loss = crit(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        _, predicted = outputs.max(1)

        # Count correct predictions
        total_correct += (predicted == labels).sum().item()
        total_samples += labels.size(0)

        # Calculate accuracy
        accuracy = (total_correct / total_samples) * 100
        # calculate accuracy
        if i % 100 == 99:
            print(
                f"Epoch: {epoch + 1}, Loss: {running_loss / 100}, Accuracy: {accuracy}%"
            )
            running_loss = 0.0

  0%|          | 0/2500 [00:00<?, ?it/s]

Epoch: 1, Loss: 0.1252574836742133, Accuracy: 96.5%
Epoch: 1, Loss: 0.1612290677241981, Accuracy: 95.875%
Epoch: 1, Loss: 0.1418215196020901, Accuracy: 95.9375%
Epoch: 1, Loss: 0.14777467732317745, Accuracy: 95.90625%
Epoch: 1, Loss: 0.1371242884453386, Accuracy: 95.975%
Epoch: 1, Loss: 0.15617146665230394, Accuracy: 95.89583333333334%
Epoch: 1, Loss: 0.11831501048058271, Accuracy: 96.03571428571429%
Epoch: 1, Loss: 0.298379271607846, Accuracy: 95.25%
Epoch: 1, Loss: 0.18436361694708467, Accuracy: 95.11111111111111%
Epoch: 1, Loss: 0.14364695316180587, Accuracy: 95.19375%
Epoch: 1, Loss: 0.1246203856356442, Accuracy: 95.31818181818181%
Epoch: 1, Loss: 0.14554212672635913, Accuracy: 95.38541666666667%
Epoch: 1, Loss: 0.1512499910220504, Accuracy: 95.41826923076923%
Epoch: 1, Loss: 0.12865840824320912, Accuracy: 95.48660714285714%
Epoch: 1, Loss: 0.1189960309676826, Accuracy: 95.57916666666667%
Epoch: 1, Loss: 0.15177488762885333, Accuracy: 95.59765625%
Epoch: 1, Loss: 0.1563127097487449

  0%|          | 0/2500 [00:00<?, ?it/s]

Epoch: 2, Loss: 0.19645076582208276, Accuracy: 93.3125%
Epoch: 2, Loss: 0.16730761169455946, Accuracy: 93.96875%
Epoch: 2, Loss: 0.15643150942400097, Accuracy: 94.41666666666667%
Epoch: 2, Loss: 0.13075309737585486, Accuracy: 94.921875%
Epoch: 2, Loss: 0.13694046798162163, Accuracy: 95.16250000000001%
Epoch: 2, Loss: 0.14205553133040666, Accuracy: 95.26041666666667%
Epoch: 2, Loss: 0.14888643636368215, Accuracy: 95.3125%
Epoch: 2, Loss: 0.1439134368300438, Accuracy: 95.3984375%
Epoch: 2, Loss: 0.11215221283026039, Accuracy: 95.59722222222223%
Epoch: 2, Loss: 0.1628236943297088, Accuracy: 95.56875000000001%
Epoch: 2, Loss: 0.15365199402906002, Accuracy: 95.56818181818181%
Epoch: 2, Loss: 0.12017332941293717, Accuracy: 95.67708333333333%
Epoch: 2, Loss: 0.13315392060205342, Accuracy: 95.73557692307692%
Epoch: 2, Loss: 0.1207619784027338, Accuracy: 95.80803571428571%
Epoch: 2, Loss: 0.1377338131237775, Accuracy: 95.83333333333334%
Epoch: 2, Loss: 0.1470795820467174, Accuracy: 95.82421875%

  0%|          | 0/2500 [00:00<?, ?it/s]

Epoch: 3, Loss: 0.1215998753439635, Accuracy: 96.5625%
Epoch: 3, Loss: 0.18043030761182308, Accuracy: 95.4375%
Epoch: 3, Loss: 0.14169106606394052, Accuracy: 95.64583333333333%
Epoch: 3, Loss: 0.13094644121825694, Accuracy: 95.84375%
Epoch: 3, Loss: 0.1352026926353574, Accuracy: 95.9125%
Epoch: 3, Loss: 0.13988829903304578, Accuracy: 95.96875%
Epoch: 3, Loss: 0.14037946834228932, Accuracy: 95.98214285714286%
Epoch: 3, Loss: 0.1800075493659824, Accuracy: 95.78125%
Epoch: 3, Loss: 0.1470486306026578, Accuracy: 95.74305555555556%
Epoch: 3, Loss: 0.11972265152260661, Accuracy: 95.84375%
Epoch: 3, Loss: 0.12632695742882788, Accuracy: 95.88068181818183%
Epoch: 3, Loss: 0.15436113554984332, Accuracy: 95.859375%
Epoch: 3, Loss: 0.14399097309447825, Accuracy: 95.86538461538461%
Epoch: 3, Loss: 0.12037418198771775, Accuracy: 95.93303571428572%
Epoch: 3, Loss: 0.14235676418989895, Accuracy: 95.9375%
Epoch: 3, Loss: 0.11693075906485319, Accuracy: 96.00390625%
Epoch: 3, Loss: 0.11701151153072714, A

  0%|          | 0/2500 [00:00<?, ?it/s]

Epoch: 4, Loss: 0.1815065824612975, Accuracy: 94.3125%
Epoch: 4, Loss: 0.13000142465345563, Accuracy: 95.375%
Epoch: 4, Loss: 0.12617232226766645, Accuracy: 95.77083333333334%
Epoch: 4, Loss: 0.11283955536782742, Accuracy: 96.109375%
Epoch: 4, Loss: 0.16363862715661526, Accuracy: 95.89999999999999%
Epoch: 4, Loss: 0.1067474174965173, Accuracy: 96.13541666666666%
Epoch: 4, Loss: 0.14650901495479046, Accuracy: 96.08928571428571%
Epoch: 4, Loss: 0.12933372266590595, Accuracy: 96.1484375%
Epoch: 4, Loss: 0.14032363709993662, Accuracy: 96.13888888888889%
Epoch: 4, Loss: 0.10912287363782525, Accuracy: 96.24374999999999%
Epoch: 4, Loss: 0.1250154340453446, Accuracy: 96.2784090909091%
Epoch: 4, Loss: 0.18257066479884088, Accuracy: 96.09895833333333%
Epoch: 4, Loss: 0.14889267748221754, Accuracy: 96.07211538461539%
Epoch: 4, Loss: 0.14122974935919047, Accuracy: 96.04017857142857%
Epoch: 4, Loss: 0.15547442353330554, Accuracy: 96.00416666666666%
Epoch: 4, Loss: 0.1304411625303328, Accuracy: 96.0

  0%|          | 0/2500 [00:00<?, ?it/s]

Epoch: 5, Loss: 0.13243139388039707, Accuracy: 96.375%
Epoch: 5, Loss: 0.14560538660734892, Accuracy: 96.03125%
Epoch: 5, Loss: 0.14247733395546675, Accuracy: 95.95833333333333%
Epoch: 5, Loss: 0.21314703173935412, Accuracy: 95.234375%
Epoch: 5, Loss: 0.15644852874800563, Accuracy: 95.1375%
Epoch: 5, Loss: 0.13537703832611442, Accuracy: 95.3125%
Epoch: 5, Loss: 0.13043797798454762, Accuracy: 95.46428571428571%
Epoch: 5, Loss: 0.13087120265699922, Accuracy: 95.59375%
Epoch: 5, Loss: 0.13894293339923025, Accuracy: 95.63888888888889%
Epoch: 5, Loss: 0.1251847102213651, Accuracy: 95.73125%
Epoch: 5, Loss: 0.3613552219979465, Accuracy: 94.9375%
Epoch: 5, Loss: 0.15554143078625202, Accuracy: 94.98958333333334%
Epoch: 5, Loss: 0.13141772048547865, Accuracy: 95.0673076923077%
Epoch: 5, Loss: 0.167375581972301, Accuracy: 95.03125%
Epoch: 5, Loss: 0.14805990718305112, Accuracy: 95.09166666666667%
Epoch: 5, Loss: 0.1373562161438167, Accuracy: 95.1640625%
Epoch: 5, Loss: 0.14191289467737078, Accur

In [24]:
model.eval()

TransformerClassifier(
  (embedding): Embedding(5, 128)
  (encoder_layers): ModuleList(
    (0): TransformerEncoderLayer(
      (attn): MultiHeadAttention(
        (wq): Linear(in_features=128, out_features=128, bias=True)
        (wk): Linear(in_features=128, out_features=128, bias=True)
        (wv): Linear(in_features=128, out_features=128, bias=True)
        (wo): Linear(in_features=128, out_features=128, bias=True)
        (attn): ScaledDotProductAttention()
        (dropout): Dropout(p=0.1, inplace=False)
        (ln): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      )
      (ff): Sequential(
        (0): Linear(in_features=128, out_features=256, bias=True)
        (1): ReLU()
        (2): Linear(in_features=256, out_features=128, bias=True)
        (3): ReLU()
      )
      (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (fc): Line

In [27]:
correct = 0
total = 0
total_loss = 0

criterion = torch.nn.CrossEntropyLoss()

with torch.no_grad():  # Important to use torch.no_grad() to save memory and computations
    for batch in test_dl:
        _, labels, tokens = batch
        labels = labels.type(torch.LongTensor)
        labels = labels.to(device)
        tokens = tokens.to(device)

        # Forward pass
        outputs = model(tokens)

        # Calculate loss
        loss = criterion(outputs, labels)
        total_loss += loss.item()

        # Convert outputs probabilities to predicted class (0 or 1)
        _, predicted = torch.max(outputs.data, 1)

        # Count total and correct predictions
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

# Calculate average loss and accuracy
avg_loss = total_loss / len(test_dl)
accuracy = 100 * correct / total

print(f"Accuracy of the model on the test data: {accuracy:.2f}%")
print(f"Average loss on the test data: {avg_loss:.4f}")

Accuracy of the model on the test data: 96.19%
Average loss on the test data: 0.1328
