## Setup

In [1]:
import sys; sys.path.append('..')

In [2]:
# ! pip install lovely-tensors

import lovely_tensors as lt
lt.monkey_patch()

In [3]:
import torch
import numpy as np
import pandas as pd
import torch.nn.functional as F
from torch.utils.data import DataLoader

from transformer_lens import HookedTransformer, HookedTransformerConfig

In [4]:
import wandb
from tqdm.auto import tqdm

from omegaconf import OmegaConf

from datetime import datetime

from pathlib import Path

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

In [5]:
from src.tree import list2tree
from src.tree_dataset import TreeDataset, parse_input_idx, input_tokens_to_tree, tree_to_edges
from src.utils import seed_all


from src.trainer import accuracy_by_depth
from src.trainer import TreeTrainer

In [6]:
conf = OmegaConf.load('../conf/00_reproduce_6L_nodes=16.yaml')
# Output is identical to the YAML file
conf.n_nodes = 16
conf.device = 'cpu'
print(OmegaConf.to_yaml(conf))

random_seed: 42
n_nodes: 16
model:
  d_model: 128
  d_head: 128
  n_layers: 6
  act_fn: gelu
  attention_dir: causal
optimizer:
  lr: 0.001
  weight_decay: 0.01
batch_size: 64
epoch_len_steps: 5000
checkpoint_every_epoch: 2
device: cpu
debug: false
use_wandb: true
wandb:
  project: reasoning-mech-interp
  name: 00_6L_nodes=16
max_iters: null



In [7]:
REPRODUCED_MODEL_CKPT = '../checkpoints/reasoning-mech-interp__2024-04-10_16-10-20/00_6L_nodes=16__step=220000.pt'

In [8]:
device = conf.device

RANDOM_SEED = conf['random_seed']
print(f'{RANDOM_SEED=}')
seed_all(RANDOM_SEED)

RANDOM_SEED=42


In [9]:
trainer = TreeTrainer(conf)

tokenizer = trainer.dataset.tokenizer

tok = tokenizer.tokenize
detok = tokenizer.detokenize


ROOT_DELIM_TOKEN_IDX = trainer.tok([':'])[0]

state_dict = torch.load(REPRODUCED_MODEL_CKPT)
trainer.model.load_state_dict(state_dict)

Moving model to device:  cpu


<All keys matched successfully>

In [10]:
def load_baseline_model(device='cpu'):
    n_states = 16
    max_seq_length = n_states * 4 + 2
    
    number_tokens = sorted([str(i) for i in range(n_states)], key=lambda x: len(x), reverse=True)
    idx2tokens = [",", ":", "|"] + [f">{t}" for t in number_tokens] + number_tokens
    tokens2idx = {token: idx for idx, token in enumerate(idx2tokens)}
    
    
    cfg = HookedTransformerConfig(
        n_layers=6,
        d_model=128,
        n_ctx=max_seq_length - 1,
        n_heads=1,
        d_mlp=512,
        d_head=128,
        #attn_only=True,
        d_vocab=len(idx2tokens),
        device=device,
        attention_dir= "causal",
        act_fn="gelu",
    )
    model = HookedTransformer(cfg)
    
    model.load_state_dict(torch.load("/Users/mykhailokilianovskyi/src/backward-chaining-circuits/model.pt", map_location=torch.device(device)))
    
    return model

In [11]:
from src.tree_dataset import PAD_TOKEN

In [12]:
n_states = 16
bas_max_seq_length = n_states * 4 + 2

number_tokens = sorted([str(i) for i in range(n_states)], key=lambda x: len(x), reverse=True)
idx2tokens = [",", ":", "|"] + [f">{t}" for t in number_tokens] + number_tokens
tokens2idx = {token: idx for idx, token in enumerate(idx2tokens)}

In [13]:
token2bastoken = {k.replace('>', '→'):k for k in tokens2idx.keys()}
token2bastoken[PAD_TOKEN] = ','

our_idx2token = trainer.dataset.tokenizer.idx2token

idx2basidx = {}

In [14]:
for idx, our_tok in our_idx2token.items():
    bastoken = token2bastoken[our_tok]
    basidx = tokens2idx[bastoken]
    idx2basidx[idx] = basidx

In [15]:
def map_input_idx_to_bas(input_idx):
    return input_idx.clone().apply_(lambda i: idx2basidx[i])

## Do Transformers dream on their own goals?

In [16]:
from src.tree_dataset import random_binary_tree, create_tree_path


test_tree = random_binary_tree(16, seed=42)
prompt_tokens, path_tokens = create_tree_path(test_tree, goal=9)
print(f'{prompt_tokens[:7]=}')
test_tree

prompt_tokens[:7]=['12', '→10', ',', '4', '→15', ',', '12']


                       3
                       |
              +--------+-----------+
              |                    |
              0                    8
              |                    |
  +-----------+-----+        +-----+--------+
  |                 |        |              |
 11                 2       12              6
  |                 |        |              |
  +--------+     +--+     +--+--+     +-----+--+
           |     |        |     |     |        |
           4     1       10    14     5        7
           |                          |
     +-----+                          +--+
     |                                   |
    15                                   9
     |
     +--+
        |
       13

#### Setup

Experiment goal: test if trained model can come up with its 

**Why it is interesting?**

  - It was only trained to find path given the goal and root specified, and this is (very roughly) Out-of-distribution task
  - On the abstract level idea sounds cool: let transformer come up with its own goal and
    - We can study what kind of goals are hallucinated (e.g. do the depth distribution matches the training data, or it matches distribution that model learned best / has best performance on)
   

**Implementation details**

  - Leave only tree node
  - Generate batches of such examples
  - Find the correct path to the goal with BFS
  - Record a bunch of statistics:
    - If goal was hallucinated at all (have the correct format)
    - If prediction was successful
    - Depth of the hallucinated goal
    - Whether goal is a leaf node
    - Position of edges with a goal node (to see maybe goal generation is biased towards first/last mentioned edge)

