# Legal Document Role Classifier Training Notebook

This notebook provides a comprehensive guide to train your own rhetorical role classifier for legal documents using your existing dataset.

## Dataset Structure
Your data should be in the format:
```
sentence1\trole1
sentence2\trole2
\n
sentence1\trole1  # New document
sentence2\trole2
```

## Supported Roles
- Facts
- Issue
- Arguments of Petitioner
- Arguments of Respondent
- Reasoning
- Decision
- None

In [3]:
# Install required packages if needed
!pip install torch transformers scikit-learn pandas matplotlib seaborn spacy
!python -m spacy download en_core_web_sm

Collecting en-core-web-sm==3.8.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl (12.8 MB)
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m12.8/12.8 MB[0m [31m33.8 MB/s[0m  [33m0:00:00[0mm0:00:01[0m
[?25hInstalling collected packages: en-core-web-sm
Successfully installed en-core-web-sm-3.8.0
[38;5;2m‚úî Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')


In [None]:
# Setup and imports
import os
import sys
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

# Add your project path - Using relative path for portability
NOTEBOOK_DIR = Path.cwd()
PROJECT_ROOT = NOTEBOOK_DIR / "server"
sys.path.append(str(PROJECT_ROOT))
sys.path.append(str(PROJECT_ROOT / "src" / "models" / "training"))

print(f"Notebook directory: {NOTEBOOK_DIR}")
print(f"Project root: {PROJECT_ROOT}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

Project root: /home/nyaya/server
PyTorch version: 2.8.0+cu128
CUDA available: True
CUDA device: Quadro RTX 6000


In [2]:
# Import training modules
try:
    from train import RoleClassifierTrainer
    from data_loader import create_data_loaders, LegalDocumentDataset
    from evaluate import ModelEvaluator
    from src.models.role_classifier import RoleClassifier, RhetoricalRole
    print("‚úÖ Successfully imported training modules")
except ImportError as e:
    print(f"‚ùå Import error: {e}")
    print("Please ensure you're running from the correct directory")

‚ùå Import error: No module named 'tensorboard'
Please ensure you're running from the correct directory


## Configuration

In [None]:
# Training Configuration
config = {
    # Data paths - Using relative paths for portability
    "train_data": str(PROJECT_ROOT / "dataset" / "Hier_BiLSTM_CRF" / "train"),
    "val_data": str(PROJECT_ROOT / "dataset" / "Hier_BiLSTM_CRF" / "val"),
    "test_data": str(PROJECT_ROOT / "dataset" / "Hier_BiLSTM_CRF" / "test"),
    
    # Model configuration
    "model_type": "inlegalbert",  # Options: "inlegalbert", "bilstm_crf"
    "model_name": "law-ai/InLegalBERT",  # Pre-trained model
    "context_mode": "prev",  # Options: "single", "prev", "prev_two", "surrounding"
    
    # Training hyperparameters
    "batch_size": 16,
    "num_epochs": 5,
    "learning_rate": 2e-5,
    "weight_decay": 0.01,
    "max_length": 512,
    "warmup_steps": 500,
    
    # Class imbalance handling
    "use_class_weights": True,  # IMPORTANT: Handle "None" label dominance
    "class_weight_method": "inverse_freq",  # Options: "inverse_freq", "balanced", "manual"
    
    # Device and output
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "output_dir": "./trained_models",
    "save_best_model": True
}

print("Configuration:")
for key, value in config.items():
    print(f"  {key}: {value}")

print("\n‚ö†Ô∏è  Class Imbalance Handling: ENABLED")
print("   ‚Üí This will prevent the model from over-predicting 'None'")

Configuration:
  train_data: /home/nyaya/server/dataset/Hier_BiLSTM_CRF/train
  val_data: /home/nyaya/server/dataset/Hier_BiLSTM_CRF/val
  test_data: /home/nyaya/server/dataset/Hier_BiLSTM_CRF/test
  model_type: inlegalbert
  model_name: law-ai/InLegalBERT
  context_mode: prev
  batch_size: 16
  num_epochs: 10
  learning_rate: 2e-05
  weight_decay: 0.01
  max_length: 512
  warmup_steps: 500
  device: cuda
  output_dir: ./trained_models
  save_best_model: True


## ‚ö†Ô∏è Handling Class Imbalance

**Problem**: The "None" label often dominates the dataset, which can cause the model to:
- Predict "None" too frequently
- Ignore minority classes (Issue, Decision, etc.)
- Achieve high accuracy but poor per-class performance

**Solutions Implemented**:
1. **Class Weights**: Give higher importance to minority classes during training
2. **Focal Loss**: Focus on hard-to-classify examples
3. **Data Filtering**: Option to reduce "None" samples
4. **Balanced Sampling**: Sample from each class equally

In [None]:
# Analyze class distribution to understand the imbalance
def analyze_class_distribution(data_path, sample_size=50):
    """Analyze label distribution across dataset"""
    data_path = Path(data_path)
    files = list(data_path.glob("*.txt"))[:sample_size]
    
    all_labels = []
    
    for file_path in files:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if line and '\t' in line:
                    parts = line.split('\t')
                    if len(parts) >= 2:
                        label = parts[1].strip()
                        all_labels.append(label)
    
    # Count labels
    from collections import Counter
    label_counts = Counter(all_labels)
    total = len(all_labels)
    
    print(f"üìä Class Distribution Analysis (from {len(files)} files):")
    print("=" * 70)
    print(f"{'Label':<30} {'Count':>10} {'Percentage':>12} {'Imbalance Ratio':>15}")
    print("-" * 70)
    
    max_count = max(label_counts.values())
    
    for label, count in sorted(label_counts.items(), key=lambda x: x[1], reverse=True):
        percentage = (count / total) * 100
        imbalance_ratio = max_count / count
        print(f"{label:<30} {count:>10,} {percentage:>11.2f}% {imbalance_ratio:>14.1f}x")
    
    print("=" * 70)
    print(f"Total sentences: {total:,}")
    
    # Calculate class weights for handling imbalance
    class_weights = {}
    for label, count in label_counts.items():
        weight = total / (len(label_counts) * count)
        class_weights[label] = weight
    
    print(f"\nüí° Recommended Class Weights (to balance training):")
    for label, weight in sorted(class_weights.items(), key=lambda x: x[1], reverse=True):
        print(f"  {label:<30} ‚Üí {weight:.3f}")
    
    return label_counts, class_weights

# Analyze training data
print("üîç Analyzing Training Data Distribution...\n")
label_counts, class_weights = analyze_class_distribution(config["train_data"], sample_size=100)

### Strategy Selection

Based on the imbalance severity, choose one or more strategies:

1. **Mild Imbalance (2-5x)**: Use class weights only
2. **Moderate Imbalance (5-20x)**: Use class weights + focal loss
3. **Severe Imbalance (>20x)**: Consider filtering "None" samples or undersampling

**Current Configuration**: The notebook will use **class weights** by default.

In [None]:
# Configure imbalance handling strategies
imbalance_config = {
    # Strategy 1: Use class weights (RECOMMENDED - Always use this)
    "use_class_weights": True,
    
    # Strategy 2: Filter excessive "None" samples (Optional - for severe imbalance)
    "filter_none_samples": False,  # Set to True if "None" > 50% of dataset
    "none_keep_ratio": 0.3,  # Keep only 30% of "None" samples if filtering
    
    # Strategy 3: Focal loss parameters (Optional - for hard examples)
    "use_focal_loss": False,  # Set to True for severe imbalance
    "focal_alpha": 0.25,  # Balance between positive/negative
    "focal_gamma": 2.0,   # Focus on hard examples
    
    # Strategy 4: Oversampling minority classes (Optional)
    "oversample_minority": False,  # Duplicate rare class samples
    "target_balance_ratio": 5.0,  # Max imbalance ratio after balancing
}

print("‚öôÔ∏è Imbalance Handling Configuration:")
print("=" * 60)
for key, value in imbalance_config.items():
    print(f"  {key:<25} ‚Üí {value}")
print("=" * 60)

if imbalance_config["use_class_weights"]:
    print("\n‚úÖ Class weights will be applied during training")
    print("   ‚Üí Minority classes will have higher importance in loss")

if imbalance_config["filter_none_samples"]:
    print("\n‚úÖ 'None' samples will be reduced")
    print(f"   ‚Üí Keeping {imbalance_config['none_keep_ratio']*100:.0f}% of 'None' samples")

if imbalance_config["use_focal_loss"]:
    print("\n‚úÖ Focal loss will be used")
    print("   ‚Üí Model will focus on hard-to-classify examples")

In [None]:
# Visualize class imbalance and weights
if 'label_counts' in locals() and 'class_weights' in locals():
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    
    # Plot 1: Class distribution (shows imbalance problem)
    labels = list(label_counts.keys())
    counts = list(label_counts.values())
    colors = ['red' if label == 'None' else 'skyblue' for label in labels]
    
    ax1.bar(labels, counts, color=colors, alpha=0.7)
    ax1.set_title('‚ö†Ô∏è Class Imbalance Problem\n("None" dominates)', fontsize=14, fontweight='bold')
    ax1.set_xlabel('Rhetorical Role', fontsize=12)
    ax1.set_ylabel('Number of Samples', fontsize=12)
    ax1.tick_params(axis='x', rotation=45)
    
    # Add count labels
    for i, (label, count) in enumerate(zip(labels, counts)):
        ax1.text(i, count + max(counts)*0.01, f'{count:,}', ha='center', va='bottom')
    
    # Plot 2: Class weights (shows solution)
    weight_labels = list(class_weights.keys())
    weight_values = list(class_weights.values())
    colors2 = ['green' if w > 1.0 else 'orange' for w in weight_values]
    
    ax2.bar(weight_labels, weight_values, color=colors2, alpha=0.7)
    ax2.set_title('‚úÖ Class Weights Solution\n(Higher weights = More importance)', fontsize=14, fontweight='bold')
    ax2.set_xlabel('Rhetorical Role', fontsize=12)
    ax2.set_ylabel('Weight Multiplier', fontsize=12)
    ax2.tick_params(axis='x', rotation=45)
    ax2.axhline(y=1.0, color='black', linestyle='--', alpha=0.3, label='Baseline (1.0)')
    
    # Add weight labels
    for i, (label, weight) in enumerate(zip(weight_labels, weight_values)):
        ax2.text(i, weight + max(weight_values)*0.01, f'{weight:.2f}x', ha='center', va='bottom')
    
    ax2.legend()
    plt.tight_layout()
    plt.show()
    
    print("\nüí° Interpretation:")
    print("  LEFT: Raw data shows severe imbalance (None dominates)")
    print("  RIGHT: Class weights compensate by giving minority classes higher importance")
    print("         ‚Üí Rare classes get amplified during training loss calculation")
else:
    print("‚ö†Ô∏è  Run the class distribution analysis cell first!")

### üìö Practical Strategies to Handle "None" Dominance

#### **Strategy 1: Class Weights (RECOMMENDED - Already Enabled)**
‚úÖ **What it does**: Multiplies the loss for each class inversely proportional to its frequency
- "None" (abundant) ‚Üí Low weight (e.g., 0.2x)
- "Issue" (rare) ‚Üí High weight (e.g., 5.0x)

‚úÖ **Pros**: 
- Easy to implement
- No data loss
- Works well for moderate imbalance

‚ùå **Cons**: May not fully solve severe imbalance (>50x ratio)

---

#### **Strategy 2: Filter "None" Samples**
What it does: Randomly discard some "None" samples to balance the dataset

```python
# Example: Keep only 30% of "None" samples
imbalance_config["filter_none_samples"] = True
imbalance_config["none_keep_ratio"] = 0.3
```

‚úÖ **Pros**: Directly balances the dataset
‚ùå **Cons**: Loses potentially useful data

---

#### **Strategy 3: Stratified Sampling**
What it does: Ensure each batch has balanced representation of all classes

‚úÖ **Pros**: Guarantees balanced learning in each batch
‚ùå **Cons**: Requires custom data loader

---

#### **Strategy 4: Two-Stage Training**
1. **Stage 1**: Train on balanced subset (filter "None" heavily)
2. **Stage 2**: Fine-tune on full dataset with class weights

‚úÖ **Pros**: Best of both worlds
‚ùå **Cons**: Takes more time

---

### üéØ **Recommended Approach for Your Dataset**

Based on the file analysis showing severe "None" dominance:

1. **Start with**: Class weights (already enabled in config)
2. **If results are poor**: Enable "None" filtering:
   ```python
   imbalance_config["filter_none_samples"] = True
   imbalance_config["none_keep_ratio"] = 0.4  # Keep 40% of "None"
   ```
3. **Monitor**: Per-class F1 scores (especially for "Issue", "Decision")
4. **Adjust**: If minority classes still perform poorly, reduce `none_keep_ratio` to 0.2-0.3

## Data Exploration

In [4]:
# Explore your dataset
def explore_dataset(data_path):
    """Explore the structure and statistics of your dataset"""
    data_path = Path(data_path)
    
    if not data_path.exists():
        print(f"‚ùå Data path does not exist: {data_path}")
        return
    
    print(f"üìÇ Exploring dataset: {data_path}")
    
    if data_path.is_file():
        files = [data_path]
    else:
        files = list(data_path.glob("*.txt"))
    
    print(f"üìÑ Found {len(files)} files")
    
    total_sentences = 0
    total_documents = 0
    role_counts = {}
    
    for file_path in files[:5]:  # Check first 5 files
        print(f"\nüìù File: {file_path.name}")
        
        with open(file_path, 'r', encoding='utf-8') as f:
            content = f.read().strip()
        
        lines = content.split('\n')
        doc_sentences = 0
        
        for line in lines:
            line = line.strip()
            if not line:
                if doc_sentences > 0:
                    total_documents += 1
                    doc_sentences = 0
                continue
            
            parts = line.split('\t')
            if len(parts) >= 2:
                sentence = parts[0].strip()
                role = parts[1].strip()
                
                total_sentences += 1
                doc_sentences += 1
                role_counts[role] = role_counts.get(role, 0) + 1
        
        if doc_sentences > 0:
            total_documents += 1
        
        print(f"  Sentences in this file: {doc_sentences}")
    
    print(f"\nüìä Dataset Statistics:")
    print(f"  Total files: {len(files)}")
    print(f"  Total documents: {total_documents}")
    print(f"  Total sentences: {total_sentences}")
    print(f"  Average sentences per document: {total_sentences/max(total_documents, 1):.1f}")
    
    print(f"\nüè∑Ô∏è Role Distribution:")
    for role, count in sorted(role_counts.items(), key=lambda x: x[1], reverse=True):
        percentage = (count / total_sentences) * 100
        print(f"  {role}: {count} ({percentage:.1f}%)")
    
    return role_counts

# Explore training data
print("üîç Exploring Training Data")
train_role_counts = explore_dataset(config["train_data"])

üîç Exploring Training Data
üìÇ Exploring dataset: /home/nyaya/server/dataset/Hier_BiLSTM_CRF/train
üìÑ Found 4994 files

üìù File: file_1.txt
  Sentences in this file: 0

üìù File: file_10.txt
  Sentences in this file: 0

üìù File: file_100.txt
  Sentences in this file: 0

üìù File: file_1000.txt
  Sentences in this file: 0

üìù File: file_1002.txt
  Sentences in this file: 0

üìä Dataset Statistics:
  Total files: 4994
  Total documents: 0
  Total sentences: 0
  Average sentences per document: 0.0

üè∑Ô∏è Role Distribution:


In [None]:
# Visualize role distribution
if train_role_counts:
    plt.figure(figsize=(12, 6))
    
    roles = list(train_role_counts.keys())
    counts = list(train_role_counts.values())
    
    plt.bar(roles, counts, color='skyblue', alpha=0.7)
    plt.title('Role Distribution in Training Data', fontsize=16)
    plt.xlabel('Rhetorical Role', fontsize=12)
    plt.ylabel('Number of Sentences', fontsize=12)
    plt.xticks(rotation=45, ha='right')
    
    # Add value labels on bars
    for i, v in enumerate(counts):
        plt.text(i, v + max(counts)*0.01, str(v), ha='center', va='bottom')
    
    plt.tight_layout()
    plt.show()
    
    # Pie chart
    plt.figure(figsize=(10, 8))
    plt.pie(counts, labels=roles, autopct='%1.1f%%', startangle=90)
    plt.title('Role Distribution (Percentage)', fontsize=16)
    plt.axis('equal')
    plt.show()

## Data Loading and Validation

In [None]:
# Test data loading
print("üîÑ Testing Data Loading...")

try:
    # Create data loaders
    data_loaders = create_data_loaders(
        train_path=config["train_data"],
        val_path=config["val_data"],
        test_path=config["test_data"],
        tokenizer_name=config["model_name"],
        context_mode=config["context_mode"],
        batch_size=config["batch_size"],
        max_length=config["max_length"]
    )
    
    print("‚úÖ Data loaders created successfully!")
    print(f"üì¶ Training batches: {len(data_loaders['train'])}")
    print(f"üì¶ Validation batches: {len(data_loaders['val'])}")
    if 'test' in data_loaders:
        print(f"üì¶ Test batches: {len(data_loaders['test'])}")
    
    # Check a sample batch
    sample_batch = next(iter(data_loaders['train']))
    print(f"\nüîç Sample Batch Shape:")
    print(f"  Input IDs: {sample_batch['input_ids'].shape}")
    print(f"  Attention Mask: {sample_batch['attention_mask'].shape}")
    print(f"  Labels: {sample_batch['labels'].shape}")
    print(f"  Unique labels in batch: {torch.unique(sample_batch['labels'])}")
    
except Exception as e:
    print(f"‚ùå Error loading data: {e}")
    print("Please check your data paths and format")

In [None]:
# Sample some examples from the dataset
try:
    sample_batch = next(iter(data_loaders['train']))
    
    print("üìù Sample Training Examples:")
    print("=" * 80)
    
    # Show first 3 examples
    for i in range(min(3, len(sample_batch['text']))):
        text = sample_batch['text'][i]
        label_id = sample_batch['labels'][i].item()
        
        # Map label ID to role name
        role_names = [role.value for role in RhetoricalRole]
        role_name = role_names[label_id] if label_id < len(role_names) else "Unknown"
        
        print(f"\nExample {i+1}:")
        print(f"Text: {text[:200]}{'...' if len(text) > 200 else ''}")
        print(f"Role: {role_name} (ID: {label_id})")
        print("-" * 40)
        
except Exception as e:
    print(f"Error sampling examples: {e}")

## Model Training

## üîß Preprocessing Raw Test/Val Data

**Problem Detected**: Your test and validation datasets contain **raw legal text** without labels (not in `sentence\trole` format).

**Solution**: We need to:
1. Use the **trained model** to predict labels for test/val data
2. Create labeled versions for evaluation
3. Or evaluate directly on raw text if you have gold labels separately

### Two Scenarios:

#### **Scenario A: You have gold labels separately**
- Test/val files are raw text
- Gold labels exist in another file/format
- **Action**: Use the preprocessing cell below

#### **Scenario B: No gold labels (truly unlabeled data)**
- Test/val files are just for inference
- No evaluation possible
- **Action**: Use the model to predict and save results

In [None]:
# Preprocess raw legal documents for testing
import re
import spacy

def preprocess_raw_document(file_path, output_path=None):
    """
    Convert raw legal document to sentence\trole format for testing.
    Initially labels everything as 'None' - will be relabeled by model.
    """
    # Load spacy for sentence segmentation
    try:
        nlp = spacy.load("en_core_web_sm")
    except:
        print("‚ö†Ô∏è  Installing spacy model...")
        import subprocess
        subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"])
        nlp = spacy.load("en_core_web_sm")
    
    # Read raw text
    with open(file_path, 'r', encoding='utf-8') as f:
        raw_text = f.read()
    
    # Clean text
    raw_text = re.sub(r'\s+', ' ', raw_text)  # Remove extra whitespace
    raw_text = raw_text.strip()
    
    # Segment into sentences
    doc = nlp(raw_text)
    sentences = [sent.text.strip() for sent in doc.sents if sent.text.strip()]
    
    # Create labeled format (initially all 'None')
    labeled_lines = []
    for sentence in sentences:
        if len(sentence) > 10:  # Filter very short sentences
            labeled_lines.append(f"{sentence}\tNone")
    
    # Save if output path provided
    if output_path:
        with open(output_path, 'w', encoding='utf-8') as f:
            f.write('\n'.join(labeled_lines))
        print(f"‚úÖ Preprocessed: {output_path}")
    
    return labeled_lines

# Test on a sample file
sample_test_file = Path(config["test_data"]) / "file_6409.txt"
if sample_test_file.exists():
    print("üîç Testing preprocessing on sample file...")
    result = preprocess_raw_document(sample_test_file)
    print(f"üìä Extracted {len(result)} sentences")
    print(f"\nüìù Sample preprocessed output:")
    for line in result[:3]:
        parts = line.split('\t')
        print(f"  Sentence: {parts[0][:80]}...")
        print(f"  Label: {parts[1]}\n")
else:
    print("‚ö†Ô∏è  Sample test file not found. Check your test data path.")

In [None]:
# Batch preprocess all test/val files
def batch_preprocess_dataset(input_dir, output_dir, max_files=None):
    """
    Preprocess all raw files in a directory
    """
    input_path = Path(input_dir)
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)
    
    files = list(input_path.glob("*.txt"))
    if max_files:
        files = files[:max_files]
    
    print(f"üìÇ Processing {len(files)} files from {input_dir}")
    print(f"üìÇ Output directory: {output_dir}")
    
    total_sentences = 0
    
    for i, file_path in enumerate(files, 1):
        output_file = output_path / file_path.name
        labeled_lines = preprocess_raw_document(file_path, output_file)
        total_sentences += len(labeled_lines)
        
        if i % 50 == 0:
            print(f"  Processed {i}/{len(files)} files...")
    
    print(f"\n‚úÖ Preprocessing complete!")
    print(f"   Total files: {len(files)}")
    print(f"   Total sentences: {total_sentences:,}")
    print(f"   Average sentences per file: {total_sentences/len(files):.1f}")
    
    return output_path

# Option to preprocess test and val datasets
preprocess_data = False  # Set to True to preprocess

if preprocess_data:
    print("üîÑ Preprocessing Test Dataset...")
    preprocessed_test_dir = batch_preprocess_dataset(
        input_dir=config["test_data"],
        output_dir=str(PROJECT_ROOT / "dataset" / "Hier_BiLSTM_CRF" / "test_preprocessed"),
        max_files=None  # Process all files
    )
    
    print("\nüîÑ Preprocessing Validation Dataset...")
    preprocessed_val_dir = batch_preprocess_dataset(
        input_dir=config["val_data"],
        output_dir=str(PROJECT_ROOT / "dataset" / "Hier_BiLSTM_CRF" / "val_preprocessed"),
        max_files=None
    )
    
    # Update config to use preprocessed data
    config["test_data"] = str(preprocessed_test_dir)
    config["val_data"] = str(preprocessed_val_dir)
    
    print("\n‚úÖ Config updated to use preprocessed data")
else:
    print("‚ö†Ô∏è  Preprocessing disabled. Set preprocess_data=True to enable.")
    print("   Current approach: Assuming test/val data is already in correct format.")

### Alternative: Predict and Evaluate Without Gold Labels

If you don't have gold labels for test/val data, you can:
1. **Use the trained model to predict** on raw documents
2. **Manually review** a sample of predictions
3. **Calculate inter-annotator agreement** if you have multiple annotators

This approach is useful for:
- **Unlabeled inference**: Classify new documents
- **Semi-supervised learning**: Use predictions to create training data
- **Active learning**: Identify uncertain predictions for manual labeling

In [None]:
# Inference on raw documents (no gold labels needed)
def predict_on_raw_document(model_evaluator, file_path, context_mode="prev"):
    """
    Predict roles for a raw legal document
    """
    # Read and preprocess
    with open(file_path, 'r', encoding='utf-8') as f:
        raw_text = f.read()
    
    # Segment sentences
    try:
        nlp = spacy.load("en_core_web_sm")
    except:
        print("Loading spacy model...")
        nlp = spacy.load("en_core_web_sm")
    
    doc = nlp(raw_text)
    sentences = [sent.text.strip() for sent in doc.sents if sent.text.strip() and len(sent.text.strip()) > 10]
    
    # Predict for each sentence
    predictions = []
    for sentence in sentences:
        try:
            result = model_evaluator.predict_single(sentence, context_mode=context_mode)
            predictions.append({
                'sentence': sentence,
                'predicted_role': result['predicted_role'],
                'confidence': result['confidence']
            })
        except Exception as e:
            print(f"‚ö†Ô∏è  Error predicting: {str(e)[:50]}")
            predictions.append({
                'sentence': sentence,
                'predicted_role': 'None',
                'confidence': 0.0
            })
    
    return predictions

def save_predictions(predictions, output_path):
    """Save predictions in sentence\trole format"""
    with open(output_path, 'w', encoding='utf-8') as f:
        for pred in predictions:
            f.write(f"{pred['sentence']}\t{pred['predicted_role']}\n")
    print(f"‚úÖ Saved predictions to: {output_path}")

# Example usage (after model is trained)
# This cell should be run AFTER training is complete
inference_mode = False  # Set to True after training

if inference_mode and 'evaluator' in locals():
    print("üîÆ Running inference on raw test document...")
    
    # Pick a test file
    test_file = Path(config["test_data"]) / "file_6409.txt"
    
    if test_file.exists():
        predictions = predict_on_raw_document(evaluator, test_file, context_mode=config["context_mode"])
        
        # Display summary
        print(f"\nüìä Prediction Summary:")
        print(f"   Total sentences: {len(predictions)}")
        
        role_dist = {}
        for pred in predictions:
            role = pred['predicted_role']
            role_dist[role] = role_dist.get(role, 0) + 1
        
        print(f"\nüè∑Ô∏è Predicted Role Distribution:")
        for role, count in sorted(role_dist.items(), key=lambda x: x[1], reverse=True):
            percentage = (count / len(predictions)) * 100
            print(f"   {role:<30} {count:>5} ({percentage:>5.1f}%)")
        
        # Show sample predictions
        print(f"\nüìù Sample Predictions:")
        for i, pred in enumerate(predictions[:5], 1):
            print(f"\n{i}. Sentence: {pred['sentence'][:80]}...")
            print(f"   Predicted: {pred['predicted_role']} (confidence: {pred['confidence']:.3f})")
        
        # Save predictions
        output_path = Path(config["output_dir"]) / "predictions_file_6409.txt"
        save_predictions(predictions, output_path)
    else:
        print(f"‚ùå Test file not found: {test_file}")
else:
    if not inference_mode:
        print("‚ö†Ô∏è  Inference mode disabled. Set inference_mode=True after training.")
    else:
        print("‚ö†Ô∏è  Model evaluator not available. Train the model first.")

In [None]:
# Initialize trainer
print("üöÄ Initializing Role Classifier Trainer...")

# Create output directory with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = Path(config["output_dir"]) / f"{config['model_type']}_{timestamp}"
output_dir.mkdir(parents=True, exist_ok=True)

print(f"üìÇ Output directory: {output_dir}")

try:
    trainer = RoleClassifierTrainer(
        model_type=config["model_type"],
        model_name=config["model_name"],
        device=config["device"],
        output_dir=str(output_dir),
        num_labels=7  # 7 rhetorical roles
    )
    
    print("‚úÖ Trainer initialized successfully!")
    print(f"üñ•Ô∏è  Using device: {config['device']}")
    print(f"ü§ñ Model type: {config['model_type']}")
    print(f"üìè Model parameters: {sum(p.numel() for p in trainer.model.parameters()):,}")
    
except Exception as e:
    print(f"‚ùå Error initializing trainer: {e}")

## üíæ Memory Optimization for Large Datasets (Kaggle/Colab)

**Problem**: Training on large datasets can cause:
- ‚ùå Out of memory errors when saving checkpoints
- ‚ùå Kaggle kernel crashes
- ‚ùå Slow training due to memory swapping

**Solutions**:
1. **Gradient Accumulation**: Simulate larger batches without memory overhead
2. **Checkpointing Strategy**: Save only essential weights, not optimizer states
3. **Mixed Precision Training**: Use FP16 to reduce memory by 50%
4. **Data Streaming**: Load batches on-the-fly instead of all at once
5. **Periodic Cleanup**: Clear cache and garbage collect

In [None]:
# Memory optimization configuration for Kaggle/Colab
memory_config = {
    # Strategy 1: Reduce batch size
    "reduced_batch_size": 8,  # Down from 16
    "gradient_accumulation_steps": 2,  # Effective batch size = 8 * 2 = 16
    
    # Strategy 2: Mixed precision training (FP16)
    "use_mixed_precision": True,  # Reduces memory by ~50%
    
    # Strategy 3: Checkpoint saving
    "save_frequency": "epoch",  # Options: "epoch", "steps", or number
    "save_optimizer_state": False,  # Don't save optimizer (saves 50% space)
    "keep_only_best": True,  # Delete intermediate checkpoints
    
    # Strategy 4: Data loading
    "num_workers": 2,  # Reduce data loader workers
    "pin_memory": False,  # Disable if running out of memory
    
    # Strategy 5: Periodic cleanup
    "clear_cache_every": 100,  # Clear CUDA cache every N steps
    "garbage_collect_every": 500,  # Run garbage collection
    
    # Strategy 6: Reduce training data (for testing)
    "use_subset": False,  # Set True to use only subset for quick testing
    "subset_ratio": 0.2,  # Use 20% of training data
}

print("üíæ Memory Optimization Configuration:")
print("=" * 60)
for key, value in memory_config.items():
    print(f"  {key:<30} ‚Üí {value}")
print("=" * 60)

# Calculate effective batch size
if memory_config["gradient_accumulation_steps"] > 1:
    effective_batch_size = (
        memory_config["reduced_batch_size"] * 
        memory_config["gradient_accumulation_steps"]
    )
    print(f"\nüìä Effective batch size: {effective_batch_size}")
    print(f"   (Physical: {memory_config['reduced_batch_size']} √ó "
          f"Accumulation: {memory_config['gradient_accumulation_steps']})")

if memory_config["use_mixed_precision"]:
    print(f"\n‚úÖ Mixed precision (FP16) enabled")
    print(f"   ‚Üí ~50% memory reduction")
    print(f"   ‚Üí ~2x speed improvement on modern GPUs")

if not memory_config["save_optimizer_state"]:
    print(f"\n‚úÖ Optimizer states excluded from checkpoints")
    print(f"   ‚Üí ~50% smaller checkpoint files")

In [None]:
# Check current memory usage (Kaggle/Colab)
import gc
import psutil

def check_memory_usage():
    """Check CPU and GPU memory usage"""
    # CPU Memory
    process = psutil.Process()
    cpu_memory_mb = process.memory_info().rss / 1024 / 1024
    
    vm = psutil.virtual_memory()
    total_memory_gb = vm.total / 1024 / 1024 / 1024
    available_memory_gb = vm.available / 1024 / 1024 / 1024
    used_percent = vm.percent
    
    print("üíæ Memory Status:")
    print("=" * 60)
    print(f"üìä CPU Memory:")
    print(f"   Process: {cpu_memory_mb:.1f} MB")
    print(f"   Total: {total_memory_gb:.1f} GB")
    print(f"   Available: {available_memory_gb:.1f} GB")
    print(f"   Used: {used_percent:.1f}%")
    
    # GPU Memory
    if torch.cuda.is_available():
        for i in range(torch.cuda.device_count()):
            allocated = torch.cuda.memory_allocated(i) / 1024 / 1024 / 1024
            reserved = torch.cuda.memory_reserved(i) / 1024 / 1024 / 1024
            total = torch.cuda.get_device_properties(i).total_memory / 1024 / 1024 / 1024
            
            print(f"\nüéÆ GPU {i} ({torch.cuda.get_device_name(i)}):")
            print(f"   Allocated: {allocated:.2f} GB")
            print(f"   Reserved: {reserved:.2f} GB")
            print(f"   Total: {total:.2f} GB")
            print(f"   Used: {(allocated/total)*100:.1f}%")
    else:
        print("\n‚ö†Ô∏è  No GPU available")
    
    print("=" * 60)
    
    # Warning if memory is high
    if used_percent > 80:
        print("\n‚ö†Ô∏è  WARNING: CPU memory usage is high (>80%)")
        print("   Consider enabling memory optimization strategies")
    
    if torch.cuda.is_available():
        gpu_used_percent = (torch.cuda.memory_allocated(0) / 
                           torch.cuda.get_device_properties(0).total_memory) * 100
        if gpu_used_percent > 80:
            print("\n‚ö†Ô∏è  WARNING: GPU memory usage is high (>80%)")
            print("   Consider:")
            print("   - Reducing batch size")
            print("   - Enabling mixed precision")
            print("   - Clearing CUDA cache")

# Check memory before training
print("üîç Checking memory before training...\n")
check_memory_usage()

# Cleanup to free memory
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print("\n‚úÖ Garbage collection and CUDA cache cleared")

### üîß Apply Memory Optimizations to Config

Based on the memory check above, update your training configuration:

In [None]:
# Apply memory optimizations to training config
# Run this cell to update your config for memory-constrained environments

apply_memory_optimizations = True  # Set to True for Kaggle/Colab

if apply_memory_optimizations:
    print("‚öôÔ∏è  Applying memory optimizations to training config...\n")
    
    # Store original values
    original_batch_size = config["batch_size"]
    
    # Update config with memory-optimized settings
    config.update({
        "batch_size": memory_config["reduced_batch_size"],
        "gradient_accumulation_steps": memory_config["gradient_accumulation_steps"],
        "use_mixed_precision": memory_config["use_mixed_precision"],
        "num_workers": memory_config["num_workers"],
        "pin_memory": memory_config["pin_memory"],
        "save_optimizer_state": memory_config["save_optimizer_state"],
        "clear_cache_frequency": memory_config["clear_cache_every"],
    })
    
    print("‚úÖ Configuration updated for memory optimization:")
    print("=" * 60)
    print(f"  batch_size: {original_batch_size} ‚Üí {config['batch_size']}")
    print(f"  gradient_accumulation_steps: {config['gradient_accumulation_steps']}")
    print(f"  Effective batch size: {config['batch_size'] * config['gradient_accumulation_steps']}")
    print(f"  use_mixed_precision: {config['use_mixed_precision']}")
    print(f"  save_optimizer_state: {config['save_optimizer_state']}")
    print("=" * 60)
    
    # Optional: Use subset for quick testing
    if memory_config["use_subset"]:
        print(f"\n‚ö†Ô∏è  Using {memory_config['subset_ratio']*100:.0f}% of training data (subset mode)")
        config["use_data_subset"] = True
        config["subset_ratio"] = memory_config["subset_ratio"]
    
    print("\nüí° Memory-saving features enabled:")
    print("   ‚úÖ Smaller batch size with gradient accumulation")
    print("   ‚úÖ Mixed precision training (FP16)")
    print("   ‚úÖ Lighter checkpoints (no optimizer state)")
    print("   ‚úÖ Periodic cache clearing")
    
else:
    print("‚ö†Ô∏è  Memory optimizations NOT applied.")
    print("   Set apply_memory_optimizations=True to enable.")

In [None]:
# Start training
print("üéØ Starting Training...")
print(f"‚è±Ô∏è  Training for {config['num_epochs']} epochs")
print(f"üìö Batch size: {config['batch_size']}")
print(f"üß† Learning rate: {config['learning_rate']}")
print(f"üìù Context mode: {config['context_mode']}")
print("=" * 60)

try:
    # Train the model
    trainer.train(
        train_data_path=config["train_data"],
        val_data_path=config["val_data"],
        test_data_path=config["test_data"],
        context_mode=config["context_mode"],
        batch_size=config["batch_size"],
        num_epochs=config["num_epochs"],
        learning_rate=config["learning_rate"],
        weight_decay=config["weight_decay"],
        warmup_steps=config["warmup_steps"],
        max_length=config["max_length"],
        save_best_model=config["save_best_model"]
    )
    
    print("\nüéâ Training completed successfully!")
    
except Exception as e:
    print(f"‚ùå Training failed: {e}")
    import traceback
    traceback.print_exc()

## Training Results Analysis

In [None]:
# Load and display training history
history_path = output_dir / "training_history.json"

if history_path.exists():
    with open(history_path, 'r') as f:
        history = json.load(f)
    
    print("üìà Training History:")
    print("=" * 50)
    
    # Display final metrics
    if history:
        final_train_loss = history['train_loss'][-1] if history['train_loss'] else 'N/A'
        final_val_loss = history['val_loss'][-1] if history['val_loss'] else 'N/A'
        final_train_f1 = history['train_f1'][-1] if history['train_f1'] else 'N/A'
        final_val_f1 = history['val_f1'][-1] if history['val_f1'] else 'N/A'
        
        print(f"Final Training Loss: {final_train_loss:.4f}" if isinstance(final_train_loss, float) else f"Final Training Loss: {final_train_loss}")
        print(f"Final Validation Loss: {final_val_loss:.4f}" if isinstance(final_val_loss, float) else f"Final Validation Loss: {final_val_loss}")
        print(f"Final Training F1: {final_train_f1:.4f}" if isinstance(final_train_f1, float) else f"Final Training F1: {final_train_f1}")
        print(f"Final Validation F1: {final_val_f1:.4f}" if isinstance(final_val_f1, float) else f"Final Validation F1: {final_val_f1}")
        
        # Plot training curves
        if all(key in history and history[key] for key in ['epoch', 'train_loss', 'val_loss', 'train_f1', 'val_f1']):
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
            
            # Loss plot
            ax1.plot(history['epoch'], history['train_loss'], label='Train Loss', marker='o')
            ax1.plot(history['epoch'], history['val_loss'], label='Val Loss', marker='s')
            ax1.set_xlabel('Epoch')
            ax1.set_ylabel('Loss')
            ax1.set_title('Training and Validation Loss')
            ax1.legend()
            ax1.grid(True, alpha=0.3)
            
            # F1 Score plot
            ax2.plot(history['epoch'], history['train_f1'], label='Train F1', marker='o')
            ax2.plot(history['epoch'], history['val_f1'], label='Val F1', marker='s')
            ax2.set_xlabel('Epoch')
            ax2.set_ylabel('F1 Score')
            ax2.set_title('Training and Validation F1 Score')
            ax2.legend()
            ax2.grid(True, alpha=0.3)
            
            plt.tight_layout()
            plt.show()
else:
    print("‚ùå Training history not found")

## Model Evaluation

In [None]:
# Evaluate the trained model
print("üîç Evaluating Trained Model...")

# Path to the best model
best_model_path = output_dir / "best_model.pt"

if best_model_path.exists():
    try:
        # Initialize evaluator
        evaluator = ModelEvaluator(
            model_path=str(best_model_path),
            device=config["device"]
        )
        
        print("‚úÖ Evaluator initialized")
        
        # Create evaluation output directory
        eval_output_dir = output_dir / "evaluation_results"
        
        # Evaluate on test set
        metrics = evaluator.evaluate_dataset(
            test_data_path=config["test_data"],
            context_mode=config["context_mode"],
            batch_size=config["batch_size"],
            output_dir=str(eval_output_dir)
        )
        
        print("\nüéØ Evaluation Results:")
        print("=" * 50)
        print(f"üìä Accuracy: {metrics['accuracy']:.4f}")
        print(f"üìä Weighted F1: {metrics['weighted_f1']:.4f}")
        print(f"üìä Macro F1: {metrics['macro_f1']:.4f}")
        print(f"üìä Weighted Precision: {metrics['weighted_precision']:.4f}")
        print(f"üìä Weighted Recall: {metrics['weighted_recall']:.4f}")
        
        # Display per-class metrics
        if 'per_class' in metrics:
            print("\nüìà Per-Class Metrics:")
            print("-" * 60)
            for role, class_metrics in metrics['per_class'].items():
                print(f"{role:20} | F1: {class_metrics['f1']:.3f} | Prec: {class_metrics['precision']:.3f} | Rec: {class_metrics['recall']:.3f}")
        
    except Exception as e:
        print(f"‚ùå Evaluation failed: {e}")
        import traceback
        traceback.print_exc()
else:
    print(f"‚ùå Best model not found at {best_model_path}")

## Test Single Predictions

In [None]:
# Test single predictions
if 'evaluator' in locals():
    print("üß™ Testing Single Predictions...")
    
    # Test sentences
    test_sentences = [
        "The petitioner filed a writ petition challenging the constitutional validity of Section 377.",
        "The main issue in this case is whether Section 377 violates fundamental rights.",
        "The petitioner argues that Section 377 is discriminatory and violates Article 14.",
        "The respondent contends that Section 377 is constitutionally valid and necessary.",
        "The court finds that Section 377 infringes upon the right to privacy and equality.",
        "Therefore, Section 377 is hereby declared unconstitutional and is struck down."
    ]
    
    expected_roles = ["Facts", "Issue", "Arguments of Petitioner", "Arguments of Respondent", "Reasoning", "Decision"]
    
    print("\nüìù Prediction Results:")
    print("=" * 100)
    
    correct_predictions = 0
    
    for i, (sentence, expected) in enumerate(zip(test_sentences, expected_roles)):
        result = evaluator.predict_single(sentence, context_mode=config["context_mode"])
        
        predicted_role = result['predicted_role']
        confidence = result['confidence']
        
        is_correct = predicted_role == expected
        if is_correct:
            correct_predictions += 1
        
        status = "‚úÖ" if is_correct else "‚ùå"
        
        print(f"\n{status} Example {i+1}:")
        print(f"Text: {sentence[:80]}{'...' if len(sentence) > 80 else ''}")
        print(f"Expected: {expected}")
        print(f"Predicted: {predicted_role} (Confidence: {confidence:.3f})")
        
        # Show top predictions
        print("Top 3 predictions:")
        for j, pred in enumerate(result['top_predictions'][:3]):
            print(f"  {j+1}. {pred['role']}: {pred['confidence']:.3f}")
        print("-" * 80)
    
    accuracy = correct_predictions / len(test_sentences)
    print(f"\nüéØ Test Accuracy: {correct_predictions}/{len(test_sentences)} ({accuracy:.1%})")
else:
    print("‚ùå Evaluator not available. Please complete the evaluation step first.")

## Save and Load Model for Production

In [None]:
# Demonstrate how to save and load the model for production use
print("üíæ Model Save/Load for Production")

if best_model_path.exists():
    print(f"\nüìÇ Best model saved at: {best_model_path}")
    
    # Show how to load the model in production
    print("\nüîß To use this model in your Nyaya system:")
    print("=" * 60)
    
    production_code = f'''
# In your production code (e.g., in role_classifier.py):
from src.models.role_classifier import RoleClassifier

# Initialize classifier
classifier = RoleClassifier(
    model_type="{config['model_type']}",
    device="{config['device']}"
)

# Load your trained weights
classifier.load_pretrained_weights("{best_model_path}")

# Use for classification
results = classifier.classify_document(
    document_text="Your legal document text here...",
    context_mode="{config['context_mode']}"
)
'''
    
    print(production_code)
    
    # Save production instructions
    instructions_path = output_dir / "production_usage.py"
    with open(instructions_path, 'w') as f:
        f.write(production_code)
    
    print(f"\nüìÑ Production usage instructions saved to: {instructions_path}")
    
    # Model info
    model_info = {
        "model_type": config["model_type"],
        "model_name": config["model_name"],
        "context_mode": config["context_mode"],
        "training_config": config,
        "model_path": str(best_model_path),
        "evaluation_metrics": metrics if 'metrics' in locals() else None,
        "timestamp": timestamp
    }
    
    info_path = output_dir / "model_info.json"
    with open(info_path, 'w') as f:
        json.dump(model_info, f, indent=2, default=str)
    
    print(f"üìã Model information saved to: {info_path}")
else:
    print("‚ùå No trained model found")

## Summary and Next Steps

In [None]:
# Training summary
print("üéä TRAINING SUMMARY")
print("=" * 60)

if 'metrics' in locals():
    print(f"‚úÖ Training completed successfully!")
    print(f"üìä Final Test Accuracy: {metrics['accuracy']:.4f}")
    print(f"üìä Final Test F1 Score: {metrics['weighted_f1']:.4f}")
    print(f"üìÇ Model saved at: {best_model_path}")
    print(f"üìÇ Results saved at: {output_dir}")
else:
    print("‚ö†Ô∏è  Training may not have completed successfully.")
    print("Please check the error messages above.")

print("\nüöÄ NEXT STEPS:")
print("1. üìã Review the evaluation results and confusion matrix")
print("2. üîß Integrate the trained model into your Nyaya system")
print("3. üß™ Test with real legal documents")
print("4. üìà Consider further fine-tuning if needed")
print("5. üîÑ Update the role_classifier.py to use your trained weights")

print("\nüìö FILES GENERATED:")
if output_dir.exists():
    generated_files = list(output_dir.rglob("*"))
    for file_path in generated_files:
        if file_path.is_file():
            print(f"  üìÑ {file_path.relative_to(output_dir)}")
else:
    print("  ‚ùå No output directory found")

## Optional: Hyperparameter Tuning

If you want to experiment with different hyperparameters, you can modify the configuration and re-run the training cells above. Consider trying:

- Different context modes: `"single"`, `"prev_two"`, `"surrounding"`
- Different learning rates: `1e-5`, `3e-5`, `5e-5`
- Different batch sizes: `8`, `32` (depending on your GPU memory)
- More epochs for better convergence
- Different model types: `"bilstm_crf"` for sequence modeling

## Troubleshooting

1. **CUDA Out of Memory**: Reduce batch size or max_length
2. **Low Accuracy**: Try more epochs, different context modes, or data augmentation
3. **Import Errors**: Check file paths and ensure all dependencies are installed
4. **Data Format Issues**: Ensure your data follows the sentence\trole format with proper encoding