## Overview

Goal: touch `TransformerLens` hook API

In [1]:
N_TEST_BATCHES = 20

## Setup

In [2]:
import sys; sys.path.append('..')

In [3]:
# ! pip install lovely-tensors

import lovely_tensors as lt
lt.monkey_patch()

In [4]:
import torch
import numpy as np
import pandas as pd
import torch.nn.functional as F
from torch.utils.data import DataLoader

from transformer_lens import HookedTransformer, HookedTransformerConfig

In [5]:
import wandb
from tqdm.auto import tqdm

from omegaconf import OmegaConf

from datetime import datetime

from pathlib import Path

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

In [6]:
from src.tree import list2tree
from src.tree_dataset import TreeDataset, parse_input_idx, input_tokens_to_tree, tree_to_edges
from src.utils import seed_all


from src.trainer import accuracy_by_depth
from src.trainer import TreeTrainer

In [7]:
conf = OmegaConf.load('../conf/00_reproduce_6L_nodes=16.yaml')

# Output is identical to the YAML file
conf.n_nodes = 16
conf.device = 'cpu'
print(OmegaConf.to_yaml(conf))

random_seed: 42
n_nodes: 16
model:
  d_model: 128
  d_head: 128
  n_layers: 6
  act_fn: gelu
  attention_dir: causal
optimizer:
  lr: 0.001
  weight_decay: 0.01
batch_size: 64
epoch_len_steps: 5000
checkpoint_every_epoch: 2
device: cpu
debug: false
use_wandb: true
wandb:
  project: reasoning-mech-interp
  name: 00_6L_nodes=16
max_iters: null



In [8]:
REPRODUCED_MODEL_CKPT = '../checkpoints/reasoning-mech-interp__2024-04-10_16-10-20/00_6L_nodes=16__step=220000.pt'
# REPRODUCED_MODEL_CKPT = '../checkpoints/reasoning-mech-interp__2024-04-12_14-26-20/00_6L_nodes=16__deep_trees__step=9256.pt'

In [9]:
DEV_RUN = False
USE_WANDB = (not DEV_RUN) and conf.use_wandb
device = conf.device

CHECKPOINT_ROOT = Path('../checkpoints')

In [10]:
RANDOM_SEED = conf['random_seed']
print(f'{RANDOM_SEED=}')
seed_all(RANDOM_SEED)

RANDOM_SEED=42


In [11]:
trainer = TreeTrainer(conf)

tokenizer = trainer.dataset.tokenizer

tok = tokenizer.tokenize
detok = tokenizer.detokenize


ROOT_DELIM_TOKEN_IDX = trainer.tok([':'])[0]

state_dict = torch.load(REPRODUCED_MODEL_CKPT)
trainer.model.load_state_dict(state_dict)

Moving model to device:  cpu


<All keys matched successfully>

In [12]:
def load_baseline_model(device='cpu'):
    n_states = 16
    max_seq_length = n_states * 4 + 2
    
    number_tokens = sorted([str(i) for i in range(n_states)], key=lambda x: len(x), reverse=True)
    idx2tokens = [",", ":", "|"] + [f">{t}" for t in number_tokens] + number_tokens
    tokens2idx = {token: idx for idx, token in enumerate(idx2tokens)}
    
    
    cfg = HookedTransformerConfig(
        n_layers=6,
        d_model=128,
        n_ctx=max_seq_length - 1,
        n_heads=1,
        d_mlp=512,
        d_head=128,
        #attn_only=True,
        d_vocab=len(idx2tokens),
        device=device,
        attention_dir= "causal",
        act_fn="gelu",
    )
    model = HookedTransformer(cfg)
    
    model.load_state_dict(torch.load("/Users/mykhailokilianovskyi/src/backward-chaining-circuits/model.pt", map_location=torch.device(device)))
    
    return model

In [13]:
import random
import collections
from torch.utils.data import IterableDataset, DataLoader

In [14]:
import random

from src.tree import TreeNode
from src.utils import seed_all
from src.tree_dataset import random_tree_of_depth, DeepTreeDataset

In [15]:
deep_dataset = DeepTreeDataset(n_nodes=16, possible_depths=(15,14,13))
deep_tree_dataloader = DataLoader(deep_dataset, batch_size=conf['batch_size'])

In [16]:
baseline_model = load_baseline_model()

### Setup Linear Probe

In [17]:
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score

from sklearn.model_selection import train_test_split


from torch import nn


class LogisticRegressionModel(nn.Module):
    def __init__(self, n_features, n_classes):
        super(LogisticRegressionModel, self).__init__()
        self.linear = nn.Linear(n_features, n_classes)
    
    def forward(self, x):
        return self.linear(x)

