In this notebook we'll be using transformerlens and runpods to generate outputs 

I'll do this under head ablation

runpods questions 

do I have a credit card number to use for runpods?

how do I make sure runpods start when I want them and stop when im finished?

In [1]:
from transformer_lens import HookedTransformer
import transformer_lens
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import transformer_lens.utils as utils
import hashlib
import yaml 
import hashlib
import pickle
import numpy as np
import matplotlib.pyplot as plt 

  from .autonotebook import tqdm as notebook_tqdm


In [14]:
import pickle

# Save object to a file
def save_pickle(obj, filepath):
    with open(filepath, 'wb') as f:
        pickle.dump(obj, f)

# Load object from a file
def load_pickle(filepath):
    with open(filepath, 'rb') as f:
        return pickle.load(f)
    

In [15]:
device = 'cuda:0'

In [16]:
reference_model_path = 'meta-llama/Llama-3.1-8B'
baseline_model_path = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"

In [6]:
baseline_model_hf = AutoModelForCausalLM.from_pretrained(baseline_model_path, torch_dtype=torch.bfloat16)
baseline_model_tokenizer = AutoTokenizer.from_pretrained(baseline_model_path)

model = HookedTransformer.from_pretrained_no_processing(
    reference_model_path,
    hf_model=baseline_model_hf,
    tokenizer=baseline_model_tokenizer,
    device=device,
    move_to_device=True,
)

Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.03s/it]


Loaded pretrained model meta-llama/Llama-3.1-8B into HookedTransformer


In [7]:
model

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0-31): 32 x TransformerBlock(
      (ln1): RMSNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): RMSNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): GroupedQueryAttention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_rot_k): HookPoint()
        (hook_rot_q): HookPoint()
      )
      (mlp): GatedMLP(
        (hook_pre): HookPoint()
        (hook_pre_linear): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out)

In [8]:
def produce_head_ablation(head_idx):
    def head_ablation_hook(
        value,
        hook
    ):
        # print(f"Shape of the value tensor: {value.shape}")
        value[:, :, head_idx, :] = 0.
        return value

    return head_ablation_hook


def generate_using_ablated_head(model, init_strings, ablate_heads=None, max_new_tokens=300, format_init_strings=True, save_as_you_go=None):
    """ 
    
    """
    if format_init_strings:
        init_strings=[f"<｜User｜>{string}<｜Assistant｜><think>\n" for string in init_strings]
    
    # apply ablations
    model.reset_hooks()
    if ablate_heads is not None:
        for layer_idx, head_idx in ablate_heads:
            cur_head_ablation = produce_head_ablation(head_idx)
            model.add_hook(utils.get_act_name("z", layer_idx), cur_head_ablation)

        print("ablate heads:")
        for head in ablate_heads:
            print(ablate_heads)
    
    if save_as_you_go is not None:
        assert isinstance(save_as_you_go, str)
        if '.pkl' != save_as_you_go[-4:]:
            save_as_you_go = save_as_you_go+".pkl"

    # now generate 
    # these are the default settings according 
    # https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-8B
    temp=0.6 
    top_p=0.95 

    all_generated_outs = []
    for init_string in init_strings:
        test_tokens = model.to_tokens(init_string)
        generated_outputs = model.generate(test_tokens, max_new_tokens=max_new_tokens, temperature=temp, top_p=top_p)

        all_generated_outs.append(generated_outputs)

        if save_as_you_go is not None:
            save_pickle(all_generated_outs, save_as_you_go)
    
    return all_generated_outs


load the collected think dict and ablate groups of important think heads 

In [9]:
with open('collected_think_dict.yaml', 'r') as file:
    think_dict = yaml.safe_load(file)

# ranking head importance 

In [10]:
head_str, mean_think_stop_think_ratios = zip(
    *sorted(
        [(k, think_dict[k]['mean_think_stop_think_ratio']) for k in think_dict],
        key = lambda x: x[-1],
        reverse=True
    )
)

In [None]:
for i in range(10):
    print(head_str[i], mean_think_stop_think_ratios[i])

0_22 8.45154857635498
3_1 6.1095075607299805
0_21 4.9486775398254395
4_21 4.679234027862549
8_18 4.370467185974121
1_23 4.120393753051758
1_12 3.7947165966033936
1_22 3.7929306030273438
16_20 3.6640069484710693
7_21 3.477773666381836


