In [None]:
import torch
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np
import torch.nn.functional as F
import seaborn as sns
import pandas as pd
from tqdm import tqdm

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer
import transformer_lens.patching as patching
from fancy_einsum import einsum

from functools import partial
import scipy as sp

from copy import deepcopy

from ioi_dataset import IOIDataset
import networkx as nx
import einops
torch.set_grad_enabled(False)

import utils
import scipy as sp
import scipy.cluster as cl

import h5py
import matplotlib.lines as mlines

In [None]:
class ModelContext:
    
    def __init__(self, model_name, model_family, device, num_prompts):
        self.model_name = model_name
        self.model_family = model_family
        self.device = device
        # Loading the model with processing; fold_ln = True, center_writing_weights = True
        self.model = HookedTransformer.from_pretrained(model_name, device=device)
        if model_family == "gemma":
            prepend_bos = True
        else:
            prepend_bos = False
        self.ioi_dataset = IOIDataset(
            model_family=self.model_family,
            prompt_type="mixed",
            N=num_prompts,
            tokenizer=self.model.tokenizer,
            prepend_bos=prepend_bos,
            seed=0,
            device=device)
        
        # run on this set of prompts
        self.logits, self.cache = self.model.run_with_cache(self.ioi_dataset.toks)

        # This creates the keys , which have the individual attention heads outputs.
        self.cache.compute_head_results()

        self.ALL_AHS = [(i, j) for i in range(self.model.cfg.n_layers) for j in range(self.model.cfg.n_heads)]
        self.d_model = self.model.cfg.d_model

        self.svs_used_u = {}
        self.svs_used_v = {}
        self.dfs_i = {}
        self.dfs_j = {}

        self.mean_ips = {}
        self.std_ips = {}

    def get_svds(self):
        self.U, self.S, self.VT = utils.get_omega_decomposition_all_ahs(self.model, self.model_name, new_defn_omega=True) 

    def trace_prompts(self, prompt_list, firing_criteria = 'threshold', attn_thresh = 0.5, trace_mlp = False, use_svs = True):
        '''
        Trace the prompts through the model, and store the singular vectors used in each firing.
        Default is to use firings where attn > 1/2, and to use intersection method to find the orhogonal slices.
        Args:
            prompt_list: list of prompt ids to trace
            firing_criteria: 'threshold' or '1/n'
            attn_thresh: threshold for attention firing if firing_criteria is 'threshold'
            trace_mlp: whether to trace the MLPs
            use_svs: whether to use singular vectors for decomposition

        This is based on the first half of the code in utils/__trace_firing_optimized_rope_new_defn_omega() - don't need the tracing part, just the svd part.
        '''
        frac_contrib_thresh = 1.0
        candidates = []
        for prompt_id in prompt_list:
            for layer in range(0, self.model.cfg.n_layers):
                for ah_idx in range(self.model.cfg.n_heads):
                    # skipping dest = 0; special case where contrib can be negative
                    for dest_token in range(1, self.ioi_dataset.word_idx["end"][prompt_id]+1):
                        for src_token in range(0, dest_token+1):
                            # Dynamic threshold
                            if firing_criteria == "1/n":
                                attn_thresh = 1/(dest_token+1) # n in this case is the number of src_tokens

                            # We cannot trace firings with attention score < 1/n
                            if firing_criteria == "threshold" and attn_thresh < 1/(dest_token+1):
                                continue
                                
                            # did the attention head fire on this source/dest combination?
                            if self.cache[f"blocks.{layer}.attn.hook_pattern"][prompt_id, ah_idx, dest_token, src_token].item() < attn_thresh:
                                continue

                            # NOT skipping punct token
                            #if src_token == ioi_dataset.word_idx["punct"][prompt_id].item():
                            #    continue

                            candidates.append((prompt_id, layer, ah_idx, dest_token, src_token))
        
        for prompt_id, layer, ah_idx, dest_token, src_token in tqdm(candidates, total=len(candidates)):
            X = self.cache[f"blocks.{layer}.ln1.hook_normalized"][prompt_id, :, :] #Float[Tensor, 'n_tokens d_model'] 

            if self.model_name == "gemma-2-2b":
                df_decomp_i, df_decomp_j = utils.get_components_used_comparative_no_bias(X, src_token, dest_token, layer, 
                                                                            ah_idx, self.U, self.S, self.VT, 
                                                                            self.model_name, self.device)
            else:
                df_decomp_i, df_decomp_j = utils.get_components_used_comparative_new_defn(X, src_token, dest_token, layer, 
                                                                            ah_idx, self.U, self.S, self.VT, 
                                                                            self.model_name, self.device)

            # df_decomp_i will be None when contribution comes from the bias term c_1 and can't be traced
            # see last paragraph of Appendix A
            if df_decomp_i is not None:
                # Decomposing on x_i
                if use_svs:
                    last_sv_idx = np.where(df_decomp_i['sv_frac_contribution'].values.round(5) >= frac_contrib_thresh)[0][0]
                else:
                    last_sv_idx = self.model.cfg.d_head # all SVs
                svs_decomp_i = df_decomp_i.iloc[:last_sv_idx+1].idx.astype(int).values
                self.svs_used_u[(prompt_id, layer, ah_idx, dest_token, src_token)] = svs_decomp_i
                self.dfs_i[(prompt_id, layer, ah_idx, dest_token, src_token)] = df_decomp_i

            # Decomposing on x_j
            if df_decomp_j is not None:
                if use_svs:
                    last_sv_idx = np.where(df_decomp_j['sv_frac_contribution'].values.round(5) >= frac_contrib_thresh)[0][0]
                else:
                    last_sv_idx = self.model.cfg.d_head-1 # all SVs
                svs_decomp_j = df_decomp_j.iloc[:last_sv_idx+1].idx.astype(int).values
                self.svs_used_v[(prompt_id, layer, ah_idx, dest_token, src_token)] = svs_decomp_j
                self.dfs_j[(prompt_id, layer, ah_idx, dest_token, src_token)] = df_decomp_j
    
    def get_contrib_u_signals(self, df, svs, SVecs, x):
        # compute the signal used in this firing
        # weighting each singular vector by its contribution
        retvec = torch.zeros(self.d_model)
        for sv in svs:
            df_row = df[df['idx'] == sv]
            x_i_ip = utils.apply_projection(SVecs[:, sv].reshape(-1, 1), x).T @ x
            #retvec += SVecs[:, sv] * torch.Tensor(np.sign(x_i_ip) * df_row['contrib'].values)
            retvec += SVecs[:, sv] * torch.Tensor(df_row['contrib'].values)
        return retvec / torch.linalg.norm(retvec)
    
    def get_contrib_v_signals(self, df, svs, SVecs, x):
        # compute the signal used in this firing
        # weighting each singular vector by its contribution (not a projection)
        retvec = torch.zeros(self.d_model)
        for sv in svs:
            df_row = df[df['idx'] == sv]
            x_j_ip = utils.apply_projection(SVecs[:, sv].reshape(-1, 1), x).T @ x
            #retvec += SVecs[:, sv] * torch.Tensor(np.sign(x_j_ip) * df_row['contrib'].values)
            retvec += SVecs[:, sv] * torch.Tensor(df_row['contrib'].values)
        return retvec / torch.linalg.norm(retvec)

    def get_u_signals(self, df, svs, SVecs, x):
        # compute the signal used in this firing
        # as the projection of the residual on the signal subspace 
        retvec = torch.zeros(self.d_model)
        for sv in svs:
            retvec += utils.apply_projection(SVecs[:, sv].reshape(-1, 1), x)
        return retvec / torch.linalg.norm(retvec)
    
    def get_v_signals(self, df, svs, SVecs, x):
        # compute the signal used in this firing
        # as the projection of the residual on the signal subspace
        retvec = torch.zeros(self.d_model)
        for sv in svs:
            retvec += utils.apply_projection(SVecs[:, sv].reshape(-1, 1), x)
        return retvec / torch.linalg.norm(retvec)
    
    def compute_signals(self):
        # compile the set of all signals used across all firings of the model 
        self.u_signals = []
        self.v_signals = []
        self.contrib_u_signals = [] 
        self.contrib_v_signals = []
        # firings for which we can trace the destination token
        for key in tqdm(self.svs_used_u.keys(), total=len(self.svs_used_u), desc="Destination signals"):
            prompt_id, layer, ah_idx, dest_token, src_token = key
            if (self.model_name == 'gpt2-small'):
                diff = -1
            else:
                diff = dest_token - src_token
            #X = deepcopy(self.cache[f"blocks.{layer}.ln1.hook_normalized"][prompt_id, :, :])
            X = self.cache[f"blocks.{layer}.ln1.hook_normalized"][prompt_id, :, :] # no deepcopy needed here
            if self.model_name == "gemma-2-2b":
                contrib_u_signal = self.get_contrib_u_signals(self.dfs_i[key], self.svs_used_u[key], self.U[layer, ah_idx, diff], X[dest_token, :])
                u_signal = self.get_u_signals(self.dfs_i[key], self.svs_used_u[key], self.U[layer, ah_idx, diff], X[dest_token, :])
            else:
                contrib_u_signal = self.get_contrib_u_signals(self.dfs_i[key], self.svs_used_u[key], self.U['d'][layer, ah_idx, diff], X[dest_token, :])
                u_signal = self.get_u_signals(self.dfs_i[key], self.svs_used_u[key], self.U['d'][layer, ah_idx, diff], X[dest_token, :])
        
            self.u_signals.append(u_signal)
            self.contrib_u_signals.append(contrib_u_signal)
        # firings for which we can trace the source token
        for key in tqdm(self.svs_used_v.keys(), total=len(self.svs_used_v), desc="Destination signals"):
            prompt_id, layer, ah_idx, dest_token, src_token = key
            if (self.model_name == 'gpt2-small'):
                diff = -1
            else:
                diff = dest_token - src_token
            #X = deepcopy(self.cache[f"blocks.{layer}.ln1.hook_normalized"][prompt_id, :, :])
            X = self.cache[f"blocks.{layer}.ln1.hook_normalized"][prompt_id, :, :] # no deepcopy needed here
            if self.model_name == "gemma-2-2b":
                contrib_v_signal = self.get_contrib_v_signals(self.dfs_j[key], self.svs_used_v[key], self.VT[layer, ah_idx, diff].T, X[src_token, :])
                v_signal = self.get_v_signals(self.dfs_j[key], self.svs_used_v[key], self.VT[layer, ah_idx, diff].T, X[src_token, :])
            else:
                contrib_v_signal = self.get_contrib_v_signals(self.dfs_j[key], self.svs_used_v[key], self.VT['s'][layer, ah_idx, diff].T, X[src_token, :])
                v_signal = self.get_v_signals(self.dfs_j[key], self.svs_used_v[key], self.VT['s'][layer, ah_idx, diff].T, X[src_token, :])
            self.v_signals.append(v_signal)
            self.contrib_v_signals.append(contrib_v_signal)
        
        
        self.u_signals = np.array([x.numpy() for x in self.u_signals])
        self.v_signals = np.array([x.numpy() for x in self.v_signals])
        self.contrib_u_signals = np.array([x.numpy() for x in self.contrib_u_signals])
        self.contrib_v_signals = np.array([x.numpy() for x in self.contrib_v_signals])
        
    def compute_control_signals(self, similarity_threshold = 0.75):
        self.ctrl_sigs = {}
        self.ctrld_heads = {}
        test_signal_candidates = [self.u_signals, self.v_signals, 
                              self.contrib_u_signals, self.contrib_v_signals]
        test_signal_types = ['u_signals', 'v_signals', 'contrib_u_signals', 'contrib_v_signals']
        # for each of the four signal types, compute the control signals
        # and the controlled heads
        # controlled heads are those that have a consistent control signal
        # across all firings
        for test_signals, signal_type in tqdm(zip(test_signal_candidates, test_signal_types), total=len(test_signal_candidates)):
            # for each head, determine whether the head is "controlled", ie
            # uses a single predominant signal for firing on the default
            # token(s), and compute an estimate of that signal
            if signal_type in ['contrib_u_signals', 'u_signals']:
                firings = list(self.svs_used_u.keys())
            else:
                firings = list(self.svs_used_v.keys())
            vec_sets = {}
            mean_ips = np.zeros((self.model.cfg.n_layers, self.model.cfg.n_heads))
            std_ips = np.zeros((self.model.cfg.n_layers, self.model.cfg.n_heads))
            n_ips = np.zeros((self.model.cfg.n_layers, self.model.cfg.n_heads))
            for (test_layer, test_ah_idx) in self.ALL_AHS:
                # get all zero-firings for this head
                vec_sets[(test_layer, test_ah_idx)] = []
                for key, sig in zip(firings, test_signals):
                    prompt_id, layer, ah_idx, dest_token, src_token = key
                    # if this firing is for this head
                    if (layer == test_layer) and (ah_idx == test_ah_idx):
                        # we have specific rules for default tokens in each model
                        if self.model_name == 'gpt2-small':
                            if (src_token == 0):
                                vec_sets[(test_layer, test_ah_idx)].append(sig)
                        elif self.model_name == 'EleutherAI/pythia-160m' or self.model_name == "gemma-2-2b":
                            if src_token in [0, self.ioi_dataset.word_idx["punct"][prompt_id].item()]:
                                vec_sets[(test_layer, test_ah_idx)].append(sig)
                        else:
                            raise ValueError('default tokens are not yet configured for this model')
                vec_sets[(test_layer, test_ah_idx)] = np.array(vec_sets[(test_layer, test_ah_idx)])
                # we now have all the signals used in zero-firings for this head
                # next, compute statistics used to identify whether head has consistent control signals
                # specifically, compute the average mean cosine distance between each signal for this head
                n_ips[test_layer, test_ah_idx] = vec_sets[test_layer, test_ah_idx].shape[0]
                if n_ips[test_layer, test_ah_idx] > 1:
                    dists = sp.spatial.distance.pdist(vec_sets[test_layer, test_ah_idx], metric='cosine')
                    mean_ips[test_layer, test_ah_idx] = np.mean(dists)
                    std_ips[test_layer, test_ah_idx] = np.std(dists)
                else:
                    mean_ips[test_layer, test_ah_idx] = np.nan 
                    std_ips[test_layer, test_ah_idx] = np.nan 
            # some heads are controlled (have consistent control signals), some not
            controlled_heads = []
            control_signals = []
            for (test_layer, test_ah_idx) in self.ALL_AHS:
                if (mean_ips[test_layer, test_ah_idx] < similarity_threshold):
                    controlled_heads.append((test_layer, test_ah_idx))
                    control_signals.append(np.mean(vec_sets[(test_layer, test_ah_idx)], axis = 0))
            # control signals has one (mean) control signal per controlled head
            self.ctrl_sigs[signal_type] = np.array(control_signals)
            self.ctrld_heads[signal_type] = controlled_heads
            self.mean_ips[signal_type] = mean_ips
            self.std_ips[signal_type] = std_ips

    def load_cached_tracing(self, prompt_list):
        for prompt_id in prompt_list:
            with h5py.File(f'control_signals_cache/dicts_{self.model_name}_{prompt_id}.hdf5', 'r') as f:
                for key in f.keys():
                    dict_type, dict_key = key.split("_")
                    dict_key = eval(dict_key)
                    if dict_type == "dfs-i":
                        df = pd.DataFrame(f[key][:], columns=['idx', 'singular_value', 'x_i_ip', 'x_j_ip', 'denom_avg', 'product', 'contrib', 'sv_frac_contribution'])
                        df["idx"] = df["idx"].astype(int)            
                        self.dfs_i[dict_key] = df
                    elif dict_type == "dfs-j":
                        df = pd.DataFrame(f[key][:], columns=['idx', 'singular_value', 'x_i_ip', 'x_j_ip', 'denom_avg', 'product', 'contrib', 'sv_frac_contribution'])
                        df["idx"] = df["idx"].astype(int)            
                        self.dfs_j[dict_key] = df
                    elif dict_type == "svs-used-u":
                        self.svs_used_u[dict_key] = f[key][:]
                    elif dict_type == "svs-used-v":
                        self.svs_used_v[dict_key] = f[key][:]
                    else:
                        print(f"Error in the dict_type={dict_type}. Key={key}")

