## Setup

### Imports

In [1]:
import sys; sys.path.append('..')

In [2]:
# ! pip install lovely-tensors

import lovely_tensors as lt
lt.monkey_patch()

In [3]:
import torch
import numpy as np
import torch.nn.functional as F
from torch.utils.data import DataLoader

from transformer_lens import HookedTransformer, HookedTransformerConfig

In [4]:
import wandb
from tqdm.auto import tqdm

from omegaconf import OmegaConf

from datetime import datetime

from pathlib import Path

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

In [5]:
from src.tree import list2tree
from src.tree_dataset import TreeDataset, parse_input_idx, input_tokens_to_tree, tree_to_edges
from src.utils import seed_all, count_parameters


from src.trainer import TreeTrainer

### Read config

In [6]:
conf = OmegaConf.load('../conf/01_reproduce_2L_nodes=8.yaml')
# Output is identical to the YAML file
print(OmegaConf.to_yaml(conf))

random_seed: 42
n_nodes: 8
model:
  d_model: 128
  d_head: 128
  n_layers: 2
  act_fn: gelu
  attention_dir: causal
optimizer:
  lr: 0.001
  weight_decay: 0.01
batch_size: 64
epoch_len_steps: 5000
checkpoint_every_epoch: 1
device: mps
debug: false
use_wandb: true
wandb:
  project: reasoning-mech-interp
  name: 01_2L_nodes=8
max_iters: null



### Constants (mostly derived)

In [7]:
DEV_RUN = False
USE_WANDB = (not DEV_RUN) and conf.use_wandb
device = conf.device

CHECKPOINT_ROOT = Path('../checkpoints')

In [8]:
def create_checkpoint_dir(conf):
    now = datetime.now()
    now_filename = now.strftime("%Y-%m-%d_%H-%M-%S")
    
    checkpoint_dir = CHECKPOINT_ROOT/f'{conf["wandb"]["project"]}__{now_filename}'
    checkpoint_dir.mkdir(exist_ok=True, parents = True)
    return checkpoint_dir

In [9]:
if USE_WANDB:
    wandb.init(project=conf.wandb.project, name=conf.wandb.name, config=OmegaConf.to_object(conf))



