In [1]:
try:
  # for google colab users
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    %pip install sae-lens transformer-lens
except:
  # for local setup
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Imports for displaying vis in Colab / notebook
import webbrowser
import http.server
import socketserver
import threading
PORT = 8000

# general imports
import os
import torch
from tqdm import tqdm
import plotly.express as px

torch.set_grad_enabled(False);

In [2]:
def display_vis_inline(filename: str, height: int = 850):
    '''
    Displays the HTML files in Colab. Uses global `PORT` variable defined in prev cell, so that each
    vis has a unique port without having to define a port within the function.
    '''
    if not(COLAB):
        webbrowser.open(filename);

    else:
        global PORT

        def serve(directory):
            os.chdir(directory)

            # Create a handler for serving files
            handler = http.server.SimpleHTTPRequestHandler

            # Create a socket server with the handler
            with socketserver.TCPServer(("", PORT), handler) as httpd:
                print(f"Serving files from {directory} on port {PORT}")
                httpd.serve_forever()

        thread = threading.Thread(target=serve, args=("/content",))
        thread.start()

        output.serve_kernel_port_as_iframe(PORT, path=f"/{filename}", height=height, cache_in_notebook=True)

        PORT += 1

In [3]:
# package import
from torch import Tensor
from transformer_lens import utils
from functools import partial
from jaxtyping import Int, Float

device = ['cuda:1', 'cuda:2']
# device setup
# if torch.backends.mps.is_available():
#     device = "mps"
# else:
#     device = "cuda:2" if torch.cuda.is_available() else "cpu"

# print(f"Device: {device}")

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from transformer_lens import HookedTransformer
from sae_lens import SAE
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes

# Choose a layer you want to focus on
# For this tutorial, we're going to use layer 2
layer = 12

from transformers import AutoModelForCausalLM

    
local_model_0 = AutoModelForCausalLM.from_pretrained("/aifs4su/yaodong/models/gemma/gemma-2b-it")

# get model
model_0 = HookedTransformer.from_pretrained("gemma-2b-it", hf_model = local_model_0, device = device[0])

# get the SAE for this layer
sae_0, cfg_dict_0, _ = SAE.from_pretrained(
    release = "gemma-2b-it-res-jb",
    sae_id = f"blocks.{layer}.hook_resid_post",
    device = device[0]
)

# get hook point
hook_point_0 = sae_0.cfg.hook_name
print(hook_point_0)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.
Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.33it/s]


Loaded pretrained model gemma-2b-it into HookedTransformer
blocks.12.hook_resid_post


In [5]:
from transformers import AutoModelForCausalLM

    
local_model_1 = AutoModelForCausalLM.from_pretrained("/aifs4su/yaodong/projects/hantao/personal/models/output_gemma_it_sft0716")

# get model
model_1 = HookedTransformer.from_pretrained("gemma-2b-it", hf_model = local_model_1, device = device[1])

# get the SAE for this layer
sae_1, cfg_dict_1, _ = SAE.from_pretrained(
    release = "gemma-2b-it-res-jb",
    sae_id = f"blocks.{layer}.hook_resid_post",
    device = device[1]
)

# get hook point
hook_point_1 = sae_1.cfg.hook_name
print(hook_point_1)

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.30it/s]


Loaded pretrained model gemma-2b-it into HookedTransformer
blocks.12.hook_resid_post


In [47]:
sv_prompt = " <start_of_turn>user\n Fuck you you asshole! \n<start_of_turn>model\n Don't say that."
# sv_prompt = "Fuck you, you asshole!"
# sv_prompt = "Where is the golden bridge?"

In [52]:
sv_logits_0, cache_0 = model_0.run_with_cache(sv_prompt, prepend_bos=True)
tokens_0 = model_0.to_tokens(sv_prompt)
print(tokens_0.shape[1])

# get the feature activations from our SAE
sv_feature_acts_0 = sae_0.encode(cache_0[hook_point_0])

# get sae_out
sae_out_0 = sae_0.decode(sv_feature_acts_0)

# print out the top activations, focus on the indices
topk1 = torch.topk(sv_feature_acts_0, 5)
print(torch.topk(sv_feature_acts_0, 5))

