In [1]:
from dyck_k_generator import constants, checker, generator
from transformer import dataset, transformer
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import gc
torch.cuda.empty_cache()
gc.collect()

0

In [3]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
device

'cuda:0'

In [4]:
k = 3

In [5]:
VOCAB = ''.join([''.join((key, value)) for key, value in list(constants.BRACKETS.items())[:k]])
VOCAB

'()[]{}'

In [6]:
config = transformer.generate_config(
    n_ctx=42,
    d_model=56,
    d_head=28,
    n_heads=2,
    d_mlp=56,
    n_layers=3,
    attention_dir='bidirectional',
    act_fn='relu',
    d_vocab=len(VOCAB) + 3,
    d_vocab_out=2,
    use_attn_result=True,
    device=device,
    use_hook_tokens=True
)

config

HookedTransformerConfig:
{'act_fn': 'relu',
 'attention_dir': 'bidirectional',
 'attn_only': False,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 28,
 'd_mlp': 56,
 'd_model': 56,
 'd_vocab': 9,
 'd_vocab_out': 2,
 'default_prepend_bos': True,
 'device': 'cuda:0',
 'dtype': torch.float32,
 'eps': 1e-05,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': True,
 'initializer_range': 0.10690449676496976,
 'model_name': 'custom',
 'n_ctx': 42,
 'n_devices': 1,
 'n_heads': 2,
 'n_key_value_heads': None,
 'n_layers': 3,
 'n_params': 56448,
 'normalization_type': 'LN',
 'original_architecture': None,
 'parallel_attn_mlp': False,
 'positional_embedding_type': 'standard',
 'post_embedding_ln': False,
 'rotary_adjacent_pairs': False,
 'rotary_base': 10000,
 'rotary_dim': None,
 'scale_attn_by_inverse_layer_idx': False,
 'seed': None,
 'tokenizer_name': None,
 'tokenize

In [7]:
model = transformer.generate_model(config)

In [8]:
model = model.to('cpu')

Moving model to device:  cpu


In [9]:
model.train(True)

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (hook_tokens): HookPoint()
  (blocks): ModuleList(
    (0-2): 3 x TransformerBlock(
      (ln1): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out):

In [21]:
from transformer.dataset import DyckLanguageDataset

In [24]:
dataset = DyckLanguageDataset('data/dyck-1_10000-samples_40-len_p05.jsonl', VOCAB).to(device) 

In [25]:
dataset[0]

{'str': '))())()(())))',
 'tokens': tensor([0, 4, 4, 3, 4, 4, 3, 4, 3, 3, 4, 4, 4, 4, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0'),
 'class': tensor(False, device='cuda:0')}

In [26]:
from torch.utils.data import random_split

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size

train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

In [27]:
model

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (hook_tokens): HookPoint()
  (blocks): ModuleList(
    (0-2): 3 x TransformerBlock(
      (ln1): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out):

In [28]:
from torch.utils.data import DataLoader

In [29]:
dl = DataLoader(train_dataset, batch_size=1, shuffle=True)

In [30]:
for i, batch in enumerate(dl):
    print(batch)
    break

{'str': ['()((()))()()()'], 'tokens': tensor([[0, 3, 4, 3, 3, 3, 4, 4, 4, 3, 4, 3, 4, 3, 4, 2, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
       device='cuda:0'), 'class': tensor([True], device='cuda:0')}


In [31]:
from transformer_lens.train import train, HookedTransformerTrainConfig

In [32]:
cfg = HookedTransformerTrainConfig(
    num_epochs=1,
    batch_size=1,
    lr=1e-4,
    seed=42,
    device='cpu',
)

In [33]:
model = train(model, cfg, train_dataset)

Moving model to device:  cpu


0it [00:00, ?it/s]/1 [00:00<?, ?it/s]
  0%|          | 0/1 [00:00<?, ?it/s]


RuntimeError: index 3 is out of bounds for dimension 2 with size 2