# Log Probs Experiments

## Setup

In [None]:
import sys
import os
import matplotlib.pyplot as plt
import seaborn as sns

# Add the project root to the path to allow importing from `part2`
sys.path.insert(0, os.path.abspath('.'))
from part2.logprobs_cli import get_logprobs

# --- Mocking Configuration ---
# Set to False to use the live API. Requires API keys and environment variables to be set.
USE_MOCK_DATA = True

MOCK_LOGPROBS = {
    "The capital of the Netherlands is": [(' Amsterdam', 0.9999), (' The', 0.0067), (' The Hague', 0.0003), (' Rotterdam', 0.0001), (' Utrecht', 0.00004)],
    "Is Amsterdam the capital of the Netherlands? (Yes/No)": [(' Yes', 0.9048), (' No', 0.0820), (' The', 0.0024), (' It', 0.0009), (' Amsterdam', 0.0003)],
    "The capital of the Netherlands is: A) The Hague, B) Amsterdam, C) Rotterdam": [(' B', 0.8187), (' B)', 0.3011), (' A', 0.0183), (' C', 0.0067), (' Amsterdam', 0.0024)],
    "Netherlands capital?": [(' Amsterdam', 0.6065), (' The', 0.0497), (' The Hague', 0.0067), (' What', 0.0024), (' ', 0.0009)]
}

def get_data(provider, prompt, model_id, temperature, **kwargs):
    """Fetches logprobs from the live API or returns mock data based on the USE_MOCK_DATA flag."""
    if USE_MOCK_DATA:
        # The mock data doesn't change with temperature, 
        # so we return a slightly altered version for visualization purposes.
        base_probs = MOCK_LOGPROBS.get(prompt, MOCK_LOGPROBS["The capital of the Netherlands is"])
        if temperature > 0:
            # Simulate temperature effect by flattening the distribution
            tokens, probs = zip(*base_probs)
            scaled_probs = [p**(1/temperature) for p in probs]
            total = sum(scaled_probs)
            final_probs = [p/total for p in scaled_probs]
            return list(zip(tokens, final_probs))
        return base_probs
    else:
        return get_logprobs(
            provider=provider,
            prompt=prompt,
            model_id=model_id,
            temperature=temperature,
            **kwargs
        )

## Temperature Experiment

In [None]:
sns.set_theme(style="whitegrid")
sns.set_context("notebook", font_scale=1.2)

temperatures = [0.1, 0.5, 1.0, 1.5] # Using 0.1 instead of 0.0 for mock calculation stability
prompt_temp = "The capital of the Netherlands is"

fig, axes = plt.subplots(2, 2, figsize=(15, 10), sharey=True)
axes = axes.flatten()

for i, temp in enumerate(temperatures):
    logprobs = get_data(
        provider="gemini",
        prompt=prompt_temp,
        model_id="gemini-1.5-flash", 
        temperature=temp
    )
    
    tokens = [item[0] for item in logprobs]
    probs = [item[1] for item in logprobs]
    
    sns.barplot(x=tokens, y=probs, ax=axes[i], palette="viridis")
    axes[i].set_title(f"Temperature: {temp}")
    axes[i].set_ylabel("Probability")
    axes[i].set_xlabel("Token")
    axes[i].tick_params(axis='x', rotation=45)

plt.suptitle("Next Token Probabilities for Different Temperatures", fontsize=20)
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()

## Prompt Experiment

In [None]:
prompts = {
    "Freeform": "The capital of the Netherlands is",
    "Binary": "Is Amsterdam the capital of the Netherlands? (Yes/No)",
    "Multiple Choice": "The capital of the Netherlands is: A) The Hague, B) Amsterdam, C) Rotterdam",
    "Poor Prompt": "Netherlands capital?"
}

fig, axes = plt.subplots(2, 2, figsize=(15, 12), sharey=True)
axes = axes.flatten()
fixed_temp = 1.0

for i, (title, prompt) in enumerate(prompts.items()):
    logprobs = get_data(
        provider="gemini",
        prompt=prompt,
        model_id="gemini-1.5-flash",
        temperature=fixed_temp
    )
    
    tokens = [item[0] for item in logprobs]
    probs = [item[1] for item in logprobs]
    
    sns.barplot(x=tokens, y=probs, ax=axes[i], palette="plasma")
    axes[i].set_title(f"Prompt: {title}")
    axes[i].set_ylabel("Probability")
    axes[i].set_xlabel("Token")
    axes[i].tick_params(axis='x', rotation=45)

plt.suptitle(f"Next Token Probabilities for Different Prompts (Temp={fixed_temp})", fontsize=20)
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()