**Official implementation of paper "Token-based Decision Criteria Are Suboptimal in In-context Learning" (NAACL 2025)**:

### Source code for the Main Experiment

The code for the main experiments of the paper: test the performance of Hidden Calibration and other methods on the in-context learning task.

Used in experiments of:

1. The experiments described in the Sec 4, and results are shown in the Sec 4.2 of the paper.
2. The experiments described in the Sec 4.3, of the Prompt Template Complexity, shown in Fig. 5 (left), by varying the prompt template parameter of `prompt_function`.
3. The experiments described in the Sec 4.3, of the Demonstration Sampling / Order Complexity, shown in Fig. 5 (middle, right), by varying the `pre_sampled_demonstration` parameter of `prompt_function`.
4. Also, the data efficiency experiment (Sec 4.3, Training Data Complexity, Fig. 6) reuses this code by varying the parameter `calibration_sample_number_for_each_label`.

Author: Hakaze Cho, yfzhao@jaist.ac.jp

**Experiment Configs**

- `huggingface_model_name`: should be a model name from the HuggingFace model hub. For example, `facebook/opt-2.7b`.
- `huggingface_token`: should be a HuggingFace API token. Only is used when you use some models like `Llama2`.
- `quantization`: should be a boolean value. If it is `True`, the model will be quantized.
- `dataset_name`: should be a dataset name from the given examples: `"SemEvalR", "SemEvalL", "poem_sentiment", "TEE", "TEH", "TES", "FP", "AGNews", "MR", "hate_speech"`
- `k`: the demonstration numbers for the ICL.
- `calibration_sample_number_for_each_label`: the number of samples for calibration w.r.t. each label category for some of the calibration methods. i.e. the horizontal axis of the Fig. 6 in the paper.

In [121]:
# Configs
huggingface_model_name = "facebook/opt-2.7b"
huggingface_token = "API_TOKEN"
quantization = False

dataset_name = "SemEvalR" # Alternative: "SemEvalR", "SemEvalL", "poem_sentiment", "TEE", "TEH", "TES", "FP", "AGNews", "MR", "hate_speech"

k = 4
calibration_sample_number_for_each_label = 16

**Load everything.**

In [122]:
# Import libraries, and nessessary definitions

import sys
sys.path.append("hidden_calibration_released") # Replace with the path from the working directory to the root of this project. If the working directory is already the root of the project, this line is not needed.

from functools import partial
import util.prompting as prompting
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
import util.dataset_loader as dataset_loader
import util.calibrations as calibrations
import numpy as np
from tqdm import tqdm
from torcheval.metrics.functional import multiclass_accuracy, multiclass_f1_score
from scipy import spatial
import copy

def softmax(x):
    f_x = np.exp(x) / np.sum(np.exp(x))
    return f_x

dataset_name_to_class = {
    "SemEvalR": dataset_loader.SemEval2014_Restaurants,
    "SemEvalL": dataset_loader.SemEval2014_Laptops,
    "poem_sentiment": dataset_loader.poem_sentiment,
    "TEE": dataset_loader.tweet_eval_emotion,
    "TEH": dataset_loader.tweet_eval_hate,
    "TES": dataset_loader.tweet_eval_sentiment,
    "FP": dataset_loader.financial_phrasebank,
    "AGNews": dataset_loader.agnews,
    "MR": dataset_loader.rooten_tomato,
    "hate_speech": dataset_loader.hate_speech18,
}

In [123]:
# Load model and tokenizer from Huggingface

torch.cuda.empty_cache()

tokenizer = AutoTokenizer.from_pretrained(huggingface_model_name, token = huggingface_token)
if quantization:
    model = AutoModelForCausalLM.from_pretrained(huggingface_model_name, token = huggingface_token, quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.bfloat16
    ))
else:
    model = AutoModelForCausalLM.from_pretrained(huggingface_model_name, token = huggingface_token).cuda()

