# Notebook 10: LegalBench Generation Evaluation

**Objective:** Compare generation methods on LegalBench-RAG benchmark

**Methods:**
1. **No-RAG**: Direct generation without retrieval (baseline)
2. **Basic RAG**: Retrieve once, then generate
3. **Self-RAG**: Adaptive retrieval with reflection tokens
4. **Self-RAG + INSIDE**: Self-RAG enhanced with hallucination detection

**Metrics:**
- **F1 Score**: Token-level overlap with ground truth
- **ROUGE-L**: Longest common subsequence
- **Hallucination Rate**: From reflection tokens (ISSUP)
- **Utility Score**: From ISUSE tokens
- **EigenScore**: Internal state-based hallucination detection (INSIDE)

In [None]:
import sys
from pathlib import Path
import json
import yaml
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict

# Add src to path
sys.path.insert(0, str(Path.cwd().parent / 'src'))

# Set random seeds for reproducibility
import random
import torch
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

# Plotting style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

print("✓ Setup complete")

## 1. Load Data and Models

In [None]:
# Load LegalBench mini dataset
QUERIES_FILE = "../data/legalbench-rag/queries.json"
NUM_QUERIES = 776  # Mini dataset

with open(QUERIES_FILE, 'r') as f:
    queries_data = json.load(f)

queries = queries_data['tests'][:NUM_QUERIES]

print(f"Loaded {len(queries)} queries")
print(f"\nExample query:")
print(f"  Query: {queries[0]['query'][:100]}...")
print(f"  Dataset: {queries[0]['dataset_source']}")
print(f"  Num snippets: {len(queries[0]['snippets'])}")

In [None]:
# Count queries by subdataset
dataset_counts = defaultdict(int)
for q in queries:
    dataset_counts[q['dataset_source']] += 1

print("Queries by subdataset:")
for dataset, count in sorted(dataset_counts.items()):
    print(f"  {dataset}: {count}")

### 1.1 Load Retriever

In [None]:
from retrieval.retriever import LegalRetriever
from retrieval.embedding import EmbeddingModel

# Load retriever
print("Loading retriever...")
embedding_model = EmbeddingModel(model_name="sentence-transformers/all-mpnet-base-v2")
retriever = LegalRetriever(embedding_model=embedding_model, top_k=3)
retriever.load_index("../data/legalbench_embeddings")

print("✓ Retriever loaded")

### 1.2 Load Generator Model

**Note:** This notebook expects trained Self-RAG models. If training is not complete:
- Base model will be used (no LoRA)
- Reflection tokens may not be accurate
- Results will demonstrate the pipeline, not final performance

In [None]:
from self_rag.generator import SelfRAGGenerator
from self_rag.critic import Critic
from self_rag.inference import SelfRAGPipeline

# Load generator config
with open('../configs/generator_config.yaml', 'r') as f:
    generator_config = yaml.safe_load(f)

# Load critic config
with open('../configs/critic_config.yaml', 'r') as f:
    critic_config = yaml.safe_load(f)

print("Loading generator model...")
generator = SelfRAGGenerator(generator_config)

print("Loading critic model...")
critic = Critic(critic_config)

print("✓ Models loaded")

# Check if LoRA weights exist
lora_path = Path(generator_config['model']['lora_weights_path'])
if lora_path.exists():
    print("✓ Using trained LoRA weights")
else:
    print("⚠️  LoRA weights not found - using base model")
    print("   Train models first with: python src/training/train_generator_qlora.py")

## 2. Implement 4 Methods

### Method 1: No-RAG (Baseline)

In [None]:
def no_rag_generation(query: str, generator: SelfRAGGenerator) -> dict:
    """
    Generate answer without retrieval (baseline).
    
    Args:
        query: Legal question
        generator: Generator model
    
    Returns:
        Dictionary with answer and metadata
    """
    prompt = f"""Question: {query}

Answer the legal question based on your knowledge. Be concise and factual.

Answer:"""
    
    result = generator.generate(
        prompt=prompt,
        max_new_tokens=512,
        temperature=0.0,  # Greedy for reproducibility
        do_sample=False,
    )
    
    return {
        'answer': result['generated_text'],
        'method': 'No-RAG',
        'num_tokens': len(result['generated_text'].split()),
    }

# Test
test_result = no_rag_generation(queries[0]['query'], generator)
print("No-RAG test:")
print(f"  Answer: {test_result['answer'][:100]}...")
print(f"  Length: {test_result['num_tokens']} tokens")

### Method 2: Basic RAG

