In [1]:
N_TEST_BATCHES = 100

## Setup

In [2]:
import sys; sys.path.append('..')

In [3]:
# ! pip install lovely-tensors

import lovely_tensors as lt
lt.monkey_patch()

In [4]:
import torch
import numpy as np
import pandas as pd
import torch.nn.functional as F
from torch.utils.data import DataLoader

from transformer_lens import HookedTransformer, HookedTransformerConfig

In [5]:
import wandb
from tqdm.auto import tqdm

from omegaconf import OmegaConf

from datetime import datetime

from pathlib import Path

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

In [6]:
from src.tree import list2tree
from src.tree_dataset import TreeDataset, parse_input_idx, input_tokens_to_tree, tree_to_edges
from src.utils import seed_all


from src.trainer import accuracy_by_depth
from src.trainer import TreeTrainer

In [7]:
conf = OmegaConf.load('../conf/00_reproduce_6L_nodes=16.yaml')

# Output is identical to the YAML file
conf.n_nodes = 16
conf.device = 'cpu'
print(OmegaConf.to_yaml(conf))

random_seed: 42
n_nodes: 16
model:
  d_model: 128
  d_head: 128
  n_layers: 6
  act_fn: gelu
  attention_dir: causal
optimizer:
  lr: 0.001
  weight_decay: 0.01
batch_size: 64
epoch_len_steps: 5000
checkpoint_every_epoch: 2
device: cpu
debug: false
use_wandb: true
wandb:
  project: reasoning-mech-interp
  name: 00_6L_nodes=16
max_iters: null



In [8]:
REPRODUCED_MODEL_CKPT = '../checkpoints/reasoning-mech-interp__2024-04-10_16-10-20/00_6L_nodes=16__step=220000.pt'
# REPRODUCED_MODEL_CKPT = '../checkpoints/reasoning-mech-interp__2024-04-12_14-26-20/00_6L_nodes=16__deep_trees__step=9256.pt'

In [9]:
DEV_RUN = False
USE_WANDB = (not DEV_RUN) and conf.use_wandb
device = conf.device

CHECKPOINT_ROOT = Path('../checkpoints')

In [10]:
RANDOM_SEED = conf['random_seed']
print(f'{RANDOM_SEED=}')
seed_all(RANDOM_SEED)

RANDOM_SEED=42


In [11]:
trainer = TreeTrainer(conf)

tokenizer = trainer.dataset.tokenizer

tok = tokenizer.tokenize
detok = tokenizer.detokenize


ROOT_DELIM_TOKEN_IDX = trainer.tok([':'])[0]

state_dict = torch.load(REPRODUCED_MODEL_CKPT)
trainer.model.load_state_dict(state_dict)

Moving model to device:  cpu


<All keys matched successfully>

In [12]:
def load_baseline_model(device='cpu'):
    n_states = 16
    max_seq_length = n_states * 4 + 2
    
    number_tokens = sorted([str(i) for i in range(n_states)], key=lambda x: len(x), reverse=True)
    idx2tokens = [",", ":", "|"] + [f">{t}" for t in number_tokens] + number_tokens
    tokens2idx = {token: idx for idx, token in enumerate(idx2tokens)}
    
    
    cfg = HookedTransformerConfig(
        n_layers=6,
        d_model=128,
        n_ctx=max_seq_length - 1,
        n_heads=1,
        d_mlp=512,
        d_head=128,
        #attn_only=True,
        d_vocab=len(idx2tokens),
        device=device,
        attention_dir= "causal",
        act_fn="gelu",
    )
    model = HookedTransformer(cfg)
    
    model.load_state_dict(torch.load("/Users/mykhailokilianovskyi/src/backward-chaining-circuits/model.pt", map_location=torch.device(device)))
    
    return model

In [13]:
import random
import collections
from torch.utils.data import IterableDataset, DataLoader

In [14]:
import random

from src.tree import TreeNode
from src.utils import seed_all
from src.tree_dataset import random_tree_of_depth, DeepTreeDataset

In [15]:
deep_dataset = DeepTreeDataset(n_nodes=16, possible_depths=(15,14,13))
deep_tree_dataloader = DataLoader(deep_dataset, batch_size=conf['batch_size'])

In [16]:
baseline_model = load_baseline_model()

### Setup Linear Probe

In [17]:
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score

from sklearn.model_selection import train_test_split


from torch import nn


class LogisticRegressionModel(nn.Module):
    def __init__(self, n_features, n_classes):
        super(LogisticRegressionModel, self).__init__()
        self.linear = nn.Linear(n_features, n_classes)
    
    def forward(self, x):
        return self.linear(x)

# Assuming n_features is the number of features in your dataset and n_classes is the number of classes


In [18]:
from torch.utils.data import DataLoader, TensorDataset

