**Official implementation of paper "Token-based Decision Criteria Are Suboptimal in In-context Learning" (NAACL 2025)**:

### Source code for the Analysis in Sec. 5.1.

The code for the main experiments of the paper: test the performance of Hidden Calibration and other methods on the in-context learning task.

Used in experiments of:

1. The experiments described in the Sec 5.1, and results are shown in the Fig. 7, 8, 9 of the paper.

Author: Hakaze Cho, yfzhao@jaist.ac.jp

**Experiment Configs**

- `huggingface_model_name`: should be a model name from the HuggingFace model hub. For example, `facebook/opt-2.7b`.
- `huggingface_token`: should be a HuggingFace API token. Only is used when you use some models like `Llama2`.
- `quantization`: should be a boolean value. If it is `True`, the model will be quantized.
- `dataset_name`: should be a dataset name from the given examples: `"SemEvalR", "SemEvalL", "poem_sentiment", "TEE", "TEH", "TES", "FP", "AGNews", "MR", "hate_speech"`
- `k`: the demonstration numbers for the ICL.
- `calibration_sample_number_for_each_label`: the number of samples for calibration w.r.t. each label category for some of the calibration methods. i.e. the horizontal axis of the Fig. 6 in the paper.

In [1]:
# Configs
huggingface_model_name = "facebook/opt-2.7b"
huggingface_token = "API_TOKEN"
quantization = False

dataset_name = "SemEvalR" # Alternative: "SemEvalR", "SemEvalL", "poem_sentiment", "TEE", "TEH", "TES", "FP", "AGNews", "MR", "hate_speech"

k = 4
calibration_sample_number_for_each_label = 16

**Load everything.**

In [2]:
# Import libraries, and nessessary definitions

import sys
sys.path.append("hidden_calibration_released") # Replace with the path from the working directory to the root of this project. If the working directory is already the root of the project, this line is not needed.

import util.prompting as prompting
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
import util.dataset_loader as dataset_loader
import util.calibrations as calibrations
import numpy as np
from tqdm import tqdm
from scipy import spatial
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde
import seaborn as sns

def softmax(x):
    f_x = np.exp(x) / np.sum(np.exp(x))
    return f_x

dataset_name_to_class = {
    "SemEvalR": dataset_loader.SemEval2014_Restaurants,
    "SemEvalL": dataset_loader.SemEval2014_Laptops,
    "poem_sentiment": dataset_loader.poem_sentiment,
    "TEE": dataset_loader.tweet_eval_emotion,
    "TEH": dataset_loader.tweet_eval_hate,
    "TES": dataset_loader.tweet_eval_sentiment,
    "FP": dataset_loader.financial_phrasebank,
    "AGNews": dataset_loader.agnews,
    "MR": dataset_loader.rooten_tomato,
    "hate_speech": dataset_loader.hate_speech18,
}

plt.style.use('default')
plt.rc('font',family='Times New Roman')
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Times New Roman'] + plt.rcParams['font.serif']

In [None]:
# Load model and tokenizer from Huggingface

torch.cuda.empty_cache()

tokenizer = AutoTokenizer.from_pretrained(huggingface_model_name, token = huggingface_token)
if quantization:
    model = AutoModelForCausalLM.from_pretrained(huggingface_model_name, token = huggingface_token, quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.bfloat16
    ))
else:
    model = AutoModelForCausalLM.from_pretrained(huggingface_model_name, token = huggingface_token).cuda()

In [4]:
# Load the training and testing data

train_data = dataset_name_to_class[dataset_name]().default_training_division()
test_data = dataset_name_to_class[dataset_name]().default_testing_division()
if dataset_name == "AGNews":
    train_data.cut_by_length(226)
    test_data.cut_by_length(226)

**Train all the calibration methods.**

In [None]:
# Train the Contextual Calibration and Domain Calibration

torch.cuda.empty_cache()
observed_background_prob_for_CC = calibrations.empty_query_base_logits(model, tokenizer, train_data, k, None, calibration_sample_number_for_each_label * len(train_data.label_space))
observed_background_prob_for_DC = calibrations.domain_text_base_logits(model, tokenizer, train_data, k, None, calibration_sample_number_for_each_label * len(train_data.label_space))

In [None]:
# Train the Hidden Calibration

## Collect the hidden states for the calibration samples.
ground_truth_labels_for_calibration_samples = []
observed_full_vocabulary_probabilities = []
observed_last_hidden_state = []

