**Official implementation of paper "Revisiting In-context Learning Inference Circuit in Large Language Models" (ICLR 2025)**:

### Forerunner Token Head Counting

This experiment is to count the number of forerunner token heads in each layer, and also the maximum copy magnitude, for the Fig. 5 (Middle) and Fig. 22 - 25.

Author: Hakaze Cho, yfzhao@jaist.ac.jp, 2024/09

Organized, commented, and modified by: Hakaze Cho, 2025/01/28

**Part I: Import, Define, and Load Everything**

What you should do:
1. [Cell 1] Change to the path from your working directory to the directory containing the README.md file.
2. [Cell 2] Define your experiment parameters.
3. Run the Cell 1 - 4.


In [None]:
# Cell 1: Import libraries and change the working directory.

## Change the working directory
import os
try:
    # Change to the path from your working directory to the directory containing the README.md file.
    os.chdir("ICL_Inference_Dynamics_Released") 
except:
    print("Already in the correct directory or the directory does not exist.")

## Import libraries
from util import load_model_and_data, inference
import StaICC
import matplotlib.pyplot as plt
import pickle
import numpy as np

## Some definations for the plots.
plt.style.use('default')
plt.rc('font',family='Cambria Math')
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Cambria Math'] + plt.rcParams['font.serif']

## Calculate the attention threshold from the prompt.
def get_thereshold_from_prompt(tokenizer, prompt, induction_threthold_times):
    tkized = tokenizer(prompt)['input_ids']
    return induction_threthold_times / len(tkized)

In [2]:
# Cell 2: Model and huggingfacetoken configurations

## The huggingface model name to be tested as the LM for ICL. 
## Recommended: "meta-llama/Meta-Llama-3-8B", "tiiuae/falcon-7b", "meta-llama/Meta-Llama-3-70B", "tiiuae/falcon-40b"
ICL_model_name = "meta-llama/Meta-Llama-3-70B"

## Whether to use the quantized version of the model. 
## Recommended: Keep it default.
quantized = False if ICL_model_name in ["meta-llama/Meta-Llama-3-8B", "EleutherAI/pythia-6.9b", "tiiuae/falcon-7b"] else True

## The huggingface token to access the model. If you use the Llama model, you need to set this.
huggingface_token = "your token here"

## Use CPU instead of GPU to process this experiment. 
## Recommended: Only to use when you have the intermediate results in path `experiment_matrial` (shown below)
cpu_process = False

# Experiment parameters

## The demonstration numbers. Recommended: 0, 1, 2, 4, 8, 12.
k = 4 

## The used dataset index from the StaICC library. Alternative: 0, 1, 2, 3, 4, 5. See the README.md for more information.
dataset_index = 2 

## Whether the **last** label in the prompt is (True) correct label or (False) wrong label.
corr_label = True

## Force the ICL_model to reload, even the ICL_model is already in the variables. 
## Recommended: False.
model_forced_reload = False

## Force the experiment to be redone, even the intermediate results are already in the path `experiment_material`.
## Recommended: False.
experiment_forced_redo = False

## Define the const in the threshold to be judged as a forerunner token head. (e.g., "5" in 5/n_t)
forerunner_token_head_threshold_times = 5

In [None]:
# Cell 3: Load the data and build the test inputs.

bench = StaICC.Normal(k)
prompts, queries = load_model_and_data.load_data_from_StaICC_experimentor(bench[dataset_index], "label_words", corr_label)

In [None]:
# Cell 4: Load the model. If the intermediate results in path `experiment_matrial` (shown below) has been detected, automatically skip.

data_file_name = "experiment_matrial/" + ICL_model_name.replace('/', '_')+ "," + ",copy_Hidd_att" + ',' + ("" if corr_label else "wrong,") + str(k) + ',' + str(dataset_index + 1) + ".pickle"
if not os.path.exists(data_file_name) or experiment_forced_redo:
    vars_dict = vars() if "ICL_model" in vars() else locals()
    if "ICL_model" not in vars_dict or model_forced_reload:
        ICL_model, ICL_tknz = load_model_and_data.load_ICL_model(ICL_model_name, huggingface_token = huggingface_token, quantized = quantized, device = "cpu" if cpu_process else "cuda")
        loaded = True

**Part II: Run the Experiment**

What you should do:

1. Run the Cell 5 - 7.

In [None]:
# Cell 5: Inference the hidden states and save the intermediate results. If the intermediate results in path `experiment_matrial` (shown below) has been detected, automatically load the results.

data_file_name = "experiment_matrial/" + ICL_model_name.replace('/', '_')+ "," + ",copy_Hidd_att" + ',' + ("" if corr_label else "wrong,") + str(k) + ',' + str(dataset_index + 1) + ".pickle"
if os.path.exists(data_file_name) and not experiment_forced_redo:
    with open(data_file_name, 'rb') as f:
        ICL_hidden_states = pickle.load(f)
        print("Intermediate results loaded")