In [None]:
def basic_rag_generation(query: str, retriever: LegalRetriever, generator: SelfRAGGenerator) -> dict:
    """
    Generate answer with basic RAG (retrieve once, then generate).
    
    Args:
        query: Legal question
        retriever: Retrieval system
        generator: Generator model
    
    Returns:
        Dictionary with answer and metadata
    """
    # Retrieve passages
    retrieved_docs = retriever.retrieve(query, top_k=3)
    
    if not retrieved_docs:
        # Fallback to no-RAG
        return no_rag_generation(query, generator)
    
    # Concatenate top passages
    passages_text = "\n\n".join([doc['text'][:500] for doc in retrieved_docs[:2]])
    
    prompt = f"""Question: {query}

Relevant Legal Documents:
{passages_text}

Based on the provided documents, answer the legal question. Be concise and cite specific clauses when relevant.

Answer:"""
    
    result = generator.generate(
        prompt=prompt,
        max_new_tokens=512,
        temperature=0.0,
        do_sample=False,
    )
    
    return {
        'answer': result['generated_text'],
        'method': 'Basic RAG',
        'num_tokens': len(result['generated_text'].split()),
        'retrieved_docs': [doc['source'] for doc in retrieved_docs],
    }

# Test
test_result = basic_rag_generation(queries[0]['query'], retriever, generator)
print("Basic RAG test:")
print(f"  Answer: {test_result['answer'][:100]}...")
print(f"  Retrieved from: {test_result['retrieved_docs'][0]}")

### Method 3: Self-RAG

In [None]:
def self_rag_generation(query: str, retriever: LegalRetriever, generator: SelfRAGGenerator, critic: Critic) -> dict:
    """
    Generate answer with Self-RAG (adaptive retrieval + reflection tokens).
    
    Args:
        query: Legal question
        retriever: Retrieval system
        generator: Self-RAG generator
        critic: Critic model for reflection tokens
    
    Returns:
        Dictionary with answer, reflection tokens, and metadata
    """
    # Use Self-RAG pipeline
    pipeline = SelfRAGPipeline(
        generator=generator,
        retriever=retriever,
        critic=critic,
    )
    
    result = pipeline.answer_question(
        question=query,
        include_retrieval=True,
        max_new_tokens=512,
        temperature=0.0,
    )
    
    return {
        'answer': result['answer'],
        'method': 'Self-RAG',
        'num_tokens': len(result['answer'].split()),
        'reflection_tokens': result.get('reflection_tokens', {}),
        'score': result.get('score', 0.0),
        'retrieved_docs': result.get('retrieved_docs', []),
    }

# Test
test_result = self_rag_generation(queries[0]['query'], retriever, generator, critic)
print("Self-RAG test:")
print(f"  Answer: {test_result['answer'][:100]}...")
print(f"  Reflection tokens: {test_result['reflection_tokens']}")
print(f"  Score: {test_result['score']:.3f}")

### Method 4: Self-RAG + INSIDE

In [None]:
def self_rag_inside_generation(query: str, retriever: LegalRetriever, generator: SelfRAGGenerator, critic: Critic) -> dict:
    """
    Generate answer with Self-RAG + INSIDE hallucination detection.
    
    Args:
        query: Legal question
        retriever: Retrieval system
        generator: Self-RAG generator with INSIDE enabled
        critic: Critic model
    
    Returns:
        Dictionary with answer, reflection tokens, EigenScore, and metadata
    """
    # Temporarily enable INSIDE
    original_inside_enabled = generator.config.get('inside', {}).get('enabled', False)
    
    if 'inside' not in generator.config:
        generator.config['inside'] = {}
    generator.config['inside']['enabled'] = True
    generator.config['inside']['extract_internal_states'] = True
    
    try:
        # Use Self-RAG pipeline with INSIDE
        pipeline = SelfRAGPipeline(
            generator=generator,
            retriever=retriever,
            critic=critic,
        )
        
        result = pipeline.answer_question(
            question=query,
            include_retrieval=True,
            max_new_tokens=512,
            temperature=0.0,
            detect_hallucination=True,  # Enable INSIDE
        )
        
        return {
            'answer': result['answer'],
            'method': 'Self-RAG+INSIDE',
            'num_tokens': len(result['answer'].split()),
            'reflection_tokens': result.get('reflection_tokens', {}),
            'eigenscore': result.get('eigenscore', None),
            'hallucination_detected': result.get('hallucination_detected', None),
            'combined_score': result.get('combined_score', result.get('score', 0.0)),
            'retrieved_docs': result.get('retrieved_docs', []),
        }
    
    finally:
        # Restore original setting
        generator.config['inside']['enabled'] = original_inside_enabled

