In [1]:
from main import get_last_data_path
import pickle
from consts import GEMMA_2

data_dir = "data"

data_path = get_last_data_path(data_dir)
with open(data_path, "rb") as f:
    data = pickle.load(f)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
gemma_2_data = data[GEMMA_2]
position = -1

print(gemma_2_data.keys())

dict_keys(['1161', '1200', '137', '3949', '4086', '4287', '5548', '6832', '695', '709', '7988', '8015'])


In [10]:
import numpy as np
import re

all_steering_vectors = []
all_mlp_vectors, all_attn_vectors, all_resid_vectors = [], [], []
for steering_vector, per_vector_data in gemma_2_data.items():
    vec = per_vector_data['meta']['direction']
    all_steering_vectors.append(vec.float())
    
    data_at_position = per_vector_data['all']
    negative_outputs_train, positive_outputs_train, negative_outputs_test, positive_outputs_test = data_at_position
    negative_dots_train, negative_norms_train, negative_agg_train = negative_outputs_train
    positive_dots_train, positive_norms_train, positive_agg_train = positive_outputs_train
    mlp_vector, attn_vector, resid_vector = np.zeros(vec.shape), np.zeros(vec.shape), np.zeros(vec.shape)
    for component_name in negative_agg_train.keys():
        re_pattern = 'blocks\.(\d+)\.*'
        layer_num = re.search(re_pattern, component_name).group(1)
        layer_num = int(layer_num)
        if layer_num < 15:
            continue
        diff_means = positive_agg_train[component_name] - negative_agg_train[component_name]
        diff_means = diff_means.float().numpy()
        if 'ln2' in component_name:
            mlp_vector += diff_means
        elif 'ln1' in component_name:
            attn_vector += diff_means
        else:
            resid_vector += diff_means
    all_mlp_vectors.append(mlp_vector)
    all_attn_vectors.append(attn_vector)
    all_resid_vectors.append(resid_vector)

all_steering_vectors = np.array(all_steering_vectors) / np.linalg.norm(all_steering_vectors, axis=1, keepdims=True)
all_mlp_vectors = np.array(all_mlp_vectors) / np.linalg.norm(all_mlp_vectors, axis=1, keepdims=True)
all_attn_vectors = np.array(all_attn_vectors) / np.linalg.norm(all_attn_vectors, axis=1, keepdims=True)
all_resid_vectors = np.array(all_resid_vectors) / np.linalg.norm(all_resid_vectors, axis=1, keepdims=True)


In [11]:
dot_products_mlp = np.einsum('ij,ij->i', all_mlp_vectors, all_steering_vectors)
dot_products_attn = np.einsum('ij,ij->i', all_attn_vectors, all_steering_vectors)
dot_products_resid = np.einsum('ij,ij->i', all_resid_vectors, all_steering_vectors)

In [12]:
dot_products_mlp 

array([0.10420628, 0.05083645, 0.08824812, 0.14025829, 0.11017924,
       0.01504239, 0.08750446, 0.11274216, 0.08378014, 0.10738567,
       0.12414297, 0.09391268])

In [13]:
dot_products_attn

array([ 0.01781004,  0.05212393, -0.00397307,  0.01475263,  0.08361365,
        0.02437153,  0.03358879,  0.02898224,  0.00423687,  0.01757811,
        0.06283319,  0.0066732 ])

In [14]:
dot_products_resid

array([0.17869838, 0.09941412, 0.10847355, 0.180939  , 0.15526515,
       0.00679891, 0.09239458, 0.15401034, 0.14686764, 0.1547267 ,
       0.15670916, 0.11853292])

In [15]:
dot_products_mlp.mean(), dot_products_attn.mean(), dot_products_resid.mean()

(np.float64(0.09318657049480035),
 np.float64(0.028549260221813418),
 np.float64(0.12940253730694742))

In [16]:
dot_products_mlp.std(), dot_products_attn.std(), dot_products_resid.std()

(np.float64(0.0319034931483563),
 np.float64(0.02476502549424573),
 np.float64(0.04646612934913045))