# Expand Dataset: Multi-Domain Question Pairs

**Goal:** Generate 150 question pairs across 3 domains to build a robust dataset for probing.

**Domains:**
- Geography: North/South relationships (50 pairs)
- Dates: Historical event ordering (45 pairs)
- Population: City/country size comparisons (45 pairs)

**HITL Checkpoints:**
- After each domain: Review contradiction rate and CoT quality
- Before TransformerLens: Manual review of 10-15 contradiction examples

In [None]:
# Cell 0: Setup - Clone repo and install package
import os

# Clone repo (only if not already cloned)
if not os.path.exists('/content/MATS_Neel'):
    !git clone https://github.com/YOUR_USERNAME/MATS_Neel.git
    %cd /content/MATS_Neel
else:
    %cd /content/MATS_Neel
    !git pull  # Get latest changes

# Install dependencies
!pip install torch transformers accelerate pandas -q

# Install package in editable mode
!pip install -e . -q

print("Setup complete!")

In [None]:
# Cell 1: Imports
import torch
import pandas as pd
import re
from datetime import datetime
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login
from IPython.display import display, HTML

# Import from our package
from src.data_generation import (
    Domain,
    LOCATION_PAIRS,
    DATE_PAIRS,
    POPULATION_PAIRS,
    generate_geography_pairs,
    generate_date_pairs,
    generate_population_pairs,
    SYSTEM_PROMPTS,
)
from src.labeling import extract_yes_no, detect_contradiction
from src.experiment_utils import (
    ExperimentConfig,
    ExperimentResults,
    create_experiment_run,
    save_results,
    log_domain_metrics,
    finalize_results,
)

# Check GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Cell 2: HuggingFace Authentication
import os
from huggingface_hub import login

hf_token = None

# Method 1: Colab Secrets
try:
    from google.colab import userdata
    hf_token = userdata.get('HF_TOKEN')
    print("Found HF_TOKEN in Colab Secrets")
except:
    pass

# Method 2: Environment variable
if not hf_token and "HF_TOKEN" in os.environ:
    hf_token = os.environ["HF_TOKEN"]
    print("Found HF_TOKEN in environment")

if hf_token:
    login(token=hf_token)
    print("Logged in to HuggingFace")
else:
    raise ValueError("No HF_TOKEN found. Add to Colab Secrets or environment.")

In [None]:
# Cell 3: Create experiment run
config = ExperimentConfig(
    name="expand_dataset_v1",
    description="Initial expanded dataset across 3 domains",
    domains=["geography", "dates", "population"],
    max_pairs_per_domain=50,
    max_new_tokens=300,
    temperature=0.0,
    notes="First run with multi-domain data",
)

run_dir = create_experiment_run("expand_dataset", config)
print(f"Experiment folder: {run_dir}")

# Initialize results
results = ExperimentResults(
    start_time=datetime.now().isoformat(),
)

In [None]:
# Cell 4: Load model
MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"

print(f"Loading {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
print(f"Model loaded on {device}!")

In [None]:
# Cell 5: Generation function
def generate_response(
    question: str,
    system_prompt: str,
    max_new_tokens: int = 300,
) -> str:
    """Generate a response from the model."""
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": question},
    ]
    
    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id,
        )
    
    response = tokenizer.decode(
        outputs[0][input_ids.shape[1]:],
        skip_special_tokens=True
    )
    return response.strip()


