In [1]:
import sys; sys.path.append('..')

In [2]:
# ! pip install lovely-tensors

import lovely_tensors as lt
lt.monkey_patch()

In [3]:
import torch
import numpy as np
import pandas as pd
import torch.nn.functional as F
from torch.utils.data import DataLoader

from transformer_lens import HookedTransformer, HookedTransformerConfig

In [4]:
import wandb
from tqdm.auto import tqdm

from omegaconf import OmegaConf

from datetime import datetime

from pathlib import Path

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

In [5]:
from src.tree import list2tree
from src.tree_dataset import TreeDataset, parse_input_idx, input_tokens_to_tree, tree_to_edges
from src.utils import seed_all


from src.trainer import accuracy_by_depth
from src.trainer import TreeTrainer

In [6]:
conf = OmegaConf.load('../conf/00_reproduce_6L_nodes=16.yaml')

# Output is identical to the YAML file
conf.n_nodes = 16
conf.device = 'cpu'
print(OmegaConf.to_yaml(conf))

random_seed: 42
n_nodes: 16
model:
  d_model: 128
  d_head: 128
  n_layers: 6
  act_fn: gelu
  attention_dir: causal
optimizer:
  lr: 0.001
  weight_decay: 0.01
batch_size: 64
epoch_len_steps: 5000
checkpoint_every_epoch: 2
device: cpu
debug: false
use_wandb: true
wandb:
  project: reasoning-mech-interp
  name: 00_6L_nodes=16
max_iters: null



In [7]:
REPRODUCED_MODEL_CKPT = '../checkpoints/reasoning-mech-interp__2024-04-10_16-10-20/00_6L_nodes=16__step=220000.pt'
# REPRODUCED_MODEL_CKPT = '../checkpoints/reasoning-mech-interp__2024-04-12_14-26-20/00_6L_nodes=16__deep_trees__step=9256.pt'

In [8]:
DEV_RUN = False
USE_WANDB = (not DEV_RUN) and conf.use_wandb
device = conf.device
N_TEST_BATCHES = 500
CHECKPOINT_ROOT = Path('../checkpoints')

In [9]:
RANDOM_SEED = conf['random_seed']
print(f'{RANDOM_SEED=}')
seed_all(RANDOM_SEED)

RANDOM_SEED=42


In [10]:
trainer = TreeTrainer(conf)

tokenizer = trainer.dataset.tokenizer

tok = tokenizer.tokenize
detok = tokenizer.detokenize


ROOT_DELIM_TOKEN_IDX = trainer.tok([':'])[0]

state_dict = torch.load(REPRODUCED_MODEL_CKPT)
trainer.model.load_state_dict(state_dict)

Moving model to device:  cpu


<All keys matched successfully>

In [11]:
def load_baseline_model(device='cpu'):
    n_states = 16
    max_seq_length = n_states * 4 + 2
    
    number_tokens = sorted([str(i) for i in range(n_states)], key=lambda x: len(x), reverse=True)
    idx2tokens = [",", ":", "|"] + [f">{t}" for t in number_tokens] + number_tokens
    tokens2idx = {token: idx for idx, token in enumerate(idx2tokens)}
    
    
    cfg = HookedTransformerConfig(
        n_layers=6,
        d_model=128,
        n_ctx=max_seq_length - 1,
        n_heads=1,
        d_mlp=512,
        d_head=128,
        #attn_only=True,
        d_vocab=len(idx2tokens),
        device=device,
        attention_dir= "causal",
        act_fn="gelu",
    )
    model = HookedTransformer(cfg)
    
    model.load_state_dict(torch.load("/Users/mykhailokilianovskyi/src/backward-chaining-circuits/model.pt", map_location=torch.device(device)))
    
    return model

In [12]:
import random
import collections
from torch.utils.data import IterableDataset, DataLoader

In [13]:
import random

from src.tree import TreeNode
from src.utils import seed_all
from src.tree_dataset import random_tree_of_depth, DeepTreeDataset

In [14]:
deep_dataset = DeepTreeDataset(n_nodes=16, possible_depths=(15,14,13))
deep_tree_dataloader = DataLoader(deep_dataset, batch_size=conf['batch_size'])

In [15]:
baseline_model = load_baseline_model()

