# Import Required Libraries
Import libraries such as `torch`, `numpy`, `matplotlib`, and any other necessary modules.

In [1]:
# Import Required Libraries
import torch
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import sem  # For calculating confidence intervals
from itertools import islice  # For slicing iterators
from functools import partial  # For creating partial functions
from tqdm.auto import tqdm  # For progress bars
import matplotlib.pyplot as plt
import seaborn as sns  # For enhanced visualizations
import plotly.express as px  # For interactive visualizations
import transformer_lens
import torch
# import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix
# Import stuff
import torch
import torch.nn as nn
import einops
from fancy_einsum import einsum
from tqdm import tqdm
import plotly.express as px

from jaxtyping import Float
from functools import partial
import numpy as np

import sys
sys.path.insert(0, "..")
sys.path.insert(0, "../generators")
sys.path.insert(0, "generators")


from generators.string_reversal import StringReversalGenerator
import transformers
import itertools


# Ensure plots are displayed inline in Jupyter Notebook
%matplotlib inline

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=9

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=9


In [3]:
!pip install git+https://github.com/TransformerLensOrg/TransformerLens.git@refs/pull/900/head

Collecting git+https://github.com/TransformerLensOrg/TransformerLens.git@refs/pull/900/head
  Cloning https://github.com/TransformerLensOrg/TransformerLens.git (to revision refs/pull/900/head) to /tmp/pip-req-build-41u2xcmi
  Running command git clone --filter=blob:none --quiet https://github.com/TransformerLensOrg/TransformerLens.git /tmp/pip-req-build-41u2xcmi
[0m  Running command git fetch -q https://github.com/TransformerLensOrg/TransformerLens.git refs/pull/900/head
  Running command git checkout -q b9e5bd8cc9390e6f3efcf0792e2b2dd023734357
  Resolved https://github.com/TransformerLensOrg/TransformerLens.git to commit b9e5bd8cc9390e6f3efcf0792e2b2dd023734357
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone


In [10]:
#model_name = "Qwen/Qwen2.5-1.5B-Instruct"
#checkpoint_path_old = "checkpoints/Qwen2.5-1.5B-Instruct_string_reversal_finetuned.weights"
#checkpoint_path_new = "checkpoints/Qwen2.5-1.5B-Instruct_string_reversal_finetuned_hooked.weights"
#model_name = "meta-llama/Llama-3.2-1B-Instruct"
#checkpoint_path_old = "checkpoints/old/llama_instruct_string_reversal_finetuned.weights"
#checkpoint_path_new = "checkpoints/llama_instruct_string_reversal_finetuned_hooked.weights"
model_name = "google/gemma-3-1b-it"
checkpoint_path_old = "checkpoints/gemma_3_1b_it_string_reversal_finetuned.weights"
checkpoint_path_new = "checkpoints/gemma_3_1b_it_string_reversal_finetuned_hooked.weights"

In [11]:

from transformer_lens.pretrained.weight_conversions import convert_gemma3_weights
from transformers import AutoConfig, AutoModelForCausalLM

config = AutoConfig.from_pretrained(model_name)
model = AutoModelForCausalLM.from_config(config)
model.load_state_dict(torch.load(checkpoint_path_old, weights_only=True))
hooked_config = transformer_lens.loading.get_pretrained_model_config(model_name)
new_state_dict = convert_gemma3_weights(model, hooked_config)
torch.save(new_state_dict, checkpoint_path_new)



In [12]:
config = transformer_lens.loading.get_pretrained_model_config(model_name)
model = transformer_lens.HookedTransformer(config)
model.load_and_process_state_dict(torch.load(checkpoint_path_new, weights_only=True))



In [13]:
print(model)

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0-25): 26 x TransformerBlock(
      (ln1): RMSNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln1_post): RMSNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): RMSNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2_post): RMSNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): GroupedQueryAttention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_rot_k): HookPoint()
        (hook_rot_q): HookPoint()
        (q_norm): RMSNorm(
          (hook_scale): HookPoint()
          (hook_normalized): HookPoint()
        )
   

# Define Helper Functions
Define functions for computing accuracy, confidence intervals, and visualizing results.

