# Experiment 4: Few-Shot Learning Analysis

**Goal:** Understand how few-shot examples affect model behavior.

**Key Questions:**
- How does the number of examples affect performance?
- Does example order matter?
- What makes a "good" example vs a "bad" one?
- Is it the content or the format that helps?

In [None]:
import sys
sys.path.append('..')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import random

from src.model_utils import load_model
from src.prompt_utils import FewShotExample, FewShotPromptBuilder
from src.metrics import ExperimentResults, SequenceMetrics
from src.visualization import set_style

set_style()
random.seed(42)

In [None]:
model = load_model("TinyLlama/TinyLlama-1.1B-Chat-v1.0")

## 1. Define Few-Shot Examples

In [None]:
# Sentiment classification examples
SENTIMENT_EXAMPLES = [
    FewShotExample("This movie was fantastic!", "positive"),
    FewShotExample("I absolutely hated it.", "negative"),
    FewShotExample("Best purchase I ever made.", "positive"),
    FewShotExample("Complete waste of money.", "negative"),
    FewShotExample("The service was excellent!", "positive"),
    FewShotExample("Terrible experience, never again.", "negative"),
    FewShotExample("I'm so happy with this!", "positive"),
    FewShotExample("Very disappointing product.", "negative"),
]

# Math examples
MATH_EXAMPLES = [
    FewShotExample("2 + 3", "5"),
    FewShotExample("10 - 4", "6"),
    FewShotExample("5 * 3", "15"),
    FewShotExample("20 / 4", "5"),
    FewShotExample("7 + 8", "15"),
    FewShotExample("15 - 9", "6"),
]

# Test queries
TEST_QUERIES = {
    "sentiment": {
        "examples": SENTIMENT_EXAMPLES,
        "query": "This restaurant exceeded all my expectations!",
        "expected": "positive",
        "format": "Input: {input}\nSentiment: {output}",
        "query_format": "Input: {query}\nSentiment:"
    },
    "math": {
        "examples": MATH_EXAMPLES,
        "query": "9 + 6",
        "expected": "15",
        "format": "Problem: {input}\nAnswer: {output}",
        "query_format": "Problem: {query}\nAnswer:"
    }
}

## 2. N-Shot Scaling

In [None]:
def test_n_shot_scaling(model, task_data, n_values):
    """Test how performance scales with number of examples."""
    builder = FewShotPromptBuilder(task_data["examples"])
    
    results = []
    for n in n_values:
        prompt = builder.build(
            query=task_data["query"],
            n_examples=n,
            example_format=task_data["format"],
            query_format=task_data["query_format"]
        )
        
        dist = model.get_next_token_distribution(prompt)
        seq_probs = model.get_sequence_log_probs(prompt, " " + task_data["expected"])
        
        results.append({
            "n_shot": n,
            "prompt": prompt,
            "target_log_prob": seq_probs["total_log_prob"],
            "entropy": dist["entropy"],
            "top_5": dist["top_tokens"][:5],
            "prompt_length": len(prompt)
        })
    
    return results

In [None]:
# Test n-shot scaling
n_values = [0, 1, 2, 3, 4, 5, 6]

scaling_results = {}
for task_name, task_data in TEST_QUERIES.items():
    max_n = min(max(n_values), len(task_data["examples"]))
    valid_n = [n for n in n_values if n <= max_n]
    
    print(f"\nTesting {task_name}...")
    results = test_n_shot_scaling(model, task_data, valid_n)
    scaling_results[task_name] = results
    
    for r in results:
        print(f"  {r['n_shot']}-shot: log_prob={r['target_log_prob']:.3f}, entropy={r['entropy']:.3f}")

In [None]:
# Visualize n-shot scaling
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

for task_name, results in scaling_results.items():
    n_shots = [r["n_shot"] for r in results]
    log_probs = [r["target_log_prob"] for r in results]
    
    axes[0].plot(n_shots, log_probs, 'o-', label=task_name, linewidth=2, markersize=8)

axes[0].set_xlabel('Number of Examples')
axes[0].set_ylabel('Target Log Probability')
axes[0].set_title('N-Shot Scaling: Target Probability')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

for task_name, results in scaling_results.items():
    n_shots = [r["n_shot"] for r in results]
    entropies = [r["entropy"] for r in results]
    
    axes[1].plot(n_shots, entropies, 'o-', label=task_name, linewidth=2, markersize=8)