In [16]:
our_idx2token = trainer.dataset.tokenizer.idx2token


In [17]:
n_states = 16
bas_max_seq_length = n_states * 4 + 2

number_tokens = sorted([str(i) for i in range(n_states)], key=lambda x: len(x), reverse=True)
idx2tokens = [",", ":", "|"] + [f">{t}" for t in number_tokens] + number_tokens
tokens2idx = {token: idx for idx, token in enumerate(idx2tokens)}

In [18]:
from src.tree_dataset import PAD_TOKEN


token2bastoken = {k.replace('>', '→'):k for k in tokens2idx.keys()}
token2bastoken[PAD_TOKEN] = ','

In [19]:


idx2basidx = {}
for idx, our_tok in our_idx2token.items():
    bastoken = token2bastoken[our_tok]
    basidx = tokens2idx[bastoken]
    idx2basidx[idx] = basidx

In [20]:
def baseline_batch_train_step(baseline_model, batch):
    
    input_idx = batch['input_idx'][..., :bas_max_seq_length].clone().to(device)
    mask = batch['task_mask'][..., :bas_max_seq_length].clone().to(device)

    input_idx.apply_(lambda i: idx2basidx[i])

    inputs = input_idx[:, :-1]
    
    out_mask = mask[:, 1:]
    targets = input_idx[:, 1:][out_mask]
    
    
    # print(input_idx[:1, :4])
    outputs = baseline_model(inputs)
    
    predictions = outputs[out_mask]
    
    loss = F.cross_entropy(predictions, targets)

    is_correct = (predictions.argmax(dim=-1) == targets)
    accuracy_mean = is_correct.float().mean()
    metrics = accuracy_by_depth(outputs, input_idx, out_mask)
    metrics['accuracy/mean'] = accuracy_mean.item()

    return loss, metrics

In [21]:
def split_depths(mdic):
    return (int(k.split('=')[1]) for k in mdic.keys() if '=' in k)

In [22]:
depths = np.arange(1, 10)

In [23]:
def df2plotly_arr(df): 
    dic = df.mean().to_dict()
    return [dic.get(f'acc/depth={d}', None) for d in depths]

## Project residual stream into vocab

In [24]:
# for batch in deep_tree_dataloader:
#     break

seed_all(1)
for batch in trainer.dataloader:
    break

In [25]:
baseline_batch_train_step(baseline_model, batch)

(tensor grad NllLossBackward0 0.001,
 {'acc/depth=1': 1.0,
  'acc/depth=2': 1.0,
  'acc/depth=3': 1.0,
  'acc/depth=4': 1.0,
  'acc/depth=5': 1.0,
  'acc/depth=6': 1.0,
  'acc/depth=7': 1.0,
  'acc/depth=8': 1.0,
  'accuracy/mean': 1.0})

In [26]:
input_idx = batch['input_idx'][..., :bas_max_seq_length].clone().to(device)
mask = batch['task_mask'][..., :bas_max_seq_length].clone().to(device)

input_idx.apply_(lambda i: idx2basidx[i])

inputs = input_idx[:, :-1]

out_mask = mask[:, 1:]
targets = input_idx[:, 1:][out_mask]


# print(input_idx[:1, :4])
outputs, cache = baseline_model.run_with_cache(inputs)

In [27]:
inputs

tensor[64, 65] i64 n=4160 (32Kb) x∈[0, 34] μ=10.152 σ=11.275

In [28]:
baseline_model.ln_final

LayerNorm(
  (hook_scale): HookPoint()
  (hook_normalized): HookPoint()
)

In [29]:
x_5 = cache['resid_post', -1]
x_5

tensor[64, 65, 128] n=532480 (2.0Mb) x∈[-93.056, 231.443] μ=0.221 σ=19.380

In [30]:
x_5 = cache['resid_post', -1]

assert torch.allclose(baseline_model.unembed(baseline_model.ln_final(x_5)), outputs)

In [31]:
def unembed(x_residual):
    return baseline_model.unembed(baseline_model.ln_final(x_residual))

In [32]:
out_mask[0].v

tensor[65] bool x∈[False, True] μ=0.108 σ=0.312
tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False,  True,  True,  True,
         True,  True,  True,  True, False, False, False, False, False, False,
        False, False, False, False, False])

