- [x] Create translator into baseline model
- [x] Ensure similar metrics on your dataset generator
- [x] Do the Direct look ðŸ‘€ onto impl

## Direct ðŸ‘€ into impl


There are two orders they give edges into "random" and "backward":
 
 - https://github.com/abhay-sheshadri/backward-chaining-circuits/blob/main/src/tree_generation/gen.py#L125


  - https://github.com/abhay-sheshadri/backward-chaining-circuits/blob/main/src/tree_generation/gen.py#L153

```python
if order == "random":
    rng.shuffle(edgelist)
elif order == "backward":
    edgelist = edgelist[::-1]
```


Not sure which one they used in training, default param to function is "random".

I used random, that sounds more strict. So should be no worries there.

## Setup

In [1]:
N_TEST_BATCHES = 100

In [2]:
import sys; sys.path.append('..')

In [3]:
# ! pip install lovely-tensors

import lovely_tensors as lt
lt.monkey_patch()

In [4]:
import torch
import numpy as np
import pandas as pd
import torch.nn.functional as F
from torch.utils.data import DataLoader

from transformer_lens import HookedTransformer, HookedTransformerConfig

In [5]:
import wandb
from tqdm.auto import tqdm

from omegaconf import OmegaConf

from datetime import datetime

from pathlib import Path

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

In [6]:
from src.tree import list2tree
from src.tree_dataset import TreeDataset, parse_input_idx, input_tokens_to_tree, tree_to_edges
from src.utils import seed_all


from src.trainer import accuracy_by_depth
from src.trainer import TreeTrainer

In [7]:
conf = OmegaConf.load('../conf/00_reproduce_6L_nodes=16.yaml')
# Output is identical to the YAML file
conf.n_nodes = 16
conf.device = 'cpu'
print(OmegaConf.to_yaml(conf))

random_seed: 42
n_nodes: 16
model:
  d_model: 128
  d_head: 128
  n_layers: 6
  act_fn: gelu
  attention_dir: causal
optimizer:
  lr: 0.001
  weight_decay: 0.01
batch_size: 64
epoch_len_steps: 5000
checkpoint_every_epoch: 2
device: cpu
debug: false
use_wandb: true
wandb:
  project: reasoning-mech-interp
  name: 00_6L_nodes=16
max_iters: null



In [8]:
REPRODUCED_MODEL_CKPT = '../checkpoints/reasoning-mech-interp__2024-04-10_16-10-20/00_6L_nodes=16__step=220000.pt'

In [9]:
DEV_RUN = False
USE_WANDB = (not DEV_RUN) and conf.use_wandb
device = conf.device

CHECKPOINT_ROOT = Path('../checkpoints')

In [10]:
RANDOM_SEED = conf['random_seed']
print(f'{RANDOM_SEED=}')
seed_all(RANDOM_SEED)

RANDOM_SEED=42


In [11]:
trainer = TreeTrainer(conf)

tokenizer = trainer.dataset.tokenizer

tok = tokenizer.tokenize
detok = tokenizer.detokenize


ROOT_DELIM_TOKEN_IDX = trainer.tok([':'])[0]

Moving model to device:  cpu


In [12]:
state_dict = torch.load(REPRODUCED_MODEL_CKPT)
trainer.model.load_state_dict(state_dict)

<All keys matched successfully>

In [13]:
sample_size = 0

metric_df = []

for i, batch in tqdm(zip(range(N_TEST_BATCHES), trainer.dataloader), total=N_TEST_BATCHES):
    sample_size += len(batch['input_idx'])
    with torch.inference_mode():
        loss, metrics = trainer.train_step(batch)
        metric_df.append(metrics)


metric_df = pd.DataFrame(metric_df)

metric_df.mean()

  0%|          | 0/100 [00:00<?, ?it/s]

acc/depth=1      0.996667
acc/depth=2      1.000000
acc/depth=3      0.994493
acc/depth=4      0.956678
acc/depth=5      0.888542
acc/depth=6      0.841603
acc/depth=8      0.743056
accuracy/mean    0.982899
acc/depth=7      0.780296
acc/depth=9      0.862069
acc/depth=10     1.000000
acc/depth=11     0.000000
acc/depth=12     1.000000
dtype: float64

In [14]:
trainer.train_step(batch)

(tensor grad NllLossBackward0 0.040,
 {'acc/depth=1': 1.0,
  'acc/depth=2': 1.0,
  'acc/depth=3': 1.0,
  'acc/depth=4': 0.9333333333333333,
  'acc/depth=5': 0.8666666666666667,
  'acc/depth=6': 0.9230769230769231,
  'acc/depth=7': 1.0,
  'accuracy/mean': 0.9855595827102661})