In [124]:
train_data = dataset_name_to_class[dataset_name]().default_training_division()
test_data = dataset_name_to_class[dataset_name]().default_testing_division()
if dataset_name == "AGNews":
    train_data.cut_by_length(226)
    test_data.cut_by_length(226)

Define the prompt template by:

``` python
default_prompting(dataset, demos_amount, query_index = -1, input_start = 'Input: ', input_end_label_start = ', Label: ', inter_div = '\n')
```
where, `dataset` is the dataset name, `demos_amount` is the number of demonstrations, `query_index` is the index of the query, `input_start` is what in the start of the input text, `input_end_label_start` is what in the end of the input text and the start of the label, `inter_div` is the division between the input and the label.

Recommand to use `partial` to define the template by setting the `input_start`, `input_end_label_start`, `inter_div`.

In [125]:
prompt_function = partial(prompting.default_prompting)

**Train all the calibration methods.**

In [None]:
# Train the Contextual Calibration and Domain Calibration

torch.cuda.empty_cache()
observed_background_prob_for_CC = calibrations.empty_query_base_logits(model, tokenizer, train_data, k, prompt_function, calibration_sample_number_for_each_label * len(train_data.label_space))
observed_background_prob_for_DC = calibrations.domain_text_base_logits(model, tokenizer, train_data, k, prompt_function, calibration_sample_number_for_each_label * len(train_data.label_space))

In [None]:
# Train the Centroid Calibration and Hidden Calibration

## Collect the hidden states and full vocabulary probabilities for the calibration samples.
ground_truth_labels_for_calibration_samples = []
observed_full_vocabulary_probabilities = []
observed_last_hidden_state = []

torch.no_grad()
for j in range(len(train_data.label_space)):
    label_set = dataset_loader.get_label_set_from_label_index(train_data, j)
    for i in tqdm(label_set[:calibration_sample_number_for_each_label]):
        torch.cuda.empty_cache()
        prpt = prompt_function(train_data, k, query_index=i)
        tknzd_data = tokenizer(prpt[0], return_tensors="pt").input_ids.cuda()
        result = model(tknzd_data, output_hidden_states = True)
        result_vector = result['logits'][0][-1].detach().cpu().numpy()
        one_last_hidden_state = result.hidden_states[-1][-1][-1].detach().cpu().numpy()
        observed_last_hidden_state.append(one_last_hidden_state)
        tkized_label_space = []
        observed_full_vocabulary_probabilities.append(softmax(result_vector))
        ground_truth_labels_for_calibration_samples.append(train_data.label_space.index(prpt[1]))

## Divide the collected hidden states and full vocabulary probabilities by label.
observed_full_vocabulary_probabilities_indexed_by_label = []
observed_last_hidden_state_indexed_by_label = []

for label in train_data.label_space:
    observed_full_vocabulary_probabilities_indexed_by_label.append([])
    observed_last_hidden_state_indexed_by_label.append([])
for i in range(len(ground_truth_labels_for_calibration_samples)):
    observed_full_vocabulary_probabilities_indexed_by_label[ground_truth_labels_for_calibration_samples[i]].append(observed_full_vocabulary_probabilities[i])
    observed_last_hidden_state_indexed_by_label[ground_truth_labels_for_calibration_samples[i]].append(observed_last_hidden_state[i])

## Calculate the centroids from the collected hidden states and full vocabulary probabilities.
observed_full_vocabulary_probabilities_CENTROID_indexed_by_label = []
observed_last_hidden_state_CENTROID_indexed_by_label = []
for lists in observed_full_vocabulary_probabilities_indexed_by_label:
    observed_full_vocabulary_probabilities_CENTROID_indexed_by_label.append(np.mean(lists, axis=0))
for lists in observed_last_hidden_state_indexed_by_label:
    observed_last_hidden_state_CENTROID_indexed_by_label.append(np.mean(lists, axis=0))

**Inference.**

In [None]:
# Inference by vanilla model, Contextual Calibration, and Domain Calibration, and collect the hidden states and full vocabulary probabilities for Hidden Calibration and Centroid Calibration