def run_logreg(X,y, num_epochs=5, verbose=False):
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)


    logreg = LogisticRegressionModel(X_train.shape[1], len(np.unique(y_train)))
    
    logreg_train_dataset = TensorDataset(X_train, y_train)
    train_loader = DataLoader(logreg_train_dataset, batch_size=64, shuffle=True)
    
    logreg_test_dataset = TensorDataset(X_test, y_test)
    logreg_test_loader = DataLoader(logreg_test_dataset, batch_size=64, shuffle=True)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(logreg.parameters(), lr=0.001)
    
    
    train_losses = []
    train_accuracies = []
    
    test_losses = []
    test_accuracies = []
    
    for epoch in range(num_epochs):
        logreg.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = logreg(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
        epoch_loss = running_loss / len(train_loader)
        epoch_accuracy = 100 * correct / total
        train_losses.append(epoch_loss)
        train_accuracies.append(epoch_accuracy)

        if verbose:
            print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%')
    
        running_loss = 0.0
        correct = 0
        total = 0
        for inputs, labels in logreg_test_loader:
    
            with torch.inference_mode():
                outputs = logreg(inputs)
    
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
        epoch_loss = running_loss / len(logreg_test_loader)
        epoch_accuracy = 100 * correct / total
        train_losses.append(epoch_loss)
        train_accuracies.append(epoch_accuracy)

        if verbose:
            print(f'Test Loss: {epoch_loss:.4f}, Test Accuracy: {epoch_accuracy:.2f}%')
    return epoch_accuracy

In [19]:
our_idx2token = trainer.dataset.tokenizer.idx2token


n_states = 16
bas_max_seq_length = n_states * 4 + 2

number_tokens = sorted([str(i) for i in range(n_states)], key=lambda x: len(x), reverse=True)
idx2tokens = [",", ":", "|"] + [f">{t}" for t in number_tokens] + number_tokens
tokens2idx = {token: idx for idx, token in enumerate(idx2tokens)}

from src.tree_dataset import PAD_TOKEN


token2bastoken = {k.replace('>', '→'):k for k in tokens2idx.keys()}
token2bastoken[PAD_TOKEN] = ','



idx2basidx = {}
for idx, our_tok in our_idx2token.items():
    bastoken = token2bastoken[our_tok]
    basidx = tokens2idx[bastoken]
    idx2basidx[idx] = basidx

In [20]:
def baseline_batch_train_step(baseline_model, batch):
    
    input_idx = batch['input_idx'][..., :bas_max_seq_length].clone().to(device)
    mask = batch['task_mask'][..., :bas_max_seq_length].clone().to(device)

    input_idx.apply_(lambda i: idx2basidx[i])

    inputs = input_idx[:, :-1]
    
    out_mask = mask[:, 1:]
    targets = input_idx[:, 1:][out_mask]
    
    
    # print(input_idx[:1, :4])
    outputs = baseline_model(inputs)
    
    predictions = outputs[out_mask]
    
    loss = F.cross_entropy(predictions, targets)

    is_correct = (predictions.argmax(dim=-1) == targets)
    accuracy_mean = is_correct.float().mean()
    metrics = accuracy_by_depth(outputs, input_idx, out_mask)
    metrics['accuracy/mean'] = accuracy_mean.item()

    return loss, metrics

In [21]:

ROOT_DELIM_TOKEN_IDX = trainer.tok([':'])[0]

def extract_edges(sample_input_idx):
    upper_task_bound = sample_input_idx.tolist().index(ROOT_DELIM_TOKEN_IDX) + 2
    prompt = trainer.detok(sample_input_idx)[:upper_task_bound]
    
    i = 2
    
    edges = []
    
    while i < len(prompt) and prompt[i] != '|':
        edge = (prompt[i-2], prompt[i-1])
        edges.append(edge)
        i += 3
    
    return edges


n_nodes = 16

edges = [(str(i), '→'+str(j)) for i in range(n_nodes) for j in range(n_nodes)]

edge2idx = {e:i for i,e in enumerate(edges)}

idx2edge = {i:e for e,i in edge2idx.items()}
print(f'{len(edge2idx)=}')

def get_first_node(edge_idx): return int(idx2edge[edge_idx][0])
def get_second_node(edge_idx): return int(idx2edge[edge_idx][1][1:])
    

def get_edge_labels(batch):
    input_idx = batch['input_idx']
    edge_batch = []
    for row in input_idx:
        edges = extract_edges(row)
        edges = [edge2idx[e] for e in edges]
        edge_batch.append(edges)
    return torch.tensor(edge_batch)

len(edge2idx)=256


## Check that goal node is moved

### Reproduce "edge token" detection

In [22]:
third_token_idx = torch.arange(2, 14*3, 3)
third_token_idx.v

tensor[14] i64 x∈[2, 41] μ=21.500 σ=12.550
tensor([ 2,  5,  8, 11, 14, 17, 20, 23, 26, 29, 32, 35, 38, 41])

In [23]:
cache_keys = [{'key': 'resid_post', "n": n} for n in range(6)]
cache_keys = [{'key': 'embed', 'n': None}] + cache_keys
cache_keys = [
    {'key': 'embed', 'n': None}, # just embeddings as a control experiment
 {'key': 'resid_post', 'n': 0},
 {'key': 'resid_post', 'n': 1},
 {'key': 'resid_post', 'n': 2},
 {'key': 'resid_post', 'n': 3},
 {'key': 'resid_post', 'n': 4},
 {'key': 'resid_post', 'n': 5}]



In [24]:
for batch in deep_tree_dataloader:
    break

In [25]:
torch.inference_mode()
def get_cache_baseline_model(batch, baseline_model):
    input_idx = batch['input_idx'][..., :bas_max_seq_length].clone().to(device)
    mask = batch['task_mask'][..., :bas_max_seq_length].clone().to(device)

    input_idx.apply_(lambda i: idx2basidx[i])

    inputs = input_idx[:, :-1]
    
    out_mask = mask[:, 1:]
    targets = input_idx[:, 1:][out_mask]
    
    
    # print(input_idx[:1, :4])
    # outputs = baseline_model(inputs)
    
    outputs, cache = baseline_model.run_with_cache(inputs)
    return cache

In [26]:
second_node_idx = torch.arange(1, 14*3, 3)
second_node_idx.v

tensor[14] i64 x∈[1, 40] μ=20.500 σ=12.550
tensor([ 1,  4,  7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40])

In [27]:
pad_token_idx = torch.arange(2, 14*3, 3)
pad_token_idx.v

tensor[14] i64 x∈[2, 41] μ=21.500 σ=12.550
tensor([ 2,  5,  8, 11, 14, 17, 20, 23, 26, 29, 32, 35, 38, 41])

In [28]:
d_model = conf.model.d_model


def extract_Ln_second_node_xy(batch, baseline_model, key='resid_post', n=0):
    cache = get_cache_baseline_model(batch, baseline_model)
    resid_act0 = cache[key, n]
    
    X = resid_act0[:, second_node_idx].reshape(-1, d_model)
    y = get_edge_labels(batch).reshape(-1)
    
    return X,y


def extract_Ln_pad_token_xy(batch, baseline_model, key='resid_post', n=0):
    cache = get_cache_baseline_model(batch, baseline_model)
    resid_act0 = cache[key, n]
    
    X = resid_act0[:, pad_token_idx].reshape(-1, d_model)
    y = get_edge_labels(batch).reshape(-1)
    
    return X,y

In [29]:
baseline_model = load_baseline_model()

In [30]:
accuracies = []

for cache_key in cache_keys:

    X = []
    y = []
    
    for i,batch in tqdm(zip(range(N_TEST_BATCHES), deep_tree_dataloader), total=N_TEST_BATCHES):
        Xbatch, ybatch = extract_Ln_second_node_xy(batch, baseline_model, **cache_key)
    
        X.append(Xbatch)
        y.append(ybatch)
    
    
    X = torch.cat(X)
    y = torch.cat(y)
    
    y_first_node = torch.tensor( [get_first_node(i.item()) for i in y] )
    y_second_node = torch.tensor( [get_second_node(i.item()) for i in y] )

    acc_node0 = run_logreg(X, y_first_node)
    acc_node1 = run_logreg(X, y_second_node)


    layer_name = ''.join ( [f'{k}={v} ' for k,v in cache_key.items()] )[:-1]
    metric = {'acc_node_0': float(acc_node0), 'acc_node_1': float(acc_node1), 'layer_name':layer_name}
    print(f'{metric=}')

    accuracies.append(metric)

print('extract_Ln_second_node_xy')
pd.DataFrame(accuracies)

  0%|          | 0/100 [00:00<?, ?it/s]

metric={'acc_node_0': 6.478794642857143, 'acc_node_1': 100.0, 'layer_name': 'key=embed n=None'}


  0%|          | 0/100 [00:00<?, ?it/s]

metric={'acc_node_0': 100.0, 'acc_node_1': 100.0, 'layer_name': 'key=resid_post n=0'}


  0%|          | 0/100 [00:00<?, ?it/s]

metric={'acc_node_0': 100.0, 'acc_node_1': 100.0, 'layer_name': 'key=resid_post n=1'}


  0%|          | 0/100 [00:00<?, ?it/s]

metric={'acc_node_0': 100.0, 'acc_node_1': 100.0, 'layer_name': 'key=resid_post n=2'}


  0%|          | 0/100 [00:00<?, ?it/s]

metric={'acc_node_0': 100.0, 'acc_node_1': 100.0, 'layer_name': 'key=resid_post n=3'}


  0%|          | 0/100 [00:00<?, ?it/s]

metric={'acc_node_0': 100.0, 'acc_node_1': 100.0, 'layer_name': 'key=resid_post n=4'}


  0%|          | 0/100 [00:00<?, ?it/s]

metric={'acc_node_0': 100.0, 'acc_node_1': 100.0, 'layer_name': 'key=resid_post n=5'}
extract_Ln_second_node_xy


Unnamed: 0,acc_node_0,acc_node_1,layer_name
0,6.478795,100.0,key=embed n=None
1,100.0,100.0,key=resid_post n=0
2,100.0,100.0,key=resid_post n=1
3,100.0,100.0,key=resid_post n=2
4,100.0,100.0,key=resid_post n=3
5,100.0,100.0,key=resid_post n=4
6,100.0,100.0,key=resid_post n=5


### Search for edge in pad token

In [31]:
accuracies = []

for cache_key in cache_keys:
    
    X = []
    y = []
    
    for i,batch in tqdm(zip(range(N_TEST_BATCHES), deep_tree_dataloader), total=N_TEST_BATCHES):
        Xbatch, ybatch = extract_Ln_pad_token_xy(batch, baseline_model, **cache_key)
    
        X.append(Xbatch)
        y.append(ybatch)
    
    
    X = torch.cat(X)
    y = torch.cat(y)
    
    y_first_node = torch.tensor( [get_first_node(i.item()) for i in y] )
    y_second_node = torch.tensor( [get_second_node(i.item()) for i in y] )

    acc_node0 = run_logreg(X, y_first_node)
    acc_node1 = run_logreg(X, y_second_node)


    layer_name = ''.join ( [f'{k}={v} ' for k,v in cache_key.items()] )[:-1]
    metric = {'acc_node_0': float(acc_node0), 'acc_node_1': float(acc_node1), 'layer_name':layer_name}
    print(f'{metric=}')

    accuracies.append(metric)

print('extract_Ln_pad_token_xy')
pd.DataFrame(accuracies)

  0%|          | 0/100 [00:00<?, ?it/s]

metric={'acc_node_0': 6.294642857142857, 'acc_node_1': 6.283482142857143, 'layer_name': 'key=embed n=None'}


  0%|          | 0/100 [00:00<?, ?it/s]

metric={'acc_node_0': 19.575892857142858, 'acc_node_1': 6.40625, 'layer_name': 'key=resid_post n=0'}


  0%|          | 0/100 [00:00<?, ?it/s]

metric={'acc_node_0': 9.486607142857142, 'acc_node_1': 7.109375, 'layer_name': 'key=resid_post n=1'}


  0%|          | 0/100 [00:00<?, ?it/s]

metric={'acc_node_0': 9.893973214285714, 'acc_node_1': 6.640625, 'layer_name': 'key=resid_post n=2'}


  0%|          | 0/100 [00:00<?, ?it/s]

metric={'acc_node_0': 12.204241071428571, 'acc_node_1': 8.030133928571429, 'layer_name': 'key=resid_post n=3'}


  0%|          | 0/100 [00:00<?, ?it/s]

metric={'acc_node_0': 16.852678571428573, 'acc_node_1': 21.534598214285715, 'layer_name': 'key=resid_post n=4'}


  0%|          | 0/100 [00:00<?, ?it/s]

metric={'acc_node_0': 11.428571428571429, 'acc_node_1': 22.957589285714285, 'layer_name': 'key=resid_post n=5'}
extract_Ln_pad_token_xy


Unnamed: 0,acc_node_0,acc_node_1,layer_name
0,6.294643,6.283482,key=embed n=None
1,19.575893,6.40625,key=resid_post n=0
2,9.486607,7.109375,key=resid_post n=1
3,9.893973,6.640625,key=resid_post n=2
4,12.204241,8.030134,key=resid_post n=3
5,16.852679,21.534598,key=resid_post n=4
6,11.428571,22.957589,key=resid_post n=5


### Things become interesting - search for goal node!

In [32]:
def extract_Ln_pad_token_xy(batch, baseline_model, key='resid_post', n=0):
    cache = get_cache_baseline_model(batch, baseline_model)
    resid_act0 = cache[key, n]
    
    X = resid_act0[:, pad_token_idx].reshape(-1, d_model)
    y = get_edge_labels(batch).reshape(-1)
    
    return X,y

In [33]:
input_idx = batch['input_idx'][..., :bas_max_seq_length].clone().to(device)
mask = batch['task_mask'][..., :bas_max_seq_length].clone().to(device)

input_idx.apply_(lambda i: idx2basidx[i])

inputs = input_idx[:, :-1]

out_mask = mask[:, 1:]
targets = input_idx[:, 1:][out_mask]


# print(input_idx[:1, :4])
# outputs = baseline_model(inputs)

outputs, cache = baseline_model.run_with_cache(inputs)

In [34]:
LAST_TOKEN_I = 47
[idx2tokens[i] for i in inputs[:, LAST_TOKEN_I-1]][:3]

[':', ':', ':']

In [35]:
GOAL_TOKEN_I = LAST_TOKEN_I - 2
[idx2tokens[i] for i in inputs[:, GOAL_TOKEN_I+1]][:3]

[':', ':', ':']

In [36]:
inputs[:, LAST_TOKEN_I+1].v

tensor[64] i64 x∈[3, 18] μ=10.094 σ=4.363
tensor([ 4,  7, 10,  5, 11,  6,  9, 12, 16,  4, 12,  6, 11, 18,  8, 17,  9,  4,
        17,  6,  3, 10, 11, 12, 11, 14,  3, 11, 15, 18,  9, 15,  8,  5, 13, 15,
        15, 16,  7,  6, 11, 12, 16, 10,  4,  7, 13,  5, 11, 13,  6, 14, 14, 12,
         7, 16,  6,  3,  5,  5, 15,  8,  8, 16])

In [37]:
outputs[:, LAST_TOKEN_I].argmax(dim=-1).v

tensor[64] i64 x∈[3, 18] μ=10.094 σ=4.363
tensor([ 4,  7, 10,  5, 11,  6,  9, 12, 16,  4, 12,  6, 11, 18,  8, 17,  9,  4,
        17,  6,  3, 10, 11, 12, 11, 14,  3, 11, 15, 18,  9, 15,  8,  5, 13, 15,
        15, 16,  7,  6, 11, 12, 16, 10,  4,  7, 13,  5, 11, 13,  6, 14, 14, 12,
         7, 16,  6,  3,  5,  5, 15,  8,  8, 16])

In [38]:
key='resid_post'
n=0

In [39]:
def extract_Xy__last_token__goal(batch, baseline_model, key='resid_post', n=0):
    cache = get_cache_baseline_model(batch, baseline_model)
    resid_act0 = cache[key, n]
    
    X = resid_act0[:, LAST_TOKEN_I].reshape(-1, d_model)

    def extract_node_id_from_pos(pos):      
        y = batch['input_idx'][:, pos]
        y = detok(y) # NOTA BENE: THIS IS IMPORTANT STEP
        y = torch.tensor([int(node_name_str.replace('→', '')) for node_name_str in y])
        return y

    y = extract_node_id_from_pos(GOAL_TOKEN_I)

    # extra_pos_dict {pos: id of a node at pos}
    extra_pos_dict = {i:extract_node_id_from_pos(i) for i in [0,1, 3,4]}
    
    return X, y, extra_pos_dict

In [40]:
X, y, extra_pos_dict = extract_Xy__last_token__goal(batch, baseline_model)

In [41]:
batch['input_idx'][0][GOAL_TOKEN_I]

tensor i64 6

In [42]:
# detok(batch['input_idx'][0])

In [43]:
trainer.print_sample_pred(batch['input_idx'][0])
print(f'{y[0]=}')

****************************************************************************************************
                                               5
                                               |
                                            +--+
                                            |
                                           11
                                            |
                                         +--+
                                         |
                                         3
                                         |
                                      +--+
                                      |
                                      8
                                      |
     +--------------------------------+
     |
     7
     |
  +--+-----------------------------+
  |                                |
  6                                0
                                   |
                                +--+
                              

In [44]:
import collections

In [45]:
accuracies = []

for cache_key in cache_keys:
    
    X = []
    y = []

    extra_pos_dict = collections.defaultdict(list)
    
    for i,batch in tqdm(zip(range(N_TEST_BATCHES), deep_tree_dataloader), total=N_TEST_BATCHES):
        Xbatch, ybatch, extra_pos_dict_batch = extract_Xy__last_token__goal(batch, baseline_model, **cache_key)
    
        X.append(Xbatch)
        y.append(ybatch)

        for pos, label_batch in extra_pos_dict_batch.items():
            extra_pos_dict[pos].append(label_batch)
    
    
    X = torch.cat(X)
    y = torch.cat(y)

    for pos, label_batch in extra_pos_dict.items():
        extra_pos_dict[pos] = torch.cat(extra_pos_dict[pos])

    

    acc_goal_from_last_node = run_logreg(X, y)
    


    layer_name = ''.join ( [f'{k}={v} ' for k,v in cache_key.items()] )[:-1]
    metric = {'acc_goal_from_last_node': float(acc_goal_from_last_node), 'layer_name':layer_name}

    for pos, node_label_batch in tqdm(extra_pos_dict.items(), leave=False):
        acc = run_logreg(X, node_label_batch)
        metric[f'acc_of_predicting_node_at_pos={pos}'] = acc
    
    print(f'{metric=}')

    accuracies.append(metric)

print('extract_Xy__last_token__goal')
pd.DataFrame(accuracies)

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

metric={'acc_goal_from_last_node': 6.015625, 'layer_name': 'key=embed n=None', 'acc_of_predicting_node_at_pos=0': 6.40625, 'acc_of_predicting_node_at_pos=1': 6.5625, 'acc_of_predicting_node_at_pos=3': 5.78125, 'acc_of_predicting_node_at_pos=4': 7.03125}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

metric={'acc_goal_from_last_node': 100.0, 'layer_name': 'key=resid_post n=0', 'acc_of_predicting_node_at_pos=0': 6.953125, 'acc_of_predicting_node_at_pos=1': 6.953125, 'acc_of_predicting_node_at_pos=3': 6.5625, 'acc_of_predicting_node_at_pos=4': 7.265625}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

metric={'acc_goal_from_last_node': 100.0, 'layer_name': 'key=resid_post n=1', 'acc_of_predicting_node_at_pos=0': 7.734375, 'acc_of_predicting_node_at_pos=1': 6.328125, 'acc_of_predicting_node_at_pos=3': 5.703125, 'acc_of_predicting_node_at_pos=4': 7.1875}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

metric={'acc_goal_from_last_node': 100.0, 'layer_name': 'key=resid_post n=2', 'acc_of_predicting_node_at_pos=0': 9.0625, 'acc_of_predicting_node_at_pos=1': 6.5625, 'acc_of_predicting_node_at_pos=3': 7.8125, 'acc_of_predicting_node_at_pos=4': 7.8125}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

metric={'acc_goal_from_last_node': 100.0, 'layer_name': 'key=resid_post n=3', 'acc_of_predicting_node_at_pos=0': 9.296875, 'acc_of_predicting_node_at_pos=1': 7.34375, 'acc_of_predicting_node_at_pos=3': 8.90625, 'acc_of_predicting_node_at_pos=4': 6.328125}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

metric={'acc_goal_from_last_node': 100.0, 'layer_name': 'key=resid_post n=4', 'acc_of_predicting_node_at_pos=0': 8.984375, 'acc_of_predicting_node_at_pos=1': 7.109375, 'acc_of_predicting_node_at_pos=3': 9.6875, 'acc_of_predicting_node_at_pos=4': 7.265625}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

metric={'acc_goal_from_last_node': 100.0, 'layer_name': 'key=resid_post n=5', 'acc_of_predicting_node_at_pos=0': 7.96875, 'acc_of_predicting_node_at_pos=1': 7.265625, 'acc_of_predicting_node_at_pos=3': 7.890625, 'acc_of_predicting_node_at_pos=4': 6.484375}
extract_Xy__last_token__goal


Unnamed: 0,acc_goal_from_last_node,layer_name,acc_of_predicting_node_at_pos=0,acc_of_predicting_node_at_pos=1,acc_of_predicting_node_at_pos=3,acc_of_predicting_node_at_pos=4
0,6.015625,key=embed n=None,6.40625,6.5625,5.78125,7.03125
1,100.0,key=resid_post n=0,6.953125,6.953125,6.5625,7.265625
2,100.0,key=resid_post n=1,7.734375,6.328125,5.703125,7.1875
3,100.0,key=resid_post n=2,9.0625,6.5625,7.8125,7.8125
4,100.0,key=resid_post n=3,9.296875,7.34375,8.90625,6.328125
5,100.0,key=resid_post n=4,8.984375,7.109375,9.6875,7.265625
6,100.0,key=resid_post n=5,7.96875,7.265625,7.890625,6.484375


In [46]:
LAST_TOKEN_I

47

In [47]:
accuracies = []

for cache_key in cache_keys:
    
    X = []
    y = []

    extra_pos_dict = collections.defaultdict(list)
    
    for i,batch in tqdm(zip(range(N_TEST_BATCHES), deep_tree_dataloader), total=N_TEST_BATCHES):
        Xbatch, ybatch, extra_pos_dict_batch = extract_Xy__last_token__goal(batch, baseline_model, **cache_key)
    
        X.append(Xbatch)
        y.append(ybatch)

        for pos, label_batch in extra_pos_dict_batch.items():
            extra_pos_dict[pos].append(label_batch)
    
    
    X = torch.cat(X)
    y = torch.cat(y)

    for pos, label_batch in extra_pos_dict.items():
        extra_pos_dict[pos] = torch.cat(extra_pos_dict[pos])

    

    acc_goal_from_last_node = run_logreg(X, y)
    


    layer_name = ''.join ( [f'{k}={v} ' for k,v in cache_key.items()] )[:-1]
    metric = {'acc_goal_from_last_node': float(acc_goal_from_last_node), 'layer_name':layer_name}

    for pos, node_label_batch in tqdm(extra_pos_dict.items(), leave=False):
        acc = run_logreg(X, node_label_batch)
        metric[f'acc_of_predicting_node_at_pos={pos}'] = acc
    
    print(f'{metric=}')

    accuracies.append(metric)

print('extract_Xy__last_token__goal')
pd.DataFrame(accuracies)

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

metric={'acc_goal_from_last_node': 7.8125, 'layer_name': 'key=embed n=None', 'acc_of_predicting_node_at_pos=0': 5.46875, 'acc_of_predicting_node_at_pos=1': 7.109375, 'acc_of_predicting_node_at_pos=3': 6.875, 'acc_of_predicting_node_at_pos=4': 6.40625}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

metric={'acc_goal_from_last_node': 100.0, 'layer_name': 'key=resid_post n=0', 'acc_of_predicting_node_at_pos=0': 6.25, 'acc_of_predicting_node_at_pos=1': 5.9375, 'acc_of_predicting_node_at_pos=3': 7.265625, 'acc_of_predicting_node_at_pos=4': 5.78125}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

metric={'acc_goal_from_last_node': 100.0, 'layer_name': 'key=resid_post n=1', 'acc_of_predicting_node_at_pos=0': 6.875, 'acc_of_predicting_node_at_pos=1': 6.25, 'acc_of_predicting_node_at_pos=3': 7.890625, 'acc_of_predicting_node_at_pos=4': 6.171875}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

metric={'acc_goal_from_last_node': 100.0, 'layer_name': 'key=resid_post n=2', 'acc_of_predicting_node_at_pos=0': 9.6875, 'acc_of_predicting_node_at_pos=1': 6.875, 'acc_of_predicting_node_at_pos=3': 7.03125, 'acc_of_predicting_node_at_pos=4': 7.109375}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

metric={'acc_goal_from_last_node': 100.0, 'layer_name': 'key=resid_post n=3', 'acc_of_predicting_node_at_pos=0': 9.296875, 'acc_of_predicting_node_at_pos=1': 5.859375, 'acc_of_predicting_node_at_pos=3': 7.8125, 'acc_of_predicting_node_at_pos=4': 6.640625}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

metric={'acc_goal_from_last_node': 100.0, 'layer_name': 'key=resid_post n=4', 'acc_of_predicting_node_at_pos=0': 7.5, 'acc_of_predicting_node_at_pos=1': 8.046875, 'acc_of_predicting_node_at_pos=3': 7.578125, 'acc_of_predicting_node_at_pos=4': 6.171875}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

metric={'acc_goal_from_last_node': 99.921875, 'layer_name': 'key=resid_post n=5', 'acc_of_predicting_node_at_pos=0': 7.890625, 'acc_of_predicting_node_at_pos=1': 7.109375, 'acc_of_predicting_node_at_pos=3': 9.53125, 'acc_of_predicting_node_at_pos=4': 7.03125}
extract_Xy__last_token__goal


Unnamed: 0,acc_goal_from_last_node,layer_name,acc_of_predicting_node_at_pos=0,acc_of_predicting_node_at_pos=1,acc_of_predicting_node_at_pos=3,acc_of_predicting_node_at_pos=4
0,7.8125,key=embed n=None,5.46875,7.109375,6.875,6.40625
1,100.0,key=resid_post n=0,6.25,5.9375,7.265625,5.78125
2,100.0,key=resid_post n=1,6.875,6.25,7.890625,6.171875
3,100.0,key=resid_post n=2,9.6875,6.875,7.03125,7.109375
4,100.0,key=resid_post n=3,9.296875,5.859375,7.8125,6.640625
5,100.0,key=resid_post n=4,7.5,8.046875,7.578125,6.171875
6,99.921875,key=resid_post n=5,7.890625,7.109375,9.53125,7.03125


### SWAPP pos embs

In [48]:
baseline_model = load_baseline_model()
W_pos = baseline_model.pos_embed.W_pos.data.clone()

In [49]:
W_pos__first_token_is_goal = W_pos.clone()

# swap first pos embedding with the goal pos emb
t = W_pos__first_token_is_goal[0].clone()
W_pos__first_token_is_goal[0] = W_pos__first_token_is_goal[GOAL_TOKEN_I]
W_pos__first_token_is_goal[GOAL_TOKEN_I] = t
# --- ---

baseline_model.pos_embed.W_pos.data = W_pos__first_token_is_goal

In [50]:
accuracies = []

for cache_key in cache_keys:
    
    X = []
    y = []

    extra_pos_dict = collections.defaultdict(list)
    
    for i,batch in tqdm(zip(range(N_TEST_BATCHES), deep_tree_dataloader), total=N_TEST_BATCHES):
        Xbatch, ybatch, extra_pos_dict_batch = extract_Xy__last_token__goal(batch, baseline_model, **cache_key)
    
        X.append(Xbatch)
        y.append(ybatch)

        for pos, label_batch in extra_pos_dict_batch.items():
            extra_pos_dict[pos].append(label_batch)
    
    
    X = torch.cat(X)
    y = torch.cat(y)

    for pos, label_batch in extra_pos_dict.items():
        extra_pos_dict[pos] = torch.cat(extra_pos_dict[pos])

    

    acc_goal_from_last_node = run_logreg(X, y)
    


    layer_name = ''.join ( [f'{k}={v} ' for k,v in cache_key.items()] )[:-1]
    metric = {'acc_goal_from_last_node': float(acc_goal_from_last_node), 'layer_name':layer_name}

    for pos, node_label_batch in tqdm(extra_pos_dict.items(), leave=False):
        acc = run_logreg(X, node_label_batch)
        metric[f'acc_of_predicting_node_at_pos={pos}'] = acc
    
    print(f'{metric=}')

    accuracies.append(metric)

print('swap with pos=0 | extract_Xy__last_token__goal')
pd.DataFrame(accuracies)

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

metric={'acc_goal_from_last_node': 6.640625, 'layer_name': 'key=embed n=None', 'acc_of_predicting_node_at_pos=0': 7.1875, 'acc_of_predicting_node_at_pos=1': 6.5625, 'acc_of_predicting_node_at_pos=3': 7.03125, 'acc_of_predicting_node_at_pos=4': 6.484375}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

metric={'acc_goal_from_last_node': 7.5, 'layer_name': 'key=resid_post n=0', 'acc_of_predicting_node_at_pos=0': 100.0, 'acc_of_predicting_node_at_pos=1': 7.03125, 'acc_of_predicting_node_at_pos=3': 6.71875, 'acc_of_predicting_node_at_pos=4': 6.875}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

metric={'acc_goal_from_last_node': 8.28125, 'layer_name': 'key=resid_post n=1', 'acc_of_predicting_node_at_pos=0': 100.0, 'acc_of_predicting_node_at_pos=1': 5.234375, 'acc_of_predicting_node_at_pos=3': 7.1875, 'acc_of_predicting_node_at_pos=4': 6.328125}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

metric={'acc_goal_from_last_node': 8.125, 'layer_name': 'key=resid_post n=2', 'acc_of_predicting_node_at_pos=0': 100.0, 'acc_of_predicting_node_at_pos=1': 7.96875, 'acc_of_predicting_node_at_pos=3': 8.59375, 'acc_of_predicting_node_at_pos=4': 7.1875}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

metric={'acc_goal_from_last_node': 16.5625, 'layer_name': 'key=resid_post n=3', 'acc_of_predicting_node_at_pos=0': 99.921875, 'acc_of_predicting_node_at_pos=1': 8.90625, 'acc_of_predicting_node_at_pos=3': 7.109375, 'acc_of_predicting_node_at_pos=4': 7.03125}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

metric={'acc_goal_from_last_node': 18.59375, 'layer_name': 'key=resid_post n=4', 'acc_of_predicting_node_at_pos=0': 99.6875, 'acc_of_predicting_node_at_pos=1': 9.6875, 'acc_of_predicting_node_at_pos=3': 6.484375, 'acc_of_predicting_node_at_pos=4': 8.59375}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

metric={'acc_goal_from_last_node': 13.203125, 'layer_name': 'key=resid_post n=5', 'acc_of_predicting_node_at_pos=0': 99.765625, 'acc_of_predicting_node_at_pos=1': 8.75, 'acc_of_predicting_node_at_pos=3': 6.09375, 'acc_of_predicting_node_at_pos=4': 7.5}
swap with pos=0 | extract_Xy__last_token__goal


Unnamed: 0,acc_goal_from_last_node,layer_name,acc_of_predicting_node_at_pos=0,acc_of_predicting_node_at_pos=1,acc_of_predicting_node_at_pos=3,acc_of_predicting_node_at_pos=4
0,6.640625,key=embed n=None,7.1875,6.5625,7.03125,6.484375
1,7.5,key=resid_post n=0,100.0,7.03125,6.71875,6.875
2,8.28125,key=resid_post n=1,100.0,5.234375,7.1875,6.328125
3,8.125,key=resid_post n=2,100.0,7.96875,8.59375,7.1875
4,16.5625,key=resid_post n=3,99.921875,8.90625,7.109375,7.03125
5,18.59375,key=resid_post n=4,99.6875,9.6875,6.484375,8.59375
6,13.203125,key=resid_post n=5,99.765625,8.75,6.09375,7.5


### 

### To the second node (note it has separate emb)

In [51]:
baseline_model = load_baseline_model()
W_pos = baseline_model.pos_embed.W_pos.data.clone()

In [52]:
W_pos__first_token_is_goal = W_pos.clone()

# swap first pos embedding with the goal pos emb

t = W_pos__first_token_is_goal[1].clone()
W_pos__first_token_is_goal[1] = W_pos__first_token_is_goal[GOAL_TOKEN_I]
W_pos__first_token_is_goal[GOAL_TOKEN_I] = t
# --- ---

baseline_model.pos_embed.W_pos.data = W_pos__first_token_is_goal

In [53]:
# N_TEST_BATCHES = 20

In [54]:
accuracies = []

for cache_key in cache_keys:
    
    X = []
    y = []

    extra_pos_dict = collections.defaultdict(list)
    
    for i,batch in tqdm(zip(range(N_TEST_BATCHES), deep_tree_dataloader), total=N_TEST_BATCHES):
        Xbatch, ybatch, extra_pos_dict_batch = extract_Xy__last_token__goal(batch, baseline_model, **cache_key)
    
        X.append(Xbatch)
        y.append(ybatch)

        for pos, label_batch in extra_pos_dict_batch.items():
            extra_pos_dict[pos].append(label_batch)
    
    
    X = torch.cat(X)
    y = torch.cat(y)

    for pos, label_batch in extra_pos_dict.items():
        extra_pos_dict[pos] = torch.cat(extra_pos_dict[pos])

    

    acc_goal_from_last_node = run_logreg(X, y)
    


    layer_name = ''.join ( [f'{k}={v} ' for k,v in cache_key.items()] )[:-1]
    metric = {'acc_goal_from_last_node': float(acc_goal_from_last_node), 'layer_name':layer_name}

    for pos, node_label_batch in tqdm(extra_pos_dict.items(), leave=False):
        acc = run_logreg(X, node_label_batch)
        metric[f'acc_of_predicting_node_at_pos={pos}'] = acc
    
    print(f'{metric=}')

    accuracies.append(metric)

print('swap with pos=1 | extract_Xy__last_token__goal')
pd.DataFrame(accuracies)

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

metric={'acc_goal_from_last_node': 6.328125, 'layer_name': 'key=embed n=None', 'acc_of_predicting_node_at_pos=0': 6.5625, 'acc_of_predicting_node_at_pos=1': 6.40625, 'acc_of_predicting_node_at_pos=3': 6.484375, 'acc_of_predicting_node_at_pos=4': 6.09375}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

metric={'acc_goal_from_last_node': 19.84375, 'layer_name': 'key=resid_post n=0', 'acc_of_predicting_node_at_pos=0': 6.953125, 'acc_of_predicting_node_at_pos=1': 43.203125, 'acc_of_predicting_node_at_pos=3': 7.265625, 'acc_of_predicting_node_at_pos=4': 6.875}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

metric={'acc_goal_from_last_node': 24.375, 'layer_name': 'key=resid_post n=1', 'acc_of_predicting_node_at_pos=0': 7.109375, 'acc_of_predicting_node_at_pos=1': 21.171875, 'acc_of_predicting_node_at_pos=3': 7.578125, 'acc_of_predicting_node_at_pos=4': 6.796875}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

metric={'acc_goal_from_last_node': 33.984375, 'layer_name': 'key=resid_post n=2', 'acc_of_predicting_node_at_pos=0': 13.359375, 'acc_of_predicting_node_at_pos=1': 18.59375, 'acc_of_predicting_node_at_pos=3': 10.546875, 'acc_of_predicting_node_at_pos=4': 7.734375}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

metric={'acc_goal_from_last_node': 44.53125, 'layer_name': 'key=resid_post n=3', 'acc_of_predicting_node_at_pos=0': 13.046875, 'acc_of_predicting_node_at_pos=1': 16.5625, 'acc_of_predicting_node_at_pos=3': 9.6875, 'acc_of_predicting_node_at_pos=4': 7.890625}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

metric={'acc_goal_from_last_node': 45.3125, 'layer_name': 'key=resid_post n=4', 'acc_of_predicting_node_at_pos=0': 13.515625, 'acc_of_predicting_node_at_pos=1': 19.6875, 'acc_of_predicting_node_at_pos=3': 7.734375, 'acc_of_predicting_node_at_pos=4': 5.9375}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

metric={'acc_goal_from_last_node': 41.953125, 'layer_name': 'key=resid_post n=5', 'acc_of_predicting_node_at_pos=0': 11.5625, 'acc_of_predicting_node_at_pos=1': 20.0, 'acc_of_predicting_node_at_pos=3': 6.953125, 'acc_of_predicting_node_at_pos=4': 6.640625}
swap with pos=1 | extract_Xy__last_token__goal


Unnamed: 0,acc_goal_from_last_node,layer_name,acc_of_predicting_node_at_pos=0,acc_of_predicting_node_at_pos=1,acc_of_predicting_node_at_pos=3,acc_of_predicting_node_at_pos=4
0,6.328125,key=embed n=None,6.5625,6.40625,6.484375,6.09375
1,19.84375,key=resid_post n=0,6.953125,43.203125,7.265625,6.875
2,24.375,key=resid_post n=1,7.109375,21.171875,7.578125,6.796875
3,33.984375,key=resid_post n=2,13.359375,18.59375,10.546875,7.734375
4,44.53125,key=resid_post n=3,13.046875,16.5625,9.6875,7.890625
5,45.3125,key=resid_post n=4,13.515625,19.6875,7.734375,5.9375
6,41.953125,key=resid_post n=5,11.5625,20.0,6.953125,6.640625


### just poke around

In [55]:
baseline_model = load_baseline_model()
W_pos = baseline_model.pos_embed.W_pos.data.clone()


W_pos__first_token_is_goal = W_pos.clone()

# swap first pos embedding with the goal pos emb

t = W_pos__first_token_is_goal[0].clone()
W_pos__first_token_is_goal[0] += W_pos__first_token_is_goal[GOAL_TOKEN_I]
W_pos__first_token_is_goal[GOAL_TOKEN_I] = t
# --- ---

baseline_model.pos_embed.W_pos.data = W_pos__first_token_is_goal

In [56]:
accuracies = []

for cache_key in cache_keys:
    
    X = []
    y = []

    extra_pos_dict = collections.defaultdict(list)
    
    for i,batch in tqdm(zip(range(N_TEST_BATCHES), deep_tree_dataloader), total=N_TEST_BATCHES):
        Xbatch, ybatch, extra_pos_dict_batch = extract_Xy__last_token__goal(batch, baseline_model, **cache_key)
    
        X.append(Xbatch)
        y.append(ybatch)

        for pos, label_batch in extra_pos_dict_batch.items():
            extra_pos_dict[pos].append(label_batch)
    
    
    X = torch.cat(X)
    y = torch.cat(y)

    for pos, label_batch in extra_pos_dict.items():
        extra_pos_dict[pos] = torch.cat(extra_pos_dict[pos])

    

    acc_goal_from_last_node = run_logreg(X, y)
    


    layer_name = ''.join ( [f'{k}={v} ' for k,v in cache_key.items()] )[:-1]
    metric = {'acc_goal_from_last_node': float(acc_goal_from_last_node), 'layer_name':layer_name}

    for pos, node_label_batch in tqdm(extra_pos_dict.items(), leave=False):
        acc = run_logreg(X, node_label_batch)
        metric[f'acc_of_predicting_node_at_pos={pos}'] = acc
    
    print(f'{metric=}')

    accuracies.append(metric)

print('ADD to zero pos extract_Xy__last_token__goal')
pd.DataFrame(accuracies)

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

metric={'acc_goal_from_last_node': 8.671875, 'layer_name': 'key=embed n=None', 'acc_of_predicting_node_at_pos=0': 7.109375, 'acc_of_predicting_node_at_pos=1': 6.40625, 'acc_of_predicting_node_at_pos=3': 6.328125, 'acc_of_predicting_node_at_pos=4': 6.484375}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

metric={'acc_goal_from_last_node': 7.578125, 'layer_name': 'key=resid_post n=0', 'acc_of_predicting_node_at_pos=0': 89.296875, 'acc_of_predicting_node_at_pos=1': 8.125, 'acc_of_predicting_node_at_pos=3': 6.640625, 'acc_of_predicting_node_at_pos=4': 6.09375}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

metric={'acc_goal_from_last_node': 14.53125, 'layer_name': 'key=resid_post n=1', 'acc_of_predicting_node_at_pos=0': 75.546875, 'acc_of_predicting_node_at_pos=1': 8.125, 'acc_of_predicting_node_at_pos=3': 8.28125, 'acc_of_predicting_node_at_pos=4': 7.1875}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

metric={'acc_goal_from_last_node': 23.203125, 'layer_name': 'key=resid_post n=2', 'acc_of_predicting_node_at_pos=0': 67.265625, 'acc_of_predicting_node_at_pos=1': 8.828125, 'acc_of_predicting_node_at_pos=3': 9.453125, 'acc_of_predicting_node_at_pos=4': 7.34375}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

metric={'acc_goal_from_last_node': 37.5, 'layer_name': 'key=resid_post n=3', 'acc_of_predicting_node_at_pos=0': 48.90625, 'acc_of_predicting_node_at_pos=1': 7.34375, 'acc_of_predicting_node_at_pos=3': 8.515625, 'acc_of_predicting_node_at_pos=4': 6.484375}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

metric={'acc_goal_from_last_node': 32.578125, 'layer_name': 'key=resid_post n=4', 'acc_of_predicting_node_at_pos=0': 38.28125, 'acc_of_predicting_node_at_pos=1': 9.0625, 'acc_of_predicting_node_at_pos=3': 8.828125, 'acc_of_predicting_node_at_pos=4': 6.328125}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

metric={'acc_goal_from_last_node': 18.4375, 'layer_name': 'key=resid_post n=5', 'acc_of_predicting_node_at_pos=0': 21.09375, 'acc_of_predicting_node_at_pos=1': 8.75, 'acc_of_predicting_node_at_pos=3': 8.59375, 'acc_of_predicting_node_at_pos=4': 6.40625}
ADD to zero pos extract_Xy__last_token__goal


Unnamed: 0,acc_goal_from_last_node,layer_name,acc_of_predicting_node_at_pos=0,acc_of_predicting_node_at_pos=1,acc_of_predicting_node_at_pos=3,acc_of_predicting_node_at_pos=4
0,8.671875,key=embed n=None,7.109375,6.40625,6.328125,6.484375
1,7.578125,key=resid_post n=0,89.296875,8.125,6.640625,6.09375
2,14.53125,key=resid_post n=1,75.546875,8.125,8.28125,7.1875
3,23.203125,key=resid_post n=2,67.265625,8.828125,9.453125,7.34375
4,37.5,key=resid_post n=3,48.90625,7.34375,8.515625,6.484375
5,32.578125,key=resid_post n=4,38.28125,9.0625,8.828125,6.328125
6,18.4375,key=resid_post n=5,21.09375,8.75,8.59375,6.40625