In [33]:
out_mask.int().argmax(dim=-1).v

tensor[64] i64 x∈[47, 47] μ=47.000 σ=0.
tensor([47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47,
        47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47,
        47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47, 47,
        47, 47, 47, 47, 47, 47, 47, 47, 47, 47])

In [34]:
tokens2idx[':']

1

In [35]:
tok([':']), detok([1])

([3], [','])

In [36]:
input_idx[:, 46:].v

tensor[64, 20] i64 n=1280 (10Kb) x∈[0, 34] μ=3.782 σ=7.303
tensor([[ 1, 29, 18,  ...,  0,  0,  0],
        [ 1, 24, 17,  ...,  0,  0,  0],
        [ 1, 23, 15,  ...,  0,  0,  0],
        ...,
        [ 1, 27, 16,  ...,  0,  0,  0],
        [ 1, 23,  4,  ...,  0,  0,  0],
        [ 1, 26, 12,  ...,  0,  0,  0]])

In [37]:
input_idx[:, 45].v

tensor[64] i64 x∈[19, 34] μ=25.812 σ=4.411
tensor([27, 30, 20, 28, 26, 20, 22, 26, 32, 34, 22, 25, 25, 21, 24, 22, 27, 30,
        26, 30, 34, 28, 22, 21, 31, 19, 33, 27, 30, 27, 19, 30, 20, 28, 24, 23,
        28, 24, 33, 20, 34, 24, 19, 33, 32, 30, 30, 22, 22, 22, 25, 29, 33, 23,
        25, 19, 28, 22, 24, 24, 25, 22, 25, 22])

In [38]:
out_mask[:, 46:].v

tensor[64, 19] bool n=1216 (1.2Kb) x∈[False, True] μ=0.237 σ=0.425
tensor([[False,  True,  True,  ..., False, False, False],
        [False,  True,  True,  ..., False, False, False],
        [False,  True,  True,  ..., False, False, False],
        ...,
        [False,  True,  True,  ..., False, False, False],
        [False,  True,  True,  ..., False, False, False],
        [False,  True,  True,  ..., False, False, False]])

In [39]:
# detok(input_idx[0])

In [40]:
input_idx[:, 46:].v

tensor[64, 20] i64 n=1280 (10Kb) x∈[0, 34] μ=3.782 σ=7.303
tensor([[ 1, 29, 18,  ...,  0,  0,  0],
        [ 1, 24, 17,  ...,  0,  0,  0],
        [ 1, 23, 15,  ...,  0,  0,  0],
        ...,
        [ 1, 27, 16,  ...,  0,  0,  0],
        [ 1, 23,  4,  ...,  0,  0,  0],
        [ 1, 26, 12,  ...,  0,  0,  0]])

In [41]:
inputs[47-4].v

tensor[65] i64 x∈[0, 33] μ=10.046 σ=10.868
tensor([24,  7,  0, 24, 15,  0, 19, 14,  0, 25, 10,  0, 26,  4,  0, 22,  9,  0,
        28,  5,  0, 21,  8,  0, 28,  3,  0, 31, 16,  0, 22, 11,  0, 26, 13,  0,
        32, 17,  0, 23, 18,  0, 21,  6,  2, 33,  1, 28,  5,  8, 15, 16, 17,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0])

In [42]:
pred_x0 = unembed(cache['resid_post', 0]).argmax(dim=-1)
pred_x0

tensor[64, 65] i64 n=4160 (32Kb) x∈[0, 28] μ=5.644 σ=6.326

In [43]:
pred_x0_tokens = np.array([[idx2tokens[i] for i in sample] for sample in pred_x0])
pred_x0_tokens.shape

(64, 65)

In [44]:
pred_x0_tokens[:, 46:]

array([['>6', '>3', '>5', ..., ',', ',', ','],
       ['>6', '>5', '>13', ..., ',', ',', ','],
       ['>5', '>3', '>4', ..., ',', ',', ','],
       ...,
       ['>6', '>13', '>13', ..., ',', ',', ','],
       ['>5', '>13', '>13', ..., ',', ',', ','],
       ['>5', '>14', '>13', ..., ',', ',', ',']], dtype='<U3')