torch.no_grad()
for j in range(len(train_data.label_space)):
    label_set = dataset_loader.get_label_set_from_label_index(train_data, j)
    for i in tqdm(label_set[:calibration_sample_number_for_each_label]):
        torch.cuda.empty_cache()
        prpt = prompting.default_prompting(train_data, k, query_index=i)
        tknzd_data = tokenizer(prpt[0], return_tensors="pt").input_ids.cuda()
        result = model(tknzd_data, output_hidden_states = True)
        result_vector = result['logits'][0][-1].detach().cpu().numpy()
        one_last_hidden_state = result.hidden_states[-1][-1][-1].detach().cpu().numpy()
        observed_last_hidden_state.append(one_last_hidden_state)
        tkized_label_space = []
        ground_truth_labels_for_calibration_samples.append(train_data.label_space.index(prpt[1]))

## Divide the collected hidden states by label.
observed_last_hidden_state_indexed_by_label = []

for label in train_data.label_space:
    observed_last_hidden_state_indexed_by_label.append([])
for i in range(len(ground_truth_labels_for_calibration_samples)):
    observed_last_hidden_state_indexed_by_label[ground_truth_labels_for_calibration_samples[i]].append(observed_last_hidden_state[i])

## Calculate the centroids from the collected hidden states
observed_last_hidden_state_CENTROID_indexed_by_label = []
for lists in observed_last_hidden_state_indexed_by_label:
    observed_last_hidden_state_CENTROID_indexed_by_label.append(np.mean(lists, axis=0))

**Inference and calculate the overlap**

In [7]:
# Initialize a list to store: the inference output of each input sample, index by the input sample's ground truth label.

logits_set_wrt_gt_label = []
hidden_state_set_wrt_gt_label = []
for i in range(len(train_data.label_space)):
    logits_set_wrt_gt_label.append([])
    hidden_state_set_wrt_gt_label.append([])

In [None]:
# Fill the aforementioned list by inference

count = 0
correct_predicted_hidden_calibration = 0
correct_predicted_vanilla = 0

torch.no_grad()
for i in tqdm(range(test_data.get_max())):
    count += 1
    torch.cuda.empty_cache()
    predicted_single = []
    prpt = prompting.default_prompting(test_data, k, query_index=i)
    tknzd_data = tokenizer(prpt[0], return_tensors="pt").input_ids.cuda()
    result = model(tknzd_data, output_hidden_states = True)
    result_vector = result['logits'][0][-1].detach().cpu().numpy()
    ahidden_state = result.hidden_states[-1][-1][-1].detach().cpu().numpy()
    direct_label_logits = []
    # hidden_set_wrt_label[test_data.label_space.index(prpt[1])].append(ahidden_state)
    
    for label in test_data.label_space:
        index = tokenizer(label).input_ids[-1]
        direct_label_logits.append(result_vector[index])
    direct_label_logits = softmax(direct_label_logits)
    logits_set_wrt_gt_label[test_data.label_space.index(prpt[1])].append(direct_label_logits)
    
    if test_data.label_space.index(prpt[1]) == np.argmax(direct_label_logits):
        correct_predicted_vanilla += 1
    
    distance_label_logits = []
    for j in range(len(train_data.label_space)):
        distance_label_logits.append(-spatial.distance.euclidean(ahidden_state, observed_last_hidden_state_CENTROID_indexed_by_label[j]) + 15)
    hidden_state_set_wrt_gt_label[test_data.label_space.index(prpt[1])].append(softmax(distance_label_logits))
    
    if test_data.label_space.index(prpt[1]) == np.argmax(distance_label_logits):
        correct_predicted_hidden_calibration += 1

In [9]:
# Initialize the heatmap and mean lists for the results.

vanilla_result_heatmap = []
hidden_result_heatmap = []
contextual_result_heatmap = []
domain_result_heatmap = []
vanilla_result_mean = []
hidden_result_mean = []
contextual_result_mean = []
domain_result_mean = []

for i in range(len(train_data.label_space)):
    vanilla_result_heatmap.append([0] * len(train_data.label_space))
    hidden_result_heatmap.append([0] * len(train_data.label_space))
    contextual_result_heatmap.append([0] * len(train_data.label_space))
    domain_result_heatmap.append([0] * len(train_data.label_space))

In [None]:
# Calculate the overlap results by vanilla ICL, fill the corresponding heatmap and mean lists for the results, and show the density plots.
# Calculation described in the Appendix B.5.1.