We'll take the top 4 important heads and ablate them one by one 

In [17]:
top_4_targets = [tuple([int(x) for x in h.split('_')]) for h in head_str[:4]]

In [18]:
top_4_targets

[(0, 22), (3, 1), (0, 21), (4, 21)]

In [19]:

def read_prompts(file_path="reasoning-prompts.md"):
    """
    Read prompts from a text file and return them as a list of strings.
    
    Args:
        file_path (str): Path to the file containing reasoning prompts
        
    Returns:
        list: A list of strings, each containing a reasoning prompt
    """
    prompts = []
    
    try:
        with open(file_path, 'r', encoding='utf-8') as file:
            for line in file:
                # Skip empty lines, headers, and category labels
                line = line.strip()
                if (line and 
                    not line.startswith('#') and 
                    not line.startswith('##') and
                    not line == ""):
                    
                    # Extract the prompt text by removing the number and period
                    parts = line.split('. ', 1)
                    if len(parts) > 1 and parts[0].isdigit():
                        prompt = parts[1]
                    else:
                        prompt = line
                        
                    prompts.append(prompt)
    except FileNotFoundError:
        print(f"Error: File '{file_path}' not found.")
    except Exception as e:
        print(f"Error reading prompts: {e}")
    
    return prompts

In [20]:
sampled_prompts = read_prompts()[:50]

In [15]:
collected_outputs = {}

for i in range(4):
    save_string = f"top_{i+1}_heads_ablated.pkl"

    generated_outputs = generate_using_ablated_head(
        model, 
        init_strings=sampled_prompts, 
        ablate_heads=top_4_targets[:i+1], 
        max_new_tokens=300, 
        format_init_strings=True, 
        save_as_you_go=save_string
    )

    collected_outputs[save_string]=generated_outputs

ablate heads:
[(0, 22)]


  0%|          | 0/300 [00:00<?, ?it/s]

 78%|███████▊  | 233/300 [00:17<00:05, 13.27it/s]
 58%|█████▊    | 174/300 [00:12<00:09, 13.58it/s]
100%|██████████| 300/300 [00:22<00:00, 13.61it/s]
 52%|█████▏    | 156/300 [00:11<00:10, 13.57it/s]
 54%|█████▍    | 163/300 [00:12<00:10, 13.49it/s]
 60%|██████    | 181/300 [00:13<00:08, 13.57it/s]
100%|██████████| 300/300 [00:22<00:00, 13.61it/s]
100%|██████████| 300/300 [00:22<00:00, 13.62it/s]
 79%|███████▊  | 236/300 [00:17<00:04, 13.57it/s]
 77%|███████▋  | 230/300 [00:16<00:05, 13.56it/s]
 79%|███████▊  | 236/300 [00:17<00:04, 13.57it/s]
100%|██████████| 300/300 [00:22<00:00, 13.60it/s]
 81%|████████  | 243/300 [00:17<00:04, 13.57it/s]
 82%|████████▏ | 247/300 [00:18<00:03, 13.58it/s]
100%|██████████| 300/300 [00:22<00:00, 13.59it/s]
 65%|██████▌   | 195/300 [00:14<00:07, 13.57it/s]
 72%|███████▏  | 217/300 [00:16<00:06, 13.56it/s]
100%|██████████| 300/300 [00:22<00:00, 13.61it/s]
100%|██████████| 300/300 [00:22<00:00, 13.60it/s]
 82%|████████▏ | 245/300 [00:18<00:04, 13.51it/s]


ablate heads:
[(0, 22), (3, 1)]
[(0, 22), (3, 1)]


 93%|█████████▎| 280/300 [00:20<00:01, 13.59it/s]
 62%|██████▏   | 185/300 [00:13<00:08, 13.58it/s]