# Assuming n_features is the number of features in your dataset and n_classes is the number of classes


In [18]:
from torch.utils.data import DataLoader, TensorDataset

def run_logreg(X,y, num_epochs=5, verbose=False):
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)


    logreg = LogisticRegressionModel(X_train.shape[1], len(np.unique(y_train)))
    
    logreg_train_dataset = TensorDataset(X_train, y_train)
    train_loader = DataLoader(logreg_train_dataset, batch_size=64, shuffle=True)
    
    logreg_test_dataset = TensorDataset(X_test, y_test)
    logreg_test_loader = DataLoader(logreg_test_dataset, batch_size=64, shuffle=True)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(logreg.parameters(), lr=0.001)
    
    
    train_losses = []
    train_accuracies = []
    
    test_losses = []
    test_accuracies = []
    
    for epoch in range(num_epochs):
        logreg.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = logreg(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
        epoch_loss = running_loss / len(train_loader)
        epoch_accuracy = 100 * correct / total
        train_losses.append(epoch_loss)
        train_accuracies.append(epoch_accuracy)

        if verbose:
            print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%')
    
        running_loss = 0.0
        correct = 0
        total = 0
        for inputs, labels in logreg_test_loader:
    
            with torch.inference_mode():
                outputs = logreg(inputs)
    
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
        epoch_loss = running_loss / len(logreg_test_loader)
        epoch_accuracy = 100 * correct / total
        train_losses.append(epoch_loss)
        train_accuracies.append(epoch_accuracy)

        if verbose:
            print(f'Test Loss: {epoch_loss:.4f}, Test Accuracy: {epoch_accuracy:.2f}%')
    return epoch_accuracy

In [19]:
our_idx2token = trainer.dataset.tokenizer.idx2token


n_states = 16
bas_max_seq_length = n_states * 4 + 2

number_tokens = sorted([str(i) for i in range(n_states)], key=lambda x: len(x), reverse=True)
idx2tokens = [",", ":", "|"] + [f">{t}" for t in number_tokens] + number_tokens
tokens2idx = {token: idx for idx, token in enumerate(idx2tokens)}

from src.tree_dataset import PAD_TOKEN


token2bastoken = {k.replace('>', '→'):k for k in tokens2idx.keys()}
token2bastoken[PAD_TOKEN] = ','



idx2basidx = {}
for idx, our_tok in our_idx2token.items():
    bastoken = token2bastoken[our_tok]
    basidx = tokens2idx[bastoken]
    idx2basidx[idx] = basidx

In [20]:
def baseline_batch_train_step(baseline_model, batch):
    
    input_idx = batch['input_idx'][..., :bas_max_seq_length].clone().to(device)
    mask = batch['task_mask'][..., :bas_max_seq_length].clone().to(device)

    input_idx.apply_(lambda i: idx2basidx[i])

    inputs = input_idx[:, :-1]
    
    out_mask = mask[:, 1:]
    targets = input_idx[:, 1:][out_mask]
    
    
    # print(input_idx[:1, :4])
    outputs = baseline_model(inputs)
    
    predictions = outputs[out_mask]
    
    loss = F.cross_entropy(predictions, targets)

    is_correct = (predictions.argmax(dim=-1) == targets)
    accuracy_mean = is_correct.float().mean()
    metrics = accuracy_by_depth(outputs, input_idx, out_mask)
    metrics['accuracy/mean'] = accuracy_mean.item()

    return loss, metrics

In [21]:

ROOT_DELIM_TOKEN_IDX = trainer.tok([':'])[0]

def extract_edges(sample_input_idx):
    upper_task_bound = sample_input_idx.tolist().index(ROOT_DELIM_TOKEN_IDX) + 2
    prompt = trainer.detok(sample_input_idx)[:upper_task_bound]
    
    i = 2
    
    edges = []
    
    while i < len(prompt) and prompt[i] != '|':
        edge = (prompt[i-2], prompt[i-1])
        edges.append(edge)
        i += 3
    
    return edges


n_nodes = 16

edges = [(str(i), '→'+str(j)) for i in range(n_nodes) for j in range(n_nodes)]

edge2idx = {e:i for i,e in enumerate(edges)}

idx2edge = {i:e for e,i in edge2idx.items()}
print(f'{len(edge2idx)=}')

def get_first_node(edge_idx): return int(idx2edge[edge_idx][0])
def get_second_node(edge_idx): return int(idx2edge[edge_idx][1][1:])
    

def get_edge_labels(batch):
    input_idx = batch['input_idx']
    edge_batch = []
    for row in input_idx:
        edges = extract_edges(row)
        edges = [edge2idx[e] for e in edges]
        edge_batch.append(edges)
    return torch.tensor(edge_batch)

len(edge2idx)=256


## Check that goal node is moved

### Reproduce "edge token" detection

In [22]:
third_token_idx = torch.arange(2, 14*3, 3)
third_token_idx.v

tensor[14] i64 x∈[2, 41] μ=21.500 σ=12.550
tensor([ 2,  5,  8, 11, 14, 17, 20, 23, 26, 29, 32, 35, 38, 41])

In [23]:
cache_keys = [{'key': 'resid_post', "n": n} for n in range(6)]
cache_keys = [{'key': 'embed', 'n': None}] + cache_keys
cache_keys = [
    {'key': 'embed', 'n': None}, # just embeddings as a control experiment
 {'key': 'resid_post', 'n': 0},
 {'key': 'resid_post', 'n': 1},
 {'key': 'resid_post', 'n': 2},
 {'key': 'resid_post', 'n': 3},
 {'key': 'resid_post', 'n': 4},
 {'key': 'resid_post', 'n': 5}]



In [24]:
for batch in deep_tree_dataloader:
    break

In [25]:
torch.inference_mode()
def get_cache_baseline_model(batch, baseline_model):
    input_idx = batch['input_idx'][..., :bas_max_seq_length].clone().to(device)
    mask = batch['task_mask'][..., :bas_max_seq_length].clone().to(device)

    input_idx.apply_(lambda i: idx2basidx[i])

    inputs = input_idx[:, :-1]
    
    out_mask = mask[:, 1:]
    targets = input_idx[:, 1:][out_mask]
    
    
    # print(input_idx[:1, :4])
    # outputs = baseline_model(inputs)
    
    outputs, cache = baseline_model.run_with_cache(inputs)
    return cache

In [26]:
second_node_idx = torch.arange(1, 14*3, 3)
second_node_idx.v

tensor[14] i64 x∈[1, 40] μ=20.500 σ=12.550
tensor([ 1,  4,  7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40])

In [27]:
pad_token_idx = torch.arange(2, 14*3, 3)
pad_token_idx.v

tensor[14] i64 x∈[2, 41] μ=21.500 σ=12.550
tensor([ 2,  5,  8, 11, 14, 17, 20, 23, 26, 29, 32, 35, 38, 41])

In [28]:
d_model = conf.model.d_model


def extract_Ln_second_node_xy(batch, baseline_model, key='resid_post', n=0):
    cache = get_cache_baseline_model(batch, baseline_model)
    resid_act0 = cache[key, n]
    
    X = resid_act0[:, second_node_idx].reshape(-1, d_model)
    y = get_edge_labels(batch).reshape(-1)
    
    return X,y


def extract_Ln_pad_token_xy(batch, baseline_model, key='resid_post', n=0):
    cache = get_cache_baseline_model(batch, baseline_model)
    resid_act0 = cache[key, n]
    
    X = resid_act0[:, pad_token_idx].reshape(-1, d_model)
    y = get_edge_labels(batch).reshape(-1)
    
    return X,y

In [29]:
baseline_model = load_baseline_model()

In [30]:
accuracies = []

for cache_key in cache_keys:

    X = []
    y = []
    
    for i,batch in tqdm(zip(range(N_TEST_BATCHES), deep_tree_dataloader), total=N_TEST_BATCHES):
        Xbatch, ybatch = extract_Ln_second_node_xy(batch, baseline_model, **cache_key)
    
        X.append(Xbatch)
        y.append(ybatch)
    
    
    X = torch.cat(X)
    y = torch.cat(y)
    
    y_first_node = torch.tensor( [get_first_node(i.item()) for i in y] )
    y_second_node = torch.tensor( [get_second_node(i.item()) for i in y] )

    acc_node0 = run_logreg(X, y_first_node)
    acc_node1 = run_logreg(X, y_second_node)


    layer_name = ''.join ( [f'{k}={v} ' for k,v in cache_key.items()] )[:-1]
    metric = {'acc_node_0': float(acc_node0), 'acc_node_1': float(acc_node1), 'layer_name':layer_name}
    print(f'{metric=}')

    accuracies.append(metric)

print('extract_Ln_second_node_xy')
pd.DataFrame(accuracies)

  0%|          | 0/20 [00:00<?, ?it/s]

metric={'acc_node_0': 7.059151785714286, 'acc_node_1': 100.0, 'layer_name': 'key=embed n=None'}


  0%|          | 0/20 [00:00<?, ?it/s]

metric={'acc_node_0': 100.0, 'acc_node_1': 100.0, 'layer_name': 'key=resid_post n=0'}


  0%|          | 0/20 [00:00<?, ?it/s]

metric={'acc_node_0': 100.0, 'acc_node_1': 100.0, 'layer_name': 'key=resid_post n=1'}


  0%|          | 0/20 [00:00<?, ?it/s]

metric={'acc_node_0': 100.0, 'acc_node_1': 100.0, 'layer_name': 'key=resid_post n=2'}


  0%|          | 0/20 [00:00<?, ?it/s]

metric={'acc_node_0': 100.0, 'acc_node_1': 100.0, 'layer_name': 'key=resid_post n=3'}


  0%|          | 0/20 [00:00<?, ?it/s]

metric={'acc_node_0': 100.0, 'acc_node_1': 100.0, 'layer_name': 'key=resid_post n=4'}


  0%|          | 0/20 [00:00<?, ?it/s]

metric={'acc_node_0': 100.0, 'acc_node_1': 100.0, 'layer_name': 'key=resid_post n=5'}
extract_Ln_second_node_xy


Unnamed: 0,acc_node_0,acc_node_1,layer_name
0,7.059152,100.0,key=embed n=None
1,100.0,100.0,key=resid_post n=0
2,100.0,100.0,key=resid_post n=1
3,100.0,100.0,key=resid_post n=2
4,100.0,100.0,key=resid_post n=3
5,100.0,100.0,key=resid_post n=4
6,100.0,100.0,key=resid_post n=5


In [31]:
import os
import torch as t
from torch import nn, Tensor
from torch.distributions.categorical import Categorical
from torch.nn import functional as F
from dataclasses import dataclass
import numpy as np
import einops
from jaxtyping import Float, Int
from typing import Optional, Callable, Union, List, Tuple
from functools import partial
from tqdm.notebook import tqdm
from dataclasses import dataclass

In [32]:
def linear_lr(step, steps):
    return (1 - (step / steps))

def constant_lr(*_):
    return 1.0

def cosine_decay_lr(step, steps):
    return np.cos(0.5 * np.pi * step / (steps - 1))

In [33]:
@dataclass
class AutoEncoderConfig:
    n_instances: int
    n_input_ae: int
    n_hidden_ae: int
    l1_coeff: float = 0.5
    tied_weights: bool = False
    weight_normalize_eps: float = 1e-8


class AutoEncoder(nn.Module):
    W_enc: Float[Tensor, "n_instances n_input_ae n_hidden_ae"]
    W_dec: Float[Tensor, "n_instances n_hidden_ae n_input_ae"]
    b_enc: Float[Tensor, "n_instances n_hidden_ae"]
    b_dec: Float[Tensor, "n_instances n_input_ae"]


    def __init__(self, cfg: AutoEncoderConfig):
        '''
        Initializes the two weights and biases according to the type signature above.

        If self.cfg.tied_weights = True, then we only create W_enc, not W_dec.
        '''
        pass


    def normalize_and_return_W_dec(self) -> Float[Tensor, "n_instances n_hidden_ae n_input_ae"]:
        '''
        If self.cfg.tied_weights = True, we return the normalized & transposed encoder weights.
        If self.cfg.tied_weights = False, we normalize the decoder weights in-place, and return them.

        Normalization should be over the `n_input_ae` dimension, i.e. each feature should have a noramlized decoder weight.
        '''
        pass


    def forward(self, h: Float[Tensor, "batch_size n_instances n_input_ae"]):
        '''
        Runs a forward pass on the autoencoder, and returns several outputs.

        Inputs:
            h: Float[Tensor, "batch_size n_instances n_input_ae"]
                hidden activations generated from a Model instance

        Returns:
            l1_loss: Float[Tensor, "batch_size n_instances"]
                L1 loss for each batch elem & each instance (sum over the `n_hidden_ae` dimension)
            l2_loss: Float[Tensor, "batch_size n_instances"]
                L2 loss for each batch elem & each instance (take mean over the `n_input_ae` dimension)
            loss: Float[Tensor, ""]
                Sum of L1 and L2 loss (with the former scaled by `self.cfg.l1_coeff). We sum over the `n_instances`
                dimension but take mean over the batch dimension
            acts: Float[Tensor, "batch_size n_instances n_hidden_ae"]
                Activations of the autoencoder's hidden states (post-ReLU)
            h_reconstructed: Float[Tensor, "batch_size n_instances n_input_ae"]
                Reconstructed hidden states, i.e. the autoencoder's final output
        '''
        pass