In [None]:
def main_clusters(cluster_id):
    '''
    Given a list of cluster ids, return the sorted list of cluster ids
    sorted by the number of elements in each cluster.'''
    cluster_sizes = {}
    for id in cluster_id:
        if id not in cluster_sizes:
            cluster_sizes[id] = 1
        else:
            cluster_sizes[id] += 1
    return sorted(cluster_sizes, key = cluster_sizes.get, reverse=True)

def f_counts(fset): 
    '''
    Given a set of firings, return the counts of each firing.
    fset: set of firings
    '''
    # fset is a list of tuples, where each tuple is (prompt, layer, head, dest_token, src_token)
    # f1_cnts is a dictionary where the keys are the firing sources and the values are the counts
    # of each firing source
    f1_cnts = {}    
    for f in fset:
        f_src = f[4]
        if f_src in f1_cnts:
            f1_cnts[f_src] += 1
        else:
            f1_cnts[f_src] = 1
    return f1_cnts

def conf_matrix(c_sorted, subject_cluster, targets):
    f = []
    cm = {'tp': 0, 'fp': 0, 'tn': 0, 'fn': 0}
    for c in range(len(c_sorted)):
        c_firings = [firings[i] for i in np.where(cluster_id == c_sorted[c])[0]]
        f.append(c_firings)
        if subject_cluster == c:
            for firing in c_firings:
                if firing[4] in targets:
                    cm['tp'] += 1
                else:
                    cm['fp'] += 1
        else:
            for firing in c_firings:
                if firing[4] == 0:
                    cm['fn'] += 1
                else:
                    cm['tn'] += 1
    return cm, f