100%|██████████| 300/300 [00:22<00:00, 13.60it/s]
 55%|█████▍    | 164/300 [00:12<00:10, 13.57it/s]
 89%|████████▉ | 268/300 [00:19<00:02, 13.56it/s]
 63%|██████▎   | 190/300 [00:13<00:08, 13.59it/s]
 91%|█████████ | 273/300 [00:20<00:01, 13.57it/s]
 90%|█████████ | 271/300 [00:19<00:02, 13.60it/s]
 80%|████████  | 240/300 [00:17<00:04, 13.58it/s]
 77%|███████▋  | 230/300 [00:16<00:05, 13.58it/s]
 88%|████████▊ | 265/300 [00:19<00:02, 13.57it/s]
100%|██████████| 300/300 [00:22<00:00, 13.61it/s]
 75%|███████▌  | 225/300 [00:16<00:05, 13.59it/s]
 87%|████████▋ | 260/300 [00:19<00:02, 13.57it/s]
100%|██████████| 300/300 [00:22<00:00, 13.63it/s]
 58%|█████▊    | 174/300 [00:12<00:09, 13.64it/s]
 83%|████████▎ | 248/300 [00:18<00:03, 13.61it/s]
100%|██████████| 300/300 [00:22<00:00, 13.61it/s]
100%|██████████| 300/300 [00:22<00:00, 13.62it/s]
 81%|████████  | 243/300 [00:17<00:04, 13.57it/s]


ablate heads:
[(0, 22), (3, 1), (0, 21)]
[(0, 22), (3, 1), (0, 21)]
[(0, 22), (3, 1), (0, 21)]


 84%|████████▍ | 252/300 [00:18<00:03, 13.55it/s]
 62%|██████▏   | 187/300 [00:13<00:08, 13.56it/s]
100%|██████████| 300/300 [00:22<00:00, 13.60it/s]
 61%|██████    | 183/300 [00:13<00:08, 13.57it/s]
 72%|███████▏  | 215/300 [00:15<00:06, 13.55it/s]
 54%|█████▍    | 162/300 [00:11<00:10, 13.57it/s]
 82%|████████▏ | 245/300 [00:18<00:04, 13.55it/s]
100%|██████████| 300/300 [00:22<00:00, 13.61it/s]
100%|██████████| 300/300 [00:22<00:00, 13.60it/s]
 83%|████████▎ | 248/300 [00:18<00:03, 13.59it/s]
 93%|█████████▎| 279/300 [00:20<00:01, 13.56it/s]
100%|██████████| 300/300 [00:22<00:00, 13.60it/s]
 72%|███████▏  | 215/300 [00:15<00:06, 13.57it/s]
100%|██████████| 300/300 [00:22<00:00, 13.61it/s]
100%|██████████| 300/300 [00:22<00:00, 13.61it/s]
 66%|██████▌   | 198/300 [00:14<00:07, 13.59it/s]
 89%|████████▊ | 266/300 [00:19<00:02, 13.57it/s]
 95%|█████████▍| 284/300 [00:20<00:01, 13.57it/s]
100%|██████████| 300/300 [00:22<00:00, 13.60it/s]
 82%|████████▏ | 246/300 [00:18<00:03, 13.58it/s]


ablate heads:
[(0, 22), (3, 1), (0, 21), (4, 21)]
[(0, 22), (3, 1), (0, 21), (4, 21)]
[(0, 22), (3, 1), (0, 21), (4, 21)]
[(0, 22), (3, 1), (0, 21), (4, 21)]


100%|██████████| 300/300 [00:21<00:00, 13.64it/s]
 57%|█████▋    | 170/300 [00:12<00:09, 13.59it/s]
100%|██████████| 300/300 [00:22<00:00, 13.61it/s]
 51%|█████     | 153/300 [00:11<00:10, 13.60it/s]
 74%|███████▍  | 222/300 [00:16<00:05, 13.58it/s]
 60%|██████    | 180/300 [00:13<00:08, 13.59it/s]
 97%|█████████▋| 292/300 [00:21<00:00, 13.56it/s]
100%|██████████| 300/300 [00:22<00:00, 13.61it/s]
 99%|█████████▉| 298/300 [00:21<00:00, 13.58it/s]
 81%|████████  | 242/300 [00:17<00:04, 13.58it/s]
100%|██████████| 300/300 [00:22<00:00, 13.60it/s]
 99%|█████████▉| 298/300 [00:21<00:00, 13.57it/s]
 76%|███████▋  | 229/300 [00:16<00:05, 13.55it/s]
 88%|████████▊ | 264/300 [00:19<00:02, 13.50it/s]