else:
    ICL_hidden_states = inference.step2_get_fl_feature_and_lastftol_attention(ICL_model, ICL_tknz, prompts)
    with open(data_file_name, 'wb') as f:
        pickle.dump(ICL_hidden_states, f)

In [None]:
# Cell 6: Calculate the counted times for each attention head.

counted_times_for_each_head = [[0 for i in range(len(inference.get_copy_magnitude_for_single_layer(ICL_hidden_states[1], 0, 0)))] for i in range(len(ICL_hidden_states[1][0]))]
for layers in range(len(ICL_hidden_states[1][0])):
    temp = []
    head_count = []
    for sample in range(len(ICL_hidden_states[1])):
        thre = get_thereshold_from_prompt(ICL_tknz, prompts[sample], forerunner_token_head_threshold_times)
        magnitudes = inference.get_copy_magnitude_for_single_layer(ICL_hidden_states[1], sample, layers)
        for headindex in range(len(magnitudes)):
            if ICL_hidden_states[1][sample][layers][headindex][0] > thre:
                counted_times_for_each_head[layers][headindex] += 1

In [16]:
# Cell 7: Calculate the counted times summary for each layer.

mean_max_magnitude = []
mean_head_count = []

for layers in range(len(ICL_hidden_states[1][0])):
    temp = []
    head_count = []
    for sample in range(len(ICL_hidden_states[1])):
        thre = get_thereshold_from_prompt(ICL_tknz, prompts[sample], forerunner_token_head_threshold_times)
        magnitudes = inference.get_copy_magnitude_for_single_layer(ICL_hidden_states[1], sample, layers)
        temp.append(max(magnitudes))
        count = 0
        for temp_res in magnitudes:
            if temp_res > thre:
                count += 1
        head_count.append(count)
    mean_max_magnitude.append(np.mean(temp))
    mean_head_count.append(np.mean(head_count))

**Part III: Plot and Save the Result**

What you should do:

1. Run the Cell 8 - 11. You can define your own file name and dictionary to save the result in Cell 9 and 11.

In [None]:
# Cell 8: Plot the counted times for each attention head in a heatmap.

r = plt.imshow(counted_times_for_each_head, cmap = 'Blues' if corr_label else "Purples", vmin = 128, vmax = 512)
plt.yticks(range(0, len(counted_times_for_each_head))[::10], range(1, 1 + len(counted_times_for_each_head))[::10])
plt.colorbar(r, shrink=0.5)
plt.xlabel('Head #', fontsize = 12)
plt.ylabel('Transformer Block', fontsize = 12) 
plt.title("Forerunner Token Head to " + ("Correct Label" if corr_label else "Wrong Label") + "\n Dataset " + str(dataset_index + 1) + " with k = " + str(k) + "\n model: " + ICL_model_name, fontsize = 12)

In [14]:
# Cell 9: Save the heat map data.

import pickle

data_file_name = "data/" + ICL_model_name.replace('/', '_')+ "," + "copy_magnitude,headstat"  + ',' + ("" if corr_label else "wrong,") + str(k) + ',' + str(dataset_index + 1) + ".pickle"
with open(data_file_name, 'wb') as f:
    pickle.dump(counted_times_for_each_head, f)

In [None]:
# Cell 10: Plot the figure similar to the Fig. 5 (Middle)

plt.figure(figsize=(4, 3))
ax = plt.gca()

ax.plot(range(1,len(mean_head_count) + 1), mean_head_count, 
        label = "Correct Label" if corr_label else "Wrong Label",
        color = "#023858" if corr_label else "#ff7f0e"
)

ax.set_xlim(-1, len(mean_head_count) + 1)
ax.set_xlabel("Transformer Block Number", fontsize = 12)
ax.set_ylabel("Forerunner Token Head #", fontsize = 12)

ylim = ax.get_ylim()
xrange = ax.get_xticks()
xrange[1] = 1
plt.xticks(xrange[1:-1])
ax.set_ylim(ylim)

ax2 = ax.twinx()
ax2.set_ylabel("Maximum Copy Magnitude", fontsize = 12)
ax2.fill_between(range(1,len(mean_max_magnitude) + 1), mean_max_magnitude, color = "#023858" if corr_label else "#ff7f0e", alpha = 0.2)
ax2.set_ylim((0, 1.1))

ax.set_zorder(3)
ax2.set_zorder(4)

In [23]:
# Cell 11: Save the figure data.
# Result file organization:
# (mean_max_magnitude: list[layer_index] = max forerunner token copy attention, mean_head_count: list[layer_index] = forerunner token head count)

data_file_name = "data/" + ICL_model_name.replace('/', '_') + ",copy_magnitude" + ',' + ("" if corr_label else "wrong,") + str(k) + ',' + str(dataset_index + 1) + ".pickle"
with open(data_file_name, 'wb') as f:
    pickle.dump([mean_max_magnitude, mean_head_count], f)