In [None]:
prompt_set = [1, 2, 3]

In [None]:
gpt2s = ModelContext('gpt2-small', 'gpt2', 'cpu', 16)
gpt2s.get_svds()
gpt2s.trace_prompts(prompt_set, firing_criteria='threshold', attn_thresh = 0.4) # attn_thresh = 0.5
gpt2s.compute_signals()
# similarity_threshold is the threshold for the mean cosine distance between signals
# that labels a head as "controlled" and assigns it a single signal
gpt2s.compute_control_signals(similarity_threshold = 0.5) # no big change between 0.5 and 0.9

In [None]:
pyth = ModelContext('EleutherAI/pythia-160m', 'pythia', 'cpu', 16)
pyth.get_svds()
pyth.trace_prompts(prompt_set, firing_criteria='threshold', attn_thresh = 0.4) # attn_thresh = 0.5
pyth.compute_signals()
# similarity_threshold is the threshold for the mean cosine distance between signals
# that labels a head as "controlled" and assigns it a single signal
pyth.compute_control_signals(similarity_threshold = 0.5)

In [None]:
gemma = ModelContext('gemma-2-2b', 'gemma', 'cpu', 16)
gemma.get_svds()
gemma.load_cached_tracing(prompt_set)
# To fully trace the prompts again:
#gemma.trace_prompts(prompt_set, firing_criteria='threshold', attn_thresh = 0.4) # attn_thresh = 0.5
gemma.compute_signals()
# # similarity_threshold is the threshold for the mean cosine distance between signals
# # that labels a head as "controlled" and assigns it a single signal
gemma.compute_control_signals(similarity_threshold = 0.5)

## GPT-2

In [None]:
test_signals = gpt2s.v_signals
cosine_distances = sp.spatial.distance.squareform(sp.spatial.distance.pdist(test_signals, metric='cosine'))
Z = cl.hierarchy.linkage(test_signals, 'average', 'cosine', optimal_ordering = False)
sns.clustermap(cosine_distances, row_linkage = Z, col_linkage = Z, figsize = (10,10))
plt.title('Cosine distances of Source signals');

In [None]:
# for average linkage with "regular" U signal, a threshold of 0.99 separates the two big clusters
# note that it gets much worse at threshold of 1.0 !
# for average linkage with "regular" C signal, a threshold of 1 separates the two big clusters
# note that it gets much worse at threshold of 1.0 !
plt.figure()
cluster_threshold = .97
# cosine_distances = sp.spatial.distance.squareform(sp.spatial.distance.pdist(contrib_u_signals, metric='cosine'))
Z_contrib = cl.hierarchy.linkage(test_signals, 'average', 'cosine', optimal_ordering = False)
cluster_id = cl.hierarchy.fcluster(Z_contrib, cluster_threshold, criterion = 'distance')
c_sorted = main_clusters(cluster_id)
firings = list(gpt2s.svs_used_v.keys())
f = []
for c in range(len(c_sorted)):
    f.append([firings[i] for i in np.where(cluster_id == c_sorted[c])[0]])
#cl.hierarchy.dendrogram(Z_contrib, color_threshold = cluster_threshold)
print('Source tokens in each cluster:')
for st in range(len(f)):
    print(f_counts(f[st]))
# confusion matrix tells us how cluster zero works as a classifier of a zero-firing
# conf_matrix(c_sorted, subject_cluster, targets)
print(conf_matrix(c_sorted, 0, [0])[0])

In [None]:
# find vector that is "average" for cluster zero, meaning first cluster in sorted list
# assuming the step above has found that cluster zero gives good separation (confusion matrix)
cluster_zero_signals = [s for i, s in enumerate(test_signals) if i in np.where(cluster_id == c_sorted[0])[0]]
cluster_nonzero_signals = [s for i, s in enumerate(test_signals) if i not in np.where(cluster_id == c_sorted[0])[0]]
# compute average cosine dist 
cosine_distances = sp.spatial.distance.squareform(sp.spatial.distance.pdist(cluster_zero_signals, metric='cosine'))
# find vector that has minimum avg cos dist to all others in the cluster
average_cos_dist = np.mean(cosine_distances, axis = 1)
min_idx = np.where(average_cos_dist == np.min(average_cos_dist))[0][0]
min_vec = cluster_zero_signals[min_idx]

In [None]:
zero_token_signals = np.array([s for s, f in zip(gpt2s.v_signals, gpt2s.svs_used_v) if f[4] == 0])
nonzero_token_signals = np.array([s for s, f in zip(gpt2s.v_signals, gpt2s.svs_used_v) if f[4] != 0])

In [None]:
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
plt.rc('font', size=8)

fig, ax = plt.subplots(1, 1, figsize=(3, 1.9))
sns.kdeplot(zero_token_signals @ min_vec, label = 'zero tokens', bw_adjust = 0.125)
sns.kdeplot(nonzero_token_signals @ min_vec, label = 'non-zero tokens', bw_adjust = 0.125)
#plt.legend(loc = 'best')
plt.legend(loc='lower center', bbox_to_anchor=(0.5, 1), ncol=2);
plt.xlabel("Inner product");
plt.tight_layout()
#plt.title('IP of prototype signal with other signals');
plt.savefig("figures/control_signals/gpt2-small_control-signals_ip_v-signals.pdf", bbox_inches='tight', dpi=800);
plt.close()

## Pythia

In [None]:
test_signals = pyth.v_signals
cosine_distances = sp.spatial.distance.squareform(sp.spatial.distance.pdist(test_signals, metric='cosine'))
Z = cl.hierarchy.linkage(test_signals, 'average', 'cosine', optimal_ordering = False)
sns.clustermap(cosine_distances, row_linkage = Z, col_linkage = Z, figsize = (10,10))
plt.title('Cosine distances of Src signals');