100%|██████████| 300/300 [00:22<00:00, 13.54it/s]
 61%|██████    | 182/300 [00:13<00:08, 13.51it/s]
 83%|████████▎ | 250/300 [00:18<00:03, 13.52it/s]
 92%|█████████▏| 276/300 [00:20<00:01, 13.52it/s]
100%|██████████| 300/300 [00:22<00:00, 13.51it/s]
 77%|███████▋  | 232/300 [00:17<00:05, 13.54it/s]


In [44]:
generated_outputs = generate_using_ablated_head(
    model, 
    init_strings=sampled_prompts, 
    ablate_heads=None, 
    max_new_tokens=300, 
    format_init_strings=True, 
    save_as_you_go="no_heads_ablated.pkl"
) 

no_heads_ablated = load_pickle("no_heads_ablated.pkl") 

no_heads_ablated_string = [
    model.to_string(no_heads_ablated_tensor) for no_heads_ablated_tensor 
    in no_heads_ablated
]

save_pickle(no_heads_ablated_string, "no_heads_ablated_string.pkl")

  0%|          | 0/300 [00:00<?, ?it/s]

100%|██████████| 300/300 [00:22<00:00, 13.49it/s]
 62%|██████▏   | 185/300 [00:13<00:08, 13.49it/s]
100%|██████████| 300/300 [00:22<00:00, 13.48it/s]
 54%|█████▍    | 163/300 [00:12<00:10, 13.47it/s]
 61%|██████▏   | 184/300 [00:13<00:08, 13.49it/s]
 55%|█████▍    | 164/300 [00:12<00:10, 13.47it/s]
 93%|█████████▎| 279/300 [00:20<00:01, 13.47it/s]
100%|██████████| 300/300 [00:22<00:00, 13.52it/s]
 71%|███████▏  | 214/300 [00:15<00:06, 13.48it/s]
100%|██████████| 300/300 [00:22<00:00, 13.54it/s]
 83%|████████▎ | 248/300 [00:18<00:03, 13.46it/s]
100%|██████████| 300/300 [00:22<00:00, 13.50it/s]
 67%|██████▋   | 202/300 [00:14<00:07, 13.55it/s]
 95%|█████████▌| 286/300 [00:21<00:01, 13.42it/s]
100%|██████████| 300/300 [00:22<00:00, 13.51it/s]
 62%|██████▏   | 187/300 [00:13<00:08, 13.53it/s]
100%|██████████| 300/300 [00:22<00:00, 13.51it/s]
100%|██████████| 300/300 [00:22<00:00, 13.50it/s]
 90%|█████████ | 271/300 [00:20<00:02, 13.50it/s]
 77%|███████▋  | 231/300 [00:17<00:05, 13.48it/s]


In [38]:
top_further_targets = [tuple([int(x) for x in h.split('_')]) for h in head_str[4:]]

# remove those heads beyond layer 4

top_further_targets = [(l, h) for l, h in top_further_targets if l<5][:16]

total_top_target_heads = top_4_targets+top_further_targets

In [39]:
total_top_target_heads 

[(0, 22),
 (3, 1),
 (0, 21),
 (4, 21),
 (1, 23),
 (1, 12),
 (1, 22),
 (0, 20),
 (3, 25),
 (1, 7),
 (3, 22),
 (1, 2),
 (4, 30),
 (0, 3),
 (2, 6),
 (2, 23),
 (4, 28),
 (4, 3),
 (0, 28),
 (4, 10)]

In [40]:
print(total_top_target_heads)

[(0, 22), (3, 1), (0, 21), (4, 21), (1, 23), (1, 12), (1, 22), (0, 20), (3, 25), (1, 7), (3, 22), (1, 2), (4, 30), (0, 3), (2, 6), (2, 23), (4, 28), (4, 3), (0, 28), (4, 10)]


In [41]:
# further removed heads 

save_string = f"further_top_{len(total_top_target_heads)}_before_l5_heads_ablated.pkl"

generated_outputs = generate_using_ablated_head(
    model, 
    init_strings=sampled_prompts, 
    ablate_heads=total_top_target_heads, 
    max_new_tokens=300, 
    format_init_strings=True, 
    save_as_you_go=save_string
)