In [15]:
self = trainer

In [16]:
input_idx = batch['input_idx'].to(self.device)
mask = batch['task_mask'].to(self.device)

# input_idx = batch['input_idx'][..., :bas_max_seq_length].to(device)
# mask = batch['task_mask'][..., :bas_max_seq_length].to(device)

inputs = input_idx[:, :-1]

out_mask = mask[:, 1:]
targets = input_idx[:, 1:][out_mask]


# print(input_idx[:1, :4])
outputs = self.model(inputs)

predictions = outputs[out_mask]

loss = F.cross_entropy(predictions, targets)

is_correct = (predictions.argmax(dim=-1) == targets)
accuracy_mean = is_correct.float().mean()
metrics = accuracy_by_depth(outputs, input_idx, out_mask)
metrics['accuracy/mean'] = accuracy_mean.item()

metrics

{'acc/depth=1': 1.0,
 'acc/depth=2': 1.0,
 'acc/depth=3': 1.0,
 'acc/depth=4': 0.9333333333333333,
 'acc/depth=5': 0.8666666666666667,
 'acc/depth=6': 0.9230769230769231,
 'acc/depth=7': 1.0,
 'accuracy/mean': 0.9855595827102661}

## Create translator

In [17]:
def load_baseline_model(device='cpu'):
    n_states = 16
    max_seq_length = n_states * 4 + 2
    
    number_tokens = sorted([str(i) for i in range(n_states)], key=lambda x: len(x), reverse=True)
    idx2tokens = [",", ":", "|"] + [f">{t}" for t in number_tokens] + number_tokens
    tokens2idx = {token: idx for idx, token in enumerate(idx2tokens)}
    
    
    cfg = HookedTransformerConfig(
        n_layers=6,
        d_model=128,
        n_ctx=max_seq_length - 1,
        n_heads=1,
        d_mlp=512,
        d_head=128,
        #attn_only=True,
        d_vocab=len(idx2tokens),
        device=device,
        attention_dir= "causal",
        act_fn="gelu",
    )
    model = HookedTransformer(cfg)
    
    model.load_state_dict(torch.load("/Users/mykhailokilianovskyi/src/backward-chaining-circuits/model.pt", map_location=torch.device(device)))
    
    return model

In [18]:
baseline_model = load_baseline_model()

In [19]:
seed_all(0)
for batch in trainer.dataloader: break

In [20]:
input_idx = batch['input_idx']
trainer.print_sample_pred(input_idx[0])

****************************************************************************************************
                            12
                             |
                    +--------+-----------+
                    |                    |
                   15                    2
                    |                    |
              +-----+-----+           +--+-----+
              |           |           |        |
              4           6           5        9
              |           |           |        |
     +--------+--+     +--+        +--+     +--+
     |           |     |           |        |
     8           3     0          14        1
     |                             |
  +--+-----+                    +--+
  |        |                    |
 11        7                   10
           |
        +--+
        |
       13

goal=3
accuracy=1.0 gt_path=['â†’15', 'â†’4', 'â†’3'] pred_path=['â†’15', 'â†’4', 'â†’3']
*************************************************

In [21]:
# goal: idx -> idx_baseline

# our model idx -> token
# bas model bastoken -> basidx

# first: create token2bastoken
# NOTE, in baseline: , is used as <PAD>

In [22]:
n_states = 16
bas_max_seq_length = n_states * 4 + 2

number_tokens = sorted([str(i) for i in range(n_states)], key=lambda x: len(x), reverse=True)
idx2tokens = [",", ":", "|"] + [f">{t}" for t in number_tokens] + number_tokens
tokens2idx = {token: idx for idx, token in enumerate(idx2tokens)}

In [23]:
# our_idx2token.keys()

In [24]:
from src.tree_dataset import PAD_TOKEN

In [25]:
token2bastoken = {k.replace('>', 'â†’'):k for k in tokens2idx.keys()}
token2bastoken[PAD_TOKEN] = ','

our_idx2token = trainer.dataset.tokenizer.idx2token

idx2basidx = {}



In [26]:
for idx, our_tok in our_idx2token.items():
    bastoken = token2bastoken[our_tok]
    basidx = tokens2idx[bastoken]
    idx2basidx[idx] = basidx

In [27]:
idx2basidx

