# Experiment 1
### Idea: Plot all the probabilities of the token from the true sentence and false sentence of the intersection of the heads that appear in both sentences

In [None]:
import torch 
import seaborn as sns
from circuitsvis.attention import attention_heads, attention_patterns
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer
import pandas as pd
import os
import time
import gc
import json 

In [None]:
# Global settings
torch.set_grad_enabled(False) #to disable gradients -> faster computiations
torch.set_printoptions(sci_mode=False)
# Ensure GPU acceleration is enabled on Mac
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
mod = None
tokenizer = None
META_LLAMA_3_2_3B = "meta-llama/Llama-3.2-3B"
GOOGLE_GEMMA_2_2B = "google/gemma-2-2b"
dataset = {}
CSV_PATH_DATASET = "../dataset/examples.csv"

In [None]:
models = [META_LLAMA_3_2_3B] 

### Functions

In [None]:
def initialize_model(model_name: str, tokenizer_name: str = None):
    if not tokenizer_name:
        tokenizer_name = model_name
    # Initialize model and tokenizer
    global model
    model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16)
    if not tokenizer_name:
        tokenizer_name = model_name
    global tokenizer
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)


In [None]:
def load_dataset(path_to_csv: str):
    # Check if the file at the given path exists
    if os.path.exists(path_to_csv):
        df = pd.read_csv(path_to_csv)
    else:
        print("File does not exist.")
        exit(1)

    global dataset 
    dataset = df
    
    # Create a new column "token_probability" for saving up the probabilites of the studied token for all prompts. Initially, 0.
    dataset["token_probability_true_sentence"] = 0
    dataset["token_probability_false_sentence"] = 0

In [None]:
def print_colored_separator(color="\033[94m", char="=", length=150, prints_enabled: bool = False):
    if prints_enabled:
        reset = "\033[0m"  # Reset color
        print(f"{color}{char * length}{reset}")

In [None]:
# Returns the model's output after feeding it with a prompt concatenated prompt_repetitions times and the concatenated prompt tensor
def feed_forward(true_sentence: str, false_sentence: str, prompt_repetitions: int = 1, prints_enabled: bool = False):
    print_colored_separator(prints_enabled)
    # Before proceeding, check that the true_sentence and false_sentence contain the same amount of tokens after tokenizing them. 
    # Important!: BOS token is usually not included for counting the tokens of a sentence, when indexing .shape[...]
    true_sentence_token_n = tokenizer(true_sentence, return_tensors="pt")["input_ids"][0].shape[0]
    false_sentence_token_n = tokenizer(false_sentence, return_tensors="pt")["input_ids"][0].shape[0]
    if true_sentence_token_n != false_sentence_token_n:
        return None, None, None
    
    # Extract all the words except the last one, split by space.  
    sentence_without_last_token = "".join(true_sentence.rsplit(" ", 1)[:-1])
    # Append the sentence without the last token to the prompt, starting with the true_sentence. This is one-shot learning.
    # Add space token to avoid that the point token "." gets tokenized together with the beginning of the next sentence.
    prompt = true_sentence + "\n" + false_sentence + "\n" + sentence_without_last_token
    token_sequence = tokenizer(prompt, return_tensors="pt")
    # print(f"prompt: {prompt}\ntoken_sequence: {token_sequence}\nNumber of tokens: {len(token_sequence['input_ids'][0])}")
    tokens = token_sequence["input_ids"][0]
    
    # Feed forward to the model
    global model
    out = model(tokens.unsqueeze(0).to(model.device), return_dict=True, output_attentions=True)
    # Return the output of the model, the tokenized prompt, number of tokens from the sentences (both sentences should have the same amount of tokens at this point)
    return out, tokens, true_sentence_token_n

In [None]:
import numpy as np
import plotly.graph_objects as go