ablate heads:
[(0, 22), (3, 1), (0, 21), (4, 21), (1, 23), (1, 12), (1, 22), (0, 20), (3, 25), (1, 7), (3, 22), (1, 2), (4, 30), (0, 3), (2, 6), (2, 23), (4, 28), (4, 3), (0, 28), (4, 10)]
[(0, 22), (3, 1), (0, 21), (4, 21), (1, 23), (1, 12), (1, 22), (0, 20), (3, 25), (1, 7), (3, 22), (1, 2), (4, 30), (0, 3), (2, 6), (2, 23), (4, 28), (4, 3), (0, 28), (4, 10)]
[(0, 22), (3, 1), (0, 21), (4, 21), (1, 23), (1, 12), (1, 22), (0, 20), (3, 25), (1, 7), (3, 22), (1, 2), (4, 30), (0, 3), (2, 6), (2, 23), (4, 28), (4, 3), (0, 28), (4, 10)]
[(0, 22), (3, 1), (0, 21), (4, 21), (1, 23), (1, 12), (1, 22), (0, 20), (3, 25), (1, 7), (3, 22), (1, 2), (4, 30), (0, 3), (2, 6), (2, 23), (4, 28), (4, 3), (0, 28), (4, 10)]
[(0, 22), (3, 1), (0, 21), (4, 21), (1, 23), (1, 12), (1, 22), (0, 20), (3, 25), (1, 7), (3, 22), (1, 2), (4, 30), (0, 3), (2, 6), (2, 23), (4, 28), (4, 3), (0, 28), (4, 10)]
[(0, 22), (3, 1), (0, 21), (4, 21), (1, 23), (1, 12), (1, 22), (0, 20), (3, 25), (1, 7), (3, 22), (1, 2), (4, 3

100%|██████████| 300/300 [00:22<00:00, 13.46it/s]
 58%|█████▊    | 173/300 [00:12<00:09, 13.45it/s]
100%|██████████| 300/300 [00:22<00:00, 13.51it/s]
 65%|██████▍   | 194/300 [00:14<00:07, 13.44it/s]
 64%|██████▍   | 193/300 [00:14<00:07, 13.46it/s]
 50%|█████     | 150/300 [00:11<00:11, 13.45it/s]
100%|██████████| 300/300 [00:22<00:00, 13.48it/s]
100%|██████████| 300/300 [00:22<00:00, 13.48it/s]
 94%|█████████▍| 282/300 [00:21<00:01, 13.43it/s]
 70%|███████   | 210/300 [00:15<00:06, 13.48it/s]
 97%|█████████▋| 291/300 [00:21<00:00, 13.44it/s]
100%|█████████▉| 299/300 [00:22<00:00, 13.43it/s]
 61%|██████    | 182/300 [00:13<00:08, 13.48it/s]
 76%|███████▌  | 228/300 [00:16<00:05, 13.45it/s]
100%|██████████| 300/300 [00:22<00:00, 13.48it/s]
 88%|████████▊ | 264/300 [00:19<00:02, 13.47it/s]
 92%|█████████▏| 275/300 [00:20<00:01, 13.43it/s]
 86%|████████▋ | 259/300 [00:19<00:03, 13.45it/s]
 90%|████████▉ | 269/300 [00:19<00:02, 13.48it/s]
 93%|█████████▎| 278/300 [00:20<00:01, 13.42it/s]


In [42]:
further_top_20_before_l5_heads_ablated = load_pickle("further_top_20_before_l5_heads_ablated.pkl") 

further_top_20_before_l5_heads_ablated_string = [
    model.to_string(further_top_20_before_l5_heads_ablated_tensor) for further_top_20_before_l5_heads_ablated_tensor 
    in further_top_20_before_l5_heads_ablated
]

save_pickle(further_top_20_before_l5_heads_ablated_string, "further_top_20_before_l5_heads_ablated_string.pkl")

In [43]:
further_top_20_before_l5_heads_ablated_string

[["<｜begin▁of▁sentence｜><｜User｜>What is the square root of 144?<｜Assistant｜><think>\nTo find the square root of 144, I will start by understanding what a square root means. The square root of a number is another number that, when multiplied by itself, gives the original number.\n\nNext, I will consider the factors of 144. I know that 12 times 12 equals 144, so 12 is a factor. Since both 12 and 12 are the same, 12 is the principal square root of 144.\n\nTherefore, the square root of 144 is 12.\n</think>\n\n**Solution:**\n\nTo find the square root of 144, follow these steps:\n\n1. **Understand the Definition of a Square Root:**\n   \n   The square root of a number \\( n \\) is a number \\( m \\) such that:\n   \\[\n   m \\times m = n\n   \\]\n   In this case, we are looking for \\( m \\) where:\n   \\[\n   m \\times m = 144\n   \\]\n\n2. **Find the Factors of 144:**\n   \n   We need to find two equal numbers that multiply to give 144. Let's list the factors of 144:\n   \n   \\[\n   1 \\t

# looking at results

In [31]:
top_1_heads_ablated = load_pickle("top_1_heads_ablated.pkl")
top_2_heads_ablated = load_pickle("top_2_heads_ablated.pkl")
top_3_heads_ablated = load_pickle("top_3_heads_ablated.pkl")
top_4_heads_ablated = load_pickle("top_4_heads_ablated.pkl")

convert to text

In [32]:
top_1_heads_ablated_string = [
    model.to_string(top_1_heads_ablated_tensor) for top_1_heads_ablated_tensor 
    in top_1_heads_ablated
]

save_pickle(top_1_heads_ablated_string, "top_1_heads_ablated_string.pkl")

######################

top_2_heads_ablated_string = [
    model.to_string(top_2_heads_ablated_tensor) for top_2_heads_ablated_tensor 
    in top_2_heads_ablated
]

save_pickle(top_2_heads_ablated_string, "top_2_heads_ablated_string.pkl")

######################

top_3_heads_ablated_string = [
    model.to_string(top_3_heads_ablated_tensor) for top_3_heads_ablated_tensor 
    in top_3_heads_ablated
]

save_pickle(top_3_heads_ablated_string, "top_3_heads_ablated_string.pkl")

######################

top_4_heads_ablated_string = [
    model.to_string(top_4_heads_ablated_tensor) for top_4_heads_ablated_tensor 
    in top_4_heads_ablated
]

save_pickle(top_4_heads_ablated_string, "top_4_heads_ablated_string.pkl")

# Comparing outputs 

Let's go through and see whether there is any trend here

In [45]:
! pip install rich

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




In [46]:
from rich import print as rprint
from rich.pretty import pprint

In [48]:
all_string_outputs = []

for i in range(-49, -5, 1):
    print("\n\n\n")
    print("original")
    pprint(no_heads_ablated_string[i])
    print("1 -----------------------")
    pprint(top_1_heads_ablated_string[i])
    print("2 -----------------------")
    pprint(top_2_heads_ablated_string[i])
    print("3 -----------------------")
    pprint(top_3_heads_ablated_string[i])
    print("4 -----------------------")
    pprint(top_4_heads_ablated_string[i])
    print("top20 -----------------------")
    pprint(further_top_20_before_l5_heads_ablated_string[i])
    print("\n\n\n")






original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------










original


1 -----------------------


2 -----------------------


3 -----------------------


4 -----------------------


top20 -----------------------








In [33]:
top_4_heads_ablated_string

[["<｜begin▁of▁sentence｜><｜User｜>What is the square root of 144?<｜Assistant｜><think>\nFirst, I recognize that 144 is a perfect square because its square root is a whole number.\n\nI will list the multiplication tables of numbers to identify which two numbers multiply to 144.\n\nStarting from 1, I multiply sequentially:\n1 × 144 = 144\n2 × 72 = 144\n3 × 48 = 144\n4 × 36 = 144\n5 × 24 = 144\n6 × 24 = 144\n7 × 21 = 144\n8 × 18 = 144\n9 × 16 = 144\n10 × 14.4 = 144\n\nNext, I check if any of these pairs are reciprocals of each other. For example, 12 and 12 multiply to 144, and 12 is its own reciprocal.\n\nSince 12 and 12 are reciprocals, the square root of 144 is 12.\n</think>\n\nTo find the square root of 144, we need to determine which number multiplied by itself equals 144.\n\nLet's explore the possibilities:\n\n1. **List Multiplication Pairs:**\n   - \\(1 \\times 144 = 144\\)\n   - \\(2 \\times 72 = 144\\)\n   - \\(3 \\times 48 = 144\\)\n   - \\(4 \\times 36 = 144\\)\n   - \\(5 \\times 2

In [None]:
tokens = model.to_tokens("Hey whats up?")

In [None]:
tokens

tensor([[128000,  19182,  41209,    709,     30]], device='cuda:0')

In [None]:

model.reset_hooks()
original_logits = model(tokens, return_type="logits")

test_tokens = model.to_tokens(["Hey whats up?"])

# try generating 
outs = model.generate(test_tokens, max_new_tokens=400, temperature=0.6, top_p=0.95)

100%|██████████| 400/400 [00:29<00:00, 13.60it/s]


In [None]:
model.to_string(outs)

["<｜begin▁of▁sentence｜>Hey whats up? So I just saw this video where someone was trying to make a cake from scratch, but it didn't turn out so well. It was kinda funny but also kind of sad because the person put so much effort into it. I wonder why it didn't work out. Maybe they didn't measure the ingredients properly? Or did they forget to add baking powder? I mean, I've had my fair share of baking failures before, so I can relate. Let me think about how to approach this.\n\nFirst off, when I try baking, I usually just throw in whatever I have, but that's probably why my cakes sometimes end up too dense or too dry. Maybe I should double-check the recipes next time. But wait, the person in the video was following a recipe, right? So maybe the issue was with the measurements. Did they use the right amounts of flour, sugar, butter, etc.? It's easy to get confused with cup measurements versus grams or milliliters, especially if you're not used to it.\n\nAlso, could it be the oven temperatu

In [None]:
utils.get_act_name("z", 0)

'blocks.0.attn.hook_z'

In [None]:
model.reset_hooks()

In [None]:

model.add_hook(utils.get_act_name("z", layer_to_ablate), head_ablation_hook)


modified_logits = model(tokens, return_type="logits")

Shape of the value tensor: torch.Size([1, 5, 32, 128])


In [None]:
modified_logits

tensor([[[ 1.2196,  2.1443,  4.9758,  ..., -1.1626, -1.1626, -1.1626],
         [12.2630,  1.2126,  0.5368,  ..., -3.3336, -3.3335, -3.3338],
         [ 5.2130,  2.5234, -0.0555,  ..., -4.4290, -4.4293, -4.4291],
         [11.5725,  3.8210,  1.2050,  ..., -5.2874, -5.2878, -5.2877],
         [ 5.1244, -3.3234,  1.2445,  ..., -4.2034, -4.2034, -4.2036]]],
       device='cuda:0', grad_fn=<ViewBackward0>)

In [None]:
original_logits

tensor([[[ 1.1601,  2.0860,  4.8631,  ..., -1.1549, -1.1549, -1.1549],
         [12.2188,  1.1821,  0.5339,  ..., -3.3060, -3.3059, -3.3062],
         [ 5.2511,  2.6107, -0.0394,  ..., -4.4606, -4.4609, -4.4606],
         [11.5757,  3.8677,  1.2013,  ..., -5.2869, -5.2873, -5.2872],
         [ 5.1116, -3.3044,  1.2272,  ..., -4.2133, -4.2133, -4.2136]]],
       device='cuda:0', grad_fn=<ViewBackward0>)

In [None]:
model.reset_hooks()
original_logits2 = model(tokens, return_type="logits")


In [None]:
original_logits2

tensor([[[ 1.1601,  2.0860,  4.8631,  ..., -1.1549, -1.1549, -1.1549],
         [12.2188,  1.1821,  0.5339,  ..., -3.3060, -3.3059, -3.3062],
         [ 5.2511,  2.6107, -0.0394,  ..., -4.4606, -4.4609, -4.4606],
         [11.5757,  3.8677,  1.2013,  ..., -5.2869, -5.2873, -5.2872],
         [ 5.1116, -3.3044,  1.2272,  ..., -4.2133, -4.2133, -4.2136]]],
       device='cuda:0', grad_fn=<ViewBackward0>)