21
torch.return_types.topk(
values=tensor([[[35.2354, 35.2019, 34.1813, 34.1702, 33.1085],
         [ 5.2214,  1.9703,  1.9544,  1.0701,  1.0042],
         [ 4.7250,  2.1992,  1.8401,  1.6175,  1.3847],
         [ 6.5303,  2.2195,  2.1185,  1.0758,  1.0517],
         [ 4.3880,  2.2728,  0.9632,  0.7951,  0.6092],
         [ 8.7513,  5.4801,  4.0706,  2.1339,  2.0712],
         [ 4.9153,  4.3126,  3.7675,  2.2982,  1.7447],
         [ 4.8043,  3.1073,  2.0043,  1.9802,  1.7382],
         [ 5.2204,  3.4372,  2.7366,  2.4824,  2.3386],
         [ 5.8967,  1.8846,  1.5570,  1.5153,  1.5003],
         [ 5.2550,  2.5362,  2.0057,  1.8109,  1.7355],
         [ 2.6782,  2.3030,  2.2487,  1.5172,  1.1844],
         [ 4.5712,  2.4375,  2.2674,  1.5174,  1.3008],
         [ 6.0807,  4.5854,  3.9937,  2.5786,  1.9437],
         [ 5.6117,  3.5105,  2.0435,  1.6620,  1.4337],
         [ 7.3446,  5.2239,  1.4832,  1.2312,  1.2179],
         [ 3.3165,  2.2162,  1.9631,  1.4558,  1.2058],
         [ 7.

: 

In [49]:

sv_logits_1, cache_1 = model_1.run_with_cache(sv_prompt, prepend_bos=True)
tokens_1 = model_1.to_tokens(sv_prompt)
print(tokens_1)

# get the feature activations from our SAE
sv_feature_acts_1 = sae_1.encode(cache_1[hook_point_1])

# get sae_out
sae_out_1 = sae_1.decode(sv_feature_acts_1)

# print out the top activations, focus on the indices
topk2 = torch.topk(sv_feature_acts_1, 5)
print(torch.topk(sv_feature_acts_1, 5))

tensor([[     2, 235248,    106,   1645,    108,  48075,    692,    692,  73663,
         235341, 235248,    108,    106,   2516,    108,   4257, 235303, 235251,
           1931,    674, 235265]], device='cuda:2')
torch.return_types.topk(
values=tensor([[[30.0818, 30.0677, 29.2058, 29.1572, 28.3047],
         [ 4.2067,  2.3358,  1.5934,  1.3462,  1.0896],
         [ 4.2978,  1.8611,  1.5436,  1.2794,  0.9339],
         [ 7.0984,  1.8055,  1.4783,  1.4376,  0.9996],
         [ 3.6766,  0.9447,  0.9242,  0.6379,  0.5407],
         [ 8.5701,  6.5818,  4.1991,  1.8930,  1.7804],
         [ 5.6836,  3.2678,  3.1601,  1.9078,  1.5844],
         [ 5.9737,  4.3934,  2.0849,  1.9427,  1.8765],
         [ 5.8809,  3.6808,  3.2547,  3.0860,  2.0637],
         [ 7.4964,  2.1445,  1.5424,  1.4714,  1.4118],
         [ 5.4907,  2.7920,  2.2104,  1.7469,  1.2382],
         [ 2.9947,  1.6853,  1.4526,  1.1380,  1.1139],
         [ 3.9846,  2.1403,  1.5395,  1.1767,  0.9074],
         [ 9.0793,  2.2083

In [50]:
topk1_values_cpu = topk1.values.to('cpu')
topk1_indices_cpu = topk1.indices.to('cpu')
topk2_values_cpu = topk2.values.to('cpu')
topk2_indices_cpu = topk2.indices.to('cpu')

# Create new structures on the CPU
# topk1_cpu = torch.return_types.topk(values=topk1_values_cpu, indices=topk1_indices_cpu)
# topk2_cpu = torch.return_types.topk(values=topk2_values_cpu, indices=topk2_indices_cpu)

intersection_counts = []
value_changes = []

# for col in range(topk1_indices_cpu.size(1)):  # Assuming same number of columns
for col in range(15, 19):
    indices1 = topk1_indices_cpu[:,col,:]
    indices2 = topk2_indices_cpu[:,col,:]
    values1 = topk1_values_cpu[:,col,:]
    values2 = topk2_values_cpu[:,col,:]

    # Compute intersection manually
    mask = (indices1.unsqueeze(-1) == indices2.unsqueeze(-2)).any(-1)
    common_indices1 = indices1[mask]
    common_indices2 = indices2[mask]

    intersection_count = common_indices1.size(0)
    intersection_counts.append(intersection_count)

    # Compute value changes for intersected indices
    if intersection_count > 0:
        # Matching indices in both tensors
        matching_values1 = values1[mask]
        matching_values2 = values2[mask]

        # Calculate mean of absolute differences
        value_change = torch.abs(matching_values1 - matching_values2).mean() /  torch.abs(matching_values1 + matching_values2).mean()
    else:
        value_change = torch.tensor(0.0)  # If no intersection, set change to 0
    value_changes.append(value_change)

# Output results
intersection_counts = torch.tensor(intersection_counts)
value_changes = torch.tensor(value_changes)

print("Intersection Counts:", intersection_counts)
print("Average Value Changes:", value_changes)

Intersection Counts: tensor([2, 3, 3, 3])
Average Value Changes: tensor([0.1870, 0.0505, 0.2290, 0.0445])