In [None]:
# pythia: very good separation obtained for v_signals, average linkage, threshold = 0.89
plt.figure()
cluster_threshold = 0.99
Z_contrib = cl.hierarchy.linkage(test_signals, 'average', 'cosine', optimal_ordering = False)
cluster_id = cl.hierarchy.fcluster(Z_contrib, cluster_threshold, criterion = 'distance')
c_sorted = main_clusters(cluster_id)
f = []
firings = list(pyth.svs_used_v.keys())
for c in range(len(c_sorted)):
    f.append([firings[i] for i in np.where(cluster_id == c_sorted[c])[0]])
cl.hierarchy.dendrogram(Z_contrib, color_threshold = cluster_threshold)
plt.title('Contrib signal defn')
for st in range(len(f)):
    print(f_counts(f[st]))
# note this is only approximately right; we need to apply the punct on a per-prompt basis
# potentially inflating false negatives at expense of true negatives
# and true positives at expnse of false positives
pyth_puncts = [pyth.ioi_dataset.word_idx["punct"][prompt_id] for prompt_id in prompt_set]
print(conf_matrix(c_sorted, 0, [0] + pyth_puncts)[0])

In [None]:
# find vector that is "average" for cluster zero
cluster_zero_signals = [s for i, s in enumerate(test_signals) if i in np.where(cluster_id == c_sorted[0])[0]]
cluster_nonzero_signals = [s for i, s in enumerate(test_signals) if i not in np.where(cluster_id == c_sorted[0])[0]]
# compute average cosine dist 
cosine_distances = sp.spatial.distance.squareform(sp.spatial.distance.pdist(cluster_zero_signals, metric='cosine'))
# find vector that has minimum avg cos dist to all others in the cluster
average_cos_dist = np.mean(cosine_distances, axis = 1)
min_idx = np.where(average_cos_dist == np.min(average_cos_dist))[0][0]
min_vec = cluster_zero_signals[min_idx]

In [None]:
zero_token_signals = np.array([s for s, f in zip(test_signals, firings) if f[4] in [0] + pyth_puncts])
nonzero_token_signals = np.array([s for s, f in zip(test_signals, firings) if f[4] not in [0] + pyth_puncts])

In [None]:
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
plt.rc('font', size=8)

fig, ax = plt.subplots(1, 1, figsize=(3, 1.9))
sns.kdeplot(zero_token_signals @ min_vec, label = 'zero tokens', bw_adjust = 0.125)
sns.kdeplot(nonzero_token_signals @ min_vec, label = 'non-zero tokens', bw_adjust = 0.125)
#plt.legend(loc = 'best')
plt.legend(loc='lower center', bbox_to_anchor=(0.5, 1), ncol=2);
plt.xlabel("Inner product");
plt.tight_layout()
#plt.title('IP of prototype signal with other signals');
plt.savefig("figures/control_signals/pythia-160m_control-signals_ip_v-signals.pdf", bbox_inches='tight', dpi=800);
plt.close()

## Gemma-2 2B

In [None]:
test_signals = gemma.v_signals
cosine_distances = sp.spatial.distance.squareform(sp.spatial.distance.pdist(test_signals, metric='cosine'))
Z = cl.hierarchy.linkage(test_signals, 'average', 'cosine', optimal_ordering = False)
sns.clustermap(cosine_distances, row_linkage = Z, col_linkage = Z, figsize = (10,10))
plt.title('Cosine distances of Source signals');

In [None]:
# gemma:
plt.figure()
cluster_threshold = 0.99
Z_contrib = cl.hierarchy.linkage(test_signals, 'average', 'cosine', optimal_ordering = False)
cluster_id = cl.hierarchy.fcluster(Z_contrib, cluster_threshold, criterion = 'distance')
c_sorted = main_clusters(cluster_id)
f = []
firings = list(gemma.svs_used_v.keys())
for c in range(len(c_sorted)):
    f.append([firings[i] for i in np.where(cluster_id == c_sorted[c])[0]])
cl.hierarchy.dendrogram(Z_contrib, color_threshold = cluster_threshold)
plt.title('Contrib signal defn')
for st in range(len(f)):
    print(f_counts(f[st]))
# note this is only approximately right; we need to apply the punct on a per-prompt basis
# potentially inflating false negatives at expense of true negatives
# and true positives at expnse of false positives
gemma_puncts = [gemma.ioi_dataset.word_idx["punct"][prompt_id] for prompt_id in prompt_set]
print(conf_matrix(c_sorted, 0, [0] + gemma_puncts)[0])

In [None]:
# find vector that is "average" for cluster zero
cluster_zero_signals = [s for i, s in enumerate(test_signals) if i in np.where(cluster_id == c_sorted[0])[0]]
cluster_nonzero_signals = [s for i, s in enumerate(test_signals) if i not in np.where(cluster_id == c_sorted[0])[0]]
# compute average cosine dist 
cosine_distances = sp.spatial.distance.squareform(sp.spatial.distance.pdist(cluster_zero_signals, metric='cosine'))
# find vector that has minimum avg cos dist to all others in the cluster
average_cos_dist = np.mean(cosine_distances, axis = 1)
min_idx = np.where(average_cos_dist == np.min(average_cos_dist))[0][0]
min_vec = cluster_zero_signals[min_idx]

In [None]:
zero_token_signals = np.array([s for s, f in zip(test_signals, firings) if f[4] in [0] + gemma_puncts])
nonzero_token_signals = np.array([s for s, f in zip(test_signals, firings) if f[4] not in [0] + gemma_puncts])

In [None]:
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
plt.rc('font', size=8)

fig, ax = plt.subplots(1, 1, figsize=(3, 1.9))
sns.kdeplot(zero_token_signals @ min_vec, label = 'zero tokens', bw_adjust = 0.125)
sns.kdeplot(nonzero_token_signals @ min_vec, label = 'non-zero tokens', bw_adjust = 0.125)
#plt.legend(loc = 'best')
plt.legend(loc='lower center', bbox_to_anchor=(0.5, 1), ncol=2);
plt.xlabel("Inner product");
plt.tight_layout()
#plt.title('IP of prototype signal with other signals');
plt.savefig("figures/control_signals/gemma-2-2b_control-signals_ip_v-signals.pdf", bbox_inches='tight', dpi=800);
plt.close()

## GPT2 at the head level

In [None]:
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
plt.rc('font', size=8)

In [None]:
fig, axs = plt.subplots(2, 2, figsize = (5, 3.5))

# Has to be set manually after seeing the plot
n_clusters_appearing_dendogram = {
    0: 6,
    1: 9
}

signal_types = {0: "u_signals", 1: "v_signals"}