def plot_induction_mask_with_plotly(induction_mask, induction_mask_text, prompt):
    # Create a Heatmap with the numeric mask (z) and attach the text
    heatmap = go.Heatmap(
        z=induction_mask,
        text=induction_mask_text, 
        hoverinfo='text',  # Only show the text on hover
        colorscale='Blues',  # Choose any Plotly colorscale you like
        showscale=True
    )

    fig = go.Figure(data=[heatmap])

    # Make the squares actually square by linking x/y scales
    fig.update_layout(
        xaxis=dict(scaleanchor="y", scaleratio=1),
        yaxis=dict(autorange="reversed"),  # Reverse y-axis so row 0 is at top
        title=f"Induction Mask for prompt: {prompt}\n"
    )

    fig.show()

In [None]:
def create_attention_mask(token_sequence: torch.Tensor, token_number_sentence: int, show_induction_mask: bool = False, prints_enabled: bool = False):
    print_colored_separator(prints_enabled)
    sequence_length = token_sequence.shape[0]
    induction_mask = torch.zeros(sequence_length, sequence_length).to(float)
    induction_mask_text = np.full((sequence_length, sequence_length), "", dtype=object)

    # Start at the beginning of the second sentence (+1 since BOS token was not counted). 
    for i in range(token_number_sentence + 1, sequence_length):
        if token_sequence[i] not in token_sequence[:i]:
            continue
        for j in range(i):
            if token_sequence[i] == token_sequence[j]:
                induction_mask[i, j + 1] = 1 
                # Encode to show raw strings (show e.g. new lines tokens)
                induction_mask_text[i, j + 1] = tokenizer.decode(token_sequence[i]).encode('unicode_escape').decode('utf-8') + "/" + tokenizer.decode(token_sequence[j + 1]).encode('unicode_escape').decode('utf-8')
    
    if show_induction_mask:
        # print("Induction Mask:\n")
        # print(induction_mask)
        # print()
        # print("Induction Mask plot:\n")
        # plt.imshow(induction_mask)
        # plt.show()
        plot_induction_mask_with_plotly(induction_mask, induction_mask_text, prompt=tokenizer.decode(token_sequence))
        return induction_mask

In [None]:
def compute_induction_head_scores(token_sequence: torch.Tensor, induction_mask: torch.Tensor, model_output):
    num_heads = model.config.num_attention_heads
    num_layers = model.config.num_hidden_layers
    sequence_length = token_sequence.shape[0]

    induction_scores = torch.zeros(num_layers, num_heads)
    #TODO: Does it make sense to remove the first row and first column? corresponding to the BOS token
    tril = torch.tril_indices(sequence_length, sequence_length) # gets the indices of elements on and below the diagonal
    induction_flat = induction_mask[tril[0], tril[1]].flatten()
    
    for layer in range(num_layers):
        for head in range(num_heads):
            pattern = model_output["attentions"][layer][0][head].cpu().to(float)
            pattern_flat = pattern[tril[0], tril[1]].flatten()
            score = (induction_flat @ pattern_flat) / pattern_flat.sum()
            induction_scores[layer, head] = score
    return induction_scores

In [None]:
def create_heatmap(induction_scores: torch.Tensor):
    print_colored_separator()
    _, ax = plt.subplots()
    print("Heatmap of induction scores across heads and layers: \n")
    sns.heatmap(induction_scores, cbar_kws={"label": "Induction Head Score"}, ax=ax)
    ax.set_ylabel("Layer #")
    ax.set_xlabel("Head #")
    plt.show()

In [None]:
def sort_filter_high_scoring_induction_heads(induction_scores: torch.Tensor, model_output: any, filter_by_threshold: float = 0.5, show_induction_heads: bool = False, prints_enabled: bool = False): 
    print_colored_separator(prints_enabled)
    
    # Filter induction scores by threshold
    mask = induction_scores >= filter_by_threshold

    filtered_scores = torch.where(mask, induction_scores, torch.tensor(float('-inf')))

    # Get flattened indices sorted by scores in descending order
    sorted_flat_indices = torch.argsort(filtered_scores.flatten(), descending=True)

    valid_indices = sorted_flat_indices[filtered_scores.flatten()[sorted_flat_indices] != float('-inf')]

    # Convert flattened indices to 2D indices
    sorted_indices = torch.unravel_index(valid_indices, induction_scores.shape)

    # Stack the row and column indices for final output
    sorted_indices = torch.stack(sorted_indices, dim=1)

    if show_induction_heads:
        print("Top 5 Induction Heads with the highest induction score - Descending order\n")
        for layer, head in sorted_indices[:5]:
            induction_score = induction_scores[layer][head]
            print(f"Layer: {layer}\nHead: {head}\nInduction Score: {induction_score}")
            plt.imshow(model_output["attentions"][layer][0][head].cpu().float())
            plt.show()
            print()
    return sorted_indices

