In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch

from dyck_k_generator import checker, constants, generator
from transformer import dataset, transformer

In [3]:
import gc

torch.cuda.empty_cache()
gc.collect()

4

In [4]:
device = (
    "cuda:0"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
device

'mps'

In [5]:
k = 3

In [6]:
VOCAB = "".join(
    ["".join((key, value)) for key, value in list(constants.BRACKETS.items())[:k]]
)
VOCAB

'()[]{}'

# TransformerLens


In [13]:
config = transformer.generate_config(
    n_ctx=102,
    d_model=56,
    d_head=28,
    n_heads=2,
    d_mlp=56,
    n_layers=3,
    attention_dir="bidirectional",
    act_fn="relu",
    d_vocab=len(VOCAB) + 3,
    d_vocab_out=2,
    use_attn_result=True,
    device=device,
    use_hook_tokens=True,
)

config

HookedTransformerConfig:
{'act_fn': 'relu',
 'attention_dir': 'bidirectional',
 'attn_only': False,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 28,
 'd_mlp': 56,
 'd_model': 56,
 'd_vocab': 9,
 'd_vocab_out': 2,
 'default_prepend_bos': True,
 'device': 'mps',
 'dtype': torch.float32,
 'eps': 1e-05,
 'experts_per_token': None,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': True,
 'initializer_range': 0.10690449676496976,
 'load_in_4bit': False,
 'model_name': 'custom',
 'n_ctx': 102,
 'n_devices': 1,
 'n_heads': 2,
 'n_key_value_heads': None,
 'n_layers': 3,
 'n_params': 56448,
 'normalization_type': 'LN',
 'num_experts': None,
 'original_architecture': None,
 'parallel_attn_mlp': False,
 'positional_embedding_type': 'standard',
 'post_embedding_ln': False,
 'rotary_adjacent_pairs': False,
 'rotary_base': 10000,
 'rotary_dim': None,
 'scale_attn_by_inver

In [14]:
model = transformer.generate_model(config)

In [15]:
model = model.to(device)

Moving model to device:  mps


In [16]:
model.train(True)

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (hook_tokens): HookPoint()
  (blocks): ModuleList(
    (0-2): 3 x TransformerBlock(
      (ln1): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out):

In [17]:
from transformer.dataset import DyckLanguageDataset

In [19]:
dataset = DyckLanguageDataset(
    "data/dyck-3_250000-samples_1024-len_p065.jsonl", VOCAB
).to(device)

Loaded 250000 samples from data/dyck-3_250000-samples_1024-len_p065.jsonl


In [None]:
dataset[0]

In [None]:
from collections import namedtuple

Item = namedtuple("Item", field_names=["string", "target", "tokens"])

In [None]:
item = Item(dataset[0][0], dataset[0][1], dataset[0][2])

In [None]:
item.tokens

In [20]:
from torch.utils.data import random_split

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size

train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

In [None]:
model

In [22]:
from torch.utils.data import DataLoader

In [23]:
dl = DataLoader(train_dataset, batch_size=1, shuffle=True)

In [24]:
for i, batch in enumerate(dl):
    print(batch[2].shape)

torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size([1, 1026])
torch.Size

KeyboardInterrupt: 

In [None]:
from transformer_lens.train import HookedTransformerTrainConfig, train

In [None]:
cfg = HookedTransformerTrainConfig(
    num_epochs=1,
    batch_size=1,
    lr=1e-4,
    seed=42,
    device=device,
)

In [None]:
model = train(model, cfg, train_dataset)

# Manual Transformer + BERTViz


In [25]:
from torch import nn

In [26]:
class PosEncoding(nn.Module):
    def __init__(self, d_model, vocab_size=5_000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        pos_enc = torch.zeros(vocab_size, d_model)
        position = torch.arange(0, vocab_size, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float()
            * (-torch.log(torch.tensor(1e4)) / d_model)
        )

        pos_enc[:, 0::2] = torch.sin(position * div_term)
        pos_enc[:, 1::2] = torch.cos(position * div_term)
        pos_enc = pos_enc.unsqueeze(0)
        self.register_buffer("pos_enc", pos_enc)

    def forward(self, x):
        x = x + self.pos_enc[:, : x.size(1)]
        return self.dropout(x)

In [27]:
class TransformerClassifier(nn.Module):
    def __init__(
        self,
        embeddings,
        n_heads=8,
        d_ff=2048,
        n_layers=6,
        dropout=0.1,
        act="relu",
        classifier_dropout=0.1,
    ):
        super().__init__()
        vocab_size, d_model = embeddings.size()

        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"

        self.emb = nn.Embedding.from_pretrained(embeddings)
        self.pos_enc = PosEncoding(
            d_model=d_model, vocab_size=vocab_size, dropout=dropout
        )

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=d_ff,
            dropout=dropout,
            activation=act,
        )

        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)

        self.classifier = nn.Linear(d_model, 2)
        self.d_model = d_model

    def forward(self, x):
        x = self.emb(x) * (self.d_model**0.5)
        x = self.pos_enc(x)
        x = self.encoder(x)
        x = x.mean(dim=1)
        x = self.classifier(x)
        return x