In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch

from dyck_k_generator import constants

In [3]:
device = (
    "cuda:0"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)
device

'cuda:0'

In [4]:
if device == "mps":
    torch.mps.empty_cache()
elif device == "cuda:0":
    torch.cuda.empty_cache()

In [5]:
k = 1

In [6]:
VOCAB = "".join(
    ["".join((key, value)) for key, value in list(constants.BRACKETS.items())[:k]]
)
VOCAB

'()'

In [7]:
from dataset.dataset import DyckLanguageDataset

In [8]:
dataset = DyckLanguageDataset("data/dyck-1_50000-samples_80-len_p05.jsonl", VOCAB).to(
    device
)

Loaded 50000 samples from data/dyck-1_50000-samples_80-len_p05.jsonl


In [9]:
from torch.utils.data import random_split

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size

train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

In [10]:
from torch.utils.data import DataLoader

In [11]:
dl = DataLoader(train_dataset, batch_size=16, shuffle=True)

In [12]:
test_dl = DataLoader(test_dataset, batch_size=8, shuffle=True)

# Manual Transformer + BERTViz


In [13]:
from transformer.hooked_transformer import (
    TransformerClassifier,
    TransformerClassifierConfig,
    pad_token_mask,
)

In [14]:
model_config = TransformerClassifierConfig(
    vocab_size=len(VOCAB), d_model=128, n_heads=1, dim_ff=256, n_layers=1, n_classes=2
)

In [15]:
model = TransformerClassifier(model_config)

In [16]:
model.train()

