In [None]:
# Fine-Tuning Mistral-7B for Legal QA
## RAG vs Fine-Tuning: A Comparative Study

This notebook fine-tunes the Mistral-7B model on the processed Indian Legal dataset using QLoRA (Quantized LoRA) for efficient training.

**Key Features:**
- QLoRA for memory-efficient training
- Legal domain-specific instruction tuning  
- Comprehensive monitoring and evaluation
- Model saving and deployment preparation


In [None]:
## 1. Setup and Imports


In [None]:
import os
import json
import torch
import numpy as np
import pandas as pd
from datasets import load_from_disk
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    BitsAndBytesConfig
)
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
    TaskType
)
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Check GPU availability
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
else:
    print("Running on CPU - training will be slower")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)


In [None]:
## 2. Load Processed Data


In [None]:
# Load the processed datasets
try:
    train_dataset = load_from_disk('./processed_data/train')
    val_dataset = load_from_disk('./processed_data/val')
    
    print(f"✅ Training dataset loaded: {len(train_dataset)} examples")
    print(f"✅ Validation dataset loaded: {len(val_dataset)} examples")
    
    # Load metadata
    with open('./processed_data/metadata.json', 'r') as f:
        metadata = json.load(f)
    
    print(f"\n📊 Dataset Statistics:")
    for key, value in metadata.items():
        print(f"  {key}: {value}")
        
except FileNotFoundError:
    print("❌ Processed data not found. Please run 1_data_preparation.ipynb first.")
    raise

# Display sample data
print(f"\n📝 Sample Training Example:")
print("=" * 60)
print(train_dataset[0]['text'][:800] + "...")
print("=" * 60)


In [None]:
## 3. Model and Tokenizer Setup


In [None]:
# Model configuration
MODEL_NAME = "mistralai/Mistral-7B-v0.1"
OUTPUT_DIR = "./fine_tuned_legal_mistral"

# QLoRA configuration for efficient training
qlora_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

print(f"🔄 Loading model: {MODEL_NAME}")
print("Using QLoRA (4-bit quantization) for memory efficiency...")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

print(f"✅ Tokenizer loaded")
print(f"  Vocab size: {len(tokenizer)}")
print(f"  Pad token: {tokenizer.pad_token}")

# Load model with quantization
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=qlora_config,
    device_map="auto",
    trust_remote_code=True
)

print(f"✅ Model loaded with 4-bit quantization")
print(f"  Model device: {next(model.parameters()).device}")
print(f"  Model dtype: {next(model.parameters()).dtype}")


In [None]:
## 4. LoRA Configuration and Model Preparation


In [None]:
# Prepare model for k-bit training
model = prepare_model_for_kbit_training(model)

# LoRA configuration for Mistral
lora_config = LoraConfig(
    r=16,  # Rank
    lora_alpha=32,  # Alpha parameter for LoRA scaling
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", 
                   "gate_proj", "up_proj", "down_proj"],  # Mistral attention layers
    lora_dropout=0.05,  # Dropout probability for LoRA layers
    bias="none",  # Bias type
    task_type=TaskType.CAUSAL_LM,  # Task type
)

print("🔧 LoRA Configuration:")
print(f"  Rank (r): {lora_config.r}")
print(f"  Alpha: {lora_config.lora_alpha}")
print(f"  Target modules: {lora_config.target_modules}")
print(f"  Dropout: {lora_config.lora_dropout}")

# Get PEFT model
model = get_peft_model(model, lora_config)

# Print trainable parameters
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
    all_param += param.numel()
    if param.requires_grad:
        trainable_params += param.numel()

print(f"\n📊 Model Parameters:")
print(f"  Trainable params: {trainable_params:,}")
print(f"  All params: {all_param:,}")
print(f"  Trainable%: {100 * trainable_params / all_param:.2f}%")

# Enable gradient checkpointing for memory efficiency
model.gradient_checkpointing_enable()
model.config.use_cache = False  # Disable cache for training