def run_domain(
    domain: Domain,
    question_pairs: list[dict],
    max_pairs: int = 50,
) -> pd.DataFrame:
    """Run question pairs for a domain and collect results."""
    system_prompt = SYSTEM_PROMPTS[domain]
    results_list = []
    
    pairs_to_run = question_pairs[:max_pairs]
    print(f"\n{'='*60}")
    print(f"Running {domain.value.upper()}: {len(pairs_to_run)} pairs")
    print(f"{'='*60}")
    
    for i, pair in enumerate(pairs_to_run):
        print(f"\nPair {i+1}/{len(pairs_to_run)}: {pair['entity_x']} vs {pair['entity_y']}")
        
        # Generate responses
        cot_a = generate_response(pair["question_a"], system_prompt)
        cot_b = generate_response(pair["question_b"], system_prompt)
        
        # Extract answers
        ans_a = extract_yes_no(cot_a)
        ans_b = extract_yes_no(cot_b)
        
        # Check for contradiction
        is_contradiction = detect_contradiction(ans_a, ans_b)
        
        if is_contradiction:
            print(f"  >>> CONTRADICTION: Both {ans_a}")
        else:
            print(f"  Answers: A={ans_a}, B={ans_b}")
        
        results_list.append({
            "pair_id": pair["pair_id"],
            "domain": domain.value,
            "entity_x": pair["entity_x"],
            "entity_y": pair["entity_y"],
            "difficulty": pair["difficulty"],
            "question_a": pair["question_a"],
            "question_b": pair["question_b"],
            "answer_a": ans_a,
            "answer_b": ans_b,
            "ground_truth_a": pair["ground_truth_a"],
            "ground_truth_b": pair["ground_truth_b"],
            "is_contradiction": is_contradiction,
            "cot_a": cot_a,
            "cot_b": cot_b,
        })
    
    return pd.DataFrame(results_list)

print("Functions ready.")

## Run Geography Domain

**HITL Checkpoint:** After this section, review the contradiction rate.
Target: >= 15% contradiction rate (we saw 40% in quick test)

In [None]:
# Cell 6: Run Geography
geo_pairs = generate_geography_pairs()
print(f"Generated {len(geo_pairs)} geography pairs")

geo_df = run_domain(Domain.GEOGRAPHY, geo_pairs, max_pairs=50)

In [None]:
# Cell 7: Geography Summary
geo_contradictions = geo_df["is_contradiction"].sum()
geo_rate = geo_contradictions / len(geo_df) * 100

print(f"\n{'='*60}")
print("GEOGRAPHY SUMMARY")
print(f"{'='*60}")
print(f"Total pairs: {len(geo_df)}")
print(f"Contradictions: {geo_contradictions} ({geo_rate:.1f}%)")

# Log metrics
correct_a = (geo_df["answer_a"] == geo_df["ground_truth_a"]).sum()
correct_b = (geo_df["answer_b"] == geo_df["ground_truth_b"]).sum()
log_domain_metrics(results, "geography", len(geo_df), geo_contradictions, correct_a, correct_b)

# Save intermediate results
geo_df.to_csv("data/trajectories/geography.csv", index=False)
print(f"\nSaved to data/trajectories/geography.csv")

## Run Dates Domain

**HITL Checkpoint:** Does the model show similar contradiction patterns with historical dates?

In [None]:
# Cell 8: Run Dates
date_pairs = generate_date_pairs()
print(f"Generated {len(date_pairs)} date pairs")

date_df = run_domain(Domain.DATES, date_pairs, max_pairs=45)

In [None]:
# Cell 9: Dates Summary
date_contradictions = date_df["is_contradiction"].sum()
date_rate = date_contradictions / len(date_df) * 100

print(f"\n{'='*60}")
print("DATES SUMMARY")
print(f"{'='*60}")
print(f"Total pairs: {len(date_df)}")
print(f"Contradictions: {date_contradictions} ({date_rate:.1f}%)")

# Log metrics
correct_a = (date_df["answer_a"] == date_df["ground_truth_a"]).sum()
correct_b = (date_df["answer_b"] == date_df["ground_truth_b"]).sum()
log_domain_metrics(results, "dates", len(date_df), date_contradictions, correct_a, correct_b)

# Save
date_df.to_csv("data/trajectories/dates.csv", index=False)
print(f"\nSaved to data/trajectories/dates.csv")

## Run Population Domain

**HITL Checkpoint:** Population comparisons may have different error patterns.

In [None]:
# Cell 10: Run Population
pop_pairs = generate_population_pairs()
print(f"Generated {len(pop_pairs)} population pairs")

