In [1]:
import torch
import pandas as pd
import plotly.express as px

from tqdm.auto import tqdm
from nnsight import LanguageModel
from src.utils import path_patch, plain_run
from random import sample

In [2]:
gemma = LanguageModel("google/gemma-2-2b", torch_dtype=torch.bfloat16, device_map="cuda")
gemma.model.train()

Gemma2Model(
  (embed_tokens): Embedding(256000, 2304, padding_idx=0)
  (layers): ModuleList(
    (0-25): 26 x Gemma2DecoderLayer(
      (self_attn): Gemma2Attention(
        (q_proj): Linear(in_features=2304, out_features=2048, bias=False)
        (k_proj): Linear(in_features=2304, out_features=1024, bias=False)
        (v_proj): Linear(in_features=2304, out_features=1024, bias=False)
        (o_proj): Linear(in_features=2048, out_features=2304, bias=False)
        (rotary_emb): Gemma2RotaryEmbedding()
      )
      (mlp): Gemma2MLP(
        (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)
        (up_proj): Linear(in_features=2304, out_features=9216, bias=False)
        (down_proj): Linear(in_features=9216, out_features=2304, bias=False)
        (act_fn): PytorchGELUTanh()
      )
      (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
      (post_attention_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
      (pre_feedforward_layernorm): Gemma2RMSNorm((2304,), e

In [2]:
sport_prompt = "Fact: Tiger Woods plays the sport of golf.\nFact: {} plays the sport of"
athlete_data = pd.read_csv("../data/athlete.csv")

In [4]:
tokenizer = gemma.tokenizer

# Direct Patching

In [None]:
for layer in range(26):
  for comp in ["attention", "feedforward"]:
    # I want to select a random pair of distinct sports n times
    sports = ["American football", "basketball", "baseball"]
    for _ in range(6):
      src_sport, target_sport = sample(sports, 2)

      src_athlete = athlete_data[athlete_data["Sport"] == src_sport].sample(n=1).iloc[0]
      target_athlete = athlete_data[athlete_data["Sport"] == target_sport].sample(n=1).iloc[0]

      src_prompt = sport_prompt.format(src_athlete["Name"])
      target_prompt = sport_prompt.format(target_athlete["Name"])

      logit_diff = lambda logits: logits[:, -1, tokenizer.convert_tokens_to_ids(src_athlete["Sport Token"])] - logits[:, -1, tokenizer.convert_tokens_to_ids(target_athlete["Sport Token"])]

      src_diff = logit_diff(plain_run(gemma, src_prompt))
      patched_diff = logit_diff(path_patch(
        gemma,
        [(f".model.layers.{layer}.post_{comp}_layernorm",-1)],
        [(f".model.norm",-1)],
        src_prompt,
        target_prompt
      ))

      

      


In [3]:
patched_logits = path_patch(
  gemma, 
  [(".model.layers.25.post_attention_layernorm", -1),
   (".model.layers.25.post_feedforward_layernorm", -1)], 
  [".model.norm"],
  "goodbye",
  "goodbye"
)

clean_logits = plain_run(gemma, "goodbye")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

You're using a GemmaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
The 'batch_size' argument of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'max_batch_size' argument instead.
The 'batch_size' attribute of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'self.max_batch_size' attribute instead.


In [4]:
with gemma.trace("testing!") as tracer, torch.no_grad():
  gemma.model.layers[25].post_attention_layernorm

In [5]:
gemma.model.layers[25].post_attention_layernorm

Gemma2RMSNorm((2304,), eps=1e-06)