for i in range(len(train_data.label_space)):
    for j in range(i):
        logist_distribution_labeli = []
        logist_distribution_labelj = []
        for res in logits_set_wrt_gt_label[i]:
            logist_distribution_labeli.append(res[i] - res[j])
        for res in logits_set_wrt_gt_label[j]:
            logist_distribution_labelj.append(res[i] - res[j])
        density1 = gaussian_kde(logist_distribution_labeli)
        density2 = gaussian_kde(logist_distribution_labelj)
        x = np.linspace(min(min(logist_distribution_labeli), min(logist_distribution_labelj)), 
                        max(max(logist_distribution_labeli), max(logist_distribution_labelj)), 500)
        overlap_area = np.trapz(np.minimum(density1(x), density2(x)), x)
        vanilla_result_heatmap[i][j] = overlap_area
        vanilla_result_mean.append(overlap_area)
        
        plt.figure(figsize=(4, 4), dpi=300)
        plt.plot(x, density1(x), 
                 label='Positive Samples', 
                 color='Royalblue',
                 linewidth = 3)
        plt.plot(x, density2(x), 
                 label='Negative Samples', 
                 color='Coral',
                 linewidth = 3)
        plt.fill_between(x, density1(x), 0, color='Royalblue', alpha=0.2)
        plt.fill_between(x, density2(x), 0, color='Coral', alpha=0.2)
        plt.legend(prop = {'size': 8})
        plt.tick_params(width=2, labelsize=24)
        plt.show()

In [None]:
# Calculate the overlap results by contextual calibration, fill the corresponding heatmap and mean lists for the results, and show the density plots.
# Calculation described in the Appendix B.5.1.

for i in range(len(train_data.label_space)):
    for j in range(i):
        con_distribution_labeli = []
        con_distribution_labelj = []
        for res in logits_set_wrt_gt_label[i]:
            con_distribution_labeli.append(res[i] / observed_background_prob_for_CC[i] - res[j] / observed_background_prob_for_CC[j])
        for res in logits_set_wrt_gt_label[j]:
            con_distribution_labelj.append(res[i] / observed_background_prob_for_CC[i] - res[j] / observed_background_prob_for_CC[j])
        density1 = gaussian_kde(con_distribution_labeli)
        density2 = gaussian_kde(con_distribution_labelj)
        x = np.linspace(min(min(con_distribution_labeli), min(con_distribution_labelj)), 
                        max(max(con_distribution_labeli), max(con_distribution_labelj)), 500)
        overlap_area = np.trapz(np.minimum(density1(x), density2(x)), x)
        contextual_result_heatmap[i][j] = overlap_area
        contextual_result_mean.append(overlap_area)
        
        plt.figure(figsize=(4, 4), dpi=300)
        plt.plot(x, density1(x), 
                 label='Positive Samples', 
                 color='Royalblue',
                 linewidth = 3)
        plt.plot(x, density2(x), 
                 label='Negative Samples', 
                 color='Coral',
                 linewidth = 3)
        plt.fill_between(x, density1(x), 0, color='Royalblue', alpha=0.2)
        plt.fill_between(x, density2(x), 0, color='Coral', alpha=0.2)
        plt.legend(prop = {'size': 8})
        plt.tick_params(width=2, labelsize=24)
        plt.show()

In [None]:
# Calculate the overlap results by domain calibration, fill the corresponding heatmap and mean lists for the results, and show the density plots.
# Calculation described in the Appendix B.5.1.

for i in range(len(train_data.label_space)):
    for j in range(i):
        con_distribution_labeli = []
        con_distribution_labelj = []
        for res in logits_set_wrt_gt_label[i]:
            con_distribution_labeli.append(res[i] / observed_background_prob_for_DC[i] - res[j] / observed_background_prob_for_DC[j])
        for res in logits_set_wrt_gt_label[j]:
            con_distribution_labelj.append(res[i] / observed_background_prob_for_DC[i] - res[j] / observed_background_prob_for_DC[j])
        density1 = gaussian_kde(con_distribution_labeli)
        density2 = gaussian_kde(con_distribution_labelj)
        x = np.linspace(min(min(con_distribution_labeli), min(con_distribution_labelj)), 
                        max(max(con_distribution_labeli), max(con_distribution_labelj)), 500)
        overlap_area = np.trapz(np.minimum(density1(x), density2(x)), x)
        domain_result_heatmap[i][j] = overlap_area
        domain_result_mean.append(overlap_area)
        
        plt.figure(figsize=(4, 4), dpi=300)
        plt.plot(x, density1(x), 
                 label='Positive Samples', 
                 color='Royalblue',
                 linewidth = 3)
        plt.plot(x, density2(x), 
                 label='Negative Samples', 
                 color='Coral',
                 linewidth = 3)
        plt.fill_between(x, density1(x), 0, color='Royalblue', alpha=0.2)
        plt.fill_between(x, density2(x), 0, color='Coral', alpha=0.2)
        plt.legend(prop = {'size': 8})
        plt.tick_params(width=2, labelsize=24)
        plt.show()

