# Factuality Slice: Evidence-Grounded QA Training Pipeline

This notebook implements a complete training pipeline for evidence-grounded question answering with citation training.

## Pipeline Overview

1. **Data Building**: Create unified dataset from FEVER, HotpotQA, NQ-Open, and PopQA
2. **Supervised Fine-Tuning (SFT)**: Train base model to answer questions with citations
3. **Evaluation**: Compare SFT model against base model
4. **Preference Generation**: Create preference pairs using judge model
5. **Reward Model Training**: Train reward model with anti-gaming features

## Requirements

- Google Colab with A100 GPU (recommended)
- Hugging Face account for model downloads
- ~50GB storage for models and data

## Step 1: Environment Setup

Mount Google Drive for persistent storage and authenticate with Hugging Face for model access.

In [None]:
# Mount Google Drive for persistent storage
from google.colab import drive
drive.mount('/content/drive')

# Login to Hugging Face (you'll need to paste your access token)
# Get your token from: https://huggingface.co/settings/tokens
!huggingface-cli login

# Install required packages
!pip install -U bitsandbytes transformers peft accelerate datasets
!pip install scipy scikit-learn pyyaml tqdm

## Step 2: Build the Dataset

Download and process data from multiple QA sources. This creates:
- `data/processed/train.jsonl`: Training data
- `data/processed/val.jsonl`: Validation data
- `data/processed/test.jsonl`: Test data

Each sample contains:
- Question
- Evidence chunks (context)
- Gold answer
- Support spans (which evidence chunks contain the answer)
- Hard negatives (topically related but incorrect evidence)

In [None]:
# Build the unified dataset from multiple sources
# --accept-new-hash: Required on first run to create source verification hashes
# --seed: Ensures reproducible train/val/test splits
!python build_data.py --data-dir data --seed 42 --accept-new-hash

