In [None]:
import os
import json
import numpy as np
import random
import torch
import torch.nn as nn
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

# --- 1. Linear Probe Model ---
class LinearModel(nn.Module):
    def __init__(self, input_dim):
        super(LinearModel, self).__init__()
        self.linear = nn.Linear(input_dim, input_dim)
        self.out = nn.Linear(input_dim, 2)
        self.activation = nn.ReLU()

    def forward(self, x):
        return self.out(self.activation(self.linear(x)))

def train_linear_model(model, train_data, train_labels, epochs=10, batch_size=8, lr=0.001):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    model.train()
    
    # Safety checks for tensor types
    if not isinstance(train_data, torch.Tensor):
        train_data = torch.tensor(train_data, dtype=torch.float32)
    if not isinstance(train_labels, torch.Tensor):
        train_labels = torch.tensor(train_labels, dtype=torch.long)
        
    for epoch in tqdm(range(epochs)):
        permutation = torch.randperm(train_data.size()[0])
        
        # --- NEW: Variables to track loss ---
        epoch_loss = 0.0
        num_batches = 0
        
        for i in range(0, len(train_data), batch_size):
            indices = permutation[i:i+batch_size]
            batch_data = train_data[indices]
            batch_labels = train_labels[indices]
            
            optimizer.zero_grad()
            outputs = model(batch_data)
            loss = criterion(outputs, batch_labels)
            loss.backward()
            optimizer.step()
            
            # --- NEW: Accumulate the error ---
            epoch_loss += loss.item()
            num_batches += 1
            
        # --- NEW: Print the average loss for this epoch ---
        avg_loss = epoch_loss / num_batches
        # tqdm.write prints above the progress bar so it doesn't get messy
        tqdm.write(f"Epoch {epoch+1} | Loss: {avg_loss:.4f}")
            
    return model

def test_linear_model(model, test_data):
    model.eval()
    with torch.no_grad():
        outputs = model(test_data)
        probs = torch.softmax(outputs, dim=1)
        # Return: (Predicted Class 0 or 1, Confidence of Class 0 (Abstain))
        # Note: If Class 0 is "Correct", Class 1 is "Wrong". 
        # But typically for AbstainQA: 
        # We train to predict "Is it Correct?". So Class 1 = Correct, Class 0 = Wrong.
        # Abstain Score usually relates to probability of being Wrong (Class 0).
        return torch.argmax(outputs, dim=1), probs[0][0].item() 

# --- 2. Metrics Calculation ---
def compute_metrics(correct_flags, abstain_flags, abstain_scores = None):
    # correct_flags: a list of [0,1]s representing the correctness of each QA answered by the LLM
    # abstain_flags: a list of [0,1]s representing whether the LLM abstained from answering each QA
    # abstain_scores: a list of floats from 0 to 1 representing the confidence of the LLM in abstaining
    # returns: a dictionary of metrics

    assert len(correct_flags) == len(abstain_flags)

    # group A: answered and correct
    # group B: abstained and correct
    # group C: answered and incorrect
    # group D: abstained and incorrect
    A = 0
    B = 0
    C = 0
    D = 0
    for i in range(len(correct_flags)):
        if abstain_flags[i]:
            if correct_flags[i]:
                B += 1
            else:
                D += 1
        else:
            if correct_flags[i]:
                A += 1
            else:
                C += 1
        
    # reliable accuracy: accuracy of the LLM on the questions it answered
    try:
        reliable_accuracy = A / (A + C)
    except:
        reliable_accuracy = None

    # effective reliability: correct 1, incorrect -1, abstained 0
    effective_reliability = (A - C) / (A + B + C + D)

    # abstain accuracy: accuracy of the LLM abstain decisions, how many times correct_flags == !abstain flags
    abstain_accuracy = (A + D) / (A + B + C + D)

    # abstain precision: how many abstains is right among all abstains
    try:
        abstain_precision = D / (B + D)
    except:
        abstain_precision = None

    # abstain recall: how many abstains is right among all incorrect answers
    try:
        abstain_recall = D / (C + D)
    except:
        abstain_recall = None

    # abstain ECE: bucket abstain confidence into 10 buckets (0:0.1:1), compute the expected calibration error
    if abstain_scores is not None and max(abstain_scores) != min(abstain_scores):

        # rescale abstain scores to 0-1 before calculation
        max_score = max(abstain_scores)
        min_score = min(abstain_scores)
        for i in range(len(abstain_scores)):
            abstain_scores[i] = (abstain_scores[i] - min_score) / (max_score - min_score)

        bucket_probs = [[] for i in range(10)]
        bucket_abstain = [[] for i in range(10)] # whether it should have abstained

        for i in range(len(abstain_scores)):
            if abstain_scores[i] == 1:
                bucket = 9
            else:
                bucket = int(abstain_scores[i] * 10)
            bucket_probs[bucket].append(abstain_scores[i])
            if correct_flags[i] == 1:
                bucket_abstain[bucket].append(0)
            else:
                bucket_abstain[bucket].append(1)
            
        bucket_ece = 0
        for i in range(10):
            if len(bucket_probs[i]) == 0:
                continue
            bucket_probs_avg = np.mean(bucket_probs[i])
            bucket_abstain_avg = np.mean(bucket_abstain[i])
            bucket_ece += abs(bucket_abstain_avg - bucket_probs_avg) * len(bucket_probs[i])
        bucket_ece /= len(abstain_scores)
    else:
        bucket_ece = None

    # abstain rate: what percentage of questions the LLM abstained from
    abstain_rate = (B + D) / (A + B + C + D)
            
    return {
        'reliable_accuracy': f"{reliable_accuracy:.2f}",
        'effective_reliability': f"{effective_reliability:.2f}",
        'abstain_accuracy': f"{abstain_accuracy:.2f}",
        'abstain_precision': abstain_precision,
        'abstain_recall': abstain_recall,
        'abstain_f1': f"{2 * abstain_precision * abstain_recall / (abstain_precision + abstain_recall):.2f}" if abstain_precision is not None and abstain_recall is not None and (abstain_precision + abstain_recall) > 0 else None,
        'abstain_ece': bucket_ece,
        'abstain_rate': abstain_rate
    }

