# InLegalBERT Role Classifier Training

This notebook trains a rhetorical role classifier for Indian legal judgments using **InLegalBERT** as the base model.

## Model Information
- **Base Model:** `law-ai/InLegalBERT` (pre-trained on 5.4M Indian legal documents)
- **Task:** Multi-class sentence classification into 7 rhetorical roles
- **Dataset:** `train_final/` directory

## The 7 Rhetorical Roles
1. **Facts** - Background and case events
2. **Issue** - Legal questions to resolve
3. **Arguments of Petitioner (AoP)** - Petitioner's claims
4. **Arguments of Respondent (AoR)** - Respondent's counter-arguments
5. **Reasoning** - Court's legal analysis
6. **Decision** - Final judgment
7. **None** - Other content

## 1. Environment Setup

In [None]:
# Install required packages
!pip install transformers>=4.35.0 torch>=2.0.0 datasets scikit-learn pandas numpy tqdm matplotlib seaborn

In [None]:
# Import libraries
import os
import json
import torch
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Dict, Tuple

from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback,
    DataCollatorWithPadding
)
from datasets import Dataset, DatasetDict
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    accuracy_score,
    f1_score,
    precision_score,
    recall_score
)

# Set random seeds for reproducibility
import random
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 2. Configuration

In [None]:
# Configuration
class Config:
    # Paths
    DATASET_ROOT = Path("/root/dataset")
    TRAIN_DIR = DATASET_ROOT / "train_final"
    VAL_DIR = DATASET_ROOT / "val"
    TEST_DIR = DATASET_ROOT / "test"
    OUTPUT_DIR = Path("./models/inlegalbert_classifier")
    LOGS_DIR = Path("./logs")
    
    # Model
    MODEL_NAME = "law-ai/InLegalBERT"
    MAX_LENGTH = 256  # Maximum sequence length
    
    # Role labels
    ROLES = [
        "Facts",
        "Issue",
        "Arguments of Petitioner",
        "Arguments of Respondent",
        "Reasoning",
        "Decision",
        "None"
    ]
    
    # Training hyperparameters
    BATCH_SIZE = 16
    LEARNING_RATE = 2e-5
    NUM_EPOCHS = 10
    WEIGHT_DECAY = 0.01
    WARMUP_STEPS = 500
    GRADIENT_ACCUMULATION_STEPS = 2
    
    # Early stopping
    EARLY_STOPPING_PATIENCE = 3
    EARLY_STOPPING_THRESHOLD = 0.01
    
    # Evaluation
    EVAL_STEPS = 500
    SAVE_STEPS = 500
    LOGGING_STEPS = 100
    
    # Device
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

config = Config()

# Create directories
config.OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
config.LOGS_DIR.mkdir(parents=True, exist_ok=True)

print(f"Training on: {config.DEVICE}")
print(f"Output directory: {config.OUTPUT_DIR}")
print(f"Number of roles: {len(config.ROLES)}")

## 3. Data Loading

The dataset format is:
```
Sentence\tRole\tConfidence
Sentence\tRole\tConfidence

Sentence\tRole\tConfidence
...
```
Blank lines separate different documents.

In [None]:
def load_dataset_from_files(data_dir: Path) -> List[Dict[str, str]]:
    """
    Load dataset from tab-separated text files.
    
    Format: Sentence\tRole\tConfidence (blank lines separate documents)
    """
    samples = []
    
    if not data_dir.exists():
        print(f"Warning: Directory {data_dir} does not exist!")
        return samples
    
    txt_files = list(data_dir.glob("*.txt"))
    print(f"Found {len(txt_files)} files in {data_dir.name}")
    
    for file_path in tqdm(txt_files, desc=f"Loading {data_dir.name}"):
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                lines = f.readlines()
            
            for line in lines:
                line = line.strip()
                if not line:  # Skip blank lines
                    continue
                
                parts = line.split('\t')
                if len(parts) >= 2:
                    sentence = parts[0].strip()
                    role = parts[1].strip()
                    
                    # Skip empty sentences
                    if sentence and role in config.ROLES:
                        samples.append({
                            'text': sentence,
                            'label': role
                        })
        except Exception as e:
            print(f"Error loading {file_path}: {e}")
            continue
    
    return samples

# Load datasets
print("Loading training data...")
train_samples = load_dataset_from_files(config.TRAIN_DIR)

