In [None]:
# Fine-Tuning Mistral-7B for Legal QA
## RAG vs Fine-Tuning: A Comparative Study

This notebook fine-tunes the Mistral-7B model on the processed Indian Legal dataset using QLoRA (Quantized LoRA) for efficient training.

**Key Features:**
- QLoRA for memory-efficient training
- Legal domain-specific instruction tuning  
- Comprehensive monitoring and evaluation
- Model saving and deployment preparation


In [None]:
## 1. Setup and Imports


In [None]:
import os
import json
import torch
import numpy as np
import pandas as pd
from datasets import load_from_disk
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    BitsAndBytesConfig
)
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
    TaskType
)
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Check GPU availability
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
else:
    print("Running on CPU - training will be slower")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)


In [None]:
## 2. Load Processed Data


In [None]:
# Load the processed datasets
try:
    train_dataset = load_from_disk('./processed_data/train')
    val_dataset = load_from_disk('./processed_data/val')
    
    print(f"✅ Training dataset loaded: {len(train_dataset)} examples")
    print(f"✅ Validation dataset loaded: {len(val_dataset)} examples")
    
    # Load metadata
    with open('./processed_data/metadata.json', 'r') as f:
        metadata = json.load(f)
    
    print(f"\n📊 Dataset Statistics:")
    for key, value in metadata.items():
        print(f"  {key}: {value}")
        
except FileNotFoundError:
    print("❌ Processed data not found. Please run 1_data_preparation.ipynb first.")
    raise

# Display sample data
print(f"\n📝 Sample Training Example:")
print("=" * 60)
print(train_dataset[0]['text'][:800] + "...")
print("=" * 60)


In [None]:
## 3. Model and Tokenizer Setup


In [None]:
# Model configuration
MODEL_NAME = "mistralai/Mistral-7B-v0.1"
OUTPUT_DIR = "./fine_tuned_legal_mistral"

# QLoRA configuration for efficient training
qlora_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

print(f"🔄 Loading model: {MODEL_NAME}")
print("Using QLoRA (4-bit quantization) for memory efficiency...")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

print(f"✅ Tokenizer loaded")
print(f"  Vocab size: {len(tokenizer)}")
print(f"  Pad token: {tokenizer.pad_token}")

# Load model with quantization
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=qlora_config,
    device_map="auto",
    trust_remote_code=True
)

print(f"✅ Model loaded with 4-bit quantization")
print(f"  Model device: {next(model.parameters()).device}")
print(f"  Model dtype: {next(model.parameters()).dtype}")


In [None]:
## 4. LoRA Configuration and Model Preparation


In [None]:
# Prepare model for k-bit training
model = prepare_model_for_kbit_training(model)

# LoRA configuration for Mistral
lora_config = LoraConfig(
    r=16,  # Rank
    lora_alpha=32,  # Alpha parameter for LoRA scaling
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", 
                   "gate_proj", "up_proj", "down_proj"],  # Mistral attention layers
    lora_dropout=0.05,  # Dropout probability for LoRA layers
    bias="none",  # Bias type
    task_type=TaskType.CAUSAL_LM,  # Task type
)

print("🔧 LoRA Configuration:")
print(f"  Rank (r): {lora_config.r}")
print(f"  Alpha: {lora_config.lora_alpha}")
print(f"  Target modules: {lora_config.target_modules}")
print(f"  Dropout: {lora_config.lora_dropout}")

# Get PEFT model
model = get_peft_model(model, lora_config)

# Print trainable parameters
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
    all_param += param.numel()
    if param.requires_grad:
        trainable_params += param.numel()

print(f"\n📊 Model Parameters:")
print(f"  Trainable params: {trainable_params:,}")
print(f"  All params: {all_param:,}")
print(f"  Trainable%: {100 * trainable_params / all_param:.2f}%")

# Enable gradient checkpointing for memory efficiency
model.gradient_checkpointing_enable()
model.config.use_cache = False  # Disable cache for training

print(f"\n✅ Model prepared for training with LoRA")
print(f"  Gradient checkpointing: Enabled")
print(f"  Cache: Disabled for training")


In [None]:
## 5. Training Configuration