In [17]:
for batch in trainer.dataloader:
    break

In [18]:
input_idx = batch['input_idx'].to()

In [19]:
BAS_ROOT_DELIM_TOKEN_IDX = tokens2idx['|']
def basidx2prompt(idx): return [idx2tokens[i] for i in idx]

In [20]:
baseline_model = load_baseline_model()

In [21]:
our_tokens = detok(input_idx[0])
bas_tokens = [token2bastoken[t] for t in our_tokens]
# print(f'{bas_tokens[:5]=}')

bas_idx = [tokens2idx[t] for t in bas_tokens]
upper_task_bound = bas_idx.index(BAS_ROOT_DELIM_TOKEN_IDX) + 4
upper_TREE_bound = bas_idx.index(BAS_ROOT_DELIM_TOKEN_IDX)
prompt_idx = bas_idx[:upper_TREE_bound+1]
basidx2prompt(prompt_idx)[-5:]

['>15', ',', '8', '>1', '|']

In [22]:
for i in range(7):
    pred_idx_greedy = baseline_model(torch.tensor(prompt_idx))[0, -1].argmax()
    pred_token = idx2tokens[pred_idx_greedy]
    print(f'{pred_token=}')
    prompt_idx.append(pred_idx_greedy)

pred_token='>11'
pred_token='>13'
pred_token='>5'
pred_token='>9'
pred_token='>4'
pred_token='>11'
pred_token='>8'


In [23]:
for j,batch in tqdm(zip(range(1), trainer.dataloader) ):
    input_idx = batch['input_idx'].to(device)
    for sample_input_idx in input_idx:
        ROOT_DELIM_TOKEN_IDX = trainer.tok([':'])[0]
        
        
        upper_task_bound = sample_input_idx.tolist().index(ROOT_DELIM_TOKEN_IDX) + 2
        prompt_autoregressive = trainer.detok(sample_input_idx)[:upper_task_bound-3]
        input_tree = input_tokens_to_tree(prompt_autoregressive)
        
        prompt = prompt_autoregressive
    
        parsed_input = parse_input_idx(sample_input_idx, trainer.dataset.tokenizer)
        gt_path = parsed_input['path']
        pred_path = []
        
        for i in range(len(gt_path)):
            pred_token = trainer.inference_on_prompt(prompt)
            pred_path.append(pred_token)
            prompt += [pred_token]
        
        if ':' in pred_path:
            print('FoUND')
            print(pred_path)

0it [00:00, ?it/s]

In [24]:
ROOT_DELIM_TOKEN_IDX = trainer.tok([':'])[0]


upper_task_bound = sample_input_idx.tolist().index(ROOT_DELIM_TOKEN_IDX) + 2
prompt_autoregressive = trainer.detok(sample_input_idx)[:upper_task_bound-1]
input_tree = input_tokens_to_tree(prompt_autoregressive)

prompt = prompt_autoregressive
print(prompt[-5:])
parsed_input = parse_input_idx(sample_input_idx, trainer.dataset.tokenizer)
gt_path = parsed_input['path']
pred_path = []

for i in range(len(gt_path)):
    pred_token = trainer.inference_on_prompt(prompt)
    pred_path.append(pred_token)
    prompt += [pred_token]

['15', '→8', '|', '7', ':']


In [25]:
pred_path

['→8', '→3']

In [26]:
input_idx = batch['input_idx'].to(device)
for sample_input_idx in input_idx:
    ROOT_DELIM_TOKEN_IDX = trainer.tok([':'])[0]
    
    
    upper_task_bound = sample_input_idx.tolist().index(ROOT_DELIM_TOKEN_IDX) + 2
    prompt_autoregressive = trainer.detok(sample_input_idx)[:upper_task_bound-1]
    input_tree = input_tokens_to_tree(prompt_autoregressive)
    
    prompt = prompt_autoregressive

    parsed_input = parse_input_idx(sample_input_idx, trainer.dataset.tokenizer)
    gt_path = parsed_input['path']
    pred_path = []
    
    for i in range(len(gt_path)):
        pred_token = trainer.inference_on_prompt(prompt)
        pred_path.append(pred_token)
        prompt += [pred_token]
    
    if ':' in pred_path:
        print('FoUND')
        print(pred_path)

In [27]:
for j,batch in tqdm(zip(range(1), trainer.dataloader) ):
    input_idx = batch['input_idx'].to(device)
    for sample_input_idx in input_idx:
        ROOT_DELIM_TOKEN_IDX = trainer.tok([':'])[0]
        
        
        upper_task_bound = sample_input_idx.tolist().index(ROOT_DELIM_TOKEN_IDX) + 2
        prompt_autoregressive = trainer.detok(sample_input_idx)[:upper_task_bound-3]
        input_tree = input_tokens_to_tree(prompt_autoregressive)
        
        prompt = prompt_autoregressive
    
        parsed_input = parse_input_idx(sample_input_idx, trainer.dataset.tokenizer)
        gt_path = parsed_input['path']
        pred_path = []
        
        for i in range(len(gt_path)):
            pred_token = trainer.inference_on_prompt(prompt)
            pred_path.append(pred_token)
            prompt += [pred_token]
        
        if ':' in pred_path:
            print('FoUND')
            print(pred_path)

0it [00:00, ?it/s]

In [28]:
# break

## H#0: "Edge Embeddings" are learned

### Check that this info is in $x_1$ (residual stream after first layer)

In [95]:
seed_all(20)
for batch in trainer.dataloader: 
    break

In [96]:
self = trainer

In [97]:
input_idx = batch['input_idx'].to(self.device)
mask = batch['task_mask'].to(self.device)

inputs = input_idx[:, :-1]