axes[1].set_xlabel('Number of Examples')
axes[1].set_ylabel('Output Entropy')
axes[1].set_title('N-Shot Scaling: Output Entropy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('../results/exp4_nshot_scaling.png', dpi=150, bbox_inches='tight')
plt.show()

## 3. Example Order Effects

In [None]:
def test_example_order(model, task_data, n_examples=4, n_permutations=10):
    """Test how example order affects performance."""
    examples = task_data["examples"][:n_examples]
    
    results = []
    
    # Test multiple random orderings
    for i in range(n_permutations):
        shuffled = examples.copy()
        random.shuffle(shuffled)
        
        builder = FewShotPromptBuilder(shuffled)
        prompt = builder.build(
            query=task_data["query"],
            n_examples=n_examples,
            example_format=task_data["format"],
            query_format=task_data["query_format"]
        )
        
        seq_probs = model.get_sequence_log_probs(prompt, " " + task_data["expected"])
        dist = model.get_next_token_distribution(prompt)
        
        results.append({
            "order": [ex.input[:20] for ex in shuffled],
            "target_log_prob": seq_probs["total_log_prob"],
            "entropy": dist["entropy"]
        })
    
    return results

In [None]:
# Test order effects
order_results = {}

for task_name, task_data in TEST_QUERIES.items():
    print(f"\nTesting order effects for {task_name}...")
    results = test_example_order(model, task_data, n_examples=4, n_permutations=20)
    order_results[task_name] = results
    
    log_probs = [r["target_log_prob"] for r in results]
    print(f"  Log-prob range: [{min(log_probs):.3f}, {max(log_probs):.3f}]")
    print(f"  Variance: {np.var(log_probs):.4f}")

In [None]:
# Visualize order effects
fig, ax = plt.subplots(figsize=(12, 5))

data_for_plot = []
labels = []
for task_name, results in order_results.items():
    log_probs = [r["target_log_prob"] for r in results]
    data_for_plot.append(log_probs)
    labels.append(task_name)

bp = ax.boxplot(data_for_plot, labels=labels, patch_artist=True)
colors = ['steelblue', 'coral']
for patch, color in zip(bp['boxes'], colors):
    patch.set_facecolor(color)
    patch.set_alpha(0.7)

ax.set_ylabel('Target Log Probability')
ax.set_title('Effect of Example Order (20 random permutations each)')

plt.tight_layout()
plt.savefig('../results/exp4_order_effects.png', dpi=150, bbox_inches='tight')
plt.show()

## 4. Example Quality Analysis

In [None]:
def test_individual_examples(model, task_data):
    """Test the impact of each individual example."""
    results = []
    
    for i, example in enumerate(task_data["examples"]):
        # Single example prompt
        builder = FewShotPromptBuilder([example])
        prompt = builder.build(
            query=task_data["query"],
            n_examples=1,
            example_format=task_data["format"],
            query_format=task_data["query_format"]
        )
        
        seq_probs = model.get_sequence_log_probs(prompt, " " + task_data["expected"])
        dist = model.get_next_token_distribution(prompt)
        
        results.append({
            "example_idx": i,
            "example_input": example.input,
            "example_output": example.output,
            "target_log_prob": seq_probs["total_log_prob"],
            "entropy": dist["entropy"]
        })
    
    return results

In [None]:
# Test individual example quality
example_quality = {}

for task_name, task_data in TEST_QUERIES.items():
    print(f"\nAnalyzing examples for {task_name}...")
    results = test_individual_examples(model, task_data)
    example_quality[task_name] = results
    
    sorted_results = sorted(results, key=lambda x: x["target_log_prob"], reverse=True)
    print("\nExample ranking (best to worst):")
    for r in sorted_results:
        print(f"  '{r['example_input'][:30]}...' -> {r['example_output']}: log_prob={r['target_log_prob']:.3f}")

## 5. Format vs Content

In [None]:
def test_format_vs_content(model, task_data):
    """Disentangle format from content by using different combinations."""
    results = []
    
    # 1. Full examples (content + format)
    builder = FewShotPromptBuilder(task_data["examples"][:3])
    full_prompt = builder.build(
        query=task_data["query"],
        n_examples=3,
        example_format=task_data["format"],
        query_format=task_data["query_format"]
    )
    
    # 2. Format only (use placeholder content)
    placeholder_examples = [
        FewShotExample("[example input]", "[output]"),
        FewShotExample("[example input]", "[output]"),
        FewShotExample("[example input]", "[output]"),
    ]
    builder_format = FewShotPromptBuilder(placeholder_examples)
    format_prompt = builder_format.build(
        query=task_data["query"],
        n_examples=3,
        example_format=task_data["format"],
        query_format=task_data["query_format"]
    )
    
    # 3. Wrong labels (content with incorrect outputs)
    wrong_examples = []
    for ex in task_data["examples"][:3]:
        wrong_output = "negative" if ex.output == "positive" else "positive"
        if task_data == TEST_QUERIES["math"]:
            wrong_output = str(int(ex.output) + 1) if ex.output.isdigit() else "wrong"
        wrong_examples.append(FewShotExample(ex.input, wrong_output))
    
    builder_wrong = FewShotPromptBuilder(wrong_examples)
    wrong_prompt = builder_wrong.build(
        query=task_data["query"],
        n_examples=3,
        example_format=task_data["format"],
        query_format=task_data["query_format"]
    )
    
    # Evaluate all
    for name, prompt in [("full", full_prompt), ("format_only", format_prompt), ("wrong_labels", wrong_prompt)]:
        seq_probs = model.get_sequence_log_probs(prompt, " " + task_data["expected"])
        dist = model.get_next_token_distribution(prompt)
        
        results.append({
            "condition": name,
            "target_log_prob": seq_probs["total_log_prob"],
            "entropy": dist["entropy"]
        })
    
    return results

In [None]:
# Test format vs content
format_content_results = {}

for task_name, task_data in TEST_QUERIES.items():
    print(f"\n{task_name}:")
    results = test_format_vs_content(model, task_data)
    format_content_results[task_name] = results
    
    for r in results:
        print(f"  {r['condition']:15s}: log_prob={r['target_log_prob']:.3f}, entropy={r['entropy']:.3f}")

In [None]:
# Visualize format vs content
fig, ax = plt.subplots(figsize=(10, 6))

x = np.arange(len(TEST_QUERIES))
width = 0.25

conditions = ["full", "format_only", "wrong_labels"]
colors = ['green', 'steelblue', 'red']

for i, condition in enumerate(conditions):
    values = []
    for task_name in TEST_QUERIES.keys():
        for r in format_content_results[task_name]:
            if r["condition"] == condition:
                values.append(r["target_log_prob"])
    
    ax.bar(x + i*width, values, width, label=condition, color=colors[i], alpha=0.7)

ax.set_xlabel('Task')
ax.set_ylabel('Target Log Probability')
ax.set_title('Format vs Content: What Makes Few-Shot Work?')
ax.set_xticks(x + width)
ax.set_xticklabels(TEST_QUERIES.keys())
ax.legend()

plt.tight_layout()
plt.savefig('../results/exp4_format_vs_content.png', dpi=150, bbox_inches='tight')
plt.show()

## 6. Key Findings

In [None]:
print("="*60)
print("EXPERIMENT 4 SUMMARY: Few-Shot Analysis")
print("="*60)

print("\n1. N-Shot Scaling:")
for task_name, results in scaling_results.items():
    best_n = max(results, key=lambda x: x["target_log_prob"])
    print(f"   {task_name}: Best at {best_n['n_shot']}-shot (log_prob={best_n['target_log_prob']:.3f})")

print("\n2. Order Sensitivity:")
for task_name, results in order_results.items():
    log_probs = [r["target_log_prob"] for r in results]
    print(f"   {task_name}: variance={np.var(log_probs):.4f}, range={max(log_probs)-min(log_probs):.3f}")

print("\n3. Format vs Content:")
for task_name in TEST_QUERIES.keys():
    results = format_content_results[task_name]
    full = next(r for r in results if r["condition"] == "full")
    format_only = next(r for r in results if r["condition"] == "format_only")
    wrong = next(r for r in results if r["condition"] == "wrong_labels")
    
    format_contrib = format_only["target_log_prob"]
    content_contrib = full["target_log_prob"] - format_only["target_log_prob"]
    
    print(f"   {task_name}: format contribution ≈ {format_contrib:.3f}, content adds ≈ {content_contrib:.3f}")

In [None]:
# Save results
import json
import os

os.makedirs('../results', exist_ok=True)

save_data = {
    "scaling": {k: [{"n_shot": r["n_shot"], "log_prob": r["target_log_prob"]} for r in v] 
                for k, v in scaling_results.items()},
    "order_variance": {k: np.var([r["target_log_prob"] for r in v]) 
                       for k, v in order_results.items()},
    "format_vs_content": format_content_results
}

with open('../results/exp4_fewshot_results.json', 'w') as f:
    json.dump(save_data, f, indent=2, default=float)

print("Results saved.")