print(f"\n✅ Model prepared for training with LoRA")
print(f"  Gradient checkpointing: Enabled")
print(f"  Cache: Disabled for training")


In [None]:
## 5. Training Configuration


In [None]:
# Data collator for language modeling
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,  # We're doing causal language modeling, not masked
    pad_to_multiple_of=8,  # For efficiency on modern hardware
)

# Training arguments for QLoRA fine-tuning
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=1,  # Small batch size for memory efficiency
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=8,  # Effective batch size = 1 * 8 = 8
    num_train_epochs=3,  # Start with 3 epochs
    learning_rate=2e-4,  # Higher learning rate for LoRA
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    logging_steps=10,
    eval_steps=50,
    save_steps=100,
    evaluation_strategy="steps",
    save_strategy="steps",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    fp16=True,  # Use mixed precision for efficiency
    dataloader_pin_memory=False,
    remove_unused_columns=False,
    report_to=None,  # Disable wandb/tensorboard for now
    run_name="mistral-legal-qa-lora",
    seed=42,
    data_seed=42,
    optim="adamw_torch",
    max_grad_norm=1.0,
    group_by_length=True,  # Group similar length sequences for efficiency
)

print("🔧 Training Configuration:")
print(f"  Output directory: {training_args.output_dir}")
print(f"  Batch size (per device): {training_args.per_device_train_batch_size}")
print(f"  Gradient accumulation: {training_args.gradient_accumulation_steps}")
print(f"  Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f"  Learning rate: {training_args.learning_rate}")
print(f"  Number of epochs: {training_args.num_train_epochs}")
print(f"  Scheduler: {training_args.lr_scheduler_type}")
print(f"  Mixed precision: {training_args.fp16}")
print(f"  Optimizer: {training_args.optim}")

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"✅ Training configuration ready")


In [None]:
## 6. Data Preprocessing for Training


In [None]:
def tokenize_function(examples):
    """Tokenize the training examples"""
    # Tokenize the text
    tokenized = tokenizer(
        examples["text"],
        truncation=True,
        padding=False,  # We'll pad dynamically
        max_length=2048,
        return_tensors=None,
    )
    
    # For causal language modeling, labels are the same as input_ids
    tokenized["labels"] = tokenized["input_ids"].copy()
    
    return tokenized

print("🔄 Tokenizing datasets...")

# Tokenize the datasets
tokenized_train = train_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=train_dataset.column_names,
    desc="Tokenizing training data"
)

tokenized_val = val_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=val_dataset.column_names,
    desc="Tokenizing validation data"
)

print(f"✅ Tokenization complete")
print(f"  Training examples: {len(tokenized_train)}")
print(f"  Validation examples: {len(tokenized_val)}")

# Check a sample
sample_tokens = tokenized_train[0]
print(f"\n📊 Sample tokenized example:")
print(f"  Input length: {len(sample_tokens['input_ids'])} tokens")
print(f"  Labels length: {len(sample_tokens['labels'])} tokens")
print(f"  First 10 tokens: {sample_tokens['input_ids'][:10]}")

# Check token length distribution
train_lengths = [len(example['input_ids']) for example in tokenized_train]
print(f"\n📏 Token length statistics:")
print(f"  Mean: {np.mean(train_lengths):.1f}")
print(f"  Max: {np.max(train_lengths)}")
print(f"  Min: {np.min(train_lengths)}")
print(f"  Std: {np.std(train_lengths):.1f}")


In [None]:
## 7. Initialize Trainer and Start Training


In [None]:
# Initialize the trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    data_collator=data_collator,
    tokenizer=tokenizer,
)

