In [None]:
import sys
sys.path.append('..')
import torch
torch.set_grad_enabled(False)

import transformer_lens.utils as utils
from transformer_lens import ActivationCache, HookedTransformer, SVDInterpreter
import transformers
from transformers import LlamaForCausalLM
from typing import List, Optional, Union

In [None]:
import random
from transformers import set_seed
### Random Seed ###
SEED = 42
def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
seed_everything(SEED)
set_seed(SEED)

## Load Model

In [None]:
MODEL_PATH = "meta-llama/Llama-2-7b-chat-hf"
model_name = "Llama-2-7b-chat-hf"

model = HookedTransformer.from_pretrained(
    MODEL_PATH,
    device="cuda:0", 
    fold_ln=True, 
    center_writing_weights=True, 
    center_unembed=True, 
    refactor_factored_attn_matrices=False,
)

In [None]:
def ablate_attention_head_output(attn_out, hook, head_idx_to_ablate=None):
    """
    attn_out.shape == [batch_size, seq_len, n_heads, d_head]
    (May vary depending on model configuration)
    head_idx_to_ablate: index of the attention head to ablate (set to 0)
    """
    if head_idx_to_ablate is not None:
        attn_out[..., head_idx_to_ablate, :] = 0.0
    return attn_out

def ablate_and_infer(model, prompt, layer_idx, head_idx, max_new_tokens=50):
    """
    - Registers an ablation hook on the specified layer_idx and head_idx of the model
    - Generates text from the given prompt with the ablated state
    - Returns the generated text as a string
    """
    # (1) Register hook
    hook_name = f"blocks.{layer_idx}.attn.hook_result"
    
    # If any hooks have been previously registered, it is safer to reset them first
    model.reset_hooks(including_permanent=True)
    
    model.cfg.use_split_qkv_input = True
    model.cfg.use_attn_result = True

    model.add_perma_hook(
        hook_name,
        lambda attn_out, hook: ablate_attention_head_output(
            attn_out,
            hook,
            head_idx_to_ablate=head_idx
        )
    )

    # (2) Tokenize prompt and generate output using the generate method
    print(f"[INFO] Ablation on layer {layer_idx}, head {head_idx} - Prompt: {prompt}")
    input_tokens = model.to_tokens(prompt, prepend_bos=True)
    generated_tokens = model.generate(
        input_tokens,
        max_new_tokens=max_new_tokens,
        temperature=0.0,   # Adjust parameters as desired
        top_p=1
    )
    generated_text = model.to_string(generated_tokens)
    return generated_text

In [None]:
prompt = 'In 1999, the name of president of South Korea was'
# prompt = 'In 2004, the name of president of South Korea was'
# prompt = 'In 2009, the name of president of South Korea was'

# Model output without any ablation
# Reset hooks to return the model to its original state
model.reset_hooks(including_permanent=True)

# Tokenize the prompt
input_tokens = model.to_tokens(prompt, prepend_bos=False)

# Generate text
no_ablation_tokens = model.generate(
    input_tokens,
    max_new_tokens=50,
    temperature=0.0,
    top_p=1
)

# Convert tokens to string
no_ablation_answer = model.to_string(no_ablation_tokens)

print("=== Without Ablation ===")
print(no_ablation_answer)
print("============================================")

# (B) Example: Ablation at layer=2, head=2 (a2.h2)
ablated_answer = ablate_and_infer(model, prompt, layer_idx=2, head_idx=2)
print("=== With Ablation (layer2, head2) ===")
print(ablated_answer)
print("============================================")

# (C) Temporal Head Ablation (layer=18, head=3 -> a18.h3)
ablated_answer_2 = ablate_and_infer(model, prompt, layer_idx=18, head_idx=3)
print("=== With Ablation (layer18, head3) ===")
print(ablated_answer_2)
print("============================================")

# (D) Temporal Head Ablation (layer=15, head=0 -> a15.h0)
ablated_answer_3 = ablate_and_infer(model, prompt, layer_idx=15, head_idx=0)
print("=== With Ablation (layer15, head0) ===")
print(ablated_answer_3)
print("============================================")

In [None]:
# Function to ablate multiple attention heads
def ablate_multi_attention_head_output(attn_out, hook, heads_to_ablate=None):
    """
    Ablates specified attention heads by setting their outputs to zero.

    Args:
        attn_out (torch.Tensor): The attention output tensor of shape [batch_size, seq_len, n_heads, d_head].
        hook (HookPoint): The hook point in the model.
        heads_to_ablate (List[int], optional): List of head indices to ablate. Defaults to None.

    Returns:
        torch.Tensor: Modified attention output with specified heads ablated.
    """
    if heads_to_ablate is not None:
        attn_out[..., heads_to_ablate, :] = 0.0
    return attn_out

# Function to register ablation hooks for multiple heads across specified layers
def ablate_multiple_heads(model, layer_heads_dict):
    """
    Registers permanent hooks to ablate specified attention heads in given layers.

    Args:
        model (HookedTransformer): The transformer model.
        layer_heads_dict (Dict[int, List[int]]): Dictionary mapping layer indices to lists of head indices to ablate.
            Example: {2: [2, 5], 18: [3, 7]} ablates heads 2 and 5 in layer 2, and heads 3 and 7 in layer 18.
    """
    # Reset existing hooks to prevent conflicts
    model.reset_hooks(including_permanent=True)

    # Register hooks for specified layers and heads
    for layer_idx, heads_list in layer_heads_dict.items():
        hook_name = f"blocks.{layer_idx}.attn.hook_result"
        model.add_perma_hook(
            hook_name,
            lambda attn_out, hook, heads_list=heads_list: ablate_multi_attention_head_output(
                attn_out,
                hook,
                heads_to_ablate=heads_list
            )
        )
    print(f"[INFO] Registered multi-head ablation hooks for layers and heads: {layer_heads_dict}")

In [None]:
# Dictionary specifying which heads to ablate in which layers
ablation_dict = {
    # 2: [2],
    18: [3],
    15: [0],
    # 1: [15],
    # 16: [10],
    # 20: [17],
    # 3: [19]
    # 18: [15],
    # 23: [26],
    # 10: [13],
    # 12: [6],
    # 3: [4],
    # 17: [15]
}

# Register the ablation hooks
ablate_multiple_heads(model, ablation_dict)

# Inference with the model after ablation
prompt = 'In 1999, the name of president of South Korea was'
# prompt = 'In 2004, the name of president of South Korea was'
# prompt = 'In 2009, the name of president of South Korea was'

# Tokenize the prompt
input_tokens = model.to_tokens(prompt, prepend_bos=True)

# Generate text with the ablated model
ablated_tokens = model.generate(
    input_tokens,
    max_new_tokens=50,
    temperature=0.0,
    top_p=1
)

# Convert generated tokens to string
ablated_answer = model.to_string(ablated_tokens)

print("=== With Multi-Head Ablation ===")
print(ablated_answer)