# Legal Document Role Classifier Training Notebook

This notebook provides a comprehensive guide to train your own rhetorical role classifier for legal documents using your existing dataset.

## Dataset Structure
Your data should be in the format:
```
sentence1\trole1
sentence2\trole2
\n
sentence1\trole1  # New document
sentence2\trole2
```

## Supported Roles
- Facts
- Issue
- Arguments of Petitioner
- Arguments of Respondent
- Reasoning
- Decision
- None

In [None]:
# Install required packages if needed
!pip install torch transformers scikit-learn pandas matplotlib seaborn spacy
!python -m spacy download en_core_web_sm

In [None]:
# Setup and imports
import os
import sys
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

# Add your project path
PROJECT_ROOT = "/home/uttam/B.Tech Major Project/nyaya/server"
sys.path.append(PROJECT_ROOT)
sys.path.append(os.path.join(PROJECT_ROOT, "src", "models", "training"))

print(f"Project root: {PROJECT_ROOT}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

In [None]:
# Import training modules
try:
    from train import RoleClassifierTrainer
    from data_loader import create_data_loaders, LegalDocumentDataset
    from evaluate import ModelEvaluator
    from src.models.role_classifier import RoleClassifier, RhetoricalRole
    print("✅ Successfully imported training modules")
except ImportError as e:
    print(f"❌ Import error: {e}")
    print("Please ensure you're running from the correct directory")

## Configuration

In [None]:
# Training Configuration
config = {
    # Data paths - Update these to match your dataset location
    "train_data": "/home/uttam/B.Tech Major Project/nyaya/Hier_BiLSTM_CRF-20250827T113256Z-1-001/Hier_BiLSTM_CRF/train",
    "val_data": "/home/uttam/B.Tech Major Project/nyaya/Hier_BiLSTM_CRF-20250827T113256Z-1-001/Hier_BiLSTM_CRF/val", 
    "test_data": "/home/uttam/B.Tech Major Project/nyaya/Hier_BiLSTM_CRF-20250827T113256Z-1-001/Hier_BiLSTM_CRF/test",
    
    # Model configuration
    "model_type": "inlegalbert",  # Options: "inlegalbert", "bilstm_crf"
    "model_name": "law-ai/InLegalBERT",  # Pre-trained model
    "context_mode": "prev",  # Options: "single", "prev", "prev_two", "surrounding"
    
    # Training hyperparameters
    "batch_size": 16,
    "num_epochs": 10,
    "learning_rate": 2e-5,
    "weight_decay": 0.01,
    "max_length": 512,
    "warmup_steps": 500,
    
    # Device and output
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "output_dir": "./trained_models",
    "save_best_model": True
}

print("Configuration:")
for key, value in config.items():
    print(f"  {key}: {value}")

## Data Exploration

In [None]:
# Explore your dataset
def explore_dataset(data_path):
    """Explore the structure and statistics of your dataset"""
    data_path = Path(data_path)
    
    if not data_path.exists():
        print(f"❌ Data path does not exist: {data_path}")
        return
    
    print(f"📂 Exploring dataset: {data_path}")
    
    if data_path.is_file():
        files = [data_path]
    else:
        files = list(data_path.glob("*.txt"))
    
    print(f"📄 Found {len(files)} files")
    
    total_sentences = 0
    total_documents = 0
    role_counts = {}
    
    for file_path in files[:5]:  # Check first 5 files
        print(f"\n📝 File: {file_path.name}")
        
        with open(file_path, 'r', encoding='utf-8') as f:
            content = f.read().strip()
        
        lines = content.split('\n')
        doc_sentences = 0
        
        for line in lines:
            line = line.strip()
            if not line:
                if doc_sentences > 0:
                    total_documents += 1
                    doc_sentences = 0
                continue
            
            parts = line.split('\t')
            if len(parts) >= 2:
                sentence = parts[0].strip()
                role = parts[1].strip()
                
                total_sentences += 1
                doc_sentences += 1
                role_counts[role] = role_counts.get(role, 0) + 1
        
        if doc_sentences > 0:
            total_documents += 1
        
        print(f"  Sentences in this file: {doc_sentences}")
    
    print(f"\n📊 Dataset Statistics:")
    print(f"  Total files: {len(files)}")
    print(f"  Total documents: {total_documents}")
    print(f"  Total sentences: {total_sentences}")
    print(f"  Average sentences per document: {total_sentences/max(total_documents, 1):.1f}")
    
    print(f"\n🏷️ Role Distribution:")
    for role, count in sorted(role_counts.items(), key=lambda x: x[1], reverse=True):
        percentage = (count / total_sentences) * 100
        print(f"  {role}: {count} ({percentage:.1f}%)")
    
    return role_counts

# Explore training data
print("🔍 Exploring Training Data")
train_role_counts = explore_dataset(config["train_data"])

In [None]:
# Visualize role distribution
if train_role_counts:
    plt.figure(figsize=(12, 6))
    
    roles = list(train_role_counts.keys())
    counts = list(train_role_counts.values())
    
    plt.bar(roles, counts, color='skyblue', alpha=0.7)
    plt.title('Role Distribution in Training Data', fontsize=16)
    plt.xlabel('Rhetorical Role', fontsize=12)
    plt.ylabel('Number of Sentences', fontsize=12)
    plt.xticks(rotation=45, ha='right')
    
    # Add value labels on bars
    for i, v in enumerate(counts):
        plt.text(i, v + max(counts)*0.01, str(v), ha='center', va='bottom')
    
    plt.tight_layout()
    plt.show()
    
    # Pie chart
    plt.figure(figsize=(10, 8))
    plt.pie(counts, labels=roles, autopct='%1.1f%%', startangle=90)
    plt.title('Role Distribution (Percentage)', fontsize=16)
    plt.axis('equal')
    plt.show()

## Data Loading and Validation

In [None]:
# Test data loading
print("🔄 Testing Data Loading...")

try:
    # Create data loaders
    data_loaders = create_data_loaders(
        train_path=config["train_data"],
        val_path=config["val_data"],
        test_path=config["test_data"],
        tokenizer_name=config["model_name"],
        context_mode=config["context_mode"],
        batch_size=config["batch_size"],
        max_length=config["max_length"]
    )
    
    print("✅ Data loaders created successfully!")
    print(f"📦 Training batches: {len(data_loaders['train'])}")
    print(f"📦 Validation batches: {len(data_loaders['val'])}")
    if 'test' in data_loaders:
        print(f"📦 Test batches: {len(data_loaders['test'])}")
    
    # Check a sample batch
    sample_batch = next(iter(data_loaders['train']))
    print(f"\n🔍 Sample Batch Shape:")
    print(f"  Input IDs: {sample_batch['input_ids'].shape}")
    print(f"  Attention Mask: {sample_batch['attention_mask'].shape}")
    print(f"  Labels: {sample_batch['labels'].shape}")
    print(f"  Unique labels in batch: {torch.unique(sample_batch['labels'])}")
    
except Exception as e:
    print(f"❌ Error loading data: {e}")
    print("Please check your data paths and format")

In [None]:
# Sample some examples from the dataset
try:
    sample_batch = next(iter(data_loaders['train']))
    
    print("📝 Sample Training Examples:")
    print("=" * 80)
    
    # Show first 3 examples
    for i in range(min(3, len(sample_batch['text']))):
        text = sample_batch['text'][i]
        label_id = sample_batch['labels'][i].item()
        
        # Map label ID to role name
        role_names = [role.value for role in RhetoricalRole]
        role_name = role_names[label_id] if label_id < len(role_names) else "Unknown"
        
        print(f"\nExample {i+1}:")
        print(f"Text: {text[:200]}{'...' if len(text) > 200 else ''}")
        print(f"Role: {role_name} (ID: {label_id})")
        print("-" * 40)
        
except Exception as e:
    print(f"Error sampling examples: {e}")

## Model Training

In [None]:
# Initialize trainer
print("🚀 Initializing Role Classifier Trainer...")

# Create output directory with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = Path(config["output_dir"]) / f"{config['model_type']}_{timestamp}"
output_dir.mkdir(parents=True, exist_ok=True)

print(f"📂 Output directory: {output_dir}")

try:
    trainer = RoleClassifierTrainer(
        model_type=config["model_type"],
        model_name=config["model_name"],
        device=config["device"],
        output_dir=str(output_dir),
        num_labels=7  # 7 rhetorical roles
    )
    
    print("✅ Trainer initialized successfully!")
    print(f"🖥️  Using device: {config['device']}")
    print(f"🤖 Model type: {config['model_type']}")
    print(f"📏 Model parameters: {sum(p.numel() for p in trainer.model.parameters()):,}")
    
except Exception as e:
    print(f"❌ Error initializing trainer: {e}")

In [None]:
# Start training
print("🎯 Starting Training...")
print(f"⏱️  Training for {config['num_epochs']} epochs")
print(f"📚 Batch size: {config['batch_size']}")
print(f"🧠 Learning rate: {config['learning_rate']}")
print(f"📝 Context mode: {config['context_mode']}")
print("=" * 60)

try:
    # Train the model
    trainer.train(
        train_data_path=config["train_data"],
        val_data_path=config["val_data"],
        test_data_path=config["test_data"],
        context_mode=config["context_mode"],
        batch_size=config["batch_size"],
        num_epochs=config["num_epochs"],
        learning_rate=config["learning_rate"],
        weight_decay=config["weight_decay"],
        warmup_steps=config["warmup_steps"],
        max_length=config["max_length"],
        save_best_model=config["save_best_model"]
    )
    
    print("\n🎉 Training completed successfully!")
    
except Exception as e:
    print(f"❌ Training failed: {e}")
    import traceback
    traceback.print_exc()

## Training Results Analysis

In [None]:
# Load and display training history
history_path = output_dir / "training_history.json"

if history_path.exists():
    with open(history_path, 'r') as f:
        history = json.load(f)
    
    print("📈 Training History:")
    print("=" * 50)
    
    # Display final metrics
    if history:
        final_train_loss = history['train_loss'][-1] if history['train_loss'] else 'N/A'
        final_val_loss = history['val_loss'][-1] if history['val_loss'] else 'N/A'
        final_train_f1 = history['train_f1'][-1] if history['train_f1'] else 'N/A'
        final_val_f1 = history['val_f1'][-1] if history['val_f1'] else 'N/A'
        
        print(f"Final Training Loss: {final_train_loss:.4f}" if isinstance(final_train_loss, float) else f"Final Training Loss: {final_train_loss}")
        print(f"Final Validation Loss: {final_val_loss:.4f}" if isinstance(final_val_loss, float) else f"Final Validation Loss: {final_val_loss}")
        print(f"Final Training F1: {final_train_f1:.4f}" if isinstance(final_train_f1, float) else f"Final Training F1: {final_train_f1}")
        print(f"Final Validation F1: {final_val_f1:.4f}" if isinstance(final_val_f1, float) else f"Final Validation F1: {final_val_f1}")
        
        # Plot training curves
        if all(key in history and history[key] for key in ['epoch', 'train_loss', 'val_loss', 'train_f1', 'val_f1']):
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
            
            # Loss plot
            ax1.plot(history['epoch'], history['train_loss'], label='Train Loss', marker='o')
            ax1.plot(history['epoch'], history['val_loss'], label='Val Loss', marker='s')
            ax1.set_xlabel('Epoch')
            ax1.set_ylabel('Loss')
            ax1.set_title('Training and Validation Loss')
            ax1.legend()
            ax1.grid(True, alpha=0.3)
            
            # F1 Score plot
            ax2.plot(history['epoch'], history['train_f1'], label='Train F1', marker='o')
            ax2.plot(history['epoch'], history['val_f1'], label='Val F1', marker='s')
            ax2.set_xlabel('Epoch')
            ax2.set_ylabel('F1 Score')
            ax2.set_title('Training and Validation F1 Score')
            ax2.legend()
            ax2.grid(True, alpha=0.3)
            
            plt.tight_layout()
            plt.show()
else:
    print("❌ Training history not found")

## Model Evaluation

In [None]:
# Evaluate the trained model
print("🔍 Evaluating Trained Model...")

# Path to the best model
best_model_path = output_dir / "best_model.pt"

if best_model_path.exists():
    try:
        # Initialize evaluator
        evaluator = ModelEvaluator(
            model_path=str(best_model_path),
            device=config["device"]
        )
        
        print("✅ Evaluator initialized")
        
        # Create evaluation output directory
        eval_output_dir = output_dir / "evaluation_results"
        
        # Evaluate on test set
        metrics = evaluator.evaluate_dataset(
            test_data_path=config["test_data"],
            context_mode=config["context_mode"],
            batch_size=config["batch_size"],
            output_dir=str(eval_output_dir)
        )
        
        print("\n🎯 Evaluation Results:")
        print("=" * 50)
        print(f"📊 Accuracy: {metrics['accuracy']:.4f}")
        print(f"📊 Weighted F1: {metrics['weighted_f1']:.4f}")
        print(f"📊 Macro F1: {metrics['macro_f1']:.4f}")
        print(f"📊 Weighted Precision: {metrics['weighted_precision']:.4f}")
        print(f"📊 Weighted Recall: {metrics['weighted_recall']:.4f}")
        
        # Display per-class metrics
        if 'per_class' in metrics:
            print("\n📈 Per-Class Metrics:")
            print("-" * 60)
            for role, class_metrics in metrics['per_class'].items():
                print(f"{role:20} | F1: {class_metrics['f1']:.3f} | Prec: {class_metrics['precision']:.3f} | Rec: {class_metrics['recall']:.3f}")
        
    except Exception as e:
        print(f"❌ Evaluation failed: {e}")
        import traceback
        traceback.print_exc()
else:
    print(f"❌ Best model not found at {best_model_path}")

## Test Single Predictions

In [None]:
# Test single predictions
if 'evaluator' in locals():
    print("🧪 Testing Single Predictions...")
    
    # Test sentences
    test_sentences = [
        "The petitioner filed a writ petition challenging the constitutional validity of Section 377.",
        "The main issue in this case is whether Section 377 violates fundamental rights.",
        "The petitioner argues that Section 377 is discriminatory and violates Article 14.",
        "The respondent contends that Section 377 is constitutionally valid and necessary.",
        "The court finds that Section 377 infringes upon the right to privacy and equality.",
        "Therefore, Section 377 is hereby declared unconstitutional and is struck down."
    ]
    
    expected_roles = ["Facts", "Issue", "Arguments of Petitioner", "Arguments of Respondent", "Reasoning", "Decision"]
    
    print("\n📝 Prediction Results:")
    print("=" * 100)
    
    correct_predictions = 0
    
    for i, (sentence, expected) in enumerate(zip(test_sentences, expected_roles)):
        result = evaluator.predict_single(sentence, context_mode=config["context_mode"])
        
        predicted_role = result['predicted_role']
        confidence = result['confidence']
        
        is_correct = predicted_role == expected
        if is_correct:
            correct_predictions += 1
        
        status = "✅" if is_correct else "❌"
        
        print(f"\n{status} Example {i+1}:")
        print(f"Text: {sentence[:80]}{'...' if len(sentence) > 80 else ''}")
        print(f"Expected: {expected}")
        print(f"Predicted: {predicted_role} (Confidence: {confidence:.3f})")
        
        # Show top predictions
        print("Top 3 predictions:")
        for j, pred in enumerate(result['top_predictions'][:3]):
            print(f"  {j+1}. {pred['role']}: {pred['confidence']:.3f}")
        print("-" * 80)
    
    accuracy = correct_predictions / len(test_sentences)
    print(f"\n🎯 Test Accuracy: {correct_predictions}/{len(test_sentences)} ({accuracy:.1%})")
else:
    print("❌ Evaluator not available. Please complete the evaluation step first.")

## Save and Load Model for Production

In [None]:
# Demonstrate how to save and load the model for production use
print("💾 Model Save/Load for Production")

if best_model_path.exists():
    print(f"\n📂 Best model saved at: {best_model_path}")
    
    # Show how to load the model in production
    print("\n🔧 To use this model in your Nyaya system:")
    print("=" * 60)
    
    production_code = f'''
# In your production code (e.g., in role_classifier.py):
from src.models.role_classifier import RoleClassifier

# Initialize classifier
classifier = RoleClassifier(
    model_type="{config['model_type']}",
    device="{config['device']}"
)

# Load your trained weights
classifier.load_pretrained_weights("{best_model_path}")

# Use for classification
results = classifier.classify_document(
    document_text="Your legal document text here...",
    context_mode="{config['context_mode']}"
)
'''
    
    print(production_code)
    
    # Save production instructions
    instructions_path = output_dir / "production_usage.py"
    with open(instructions_path, 'w') as f:
        f.write(production_code)
    
    print(f"\n📄 Production usage instructions saved to: {instructions_path}")
    
    # Model info
    model_info = {
        "model_type": config["model_type"],
        "model_name": config["model_name"],
        "context_mode": config["context_mode"],
        "training_config": config,
        "model_path": str(best_model_path),
        "evaluation_metrics": metrics if 'metrics' in locals() else None,
        "timestamp": timestamp
    }
    
    info_path = output_dir / "model_info.json"
    with open(info_path, 'w') as f:
        json.dump(model_info, f, indent=2, default=str)
    
    print(f"📋 Model information saved to: {info_path}")
else:
    print("❌ No trained model found")

## Summary and Next Steps

In [None]:
# Training summary
print("🎊 TRAINING SUMMARY")
print("=" * 60)

if 'metrics' in locals():
    print(f"✅ Training completed successfully!")
    print(f"📊 Final Test Accuracy: {metrics['accuracy']:.4f}")
    print(f"📊 Final Test F1 Score: {metrics['weighted_f1']:.4f}")
    print(f"📂 Model saved at: {best_model_path}")
    print(f"📂 Results saved at: {output_dir}")
else:
    print("⚠️  Training may not have completed successfully.")
    print("Please check the error messages above.")

print("\n🚀 NEXT STEPS:")
print("1. 📋 Review the evaluation results and confusion matrix")
print("2. 🔧 Integrate the trained model into your Nyaya system")
print("3. 🧪 Test with real legal documents")
print("4. 📈 Consider further fine-tuning if needed")
print("5. 🔄 Update the role_classifier.py to use your trained weights")

print("\n📚 FILES GENERATED:")
if output_dir.exists():
    generated_files = list(output_dir.rglob("*"))
    for file_path in generated_files:
        if file_path.is_file():
            print(f"  📄 {file_path.relative_to(output_dir)}")
else:
    print("  ❌ No output directory found")

## Optional: Hyperparameter Tuning

If you want to experiment with different hyperparameters, you can modify the configuration and re-run the training cells above. Consider trying:

- Different context modes: `"single"`, `"prev_two"`, `"surrounding"`
- Different learning rates: `1e-5`, `3e-5`, `5e-5`
- Different batch sizes: `8`, `32` (depending on your GPU memory)
- More epochs for better convergence
- Different model types: `"bilstm_crf"` for sequence modeling

## Troubleshooting

1. **CUDA Out of Memory**: Reduce batch size or max_length
2. **Low Accuracy**: Try more epochs, different context modes, or data augmentation
3. **Import Errors**: Check file paths and ensure all dependencies are installed
4. **Data Format Issues**: Ensure your data follows the sentence\trole format with proper encoding