# Test
test_result = self_rag_inside_generation(queries[0]['query'], retriever, generator, critic)
print("Self-RAG+INSIDE test:")
print(f"  Answer: {test_result['answer'][:100]}...")
print(f"  EigenScore: {test_result.get('eigenscore', 'N/A')}")
print(f"  Hallucination detected: {test_result.get('hallucination_detected', 'N/A')}")
print(f"  Combined score: {test_result['combined_score']:.3f}")

## 3. Run Evaluation

**Note:** This will take 1-2 hours for all 776 queries × 4 methods.
For testing, use a smaller subset (e.g., 50 queries).

In [None]:
# Configuration
USE_SUBSET = True  # Set to False to run on full 776 queries
SUBSET_SIZE = 50 if USE_SUBSET else len(queries)

# Cache file
CACHE_FILE = f"../results/generation_results_{'subset' if USE_SUBSET else 'full'}.json"
Path(CACHE_FILE).parent.mkdir(exist_ok=True)

print(f"Will evaluate on {SUBSET_SIZE} queries")
print(f"Cache file: {CACHE_FILE}")

In [None]:
# Check if cached results exist
if Path(CACHE_FILE).exists():
    print(f"✓ Loading cached results from {CACHE_FILE}")
    with open(CACHE_FILE, 'r') as f:
        all_results = json.load(f)
    print(f"  Loaded {len(all_results)} results")
else:
    print("No cached results found. Will run evaluation...")
    all_results = None

In [None]:
# Run evaluation (skip if cached)
if all_results is None:
    all_results = []
    
    for i, query_data in enumerate(tqdm(queries[:SUBSET_SIZE], desc="Evaluating")):
        query = query_data['query']
        
        # Get ground truth
        if query_data['snippets']:
            ground_truth = query_data['snippets'][0]['answer']
        else:
            ground_truth = ""
        
        result_entry = {
            'query_id': i,
            'query': query,
            'ground_truth': ground_truth,
            'dataset_source': query_data['dataset_source'],
            'methods': {}
        }
        
        # Method 1: No-RAG
        try:
            no_rag_result = no_rag_generation(query, generator)
            result_entry['methods']['No-RAG'] = no_rag_result
        except Exception as e:
            print(f"\nError in No-RAG for query {i}: {e}")
            result_entry['methods']['No-RAG'] = {'error': str(e)}
        
        # Method 2: Basic RAG
        try:
            basic_rag_result = basic_rag_generation(query, retriever, generator)
            result_entry['methods']['Basic RAG'] = basic_rag_result
        except Exception as e:
            print(f"\nError in Basic RAG for query {i}: {e}")
            result_entry['methods']['Basic RAG'] = {'error': str(e)}
        
        # Method 3: Self-RAG
        try:
            self_rag_result = self_rag_generation(query, retriever, generator, critic)
            result_entry['methods']['Self-RAG'] = self_rag_result
        except Exception as e:
            print(f"\nError in Self-RAG for query {i}: {e}")
            result_entry['methods']['Self-RAG'] = {'error': str(e)}
        
        # Method 4: Self-RAG + INSIDE
        try:
            inside_result = self_rag_inside_generation(query, retriever, generator, critic)
            result_entry['methods']['Self-RAG+INSIDE'] = inside_result
        except Exception as e:
            print(f"\nError in Self-RAG+INSIDE for query {i}: {e}")
            result_entry['methods']['Self-RAG+INSIDE'] = {'error': str(e)}
        
        all_results.append(result_entry)
    
    # Save results
    with open(CACHE_FILE, 'w') as f:
        json.dump(all_results, f, indent=2)
    
    print(f"\n✓ Results saved to {CACHE_FILE}")

## 4. Compute Metrics

In [None]:
from evaluation.legalbench_generation_eval import (
    evaluate_generation,
    aggregate_metrics,
    compare_methods,
)

# Compute metrics for each method
metrics_by_method = {}

for method_name in ['No-RAG', 'Basic RAG', 'Self-RAG', 'Self-RAG+INSIDE']:
    method_results = []
    
    for result in all_results:
        if method_name not in result['methods']:
            continue
        
        method_data = result['methods'][method_name]
        
        if 'error' in method_data:
            continue
        
        # Evaluate this example
        metrics = evaluate_generation(
            prediction=method_data.get('answer', ''),
            ground_truth=result['ground_truth'],
            reflection_tokens=method_data.get('reflection_tokens'),
            eigenscore=method_data.get('eigenscore'),
        )
        
        method_results.append(metrics)
    
    # Aggregate
    metrics_by_method[method_name] = aggregate_metrics(method_results)

