# Constitutional Law LLM Training

This notebook provides an interactive training workflow for the Constitutional Law LLM.

## Overview
- **Model**: OpenLLaMA with LoRA fine-tuning
- **Data**: Supreme Court cases (First and Fourth Amendments)
- **Task**: Constitutional law question answering

## Setup

In [None]:
# Install required packages
# !pip install -r ../requirements.txt

import os
import sys
import json
import torch
import wandb
from pathlib import Path

# Add src to path
sys.path.append('../src')

# Import our modules
from config import config
from data_processing import preprocess_data
from model_training import ConstitutionalLawTrainer, train_model
from model_utils import ModelManager

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## Configuration

Configure training parameters and paths:

In [None]:
# Configure environment
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

# Set up wandb (optional)
USE_WANDB = True
if USE_WANDB and config.wandb_token:
    wandb.login(key=config.wandb_token)
    print("Weights & Biases configured successfully")
else:
    print("Training without Weights & Biases logging")
    config.training.report_to = None

# Training configuration
TRAIN_CONFIG = {
    'num_train_epochs': 3,
    'per_device_train_batch_size': 2,
    'learning_rate': 2e-4,
    'weight_decay': 0.01,
    'warmup_steps': 100,
    'gradient_accumulation_steps': 4,
    'logging_steps': 10,
    'eval_steps': 500,
    'save_steps': 500,
    'fp16': True,
    'gradient_checkpointing': True
}

# Model configuration
MODEL_NAME = "openlm-research/open_llama_7b"
OUTPUT_DIR = "../models/constitutional_law_trained"

print(f"Model: {MODEL_NAME}")
print(f"Output directory: {OUTPUT_DIR}")
print(f"Training config: {TRAIN_CONFIG}")

## Data Preparation

Load and inspect the training data:

In [None]:
# Check if processed data exists
train_file = "../data/processed/train_cleaned.jsonl"
val_file = "../data/processed/validation_cleaned.jsonl"

if not os.path.exists(train_file) or not os.path.exists(val_file):
    print("Processed data not found. Running preprocessing...")
    preprocess_data("../data/raw", "../data/processed")
    print("Preprocessing completed.")
else:
    print("Using existing processed data.")

# Load and inspect data
import json
from datasets import load_dataset

train_ds = load_dataset('json', data_files=train_file)['train']
val_ds = load_dataset('json', data_files=val_file)['train']

print(f"Training examples: {len(train_ds)}")
print(f"Validation examples: {len(val_ds)}")

# Show example
print("\nExample training instance:")
example = train_ds[0]
print(f"Case: {example['name']}")
print(f"Instruction: {example['instruction'][:200]}...")
print(f"Response: {example['response'][:200]}...")

## Model Training

Train the model with LoRA fine-tuning:

In [None]:
# Initialize trainer
trainer = ConstitutionalLawTrainer(MODEL_NAME)

print("Starting training...")
print(f"Base model: {MODEL_NAME}")
print(f"LoRA config: r={config.lora.r}, alpha={config.lora.lora_alpha}, dropout={config.lora.lora_dropout}")

# Train model
training_results = trainer.train(
    train_file=train_file,
    val_file=val_file,
    output_dir=OUTPUT_DIR,
    **TRAIN_CONFIG
)

print("\nTraining completed!")
print(f"Final training loss: {training_results['training_metrics'].get('train_loss', 'N/A')}")
print(f"Final validation loss: {training_results['evaluation_metrics'].get('eval_loss', 'N/A')}")
print(f"Final validation accuracy: {training_results['evaluation_metrics'].get('eval_accuracy', 'N/A')}")
print(f"Model saved to: {training_results['model_path']}")

## Model Testing

Test the trained model with sample questions:

In [None]:
# Test the trained model
def test_model(model_path, test_questions):
    """Test the model with sample questions."""
    
    # Load trained model
    model_manager = ModelManager()
    model_manager.load_model_local(model_path, MODEL_NAME)
    
    print("Testing trained model...\n")
    
    for i, test_case in enumerate(test_questions, 1):
        print(f"Test Case {i}:")
        print(f"Facts: {test_case['facts']}")
        print(f"Question: {test_case['question']}")
        
        # Generate response
        response = model_manager.generate_response(
            test_case['facts'], 
            test_case['question']
        )
        
        print(f"Model Response: {response}")
        print("-" * 80)

# Sample test cases
test_questions = [
    {
        "facts": "A high school student wore a black armband to school to protest the Vietnam War. The school suspended the student for violating a policy against political demonstrations.",
        "question": "Did the school's suspension of the student for wearing the armband violate the First Amendment?"
    },
    {
        "facts": "Police officers conducted a warrantless search of a person's home after pursuing them for a misdemeanor traffic violation.",
        "question": "Was the warrantless search constitutional under the Fourth Amendment?"
    },
    {
        "facts": "A city passed an ordinance banning all public demonstrations in the downtown area, citing traffic concerns.",
        "question": "Does this ordinance violate the First Amendment right to freedom of assembly?"
    }
]

test_model(OUTPUT_DIR, test_questions)

## Evaluation

Run comprehensive evaluation on test cases:

In [None]:
# Load test cases and evaluate
test_file = "../evaluation/test_cases.json"

if os.path.exists(test_file):
    # Run evaluation
    evaluation_results = trainer.evaluate(test_file)
    
    print(f"Evaluation completed on {evaluation_results['total_cases']} test cases")
    
    # Show some results
    for i, result in enumerate(evaluation_results['results'][:3]):
        print(f"\nTest Case {i+1}:")
        print(f"Question: {result['question'][:100]}...")
        print(f"Generated: {result['generated'][:150]}...")
        print(f"Reference: {result['reference'][:150]}...")
        
else:
    print(f"Test file not found: {test_file}")
    print("Using sample test cases instead.")

## Save and Export

Save the model for future use:

In [None]:
# Save model locally (already done during training)
print(f"Model saved locally at: {OUTPUT_DIR}")

# Optionally save to Hugging Face Hub
SAVE_TO_HUB = False  # Set to True to upload to HF Hub
HF_MODEL_NAME = "your-username/constitutional-law-llama"  # Change this

if SAVE_TO_HUB and config.hf_token:
    print("Uploading to Hugging Face Hub...")
    trainer.model_manager.save_model_hub(HF_MODEL_NAME)
    print(f"Model uploaded to: https://huggingface.co/{HF_MODEL_NAME}")
else:
    print("Skipping Hugging Face Hub upload")

# Save training summary
summary = {
    "model_name": MODEL_NAME,
    "training_config": TRAIN_CONFIG,
    "lora_config": {
        "r": config.lora.r,
        "alpha": config.lora.lora_alpha,
        "dropout": config.lora.lora_dropout
    },
    "data_stats": {
        "train_examples": len(train_ds),
        "val_examples": len(val_ds)
    },
    "results": training_results
}

summary_file = os.path.join(OUTPUT_DIR, "training_summary.json")
with open(summary_file, 'w') as f:
    json.dump(summary, f, indent=2)

print(f"Training summary saved to: {summary_file}")
print("\nTraining completed successfully!")