for t in range(2): # Your loop
    control_signals = gpt2s.ctrl_sigs[signal_types[t]]
    current_Z = cl.hierarchy.linkage(control_signals, method='average', metric='cosine', optimal_ordering=False)
    n_samples = current_Z.shape[0] + 1
    
    current_cluster_threshold = 0.6

    flat_cluster_ids_for_samples = cl.hierarchy.fcluster(current_Z, current_cluster_threshold, criterion='distance')
    c_sorted_list = list(main_clusters(flat_cluster_ids_for_samples))

    # --- Prepare Color Tuples First ---
    color_tuples_for_dendro = []
    if len(c_sorted_list) > 0:
        color_tuples_for_dendro = list(sns.color_palette("deep", len(c_sorted_list)+1))[1:] # List of RGB tuples
        #color_tuples_for_dendro[0] = (0, 0, 1) #(1.0, 1.0, 1.0) # White RGB tuple
        #color_tuples_for_dendro = [(0, 0, 1)] * len(c_sorted_list)
    
    above_threshold_color_tuple = (0.5, 0.5, 0.5, 0.3) # Gray RGBA tuple

    # --- Convert Color Tuples to Hex Strings ---
    hex_dendrogram_link_colors = []
    for rgb_tuple in color_tuples_for_dendro:
        # sns.color_palette for "deep" returns RGB, so keep_alpha=False
        hex_dendrogram_link_colors.append(mcolors.to_hex(rgb_tuple, keep_alpha=False)) 
    
    # above_threshold_color_tuple has alpha, so keep_alpha=True
    hex_above_threshold_color = mcolors.to_hex(above_threshold_color_tuple, keep_alpha=True)

    # --- Define the custom link_color_func using Hex Colors ---
    _Z_for_func = current_Z
    _n_samples_for_func = n_samples
    _cluster_threshold_for_func = current_cluster_threshold
    _flat_ids_for_samples_for_func = flat_cluster_ids_for_samples
    _c_sorted_list_for_func = c_sorted_list
    _dendro_colors_hex_for_func = hex_dendrogram_link_colors # Use hex colors
    _above_color_hex_for_func = hex_above_threshold_color   # Use hex color

    def custom_dendrogram_link_color_func(link_k):
        link_merge_idx = link_k - _n_samples_for_func
        link_distance = _Z_for_func[link_merge_idx, 2]

        if link_distance >= _cluster_threshold_for_func:
            return _above_color_hex_for_func # Return hex string
        else:
            current_node = link_k
            while current_node >= _n_samples_for_func:
                current_node = int(_Z_for_func[current_node - _n_samples_for_func, 0])
            one_leaf_idx = current_node
            
            target_flat_cid = _flat_ids_for_samples_for_func[one_leaf_idx]
            
            try:
                cmap_idx = _c_sorted_list_for_func.index(target_flat_cid)
                if 0 <= cmap_idx < len(_dendro_colors_hex_for_func):
                    return _dendro_colors_hex_for_func[cmap_idx] # Return hex string
                else: 
                    return _above_color_hex_for_func
            except ValueError:
                return _above_color_hex_for_func

    # --- Plotting the Dendrogram ---
    cl.hierarchy.dendrogram(
        current_Z,
        link_color_func=custom_dendrogram_link_color_func,
        ax=axs[0, t],
        above_threshold_color=_above_color_hex_for_func # Consistent with link_color_func returns
    )
    axs[0, t].set_xticks([])
    axs[0, t].set_ylabel("Cosine similarity")
    axs[0, t].set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
    axs[0, t].set_yticklabels([1.0, 0.8, 0.6, 0.4, 0.2, 0.0])

    # --- Your existing Heatmap code (slightly adapted for clarity) ---
    headmap = np.zeros((12, 12), dtype=int) - 1
    for c_map_idx, actual_cluster_id_val in enumerate(c_sorted_list):
        for i in np.where(flat_cluster_ids_for_samples == actual_cluster_id_val)[0]:
            layer, ah_idx = gpt2s.ctrld_heads[signal_types[t]][i]
            headmap[layer, ah_idx] = c_map_idx

    # Prepare the exact cmap for the heatmap (using original tuples is fine for seaborn)
    if len(c_sorted_list) > 0:
        heatmap_specific_color_tuples = list(sns.color_palette("deep", len(c_sorted_list)))
        heatmap_specific_color_tuples[0] = (1.0, 1.0, 1.0)
        final_heatmap_cmap = heatmap_specific_color_tuples
    else:
        final_heatmap_cmap = None 

    for c in range(n_clusters_appearing_dendogram[t]+1, len(final_heatmap_cmap)):
        final_heatmap_cmap[c] = (1, 1, 1) # setting every cluster not appearing in the dendogram to white

    sns.heatmap(headmap, ax=axs[1, t], cmap=final_heatmap_cmap, cbar=False, annot=False)
    axs[1, t].set_ylabel("Layer")
    axs[1, t].set_xlabel("Attention Head Index")

plt.tight_layout()
plt.savefig("figures/control_signals/gpt-2-small_control-signals.pdf", bbox_inches='tight', dpi=800);
plt.close()

## Pythia at the Head level

In [None]:
fig, axs = plt.subplots(2, 2, figsize = (5, 3.5))

# Has to be set manually after seeing the plot
n_clusters_appearing_dendogram = {
    0: 6,
    1: 9
}

signal_types = {0: "u_signals", 1: "v_signals"}

for t in range(2): # Your loop
    control_signals = pyth.ctrl_sigs[signal_types[t]]
    current_Z = cl.hierarchy.linkage(control_signals, method='average', metric='cosine', optimal_ordering=False)
    n_samples = current_Z.shape[0] + 1
    
    current_cluster_threshold = 0.7

    flat_cluster_ids_for_samples = cl.hierarchy.fcluster(current_Z, current_cluster_threshold, criterion='distance')
    c_sorted_list = list(main_clusters(flat_cluster_ids_for_samples))

    # --- Prepare Color Tuples First ---
    color_tuples_for_dendro = []
    if len(c_sorted_list) > 0:
        color_tuples_for_dendro = list(sns.color_palette("deep", len(c_sorted_list)+1))[1:] # List of RGB tuples
        #color_tuples_for_dendro[0] = (0, 0, 1) #(1.0, 1.0, 1.0) # White RGB tuple
        #color_tuples_for_dendro = [(0, 0, 1)] * len(c_sorted_list)
    
    above_threshold_color_tuple = (0.5, 0.5, 0.5, 0.3) # Gray RGBA tuple

    # --- Convert Color Tuples to Hex Strings ---
    hex_dendrogram_link_colors = []
    for rgb_tuple in color_tuples_for_dendro:
        # sns.color_palette for "deep" returns RGB, so keep_alpha=False
        hex_dendrogram_link_colors.append(mcolors.to_hex(rgb_tuple, keep_alpha=False)) 
    
    # above_threshold_color_tuple has alpha, so keep_alpha=True
    hex_above_threshold_color = mcolors.to_hex(above_threshold_color_tuple, keep_alpha=True)

    # --- Define the custom link_color_func using Hex Colors ---
    _Z_for_func = current_Z
    _n_samples_for_func = n_samples
    _cluster_threshold_for_func = current_cluster_threshold
    _flat_ids_for_samples_for_func = flat_cluster_ids_for_samples
    _c_sorted_list_for_func = c_sorted_list
    _dendro_colors_hex_for_func = hex_dendrogram_link_colors # Use hex colors
    _above_color_hex_for_func = hex_above_threshold_color   # Use hex color

    def custom_dendrogram_link_color_func(link_k):
        link_merge_idx = link_k - _n_samples_for_func
        link_distance = _Z_for_func[link_merge_idx, 2]

        if link_distance >= _cluster_threshold_for_func:
            return _above_color_hex_for_func # Return hex string
        else:
            current_node = link_k
            while current_node >= _n_samples_for_func:
                current_node = int(_Z_for_func[current_node - _n_samples_for_func, 0])
            one_leaf_idx = current_node
            
            target_flat_cid = _flat_ids_for_samples_for_func[one_leaf_idx]
            
            try:
                cmap_idx = _c_sorted_list_for_func.index(target_flat_cid)
                if 0 <= cmap_idx < len(_dendro_colors_hex_for_func):
                    return _dendro_colors_hex_for_func[cmap_idx] # Return hex string
                else: 
                    return _above_color_hex_for_func
            except ValueError:
                return _above_color_hex_for_func

    # --- Plotting the Dendrogram ---
    cl.hierarchy.dendrogram(
        current_Z,
        link_color_func=custom_dendrogram_link_color_func,
        ax=axs[0, t],
        above_threshold_color=_above_color_hex_for_func # Consistent with link_color_func returns
    )
    axs[0, t].set_xticks([])
    axs[0, t].set_ylabel("Cosine similarity")
    axs[0, t].set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
    axs[0, t].set_yticklabels([1.0, 0.8, 0.6, 0.4, 0.2, 0.0])


    # --- Your existing Heatmap code (slightly adapted for clarity) ---
    headmap = np.zeros((12, 12), dtype=int) - 1
    for c_map_idx, actual_cluster_id_val in enumerate(c_sorted_list):
        for i in np.where(flat_cluster_ids_for_samples == actual_cluster_id_val)[0]:
            layer, ah_idx = pyth.ctrld_heads[signal_types[t]][i]
            headmap[layer, ah_idx] = c_map_idx

    # Prepare the exact cmap for the heatmap (using original tuples is fine for seaborn)
    if len(c_sorted_list) > 0:
        heatmap_specific_color_tuples = list(sns.color_palette("deep", len(c_sorted_list)))
        heatmap_specific_color_tuples[0] = (1.0, 1.0, 1.0)
        final_heatmap_cmap = heatmap_specific_color_tuples
    else:
        final_heatmap_cmap = None 

    for c in range(n_clusters_appearing_dendogram[t]+1, len(final_heatmap_cmap)):
        final_heatmap_cmap[c] = (1, 1, 1) # setting every cluster not appearing in the dendogram to white

    sns.heatmap(headmap, ax=axs[1, t], cmap=final_heatmap_cmap, cbar=False, annot=False)
    axs[1, t].set_ylabel("Layer")
    axs[1, t].set_xlabel("Attention Head Index")