print("✓ Metrics computed for all methods")

In [None]:
# Display comparison table
import pandas as pd

comparison_data = []

for method_name, metrics in metrics_by_method.items():
    row = {
        'Method': method_name,
        'F1': f"{metrics['avg_f1_score']:.3f}",
        'ROUGE-L': f"{metrics['avg_rouge_l']:.3f}",
        'Halluc%': f"{metrics.get('hallucination_rate', 0) * 100:.1f}%" if 'hallucination_rate' in metrics else 'N/A',
        'Utility': f"{metrics.get('avg_utility_score', 0):.2f}" if 'avg_utility_score' in metrics else 'N/A',
        'Avg Length': f"{metrics['avg_prediction_length']:.0f}",
    }
    comparison_data.append(row)

comparison_df = pd.DataFrame(comparison_data)
print("\n" + "=" * 80)
print("Generation Method Comparison")
print("=" * 80)
print(comparison_df.to_string(index=False))
print("=" * 80)

## 5. Visualizations

In [None]:
# Bar chart: F1 and ROUGE-L comparison
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

methods = list(metrics_by_method.keys())
f1_scores = [metrics_by_method[m]['avg_f1_score'] for m in methods]
rouge_scores = [metrics_by_method[m]['avg_rouge_l'] for m in methods]

# F1 scores
ax1.bar(methods, f1_scores, color=['#e74c3c', '#3498db', '#2ecc71', '#9b59b6'])
ax1.set_ylabel('F1 Score')
ax1.set_title('F1 Score by Method')
ax1.set_ylim(0, max(f1_scores) * 1.2)
ax1.grid(axis='y', alpha=0.3)

for i, v in enumerate(f1_scores):
    ax1.text(i, v + 0.01, f'{v:.3f}', ha='center', fontweight='bold')

# ROUGE-L scores
ax2.bar(methods, rouge_scores, color=['#e74c3c', '#3498db', '#2ecc71', '#9b59b6'])
ax2.set_ylabel('ROUGE-L Score')
ax2.set_title('ROUGE-L Score by Method')
ax2.set_ylim(0, max(rouge_scores) * 1.2)
ax2.grid(axis='y', alpha=0.3)

for i, v in enumerate(rouge_scores):
    ax2.text(i, v + 0.01, f'{v:.3f}', ha='center', fontweight='bold')

plt.tight_layout()
plt.savefig('../results/generation_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

print("✓ Saved to results/generation_comparison.png")

In [None]:
# Radar chart: Multi-metric comparison
from math import pi

# Prepare data (normalize to 0-1 for radar chart)
categories = ['F1', 'ROUGE-L', 'Support\n(1-Halluc)', 'Utility', 'Relevance']
N = len(categories)

angles = [n / float(N) * 2 * pi for n in range(N)]
angles += angles[:1]

fig, ax = plt.subplots(figsize=(10, 10), subplot_kw=dict(projection='polar'))

colors = ['#e74c3c', '#3498db', '#2ecc71', '#9b59b6']

for i, method in enumerate(methods):
    metrics = metrics_by_method[method]
    
    values = [
        metrics['avg_f1_score'],
        metrics['avg_rouge_l'],
        1 - metrics.get('hallucination_rate', 0.5),  # Invert hallucination
        metrics.get('avg_utility_score', 0.5),
        metrics.get('relevance_rate', 0.5),
    ]
    values += values[:1]
    
    ax.plot(angles, values, 'o-', linewidth=2, label=method, color=colors[i])
    ax.fill(angles, values, alpha=0.15, color=colors[i])

ax.set_xticks(angles[:-1])
ax.set_xticklabels(categories)
ax.set_ylim(0, 1)
ax.set_yticks([0.2, 0.4, 0.6, 0.8, 1.0])
ax.set_yticklabels(['0.2', '0.4', '0.6', '0.8', '1.0'])
ax.grid(True)
ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1))
ax.set_title('Multi-Metric Comparison (Radar Chart)', pad=20, fontsize=14, fontweight='bold')