In [None]:
def token_probability_extraction(head_indices: torch.Tensor, models_output: any, token_number_sentence: int, prints_enabled: bool = False):
    result_true_sentence = {}
    result_false_sentence = {}
    for idx in head_indices:
        print_colored_separator(prints_enabled)
        layer, head = idx
        probs = models_output["attentions"][layer][0][head]

        # Extract probability of the specified token
        sequence_length = probs.shape[0]
        # First index is y-axis, second is x-axis from the source destination diagram.
        # sequence_length - 1 because we want to index the last token of a sequence. 
        # token_number_sentence - 1 because we skip the newline at the end of each sentence. 
        probability_token_true_sentence = probs[sequence_length - 1, token_number_sentence - 1].item() 
        probability_token_false_sentence = probs[sequence_length - 1, 2 * token_number_sentence - 1].item() 

        # Results for token from true_sentence and false_sentence at current layer and head
        result_true_sentence[f"L{layer}_H{head}"] = probability_token_true_sentence 
        result_false_sentence[f"L{layer}_H{head}"] = probability_token_false_sentence
    return json.dumps(result_true_sentence), json.dumps(result_false_sentence)


In [None]:
def save_probability(token_probability: int, example_id: int, column_name_probability: str, prints_enabled: bool = False):
    if dataset.empty:
        raise Exception("Dataset is empty")
    
    print_colored_separator(prints_enabled)
    # Log the probability of the token into its corresponding row and column in the dataset.
    dataset.loc[dataset["example_id"] == example_id, f"{column_name_probability}"] = token_probability 
    if prints_enabled:
        print(f"Saved probability for token from example_id: {example_id}\n")

In [None]:
def display_attention_visualizations(head_indices: torch.Tensor, token_sequence: torch.Tensor, models_output):
    # Display attention diagrams
    tokens_vis = tokenizer.tokenize(tokenizer.decode(token_sequence.squeeze()))
    layer, head = head_indices[0]
    return attention_patterns(tokens_vis, models_output["attentions"][layer][0]), attention_heads(models_output["attentions"][layer][0], tokens_vis)

In [None]:
def run_experiment(true_sentence: str, false_sentence: str, filter_by_threshold: float = 0.5):
    models_output, token_sequence, token_number_sentence = feed_forward(true_sentence=true_sentence, false_sentence=false_sentence)
    if token_sequence is None: 
        # true_sentence and false_sentence have different number of tokens after tokenizing them. 
        return None, None
    print("Token Sequence: ", token_sequence)
    print("Token Sequence number: ", token_sequence.shape[0])
    print("Token number sentence: ", token_number_sentence)
    induction_mask = create_attention_mask(token_sequence=token_sequence, token_number_sentence=token_number_sentence, show_induction_mask=True)
    induction_scores = compute_induction_head_scores(token_sequence=token_sequence, induction_mask=induction_mask, model_output=models_output)
    create_heatmap(induction_scores=induction_scores)
    
    # Get induction heads indices after filtering by a threshold and sort them descending from top scoring head
    indices_induction_heads = sort_filter_high_scoring_induction_heads(induction_scores=induction_scores, model_output=models_output, filter_by_threshold=filter_by_threshold, show_induction_heads=True)
    print("Indices of high scoring induction heads - desc: ", indices_induction_heads)

    # Extract the probability of the 2 studied tokens for each head and layer. Store it in json format  
    token_probability_true_sentence, token_probability_false_sentence = token_probability_extraction(indices_induction_heads, models_output, token_number_sentence)

    return token_probability_true_sentence, token_probability_false_sentence 