plt.tight_layout()
plt.savefig("figures/control_signals/pythia-160m_control-signals.pdf", bbox_inches='tight', dpi=800);
plt.close()

## Gemma-2 2B at the Head level

In [None]:
fig, axs = plt.subplots(2, 2, figsize = (5, 3.5))

# Has to be set manually after seeing the plot
n_clusters_appearing_dendogram = {
    0: 4,
    1: 9
}

signal_types = {0: "u_signals", 1: "v_signals"}

for t in range(2): # Your loop
    control_signals = gemma.ctrl_sigs[signal_types[t]]
    current_Z = cl.hierarchy.linkage(control_signals, method='average', metric='cosine', optimal_ordering=False)
    n_samples = current_Z.shape[0] + 1
    
    current_cluster_threshold = 0.75

    flat_cluster_ids_for_samples = cl.hierarchy.fcluster(current_Z, current_cluster_threshold, criterion='distance')
    c_sorted_list = list(main_clusters(flat_cluster_ids_for_samples))

    # --- Prepare Color Tuples First ---
    color_tuples_for_dendro = []
    if len(c_sorted_list) > 0:
        color_tuples_for_dendro = list(sns.color_palette("deep", len(c_sorted_list)+1))[1:] # List of RGB tuples
        #color_tuples_for_dendro[0] = (0, 0, 1) #(1.0, 1.0, 1.0) # White RGB tuple
        #color_tuples_for_dendro = [(0, 0, 1)] * len(c_sorted_list)
    
    above_threshold_color_tuple = (0.5, 0.5, 0.5, 0.3) # Gray RGBA tuple
    #above_threshold_color_tuple = (0., 0., 1.0, 1.0) # Blue RGBA tuple

    # --- Convert Color Tuples to Hex Strings ---
    hex_dendrogram_link_colors = []
    for rgb_tuple in color_tuples_for_dendro:
        # sns.color_palette for "deep" returns RGB, so keep_alpha=False
        hex_dendrogram_link_colors.append(mcolors.to_hex(rgb_tuple, keep_alpha=False)) 
    
    # above_threshold_color_tuple has alpha, so keep_alpha=True
    hex_above_threshold_color = mcolors.to_hex(above_threshold_color_tuple, keep_alpha=True)

    # --- Define the custom link_color_func using Hex Colors ---
    _Z_for_func = current_Z
    _n_samples_for_func = n_samples
    _cluster_threshold_for_func = current_cluster_threshold
    _flat_ids_for_samples_for_func = flat_cluster_ids_for_samples
    _c_sorted_list_for_func = c_sorted_list
    _dendro_colors_hex_for_func = hex_dendrogram_link_colors # Use hex colors
    _above_color_hex_for_func = hex_above_threshold_color   # Use hex color

    def custom_dendrogram_link_color_func(link_k):
        link_merge_idx = link_k - _n_samples_for_func
        link_distance = _Z_for_func[link_merge_idx, 2]

        if link_distance >= _cluster_threshold_for_func:
            return _above_color_hex_for_func # Return hex string
        else:
            current_node = link_k
            while current_node >= _n_samples_for_func:
                current_node = int(_Z_for_func[current_node - _n_samples_for_func, 0])
            one_leaf_idx = current_node
            
            target_flat_cid = _flat_ids_for_samples_for_func[one_leaf_idx]
            
            try:
                cmap_idx = _c_sorted_list_for_func.index(target_flat_cid)
                if 0 <= cmap_idx < len(_dendro_colors_hex_for_func):
                    return _dendro_colors_hex_for_func[cmap_idx] # Return hex string
                else: 
                    return _above_color_hex_for_func
            except ValueError:
                return _above_color_hex_for_func

    # --- Plotting the Dendrogram ---
    cl.hierarchy.dendrogram(
        current_Z,
        link_color_func=custom_dendrogram_link_color_func,
        ax=axs[0, t],
        above_threshold_color=_above_color_hex_for_func # Consistent with link_color_func returns
    )
    axs[0, t].set_xticks([])
    axs[0, t].set_ylabel("Cosine similarity")
    axs[0, t].set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
    axs[0, t].set_yticklabels([1.0, 0.8, 0.6, 0.4, 0.2, 0.0])

    # --- Your existing Heatmap code (slightly adapted for clarity) ---
    headmap = np.zeros((26, 8), dtype=int) - 1
    for c_map_idx, actual_cluster_id_val in enumerate(c_sorted_list):
        for i in np.where(flat_cluster_ids_for_samples == actual_cluster_id_val)[0]:
            layer, ah_idx = gemma.ctrld_heads[signal_types[t]][i]
            headmap[layer, ah_idx] = c_map_idx

    # Prepare the exact cmap for the heatmap (using original tuples is fine for seaborn)
    if len(c_sorted_list) > 0:
        heatmap_specific_color_tuples = list(sns.color_palette("deep", len(c_sorted_list)))
        heatmap_specific_color_tuples[0] = (1.0, 1.0, 1.0)
        final_heatmap_cmap = heatmap_specific_color_tuples
    else:
        final_heatmap_cmap = None 

    for c in range(n_clusters_appearing_dendogram[t]+1, len(final_heatmap_cmap)):
        final_heatmap_cmap[c] = (1, 1, 1) # setting every cluster not appearing in the dendogram to white

    sns.heatmap(headmap, ax=axs[1, t], cmap=final_heatmap_cmap, cbar=False, annot=False)
    axs[1, t].set_ylabel("Layer")
    axs[1, t].set_xlabel("Attention Head Index")

plt.tight_layout()
plt.savefig("figures/control_signals/gemma-2-2b_control-signals.pdf", bbox_inches='tight', dpi=800);
plt.close()

# Which components are adding the control signals?

In [None]:
def get_proto_signals(mod, test_signals, cluster_threshold = 0.65):
    Z = cl.hierarchy.linkage(mod.ctrl_sigs[test_signals], 'average', 'cosine', optimal_ordering = False)
    cl.hierarchy.dendrogram(Z, color_threshold = cluster_threshold)
    cluster_id = cl.hierarchy.fcluster(Z, cluster_threshold, criterion = 'distance')
    c_sorted = main_clusters(cluster_id)
    proto_signals = []
    head_map = {}
    for c in range(len(c_sorted)):
        accum = []
        for i in np.where(cluster_id == c_sorted[c])[0]:
            test_layer, test_ah_idx = mod.ctrld_heads[test_signals][i]
            head_map[test_layer, test_ah_idx] = c
            accum.append(mod.ctrl_sigs[test_signals][i])
        # yes, we are taking the mean of means here; does it matter?
        proto_signals.append(np.array(accum).mean(axis = 0))
    proto_signals = np.array(proto_signals)
    return proto_signals, head_map

## GPT-2

In [None]:
gpt2s_v_proto_signals, gpt2s_v_head_map = get_proto_signals(gpt2s, 'v_signals', cluster_threshold = 0.6)

In [None]:
gpt2s_u_proto_signals, gpt2s_u_head_map = get_proto_signals(gpt2s, 'u_signals', cluster_threshold = 0.6)