pop_df = run_domain(Domain.POPULATION, pop_pairs, max_pairs=45)

In [None]:
# Cell 11: Population Summary
pop_contradictions = pop_df["is_contradiction"].sum()
pop_rate = pop_contradictions / len(pop_df) * 100

print(f"\n{'='*60}")
print("POPULATION SUMMARY")
print(f"{'='*60}")
print(f"Total pairs: {len(pop_df)}")
print(f"Contradictions: {pop_contradictions} ({pop_rate:.1f}%)")

# Log metrics
correct_a = (pop_df["answer_a"] == pop_df["ground_truth_a"]).sum()
correct_b = (pop_df["answer_b"] == pop_df["ground_truth_b"]).sum()
log_domain_metrics(results, "population", len(pop_df), pop_contradictions, correct_a, correct_b)

# Save
pop_df.to_csv("data/trajectories/population.csv", index=False)
print(f"\nSaved to data/trajectories/population.csv")

## Final Summary

In [None]:
# Cell 12: Overall Summary
from datetime import datetime

# Finalize results
results.end_time = datetime.now().isoformat()
finalize_results(results)

print(f"\n{'='*60}")
print("OVERALL SUMMARY")
print(f"{'='*60}")
print(f"Total pairs: {results.total_pairs}")
print(f"Total contradictions: {results.total_contradictions}")
print(f"Overall contradiction rate: {results.contradiction_rate:.1%}")

print("\nPer-domain breakdown:")
for domain, metrics in results.domain_metrics.items():
    print(f"  {domain}: {metrics['contradictions']}/{metrics['total_pairs']} "
          f"({metrics['contradiction_rate']:.1%})")

# Save final results
save_results(run_dir, results)
print(f"\nResults saved to {run_dir}/results.json")

In [None]:
# Cell 13: Combine all results
all_df = pd.concat([geo_df, date_df, pop_df], ignore_index=True)
all_df.to_csv("data/trajectories/all_domains.csv", index=False)
print(f"Combined dataset saved: {len(all_df)} pairs")

# Quick stats
print("\nContradictions by difficulty:")
display(all_df.groupby("difficulty")["is_contradiction"].agg(["sum", "count", "mean"]))

In [None]:
# Cell 14: Display contradiction cases for review
print("\n" + "="*60)
print("CONTRADICTION CASES FOR MANUAL REVIEW")
print("="*60)

contradiction_df = all_df[all_df["is_contradiction"]]
print(f"\nTotal contradiction cases: {len(contradiction_df)}")

# Show first 10 for manual review
for i, (_, row) in enumerate(contradiction_df.head(10).iterrows()):
    print(f"\n{'='*60}")
    print(f"[{row['domain'].upper()}] {row['entity_x']} vs {row['entity_y']}")
    print(f"Both answers: {row['answer_a']}")
    print(f"\n--- CoT A ---")
    print(row['cot_a'][:400] + "..." if len(row['cot_a']) > 400 else row['cot_a'])
    print(f"\n--- CoT B ---")
    print(row['cot_b'][:400] + "..." if len(row['cot_b']) > 400 else row['cot_b'])

In [None]:
# Cell 15: Commit results to GitHub (optional)
# Uncomment to save results back to repo

# !git add data/trajectories/
# !git add experiments/
# !git commit -m "Add expanded dataset results"
# !git push

## Decision Gate

Based on the results above:

| Contradiction Rate | Action |
|:-------------------|:-------|
| >= 25% overall | Excellent! Proceed to TransformerLens activation extraction |
| 15-25% overall | Good. May want to add more pairs, but can proceed |
| < 15% overall | Review domain-specific rates. Focus on highest-yield domain |

### Next Steps
1. Download `data/trajectories/all_domains.csv` locally
2. Manually review 10-15 contradiction cases
3. Log observations in `experiments/<run_folder>/notes.md`
4. If ready: proceed to `03_extract_activations.ipynb`