print("🚀 Trainer initialized successfully!")
print(f"📊 Training overview:")
print(f"  Model: {MODEL_NAME}")
print(f"  Training samples: {len(tokenized_train):,}")
print(f"  Validation samples: {len(tokenized_val):,}")
print(f"  Epochs: {training_args.num_train_epochs}")
print(f"  Total steps: {len(tokenized_train) // (training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps) * training_args.num_train_epochs}")

# Check GPU memory before training
if torch.cuda.is_available():
    print(f"\n🔧 GPU Memory Status:")
    print(f"  Allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
    print(f"  Reserved: {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB")
    print(f"  Free: {(torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_reserved(0)) / 1024**3:.2f} GB")

print(f"\n⏰ Starting training at {pd.Timestamp.now()}")
print("=" * 60)


In [None]:
# Start training
try:
    print("🔥 Training started...")
    training_output = trainer.train()
    
    print(f"\n✅ Training completed successfully!")
    print(f"📊 Training results:")
    print(f"  Final training loss: {training_output.training_loss:.4f}")
    print(f"  Training time: {training_output.metrics['train_runtime']:.2f} seconds")
    print(f"  Samples per second: {training_output.metrics['train_samples_per_second']:.2f}")
    print(f"  Steps per second: {training_output.metrics['train_steps_per_second']:.2f}")
    
    # Get final evaluation
    print("\n📈 Running final evaluation...")
    eval_results = trainer.evaluate()
    print(f"Final evaluation loss: {eval_results['eval_loss']:.4f}")
    print(f"Perplexity: {np.exp(eval_results['eval_loss']):.2f}")
    
except Exception as e:
    print(f"❌ Training failed with error: {e}")
    print("This might be due to memory constraints or other issues")
    print("Consider reducing batch size or sequence length")
    raise

print(f"\n⏰ Training finished at {pd.Timestamp.now()}")

# Save the final model
print(f"\n💾 Saving fine-tuned model...")
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)

print(f"✅ Model saved to: {OUTPUT_DIR}")
print(f"   - Model files: adapter_config.json, adapter_model.bin")
print(f"   - Tokenizer files: tokenizer.json, tokenizer_config.json")


In [None]:
## 8. Model Testing and Inference


In [None]:
# Test the fine-tuned model
def test_legal_qa(question, context, max_length=512):
    """Test the fine-tuned model on a legal QA task"""
    
    # Format the input like our training data
    prompt = f"""<s>[INST] You are a legal AI assistant specializing in Indian law. Based on the provided legal document, answer the following question accurately and comprehensively.

Legal Document:
{context[:800]}

Question: {question} [/INST]

Based on the legal document provided, I can analyze that:"""
    
    # Tokenize the input
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1500)
    
    # Move to same device as model
    if torch.cuda.is_available():
        inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    # Generate response
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_length,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.1
        )
    
    # Decode the response
    full_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract just the generated part (after [/INST])
    if "[/INST]" in full_response:
        response = full_response.split("[/INST]")[-1].strip()
    else:
        response = full_response
    
    return response

# Test with a sample from our validation set
print("🧪 Testing the fine-tuned model...")

# Get a sample legal document
sample_text = """The Corporation made available to the Contractors different kinds of machinery and equipment for which the price paid by the Corporation inclusive of freight, insurance, customs duty etc. has to be charged to them. But the machinery and the equipment so made available to the Contractors were to remain the property of the Corporation until the full price thereof had been realised from the Contractors. There is a further condition that the Corporation will take over from the contractors items at their residual value calculated on the basis indicated in the agreement. The total approximate price is payable by the Contractors in 18 equal instalments."""

test_questions = [
    "What is the ownership arrangement for the machinery?",
    "How is the payment structured for the machinery?",
    "What happens to the machinery after full payment?",
    "What are the key terms of this agreement?"
]

print("📝 Sample Legal QA Responses:")
print("=" * 80)

