# Analyze the post-patching results in depth
We want to probe the activations post activation patching.

In [2]:

#%%
from IPython import get_ipython

ipython = get_ipython()
# Code to automatically update the TransformerLens code as its edited without restarting the kernel
ipython.magic("load_ext autoreload")
ipython.magic("autoreload 2")
    
import plotly.io as pio
# pio.renderers.default = "png"
# Import stuff

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

from tqdm import tqdm
# from utils.probing_utils import ModelActs
from utils.dataset_utils import CounterFact_Dataset, TQA_MC_Dataset, EZ_Dataset

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

from utils.iti_utils import patch_iti

from utils.analytics_utils import plot_probe_accuracies, plot_norm_diffs, plot_cosine_sims
import os
from torch import Tensor
from plotly.subplots import make_subplots

from utils.analytics_utils import plot_z_probe_accuracies, plot_resid_probe_accuracies, acc_tensor_from_dict, get_px_fig

  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


## Format raw activations

Format unformatted activation files with each prompt being a single file, into formatted activations of each component's collated acts.

Only has to be done once.

In [20]:
run_id = 10
# N = 2550 #upper bound the global (level 0) index
d_head = 128
n_layers = 80
n_heads = 64
patch_id = 1
# num_params = "70b"

from utils.cache_utils import create_probe_dataset, create_all_probe_datasets, format_logits
act_type = "z"
data_dir = "/mnt/ssd-2/phillipguo3/patching_acts"
splits  =["azaria_mitchell_facts"]
# create_probe_dataset(run_id, -1, "honest", act_type, data_dir=data_dir, patch_id=patch_id, splits=splits)

from datasets import load_dataset
dataset_name = "notrichardren/truthfulness_high_quality"
dataset = load_dataset(dataset_name)

dataset_indices = [row['ind'] for row in dataset["combined"] if row['dataset'] == "azaria_mitchell_facts"][:300]

# format_logits(dataset_indices, "azaria_mitchell_facts", f"{data_dir}/data/large_run_{run_id}_patch_{patch_id}", run_id=10)

## Initialize ModelActs objects. 

ModelActsLargeSimple is if we want to store all activations in memory at once (memory inefficient but faster transfer accuracy). 

ChunkedModelActs is if we don't want to keep activations in memory, instead load in activations in batches of layers and train probes iteratively, then unload activations. Transfer accuracy is much slower.

In [11]:
from utils.new_probing_utils import ModelActsLargeSimple, ChunkedModelActs

modes = ["honest", "liar"]

clean_acts = {"honest": ModelActsLargeSimple(), "liar": ModelActsLargeSimple()}

patched_acts_0 = {"honest": ModelActsLargeSimple(), "liar": ModelActsLargeSimple()} # clean run is honest

patched_acts_1 = {"honest": ModelActsLargeSimple(), "liar": ModelActsLargeSimple()} # clean run is liar


## Train probes on clean and patched runs

Technically, we don't need clean_acts dictionary, since patched acts for honest -> honest or liar -> liar are equivalent to clean acts

In [None]:
seq_pos = -1
act_types = ["z", "logits"]
dataset_name = "azaria_mitchell_facts"
dont_include = None
run_id = 5
data_folder = f"/mnt/ssd-2/jamescampbell4"

for mode in modes:
    
    for act_type in act_types:
        file_prefix = f"{data_folder}/activations/formatted/run_{run_id}_{mode}"
        if seq_pos is not None:
            file_prefix += f"_{seq_pos}"
        file_prefix += f"_{act_type}"
        if dataset_name is not None:
            file_prefix += f"_{dataset_name}"

        with open(f"{data_folder}/activations/formatted/labels_{run_id}_{mode}_{seq_pos}_z_{dataset_name}.pt", "rb") as handle:
            labels = torch.load(handle)
            # print(f"{labels.shape=}")
        # print(labels)

        clean_acts[mode].load_acts(file_prefix, n_layers, n_heads=n_heads, labels=labels, exclude_points=dont_include, act_type=act_type)

        if act_type != "logits":
            clean_acts[mode].train_probes(act_type, verbose=True, max_iter=10000)
    # elem_acts[label].load_acts_per_layer(f"data/large_run_1/activations/formatted/large_run_1_{label}", n_layers, n_heads, labels, exclude_points=dont_include)
print(f"Dataset Size: {labels.shape[0]}")

In [21]:
seq_pos = -1
act_types = ["z", "logits"]
dataset_name = "azaria_mitchell_facts"
dont_include = None
run_id = 10