vanilla_logits_softmaxed = [] # Prepare for the Batch Calibration

inference_full_vocabulary_probabilities = []
inference_last_hidden_state = []

predicted_by_vanilla = []
predicted_by_CC = []
predicted_by_DC = []
predicted_by_centroid_calibration = []
predicted_by_hidden_calibration_cosine = []
predicted_by_hidden_calibration_L2 = []
predicted_by_knn_withlabel = []
groundtruth = []

torch.no_grad()
for i in tqdm(range(test_data.get_max())):
    torch.cuda.empty_cache()
    predicted_single = []
    prpt = prompt_function(test_data, k, query_index=i)
    tknzd_data = tokenizer(prpt[0], return_tensors="pt").input_ids.cuda()
    result = model(tknzd_data, output_hidden_states = True)
    result_vector = result['logits'][0][-1].detach().cpu().numpy()
    ahidden_state = result.hidden_states[-1][-1][-1].detach().cpu().numpy()
    direct_label_logits = []
    inference_last_hidden_state.append(ahidden_state)
    inference_full_vocabulary_probabilities.append(softmax(result_vector))
    for label in test_data.label_space:
        index = tokenizer(label).input_ids[-1]
        direct_label_logits.append(result_vector[index])
    direct_label_logits = softmax(direct_label_logits)
    vanilla_logits_softmaxed.append(direct_label_logits)
    predicted_by_vanilla.append(np.argmax(direct_label_logits))
    dc_direct_label_logits = copy.deepcopy(direct_label_logits)
    for i in range(len(direct_label_logits)):
        direct_label_logits[i] /= observed_background_prob_for_CC[i] + 1e-10
    for i in range(len(direct_label_logits)):
        dc_direct_label_logits[i] /= observed_background_prob_for_DC[i] + 1e-10
    predicted_by_CC.append(np.argmax(direct_label_logits))
    predicted_by_DC.append(np.argmax(dc_direct_label_logits))
    groundtruth.append(test_data.label_space.index(prpt[1]))

In [129]:
## Predicting by Batch Calibration
predicted_by_batch_calibration = calibrations.batch_calibration_for_result(
    vanilla_logits_softmaxed,
    calibration_sample_number_for_each_label
)

## Predicting by KNN
for result_vector in inference_full_vocabulary_probabilities:
    predicted_by_knn_withlabel.append(calibrations.predict_by_knn(result_vector, observed_full_vocabulary_probabilities, ground_truth_labels_for_calibration_samples, spatial.distance.jensenshannon, len(train_data.label_space), 3))

def pdf(x, mu, sigma): 
    return 1/(sigma*np.sqrt(2*np.pi))*np.exp(-(x-mu)**2/(2*sigma**2))

## Predicting by Centroid Calibration
for result_vector in inference_full_vocabulary_probabilities:
    predicted_single = []
    ablation_cosine_single = []
    ablation_l2_single = []
    for l in range(len(observed_full_vocabulary_probabilities_CENTROID_indexed_by_label)):
        ablation_cosine_single.append(np.abs(1 - spatial.distance.cosine(observed_full_vocabulary_probabilities_CENTROID_indexed_by_label[l], result_vector)))
        ablation_l2_single.append(-spatial.distance.euclidean(observed_full_vocabulary_probabilities_CENTROID_indexed_by_label[l], result_vector))
        predicted_single.append(-spatial.distance.jensenshannon(observed_full_vocabulary_probabilities_CENTROID_indexed_by_label[l], result_vector))
    predicted_single = softmax(predicted_single)
    predicted_by_centroid_calibration.append(np.argmax(predicted_single))

## Predicting by Hidden Calibration
for result_vector in inference_last_hidden_state:
    cosine_single = []
    l2_single = []
    for l in range(len(observed_last_hidden_state_CENTROID_indexed_by_label)):
        cosine_single.append(-spatial.distance.euclidean(observed_last_hidden_state_CENTROID_indexed_by_label[l], result_vector))
        l2_single.append(-spatial.distance.cosine(observed_last_hidden_state_CENTROID_indexed_by_label[l], result_vector))
    predicted_by_hidden_calibration_cosine.append(np.argmax(cosine_single))
    predicted_by_hidden_calibration_L2.append(np.argmax(l2_single))