In [45]:
input_idx[:, 45].v

tensor[64] i64 x∈[19, 34] μ=25.812 σ=4.411
tensor([27, 30, 20, 28, 26, 20, 22, 26, 32, 34, 22, 25, 25, 21, 24, 22, 27, 30,
        26, 30, 34, 28, 22, 21, 31, 19, 33, 27, 30, 27, 19, 30, 20, 28, 24, 23,
        28, 24, 33, 20, 34, 24, 19, 33, 32, 30, 30, 22, 22, 22, 25, 29, 33, 23,
        25, 19, 28, 22, 24, 24, 25, 22, 25, 22])

In [89]:
seed_all(1)
for batch in trainer.dataloader: break

In [90]:
input_idx = batch['input_idx']
trainer.print_sample_pred(input_idx[0])

****************************************************************************************************
                                               4
                                               |
                                            +--+
                                            |
                                            9
                                            |
                                         +--+
                                         |
                                        13
                                         |
                    +--------------------+
                    |
                   12
                    |
        +-----------+--------+
        |                    |
        5                    1
        |                    |
     +--+--------+        +--+--------+
     |           |        |           |
     6           0        8          15
     |           |        |           |
  +--+        +--+     +--+        +--+
  |   

In [91]:
# goal: idx -> idx_baseline

# our model idx -> token
# bas model bastoken -> basidx

# first: create token2bastoken
# NOTE, in baseline: , is used as <PAD>

In [92]:
n_states = 16
bas_max_seq_length = n_states * 4 + 2

number_tokens = sorted([str(i) for i in range(n_states)], key=lambda x: len(x), reverse=True)
idx2tokens = [",", ":", "|"] + [f">{t}" for t in number_tokens] + number_tokens
tokens2idx = {token: idx for idx, token in enumerate(idx2tokens)}

In [93]:
# our_idx2token.keys()

In [94]:
from src.tree_dataset import PAD_TOKEN

In [95]:
token2bastoken = {k.replace('>', '→'):k for k in tokens2idx.keys()}
token2bastoken[PAD_TOKEN] = ','

our_idx2token = trainer.dataset.tokenizer.idx2token

idx2basidx = {}



In [96]:
for idx, our_tok in our_idx2token.items():
    bastoken = token2bastoken[our_tok]
    basidx = tokens2idx[bastoken]
    idx2basidx[idx] = basidx

In [97]:
BAS_ROOT_DELIM_TOKEN_IDX = tokens2idx['|']

In [98]:
def basidx2prompt(idx): return [idx2tokens[i] for i in idx]

In [99]:
our_tokens = detok(input_idx[0])
bas_tokens = [token2bastoken[t] for t in our_tokens]
print(f'{bas_tokens[:5]=}')

bas_idx = [tokens2idx[t] for t in bas_tokens]
upper_task_bound = bas_idx.index(BAS_ROOT_DELIM_TOKEN_IDX) + 4

prompt_idx = bas_idx[:upper_task_bound]
# basidx2prompt(prompt_idx)

bas_tokens[:5]=['1', '>8', ',', '1', '>15']


In [100]:
assert (prompt_idx )  == ([idx2basidx[i.item()] for i in  input_idx[0]][:len(prompt_idx)])

In [101]:
prompt_idx = bas_idx[:upper_task_bound]
for i in range(7):
    outputs, cache = baseline_model.run_with_cache(torch.tensor(prompt_idx))
    x_5 = cache['resid_post', -1]
    
    assert torch.allclose(baseline_model.unembed(baseline_model.ln_final(x_5)), outputs)
    
    pred_idx_greedy = outputs[0, -1].argmax()
    pred_token = idx2tokens[pred_idx_greedy]
    print(f'{pred_token=}')
    prompt_idx.append(pred_idx_greedy)

pred_token='>9'
pred_token='>13'
pred_token='>12'
pred_token='>5'
pred_token='>0'
pred_token='>10'
pred_token='>2'


In [102]:
prompt_idx = bas_idx[:upper_task_bound]
for i in range(7):
    outputs, cache = baseline_model.run_with_cache(torch.tensor(prompt_idx))
    x_5 = cache['resid_post', -1]
    assert torch.allclose(unembed(x_5), outputs)
    
    pred_idx_greedy = unembed(x_5)[0, -1].argmax()
    pred_token = idx2tokens[pred_idx_greedy]
    print(f'{pred_token=}')
    prompt_idx.append(pred_idx_greedy)
    break