# Check the generated data
!echo "Dataset statistics:"
!wc -l data/processed/*.jsonl
!echo "\nSample from training data:"
!head -n 1 data/processed/train.jsonl | python -m json.tool | head -20

## Step 3: Generate Prompt Templates

Create standardized prompt templates and output schemas. This ensures consistent formatting across training and inference.

In [None]:
# Generate prompt templates and validation schemas
!python prompt_schema.py save --prompts-dir prompts

# View the generated templates
!echo "Generated prompt templates:"
!ls -la prompts/

# Test the output validator
!python prompt_schema.py test

## Step 4: Supervised Fine-Tuning (SFT)

Fine-tune a base model using LoRA/QLoRA for parameter-efficient training.

### Training Details:
- **Method**: QLoRA (4-bit quantization with LoRA adapters)
- **Base Model**: Gemma-2-9B (configurable)
- **Training Objective**: Next-token prediction on JSON answers
- **Key Features**:
  - Answer-only loss (prompt tokens masked)
  - Gradient checkpointing for memory efficiency
  - Mixed precision training (bf16)

### Expected Training Time:
- A100 GPU: ~2-3 hours for 2 epochs
- V100 GPU: ~4-5 hours for 2 epochs

In [None]:
# Create configuration file for SFT training
# You can modify these parameters based on your needs
!mkdir -p configs
config_content = """model_name: google/gemma-2-9b
output_dir: checkpoints/sft/gemma2-9b
load_in_4bit: true
bf16: true
trust_remote_code: false
attn_implementation: eager  # Use 'flash_attention_2' if available

lora:
  r: 16
  alpha: 32
  dropout: 0.05
  target_modules: [q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj]

training:
  per_device_train_batch_size: 1
  per_device_eval_batch_size: 2
  gradient_accumulation_steps: 64
  learning_rate: 1.5e-5
  num_train_epochs: 2
  max_seq_length: 2048
  warmup_ratio: 0.03
  weight_decay: 0.1
  lr_scheduler_type: cosine
  logging_steps: 10
  eval_steps: 100
  save_steps: 100
  save_total_limit: 2
  seed: 42

data:
  train_path: data/processed/train.jsonl
  val_path: data/processed/val.jsonl
  template: default
  max_samples: null  # Use all training data
  max_val_samples: 500
"""

with open('configs/sft_gemma2.yaml', 'w') as f:
    f.write(config_content)

# Train the SFT model
!python train_sft.py -c configs/sft_gemma2.yaml

## Step 5: Evaluate SFT Model

Compare the fine-tuned model against the base model on test data.

### Metrics Evaluated:
- **Exact Match (EM)**: Percentage of exact answer matches
- **F1 Score**: Token-level overlap with gold answers
- **Hallucination Rate**: Answers without proper evidence citations
- **Citation Correctness**: Whether citations match gold support spans
- **Refusal Rate**: Frequency of "insufficient evidence" responses

In [None]:
# Find the best checkpoint (usually the last one)
!ls -la checkpoints/sft/gemma2-9b/

# Evaluate the SFT model and compare with base model
# Note: Adjust checkpoint number based on training output
!python evaluate_sft_fast_robust.py \
  --model-path checkpoints/sft/gemma2-9b/checkpoint-35 \
  --base-model google/gemma-2-9b \
  --test-data data/processed/test.jsonl \
  --compare-baseline \
  --batch-size 4 \
  --max-new-tokens 256 \
  --fast \
  --buckets 2048,3072,4096,5120,6144,7168,7936 \
  --attn-impl flash2 \
  --output eval_results.json

# Display evaluation results
!echo "\nEvaluation Results:"
!cat eval_results.json | python -m json.tool

## Step 6: Generate Preference Data (RLAIF)

Create preference pairs for reward model training using a two-phase approach:

### Phase 1: Generate Candidates
- Generate multiple responses from both SFT and base models
- Responses are cached for efficiency

### Phase 2: Judge Candidates
- Use a separate judge model to evaluate response pairs
- Create preference pairs based on judge scores
- Filter pairs by minimum margin for quality

This two-phase approach allows:
- Memory-efficient processing (only one model loaded at a time)
- Resume capability if interrupted
- Different quantization settings for generators vs judge

In [None]:
# Phase 1: Generate candidate responses from SFT and base models
# This creates cached generations in prefs/cache/<signature>/
!python make_prefs.py --phase generate \
  --sft-model checkpoints/sft/gemma2-9b/checkpoint-29 \
  --base-model google/gemma-2-9b \
  --data-path data/processed/train.jsonl \
  --max-samples 600 \
  --n-generations 2 \
  --batch-size 12 \
  --max-new-tokens 80 \
  --max-input-tokens 1536 \
  --temperature 0.0 \
  --use-8bit  # Use 8-bit quantization for memory efficiency

# Phase 2: Judge the generated candidates
# This creates preference pairs in prefs/preferences_<signature>.jsonl
!python make_prefs.py --phase judge \
  --judge-model Qwen/Qwen2.5-3B-Instruct \
  --data-path data/processed/train.jsonl \
  --max-samples 600 \
  --n-generations 2 \
  --judge-batch-size 12 \
  --min-margin 0.12 \
  --judge-use-8bit \
  --output-mode rich  # Include parsed outputs and metadata

# Check the generated preference data
!echo "\nPreference data statistics:"
!ls -lh prefs/
!wc -l prefs/preferences_*.jsonl
!echo "\nSample preference pair:"
!head -n 1 prefs/preferences_*.jsonl | python -m json.tool | head -30

## Step 7: Train Reward Model

Train a reward model on the preference pairs using Bradley-Terry loss.

### Key Features:
- **Anti-gaming penalties**: Penalizes invalid citations, over-citation, and excessive length
- **Support span checking**: Ensures cited evidence actually contains answer support
- **Confidence calibration**: Rewards appropriate confidence levels

### Training Details:
- Base model frozen, only reward head trained
- Validation includes AUC, calibration, and baseline comparison
- Best model saved based on ROC-AUC score

In [None]:
# Create reward model configuration
rm_config = """base_model: google/gemma-2-9b
output_dir: checkpoints/rm
learning_rate: 5e-6
batch_size: 16
eval_batch_size: 32
num_epochs: 3
warmup_ratio: 0.1
weight_decay: 0.01
max_seq_length: 2048
gradient_accumulation_steps: 1
eval_steps: 100
save_steps: 200
logging_steps: 10
load_in_8bit: false
seed: 42

# Anti-gaming penalty weights (tune these based on your needs)
citation_penalty_weight: 0.2      # Invalid citations penalty
confidence_penalty_weight: 0.1     # Confidence mismatch penalty
length_penalty_weight: 0.05        # Excessive length penalty
overcitation_penalty_weight: 0.15  # Citing too many chunks
span_penalty_weight: 0.25          # Support span mismatch penalty
"""

with open('configs/rm.yaml', 'w') as f:
    f.write(rm_config)

# Train the reward model
!python train_rm.py -c configs/rm.yaml

# Validate the trained reward model
!python train_rm.py -c configs/rm.yaml --validate-only --checkpoint best

# Display training history
!echo "\nTraining History:"
!cat checkpoints/rm/training_history.json | python -m json.tool | tail -50

## Step 8: Analysis and Next Steps

### Analyze Results
View the evaluation metrics and training curves to assess model performance.

In [None]:
# Load and visualize results
import json
import matplotlib.pyplot as plt

# Load evaluation results
with open('eval_results.json', 'r') as f:
    eval_data = json.load(f)

# Display improvements
if 'deltas' in eval_data:
    print("Improvements over base model:")
    print(f"  EM improvement: {eval_data['deltas']['em']:.3f}")
    print(f"  F1 improvement: {eval_data['deltas']['f1']:.3f}")
    print(f"  Hallucination reduction: {eval_data['deltas']['hallucination_rate']:.3f}")
    print(f"  Citation improvement: {eval_data['deltas']['citation_correctness']:.3f}")

# Load reward model training history
with open('checkpoints/rm/training_history.json', 'r') as f:
    rm_history = json.load(f)

# Plot reward model training curve
if rm_history['train_losses']:
    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.plot(rm_history['train_losses'])
    plt.title('Reward Model Training Loss')
    plt.xlabel('Step')
    plt.ylabel('Loss')
    
    if rm_history['eval_metrics']:
        plt.subplot(1, 2, 2)
        eval_aucs = [m.get('roc_auc', 0) for m in rm_history['eval_metrics']]
        plt.plot(eval_aucs)
        plt.title('Reward Model Validation AUC')
        plt.xlabel('Evaluation Step')
        plt.ylabel('ROC-AUC')
    
    plt.tight_layout()
    plt.show()

## Next Steps

With the trained SFT model and reward model, you can:

1. **Deploy for Inference**: Use the SFT model for evidence-grounded QA
2. **Further RLHF Training**: Use the reward model for PPO training
3. **DPO Training**: Use preference pairs directly for Direct Preference Optimization
4. **Iterative Improvement**: Generate new preferences from improved models

### Saving Models to Hub

To share your models on Hugging Face Hub:
```python
from huggingface_hub import HfApi
api = HfApi()
api.upload_folder(
    folder_path="checkpoints/sft/gemma2-9b/checkpoint-final",
    repo_id="your-username/factuality-gemma-9b-sft",
    repo_type="model"
)
```

### Tips for Production

- **Inference Optimization**: Use Flash Attention 2 and torch.compile for faster inference
- **Batching**: Implement dynamic batching for production serving
- **Monitoring**: Track citation accuracy and hallucination rates in production
- **Continuous Learning**: Collect user feedback for iterative improvement