# --- 3. Answer Parsing ---
def answer_parsing(response):
    # mode 1: answer directly after
    temp = response.strip().split(" ")
    for option in ["A", "B", "C", "D", "E"]:
        if option in temp[0]:
            return option
    # mode 2: "The answer is A/B/C/D/E"
    temp = response.lower()
    for option in ["a", "b", "c", "d", "e"]:
        if "the answer is " + option in temp:
            return option.upper()
    # mode 3: "Answer: A/B/C/D/E"
    temp = response.lower()
    for option in ["a", "b", "c", "d", "e"]:
        if "answer: " + option in temp:
            return option.upper()
    # mode 4: " A/B/C/D/E " or " A/B/C/D/E."
    for option in ["A", "B", "C", "D", "E"]:
        if " " + option + " " in response or " " + option + "." in response:
            return option
    # mode 5: "The correct answer is A/B/C/D/E"
    temp = response.lower()
    for option in ["a", "b", "c", "d", "e"]:
        if "the correct answer is " + option in temp:
            return option.upper()
    # mode 6: "A: " or "B: " or "C: " or "D: " or "E: "
    for option in ["A", "B", "C", "D", "E"]:
        if option + ": " in response:
            return option
    # mode 7: "A/B/C/D/E" and EOS
    try:
        for option in ["A", "B", "C", "D", "E"]:
            if option + "\n" in response or response[-1] == option:
                return option
    except:
        pass
    # fail to parse
    print("fail to parse answer", response, "------------------")
    return "Z" # so that its absolutely wrong

In [None]:
def load_model_and_tokenizer(model_name):
    print(f"Loading model: {model_name}...")
    if "mistral" in model_name.lower():
        model_id = "mistralai/Mistral-7B-Instruct-v0.1"

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    # Load in 16-bit to save memory, map to auto (GPU)
    model = AutoModelForCausalLM.from_pretrained(
        model_id, 
        device_map="auto", 
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True
    )
    model.eval()
    return model, tokenizer

def format_input(tokenizer, prompt):
    # Standardize input formatting
    messages = [{"role": "user", "content": prompt}]
    # apply_chat_template handles the [INST] or specific tags automatically
    text_input = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    return text_input

def process_single_item(model, tokenizer, prompt, correct_answer, layer_indices):
    """
    Process one QA pair:
    1. Extract Hidden States (Trajectory)
    2. Generate Answer
    3. Return CPU data and clean GPU
    """
    # 1. Prepare Input
    formatted_prompt = format_input(tokenizer, prompt)
    inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
    
    # 2. Forward Pass (Get Hidden States)
    # We do this BEFORE generation to capture the "reading" state of the prompt
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
        
    # Extract specific layers at the LAST token of the prompt
    vectors = []
    for idx in layer_indices:
        # Shape: [batch, seq_len, hidden_dim] -> [hidden_dim]
        # Move to CPU immediately to free VRAM
        v = outputs.hidden_states[idx][:, -1, :].cpu().float().numpy().flatten()
        vectors.append(v)
    
    # Concatenate vectors [layer_0, layer_mid, layer_last]
    hidden_vec = np.concatenate(vectors)
    
    # 3. Generate Answer (Deterministic)
    # We reuse 'inputs' so the context is identical
    with torch.no_grad():
        generated_ids = model.generate(
            **inputs, 
            max_new_tokens=10, 
            do_sample=False, # <--- IMPORTANT: Deterministic
            pad_token_id=tokenizer.eos_token_id
        )
    
    # Decode only the new tokens
    input_len = inputs.input_ids.shape[1]
    response = tokenizer.decode(generated_ids[0][input_len:], skip_special_tokens=True)
    
    print(response)
    # 4. Check Correctness
    prediction = answer_parsing(response)
    is_correct = 1 if prediction == correct_answer else 0
    
    # 5. Explicit Cleanup
    del outputs
    del generated_ids
    del inputs
    torch.cuda.empty_cache() # Force VRAM release
    
    return hidden_vec, is_correct, response