for i, question in enumerate(test_questions):
    print(f"\n🔍 Question {i+1}: {question}")
    print("-" * 60)
    
    try:
        response = test_legal_qa(question, sample_text)
        print(f"🤖 Model Response:")
        print(response[:400] + "..." if len(response) > 400 else response)
    except Exception as e:
        print(f"❌ Error generating response: {e}")
    
    print("-" * 60)

print(f"\n✅ Model testing completed!")
print(f"💡 The model can now answer legal questions based on Indian legal documents")


In [None]:
## 9. Save Training Metrics and Results


In [None]:
# Save comprehensive training results for paper
training_results = {
    'model_name': MODEL_NAME,
    'approach': 'fine_tuning',
    'method': 'QLoRA',
    'dataset': 'ninadn/indian-legal',
    'training_params': {
        'epochs': training_args.num_train_epochs,
        'learning_rate': training_args.learning_rate,
        'batch_size': training_args.per_device_train_batch_size,
        'gradient_accumulation_steps': training_args.gradient_accumulation_steps,
        'lora_r': lora_config.r,
        'lora_alpha': lora_config.lora_alpha,
        'lora_dropout': lora_config.lora_dropout,
    },
    'data_stats': {
        'train_examples': len(tokenized_train),
        'val_examples': len(tokenized_val),
        'avg_tokens': np.mean(train_lengths) if 'train_lengths' in locals() else 0,
        'max_tokens': np.max(train_lengths) if 'train_lengths' in locals() else 0,
    },
    'training_results': {
        'final_train_loss': training_output.training_loss if 'training_output' in locals() else None,
        'final_eval_loss': eval_results['eval_loss'] if 'eval_results' in locals() else None,
        'perplexity': np.exp(eval_results['eval_loss']) if 'eval_results' in locals() else None,
        'training_time': training_output.metrics['train_runtime'] if 'training_output' in locals() else None,
        'samples_per_second': training_output.metrics['train_samples_per_second'] if 'training_output' in locals() else None,
    },
    'model_info': {
        'trainable_parameters': trainable_params if 'trainable_params' in locals() else None,
        'total_parameters': all_param if 'all_param' in locals() else None,
        'trainable_percentage': (100 * trainable_params / all_param) if 'trainable_params' in locals() and 'all_param' in locals() else None,
    },
    'timestamp': pd.Timestamp.now().isoformat(),
    'output_dir': OUTPUT_DIR
}

# Save results
with open(f'{OUTPUT_DIR}/training_results.json', 'w') as f:
    json.dump(training_results, f, indent=2)

print(f"📊 Training results saved to: {OUTPUT_DIR}/training_results.json")

# Display summary for conference paper
print(f"\n📋 FINE-TUNING SUMMARY FOR CONFERENCE PAPER")
print("=" * 60)
print(f"🔬 **Method**: QLoRA Fine-tuning of Mistral-7B")
print(f"📊 **Dataset**: Indian Legal documents (ninadn/indian-legal)")
print(f"📝 **Training Examples**: {len(tokenized_train):,}")
print(f"🔧 **Parameters**: {trainable_params:,} trainable ({(100 * trainable_params / all_param):.2f}% of total)")
print(f"⏱️  **Training Time**: {training_output.metrics['train_runtime']:.1f} seconds" if 'training_output' in locals() else "⏱️  Training Time: Not recorded")
print(f"📉 **Final Loss**: {eval_results['eval_loss']:.4f}" if 'eval_results' in locals() else "📉 Final Loss: Not recorded")
print(f"🎯 **Perplexity**: {np.exp(eval_results['eval_loss']):.2f}" if 'eval_results' in locals() else "🎯 Perplexity: Not recorded")
print(f"💾 **Model Size**: LoRA adapters (~{(trainable_params * 4) / (1024**2):.1f} MB)")
print("=" * 60)

print(f"\n✅ Fine-tuning pipeline completed successfully!")
print(f"🎯 Ready for comparison with RAG approach")
print(f"📁 All files saved in: {OUTPUT_DIR}")


In [None]:
## 5. Training Configuration
