In [1]:
import sys; sys.path.append('..')

In [2]:
import torch
import numpy as np
import torch.nn.functional as F
from torch.utils.data import DataLoader

In [3]:
from src.tree_dataset import TreeDataset
from src.tree_dataset import parse_input_idx

In [4]:
from transformer_lens import HookedTransformer, HookedTransformerConfig

In [5]:
import wandb

In [6]:
from omegaconf import OmegaConf

conf = OmegaConf.create({
    "n_nodes" : 16,
    "model": {
        "d_model": 128,
        "d_head": 128,
        "n_layers": 3,
        "act_fn": 'relu',
        'attention_dir': 'causal',
    },

    "optimizer": {
        "lr": 1e-3,
        "weight_decay": 1e-3,
    },

    "batch_size": 64,
    "device": "mps",
    "epoch_len_steps": 5000,
    "checkpoint_every_epoch": 10,
    "debug": False,
    "use_wandb": True,
    "wandb": {
        "project": "reasoning-mech-interp",
        "name": "00_3L_nodes=16"
    },
    "max_iters": None
})

USE_WANDB = conf.use_wandb
device = conf.device

print(OmegaConf.to_yaml(conf))

n_nodes: 16
model:
  d_model: 128
  d_head: 128
  n_layers: 3
  act_fn: relu
  attention_dir: causal
optimizer:
  lr: 0.001
  weight_decay: 0.001
batch_size: 64
device: mps
epoch_len_steps: 5000
checkpoint_every_epoch: 10
debug: false
use_wandb: true
wandb:
  project: reasoning-mech-interp
  name: 00_3L_nodes=16
max_iters: null



In [7]:
from datetime import datetime

# Get the current datetime
now = datetime.now()

# Format the datetime to a string suitable for filenames
now_filename = now.strftime("%Y-%m-%d_%H-%M-%S")
now_filename

'2024-04-10_16-02-15'

In [8]:
from pathlib import Path
CHECKPOINT_ROOT = Path('../checkpoints')

In [9]:
checkpoint_dir = CHECKPOINT_ROOT/f'{conf["wandb"]["project"]}__{now_filename}'
checkpoint_dir.mkdir(exist_ok=True, parents = True)

In [10]:
if USE_WANDB:
    wandb.init(project=conf.wandb.project, name=conf.wandb.name, config=OmegaConf.to_object(conf))