{0: 0,
 1: 0,
 2: 2,
 3: 1,
 4: 25,
 5: 9,
 6: 26,
 7: 10,
 8: 27,
 9: 11,
 10: 28,
 11: 12,
 12: 29,
 13: 13,
 14: 30,
 15: 14,
 16: 31,
 17: 15,
 18: 32,
 19: 16,
 20: 33,
 21: 17,
 22: 34,
 23: 18,
 24: 19,
 25: 3,
 26: 20,
 27: 4,
 28: 21,
 29: 5,
 30: 22,
 31: 6,
 32: 23,
 33: 7,
 34: 24,
 35: 8}

In [28]:
BAS_ROOT_DELIM_TOKEN_IDX = tokens2idx['|']

In [29]:
def basidx2prompt(idx): return [idx2tokens[i] for i in idx]

In [30]:
our_tokens = detok(input_idx[0])
bas_tokens = [token2bastoken[t] for t in our_tokens]
print(f'{bas_tokens[:5]=}')

bas_idx = [tokens2idx[t] for t in bas_tokens]
upper_task_bound = bas_idx.index(BAS_ROOT_DELIM_TOKEN_IDX) + 4

prompt_idx = bas_idx[:upper_task_bound]
basidx2prompt(prompt_idx)

bas_tokens[:5]=['12', '>15', ',', '7', '>13']


['12',
 '>15',
 ',',
 '7',
 '>13',
 ',',
 '2',
 '>5',
 ',',
 '12',
 '>2',
 ',',
 '4',
 '>8',
 ',',
 '6',
 '>0',
 ',',
 '15',
 '>4',
 ',',
 '9',
 '>1',
 ',',
 '15',
 '>6',
 ',',
 '4',
 '>3',
 ',',
 '8',
 '>11',
 ',',
 '8',
 '>7',
 ',',
 '2',
 '>9',
 ',',
 '5',
 '>14',
 ',',
 '14',
 '>10',
 '|',
 '3',
 ':',
 '12']

In [31]:
assert (prompt_idx )  == ([idx2basidx[i.item()] for i in  input_idx[0]][:len(prompt_idx)])

In [32]:
for i in range(7):
    pred_idx_greedy = baseline_model(torch.tensor(prompt_idx))[0, -1].argmax()
    pred_token = idx2tokens[pred_idx_greedy]
    print(f'{pred_token=}')
    prompt_idx.append(pred_idx_greedy)

pred_token='>15'
pred_token='>4'
pred_token='>3'
pred_token=','
pred_token=','
pred_token=','
pred_token=','


## Verify metrics

In [33]:
def baseline_batch_train_step(baseline_model, batch):
    
    input_idx = batch['input_idx'][..., :bas_max_seq_length].clone().to(device)
    mask = batch['task_mask'][..., :bas_max_seq_length].clone().to(device)

    input_idx.apply_(lambda i: idx2basidx[i])

    inputs = input_idx[:, :-1]
    
    out_mask = mask[:, 1:]
    targets = input_idx[:, 1:][out_mask]
    
    
    # print(input_idx[:1, :4])
    outputs = baseline_model(inputs)
    
    predictions = outputs[out_mask]
    
    loss = F.cross_entropy(predictions, targets)

    is_correct = (predictions.argmax(dim=-1) == targets)
    accuracy_mean = is_correct.float().mean()
    metrics = accuracy_by_depth(outputs, input_idx, out_mask)
    metrics['accuracy/mean'] = accuracy_mean.item()

    return loss, metrics

In [34]:
seed_all(0)
for batch in trainer.dataloader: break

In [35]:
baseline_batch_train_step(baseline_model, batch)

(tensor grad NllLossBackward0 1.688e-06,
 {'acc/depth=1': 1.0,
  'acc/depth=2': 1.0,
  'acc/depth=3': 1.0,
  'acc/depth=4': 1.0,
  'acc/depth=5': 1.0,
  'acc/depth=6': 1.0,
  'acc/depth=7': 1.0,
  'acc/depth=8': 1.0,
  'accuracy/mean': 1.0})

In [None]:
sample_size = 0
basmetric_df = []

for i, batch in tqdm(zip(range(N_TEST_BATCHES), trainer.dataloader), total=N_TEST_BATCHES):
    sample_size += len(batch['input_idx'])
    with torch.inference_mode():
        loss, metrics = baseline_batch_train_step(baseline_model, batch)
        basmetric_df.append(metrics)


basmetric_df = pd.DataFrame(basmetric_df)
basmetric_df.mean()

  0%|          | 0/100 [00:00<?, ?it/s]

In [None]:
basmetric_df