out_mask = mask[:, 1:]
targets = input_idx[:, 1:][out_mask]


# print(input_idx[:1, :4])
outputs = self.model(inputs)

predictions = outputs[out_mask]

loss = F.cross_entropy(predictions, targets)

is_correct = (predictions.argmax(dim=-1) == targets)
accuracy_mean = is_correct.float().mean()
metrics = accuracy_by_depth(outputs, input_idx, out_mask)
metrics['accuracy/mean'] = accuracy_mean.item()

In [98]:
outputs, cache = self.model.run_with_cache(inputs)

In [99]:
l0_attention_pattern = cache['pattern', 0][0]
input_str = detok(inputs[0])

In [100]:
test_tree_tokens = detok(inputs[0])

In [101]:
trainer.print_sample_pred(inputs[0])

****************************************************************************************************
                                               4
                                               |
                 +-----------------------------+
                 |
                15
                 |
        +--------+-----------------------+
        |                                |
        7                               10
        |                                |
     +--+-----+              +-----------+--+
     |        |              |              |
    14       11             13              6
     |        |              |
  +--+     +--+           +--+-----+
  |        |              |        |
 12        8              1        2
                          |        |
                       +--+     +--+--+
                       |        |     |
                       5        0     3
                       |
                    +--+
                    |
              

In [102]:
import circuitsvis as cv

In [103]:
print("Layer 0 Head Attention Pattern:")
display(cv.attention.attention_patterns(
    tokens=input_str,
    attention=l0_attention_pattern,
    attention_head_names=[f"L0H{i}" for i in range(1)],
))

Layer 0 Head Attention Pattern:


#### Observation: last edge is special

 - looks like "edge embeddings" are accumulated in pad token `,`
 - but there is no `,` for the last edge (both in my impl and baseline impl)
 - attn pattern looks weird there
 - what if our path goes through last edge? Will performance drop?

In [38]:
from copy import deepcopy

In [39]:
def swap_edges(seq, e0, e1):
    seq = deepcopy(seq)
    
    e0i = -1
    e1i = -1
    
    for i in range(len(seq)-1):
        if tuple(seq[i:i+2]) == e0: e0i = i
    
    for i in range(len(seq)-1):
        if tuple(seq[i:i+2]) == e1: e1i = i
    
    seq[e0i:e0i+2] = e1
    seq[e1i:e1i+2] = e0
    return seq

In [40]:
# swap edge (12→15) with last edge (14→10)

e0 = ('15', '→4')
e1 = ('14', '→10')
input_str_swap0 = swap_edges(input_str, e0, e1)[:68]
input_idx_swap0 = torch.tensor(tok(input_str_swap0))

outputs, cache = self.model.run_with_cache(input_idx_swap0[None])

In [41]:
# for x,y in zip(input_str_swap0, input_str):
#     print(x,y)

In [42]:
input_str_swap0 == input_str

False

In [43]:
l0_attention_pattern

tensor[1, 67, 67] n=4489 (18Kb) x∈[0., 1.000] μ=0.015 σ=0.045

In [44]:
len(input_str_swap0)

67

In [45]:
print("Layer 0 Head Attention Pattern:")
l0_attention_pattern = cache['pattern', 0][0]
display(cv.attention.attention_patterns(
    tokens=input_str_swap0,
    attention=l0_attention_pattern,
    attention_head_names=[f"L0H{i}" for i in range(1)],
))

Layer 0 Head Attention Pattern:


In [46]:
trainer.print_sample_pred(input_idx[0])

****************************************************************************************************
                            12
                             |
                    +--------+-----------+
                    |                    |
                   15                    2
                    |                    |
              +-----+-----+           +--+-----+
              |           |           |        |
              4           6           5        9
              |           |           |        |
     +--------+--+     +--+        +--+     +--+
     |           |     |           |        |
     8           3     0          14        1
     |                             |
  +--+-----+                    +--+
  |        |                    |
 11        7                   10
           |
        +--+
        |
       13

goal=3
accuracy=1.0 gt_path=['→15', '→4', '→3'] pred_path=['→15', '→4', '→3']
*************************************************************

In [47]:
trainer.print_sample_pred(input_idx_swap0)

****************************************************************************************************
                            12
                             |
        +--------------------+-----------+
        |                                |
       15                                2
        |                                |
     +--+--------------+              +--+-----+
     |                 |              |        |
     6                 4              5        9
     |                 |              |        |
  +--+        +--------+--+        +--+     +--+
  |           |           |        |        |
  0           8           3       14        1
              |                    |
           +--+-----+           +--+
           |        |           |
          11        7          10
                    |
                 +--+
                 |
                13

goal=3
accuracy=1.0 gt_path=['→15', '→4', '→3'] pred_path=['→15', '→4', '→3']
*************************

#### is it also special in baseline model?

In [48]:
baseline_model = load_baseline_model()

In [49]:
l0_attention_pattern.squeeze()

tensor[67, 67] n=4489 (18Kb) x∈[0., 1.000] μ=0.015 σ=0.045

In [50]:
bas_input_idx = map_input_idx_to_bas(input_idx[0])[:bas_max_seq_length-1]
bas_input_str = [idx2tokens[i] for i in bas_input_idx]

outputs, cache = baseline_model.run_with_cache(bas_input_idx[None])
l0_attention_pattern_bas = cache['pattern', 0][0]

print("Baseline L0 Head Attention Pattern:")

display(cv.attention.attention_patterns(
    tokens=bas_input_str,
    attention=l0_attention_pattern_bas,
    attention_head_names=[f"L0H{i}" for i in range(1)],
))

Baseline L0 Head Attention Pattern:


In [51]:
bas_input_idx = map_input_idx_to_bas(input_idx_swap0)[:bas_max_seq_length-1]
bas_input_str = [idx2tokens[i] for i in bas_input_idx]