pred_token='>9'


In [103]:
prompt_idx = bas_idx[:upper_task_bound]
for i in range(7):

    outputs, cache = baseline_model.run_with_cache(torch.tensor(prompt_idx))

    for L in range(6):
        x_l = cache['resid_post', L]
        
        
        pred_idx_greedy = unembed(x_l)[0, -1].argmax()
        pred_token = idx2tokens[pred_idx_greedy]
        print(f'unembed {L=} {pred_token=}')
    prompt_idx.append(pred_idx_greedy)
    print('*'*100)

unembed L=0 pred_token='>3'
unembed L=1 pred_token='>10'
unembed L=2 pred_token='>0'
unembed L=3 pred_token='>5'
unembed L=4 pred_token='>13'
unembed L=5 pred_token='>9'
****************************************************************************************************
unembed L=0 pred_token='>5'
unembed L=1 pred_token='>10'
unembed L=2 pred_token='>0'
unembed L=3 pred_token='>5'
unembed L=4 pred_token='>13'
unembed L=5 pred_token='>13'
****************************************************************************************************
unembed L=0 pred_token='>1'
unembed L=1 pred_token='>10'
unembed L=2 pred_token='>0'
unembed L=3 pred_token='>5'
unembed L=4 pred_token='>12'
unembed L=5 pred_token='>12'
****************************************************************************************************
unembed L=0 pred_token='>10'
unembed L=1 pred_token='>10'
unembed L=2 pred_token='>0'
unembed L=3 pred_token='>5'
unembed L=4 pred_token='>5'
unembed L=5 pred_token='>5'
***************

unembed L=1 pred_token='>7'


In [74]:

x_l = cache['resid_post', -1]
assert torch.allclose(unembed(x_l), outputs)

pred_idx_greedy = outputs[0, -1].argmax()
pred_token = idx2tokens[pred_idx_greedy]
print(f'{pred_token=}')

tensor[1, 48, 35] n=1680 (6.6Kb) x∈[-22.636, 29.493] μ=-0.211 σ=4.656 grad AddBackward0

In [None]:
x_5 = cache['resid_post', -1]
assert torch.allclose(baseline_model.unembed(baseline_model.ln_final(x_5)), outputs)

pred_idx_greedy = outputs[0, -1].argmax()
pred_token = idx2tokens[pred_idx_greedy]
print(f'{pred_token=}')

In [None]:
break

In [None]:
outputs

In [None]:
sample_input_idx = batch['input_idx'][0]

In [None]:
def inference_on_prompt_with_cache(trainer, prompt):
    tokens = trainer.tok(prompt)
    tokens = torch.tensor(tokens)[None]
    outputs, cache = trainer.model.run_with_cache(tokens)
    pred_token_greedy = outputs[0, -1].argmax()
    pred_token = trainer.detok([pred_token_greedy])
    return pred_token[0], cache

In [None]:
ROOT_DELIM_TOKEN_IDX = trainer.tok([':'])[0]


upper_task_bound = sample_input_idx.tolist().index(ROOT_DELIM_TOKEN_IDX) + 2
prompt_autoregressive = trainer.detok(sample_input_idx)[:upper_task_bound]
input_tree = input_tokens_to_tree(prompt_autoregressive)

prompt = prompt_autoregressive

parsed_input = parse_input_idx(sample_input_idx, trainer.dataset.tokenizer)
gt_path = parsed_input['path']
pred_path = []

for i in range(len(gt_path)):
    # pred_token = trainer.inference_on_prompt(prompt)
    pred_token, cache = inference_on_prompt_with_cache(trainer, prompt)
    pred_path.append(pred_token)
    prompt += [pred_token]

accuracy = (np.array(gt_path) == np.array(pred_path)).astype(float).mean()
print('*'*100)
print(input_tree)
print()
print(f'goal={parsed_input["goal"]}')
print(f'{accuracy=} {gt_path=} {pred_path=}' )
print('*'*100)

In [None]:
trainer.model