In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch

from dyck_k_generator import constants

In [3]:
device = (
    "cuda:0"
    if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available() else "cpu"
)
device

'cuda:0'

In [4]:
if device == "mps":
    torch.mps.empty_cache()
elif device == "cuda:0":
    torch.cuda.empty_cache()

In [5]:
k = 1

In [6]:
VOCAB = "".join(
    ["".join((key, value)) for key, value in list(constants.BRACKETS.items())[:k]]
)
VOCAB

'()'

In [7]:
from dataset.dataset import DyckLanguageDataset

In [8]:
dataset = DyckLanguageDataset("data/dyck-1_500000-samples_80-len_p05.jsonl", VOCAB).to(
    device
)

Loaded 500000 samples from data/dyck-1_500000-samples_80-len_p05.jsonl


Tokenizing strings: 100%|██████████| 500000/500000 [00:03<00:00, 135183.86it/s]


In [9]:
from torch.utils.data import random_split

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size

train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

In [10]:
from torch.utils.data import DataLoader

In [11]:
dl = DataLoader(train_dataset, batch_size=16, shuffle=True)

In [12]:
test_dl = DataLoader(test_dataset, batch_size=8, shuffle=True)

# Manual Transformer + BERTViz


In [13]:
from transformer.hooked_transformer import (
    TransformerClassifier,
    TransformerClassifierConfig,
    pad_token_mask,
)

In [14]:
model_config = TransformerClassifierConfig(
    vocab_size=len(VOCAB),
    d_model=128,
    n_heads=1,
    dim_ff=256,
    n_layers=1,
    n_classes=2,
    max_seq_len=80,
)

In [15]:
model = TransformerClassifier(model_config)

In [16]:
model.train()

