In [199]:
import torch 
import torch.nn.functional as F
import seaborn as sns
from circuitsvis.attention import attention_heads, attention_patterns
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer
import pandas as pd
import os
import shutil
import time
import concurrent.futures

In [200]:
# Global settings
torch.set_grad_enabled(False) #to disable gradients -> faster computiations
# Ensure GPU acceleration is enabled on Mac
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
model = None
tokenizer = None
META_LLAMA_3_2_3B = "meta-llama/Llama-3.2-3B"
GOOGLE_GEMMA_2_2B = "google/gemma-2-2b"
dataset = {}
CSV_PATH_DATASET = "/Users/ivannaranjo/Documents/Helmholtz/experiments/hlmz-prep/dataset/examples.csv"

### Functions

In [201]:
def initialize_model(model_name: str, tokenizer_name: str = None):
    if not tokenizer_name:
        tokenizer_name = model_name
    # Initialize model and tokenizer
    global model
    model = AutoModelForCausalLM.from_pretrained(model_name)
    if not tokenizer_name:
        tokenizer_name = model_name
    global tokenizer
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)


In [202]:
def load_dataset(path_to_csv: str):
    # Check if the file at the given path exists
    if os.path.exists(path_to_csv):
        df = pd.read_csv(path_to_csv)
    else:
        print("File does not exist.")
        exit(1)

    global dataset 
    dataset = df
    
    # Create a new column "token_probability" for saving up the probabilites of the studied token for all prompts. Initially, 0.
    dataset["token_probability_true_sentence"] = 0
    dataset["token_probability_false_sentence"] = 0

In [203]:
def print_colored_separator(color="\033[94m", char="=", length=150):
    reset = "\033[0m"  # Reset color
    print(f"{color}{char * length}{reset}")

In [204]:
# Returns the model's output after feeding it with a prompt concatenated prompt_repetitions times and the concatenated prompt tensor
def feed_forward(prompt: str, prompt_repetitions: int = 1):
    print_colored_separator()
    print(f"Input: {prompt}\n")
    token_sequence = tokenizer(prompt, return_tensors="pt")
    tokens = token_sequence["input_ids"][0]
    print(f"Tokenizer tokens: {tokens}\n")
    
    sentence_to_concatenate = tokens
    for _ in range(prompt_repetitions):
        sentence_to_concatenate = torch.concat([sentence_to_concatenate, tokens[1:]])

    print("Concatenated prompt:")
    print(f"{tokenizer.decode(sentence_to_concatenate)}\n")
    print("Concatenated tokens:")
    print(f"{sentence_to_concatenate}\n")

    # Feed forward to the model
    global model
    out = model(sentence_to_concatenate.unsqueeze(0), return_dict=True, output_attentions=True)
    return out, sentence_to_concatenate