In [14]:
# Define Helper Functions
best_heads = {}
def fix_attn_pattern_hook(
    value: Float[torch.Tensor, "batch pos head_index d_head"],
    hook: HookPoint,
    attn_labels_1_idx,
    heads_to_reinforce,
    threshold
) -> Float[torch.Tensor, "batch pos head_index d_head"]:
    #value[:, :, :, :] += torch.tensor(neco, dtype=torch.float32).cuda()
    batch, heads, seq_len, _ = value.shape 
    
    layer = int(hook.name.split(".")[1])
    if layer not in heads_to_reinforce.keys():
        return
    for h in range(heads):
        if h not in heads_to_reinforce[layer]:
            continue
        #visualize_attention_heads(value, None, title="Example Attention Heatmap", heads=[heads_to_reinforce[layer]])
        acc = 0
        for i in range(seq_len):
            for j in range(seq_len):
                if (i, j) in attn_labels_1_idx and value[0, h, i, j] >= threshold:
                    value[0, h, i, j] += 0.6
                    #acc += value[0, h, i, j]
                    pass
        #best_heads[(layer, h)] = best_heads.get((layer, h), 0) + acc
        #print("After corrections")
        #visualize_attention_heads(value, None, title="Example Attention Heatmap", heads=[heads_to_reinforce[layer]])

    return value

def compute_accuracy(logits, targets, ignore_index=-100):
    """
    Compute accuracy by comparing predicted tokens with target tokens.
    Ignores tokens with the specified ignore_index.
    """
    predictions = torch.argmax(logits, dim=-1)  # Get predicted token indices
    correct = (predictions == targets) & (targets != ignore_index)  # Ignore padding tokens
    accuracy = correct.sum().item() / (targets != ignore_index).sum().item()
    return accuracy

def compute_confidence_interval(data, confidence=0.95):
    """
    Compute the confidence interval for a given dataset.
    Uses the standard error of the mean (SEM) and the t-distribution.
    """
    mean = np.mean(data)
    error = sem(data)  # Standard error of the mean
    margin = error * 1.96  # Approximation for 95% confidence interval
    return mean, margin

def plot_accuracy_curve(lengths, accuracies, confidence_intervals, title="Accuracy vs Input Length"):
    """
    Plot accuracy as a function of input length with enhanced confidence intervals using Seaborn.
    """
    lengths = np.array(lengths)
    accuracies = np.array(accuracies)
    confidence_intervals = np.array(confidence_intervals)

    # Create a DataFrame for Seaborn
    import pandas as pd
    data = pd.DataFrame({
        "Input Length": lengths,
        "Accuracy": accuracies,
        "Lower Bound": accuracies - confidence_intervals,
        "Upper Bound": accuracies + confidence_intervals
    })

    # Plot using Seaborn
    plt.figure(figsize=(12, 8))
    sns.lineplot(data=data, x="Input Length", y="Accuracy", marker="o", label="Accuracy")
    plt.fill_between(
        data["Input Length"],
        data["Lower Bound"],
        data["Upper Bound"],
        color="blue",
        alpha=0.2,
        label="Confidence Interval"
    )
    plt.xlabel("Input Length")
    plt.ylabel("Accuracy")
    plt.title(title)
    plt.legend()
    plt.grid(True)
    plt.show()

# Generate Data Samples
Generate random data samples of varying lengths using the data generator.

In [15]:
# Generate Data Samples

# Initialize the data generator with the model's tokenizer
data_generator = StringReversalGenerator(
    tokenizer=model.tokenizer,
    seed=0,
    use_few_shot=True,
    use_instruction=True,
    apply_chat_template=True,
)

# Function to generate random samples of varying lengths
def generate_samples_for_lengths(lengths, num_samples=10):
    """
    Generate random data samples for a list of input lengths.

    Parameters:
        lengths (list): List of input lengths to generate samples for.
        num_samples (int): Number of samples to generate for each length.

    Returns:
        dict: A dictionary where keys are lengths and values are lists of samples.
    """
    samples_by_length = {}
    for length in tqdm(lengths, desc="Generating Samples"):
        data_generator.length = (length, length)  # Set the length range
        samples = list(islice(data_generator.generate_samples(), num_samples))
        samples_by_length[length] = samples
    return samples_by_length

# Define the input lengths and number of samples per length
input_lengths = [10, 20, 50, 100, 150, 200, 250, 300, 350, 400, 450, 500, 550, 600, 650, 700, 750, 800]
#input_lengths = [100, 800]
num_samples_per_length = 10

# Generate the data samples
samples_by_length = generate_samples_for_lengths(input_lengths, num_samples_per_length)

Generating Samples: 100%|██████████| 18/18 [00:06<00:00,  2.88it/s]


# Measure Accuracy Across Input Lengths
Iterate over different input lengths, compute accuracy for a statistically significant number of samples, and store the results.

In [16]:
# Measure Accuracy Across Input Lengths
best_heads = {}
# Initialize lists to store results
lengths = []
accuracies = []
confidence_intervals = []
device = "cuda"
model.to(device)
is_attn_scores = lambda name: name.endswith("hook_pattern")

heads_to_reinforce=[
    (22, [7]),
    (24, [10]),
    (23, [0, 8, 3]),
    (27, [8, 11, 6]),
    (26, [1])

]