TransformerClassifier(
  (embedding): Embedding(5, 128)
  (pos_encoder): PositionalEncoder(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder_layers): ModuleList(
    (0): TransformerEncoderLayer(
      (attn): MultiHeadAttention(
        (q_linear): Linear(in_features=128, out_features=128, bias=True)
        (k_linear): Linear(in_features=128, out_features=128, bias=True)
        (v_linear): Linear(in_features=128, out_features=128, bias=True)
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (residual_dropout): Dropout(p=0.1, inplace=False)
        (attn): ScaledDotProductAttention()
        (out): Linear(in_features=128, out_features=128, bias=True)
      )
      (ff): Sequential(
        (0): Linear(in_features=128, out_features=256, bias=True)
        (1): ReLU()
        (2): Linear(in_features=256, out_features=128, bias=True)
      )
      (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((128,), eps=1e-05, elementwise_aff

In [17]:
model.to(device)

TransformerClassifier(
  (embedding): Embedding(5, 128)
  (pos_encoder): PositionalEncoder(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder_layers): ModuleList(
    (0): TransformerEncoderLayer(
      (attn): MultiHeadAttention(
        (q_linear): Linear(in_features=128, out_features=128, bias=True)
        (k_linear): Linear(in_features=128, out_features=128, bias=True)
        (v_linear): Linear(in_features=128, out_features=128, bias=True)
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (residual_dropout): Dropout(p=0.1, inplace=False)
        (attn): ScaledDotProductAttention()
        (out): Linear(in_features=128, out_features=128, bias=True)
      )
      (ff): Sequential(
        (0): Linear(in_features=128, out_features=256, bias=True)
        (1): ReLU()
        (2): Linear(in_features=256, out_features=128, bias=True)
      )
      (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((128,), eps=1e-05, elementwise_aff

In [18]:
import torch.optim as optim

crit = torch.nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [19]:
from tqdm.auto import tqdm

In [20]:
epochs = 2

for epoch in range(epochs):
    running_loss = 0.0

    total_correct = 0
    total_samples = 0

    for i, data in enumerate(tqdm(dl)):
        _, labels, tokens = data
        labels = labels.type(torch.LongTensor)
        labels = labels.to(device)
        tokens = tokens.to(device)

        optimizer.zero_grad()

        mask = pad_token_mask(tokens)
        outputs = model(tokens, mask=mask)
        loss = crit(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        _, predicted = outputs.max(1)

        # Count correct predictions
        total_correct += (predicted == labels).sum().item()
        total_samples += labels.size(0)

        # Calculate accuracy
        accuracy = (total_correct / total_samples) * 100
        # calculate accuracy
        if i % 100 == 99:
            print(
                f"Epoch: {epoch + 1}, Loss: {running_loss / 100}, Accuracy: {accuracy}%"
            )
            running_loss = 0.0

  0%|          | 0/25000 [00:00<?, ?it/s]

Epoch: 1, Loss: 0.6221108528971672, Accuracy: 72.625%
Epoch: 1, Loss: 0.272737877741456, Accuracy: 85.40625%
Epoch: 1, Loss: 0.02348704358097166, Accuracy: 90.27083333333333%
Epoch: 1, Loss: 0.004190042151603848, Accuracy: 92.703125%
Epoch: 1, Loss: 0.0019761067756917327, Accuracy: 94.16250000000001%
Epoch: 1, Loss: 0.0012737030346761458, Accuracy: 95.13541666666666%
Epoch: 1, Loss: 0.0007541263606981374, Accuracy: 95.83035714285715%
Epoch: 1, Loss: 0.0006207737448858097, Accuracy: 96.3515625%
Epoch: 1, Loss: 0.0007105678715743125, Accuracy: 96.75694444444444%
Epoch: 1, Loss: 0.0003472620862885378, Accuracy: 97.08125%
Epoch: 1, Loss: 0.00028122027084464206, Accuracy: 97.3465909090909%
Epoch: 1, Loss: 0.0002505427705182228, Accuracy: 97.56770833333334%
Epoch: 1, Loss: 0.00024029393971432001, Accuracy: 97.7548076923077%
Epoch: 1, Loss: 0.00020851585213677026, Accuracy: 97.91517857142857%
Epoch: 1, Loss: 0.0001903409420629032, Accuracy: 98.05416666666666%
Epoch: 1, Loss: 0.000145152755503

  0%|          | 0/25000 [00:00<?, ?it/s]

Epoch: 2, Loss: 1.4087870034984462e-06, Accuracy: 100.0%
Epoch: 2, Loss: 1.2276749862394355e-06, Accuracy: 100.0%
Epoch: 2, Loss: 7.207621698057665e-07, Accuracy: 100.0%
Epoch: 2, Loss: 5.708611809751574e-07, Accuracy: 100.0%
Epoch: 2, Loss: 7.198649582562667e-07, Accuracy: 100.0%
Epoch: 2, Loss: 3.628422435397738e-07, Accuracy: 100.0%
Epoch: 2, Loss: 5.534193397949139e-07, Accuracy: 100.0%
Epoch: 2, Loss: 4.10002653730146e-07, Accuracy: 100.0%
Epoch: 2, Loss: 2.7701182347072973e-07, Accuracy: 100.0%
Epoch: 2, Loss: 2.4735867750536043e-07, Accuracy: 100.0%
Epoch: 2, Loss: 1.9669505896047213e-07, Accuracy: 100.0%
Epoch: 2, Loss: 2.0019694233042173e-07, Accuracy: 100.0%
Epoch: 2, Loss: 1.7009662702349716e-07, Accuracy: 100.0%
Epoch: 2, Loss: 1.862643027550348e-07, Accuracy: 100.0%
Epoch: 2, Loss: 1.2584026599427033e-07, Accuracy: 100.0%
Epoch: 2, Loss: 1.304596101192601e-07, Accuracy: 100.0%
Epoch: 2, Loss: 1.3098105911879542e-07, Accuracy: 100.0%
Epoch: 2, Loss: 1.1414282871768933e-07, 

In [21]:
model.eval()

TransformerClassifier(
  (embedding): Embedding(5, 128)
  (pos_encoder): PositionalEncoder(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder_layers): ModuleList(
    (0): TransformerEncoderLayer(
      (attn): MultiHeadAttention(
        (q_linear): Linear(in_features=128, out_features=128, bias=True)
        (k_linear): Linear(in_features=128, out_features=128, bias=True)
        (v_linear): Linear(in_features=128, out_features=128, bias=True)
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (residual_dropout): Dropout(p=0.1, inplace=False)
        (attn): ScaledDotProductAttention()
        (out): Linear(in_features=128, out_features=128, bias=True)
      )
      (ff): Sequential(
        (0): Linear(in_features=128, out_features=256, bias=True)
        (1): ReLU()
        (2): Linear(in_features=256, out_features=128, bias=True)
      )
      (ln1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((128,), eps=1e-05, elementwise_aff

In [23]:
correct = 0
total = 0
total_loss = 0

criterion = torch.nn.CrossEntropyLoss()

with torch.no_grad():  # Important to use torch.no_grad() to save memory and computations
    for batch in tqdm(test_dl):
        _, labels, tokens = batch
        labels = labels.type(torch.LongTensor)
        labels = labels.to(device)
        tokens = tokens.to(device)

        # Forward pass
        outputs = model(tokens)

        # Calculate loss
        loss = criterion(outputs, labels)
        total_loss += loss.item()

        # Convert outputs probabilities to predicted class (0 or 1)
        _, predicted = torch.max(outputs.data, 1)

        # Count total and correct predictions
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

# Calculate average loss and accuracy
avg_loss = total_loss / len(test_dl)
accuracy = 100 * correct / total

print(f"Accuracy of the model on the test data: {accuracy:.2f}%")
print(f"Average loss on the test data: {avg_loss:.4f}")

  0%|          | 0/12500 [00:00<?, ?it/s]

Accuracy of the model on the test data: 100.00%
Average loss on the test data: 0.0000