In [205]:
def create_attention_mask(token_sequence: str, show_induction_mask: bool = False):
    print_colored_separator()
    sequence_length = len(token_sequence)
    half_of_sequence = sequence_length // 2 - 1
    induction_mask = torch.zeros(sequence_length, sequence_length).to(float)

    for i in range(sequence_length // 2 + 1, sequence_length):
        induction_mask[i, i - half_of_sequence] = 1 
    
    if show_induction_mask:
        print("Induction Mask:\n")
        print(induction_mask)
        print()
        print("Induction Mask plot:\n")
        plt.imshow(induction_mask)
        plt.show()
        print()
    return induction_mask

In [206]:
def compute_induction_head_scores(token_sequence: str, induction_mask: torch.Tensor, model_output):
    num_heads = model.config.num_attention_heads
    num_layers = model.config.num_hidden_layers
    sequence_length = len(token_sequence)

    induction_scores = torch.zeros(num_layers, num_heads)
    tril = torch.tril_indices(sequence_length - 1, sequence_length - 1) # gets the indices of elements on and below the diagonal
    induction_flat = induction_mask[tril[0], tril[1]].flatten()
    
    for layer in range(num_layers):
        for head in range(num_heads):
            pattern = model_output["attentions"][layer][0][head].cpu().to(float)[1:, 1:]
            pattern_flat = pattern[tril[0], tril[1]].flatten()
            score = (induction_flat @ pattern_flat) / pattern_flat.sum()
            induction_scores[layer, head] = score
    return induction_scores

In [207]:
def create_heatmap(induction_scores: torch.Tensor):
    print_colored_separator()
    _, ax = plt.subplots()
    print("Heatmap of induction scores across heads and layers: \n")
    sns.heatmap(induction_scores, cbar_kws={"label": "Induction Head Score"}, ax=ax)
    ax.set_ylabel("Layer #")
    ax.set_xlabel("Head #")
    plt.show()

In [208]:
def sort_high_scoring_heads(induction_scores: torch.Tensor, model_output: any, show_induction_heads: bool = False): 
    print_colored_separator()
    # Get flattened indices sorted by scores in descending order
    sorted_flat_indices = torch.argsort(induction_scores.flatten(), descending=True)

    # Convert flattened indices to 2D indices
    sorted_indices = torch.unravel_index(sorted_flat_indices, induction_scores.shape)
    # Stack the row and column indices for final output
    sorted_indices = torch.stack(sorted_indices, dim=1)

    if show_induction_heads:
        print("Top 5 Induction Heads with the highest induction score - Descending order\n")
        for layer, head in sorted_indices[:5]:
            print(f"Layer: {layer}\nHead: {head}")
            plt.imshow(model_output["attentions"][layer][0][head].cpu().float())
            plt.show()
            print()
    return sorted_indices

In [209]:
def token_probability_extraction(head_indices: torch.Tensor, models_output: any):
    top_layer, top_head = head_indices[0]
    probs = models_output["attentions"][top_layer][0][top_head]

    # Extract probability of the specified token
    sequence_length = probs.shape[0]
    probability_token = probs[sequence_length - 2, sequence_length // 2].item() * 100

    print_colored_separator()
    print("Probability of token: ", probability_token)
    return probability_token


In [210]:
def save_probability(token_probability: int, example_id: int, column_name_probability: str):
    if dataset.empty:
        raise Exception("Dataset is empty")
    
    print_colored_separator()
    # Log the probability of the token into its corresponding row and column in the dataset.
    dataset.loc[dataset["example_id"] == example_id, f"{column_name_probability}"] = token_probability 
    print(f"Saved probability for token from example_id: {example_id}\n")

In [211]:
def display_attention_visualizations(head_indices: torch.Tensor, token_sequence: torch.Tensor, models_output):
    # Display attention diagrams
    tokens_vis = tokenizer.tokenize(tokenizer.decode(token_sequence.squeeze()))
    layer, head = head_indices[0]
    return attention_patterns(tokens_vis, models_output["attentions"][layer][0]), attention_heads(models_output["attentions"][layer][0], tokens_vis)

In [212]:
def run_experiment(prompt: str, example_id: int, column_name_probability: str, prompt_repetitions: int = 1, show_diagrams: bool = False):
    models_output, token_sequence = feed_forward(prompt=prompt, prompt_repetitions=prompt_repetitions)
    induction_mask = create_attention_mask(token_sequence=token_sequence)
    induction_scores = compute_induction_head_scores(token_sequence=token_sequence, induction_mask=induction_mask, model_output=models_output)
    # create_heatmap(induction_scores=induction_scores)
    high_scoring_heads_indices_sorted = sort_high_scoring_heads(induction_scores=induction_scores, model_output=models_output)

    # Extract the probability of the studied token. 
    token_probability = token_probability_extraction(high_scoring_heads_indices_sorted, models_output)

    return token_probability

    # Save up the probability of the token in the dataset. 
    # save_probability(token_probability=token_probability, example_id=example_id, column_name_probability=column_name_probability)

    # Visualize attention 
    if show_diagrams:
        attention_patterns_view, attention_head_view = display_attention_visualizations(head_indices=high_scoring_heads_indices_sorted, token_sequence=token_sequence, models_output=models_output)
        return attention_patterns_view, attention_head_view

In [213]:
def run_experiment_gpu(prompt, example_id, column_name_probability):
    """Wrapper function for running inference on Mac GPU"""
    with torch.no_grad():  # Prevents GPU memory issues
        return run_experiment(prompt, example_id, column_name_probability)

In [214]:
#TODO: This can take long. Think how to parallelize this.
def run_experiment_suite(dataset_csv_file_path: str):
    print(f"Using device: {torch.device('mps') if torch.backends.mps.is_available() else 'cpu'}")
    load_dataset(path_to_csv=dataset_csv_file_path)
    # Use ThreadPoolExecutor since multiprocessing won't work well with Metal
    with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
        future_true = executor.submit(
            lambda: dataset.apply(lambda row: run_experiment_gpu(row["true_sentence"], 
                                                                 row["example_id"], 
                                                                 "token_probability_true_sentence"), axis=1)
        )
        future_false = executor.submit(
            lambda: dataset.apply(lambda row: run_experiment_gpu(row["false_sentence"], 
                                                                 row["example_id"], 
                                                                 "token_probability_false_sentence"), axis=1)
        )

        # Collect results
        dataset["token_probability_true_sentence"] = future_true.result()
        dataset["token_probability_false_sentence"] = future_false.result()


    # Create a backup
    shutil.copy(dataset_csv_file_path, dataset_csv_file_path + ".backup") 

    # Now overwrite the CSV file
    dataset.to_csv(dataset_csv_file_path, index=False)

### Experiment Start

In [215]:
initialize_model(model_name=META_LLAMA_3_2_3B, tokenizer_name=META_LLAMA_3_2_3B)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [216]:
start_time = time.perf_counter()
run_experiment_suite(dataset_csv_file_path=CSV_PATH_DATASET)
end_time = time.perf_counter()
elapsed_time = end_time - start_time
print(f"Execution Time: {elapsed_time:.4f} seconds.")

Using device: mps
Input: Macron is the president of France.

Input: Macron is the president of Germany.

Tokenizer tokens: tensor([128000,  20122,   2298,    374,    279,   4872,    315,  10057,     13])

Tokenizer tokens: tensor([128000,  20122,   2298,    374,    279,   4872,    315,   9822,     13])

Concatenated prompt:
Concatenated prompt:
<|begin_of_text|>Macron is the president of France.Macron is the president of France.

Concatenated tokens:
<|begin_of_text|>Macron is the president of Germany.Macron is the president of Germany.

Concatenated tokens:
tensor([128000,  20122,   2298,    374,    279,   4872,    315,   9822,     13,
         20122,   2298,    374,    279,   4872,    315,   9822,     13])

tensor([128000,  20122,   2298,    374,    279,   4872,    315,  10057,     13,
         20122,   2298,    374,    279,   4872,    315,  10057,     13])

Probability of token:  43.375617265701294
Probability of token:  49.307820200920105
Input: The Earth is round.

Input: The Eart

In [217]:
# attention_patterns_view, attention_heads_view = run_experiment(prompt="The planet earth is round", prompt_repetitions=1)
# attention_patterns_view
# attention_heads_view