print("\nLoading validation data...")
val_samples = load_dataset_from_files(config.VAL_DIR)

print("\nLoading test data...")
test_samples = load_dataset_from_files(config.TEST_DIR)

print(f"\n{'='*50}")
print(f"Dataset Statistics:")
print(f"  Training samples: {len(train_samples):,}")
print(f"  Validation samples: {len(val_samples):,}")
print(f"  Test samples: {len(test_samples):,}")
print(f"{'='*50}")

## 4. Data Analysis

In [None]:
def analyze_dataset(samples: List[Dict], split_name: str):
    """
    Analyze and visualize dataset statistics.
    """
    df = pd.DataFrame(samples)
    
    print(f"\n{split_name} Dataset Analysis:")
    print("="*60)
    
    # Role distribution
    role_counts = df['label'].value_counts()
    print("\nRole Distribution:")
    for role, count in role_counts.items():
        percentage = (count / len(df)) * 100
        print(f"  {role:30s}: {count:6,} ({percentage:5.2f}%)")
    
    # Text length statistics
    df['text_length'] = df['text'].apply(lambda x: len(x.split()))
    print("\nText Length Statistics (words):")
    print(f"  Mean: {df['text_length'].mean():.2f}")
    print(f"  Median: {df['text_length'].median():.2f}")
    print(f"  Min: {df['text_length'].min()}")
    print(f"  Max: {df['text_length'].max()}")
    
    # Visualizations
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Role distribution bar chart
    role_counts.plot(kind='bar', ax=axes[0], color='steelblue')
    axes[0].set_title(f'{split_name} - Role Distribution')
    axes[0].set_xlabel('Role')
    axes[0].set_ylabel('Count')
    axes[0].tick_params(axis='x', rotation=45)
    
    # Text length distribution
    axes[1].hist(df['text_length'], bins=50, color='coral', edgecolor='black')
    axes[1].set_title(f'{split_name} - Text Length Distribution')
    axes[1].set_xlabel('Number of Words')
    axes[1].set_ylabel('Frequency')
    axes[1].axvline(df['text_length'].mean(), color='red', linestyle='--', label='Mean')
    axes[1].legend()
    
    plt.tight_layout()
    plt.savefig(config.LOGS_DIR / f'{split_name.lower()}_analysis.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    return df

# Analyze datasets
if train_samples:
    train_df = analyze_dataset(train_samples, "Training")
if val_samples:
    val_df = analyze_dataset(val_samples, "Validation")
if test_samples:
    test_df = analyze_dataset(test_samples, "Test")

## 5. Create Label Mappings

In [None]:
# Create label mappings
label2id = {label: idx for idx, label in enumerate(config.ROLES)}
id2label = {idx: label for label, idx in label2id.items()}

print("Label Mappings:")
for label, idx in label2id.items():
    print(f"  {idx}: {label}")

# Save label mappings
with open(config.OUTPUT_DIR / "label_mappings.json", 'w') as f:
    json.dump({
        'label2id': label2id,
        'id2label': id2label
    }, f, indent=2)

print(f"\nLabel mappings saved to {config.OUTPUT_DIR / 'label_mappings.json'}")

## 6. Prepare Datasets for Training

In [None]:
def prepare_dataset(samples: List[Dict]) -> Dataset:
    """
    Convert samples to HuggingFace Dataset format with label IDs.
    """
    # Convert labels to IDs
    data = {
        'text': [s['text'] for s in samples],
        'label': [label2id[s['label']] for s in samples]
    }
    
    return Dataset.from_dict(data)

# Create datasets
train_dataset = prepare_dataset(train_samples) if train_samples else None
val_dataset = prepare_dataset(val_samples) if val_samples else None
test_dataset = prepare_dataset(test_samples) if test_samples else None

print("Dataset Preparation Complete:")
if train_dataset:
    print(f"  Training: {len(train_dataset)} samples")
if val_dataset:
    print(f"  Validation: {len(val_dataset)} samples")
if test_dataset:
    print(f"  Test: {len(test_dataset)} samples")

# Show example
if train_dataset:
    print("\nExample training sample:")
    print(f"  Text: {train_dataset[0]['text'][:100]}...")
    print(f"  Label ID: {train_dataset[0]['label']}")
    print(f"  Label Name: {id2label[train_dataset[0]['label']]}")

## 7. Load Tokenizer and Model

In [None]:
# Load tokenizer
print(f"Loading tokenizer: {config.MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained(config.MODEL_NAME)

# Load model
print(f"Loading model: {config.MODEL_NAME}")
model = AutoModelForSequenceClassification.from_pretrained(
    config.MODEL_NAME,
    num_labels=len(config.ROLES),
    label2id=label2id,
    id2label=id2label,
    problem_type="single_label_classification"
)

# Move model to device
model.to(config.DEVICE)

print(f"\nModel loaded successfully!")
print(f"  Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"  Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
print(f"  Device: {next(model.parameters()).device}")

## 8. Tokenization

In [None]:
def tokenize_function(examples):
    """
    Tokenize text inputs.
    """
    return tokenizer(
        examples['text'],
        padding=False,  # Dynamic padding in data collator
        truncation=True,
        max_length=config.MAX_LENGTH,
        return_tensors=None
    )

# Tokenize datasets
print("Tokenizing datasets...")

if train_dataset:
    train_dataset = train_dataset.map(
        tokenize_function,
        batched=True,
        desc="Tokenizing training data"
    )

if val_dataset:
    val_dataset = val_dataset.map(
        tokenize_function,
        batched=True,
        desc="Tokenizing validation data"
    )

if test_dataset:
    test_dataset = test_dataset.map(
        tokenize_function,
        batched=True,
        desc="Tokenizing test data"
    )

print("Tokenization complete!")

# Data collator for dynamic padding
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

## 9. Evaluation Metrics

In [None]:
def compute_metrics(eval_pred):
    """
    Compute evaluation metrics.
    """
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    
    # Calculate metrics
    accuracy = accuracy_score(labels, predictions)
    f1_macro = f1_score(labels, predictions, average='macro')
    f1_weighted = f1_score(labels, predictions, average='weighted')
    precision_macro = precision_score(labels, predictions, average='macro', zero_division=0)
    recall_macro = recall_score(labels, predictions, average='macro', zero_division=0)
    
    return {
        'accuracy': accuracy,
        'f1_macro': f1_macro,
        'f1_weighted': f1_weighted,
        'precision_macro': precision_macro,
        'recall_macro': recall_macro
    }

## 10. Training Configuration

In [None]:
# Training arguments
training_args = TrainingArguments(
    output_dir=str(config.OUTPUT_DIR),
    
    # Training hyperparameters
    num_train_epochs=config.NUM_EPOCHS,
    per_device_train_batch_size=config.BATCH_SIZE,
    per_device_eval_batch_size=config.BATCH_SIZE * 2,
    learning_rate=config.LEARNING_RATE,
    weight_decay=config.WEIGHT_DECAY,
    warmup_steps=config.WARMUP_STEPS,
    gradient_accumulation_steps=config.GRADIENT_ACCUMULATION_STEPS,
    
    # Evaluation
    eval_strategy="steps",
    eval_steps=config.EVAL_STEPS,
    save_strategy="steps",
    save_steps=config.SAVE_STEPS,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="f1_macro",
    greater_is_better=True,
    
    # Logging
    logging_dir=str(config.LOGS_DIR),
    logging_steps=config.LOGGING_STEPS,
    report_to=["tensorboard"],
    
    # Performance
    fp16=torch.cuda.is_available(),
    dataloader_num_workers=4,
    dataloader_pin_memory=True,
    
    # Misc
    seed=42,
    remove_unused_columns=True,
    push_to_hub=False,
)

print("Training Configuration:")
print(f"  Total epochs: {config.NUM_EPOCHS}")
print(f"  Batch size: {config.BATCH_SIZE}")
print(f"  Learning rate: {config.LEARNING_RATE}")
print(f"  Gradient accumulation steps: {config.GRADIENT_ACCUMULATION_STEPS}")
print(f"  Effective batch size: {config.BATCH_SIZE * config.GRADIENT_ACCUMULATION_STEPS}")
print(f"  FP16 training: {training_args.fp16}")
print(f"  Warmup steps: {config.WARMUP_STEPS}")
print(f"  Eval steps: {config.EVAL_STEPS}")

## 11. Initialize Trainer

In [None]:
# Initialize trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[
        EarlyStoppingCallback(
            early_stopping_patience=config.EARLY_STOPPING_PATIENCE,
            early_stopping_threshold=config.EARLY_STOPPING_THRESHOLD
        )
    ]
)

print("Trainer initialized successfully!")

# Calculate training stats
if train_dataset:
    total_steps = (
        len(train_dataset) // 
        (config.BATCH_SIZE * config.GRADIENT_ACCUMULATION_STEPS) * 
        config.NUM_EPOCHS
    )
    print(f"\nEstimated training steps: {total_steps:,}")
    print(f"Evaluation every {config.EVAL_STEPS} steps")
    print(f"Total evaluations: ~{total_steps // config.EVAL_STEPS}")

## 12. Train Model

In [None]:
# Train the model
print("Starting training...\n")
print("=" * 70)

train_result = trainer.train()

print("\n" + "=" * 70)
print("Training complete!\n")

# Print training summary
print("Training Summary:")
for key, value in train_result.metrics.items():
    print(f"  {key}: {value}")

## 13. Save Model and Tokenizer

In [None]:
# Save the final model
print(f"\nSaving model to {config.OUTPUT_DIR}...")

trainer.save_model(str(config.OUTPUT_DIR))
tokenizer.save_pretrained(str(config.OUTPUT_DIR))

# Save training arguments
with open(config.OUTPUT_DIR / "training_config.json", 'w') as f:
    json.dump({
        'model_name': config.MODEL_NAME,
        'num_epochs': config.NUM_EPOCHS,
        'batch_size': config.BATCH_SIZE,
        'learning_rate': config.LEARNING_RATE,
        'max_length': config.MAX_LENGTH,
        'num_labels': len(config.ROLES),
        'roles': config.ROLES
    }, f, indent=2)

print("Model saved successfully!")
print(f"\nModel files:")
for file in sorted(config.OUTPUT_DIR.glob("*")):
    print(f"  - {file.name}")

## 14. Evaluate on Test Set

In [None]:
if test_dataset:
    print("Evaluating on test set...\n")
    
    # Evaluate
    test_results = trainer.evaluate(test_dataset)
    
    print("Test Set Results:")
    print("=" * 50)
    for key, value in test_results.items():
        print(f"  {key}: {value:.4f}")
    
    # Save test results
    with open(config.OUTPUT_DIR / "test_results.json", 'w') as f:
        json.dump(test_results, f, indent=2)
else:
    print("No test dataset available for evaluation.")

## 15. Detailed Classification Report

In [None]:
if test_dataset:
    print("Generating detailed classification report...\n")
    
    # Get predictions
    predictions = trainer.predict(test_dataset)
    pred_labels = np.argmax(predictions.predictions, axis=-1)
    true_labels = predictions.label_ids
    
    # Classification report
    print("\nClassification Report:")
    print("=" * 80)
    report = classification_report(
        true_labels,
        pred_labels,
        target_names=config.ROLES,
        digits=4
    )
    print(report)
    
    # Save classification report
    with open(config.OUTPUT_DIR / "classification_report.txt", 'w') as f:
        f.write(report)
    
    # Confusion matrix
    cm = confusion_matrix(true_labels, pred_labels)
    
    # Plot confusion matrix
    plt.figure(figsize=(12, 10))
    sns.heatmap(
        cm,
        annot=True,
        fmt='d',
        cmap='Blues',
        xticklabels=config.ROLES,
        yticklabels=config.ROLES,
        cbar_kws={'label': 'Count'}
    )
    plt.title('Confusion Matrix - Test Set', fontsize=16, pad=20)
    plt.xlabel('Predicted Label', fontsize=12)
    plt.ylabel('True Label', fontsize=12)
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(config.OUTPUT_DIR / 'confusion_matrix.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"\nConfusion matrix saved to {config.OUTPUT_DIR / 'confusion_matrix.png'}")

## 16. Test Inference

In [None]:
# Test inference on sample sentences
print("Testing inference on sample sentences...\n")

test_sentences = [
    "The petitioner filed a writ petition challenging the constitutional validity of the Act.",
    "The main issue in this case is whether the amendment violates Article 14 of the Constitution.",
    "The learned counsel for the petitioner submitted that the impugned order is arbitrary.",
    "The respondent contends that the petition is not maintainable in law.",
    "After careful consideration of the submissions made by both parties, we are of the view that the law is well settled.",
    "The writ petition is hereby dismissed with costs.",
    "The court was informed about the procedural requirements."
]

# Load the trained model
from transformers import pipeline

classifier = pipeline(
    "text-classification",
    model=model,
    tokenizer=tokenizer,
    device=0 if torch.cuda.is_available() else -1
)

print("Sample Predictions:")
print("=" * 100)

for sentence in test_sentences:
    result = classifier(sentence)[0]
    predicted_label = result['label']
    confidence = result['score']
    
    print(f"\nSentence: {sentence}")
    print(f"Predicted: {predicted_label} (confidence: {confidence:.4f})")
    print("-" * 100)

## 17. Training Visualization

In [None]:
# Plot training history
import json
from pathlib import Path

# Load training log
log_file = list(config.LOGS_DIR.glob("**/trainer_state.json"))

if log_file:
    with open(log_file[0], 'r') as f:
        trainer_state = json.load(f)
    
    log_history = trainer_state['log_history']
    
    # Extract metrics
    train_loss = []
    eval_loss = []
    eval_f1 = []
    steps = []
    eval_steps = []
    
    for entry in log_history:
        if 'loss' in entry:
            train_loss.append(entry['loss'])
            steps.append(entry['step'])
        if 'eval_loss' in entry:
            eval_loss.append(entry['eval_loss'])
            eval_f1.append(entry.get('eval_f1_macro', 0))
            eval_steps.append(entry['step'])
    
    # Plot
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    # Loss plot
    axes[0].plot(steps, train_loss, label='Training Loss', color='blue', linewidth=2)
    axes[0].plot(eval_steps, eval_loss, label='Validation Loss', color='orange', linewidth=2)
    axes[0].set_xlabel('Training Steps')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training and Validation Loss')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # F1 score plot
    axes[1].plot(eval_steps, eval_f1, label='Validation F1 (Macro)', color='green', linewidth=2)
    axes[1].set_xlabel('Training Steps')
    axes[1].set_ylabel('F1 Score')
    axes[1].set_title('Validation F1 Score')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(config.OUTPUT_DIR / 'training_history.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"Training history plot saved to {config.OUTPUT_DIR / 'training_history.png'}")
else:
    print("Training log not found.")

## 18. Model Summary and Export Information

In [None]:
print("\n" + "=" * 80)
print("TRAINING COMPLETE - MODEL SUMMARY")
print("=" * 80)

print(f"\n📁 Model Location: {config.OUTPUT_DIR}")
print(f"\n📊 Performance Metrics:")
if test_dataset:
    print(f"  - Test Accuracy: {test_results.get('eval_accuracy', 0):.4f}")
    print(f"  - Test F1 (Macro): {test_results.get('eval_f1_macro', 0):.4f}")
    print(f"  - Test F1 (Weighted): {test_results.get('eval_f1_weighted', 0):.4f}")

print(f"\n🔧 Model Configuration:")
print(f"  - Base Model: {config.MODEL_NAME}")
print(f"  - Number of Labels: {len(config.ROLES)}")
print(f"  - Max Sequence Length: {config.MAX_LENGTH}")
print(f"  - Total Parameters: {sum(p.numel() for p in model.parameters()):,}")

print(f"\n📝 Usage Instructions:")
print(f"""\nTo load this model for inference:

```python
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

# Load model and tokenizer
model_path = "{config.OUTPUT_DIR}"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)

# Inference
text = "Your legal sentence here"
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length={config.MAX_LENGTH})
outputs = model(**inputs)
predicted_class = torch.argmax(outputs.logits, dim=-1).item()

# Load label mappings
import json
with open("{config.OUTPUT_DIR}/label_mappings.json", 'r') as f:
    mappings = json.load(f)
    id2label = mappings['id2label']

print(f"Predicted role: {{id2label[str(predicted_class)]}}")
```
""")

print(f"\n✅ All outputs saved to: {config.OUTPUT_DIR}")
print(f"   - Model weights: pytorch_model.bin")
print(f"   - Tokenizer: tokenizer_config.json, vocab.txt")
print(f"   - Label mappings: label_mappings.json")
print(f"   - Training config: training_config.json")
print(f"   - Test results: test_results.json")
print(f"   - Classification report: classification_report.txt")
print(f"   - Confusion matrix: confusion_matrix.png")
print(f"   - Training history: training_history.png")

print("\n" + "=" * 80)
print("🎉 Training pipeline completed successfully!")
print("=" * 80)