In [1]:
from nnsight import LanguageModel
import torch
import os
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import torch.nn.functional as F

In [None]:
torch.set_grad_enabled(False)

In [None]:
MODELS = ["Qwen/Qwen2.5-3B", "meta-llama/Llama-3.2-3B"]

llm = LanguageModel(MODELS[0], device_map="auto", torch_dtype=torch.bfloat16)

In [None]:
# Load the dataset
print(os.getcwd())
df = pd.read_csv(os.path.join(os.getcwd(), "sentences.csv"))

In [None]:
results = {"prompt_correct": {"correct_token_prob": [], "incorrect_token_prob": []}, "prompt_incorrect": {"correct_token_prob": [], "incorrect_token_prob": []}}
for i in range(len(df)):
    # Define the answers to these prompts, formatted as (correct, incorrect)
    answers = [
        (df.iloc[i]["answer_sentence1"], df.iloc[i]["answer_sentence2"]),
    ]

    answer_token_indices = [
        [llm.tokenizer.encode(answers[i][j])[0] for j in range(2)]
        for i in range(len(answers))
    ]
    print(answer_token_indices)
    print(answers)
    
    s1 = df.iloc[i]["sentence1"]
    s2 = df.iloc[i]["sentence2"]
    last_sentence = df.iloc[i]["last_sentence"]
    prompt_correct = f"{s1}\n{s2}\nThis was the correct sentence:\n{last_sentence}"
    prompt_incorrect = f"{s1}\n{s2}\nThis was the false sentence:\n{last_sentence}"
    print(prompt_correct)
    print(prompt_incorrect)
    with llm.trace(prompt_correct): 
        output = llm.lm_head.output
        logits = output[0, -1]  # Get logits for last position
        probabilities = F.softmax(logits, dim=-1)  # Apply softmax to get probabilities
        
        # Get probabilities for specific tokens
        correct_token_prob = probabilities[answer_token_indices[0][0]].item().save()
        incorrect_token_prob = probabilities[answer_token_indices[0][1]].item().save()
        
        # Store results
        results["prompt_correct"]["correct_token_prob"].append(correct_token_prob)
        results["prompt_correct"]["incorrect_token_prob"].append(incorrect_token_prob)
    print("correct token prob")
    print(correct_token_prob)
    print(incorrect_token_prob)
    
    with llm.trace(prompt_incorrect):
        output = llm.lm_head.output
        logits = output[0, -1]  # Get logits for last position
        probabilities = F.softmax(logits, dim=-1)  # Apply softmax to get probabilities
        
        # Get probabilities for specific tokens
        correct_token_prob = probabilities[answer_token_indices[0][0]].item().save()
        incorrect_token_prob = probabilities[answer_token_indices[0][1]].item().save()
        
        # Store results
        results["prompt_incorrect"]["correct_token_prob"].append(correct_token_prob)
        results["prompt_incorrect"]["incorrect_token_prob"].append(incorrect_token_prob)
    print("incorrect token prob")
    print(correct_token_prob)
    print(incorrect_token_prob)


In [None]:
# Prepare data with all individual probabilities
data = []
for prompt_type in ["prompt_correct", "prompt_incorrect"]:
    for idx, (correct_prob, incorrect_prob) in enumerate(zip(
        [val.value for val in results[prompt_type]["correct_token_prob"]], 
        [val.value for val in results[prompt_type]["incorrect_token_prob"]]
    )):
        data.append({
            "Prompt": prompt_type,
            "Probability Type": "correct_token_prob",
            "Probability": correct_prob
        })
        data.append({
            "Prompt": prompt_type,
            "Probability Type": "incorrect_token_prob",
            "Probability": incorrect_prob
        })

df_plot = pd.DataFrame(data)

# Create the bar plot with error bars
plt.figure(figsize=(8, 6))
sns.barplot(
    x="Prompt",
    y="Probability",
    hue="Probability Type",
    data=df_plot,
    errorbar="ci",  # 95% confidence interval (default)
    capsize=0.1     # Add caps to error bars for better visibility
)

plt.title('Probabilities by Prompt Type with Error Bars')
plt.xlabel('Prompt Type')
plt.ylabel('Probability')
plt.legend(title='Probability Type')
plt.ylim(0, 1)  # Probabilities range from 0 to 1
plt.tight_layout()
plt.show()