In [None]:
# Calculate the overlap results by hidden calibration, fill the corresponding heatmap and mean lists for the results, and show the density plots.
# Calculation described in the Appendix B.5.1.

for i in range(len(train_data.label_space)):
    for j in range(i):
        hidden_distribution_labeli = []
        hidden_distribution_labelj = []
        for res in hidden_state_set_wrt_gt_label[i]:
            hidden_distribution_labeli.append(res[i] - res[j])
        for res in hidden_state_set_wrt_gt_label[j]:
            hidden_distribution_labelj.append(res[i] - res[j])
        density1 = gaussian_kde(hidden_distribution_labeli)
        density2 = gaussian_kde(hidden_distribution_labelj)
        x = np.linspace(min(min(hidden_distribution_labeli), min(hidden_distribution_labelj)), 
                        max(max(hidden_distribution_labeli), max(hidden_distribution_labelj)), 500)
        overlap_area = np.trapz(np.minimum(density1(x), density2(x)), x)
        hidden_result_heatmap[i][j] = overlap_area
        hidden_result_mean.append(overlap_area)
        
        plt.figure(figsize=(4, 4), dpi=300)
        plt.plot(x, density1(x), 
                 label='Positive Samples', 
                 color='Royalblue',
                 linewidth = 3)
        plt.plot(x, density2(x), 
                 label='Negative Samples', 
                 color='Coral',
                 linewidth = 3)
        plt.fill_between(x, density1(x), 0, color='Royalblue', alpha=0.2)
        plt.fill_between(x, density2(x), 0, color='Coral', alpha=0.2)
        plt.legend(prop = {'size': 8})
        plt.tick_params(width=2, labelsize=24)
        plt.show()

In [None]:
# The final result report.

print("Result report on " + dataset_name + " dataset, metric: Inter-category overlap\n")

print("vanilla ICL: " + str(np.mean(vanilla_result_mean)))
print("hidden calibration: " + str(np.mean(hidden_result_mean)))
print("contextual calibration: " + str(np.mean(contextual_result_mean)))
print("domain calibration: " + str(np.mean(domain_result_mean)))

**Draw the heatmap**

In [None]:
# The heatmap for vanilla ICL.

for i in range(len(train_data.label_space)):
    vanilla_result_heatmap[i][i] = 1
    hidden_result_heatmap[i][i] = 1

mask = []
for i in range(len(train_data.label_space)):
    temp_mask = [0] * (i+1) + [1] * (len(train_data.label_space)-i-1)
    mask.append(temp_mask)

fig = plt.figure(figsize=(5, 5), dpi=300)
sns.heatmap(
    vanilla_result_heatmap, 
    vmin = 0,
    vmax = 1,
    mask = np.array(mask), 
    annot=True, 
    cbar=False,
    linewidths=2,
    cmap = "Reds",
    annot_kws={"fontsize":20}
)
ax=plt.gca()
ax.set_xticklabels(train_data.label_space, fontsize=16)
ax.set_yticklabels(train_data.label_space, fontsize=16)

In [None]:
# The heatmap for hidden calibration.

mask = []
for i in range(len(train_data.label_space)):
    temp_mask = [0] * (i+1) + [1] * (len(train_data.label_space)-i-1)
    mask.append(temp_mask)

fig = plt.figure(figsize=(5, 5), dpi=300)
sns.heatmap(
    hidden_result_heatmap, 
    vmin = 0,
    vmax = 1,
    mask = np.array(mask), 
    annot=True, 
    cbar=False,
    linewidths=2,
    cmap = "Reds",
    annot_kws={"fontsize":20}
)
ax=plt.gca()
ax.set_xticklabels(train_data.label_space, fontsize=16)
ax.set_yticklabels(train_data.label_space, fontsize=16)