**Test and output results.**

In [None]:
print("Result report on " + dataset_name + " dataset, metric: Macro F1 Score\n")

print("Vanilla ICL: " + str(multiclass_f1_score(torch.LongTensor(predicted_by_vanilla), torch.LongTensor(groundtruth), num_classes = len(test_data.label_space), average = 'macro').item()))
print("Contextual Calibration: " + str(multiclass_f1_score(torch.LongTensor(predicted_by_CC), torch.LongTensor(groundtruth), num_classes = len(test_data.label_space), average = 'macro').item()))
print("Domain Calibration: " + str(multiclass_f1_score(torch.LongTensor(predicted_by_DC), torch.LongTensor(groundtruth), num_classes = len(test_data.label_space), average = 'macro').item()))
print("Batch Calibration: " + str(multiclass_f1_score(torch.LongTensor(predicted_by_batch_calibration), torch.LongTensor(groundtruth), num_classes = len(test_data.label_space), average = 'macro').item()))
print("Centroid Calibration: " + str(multiclass_f1_score(torch.LongTensor(predicted_by_centroid_calibration), torch.LongTensor(groundtruth), num_classes = len(test_data.label_space), average = 'macro').item()))
print("KNN: " + str(multiclass_f1_score(torch.LongTensor(predicted_by_knn_withlabel), torch.LongTensor(groundtruth), num_classes = len(test_data.label_space), average = 'macro').item()))
print("Hidden Calibration cosine: " + str(multiclass_f1_score(torch.LongTensor(predicted_by_hidden_calibration_cosine), torch.LongTensor(groundtruth), num_classes = len(test_data.label_space), average = 'macro').item()))
print("Hidden Calibration L2: " + str(multiclass_f1_score(torch.LongTensor(predicted_by_hidden_calibration_L2), torch.LongTensor(groundtruth), num_classes = len(test_data.label_space), average = 'macro').item()))

In [None]:
print("Result report on " + dataset_name + " dataset, metric: Accuracy\n")

print("Vanilla ICL: " + str(multiclass_accuracy(torch.LongTensor(predicted_by_vanilla), torch.LongTensor(groundtruth), num_classes = len(test_data.label_space), average = 'macro').item()))
print("Contextual Calibration: " + str(multiclass_accuracy(torch.LongTensor(predicted_by_CC), torch.LongTensor(groundtruth), num_classes = len(test_data.label_space), average = 'macro').item()))
print("Domain Calibration: " + str(multiclass_accuracy(torch.LongTensor(predicted_by_DC), torch.LongTensor(groundtruth), num_classes = len(test_data.label_space), average = 'macro').item()))
print("Batch Calibration: " + str(multiclass_accuracy(torch.LongTensor(predicted_by_batch_calibration), torch.LongTensor(groundtruth), num_classes = len(test_data.label_space), average = 'macro').item()))
print("Centroid Calibration: " + str(multiclass_accuracy(torch.LongTensor(predicted_by_centroid_calibration), torch.LongTensor(groundtruth), num_classes = len(test_data.label_space), average = 'macro').item()))
print("KNN: " + str(multiclass_accuracy(torch.LongTensor(predicted_by_knn_withlabel), torch.LongTensor(groundtruth), num_classes = len(test_data.label_space), average = 'macro').item()))
print("Hidden Calibration cosine: " + str(multiclass_accuracy(torch.LongTensor(predicted_by_hidden_calibration_cosine), torch.LongTensor(groundtruth), num_classes = len(test_data.label_space), average = 'macro').item()))
print("Hidden Calibration L2: " + str(multiclass_accuracy(torch.LongTensor(predicted_by_hidden_calibration_L2), torch.LongTensor(groundtruth), num_classes = len(test_data.label_space), average = 'macro').item()))