plt.tight_layout()
plt.savefig('../results/radar_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

print("✓ Saved to results/radar_comparison.png")

In [None]:
# Per-subdataset breakdown
subdataset_metrics = defaultdict(lambda: defaultdict(list))

for result in all_results:
    dataset = result['dataset_source']
    ground_truth = result['ground_truth']
    
    for method_name in methods:
        if method_name not in result['methods']:
            continue
        
        method_data = result['methods'][method_name]
        
        if 'error' in method_data:
            continue
        
        # Compute F1
        from evaluation.legalbench_generation_eval import compute_f1_score
        f1 = compute_f1_score(method_data.get('answer', ''), ground_truth)
        
        subdataset_metrics[dataset][method_name].append(f1)

# Plot
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()

for i, (dataset, method_scores) in enumerate(sorted(subdataset_metrics.items())):
    if i >= 4:
        break
    
    method_names = list(method_scores.keys())
    avg_scores = [np.mean(method_scores[m]) for m in method_names]
    
    axes[i].bar(method_names, avg_scores, color=['#e74c3c', '#3498db', '#2ecc71', '#9b59b6'])
    axes[i].set_title(f'{dataset} (n={len(method_scores[method_names[0]])})')
    axes[i].set_ylabel('Avg F1 Score')
    axes[i].set_ylim(0, max(avg_scores) * 1.2 if avg_scores else 1)
    axes[i].grid(axis='y', alpha=0.3)
    axes[i].tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.savefig('../results/subdataset_breakdown.png', dpi=300, bbox_inches='tight')
plt.show()

print("✓ Saved to results/subdataset_breakdown.png")

## 6. Example Outputs

Show side-by-side comparison of the same query across all 4 methods.

In [None]:
# Pick an interesting example
example_idx = 5
example = all_results[example_idx]

print("=" * 80)
print(f"Example Query (ID: {example['query_id']})")
print("=" * 80)
print(f"Dataset: {example['dataset_source']}")
print()
print(f"Query: {example['query']}")
print()
print(f"Ground Truth Snippet: {example['ground_truth'][:200]}...")
print()
print("=" * 80)
print("Method Outputs")
print("=" * 80)

for method_name in methods:
    if method_name not in example['methods']:
        continue
    
    method_data = example['methods'][method_name]
    
    if 'error' in method_data:
        print(f"\n{method_name}: ERROR - {method_data['error']}")
        continue
    
    answer = method_data.get('answer', '')
    
    # Compute metrics
    from evaluation.legalbench_generation_eval import compute_f1_score, compute_rouge_l
    f1 = compute_f1_score(answer, example['ground_truth'])
    rouge = compute_rouge_l(answer, example['ground_truth'])
    
    print(f"\n{method_name}:")
    print(f"  F1: {f1:.3f} | ROUGE-L: {rouge:.3f}")
    
    if 'reflection_tokens' in method_data:
        tokens = method_data['reflection_tokens']
        print(f"  Reflection: {tokens}")
    
    if 'eigenscore' in method_data:
        print(f"  EigenScore: {method_data['eigenscore']:.2f}")
    
    print(f"  Answer: {answer}")
    print("-" * 80)

## 7. Summary and Analysis

### Expected Results (with trained models):

| Method | F1 | ROUGE-L | Halluc% | Utility | Notes |
|--------|-----|---------|---------|---------|-------|
| No-RAG | ~0.10-0.15 | ~0.12-0.18 | N/A | N/A | Baseline, no external knowledge |
| Basic RAG | ~0.25-0.35 | ~0.30-0.40 | N/A | N/A | Simple retrieval helps significantly |
| Self-RAG | ~0.40-0.50 | ~0.45-0.55 | ~15-25% | ~0.70-0.80 | Adaptive retrieval + self-assessment |
| Self-RAG+INSIDE | ~0.42-0.52 | ~0.47-0.57 | ~10-18% | ~0.75-0.85 | Best overall, combines reflection + internal states |

### Key Insights:

1. **RAG is essential**: No-RAG performs poorly on legal questions requiring specific document knowledge
2. **Self-reflection helps**: Self-RAG's reflection tokens enable self-assessment and adaptive retrieval
3. **INSIDE reduces hallucinations**: EigenScore provides additional signal for hallucination detection
4. **Trade-offs**: More sophisticated methods take longer but produce higher quality, more reliable answers

### Next Steps:

1. **Error Analysis**: Identify which types of queries each method struggles with
2. **Qualitative Evaluation**: Manual review of outputs for legal correctness
3. **Ablation Studies**: Test individual components (e.g., adaptive retrieval vs. fixed retrieval)
4. **Hyperparameter Tuning**: Optimize reflection token weights, retrieval threshold, etc.