In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch

from dyck_k_generator import checker, constants, generator
from transformer import dataset, transformer

In [3]:
import gc

torch.cuda.empty_cache()
gc.collect()

4

In [4]:
device = (
    "cuda:0"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
device

'mps'

In [5]:
k = 3

In [6]:
VOCAB = "".join(
    ["".join((key, value)) for key, value in list(constants.BRACKETS.items())[:k]]
)
VOCAB

'()[]{}'

# TransformerLens


In [None]:
config = transformer.generate_config(
    n_ctx=102,
    d_model=56,
    d_head=28,
    n_heads=2,
    d_mlp=56,
    n_layers=3,
    attention_dir="bidirectional",
    act_fn="relu",
    d_vocab=len(VOCAB) + 3,
    d_vocab_out=2,
    use_attn_result=True,
    device=device,
    use_hook_tokens=True,
)

config

In [None]:
model = transformer.generate_model(config)

In [None]:
model = model.to(device)

In [None]:
model.train(True)

In [None]:
from transformer.dataset import DyckLanguageDataset

In [None]:
dataset = DyckLanguageDataset("data/dyck-1_1000-samples_100-len_p07.jsonl", VOCAB).to(
    device
)

In [None]:
dataset[0]

In [None]:
from collections import namedtuple

Item = namedtuple("Item", field_names=["string", "target", "tokens"])

In [None]:
item = Item(dataset[0][0], dataset[0][1], dataset[0][2])

In [None]:
item.tokens

In [None]:
from torch.utils.data import random_split

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size

train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

In [None]:
model

In [None]:
from torch.utils.data import DataLoader

In [None]:
dl = DataLoader(train_dataset, batch_size=1, shuffle=True)

In [None]:
for i, batch in enumerate(dl):
    print(batch[2].shape)

In [None]:
from transformer_lens.train import HookedTransformerTrainConfig, train

In [None]:
cfg = HookedTransformerTrainConfig(
    num_epochs=1,
    batch_size=1,
    lr=1e-4,
    seed=42,
    device=device,
)

In [None]:
model = train(model, cfg, train_dataset)

# Manual Transformer + BERTViz


In [None]:
from torch import nn

In [None]:
class PosEncoding(nn.Module):
    def __init__(self, d_model, vocab_size=5_000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        pos_enc = torch.zeros(vocab_size, d_model)
        position = torch.arange(0, vocab_size, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float()
            * (-torch.log(torch.tensor(1e4)) / d_model)
        )

        pos_enc[:, 0::2] = torch.sin(position * div_term)
        pos_enc[:, 1::2] = torch.cos(position * div_term)
        pos_enc = pos_enc.unsqueeze(0)
        self.register_buffer("pos_enc", pos_enc)

    def forward(self, x):
        x = x + self.pos_enc[:, : x.size(1)]
        return self.dropout(x)

In [None]:
class TransformerClassifier(nn.Module):
    def __init__(
        self,
        embeddings,
        n_heads=8,
        d_ff=2048,
        n_layers=6,
        dropout=0.1,
        act="relu",
        classifier_dropout=0.1,
    ):
        super().__init__()