In [None]:
def save_plot_results(save_path: str):
    # Get the list of heads
    head_sets_true_sentences = dataset["token_probability_true_sentence"].apply(json.loads)
    head_sets_true = head_sets_true_sentences.apply(lambda d: set(d.keys()))

    head_sets_false_sentences = dataset["token_probability_false_sentence"].apply(json.loads)
    head_sets_false = head_sets_false_sentences.apply(lambda d: set(d.keys()))

    # Calculate the intersection
    common_heads_true = set.intersection(*head_sets_true)
    common_heads_false = set.intersection(*head_sets_false)

    # Intersection of the heads which appear in the true && false sentence
    total_intersection_set = set.intersection(common_heads_true, common_heads_false)

    data = []
    for _, row in dataset.iterrows():
        for head in total_intersection_set:
            data.append({"Head": head, "Probability": json.loads(row["token_probability_true_sentence"])[head], "Type": "True"})
            data.append({"Head": head, "Probability": json.loads(row["token_probability_false_sentence"])[head], "Type": "False"})

    plot_df = pd.DataFrame(data=data)
    # Ensure "Type" column is formatted correctly to avoid duplicate legend entries
    plot_df["Type"] = plot_df["Type"].astype(str).str.strip()
    plt.figure(figsize=(12, 10))
    ax = sns.barplot(data=plot_df, x="Head", y="Probability", hue="Type")
    ax.get_figure().savefig(f"{save_path}-results-plot.png", dpi=300)


In [None]:
def save_result_csv(model_name: str, dataset_csv_file_path: str):
    global dataset
    model_name_folder = model_name.split("/")
    folder_path = os.path.dirname(dataset_csv_file_path) + "/" + model_name_folder[0] 
    if not os.path.exists(folder_path):
        os.mkdir(folder_path)
    model_and_path = f"{model_name_folder[-1]}-results.csv"
    new_file_path = os.path.join(folder_path, model_and_path)
    dataset.to_csv(new_file_path, index=False)

    # Save plot for the current dataset 
    model_and_path_image = os.path.join(folder_path, model_name_folder[-1])
    save_plot_results(save_path=model_and_path_image)

In [None]:
def delete_model():
    global model
    del model
    gc.collect()
    if torch.backends.mps.is_available():
        torch.mps.empty_cache()  # Clear MPS GPU memory
    


In [None]:
# Apply processing and store results
def process_row(row, filter_by_threshold):
    result_true, result_false = run_experiment(true_sentence=row["true_sentence"], 
                            false_sentence=row["false_sentence"], 
                            filter_by_threshold=filter_by_threshold)
    # Check if result is None or invalid
    if result_true is None:
        return pd.Series([pd.NA, pd.NA])  # Mark for removal
    return pd.Series([result_true, result_false]) 

In [None]:
def run_experiment_suite(dataset_csv_file_path: str, llm_models: list, prompt_repetitions: int = 1, filter_by_threshold: int = 0.5):
    for mod in llm_models: 
        print(f"Using device: {torch.device('mps') if torch.backends.mps.is_available() else 'cpu'}")
        initialize_model(model_name=mod, tokenizer_name=mod)
        load_dataset(path_to_csv=dataset_csv_file_path)
        
        # Use apply() in a vectorized manner to pass both columns
        global dataset
        dataset[["token_probability_true_sentence", "token_probability_false_sentence"]] = dataset.apply(
            lambda row: process_row(row, filter_by_threshold=filter_by_threshold),
            axis=1
        )

        # Drop rows where `process_row()` returned NaN
        dataset.dropna(subset=["token_probability_true_sentence", "token_probability_false_sentence"], inplace=True)

        # Create CSV result files saved in folders respective to the used LLM.
        save_result_csv(model_name=mod, dataset_csv_file_path=dataset_csv_file_path) 

        # Delete the model loaded in memory
        delete_model()


### Experiment Start

In [None]:
print("Your current working directory:", os.getcwd())
start_time = time.perf_counter()
run_experiment_suite(dataset_csv_file_path=CSV_PATH_DATASET, llm_models=models, prompt_repetitions=1, filter_by_threshold=0.10)
end_time = time.perf_counter()
elapsed_time = end_time - start_time
print(f"Execution Time: {elapsed_time:.4f} seconds.")