for patch_id in range(2):
    # patch_id = 0
    if patch_id == 0:
        patched_acts = patched_acts_0
    else:
        patched_acts = patched_acts_1
    data_folder = f"/mnt/ssd-2/phillipguo3/patching_acts/data/large_run_{run_id}_patch_{patch_id}"

    for mode in modes:
        
        for act_type in act_types:
            file_prefix = f"{data_folder}/activations/formatted/run_{run_id}_{mode}"
            if seq_pos is not None:
                file_prefix += f"_{seq_pos}"
            file_prefix += f"_{act_type}"
            if dataset_name is not None:
                file_prefix += f"_{dataset_name}"

            with open(f"{data_folder}/activations/formatted/labels_{run_id}_{mode}_{seq_pos}_z_{dataset_name}.pt", "rb") as handle:
                labels = torch.load(handle)
                # print(f"{labels.shape=}")
            # print(labels)

            patched_acts[mode].load_acts(file_prefix, n_layers, n_heads=n_heads, labels=labels, exclude_points=dont_include, act_type=act_type)

            if act_type != "logits":
                patched_acts[mode].train_probes(act_type, verbose=True, max_iter=10000, test_ratio=.6)
        # elem_acts[label].load_acts_per_layer(f"data/large_run_1/activations/formatted/large_run_1_{label}", n_layers, n_heads, labels, exclude_points=dont_include)
    print(f"Dataset Size: {labels.shape[0]}")

100%|██████████| 5120/5120 [00:14<00:00, 356.47it/s]
100%|██████████| 5120/5120 [00:14<00:00, 352.60it/s]


Dataset Size: 300


100%|██████████| 5120/5120 [00:13<00:00, 365.74it/s]
100%|██████████| 5120/5120 [00:15<00:00, 341.27it/s]


Dataset Size: 300


In [22]:
from transformers import LlamaModel, LlamaForCausalLM, LlamaTokenizer
weights_dir = f"{os.getcwd()}/llama-weights-70b"
checkpoint_location = weights_dir
tokenizer = LlamaTokenizer.from_pretrained(checkpoint_location)


In [25]:
for acts in [patched_acts_0, patched_acts_1]:
    if acts is patched_acts_0:
        print("Patched from honest")
    else:
        print("Patched from liar")
    for mode in modes:
        correct_acc, incorrect_acc = acts[mode].get_inference_accuracy(tokenizer)
        print(f"{mode} model, {correct_acc.mean()=}, {incorrect_acc.mean()=}")

Patched from honest
honest model, correct_acc.mean()=0.5985384, incorrect_acc.mean()=0.050292112
liar model, correct_acc.mean()=0.43267593, incorrect_acc.mean()=0.17278887
Patched from liar
honest model, correct_acc.mean()=0.22346856, incorrect_acc.mean()=0.3449059
liar model, correct_acc.mean()=0.506502, incorrect_acc.mean()=0.30651087


In [None]:
import resource
print(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss) # check memory usage

## Show probe accuracies with clean and corrupt runs

In [None]:
probe_accs_fig = make_subplots(rows=len(modes), cols=len(modes))

for row, patched_acts in enumerate([patched_acts_0, patched_acts_1]): # rows are different clean runs
    for col, mode in enumerate(modes): # columns are different corrupt runs
        px_fig = plot_z_probe_accuracies(patched_acts[mode], mode, act_type, run_id, patch_id)

        probe_accs_fig.add_trace(
            px_fig['data'][0],  # add the trace from plotly express figure
            row=row+1,
            col=col+1
        )

for idx1 in range(1, len(modes)+1):
    probe_accs_fig.update_xaxes(title_text=f"Corrupt Run {modes[idx1-1]}", row=2, col=idx1)

for idx2 in range(1, len(modes)+1):
    probe_accs_fig.update_yaxes(title_text=f"Clean Run  {modes[idx2-1]}", row=idx2, col=1)

probe_accs_fig.update_layout(title_text=f"{act_type} Probe Accuracies After Patching, Dataset {dataset_name}", showlegend=False)
probe_accs_fig.show()

## Test transfer from clean to patched acts

Check how clean probes (both honest and liar) transfer to corrupt acts (first honest->liar, then liar->honest)

In [None]:
train_acts = {"honest": patched_acts_0["honest"], "liar": patched_acts_1["liar"]}
test_acts = {"honest->liar": patched_acts_0["liar"], "liar->honest": patched_acts_1["honest"]}

from utils.analytics_utils import plot_transfer_acc_subplots

transfer_acc_tensors, fig = plot_transfer_acc_subplots(train_acts, test_acts, act_type="z")

In [None]:
for idx1 in range(1, len(modes)+1):
    fig.update_xaxes(title_text=f"Tested on Corrupt Acts {list(test_acts.keys())[idx1-1]}", row=2, col=idx1)

for idx2 in range(1, len(modes)+1):
    fig.update_yaxes(title_text=f"Clean Probes from {modes[idx2-1]}", row=idx2, col=1)

fig.update_layout(title_text=f"Transfer {act_type} Probe Accuracies, Dataset {dataset_name}", height=1000)
fig.show()

# Cosine Similarities of probe weights between ModelActs

In [None]:
fig = make_subplots(rows=len(modes), cols=len(modes))
act_type = "mlp_out"

if act_type == "z":
    cosine_similarities = np.zeros(shape=(len(modes),len(modes), n_layers, n_heads))
else:
    cosine_similarities = np.zeros(shape=(len(modes),len(modes), n_layers))