# Iterate over each input length
for length, samples in tqdm(samples_by_length.items(), desc="Measuring Accuracy"):
    length_accuracies = []  # Store accuracies for the current length
    threshold = 0
    # Iterate over the samples for the current length
    for sample in samples:
        input_ids, target_ids, attn_labels = sample[1], sample[2], sample[3]  # Extract input and target IDs
        input_ids = torch.tensor(input_ids).unsqueeze(0).to(device)  # Add batch dimension
        target_ids = torch.tensor(target_ids).unsqueeze(0).to(device)
        attn_labels_1_idx = set()
        for i in range(len(attn_labels)):
            for j in range(len(attn_labels[0])):
                if attn_labels[i][j] == 1:
                    attn_labels_1_idx.add((i, j))
        # Run the model and compute accuracy
        with torch.no_grad():
            logits = model.run_with_hooks(input_ids, return_type="logits",
                                            fwd_hooks=[
                                                #(is_attn_scores, partial(fix_attn_pattern_hook, attn_labels_1_idx=attn_labels_1_idx, heads_to_reinforce=dict(heads_to_reinforce), threshold=threshold))
                                            ])
            
            accuracy = compute_accuracy(logits, target_ids)
            print(accuracy)
            length_accuracies.append(accuracy)

    # Compute mean accuracy and confidence interval for the current length
    mean_accuracy, confidence_margin = compute_confidence_interval(length_accuracies)
    lengths.append(input_ids.shape[1])
    accuracies.append(mean_accuracy)
    confidence_intervals.append(confidence_margin)

Moving model to device:  cuda


Measuring Accuracy:   0%|          | 0/18 [00:00<?, ?it/s]

  input_ids = torch.tensor(input_ids).unsqueeze(0).to(device)  # Add batch dimension
  target_ids = torch.tensor(target_ids).unsqueeze(0).to(device)


0.0
0.0
0.0
0.0


Measuring Accuracy:   0%|          | 0/18 [00:03<?, ?it/s]


KeyboardInterrupt: 

In [None]:
for item in sorted(best_heads.items(), key=lambda item: item[1], reverse=True):
    print(item[0][0], item[0][1], item[1].item())

In [None]:
best_heads

# Visualize Accuracy with Confidence Intervals
Plot the accuracy curve with confidence intervals using `matplotlib` or `plotly`.

In [None]:

# Plot the accuracy curve with confidence intervals
plot_accuracy_curve(
    lengths=lengths,
    accuracies=accuracies,
    confidence_intervals=confidence_intervals,
    title="Accuracy vs Input Length with Confidence Intervals"
)

In [None]:
import pandas as pd
import plotly.graph_objects as go

# Create a DataFrame for Plotly
data = pd.DataFrame({
    "Input Length": lengths,
    "Accuracy": accuracies,
    "Lower Bound": np.array(accuracies) - np.array(confidence_intervals),
    "Upper Bound": np.array(accuracies) + np.array(confidence_intervals)
})

# Create the main accuracy line
fig = go.Figure()

fig.add_trace(go.Scatter(
    x=data["Input Length"],
    y=data["Accuracy"],
    mode="lines+markers",
    name="Accuracy",
    line=dict(color="blue"),
    marker=dict(size=8)
))

# Add the confidence interval as a shaded area
fig.add_trace(go.Scatter(
    x=np.concatenate([data["Input Length"], data["Input Length"][::-1]]),
    y=np.concatenate([data["Upper Bound"], data["Lower Bound"][::-1]]),
    fill="toself",
    fillcolor="rgba(0, 0, 255, 0.2)",  # Light blue fill
    line=dict(color="rgba(255,255,255,0)"),  # No border line
    hoverinfo="skip",
    name="Confidence Interval"
))

# Update layout for better visualization
fig.update_layout(
    title="Accuracy vs Input Length with Confidence Intervals",
    xaxis_title="Input Length",
    yaxis_title="Accuracy",
    template="plotly_white",
    showlegend=True
)

fig.show()

In [None]:
#qwen_vanilla_data = data
#vanilla_data = data3
#extrapolate_data = data2
#vanilla_data.to_csv("string_reversal_vanilla_extrapolate.csv")
#extrapolate_data.to_csv("string_reversal_fix_extrapolate.csv")
#qwen_vanilla_data.to_csv("qwen_string_reversal_vanilla_extrapolate.csv")
qwen_reinforced_data = data
qwen_reinforced_data.to_csv("qwen_string_reversal_reinforced_extrapolate.csv")

In [None]:
vanilla_data

In [None]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go

# Load the CSV files
string_reversal_vanilla = pd.read_csv("llama_string_reversal_vanilla_extrapolate.csv")
string_reversal_reinforced = pd.read_csv("llama_string_reversal_reinforcement_extrapolate.csv")
value_assignment_vanilla = pd.read_csv("llama_value_assignment_vanilla_extrapolate.csv")
value_assignment_reinforced = pd.read_csv("llama_value_assign_reinforcement_extrapolate.csv")