In [None]:
# # # Who is adding this?

model_name = "gpt2-small"
prompt_id = 8
n_tokens = gpt2s.ioi_dataset.word_idx["end"][prompt_id].item() + 1
n_max_tokens = 21
for signal_type in ["u-signals", "v-signals"]:
    ip_diffs = torch.zeros((2, gpt2s.model.cfg.n_layers * 3, n_max_tokens)) - 1

    markers_to_use = ['o', 's', '^']
    marker_labels = ["Residual (pre)", "Residual (mid)", "Residual (post)"]
    plt.figure(figsize=(4, 1.9))

    for cid in range(2):
        if signal_type == "u-signals":
            print("U-signals")
            control_signal = gpt2s_u_proto_signals[cid]
            control_signal = control_signal / np.linalg.norm(control_signal)
            control_signal = torch.from_numpy(control_signal)
            # We are interested on the end token for U signals
            token_plot_gram = "end"
            token_plot = gpt2s.ioi_dataset.word_idx[token_plot_gram][prompt_id].item()

        else:
            print("V-signals")
            control_signal = gpt2s_v_proto_signals[cid]
            control_signal = control_signal / np.linalg.norm(control_signal)
            control_signal = torch.from_numpy(control_signal)
            # We are interested on the starts token for U signals
            token_plot_gram = "starts"
            token_plot = gpt2s.ioi_dataset.word_idx[token_plot_gram][prompt_id].item()

        i = 0 
        for layer in range(gpt2s.model.cfg.n_layers):
            for tok in range(n_max_tokens):
                if layer == 0:
                    ip_diffs[cid, i, tok] = F.cosine_similarity(gpt2s.cache[f"blocks.{layer}.hook_resid_pre"][prompt_id, tok], control_signal, dim=0)
                else:
                    ip_diffs[cid, i, tok] = F.cosine_similarity(gpt2s.cache[f"blocks.{layer}.hook_resid_pre"][prompt_id, tok], control_signal, dim=0)
                
                ip_diffs[cid, i+1, tok] = F.cosine_similarity(gpt2s.cache[f"blocks.{layer}.hook_resid_mid"][prompt_id, tok], control_signal, dim=0)
                ip_diffs[cid, i+2, tok] = F.cosine_similarity(gpt2s.cache[f"blocks.{layer}.hook_resid_post"][prompt_id, tok], control_signal, dim=0)

            i+=3
        
        # Plotting
        indices = np.arange(len(ip_diffs[cid, :, token_plot]))
        
        plt.plot(indices, ip_diffs[cid, :, token_plot], linestyle='-', color='lightgray', alpha=0.7, label='_nolegend_') # Base line

        labels_mapping = {0: "Residual (pre)", 1: "Residual (mid)", 2: "Residual (post)"}

        for marker_condition in range(3):
            x_values_group = indices[marker_condition::3]
            y_values_group = ip_diffs[cid, :, token_plot][marker_condition::3]
            
            plt.plot(x_values_group, y_values_group,
                    marker=markers_to_use[marker_condition],
                    linestyle='None',  # 'None' means no line connecting these specific marked points
                    label=labels_mapping[marker_condition],
                    markersize=3,
                    color=final_heatmap_cmap[cid+1]) # Adjust marker size if needed

    # --- Create Handles and Their Corresponding Labels for the Custom Legend ---
    legend_handles = []

    # 1. Handles for Line Colors (Clusters)
    handle_line_cluster0 = mlines.Line2D([], [], # Empty line, only for legend
                                        color=final_heatmap_cmap[1],
                                        linestyle='-',
                                        marker='None', # No marker for this legend entry
                                        label="Cluster 0")
    legend_handles.append(handle_line_cluster0)

    handle_line_cluster1 = mlines.Line2D([], [],
                                        color=final_heatmap_cmap[2],
                                        linestyle='-',
                                        marker='None',
                                        label="Cluster 1")
    legend_handles.append(handle_line_cluster1)

    # 2. Handles for Marker Types (Residual conditions)
    for i in range(len(markers_to_use)):
        handle_marker = mlines.Line2D([], [],
                                    color="black", # Neutral color for marker symbol in legend
                                    marker=markers_to_use[i],
                                    linestyle='None', # Only show the marker
                                    label=marker_labels[i])
        legend_handles.append(handle_marker)

    # --- Add the Single, Custom Legend to the Plot ---
    # The labels are taken from the 'label' attribute of each handle.
    plt.legend(handles=legend_handles,
            loc='lower center',
            bbox_to_anchor=(0.5, 1.02),
            title=None,
            ncol=3,
            fontsize=6)

    plt.xlabel('Layer')
    plt.ylabel('Cosine similarity')
    plt.xticks(range(0, ip_diffs.shape[1], 3), labels=[f"Layer {i}" for i, _ in enumerate(range(0, ip_diffs.shape[1], 3))], rotation=90);
    #plt.legend(fontsize=6) # Display the legend to identify markers
    plt.grid(True) # Add a grid for better readability
    #plt.tight_layout()
    filename = f"figures/control_signals/{model_name}_ioi_{signal_type}_{token_plot_gram}_pid-{prompt_id}.pdf"
    plt.savefig(filename, bbox_inches='tight', dpi=800);
    plt.close()

## Pythia

In [None]:
pyth_v_proto_signals, pyth_v_head_map = get_proto_signals(pyth, 'v_signals', cluster_threshold = 0.7)

In [None]:
pyth_u_proto_signals, pyth_u_head_map = get_proto_signals(pyth, 'u_signals', cluster_threshold = 0.7)

In [None]:
# # # Who is adding this? Pythia

model_name = "pythia-160m"
prompt_id = 8
n_tokens = gpt2s.ioi_dataset.word_idx["end"][prompt_id].item() + 1
n_max_tokens = 21
for signal_type in ["u-signals", "v-signals"]:

    ip_diffs = torch.zeros((2, pyth.model.cfg.n_layers * 2, n_max_tokens)) - 1

    markers_to_use = ['o', 's']
    marker_labels = ["Residual (pre)", "Residual (post)"]
    plt.figure(figsize=(4, 1.9))

    for cid in range(2):
        if signal_type == "u-signals":
            print("U-signals")
            control_signal = pyth_u_proto_signals[cid]
            control_signal = control_signal / np.linalg.norm(control_signal)
            control_signal = torch.from_numpy(control_signal)
            # We are interested on the end token for U signals
            token_plot_gram = "end"
            token_plot = pyth.ioi_dataset.word_idx[token_plot_gram][prompt_id].item()

        else:
            print("V-signals")
            control_signal = pyth_v_proto_signals[cid]
            control_signal = control_signal / np.linalg.norm(control_signal)
            control_signal = torch.from_numpy(control_signal)
            # We are interested on the starts token for U signals
            token_plot_gram = "starts"
            token_plot = pyth.ioi_dataset.word_idx[token_plot_gram][prompt_id].item()

        i = 0 
        for layer in range(pyth.model.cfg.n_layers):
            for tok in range(n_max_tokens):
                ip_diffs[cid, i, tok] = F.cosine_similarity(pyth.cache[f"blocks.{layer}.hook_resid_pre"][prompt_id, tok], control_signal, dim=0)
                ip_diffs[cid, i+1, tok] = F.cosine_similarity(pyth.cache[f"blocks.{layer}.hook_resid_post"][prompt_id, tok], control_signal, dim=0)
            i+=2
        
        # Plotting
        indices = np.arange(len(ip_diffs[cid, :, token_plot]))
        
        plt.plot(indices, ip_diffs[cid, :, token_plot], linestyle='-', color='lightgray', alpha=0.7, label='_nolegend_') # Base line

        labels_mapping = {0: "Residual (pre)", 1: "Residual (post)"}

        for marker_condition in range(2):
            x_values_group = indices[marker_condition::2]
            y_values_group = ip_diffs[cid, :, token_plot][marker_condition::2]
            
            plt.plot(x_values_group, y_values_group,
                    marker=markers_to_use[marker_condition],
                    linestyle='None',  # 'None' means no line connecting these specific marked points
                    label=labels_mapping[marker_condition],
                    markersize=3,
                    color=final_heatmap_cmap[cid+1]) # Adjust marker size if needed


    # --- Create Handles and Their Corresponding Labels for the Custom Legend ---
    legend_handles = []

    # 1. Handles for Line Colors (Clusters)
    handle_line_cluster0 = mlines.Line2D([], [], # Empty line, only for legend
                                        color=final_heatmap_cmap[1],
                                        linestyle='-',
                                        marker='None', # No marker for this legend entry
                                        label="Cluster 0")
    legend_handles.append(handle_line_cluster0)

    handle_line_cluster1 = mlines.Line2D([], [],
                                        color=final_heatmap_cmap[2],
                                        linestyle='-',
                                        marker='None',
                                        label="Cluster 1")
    legend_handles.append(handle_line_cluster1)

    # 2. Handles for Marker Types (Residual conditions)
    for i in range(len(markers_to_use)):
        handle_marker = mlines.Line2D([], [],
                                    color="black", # Neutral color for marker symbol in legend
                                    marker=markers_to_use[i],
                                    linestyle='None', # Only show the marker
                                    label=marker_labels[i])
        legend_handles.append(handle_marker)

    # --- Add the Single, Custom Legend to the Plot ---
    # The labels are taken from the 'label' attribute of each handle.
    plt.legend(handles=legend_handles,
            loc='lower center',
            bbox_to_anchor=(0.5, 1.02),
            title=None,
            ncol=3,
            fontsize=6)

    #### END HERE
    plt.xlabel('Layer')
    plt.ylabel('Cosine similarity')
    plt.xticks(range(0, ip_diffs.shape[1], 2), labels=[f"Layer {i}" for i, _ in enumerate(range(0, ip_diffs.shape[1], 2))], rotation=90);
    #plt.legend(fontsize=6) # Display the legend to identify markers
    plt.grid(True) # Add a grid for better readability
    #plt.tight_layout()
    filename = f"figures/control_signals/{model_name}_ioi_{signal_type}_{token_plot_gram}_pid-{prompt_id}.pdf"
    plt.savefig(filename, bbox_inches='tight', dpi=800);
    plt.close()