[34m[1mwandb[0m: Currently logged in as: [33mkilianovski[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [11]:
class TreeTrainer:
    def __init__(self, conf):
        dataset = TreeDataset(conf.n_nodes)
        self.dataset = dataset
        self.dataloader = DataLoader(dataset, batch_size=conf['batch_size'])

        conf.model["n_ctx"] = dataset.tokenizer.MAX_SEQ_LEN
        conf.model["d_vocab"] = len(dataset.tokenizer.token2idx)
        
        model_cfg = HookedTransformerConfig(
            **conf.model
        )
        
        device = conf.device
        self.device = device

        self.model = HookedTransformer(model_cfg).to(device)
        self.optimizer = torch.optim.AdamW(self.model.parameters(), **conf.optimizer)


    def train_step(self, batch):
        input_idx = batch['input_idx'].to(device)
        mask = batch['task_mask'].to(device)
        
        inputs = input_idx[:, :-1]
        
        out_mask = mask[:, 1:]
        targets = input_idx[:, 1:][out_mask]
        
        
        # print(input_idx[:1, :4])
        outputs = self.model(inputs)
        
        predictions = outputs[out_mask]
        
        loss = F.cross_entropy(predictions, targets)
        accuracy = (predictions.argmax(dim=-1) == targets).float().mean()
        return loss, accuracy

In [12]:
trainer = TreeTrainer(conf)

Moving model to device:  mps


In [13]:
tokenizer = trainer.dataset.tokenizer

tok = tokenizer.tokenize
detok = tokenizer.detokenize

ROOT_DELIM_TOKEN_IDX = tok([':'])[0]
ROOT_DELIM_TOKEN_IDX

def print_preds(batch, i=0):
    input_idx = batch['input_idx'].to(device)
    mask = batch['task_mask'].to(device)

    
    preds = trainer.model(input_idx)

    parsed_input = parse_input_idx(input_idx[i], tokenizer)
    print(parsed_input['tree'])
    print(f'goal: {parsed_input["goal"]}')
    print(f'true path = {detok(input_idx[i][mask[i]])}, predicted path = {detok(preds[i][mask[i]].argmax(dim=-1))}')


In [14]:
def inference_on_prompt(prompt, model=trainer.model):
    tokens = tok(prompt)
    tokens = torch.tensor(tokens)[None]
    pred_token_greedy = model(tokens)[0, -1].argmax()
    pred_token = detok([pred_token_greedy])
    return pred_token[0]

In [15]:
from src.tree_dataset import input_tokens_to_tree

In [16]:
def print_sample_pred(batch):
    sample_input_idx = batch['input_idx'].to(device)
    sample_input_idx = sample_input_idx[0]
    
    upper_task_bound = sample_input_idx.tolist().index(ROOT_DELIM_TOKEN_IDX) + 2
    prompt_autoregressive = detok(sample_input_idx)[:upper_task_bound]
    input_tree = input_tokens_to_tree(prompt_autoregressive)
    
    prompt = prompt_autoregressive
    
    parsed_input = parse_input_idx(sample_input_idx, tokenizer)
    gt_path = parsed_input['path']
    pred_path = []
    
    for i in range(len(gt_path)):
        pred_token = inference_on_prompt(prompt)
        pred_path.append(pred_token)
        prompt += [pred_token]
    
    accuracy = (np.array(gt_path) == np.array(pred_path)).astype(float).mean()
    print('*'*100)
    print(input_tree)
    print()
    print(f'goal={parsed_input["goal"]}')
    print(f'{accuracy=} {gt_path=} {pred_path=}' )
    print('*'*100)

In [17]:
from tqdm.auto import tqdm

In [None]:
global_step = 0
epoch_i = 0


while True:

    pbar = tqdm(list(range(conf.epoch_len_steps)))

    for i, batch in zip(pbar, trainer.dataloader):
        if i == 0:
            print_sample_pred(batch)
    
        trainer.optimizer.zero_grad()
        loss, accuracy = trainer.train_step(batch)
        loss.backward()
        trainer.optimizer.step()
        
        # pbar.update()
        pbar.set_description(f'loss={float(loss):.3f}')
        if USE_WANDB:
            wandb.log({'loss': loss, 'accuracy': accuracy})

        global_step += 1

    epoch_i += 1

    if epoch_i % conf['checkpoint_every_epoch'] == 0:
        checkpoint_filename = checkpoint_dir/f'{conf["wandb"]["name"]}__step={global_step}.pt'
        print(f'Saving {checkpoint_filename=}')
        torch.save(trainer.model.state_dict(), checkpoint_filename)
    
    if conf.get('max_iters', False) and global_step > conf.get('max_iters'):
        break

    

  0%|          | 0/5000 [00:00<?, ?it/s]

****************************************************************************************************
     1
     |
  +--+-----------------+
  |                    |
  9                   15
                       |
                    +--+--------------+
                    |                 |
                   10                 6
                    |                 |
           +--------+        +--------+--------+
           |                 |                 |
           2                 5                12
           |                 |                 |
        +--+-----+        +--+-----+        +--+
        |        |        |        |        |
       14        8       11        4       13
                 |                 |        |
              +--+              +--+     +--+
              |                 |        |
              3                 0        7

goal=11
accuracy=0.0 gt_path=['→15', '→6', '→5', '→11'] pred_path=['2', '0', '→1', '0']
*********************

## Debug on custom input

In [None]:
for batch in trainer.dataloader:
    print(parse_input_idx(batch['input_idx'][0], trainer.dataset.tokenizer)['tree'])
    break

In [None]:
with torch.inference_mode():
    locc, accuracy = trainer.train_step(batch)

In [None]:
input_idx_custom = batch['input_idx'][0]


parsed_input = parse_input_idx(input_idx_custom, trainer.dataset.tokenizer)
parsed_input

In [None]:
upper_task_bound = input_idx_custom.tolist().index(ROOT_DELIM_TOKEN_IDX) + 2
prompt_autoregressive = detok(input_idx_custom)[:upper_task_bound]
prompt = prompt_autoregressive
print(prompt)

In [None]:
new_token = inference_on_prompt(prompt)
prompt += [new_token]
print(prompt)

In [None]:
softmax_probs.shape

In [None]:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt


softmax_probs = preds[0][mask[0]].softmax(dim=-1).detach().cpu().numpy()

# Plotting softmax probabilities as a heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(softmax_probs, annot=True, cmap='coolwarm', fmt=".2f")
plt.title("Softmax Probabilities Heatmap")
plt.xlabel("Classes")
plt.ylabel("Instances")
plt.show()

In [None]:
input_idx_autoregressive

In [None]:
preds = trainer.model(input_idx)

In [None]:
detok(input_idx[0][mask[0]])

In [None]:
batch['input_idx'][:1]

In [None]:
accuracy

In [None]:
from src.tree import list2tree
from src.tree_dataset import tree_to_edges, tree2tokens



custom_tree = list2tree([0, 1, 2, None, None, 3])
custom_tree

In [None]:
prompt_tokens, path_tokens = tree2tokens(custom_tree)
prompt_tokens

In [None]:
self = trainer.dataset
prompt_idx = self.tokenizer(prompt_tokens)
path_idx = self.tokenizer(path_tokens)

input_tokens = prompt_idx + path_idx
pad_len = self.tokenizer.MAX_SEQ_LEN - len(input_tokens)

input_idx = torch.tensor(input_tokens + [0] * pad_len)
# pad_mask = torch.zeros(self.tokenizer.MAX_SEQ_LEN)
# pad_mask[:len(input_idx)] = 1

task_mask = torch.zeros(self.tokenizer.MAX_SEQ_LEN, dtype=torch.bool)
task_mask[len(prompt_idx):len(input_tokens)] = True

In [None]:
trainer.model(input_idx).argmax(dim=-1)

In [None]:
custom_input = ['0', ',', '→2', '0', ',', '→1', '1', ',', '→3', '|', '0', ':', '3', '→1', '→3']
tok(custom_input)

In [None]:
custom_batch = {
                'input_idx': input_idx[None],
                'task_mask': task_mask[None],
                # 'pad_mask': pad_mask,
            }

In [None]:
print_preds(custom_batch)

In [None]:
input_idx = custom_batch['input_idx'].to(device)
mask = batch['task_mask'].to(device)

preds = trainer.model(custom_batch['input_idx'][:, :])

In [None]:
detok(preds.argmax(dim=-1)[0])

In [None]:
tree_to_edges(parsed_input['tree'])

In [None]:
input_idx = batch['input_idx'].to(device)
mask = batch['task_mask'].to(device)

In [None]:
logits, cache = trainer.model.run_with_cache(input_idx[:1])

In [None]:
# pip install circuitsvis

In [None]:
import circuitsvis as cv

In [None]:
# ! pip install webbrowser

In [None]:
# path = "attn_heads.html"

# with open(path, "w") as f:
#     f.write(str(attn_heads))

# webbrowser.open(path)


In [None]:
parse_input_idx((input_idx[0]), tokenizer)[:0]

In [None]:
print(type(cache))
attention_pattern = cache["pattern", 0]
print(attention_pattern.shape)
str_tokens = detok (input_idx[0].cpu())

print("Layer 0 Head Attention Patterns:")
attn_heads = (cv.attention.attention_patterns(
    tokens=str_tokens,
    attention=attention_pattern[0],
    attention_head_names=[f"L0H{i}" for i in range(1)],
))


attn_heads

In [None]:
attention_pattern.shape

In [None]:
%pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python

In [None]:
path = "attn_heads.html"

with open(path, "w") as f:
    f.write(str(attn_heads))

In [None]:
cv.__version__

In [None]:
wandb.log({"custom_file": wandb.Html(open(path))})

In [None]:
! python --version

In [None]:
display(attn_heads)

In [None]:
# Python Example
from circuitsvis.tokens import colored_tokens
colored_tokens(["My", "tokens"], [0.123, -0.226])