# Create the figure
fig = go.Figure()

# Add the string reversal vanilla model (LLama3-SFT) accuracy curve
fig.add_trace(go.Scatter(
    x=string_reversal_vanilla["Input Length"],
    y=string_reversal_vanilla["Accuracy"],
    mode="lines+markers",
    name="String Reversal - Llama-3.2-1B-Instruct - Vanilla",
    line=dict(color="blue", width=2),
    marker=dict(size=6)
))

# Add the confidence interval for the string reversal vanilla model
fig.add_trace(go.Scatter(
    x=np.concatenate([string_reversal_vanilla["Input Length"], string_reversal_vanilla["Input Length"][::-1]]),
    y=np.concatenate([string_reversal_vanilla["Upper Bound"], string_reversal_vanilla["Lower Bound"][::-1]]),
    fill="toself",
    fillcolor="rgba(0, 0, 255, 0.2)",  # Light blue fill
    line=dict(color="rgba(255,255,255,0)"),
    hoverinfo="skip",
    name="String Reversal - Llama-3.2-1B-Instruct - Vanilla - Confidence Interval"
))

# Add the string reversal reinforced model (LLama3-SFT-AttentionReinforcement) accuracy curve
fig.add_trace(go.Scatter(
    x=string_reversal_reinforced["Input Length"],
    y=string_reversal_reinforced["Accuracy"],
    mode="lines+markers",
    name="String Reversal - Llama-3.2-1B-Instruct - AttentionReinforcement (Ours)",
    line=dict(color="green", width=2),
    marker=dict(size=6)
))

# Add the confidence interval for the string reversal reinforced model
fig.add_trace(go.Scatter(
    x=np.concatenate([string_reversal_reinforced["Input Length"], string_reversal_reinforced["Input Length"][::-1]]),
    y=np.concatenate([string_reversal_reinforced["Upper Bound"], string_reversal_reinforced["Lower Bound"][::-1]]),
    fill="toself",
    fillcolor="rgba(0, 255, 0, 0.2)",  # Light green fill
    line=dict(color="rgba(255,255,255,0)"),
    hoverinfo="skip",
    name="String Reversal - Llama-3.2-1B-Instruct - AttentionReinforcement (Ours) - Confidence Interval"
))

# Add the value assignment vanilla model (LLama3-SFT) accuracy curve
fig.add_trace(go.Scatter(
    x=value_assignment_vanilla["Input Length"],
    y=value_assignment_vanilla["Accuracy"],
    mode="lines+markers",
    name="Value Assignment - Llama-3.2-1B-Instruct - Vanilla",
    line=dict(color="red", width=2),
    marker=dict(size=6)
))

# Add the confidence interval for the value assignment vanilla model
fig.add_trace(go.Scatter(
    x=np.concatenate([value_assignment_vanilla["Input Length"], value_assignment_vanilla["Input Length"][::-1]]),
    y=np.concatenate([value_assignment_vanilla["Upper Bound"], value_assignment_vanilla["Lower Bound"][::-1]]),
    fill="toself",
    fillcolor="rgba(255, 0, 0, 0.2)",  # Light red fill
    line=dict(color="rgba(255,255,255,0)"),
    hoverinfo="skip",
    name="Value Assignment - Llama-3.2-1B-Instruct - Vanilla - Confidence Interval"
))

# Add the value assignment reinforced model (LLama3-SFT-AttentionReinforcement) accuracy curve
fig.add_trace(go.Scatter(
    x=value_assignment_reinforced["Input Length"],
    y=value_assignment_reinforced["Accuracy"],
    mode="lines+markers",
    name="Value Assignment - Llama-3.2-1B-Instruct - AttentionReinforcement (Ours)",
    line=dict(color="purple", width=2),
    marker=dict(size=6)
))

# Add the confidence interval for the value assignment reinforced model
fig.add_trace(go.Scatter(
    x=np.concatenate([value_assignment_reinforced["Input Length"], value_assignment_reinforced["Input Length"][::-1]]),
    y=np.concatenate([value_assignment_reinforced["Upper Bound"], value_assignment_reinforced["Lower Bound"][::-1]]),
    fill="toself",
    fillcolor="rgba(128, 0, 128, 0.2)",  # Light purple fill
    line=dict(color="rgba(255,255,255,0)"),
    hoverinfo="skip",
    name="Value Assignment - Llama-3.2-1B-Instruct - AttentionReinforcement (Ours) - Confidence Interval"
))

# Update layout for better visualization
fig.update_layout(
    title="Accuracy vs Input Length for String Reversal and Value Assignment Tasks",
    xaxis_title="Input Length (tokens)",
    yaxis_title="Accuracy",
    template="plotly_white",
    showlegend=True
)

# Show the plot
fig.show()