[34m[1mwandb[0m: Currently logged in as: [33mkilianovski[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [10]:
RANDOM_SEED = conf['random_seed']
print(f'{RANDOM_SEED=}')
seed_all(RANDOM_SEED)

RANDOM_SEED=42


In [11]:
trainer = TreeTrainer(conf)
count_parameters(trainer.model)

Moving model to device:  mps
+--------------------+------------+
|      Modules       | Parameters |
+--------------------+------------+
|     embed.W_E      |    2560    |
|  pos_embed.W_pos   |    4608    |
|   blocks.0.ln1.w   |    128     |
|   blocks.0.ln1.b   |    128     |
|   blocks.0.ln2.w   |    128     |
|   blocks.0.ln2.b   |    128     |
| blocks.0.attn.W_Q  |   16384    |
| blocks.0.attn.W_O  |   16384    |
| blocks.0.attn.b_Q  |    128     |
| blocks.0.attn.b_O  |    128     |
| blocks.0.attn.W_K  |   16384    |
| blocks.0.attn.W_V  |   16384    |
| blocks.0.attn.b_K  |    128     |
| blocks.0.attn.b_V  |    128     |
| blocks.0.mlp.W_in  |   65536    |
| blocks.0.mlp.b_in  |    512     |
| blocks.0.mlp.W_out |   65536    |
| blocks.0.mlp.b_out |    128     |
|   blocks.1.ln1.w   |    128     |
|   blocks.1.ln1.b   |    128     |
|   blocks.1.ln2.w   |    128     |
|   blocks.1.ln2.b   |    128     |
| blocks.1.attn.W_Q  |   16384    |
| blocks.1.attn.W_O  |   16384    |

406548

In [12]:
tokenizer = trainer.dataset.tokenizer

tok = tokenizer.tokenize
detok = tokenizer.detokenize


ROOT_DELIM_TOKEN_IDX = trainer.tok([':'])[0]

In [13]:
checkpoint_dir = create_checkpoint_dir(conf)

In [None]:
global_step = 0
epoch_i = 0


while True:

    pbar = tqdm(list(range(conf.epoch_len_steps)))

    for i, batch in zip(pbar, trainer.dataloader):
        if i == 0:
            sample_input_idx = batch['input_idx'][0]
            trainer.print_sample_pred(sample_input_idx)
    
        trainer.optimizer.zero_grad()
        loss, accuracy = trainer.train_step(batch)
        loss.backward()
        trainer.optimizer.step()
        
        # pbar.update()
        pbar.set_description(f'loss={float(loss):.3f}')
        if USE_WANDB:
            wandb.log({'loss': loss, 'accuracy': accuracy})

        global_step += 1

    epoch_i += 1

    if epoch_i % conf['checkpoint_every_epoch'] == 0:
        checkpoint_filename = checkpoint_dir/f'{conf["wandb"]["name"]}__step={global_step}.pt'
        print(f'Saving {checkpoint_filename=}')
        torch.save(trainer.model.state_dict(), checkpoint_filename)
    
    if conf.get('max_iters', False) and global_step > conf.get('max_iters'):
        break

    

  0%|          | 0/5000 [00:00<?, ?it/s]

****************************************************************************************************
             1
             |
           +-+-+
           |   |
           5   0
           |
         +-+
         |
         2
         |
       +-+
       |
       7
       |
     +-+
     |
     6
     |
   +-+
   |
   4
   |
 +-+
 |
 3

goal=0
accuracy=0.0 gt_path=['→0'] pred_path=['→3']
****************************************************************************************************
Saving checkpoint_filename=Path('../checkpoints/reasoning-mech-interp__2024-04-11_10-38-39/01_2L_nodes=8__step=5000.pt')


  0%|          | 0/5000 [00:00<?, ?it/s]

****************************************************************************************************
               1
               |
           +---+
           |
           0
           |
   +-------+-+
   |         |
   2         4
   |
 +-+-----+
 |       |
 6       3
         |
       +-+
       |
       5
       |
     +-+
     |
     7

goal=4
accuracy=1.0 gt_path=['→0', '→4'] pred_path=['→0', '→4']
****************************************************************************************************
Saving checkpoint_filename=Path('../checkpoints/reasoning-mech-interp__2024-04-11_10-38-39/01_2L_nodes=8__step=10000.pt')


  0%|          | 0/5000 [00:00<?, ?it/s]

****************************************************************************************************
             0
             |
         +---+-+
         |     |
         4     5
         |
       +-+-+
       |   |
       7   2
       |
   +---+
   |
   1
   |
 +-+-+
 |   |
 6   3

goal=6
accuracy=1.0 gt_path=['→4', '→7', '→1', '→6'] pred_path=['→4', '→7', '→1', '→6']
****************************************************************************************************
Saving checkpoint_filename=Path('../checkpoints/reasoning-mech-interp__2024-04-11_10-38-39/01_2L_nodes=8__step=15000.pt')


  0%|          | 0/5000 [00:00<?, ?it/s]

****************************************************************************************************
   6
   |
 +-+-------+
 |         |
 3         7
           |
         +-+---+
         |     |
         2     5
         |     |
       +-+   +-+
       |     |
       0     1
       |
     +-+
     |
     4

goal=1
accuracy=1.0 gt_path=['→7', '→5', '→1'] pred_path=['→7', '→5', '→1']
****************************************************************************************************
Saving checkpoint_filename=Path('../checkpoints/reasoning-mech-interp__2024-04-11_10-38-39/01_2L_nodes=8__step=20000.pt')


  0%|          | 0/5000 [00:00<?, ?it/s]

****************************************************************************************************
       1
       |
   +---+-------+
   |           |
   4           2
   |           |
 +-+-+       +-+
 |   |       |
 7   5       6
             |
           +-+
           |
           3
           |
         +-+
         |
         0

goal=0
accuracy=1.0 gt_path=['→2', '→6', '→3', '→0'] pred_path=['→2', '→6', '→3', '→0']
****************************************************************************************************
Saving checkpoint_filename=Path('../checkpoints/reasoning-mech-interp__2024-04-11_10-38-39/01_2L_nodes=8__step=25000.pt')


  0%|          | 0/5000 [00:00<?, ?it/s]

****************************************************************************************************
   6
   |
 +-+---+
 |     |
 4     2
       |
     +-+-------+
     |         |
     1         0
               |
             +-+
             |
             7
             |
           +-+
           |
           3
           |
         +-+
         |
         5

goal=5
accuracy=1.0 gt_path=['→2', '→0', '→7', '→3', '→5'] pred_path=['→2', '→0', '→7', '→3', '→5']
****************************************************************************************************
Saving checkpoint_filename=Path('../checkpoints/reasoning-mech-interp__2024-04-11_10-38-39/01_2L_nodes=8__step=30000.pt')


  0%|          | 0/5000 [00:00<?, ?it/s]

****************************************************************************************************
               2
               |
   +-----------+
   |
   5
   |
 +-+---------+
 |           |
 4           7
             |
         +---+
         |
         1
         |
       +-+-+
       |   |
       0   6
       |
     +-+
     |
     3

goal=4
accuracy=1.0 gt_path=['→5', '→4'] pred_path=['→5', '→4']
****************************************************************************************************


## Debug on custom input

In [None]:
for batch in trainer.dataloader:
    print(parse_input_idx(batch['input_idx'][0], trainer.dataset.tokenizer)['tree'])
    break

In [None]:
with torch.inference_mode():
    locc, accuracy = trainer.train_step(batch)

In [None]:
input_idx_custom = batch['input_idx'][0]


parsed_input = parse_input_idx(input_idx_custom, trainer.dataset.tokenizer)
parsed_input

In [None]:
upper_task_bound = input_idx_custom.tolist().index(ROOT_DELIM_TOKEN_IDX) + 2
prompt_autoregressive = detok(input_idx_custom)[:upper_task_bound]
prompt = prompt_autoregressive
print(prompt)

In [None]:
new_token = inference_on_prompt(prompt)
prompt += [new_token]
print(prompt)

In [None]:
softmax_probs.shape

In [None]:
softmax_probs = preds[0][mask[0]].softmax(dim=-1).detach().cpu().numpy()

# Plotting softmax probabilities as a heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(softmax_probs, annot=True, cmap='coolwarm', fmt=".2f")
plt.title("Softmax Probabilities Heatmap")
plt.xlabel("Classes")
plt.ylabel("Instances")
plt.show()

In [None]:
input_idx_autoregressive

In [None]:
preds = trainer.model(input_idx)

In [None]:
detok(input_idx[0][mask[0]])

In [None]:
batch['input_idx'][:1]

In [None]:
accuracy

In [None]:




custom_tree = list2tree([0, 1, 2, None, None, 3])
custom_tree

In [None]:
prompt_tokens, path_tokens = tree2tokens(custom_tree)
prompt_tokens

In [None]:
self = trainer.dataset
prompt_idx = self.tokenizer(prompt_tokens)
path_idx = self.tokenizer(path_tokens)

input_tokens = prompt_idx + path_idx
pad_len = self.tokenizer.MAX_SEQ_LEN - len(input_tokens)

input_idx = torch.tensor(input_tokens + [0] * pad_len)
# pad_mask = torch.zeros(self.tokenizer.MAX_SEQ_LEN)
# pad_mask[:len(input_idx)] = 1

task_mask = torch.zeros(self.tokenizer.MAX_SEQ_LEN, dtype=torch.bool)
task_mask[len(prompt_idx):len(input_tokens)] = True

In [None]:
trainer.model(input_idx).argmax(dim=-1)

In [None]:
custom_input = ['0', ',', '→2', '0', ',', '→1', '1', ',', '→3', '|', '0', ':', '3', '→1', '→3']
tok(custom_input)

In [None]:
custom_batch = {
                'input_idx': input_idx[None],
                'task_mask': task_mask[None],
                # 'pad_mask': pad_mask,
            }

In [None]:
print_preds(custom_batch)

In [None]:
input_idx = custom_batch['input_idx'].to(device)
mask = batch['task_mask'].to(device)

preds = trainer.model(custom_batch['input_idx'][:, :])

In [None]:
detok(preds.argmax(dim=-1)[0])

In [None]:
tree_to_edges(parsed_input['tree'])

In [None]:
input_idx = batch['input_idx'].to(device)
mask = batch['task_mask'].to(device)

In [None]:
logits, cache = trainer.model.run_with_cache(input_idx[:1])

In [None]:
# pip install circuitsvis

In [None]:
import circuitsvis as cv

In [None]:
# ! pip install webbrowser

In [None]:
# path = "attn_heads.html"

# with open(path, "w") as f:
#     f.write(str(attn_heads))

# webbrowser.open(path)


In [None]:
parse_input_idx((input_idx[0]), tokenizer)[:0]

In [None]:
print(type(cache))
attention_pattern = cache["pattern", 0]
print(attention_pattern.shape)
str_tokens = detok (input_idx[0].cpu())

print("Layer 0 Head Attention Patterns:")
attn_heads = (cv.attention.attention_patterns(
    tokens=str_tokens,
    attention=attention_pattern[0],
    attention_head_names=[f"L0H{i}" for i in range(1)],
))


attn_heads

In [None]:
attention_pattern.shape

In [None]:
%pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python

In [None]:
path = "attn_heads.html"

with open(path, "w") as f:
    f.write(str(attn_heads))

In [None]:
cv.__version__

In [None]:
wandb.log({"custom_file": wandb.Html(open(path))})

In [None]:
! python --version

In [None]:
display(attn_heads)

In [None]:
# Python Example
from circuitsvis.tokens import colored_tokens
colored_tokens(["My", "tokens"], [0.123, -0.226])