## Gemma-2 2B

In [None]:
gemma_v_proto_signals, gemma_v_head_map = get_proto_signals(gemma, 'v_signals', cluster_threshold = 0.75)

In [None]:
gemma_u_proto_signals, gemma_u_head_map = get_proto_signals(gemma, 'u_signals', cluster_threshold = 0.75)

In [None]:
import matplotlib
import matplotlib.lines as mlines # For creating Line2D proxy artists for the legend
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
plt.rc('font', size=8)

In [None]:
# # # Who is adding this?

model_name = "gemma-2-2b"
prompt_id = 8
n_tokens = gemma.ioi_dataset.word_idx["end"][prompt_id].item() + 1
n_max_tokens = 21
for signal_type in ["u-signals", "v-signals"]:
    ip_diffs = torch.zeros((2, gemma.model.cfg.n_layers * 3, n_max_tokens)) - 1

    markers_to_use = ['o', 's', '^']
    marker_labels = ["Residual (pre)", "Residual (mid)", "Residual (post)"]
    plt.figure(figsize=(4, 1.9))

    for cid in range(2):
        if signal_type == "u-signals":
            print("U-signals")
            control_signal = gemma_u_proto_signals[cid]
            control_signal = control_signal / np.linalg.norm(control_signal)
            control_signal = torch.from_numpy(control_signal)
            # We are interested on the end token for U signals
            token_plot_gram = "end"
            token_plot = gemma.ioi_dataset.word_idx[token_plot_gram][prompt_id].item()

        else:
            print("V-signals")
            control_signal = gemma_v_proto_signals[cid]
            control_signal = control_signal / np.linalg.norm(control_signal)
            control_signal = torch.from_numpy(control_signal)
            # We are interested on the starts token for U signals
            token_plot_gram = "starts"
            token_plot = gemma.ioi_dataset.word_idx[token_plot_gram][prompt_id].item()

        i = 0 
        for layer in range(gemma.model.cfg.n_layers):
            for tok in range(n_max_tokens):
                if layer == 0:
                    ip_diffs[cid, i, tok] = F.cosine_similarity(gemma.cache[f"blocks.{layer}.hook_resid_pre"][prompt_id, tok], control_signal, dim=0)
                else:
                    ip_diffs[cid, i, tok] = F.cosine_similarity(gemma.cache[f"blocks.{layer}.hook_resid_pre"][prompt_id, tok], control_signal, dim=0)
                
                ip_diffs[cid, i+1, tok] = F.cosine_similarity(gemma.cache[f"blocks.{layer}.hook_resid_mid"][prompt_id, tok], control_signal, dim=0)
                ip_diffs[cid, i+2, tok] = F.cosine_similarity(gemma.cache[f"blocks.{layer}.hook_resid_post"][prompt_id, tok], control_signal, dim=0)

            i+=3
        
        # Plotting
        indices = np.arange(len(ip_diffs[cid, :, token_plot]))
        
        plt.plot(indices, ip_diffs[cid, :, token_plot], linestyle='-', color='lightgray', alpha=0.7, label='_nolegend_') # Base line

        labels_mapping = {0: "Residual (pre)", 1: "Residual (mid)", 2: "Residual (post)"}

        for marker_condition in range(3):
            x_values_group = indices[marker_condition::3]
            y_values_group = ip_diffs[cid, :, token_plot][marker_condition::3]
            
            plt.plot(x_values_group, y_values_group,
                    marker=markers_to_use[marker_condition],
                    linestyle='None',  # 'None' means no line connecting these specific marked points
                    label=labels_mapping[marker_condition],
                    markersize=2.5,
                    color=final_heatmap_cmap[cid+1]) # Adjust marker size if needed

    # --- Create Handles and Their Corresponding Labels for the Custom Legend ---
    legend_handles = []

    # 1. Handles for Line Colors (Clusters)
    handle_line_cluster0 = mlines.Line2D([], [], # Empty line, only for legend
                                        color=final_heatmap_cmap[1],
                                        linestyle='-',
                                        marker='None', # No marker for this legend entry
                                        label="Cluster 0")
    legend_handles.append(handle_line_cluster0)

    handle_line_cluster1 = mlines.Line2D([], [],
                                        color=final_heatmap_cmap[2],
                                        linestyle='-',
                                        marker='None',
                                        label="Cluster 1")
    legend_handles.append(handle_line_cluster1)

    # 2. Handles for Marker Types (Residual conditions)
    for i in range(len(markers_to_use)):
        handle_marker = mlines.Line2D([], [],
                                    color="black", # Neutral color for marker symbol in legend
                                    marker=markers_to_use[i],
                                    linestyle='None', # Only show the marker
                                    label=marker_labels[i])
        legend_handles.append(handle_marker)

    # --- Add the Single, Custom Legend to the Plot ---
    # The labels are taken from the 'label' attribute of each handle.
    plt.legend(handles=legend_handles,
            loc='lower center',
            bbox_to_anchor=(0.5, 1.02),
            title=None,
            ncol=3,
            fontsize=6)

    plt.xlabel('Layer')
    plt.ylabel('Cosine similarity')
    plt.xticks(range(0, ip_diffs.shape[1], 3), labels=[f"Layer {i}" for i, _ in enumerate(range(0, ip_diffs.shape[1], 3))], rotation=90);
    #plt.legend(fontsize=6) # Display the legend to identify markers
    plt.grid(True) # Add a grid for better readability
    #plt.tight_layout()
    filename = f"figures/control_signals/{model_name}_ioi_{signal_type}_{token_plot_gram}_pid-{prompt_id}.pdf"
    plt.savefig(filename, bbox_inches='tight', dpi=800);
    plt.close()