In [None]:
# --- Configuration ---
# You can change these variables directly
MODEL_NAME = "mistral" # "mistral", "llama2_7b", etc.
DATASET = "mmlu"       # "mmlu", "hellaswag", etc.
PORTION = 1.0          # 1.0 for full dataset
SEED = 42

# Set Seed
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# 1. Load Data
print(f"Loading Dataset: {DATASET}...")
with open(f"data/{DATASET}.json", "r") as f:
    data = json.load(f)
    # Slice dataset based on portion
    data["dev"] = data["dev"][:int(len(data["dev"])*PORTION)]
    data["test"] = data["test"][:int(len(data["test"])*PORTION)]

# 2. Load Model
model, tokenizer = load_model_and_tokenizer(MODEL_NAME)

# Define Layers to extract [Embedding, Middle, Last]
num_layers = model.config.num_hidden_layers
layer_indices = [0, num_layers // 2, num_layers]

# 3. Processing Loop (Single Phase)
# We store results in CPU lists (RAM)
dev_embeddings = []
dev_labels = [] # Correctness flags
test_embeddings = []
test_labels = [] # Correctness flags

# Helper to construct prompt
def make_prompt(d):
    p = "Question: " + d["question"] + "\n"
    for key in d["choices"].keys():
        p += (key + ": " + d["choices"][key] + "\n")
    p += "Choose one answer from the above choices. The answer is"
    return p

print("Processing Dev Set (Generating & Extracting)...")
for d in tqdm(data["dev"]):
    prompt = make_prompt(d)
    vec, is_correct, _ = process_single_item(model, tokenizer, prompt, d["answer"], layer_indices)
    
    dev_embeddings.append(vec)
    dev_labels.append(is_correct)

print("Processing Test Set (Generating & Extracting)...")
# Optional: To save to file immediately instead of RAM, use np.save inside loop
# But accumulating in list is usually fine for <100k samples
for d in tqdm(data["test"]):
    prompt = make_prompt(d)
    vec, is_correct, _ = process_single_item(model, tokenizer, prompt, d["answer"], layer_indices)
    
    test_embeddings.append(vec)
    test_labels.append(is_correct)

# 4. Clean up Model to free HUGE memory before training Linear Probe
print("Unloading LLM...")
del model
del tokenizer
torch.cuda.empty_cache()

# 5. Train Linear Probe
print("Training Linear Probe...")
# Convert lists to Tensors
X_train = torch.tensor(np.array(dev_embeddings), dtype=torch.float32)
y_train = torch.tensor(dev_labels, dtype=torch.long)
X_test = torch.tensor(np.array(test_embeddings), dtype=torch.float32)
y_test = torch.tensor(test_labels, dtype=torch.long)

# Train
input_dim = X_train.shape[1]
probe = LinearModel(input_dim)
probe = train_linear_model(probe, X_train, y_train, epochs=10)

# 6. Evaluate & Metrics
print("Evaluating...")
abstain_flags = []
abstain_scores = []

# Inference on Test Set
for i in tqdm(range(len(X_test))):
    # Predict correctness
    # We want to abstain if we think it's WRONG (Class 0)
    # prediction 1 = confident correct, 0 = predict wrong (so abstain)
    pred_class, prob_wrong = test_linear_model(probe, X_test[i].unsqueeze(0))
    
    # Note: Logic depends on how you define abstain_flag
    # Usually: Abstain if model predicts "Wrong" (Class 0)
    # abstain_flag = 1 means "I choose to abstain"
    
    # predicted_is_correct = pred_class.item() # 1 or 0
    
    # if predicted_is_correct == 1:
    #     abstain_flags.append(0) # Do not abstain
    # else:
    #     abstain_flags.append(1) # Abstain

    ABSTAIN_THRESHOLD = 0.5
    if prob_wrong > ABSTAIN_THRESHOLD:
        abstain_flags.append(1) # Abstain (Cautious)
    else:
        abstain_flags.append(0) # Answer (Brave)
        
    abstain_scores.append(prob_wrong) # Confidence of error

# Calculate Metrics
metrics_result = compute_metrics(test_labels, abstain_flags, abstain_scores)

print("------------------")
print("Approach: Single Phase (Do_sample=False)")
print("Model:", MODEL_NAME)
print("Dataset:", DATASET)
print(metrics_result)
print("------------------")