for idx1, mode in enumerate(modes, start=1):
    for idx2, other_mode in enumerate(modes, start=1):
        cos_sims = {}
        for probe_index in tqdm(elem_acts["honest"].probes[act_type]):
            coefs_1 = elem_acts[mode].probes[act_type] [probe_index].coef_.squeeze()
            coefs_2 = elem_acts[other_mode].probes[act_type] [probe_index].coef_.squeeze()
            cos_sims[probe_index] = np.dot(coefs_1, coefs_2)/(np.linalg.norm(coefs_1)*np.linalg.norm(coefs_2))


        px_fig = get_px_fig(act_type, cos_sims, n_layers, n_heads, title = f"Cosine Similarities, Probes from {mode} with {other_mode}", graph_type="square")

        if act_type == "z":
            cosine_similarities[idx1-1, idx2-1] = acc_tensor_from_dict(cos_sims, n_layers, n_heads)
        else:
            cosine_similarities[idx1-1, idx2-1] = acc_tensor_from_dict(cos_sims, n_layers)

        fig.add_trace(
            px_fig['data'][0],  # add the trace from plotly express figure
            row=idx1,
            col=idx2
        )


for idx1 in range(1, 4):
    fig.update_xaxes(title_text=f"Tested on {modes[idx1-1]}", row=3, col=idx1)

for idx2 in range(1, 4):
    fig.update_yaxes(title_text=f"Trained on {modes[idx2-1]}", row=idx2, col=1)

fig.update_layout(title_text=f"Cosine Similarities of {act_type} Probe Coefficients", height=1000)
fig.show()

In [None]:
import matplotlib.pyplot as plt

percentiles = [99.8, 99, 95, 50]

def get_probe_percentiles(accs, percentiles):
    acc_percentiles = []
    for percentile in percentiles:
        acc_percentiles.append(np.percentile(accs, percentile))
    return acc_percentiles

# Set the figure size
plt.figure(figsize=(15, 15))

# Define the bar width
bar_width = 0.05

# Create lists to hold plot data
bar_positions = []
bar_heights = []
bar_labels = []
bar_colors = ['r', 'g', 'b', 'y']  # colors for different percentiles

for idx1, mode in enumerate(modes, start=1):
    for idx2, other_mode in enumerate(modes, start=1):
        acc_tensor = cosine_similarities[idx1-1, idx2-1]
        probe_percentiles = get_probe_percentiles(acc_tensor, percentiles)
        
                # For each percentile, add a new bar
        for i, percentile in enumerate(probe_percentiles):
            bar_positions.append(idx1 + idx2/3.5 + i/20)  # increment position for each percentile
            bar_heights.append(percentile)
            bar_labels.append(f'{mode}-{other_mode}')
            
# Plot horizontal bar chart
plt.barh(bar_positions, bar_heights, color=bar_colors, height=bar_width)

# Set labels for y-ticks
# plt.yticks(bar_positions, bar_labels)
plt.yticks([idx1 + idx2/3.5 + 0.015*len(percentiles) for idx1, mode in enumerate(modes, start=1) for idx2, other_mode in enumerate(modes, start=1)], bar_labels[::len(percentiles)])

legends = {percentiles[i]: bar_colors[i] for i in range(len(percentiles))}
handles = [plt.Rectangle((0,0),1,1, color=legends[p]) for p in percentiles]
plt.legend(handles, [f'{p}th Percentile' for p in percentiles])

plt.title("Percentiles for Cosine Similarities")
plt.xlabel("Cosine Sim")
plt.show()

More Analytics: Bar Chart of probe accuracies at different percentiles, and graph plotting average probe accuracy vs layer

In [None]:
def get_layer_accs(acc_tensor, percentile=50):
    layer_percentiles = []
    for layer in range(acc_tensor.shape[0]):
        layer_percentiles.append(np.percentile(acc_tensor[layer], percentile))
    return np.array(layer_percentiles)

# line_styles for each mode1
line_styles = ['-', '--', '-.']

# colors for each mode2
colors = ['b', 'g', 'r']

# Set the figure size
plt.figure(figsize=(15, 8))

# Loop through each mode1
for idx1, mode1 in enumerate(modes):
    # Loop through each mode2
    for idx2, mode2 in enumerate(modes):
        acc_tensor = transfer_acc_tensors[idx1, idx2]
        avg_acc = get_layer_accs(acc_tensor, percentile=50)
        print(avg_acc.shape)
        # Create a line graph with specific line style and color for each mode pair
        plt.plot(avg_acc, linestyle=line_styles[idx1 % len(line_styles)], color=colors[idx2 % len(colors)], label=f'{mode1}-{mode2}')

plt.xlabel('Layer')
plt.ylabel('Average Accuracy')
plt.title('Average Accuracy for Each Layer')

from matplotlib.lines import Line2D  # for creating custom legend
legend_elements = [Line2D([0], [0], color='k', linestyle=line_styles[i], label=modes[i]) for i in range(len(modes))] + \
                  [Line2D([0], [0], color=colors[i], linestyle='-', label=modes[i]) for i in range(len(modes))]

plt.legend(handles=legend_elements)

plt.show()