outputs, cache = baseline_model.run_with_cache(bas_input_idx[None])
l0_attention_pattern_bas = cache['pattern', 0][0]

print("Baseline L0 Head Attention Pattern:")

display(cv.attention.attention_patterns(
    tokens=bas_input_str,
    attention=l0_attention_pattern_bas,
    attention_head_names=[f"L0H{i}" for i in range(1)],
))

Baseline L0 Head Attention Pattern:


In [52]:
print("Layer 1 Head Attention Pattern:")

l1_attention_pattern_bas = cache['pattern', 1][0]

display(cv.attention.attention_patterns(
    tokens=bas_input_str,
    attention=l1_attention_pattern_bas,
    attention_head_names=[f"L1H{i}" for i in range(1)],
))

Layer 1 Head Attention Pattern:


### Do so called Linear Probe

How to do linear probe?

  - Take info from the unit of interest and try to classify it.

Paper makes the hypothesis that first layer aggregates info about edge into the target pos.

> the model aggregates the source and target nodes of each edge in the edge list into the target node position

  - [ ] How to access residual stream in TransformerLens?


> cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type].


`final_residual_stream = cache["resid_post", -1]`

In [53]:
from src.tree_dataset import TreeDataset, DeepTreeDataset, parse_input_idx, input_tokens_to_tree, tree_to_edges

test_dataloader = DataLoader ( DeepTreeDataset(possible_depths=list(range(1,16))), batch_size=64 )

In [54]:
for batch in test_dataloader:
    break

In [55]:
model = trainer.model

In [56]:
torch.inference_mode()
def get_cache(batch, model):
    input_idx = batch['input_idx'].to(device)
    mask = batch['task_mask'].to(device)

    inputs = input_idx[:, :-1]
    
    out_mask = mask[:, 1:]
    targets = input_idx[:, 1:][out_mask]
    
    outputs, cache = model.run_with_cache(inputs)
    return cache

In [57]:
# we expect second node in edge to contain info about two nodes in edge

In [58]:
sample_input_idx = input_idx[8]

ROOT_DELIM_TOKEN_IDX = trainer.tok([':'])[0]

def extract_edges(sample_input_idx):
    upper_task_bound = sample_input_idx.tolist().index(ROOT_DELIM_TOKEN_IDX) + 2
    prompt = trainer.detok(sample_input_idx)[:upper_task_bound]
    
    i = 2
    
    edges = []
    
    while i < len(prompt) and prompt[i] != '|':
        edge = (prompt[i-2], prompt[i-1])
        edges.append(edge)
        i += 3
    
    return edges


n_nodes = 16

edges = [(str(i), '→'+str(j)) for i in range(n_nodes) for j in range(n_nodes)]

edge2idx = {e:i for i,e in enumerate(edges)}

idx2edge = {i:e for e,i in edge2idx.items()}
print(f'{len(edge2idx)=}')

def get_first_node(edge_idx): return int(idx2edge[edge_idx][0])
def get_second_node(edge_idx): return int(idx2edge[edge_idx][1][1:])
    

def get_edge_labels(batch):
    input_idx = batch['input_idx']
    edge_batch = []
    for row in input_idx:
        edges = extract_edges(row)
        edges = [edge2idx[e] for e in edges]
        edge_batch.append(edges)
    return torch.tensor(edge_batch)

len(edge2idx)=256


256

len(edge2idx)=256


In [63]:
second_token_idx = torch.arange(1, 14*3, 3)
second_token_idx.v

tensor[14] i64 x∈[1, 40] μ=20.500 σ=12.550
tensor([ 1,  4,  7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40])

In [64]:
d_model = conf.model.d_model

In [65]:
def extract_Ln_second_node_xy(batch, model, key='resid_post', n=0):
    cache = get_cache(batch, model)
    resid_act0 = cache[key, n]
    
    X = resid_act0[:, second_token_idx].reshape(-1, d_model)
    y = get_edge_labels(batch).reshape(-1)
    
    return X,y

In [66]:
model = trainer.model

In [67]:
extract_Ln_second_node_xy(batch, model)

(tensor[896, 128] n=114688 (0.4Mb) x∈[-19.844, 12.516] μ=-0.014 σ=3.187,
 tensor[896] i64 7Kb x∈[1, 254] μ=126.410 σ=72.792)

In [68]:
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score

from sklearn.model_selection import train_test_split


from torch import nn


class LogisticRegressionModel(nn.Module):
    def __init__(self, n_features, n_classes):
        super(LogisticRegressionModel, self).__init__()
        self.linear = nn.Linear(n_features, n_classes)
    
    def forward(self, x):
        return self.linear(x)

# Assuming n_features is the number of features in your dataset and n_classes is the number of classes


In [71]:
num_epochs = 100

In [72]:
from torch.utils.data import DataLoader, TensorDataset