TransformerClassifier(
  (embedding): Embedding(5, 128)
  (encoder_layers): ModuleList(
    (0): TransformerEncoderLayer(
      (attn): MultiHeadAttention(
        (wq): Linear(in_features=128, out_features=128, bias=True)
        (wk): Linear(in_features=128, out_features=128, bias=True)
        (wv): Linear(in_features=128, out_features=128, bias=True)
        (wo): Linear(in_features=128, out_features=128, bias=True)
        (attn): ScaledDotProductAttention()
        (dropout): Dropout(p=0.1, inplace=False)
        (ln): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      )
      (ff): Sequential(
        (0): Linear(in_features=128, out_features=256, bias=True)
        (1): ReLU()
        (2): Linear(in_features=256, out_features=128, bias=True)
        (3): ReLU()
      )
      (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (fc): Line

In [17]:
model.to(device)

TransformerClassifier(
  (embedding): Embedding(5, 128)
  (encoder_layers): ModuleList(
    (0): TransformerEncoderLayer(
      (attn): MultiHeadAttention(
        (wq): Linear(in_features=128, out_features=128, bias=True)
        (wk): Linear(in_features=128, out_features=128, bias=True)
        (wv): Linear(in_features=128, out_features=128, bias=True)
        (wo): Linear(in_features=128, out_features=128, bias=True)
        (attn): ScaledDotProductAttention()
        (dropout): Dropout(p=0.1, inplace=False)
        (ln): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      )
      (ff): Sequential(
        (0): Linear(in_features=128, out_features=256, bias=True)
        (1): ReLU()
        (2): Linear(in_features=256, out_features=128, bias=True)
        (3): ReLU()
      )
      (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (fc): Line

In [18]:
import torch.optim as optim

crit = torch.nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [19]:
from tqdm.auto import tqdm

In [20]:
epochs = 5

for epoch in range(epochs):
    running_loss = 0.0

    total_correct = 0
    total_samples = 0

    for i, data in enumerate(tqdm(dl)):
        _, labels, tokens = data
        labels = labels.type(torch.LongTensor)
        labels = labels.to(device)
        tokens = tokens.to(device)

        optimizer.zero_grad()

        mask = pad_token_mask(tokens)
        outputs = model(tokens, mask=None)
        loss = crit(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        _, predicted = outputs.max(1)

        # Count correct predictions
        total_correct += (predicted == labels).sum().item()
        total_samples += labels.size(0)

        # Calculate accuracy
        accuracy = (total_correct / total_samples) * 100
        # calculate accuracy
        if i % 100 == 99:
            print(
                f"Epoch: {epoch + 1}, Loss: {running_loss / 100}, Accuracy: {accuracy}%"
            )
            running_loss = 0.0

  0%|          | 0/2500 [00:00<?, ?it/s]

Epoch: 1, Loss: 0.7017724961042404, Accuracy: 49.75%
Epoch: 1, Loss: 0.7044054061174393, Accuracy: 49.15625%
Epoch: 1, Loss: 0.7032073503732681, Accuracy: 49.375%
Epoch: 1, Loss: 0.7008068627119064, Accuracy: 49.578125%
Epoch: 1, Loss: 0.6968782502412796, Accuracy: 50.1375%
Epoch: 1, Loss: 0.6992424154281616, Accuracy: 50.0625%
Epoch: 1, Loss: 0.6996126294136047, Accuracy: 50.089285714285715%
Epoch: 1, Loss: 0.6993944478034974, Accuracy: 49.703125%
Epoch: 1, Loss: 0.695627635717392, Accuracy: 49.791666666666664%
Epoch: 1, Loss: 0.6995605140924454, Accuracy: 49.8%
Epoch: 1, Loss: 0.6998740947246551, Accuracy: 49.64772727272727%
Epoch: 1, Loss: 0.6983847111463547, Accuracy: 49.583333333333336%
Epoch: 1, Loss: 0.6933022540807724, Accuracy: 49.75961538461539%
Epoch: 1, Loss: 0.6937030464410782, Accuracy: 49.875%
Epoch: 1, Loss: 0.6963864505290985, Accuracy: 49.77916666666667%
Epoch: 1, Loss: 0.6943097025156021, Accuracy: 49.94140625%
Epoch: 1, Loss: 0.6951315319538116, Accuracy: 49.9044117

  0%|          | 0/2500 [00:00<?, ?it/s]

Epoch: 2, Loss: 0.6949646145105361, Accuracy: 49.75%
Epoch: 2, Loss: 0.6913996458053588, Accuracy: 50.84375%
Epoch: 2, Loss: 0.693353419303894, Accuracy: 50.4375%
Epoch: 2, Loss: 0.6923092359304428, Accuracy: 50.171875%
Epoch: 2, Loss: 0.6899596041440964, Accuracy: 49.925000000000004%
Epoch: 2, Loss: 0.6915988171100617, Accuracy: 50.15625%
Epoch: 2, Loss: 0.6907907122373581, Accuracy: 50.205357142857146%
Epoch: 2, Loss: 0.6954142928123475, Accuracy: 50.109375%
Epoch: 2, Loss: 0.6890650510787963, Accuracy: 50.48611111111111%
Epoch: 2, Loss: 0.6871444857120514, Accuracy: 50.80625%
Epoch: 2, Loss: 0.6911111283302307, Accuracy: 50.80681818181818%
Epoch: 2, Loss: 0.6935140496492386, Accuracy: 50.67708333333333%
Epoch: 2, Loss: 0.6894390892982483, Accuracy: 50.92307692307693%
Epoch: 2, Loss: 0.6904855871200561, Accuracy: 51.03124999999999%
Epoch: 2, Loss: 0.6845996057987214, Accuracy: 51.05833333333333%
Epoch: 2, Loss: 0.6882716172933578, Accuracy: 50.81250000000001%
Epoch: 2, Loss: 0.693306

  0%|          | 0/2500 [00:00<?, ?it/s]

Epoch: 3, Loss: 0.24446938287466766, Accuracy: 91.8125%
Epoch: 3, Loss: 0.22315522577613592, Accuracy: 92.0625%
Epoch: 3, Loss: 0.21963400296866895, Accuracy: 92.45833333333333%
Epoch: 3, Loss: 0.19216226940974593, Accuracy: 92.890625%
Epoch: 3, Loss: 0.2179800122976303, Accuracy: 92.83749999999999%
Epoch: 3, Loss: 0.21520473128184675, Accuracy: 92.98958333333334%
Epoch: 3, Loss: 0.19619207007810474, Accuracy: 93.13392857142857%
Epoch: 3, Loss: 0.16579739172011615, Accuracy: 93.40625%
Epoch: 3, Loss: 0.14749830724671484, Accuracy: 93.6875%
Epoch: 3, Loss: 0.17655942188575863, Accuracy: 93.78125%
Epoch: 3, Loss: 0.1689830768853426, Accuracy: 93.91477272727272%
Epoch: 3, Loss: 0.17367743355222046, Accuracy: 94.0%
Epoch: 3, Loss: 0.1889447011332959, Accuracy: 93.98557692307692%
Epoch: 3, Loss: 0.24123542699962855, Accuracy: 93.86160714285714%
Epoch: 3, Loss: 0.1733201264217496, Accuracy: 93.90416666666667%
Epoch: 3, Loss: 0.15850066802464424, Accuracy: 93.9921875%
Epoch: 3, Loss: 0.140978

  0%|          | 0/2500 [00:00<?, ?it/s]

Epoch: 4, Loss: 0.16679189823567867, Accuracy: 94.8125%
Epoch: 4, Loss: 0.15674508384428917, Accuracy: 95.1875%
Epoch: 4, Loss: 0.16673563987948, Accuracy: 95.0625%
Epoch: 4, Loss: 0.1798365420103073, Accuracy: 94.828125%
Epoch: 4, Loss: 0.14909384472295642, Accuracy: 95.025%
Epoch: 4, Loss: 0.1713946695625782, Accuracy: 94.91666666666667%
Epoch: 4, Loss: 0.11760223999619485, Accuracy: 95.23214285714286%
Epoch: 4, Loss: 0.14479970352724195, Accuracy: 95.3046875%
Epoch: 4, Loss: 0.14286796184256673, Accuracy: 95.36111111111111%
Epoch: 4, Loss: 0.14571362089365722, Accuracy: 95.43125%
Epoch: 4, Loss: 0.18372678393498063, Accuracy: 95.29545454545455%
Epoch: 4, Loss: 0.17082225690595806, Accuracy: 95.23958333333333%
Epoch: 4, Loss: 0.15746333418413996, Accuracy: 95.23557692307692%
Epoch: 4, Loss: 0.16973500283434986, Accuracy: 95.19196428571428%
Epoch: 4, Loss: 0.16200049759820104, Accuracy: 95.19999999999999%
Epoch: 4, Loss: 0.11911892343312502, Accuracy: 95.3125%
Epoch: 4, Loss: 0.146855

  0%|          | 0/2500 [00:00<?, ?it/s]

Epoch: 5, Loss: 0.1450021503958851, Accuracy: 95.75%
Epoch: 5, Loss: 0.1507642393000424, Accuracy: 95.625%
Epoch: 5, Loss: 0.17222867670468986, Accuracy: 95.3125%
Epoch: 5, Loss: 0.1598280956968665, Accuracy: 95.265625%
Epoch: 5, Loss: 0.15093883414752782, Accuracy: 95.2375%
Epoch: 5, Loss: 0.18484525859355927, Accuracy: 95.0625%
Epoch: 5, Loss: 0.16197976237162948, Accuracy: 95.02678571428571%
Epoch: 5, Loss: 0.16393272333778441, Accuracy: 95.046875%
Epoch: 5, Loss: 0.15040372792631387, Accuracy: 95.08333333333333%
Epoch: 5, Loss: 0.13091083045117557, Accuracy: 95.2125%
Epoch: 5, Loss: 0.16455965518951415, Accuracy: 95.19886363636364%
Epoch: 5, Loss: 0.14908032700419427, Accuracy: 95.24479166666666%
Epoch: 5, Loss: 0.1548014102317393, Accuracy: 95.2548076923077%
Epoch: 5, Loss: 0.1396953641809523, Accuracy: 95.32142857142857%
Epoch: 5, Loss: 0.18583439560607076, Accuracy: 95.21666666666667%
Epoch: 5, Loss: 0.16235683726146818, Accuracy: 95.1953125%
Epoch: 5, Loss: 0.21560352394357324,

In [24]:
model.eval()

TransformerClassifier(
  (embedding): Embedding(5, 128)
  (encoder_layers): ModuleList(
    (0): TransformerEncoderLayer(
      (attn): MultiHeadAttention(
        (wq): Linear(in_features=128, out_features=128, bias=True)
        (wk): Linear(in_features=128, out_features=128, bias=True)
        (wv): Linear(in_features=128, out_features=128, bias=True)
        (wo): Linear(in_features=128, out_features=128, bias=True)
        (attn): ScaledDotProductAttention()
        (dropout): Dropout(p=0.1, inplace=False)
        (ln): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      )
      (ff): Sequential(
        (0): Linear(in_features=128, out_features=256, bias=True)
        (1): ReLU()
        (2): Linear(in_features=256, out_features=128, bias=True)
        (3): ReLU()
      )
      (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (fc): Line

In [27]:
correct = 0
total = 0
total_loss = 0

criterion = torch.nn.CrossEntropyLoss()

with torch.no_grad():  # Important to use torch.no_grad() to save memory and computations
    for batch in test_dl:
        _, labels, tokens = batch
        labels = labels.type(torch.LongTensor)
        labels = labels.to(device)
        tokens = tokens.to(device)

        # Forward pass
        outputs = model(tokens)

        # Calculate loss
        loss = criterion(outputs, labels)
        total_loss += loss.item()

        # Convert outputs probabilities to predicted class (0 or 1)
        _, predicted = torch.max(outputs.data, 1)

        # Count total and correct predictions
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

# Calculate average loss and accuracy
avg_loss = total_loss / len(test_dl)
accuracy = 100 * correct / total

print(f"Accuracy of the model on the test data: {accuracy:.2f}%")
print(f"Average loss on the test data: {avg_loss:.4f}")

Accuracy of the model on the test data: 96.19%
Average loss on the test data: 0.1328