def run_logreg(X,y, num_epochs=5):
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)


    logreg = LogisticRegressionModel(X_train.shape[1], len(np.unique(y_train)))
    
    logreg_train_dataset = TensorDataset(X_train, y_train)
    train_loader = DataLoader(logreg_train_dataset, batch_size=64, shuffle=True)
    
    logreg_test_dataset = TensorDataset(X_test, y_test)
    logreg_test_loader = DataLoader(logreg_test_dataset, batch_size=64, shuffle=True)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(logreg.parameters(), lr=0.001)
    
    
    train_losses = []
    train_accuracies = []
    
    test_losses = []
    test_accuracies = []
    
    for epoch in range(num_epochs):
        logreg.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = logreg(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
        epoch_loss = running_loss / len(train_loader)
        epoch_accuracy = 100 * correct / total
        train_losses.append(epoch_loss)
        train_accuracies.append(epoch_accuracy)
    
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%')
    
        running_loss = 0.0
        correct = 0
        total = 0
        for inputs, labels in logreg_test_loader:
    
            with torch.inference_mode():
                outputs = logreg(inputs)
    
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
        epoch_loss = running_loss / len(logreg_test_loader)
        epoch_accuracy = 100 * correct / total
        train_losses.append(epoch_loss)
        train_accuracies.append(epoch_accuracy)
    
        print(f'Test Loss: {epoch_loss:.4f}, Test Accuracy: {epoch_accuracy:.2f}%')
    return epoch_accuracy

### Linear Probe x_1

In [73]:
# n_batches = 100
# X = []
# y = []

# for i,batch in tqdm(zip(range(n_batches), test_dataloader), total=n_batches):
#     Xbatch, ybatch = extract_Ln_second_node_xy(batch, model, n=0)

#     X.append(Xbatch)
#     y.append(ybatch)


# X = torch.cat(X)
# y = torch.cat(y)

# y_first_node = torch.tensor( [get_first_node(i.item()) for i in y] )
# y_second_node = torch.tensor( [get_second_node(i.item()) for i in y] )

In [74]:
# run_logreg(X, y_first_node)

In [75]:
# run_logreg(X, y_second_node)

### Linear Probe x_0

In [76]:
# cache['embed']

In [77]:


# n_batches = 100
# X = []
# y = []

# for i,batch in tqdm(zip(range(n_batches), test_dataloader), total=n_batches):
#     Xbatch, ybatch = extract_Ln_second_node_xy(batch, model,key='embed', n=None)

#     X.append(Xbatch)
#     y.append(ybatch)


# X = torch.cat(X)
# y = torch.cat(y)

# y_first_node = torch.tensor( [get_first_node(i.item()) for i in y] )
# y_second_node = torch.tensor( [get_second_node(i.item()) for i in y] )

In [78]:
# run_logreg(X, y_first_node)

In [79]:
# run_logreg(X, y_second_node)

### Linear Probe each layer!

In [80]:
cache_keys = [{'key': 'resid_post', "n": n} for n in range(6)]
cache_keys = [{'key': 'embed', 'n': None}] + cache_keys
cache_keys

[{'key': 'embed', 'n': None},
 {'key': 'resid_post', 'n': 0},
 {'key': 'resid_post', 'n': 1},
 {'key': 'resid_post', 'n': 2},
 {'key': 'resid_post', 'n': 3},
 {'key': 'resid_post', 'n': 4},
 {'key': 'resid_post', 'n': 5}]

In [81]:
accuracies = []

for cache_key in cache_keys:
    n_batches = 100
    X = []
    y = []
    
    for i,batch in tqdm(zip(range(n_batches), test_dataloader), total=n_batches):
        Xbatch, ybatch = extract_Ln_second_node_xy(batch, model, **cache_key)
    
        X.append(Xbatch)
        y.append(ybatch)
    
    
    X = torch.cat(X)
    y = torch.cat(y)
    
    y_first_node = torch.tensor( [get_first_node(i.item()) for i in y] )
    y_second_node = torch.tensor( [get_second_node(i.item()) for i in y] )

    acc_node0 = run_logreg(X, y_first_node)
    acc_node1 = run_logreg(X, y_second_node)


    layer_name = ''.join ( [f'{k}={v} ' for k,v in cache_key.items()] )[:-1]
    metric = {'acc_node_0': float(acc_node0), 'acc_node_1': float(acc_node1), 'layer_name':layer_name}
    print(f'{metric=}')

    accuracies.append(metric)

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch [1/5], Loss: 2.7370, Accuracy: 6.78%
Test Loss: 0.0000, Test Accuracy: 6.57%
Epoch [2/5], Loss: 2.7203, Accuracy: 6.79%
Test Loss: 0.0000, Test Accuracy: 6.78%
Epoch [3/5], Loss: 2.7168, Accuracy: 6.75%
Test Loss: 0.0000, Test Accuracy: 6.52%
Epoch [4/5], Loss: 2.7155, Accuracy: 6.78%
Test Loss: 0.0000, Test Accuracy: 6.70%
Epoch [5/5], Loss: 2.7147, Accuracy: 6.55%
Test Loss: 0.0000, Test Accuracy: 6.64%
Epoch [1/5], Loss: 0.2792, Accuracy: 98.98%
Test Loss: 0.0000, Test Accuracy: 100.00%
Epoch [2/5], Loss: 0.0091, Accuracy: 100.00%
Test Loss: 0.0000, Test Accuracy: 100.00%
Epoch [3/5], Loss: 0.0032, Accuracy: 100.00%
Test Loss: 0.0000, Test Accuracy: 100.00%
Epoch [4/5], Loss: 0.0015, Accuracy: 100.00%
Test Loss: 0.0000, Test Accuracy: 100.00%
Epoch [5/5], Loss: 0.0008, Accuracy: 100.00%
Test Loss: 0.0000, Test Accuracy: 100.00%
metric={'acc_node_0': 6.640625, 'acc_node_1': 100.0, 'layer_name': 'key=embed n=None'}


  0%|          | 0/100 [00:00<?, ?it/s]

Epoch [1/5], Loss: 0.2143, Accuracy: 93.28%
Test Loss: 0.0000, Test Accuracy: 95.76%
Epoch [2/5], Loss: 0.0916, Accuracy: 97.47%
Test Loss: 0.0000, Test Accuracy: 98.21%
Epoch [3/5], Loss: 0.0518, Accuracy: 99.14%
Test Loss: 0.0000, Test Accuracy: 99.59%
Epoch [4/5], Loss: 0.0305, Accuracy: 99.69%
Test Loss: 0.0000, Test Accuracy: 99.92%
Epoch [5/5], Loss: 0.0190, Accuracy: 99.87%
Test Loss: 0.0000, Test Accuracy: 99.92%
Epoch [1/5], Loss: 0.2600, Accuracy: 95.33%
Test Loss: 0.0000, Test Accuracy: 100.00%
Epoch [2/5], Loss: 0.0131, Accuracy: 100.00%
Test Loss: 0.0000, Test Accuracy: 100.00%
Epoch [3/5], Loss: 0.0041, Accuracy: 100.00%
Test Loss: 0.0000, Test Accuracy: 100.00%
Epoch [4/5], Loss: 0.0019, Accuracy: 100.00%
Test Loss: 0.0000, Test Accuracy: 100.00%
Epoch [5/5], Loss: 0.0010, Accuracy: 100.00%
Test Loss: 0.0000, Test Accuracy: 100.00%
metric={'acc_node_0': 99.921875, 'acc_node_1': 100.0, 'layer_name': 'key=resid_post n=0'}


  0%|          | 0/100 [00:00<?, ?it/s]

Epoch [1/5], Loss: 0.9839, Accuracy: 85.32%
Test Loss: 0.0000, Test Accuracy: 93.14%
Epoch [2/5], Loss: 0.2294, Accuracy: 93.47%
Test Loss: 0.0000, Test Accuracy: 93.97%
Epoch [3/5], Loss: 0.1966, Accuracy: 93.89%
Test Loss: 0.0000, Test Accuracy: 94.25%
Epoch [4/5], Loss: 0.1787, Accuracy: 94.40%
Test Loss: 0.0000, Test Accuracy: 94.56%
Epoch [5/5], Loss: 0.1598, Accuracy: 94.96%
Test Loss: 0.0000, Test Accuracy: 94.60%
Epoch [1/5], Loss: 1.7985, Accuracy: 82.56%
Test Loss: 0.0000, Test Accuracy: 94.59%
Epoch [2/5], Loss: 0.1405, Accuracy: 97.28%
Test Loss: 0.0000, Test Accuracy: 99.30%
Epoch [3/5], Loss: 0.0545, Accuracy: 99.49%
Test Loss: 0.0000, Test Accuracy: 99.82%
Epoch [4/5], Loss: 0.0273, Accuracy: 99.80%
Test Loss: 0.0000, Test Accuracy: 99.53%
Epoch [5/5], Loss: 0.0186, Accuracy: 99.79%
Test Loss: 0.0000, Test Accuracy: 99.94%
metric={'acc_node_0': 94.60379464285714, 'acc_node_1': 99.94419642857143, 'layer_name': 'key=resid_post n=1'}


  0%|          | 0/100 [00:00<?, ?it/s]

Epoch [1/5], Loss: 1.6028, Accuracy: 79.74%
Test Loss: 0.0000, Test Accuracy: 92.14%
Epoch [2/5], Loss: 0.2696, Accuracy: 92.73%
Test Loss: 0.0000, Test Accuracy: 92.69%
Epoch [3/5], Loss: 0.2291, Accuracy: 93.39%
Test Loss: 0.0000, Test Accuracy: 93.76%
Epoch [4/5], Loss: 0.2044, Accuracy: 93.88%
Test Loss: 0.0000, Test Accuracy: 94.33%
Epoch [5/5], Loss: 0.1893, Accuracy: 94.21%
Test Loss: 0.0000, Test Accuracy: 95.38%
Epoch [1/5], Loss: 1.8390, Accuracy: 78.47%
Test Loss: 0.0000, Test Accuracy: 90.54%
Epoch [2/5], Loss: 0.2796, Accuracy: 93.41%
Test Loss: 0.0000, Test Accuracy: 95.35%
Epoch [3/5], Loss: 0.1627, Accuracy: 96.57%
Test Loss: 0.0000, Test Accuracy: 97.28%
Epoch [4/5], Loss: 0.1110, Accuracy: 97.85%
Test Loss: 0.0000, Test Accuracy: 98.90%
Epoch [5/5], Loss: 0.0769, Accuracy: 98.61%
Test Loss: 0.0000, Test Accuracy: 99.21%
metric={'acc_node_0': 95.37946428571429, 'acc_node_1': 99.21316964285714, 'layer_name': 'key=resid_post n=2'}


  0%|          | 0/100 [00:00<?, ?it/s]

Epoch [1/5], Loss: 3.1391, Accuracy: 70.40%
Test Loss: 0.0000, Test Accuracy: 90.13%
Epoch [2/5], Loss: 0.3467, Accuracy: 91.15%
Test Loss: 0.0000, Test Accuracy: 92.68%
Epoch [3/5], Loss: 0.2908, Accuracy: 92.19%
Test Loss: 0.0000, Test Accuracy: 92.48%
Epoch [4/5], Loss: 0.2607, Accuracy: 92.79%
Test Loss: 0.0000, Test Accuracy: 93.34%
Epoch [5/5], Loss: 0.2431, Accuracy: 93.08%
Test Loss: 0.0000, Test Accuracy: 94.30%
Epoch [1/5], Loss: 2.3548, Accuracy: 74.44%
Test Loss: 0.0000, Test Accuracy: 88.24%
Epoch [2/5], Loss: 0.3895, Accuracy: 90.57%
Test Loss: 0.0000, Test Accuracy: 90.97%
Epoch [3/5], Loss: 0.2645, Accuracy: 93.00%
Test Loss: 0.0000, Test Accuracy: 94.58%
Epoch [4/5], Loss: 0.1874, Accuracy: 95.07%
Test Loss: 0.0000, Test Accuracy: 96.48%
Epoch [5/5], Loss: 0.1485, Accuracy: 96.17%
Test Loss: 0.0000, Test Accuracy: 95.57%
metric={'acc_node_0': 94.30245535714286, 'acc_node_1': 95.56919642857143, 'layer_name': 'key=resid_post n=3'}


  0%|          | 0/100 [00:00<?, ?it/s]

Epoch [1/5], Loss: 2.6012, Accuracy: 71.35%
Test Loss: 0.0000, Test Accuracy: 90.03%
Epoch [2/5], Loss: 0.3403, Accuracy: 91.42%
Test Loss: 0.0000, Test Accuracy: 92.78%
Epoch [3/5], Loss: 0.2875, Accuracy: 92.22%
Test Loss: 0.0000, Test Accuracy: 91.96%
Epoch [4/5], Loss: 0.2619, Accuracy: 92.69%
Test Loss: 0.0000, Test Accuracy: 93.02%
Epoch [5/5], Loss: 0.2407, Accuracy: 93.21%
Test Loss: 0.0000, Test Accuracy: 93.86%
Epoch [1/5], Loss: 2.7030, Accuracy: 73.26%
Test Loss: 0.0000, Test Accuracy: 90.08%
Epoch [2/5], Loss: 0.4188, Accuracy: 90.18%
Test Loss: 0.0000, Test Accuracy: 91.65%
Epoch [3/5], Loss: 0.2766, Accuracy: 92.73%
Test Loss: 0.0000, Test Accuracy: 94.07%
Epoch [4/5], Loss: 0.1988, Accuracy: 94.88%
Test Loss: 0.0000, Test Accuracy: 97.11%
Epoch [5/5], Loss: 0.1491, Accuracy: 96.21%
Test Loss: 0.0000, Test Accuracy: 96.48%
metric={'acc_node_0': 93.85602678571429, 'acc_node_1': 96.484375, 'layer_name': 'key=resid_post n=4'}


  0%|          | 0/100 [00:00<?, ?it/s]

Epoch [1/5], Loss: 2.2266, Accuracy: 72.71%
Test Loss: 0.0000, Test Accuracy: 90.05%
Epoch [2/5], Loss: 0.3424, Accuracy: 91.21%
Test Loss: 0.0000, Test Accuracy: 92.07%
Epoch [3/5], Loss: 0.2922, Accuracy: 92.05%
Test Loss: 0.0000, Test Accuracy: 92.91%
Epoch [4/5], Loss: 0.2544, Accuracy: 92.94%
Test Loss: 0.0000, Test Accuracy: 93.35%
Epoch [5/5], Loss: 0.2343, Accuracy: 93.36%
Test Loss: 0.0000, Test Accuracy: 94.08%
Epoch [1/5], Loss: 2.8577, Accuracy: 72.81%
Test Loss: 0.0000, Test Accuracy: 90.01%
Epoch [2/5], Loss: 0.3831, Accuracy: 90.54%
Test Loss: 0.0000, Test Accuracy: 92.29%
Epoch [3/5], Loss: 0.2567, Accuracy: 93.11%
Test Loss: 0.0000, Test Accuracy: 93.47%
Epoch [4/5], Loss: 0.1806, Accuracy: 95.47%
Test Loss: 0.0000, Test Accuracy: 97.59%
Epoch [5/5], Loss: 0.1436, Accuracy: 96.32%
Test Loss: 0.0000, Test Accuracy: 96.36%
metric={'acc_node_0': 94.08482142857143, 'acc_node_1': 96.35602678571429, 'layer_name': 'key=resid_post n=5'}


In [82]:
pd.DataFrame(accuracies)

Unnamed: 0,acc_node_0,acc_node_1,layer_name
0,6.640625,100.0,key=embed n=None
1,99.921875,100.0,key=resid_post n=0
2,94.603795,99.944196,key=resid_post n=1
3,95.379464,99.21317,key=resid_post n=2
4,94.302455,95.569196,key=resid_post n=3
5,93.856027,96.484375,key=resid_post n=4
6,94.084821,96.356027,key=resid_post n=5


### probe token that is next to second node token

In [83]:
third_token_idx = torch.arange(2, 14*3, 3)
third_token_idx.v

tensor[14] i64 x∈[2, 41] μ=21.500 σ=12.550
tensor([ 2,  5,  8, 11, 14, 17, 20, 23, 26, 29, 32, 35, 38, 41])

In [84]:
def extract_Ln_third_node(batch, model, key='resid_post', n=0):
    cache = get_cache(batch, model)
    resid_act0 = cache[key, n]
    
    X = resid_act0[:, third_token_idx].reshape(-1, d_model)
    y = get_edge_labels(batch).reshape(-1)
    
    return X,y

In [85]:
accuracies = []

for cache_key in cache_keys:
    n_batches = 100
    X = []
    y = []
    
    for i,batch in tqdm(zip(range(n_batches), test_dataloader), total=n_batches):
        Xbatch, ybatch = extract_Ln_third_node(batch, model, **cache_key)
    
        X.append(Xbatch)
        y.append(ybatch)
    
    
    X = torch.cat(X)
    y = torch.cat(y)
    
    y_first_node = torch.tensor( [get_first_node(i.item()) for i in y] )
    y_second_node = torch.tensor( [get_second_node(i.item()) for i in y] )

    acc_node0 = run_logreg(X, y_first_node)
    acc_node1 = run_logreg(X, y_second_node)


    layer_name = ''.join ( [f'{k}={v} ' for k,v in cache_key.items()] )[:-1]
    metric = {'acc_node_0': float(acc_node0), 'acc_node_1': float(acc_node1), 'layer_name':layer_name}
    print(f'{metric=}')

    accuracies.append(metric)

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch [1/5], Loss: 2.7750, Accuracy: 6.20%
Test Loss: 0.0000, Test Accuracy: 5.99%
Epoch [2/5], Loss: 2.7750, Accuracy: 6.37%
Test Loss: 0.0000, Test Accuracy: 6.37%
Epoch [3/5], Loss: 2.7748, Accuracy: 6.32%
Test Loss: 0.0000, Test Accuracy: 6.19%
Epoch [4/5], Loss: 2.7748, Accuracy: 6.34%
Test Loss: 0.0000, Test Accuracy: 6.19%
Epoch [5/5], Loss: 2.7747, Accuracy: 6.34%
Test Loss: 0.0000, Test Accuracy: 6.43%
Epoch [1/5], Loss: 2.7749, Accuracy: 6.26%
Test Loss: 0.0000, Test Accuracy: 6.56%
Epoch [2/5], Loss: 2.7749, Accuracy: 6.16%
Test Loss: 0.0000, Test Accuracy: 6.07%
Epoch [3/5], Loss: 2.7746, Accuracy: 6.17%
Test Loss: 0.0000, Test Accuracy: 6.56%
Epoch [4/5], Loss: 2.7748, Accuracy: 6.33%
Test Loss: 0.0000, Test Accuracy: 6.25%
Epoch [5/5], Loss: 2.7748, Accuracy: 6.38%
Test Loss: 0.0000, Test Accuracy: 6.25%
metric={'acc_node_0': 6.434151785714286, 'acc_node_1': 6.25, 'layer_name': 'key=embed n=None'}


  0%|          | 0/100 [00:00<?, ?it/s]

Epoch [1/5], Loss: 0.2662, Accuracy: 92.34%
Test Loss: 0.0000, Test Accuracy: 93.60%
Epoch [2/5], Loss: 0.2019, Accuracy: 93.51%
Test Loss: 0.0000, Test Accuracy: 93.61%
Epoch [3/5], Loss: 0.2019, Accuracy: 93.54%
Test Loss: 0.0000, Test Accuracy: 93.57%
Epoch [4/5], Loss: 0.2005, Accuracy: 93.54%
Test Loss: 0.0000, Test Accuracy: 93.60%
Epoch [5/5], Loss: 0.2017, Accuracy: 93.58%
Test Loss: 0.0000, Test Accuracy: 93.52%
Epoch [1/5], Loss: 0.3866, Accuracy: 91.06%
Test Loss: 0.0000, Test Accuracy: 93.38%
Epoch [2/5], Loss: 0.2751, Accuracy: 93.30%
Test Loss: 0.0000, Test Accuracy: 93.39%
Epoch [3/5], Loss: 0.2681, Accuracy: 93.26%
Test Loss: 0.0000, Test Accuracy: 93.36%
Epoch [4/5], Loss: 0.2613, Accuracy: 93.29%
Test Loss: 0.0000, Test Accuracy: 93.30%
Epoch [5/5], Loss: 0.2570, Accuracy: 93.26%
Test Loss: 0.0000, Test Accuracy: 93.35%
metric={'acc_node_0': 93.515625, 'acc_node_1': 93.34821428571429, 'layer_name': 'key=resid_post n=0'}


  0%|          | 0/100 [00:00<?, ?it/s]

Epoch [1/5], Loss: 0.7171, Accuracy: 91.13%
Test Loss: 0.0000, Test Accuracy: 93.30%
Epoch [2/5], Loss: 0.4164, Accuracy: 93.46%
Test Loss: 0.0000, Test Accuracy: 93.50%
Epoch [3/5], Loss: 0.3902, Accuracy: 93.50%
Test Loss: 0.0000, Test Accuracy: 93.48%
Epoch [4/5], Loss: 0.4124, Accuracy: 93.47%
Test Loss: 0.0000, Test Accuracy: 93.60%
Epoch [5/5], Loss: 0.4039, Accuracy: 93.49%
Test Loss: 0.0000, Test Accuracy: 93.59%
Epoch [1/5], Loss: 0.8389, Accuracy: 89.46%
Test Loss: 0.0000, Test Accuracy: 93.38%
Epoch [2/5], Loss: 0.4604, Accuracy: 93.36%
Test Loss: 0.0000, Test Accuracy: 93.46%
Epoch [3/5], Loss: 0.4440, Accuracy: 93.29%
Test Loss: 0.0000, Test Accuracy: 93.40%
Epoch [4/5], Loss: 0.4356, Accuracy: 93.33%
Test Loss: 0.0000, Test Accuracy: 93.44%
Epoch [5/5], Loss: 0.4179, Accuracy: 93.32%
Test Loss: 0.0000, Test Accuracy: 93.48%
metric={'acc_node_0': 93.58816964285714, 'acc_node_1': 93.48214285714286, 'layer_name': 'key=resid_post n=1'}


  0%|          | 0/100 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
pd.DataFrame(accuracies)

In [None]:
fourth_token_idx = torch.arange(3, 15*3, 3) # <<-- Nota Bene magic number change
fourth_token_idx.v

In [None]:
def extract_Ln_fourth_node(batch, model, key='resid_post', n=0):
    cache = get_cache(batch, model)
    resid_act0 = cache[key, n]
    
    X = resid_act0[:, fourth_token_idx].reshape(-1, d_model)
    y = get_edge_labels(batch).reshape(-1)
    
    return X,y

In [None]:
accuracies = []

for cache_key in cache_keys:
    n_batches = 100
    X = []
    y = []
    
    for i,batch in tqdm(zip(range(n_batches), test_dataloader), total=n_batches):
        Xbatch, ybatch = extract_Ln_fourth_node(batch, model, **cache_key)
    
        X.append(Xbatch)
        y.append(ybatch)
    
    
    X = torch.cat(X)
    y = torch.cat(y)
    
    y_first_node = torch.tensor( [get_first_node(i.item()) for i in y] )
    y_second_node = torch.tensor( [get_second_node(i.item()) for i in y] )

    acc_node0 = run_logreg(X, y_first_node)
    acc_node1 = run_logreg(X, y_second_node)


    layer_name = ''.join ( [f'{k}={v} ' for k,v in cache_key.items()] )[:-1]
    metric = {'acc_node_0': float(acc_node0), 'acc_node_1': float(acc_node1), 'layer_name':layer_name}
    print(f'{metric=}')

    accuracies.append(metric)

In [None]:
pd.DataFrame(accuracies)

In [None]:
# model.fit(X_train[:], y_train[:])