# Legal Document Rhetorical Role Classifier Training

This notebook trains a rhetorical role classifier for legal documents using InLegalBERT.

**Dataset Structure:**
- `/kaggle/input/dataset/Hier_BiLSTM_CRF/train/` - Training data
- `/kaggle/input/dataset/Hier_BiLSTM_CRF/val/` - Validation data
- `/kaggle/input/dataset/Hier_BiLSTM_CRF/test/` - Test data

**Model Output:**
- Saved to `/kaggle/working/`

## 1. Install Dependencies and Setup

In [None]:
# Install required packages
# Uninstall existing transformers to avoid conflicts
!pip uninstall -y transformers

# Install compatible versions
!pip install -q transformers==4.35.2
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install -q scikit-learn pandas numpy matplotlib seaborn tqdm
!pip install -q accelerate sentencepiece protobuf

In [None]:
# Import libraries
import os
import gc
import random
import logging
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm.auto import tqdm
from collections import Counter

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.cuda.amp import autocast, GradScaler

from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from sklearn.metrics import classification_report, confusion_matrix, f1_score

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Check GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 2. Configuration

In [None]:
# Training configuration
class Config:
    # Paths
    DATA_PATH = "/kaggle/input/dataset/Hier_BiLSTM_CRF"
    TRAIN_PATH = os.path.join(DATA_PATH, "train")
    VAL_PATH = os.path.join(DATA_PATH, "val/val")
    TEST_PATH = os.path.join(DATA_PATH, "test")
    OUTPUT_PATH = "/kaggle/working"
    
    # Model
    MODEL_NAME = "law-ai/InLegalBERT"
    MAX_LENGTH = 256  # Reduced for memory optimization
    CONTEXT_MODE = "prev"  # Options: "single", "prev", "prev_two", "surrounding"
    
    # Training
    BATCH_SIZE = 16  # Adjust based on GPU memory
    LEARNING_RATE = 2e-5
    NUM_EPOCHS = 3
    WARMUP_STEPS = 500
    WEIGHT_DECAY = 0.01
    MAX_GRAD_NORM = 1.0
    
    # Data sampling (for memory constraints)
    TRAIN_SAMPLE_RATIO = 1.0  # Use 100% of training data
    VAL_SAMPLE_RATIO = 1.0
    TEST_SAMPLE_RATIO = 1.0
    
    # Optimization
    USE_AMP = True  # Automatic Mixed Precision
    GRADIENT_ACCUMULATION_STEPS = 2
    
    # Rhetorical roles
    ROLES = [
        "None",
        "Facts",
        "Issue",
        "Arguments of Petitioner",
        "Arguments of Respondent",
        "Reasoning",
        "Decision"
    ]
    NUM_LABELS = len(ROLES)
    ROLE2ID = {role: idx for idx, role in enumerate(ROLES)}
    ID2ROLE = {idx: role for idx, role in enumerate(ROLES)}

config = Config()
print(f"Configuration loaded. Training on {config.CONTEXT_MODE} context mode.")

## 3. Dataset Class

In [None]:
class LegalDocumentDataset(Dataset):
    """Optimized dataset for legal document role classification"""
    
    def __init__(self, data_path, tokenizer, config, split='train'):
        self.data_path = Path(data_path)
        self.tokenizer = tokenizer
        self.config = config
        self.split = split
        self.data = []
        
        # Load and process data
        self._load_data()
        
    def _load_data(self):
        """Load data from text files"""
        logger.info(f"Loading {self.split} data from {self.data_path}")
        
        # Get all text files
        txt_files = sorted(list(self.data_path.glob("*.txt")))
        
        # Sample files if needed
        if self.split == 'train' and self.config.TRAIN_SAMPLE_RATIO < 1.0:
            num_files = int(len(txt_files) * self.config.TRAIN_SAMPLE_RATIO)
            txt_files = random.sample(txt_files, num_files)
        elif self.split == 'val' and self.config.VAL_SAMPLE_RATIO < 1.0:
            num_files = int(len(txt_files) * self.config.VAL_SAMPLE_RATIO)
            txt_files = random.sample(txt_files, num_files)
        elif self.split == 'test' and self.config.TEST_SAMPLE_RATIO < 1.0:
            num_files = int(len(txt_files) * self.config.TEST_SAMPLE_RATIO)
            txt_files = random.sample(txt_files, num_files)
        
        logger.info(f"Processing {len(txt_files)} files...")
        
        for file_path in tqdm(txt_files, desc=f"Loading {self.split}"):
            document = self._load_file(file_path)
            if document:
                self._process_document(document)
        
        logger.info(f"Loaded {len(self.data)} sentences from {self.split} set")
        self._print_statistics()
    
    def _load_file(self, file_path):
        """Load a single file"""
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                lines = f.readlines()
            
            document = []
            for line in lines:
                line = line.strip()
                if '\t' in line:
                    parts = line.split('\t')
                    if len(parts) == 2:
                        sentence, role = parts
                        if role in self.config.ROLE2ID:
                            document.append({'sentence': sentence.strip(), 'role': role})
            
            return document
        except Exception as e:
            logger.warning(f"Error loading {file_path}: {e}")
            return None
    
    def _process_document(self, document):
        """Process document with context"""
        for idx, item in enumerate(document):
            context_text = self._create_context(document, idx)
            self.data.append({
                'text': context_text,
                'label': self.config.ROLE2ID[item['role']]
            })
    
    def _create_context(self, document, idx):
        """Create context based on mode"""
        current = document[idx]['sentence']
        
        if self.config.CONTEXT_MODE == "single":
            return current
        elif self.config.CONTEXT_MODE == "prev":
            if idx > 0:
                prev = document[idx-1]['sentence']
                return f"{prev} [SEP] {current}"
            return current
        elif self.config.CONTEXT_MODE == "prev_two":
            context = []
            if idx > 1:
                context.append(document[idx-2]['sentence'])
            if idx > 0:
                context.append(document[idx-1]['sentence'])
            context.append(current)
            return " [SEP] ".join(context)
        elif self.config.CONTEXT_MODE == "surrounding":
            context = []
            if idx > 0:
                context.append(document[idx-1]['sentence'])
            context.append(current)
            if idx < len(document) - 1:
                context.append(document[idx+1]['sentence'])
            return " [SEP] ".join(context)
        
        return current
    
    def _print_statistics(self):
        """Print dataset statistics"""
        labels = [item['label'] for item in self.data]
        label_counts = Counter(labels)
        
        print(f"\n{self.split.upper()} Dataset Statistics:")
        print(f"Total samples: {len(self.data)}")
        print("\nLabel distribution:")
        for label_id, count in sorted(label_counts.items()):
            role_name = self.config.ID2ROLE[label_id]
            percentage = (count / len(self.data)) * 100
            print(f"  {role_name}: {count} ({percentage:.2f}%)")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Tokenize
        encoding = self.tokenizer(
            item['text'],
            max_length=self.config.MAX_LENGTH,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(item['label'], dtype=torch.long)
        }

## 4. Model Definition

In [None]:
class InLegalBERTClassifier(nn.Module):
    """InLegalBERT-based classifier for rhetorical role classification"""
    
    def __init__(self, model_name, num_labels):
        super().__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
        
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        
        pooled_output = outputs.pooler_output
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        
        return logits

## 5. Training Functions

In [None]:
def train_epoch(model, dataloader, optimizer, scheduler, scaler, device, config):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    progress_bar = tqdm(dataloader, desc="Training")
    optimizer.zero_grad()
    
    for step, batch in enumerate(progress_bar):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        # Mixed precision training
        if config.USE_AMP:
            with autocast():
                logits = model(input_ids, attention_mask)
                loss = nn.CrossEntropyLoss()(logits, labels)
                loss = loss / config.GRADIENT_ACCUMULATION_STEPS
            
            scaler.scale(loss).backward()
            
            if (step + 1) % config.GRADIENT_ACCUMULATION_STEPS == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), config.MAX_GRAD_NORM)
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()
                optimizer.zero_grad()
        else:
            logits = model(input_ids, attention_mask)
            loss = nn.CrossEntropyLoss()(logits, labels)
            loss = loss / config.GRADIENT_ACCUMULATION_STEPS
            loss.backward()
            
            if (step + 1) % config.GRADIENT_ACCUMULATION_STEPS == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), config.MAX_GRAD_NORM)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
        
        # Statistics
        total_loss += loss.item() * config.GRADIENT_ACCUMULATION_STEPS
        predictions = torch.argmax(logits, dim=1)
        correct += (predictions == labels).sum().item()
        total += labels.size(0)
        
        # Update progress bar
        progress_bar.set_postfix({
            'loss': total_loss / (step + 1),
            'acc': 100. * correct / total
        })
    
    return total_loss / len(dataloader), correct / total


def evaluate(model, dataloader, device, config):
    """Evaluate the model"""
    model.eval()
    total_loss = 0
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            if config.USE_AMP:
                with autocast():
                    logits = model(input_ids, attention_mask)
                    loss = nn.CrossEntropyLoss()(logits, labels)
            else:
                logits = model(input_ids, attention_mask)
                loss = nn.CrossEntropyLoss()(logits, labels)
            
            total_loss += loss.item()
            predictions = torch.argmax(logits, dim=1)
            
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    avg_loss = total_loss / len(dataloader)
    accuracy = np.mean(np.array(all_predictions) == np.array(all_labels))
    f1_macro = f1_score(all_labels, all_predictions, average='macro')
    f1_weighted = f1_score(all_labels, all_predictions, average='weighted')
    
    return avg_loss, accuracy, f1_macro, f1_weighted, all_predictions, all_labels

## 6. Data Loading

In [None]:
# Initialize tokenizer
print(f"Loading tokenizer: {config.MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained(config.MODEL_NAME)

# Create datasets
train_dataset = LegalDocumentDataset(config.TRAIN_PATH, tokenizer, config, split='train')
val_dataset = LegalDocumentDataset(config.VAL_PATH, tokenizer, config, split='val')
test_dataset = LegalDocumentDataset(config.TEST_PATH, tokenizer, config, split='test')

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=config.BATCH_SIZE, shuffle=False, num_workers=2)

print(f"\nDataloaders created:")
print(f"  Training batches: {len(train_loader)}")
print(f"  Validation batches: {len(val_loader)}")
print(f"  Test batches: {len(test_loader)}")

## 7. Model Initialization

In [None]:
# Initialize model
print(f"Initializing model: {config.MODEL_NAME}")
model = InLegalBERTClassifier(config.MODEL_NAME, config.NUM_LABELS)
model.to(device)

# Optimizer and scheduler
optimizer = AdamW(model.parameters(), lr=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY)

total_steps = len(train_loader) * config.NUM_EPOCHS // config.GRADIENT_ACCUMULATION_STEPS
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=config.WARMUP_STEPS,
    num_training_steps=total_steps
)

# Mixed precision scaler
scaler = GradScaler() if config.USE_AMP else None

print(f"\nTraining configuration:")
print(f"  Total steps: {total_steps}")
print(f"  Warmup steps: {config.WARMUP_STEPS}")
print(f"  Learning rate: {config.LEARNING_RATE}")
print(f"  Mixed precision: {config.USE_AMP}")

## 8. Training Loop

In [None]:
# Training history
history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': [],
    'val_f1_macro': [],
    'val_f1_weighted': []
}

best_val_f1 = 0
best_epoch = 0

print("\n" + "="*60)
print("Starting Training")
print("="*60)

for epoch in range(config.NUM_EPOCHS):
    print(f"\nEpoch {epoch + 1}/{config.NUM_EPOCHS}")
    print("-" * 60)
    
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, scheduler, scaler, device, config)
    
    # Validate
    val_loss, val_acc, val_f1_macro, val_f1_weighted, _, _ = evaluate(model, val_loader, device, config)
    
    # Save history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['val_f1_macro'].append(val_f1_macro)
    history['val_f1_weighted'].append(val_f1_weighted)
    
    # Print epoch results
    print(f"\nEpoch {epoch + 1} Results:")
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%")
    print(f"  Val Loss: {val_loss:.4f} | Val Acc: {val_acc*100:.2f}%")
    print(f"  Val F1 (Macro): {val_f1_macro:.4f} | Val F1 (Weighted): {val_f1_weighted:.4f}")
    
    # Save best model
    if val_f1_weighted > best_val_f1:
        best_val_f1 = val_f1_weighted
        best_epoch = epoch + 1
        
        # Save model
        model_path = os.path.join(config.OUTPUT_PATH, 'best_model.pt')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_f1_weighted': val_f1_weighted,
            'config': config
        }, model_path)
        
        print(f"  ✓ New best model saved! (F1: {val_f1_weighted:.4f})")
    
    # Garbage collection
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

print("\n" + "="*60)
print(f"Training completed! Best epoch: {best_epoch} (F1: {best_val_f1:.4f})")
print("="*60)

## 9. Evaluation on Test Set

In [None]:
# Load best model
print("Loading best model for testing...")
checkpoint = torch.load(os.path.join(config.OUTPUT_PATH, 'best_model.pt'))
model.load_state_dict(checkpoint['model_state_dict'])

# Evaluate on test set
print("\nEvaluating on test set...")
test_loss, test_acc, test_f1_macro, test_f1_weighted, test_predictions, test_labels = evaluate(
    model, test_loader, device, config
)

print("\n" + "="*60)
print("Test Set Results")
print("="*60)
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc*100:.2f}%")
print(f"Test F1 (Macro): {test_f1_macro:.4f}")
print(f"Test F1 (Weighted): {test_f1_weighted:.4f}")

## 10. Classification Report

In [None]:
# Generate classification report
print("\nDetailed Classification Report:")
print("="*80)
report = classification_report(
    test_labels,
    test_predictions,
    target_names=config.ROLES,
    digits=4
)
print(report)

# Save report
with open(os.path.join(config.OUTPUT_PATH, 'classification_report.txt'), 'w') as f:
    f.write(report)

## 11. Confusion Matrix

In [None]:
# Compute confusion matrix
cm = confusion_matrix(test_labels, test_predictions)

# Plot confusion matrix
plt.figure(figsize=(12, 10))
sns.heatmap(
    cm,
    annot=True,
    fmt='d',
    cmap='Blues',
    xticklabels=config.ROLES,
    yticklabels=config.ROLES
)
plt.title('Confusion Matrix - Test Set')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig(os.path.join(config.OUTPUT_PATH, 'confusion_matrix.png'), dpi=300, bbox_inches='tight')
plt.show()

print("Confusion matrix saved!")

## 12. Training History Plots

In [None]:
# Plot training history
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Loss
axes[0, 0].plot(history['train_loss'], label='Train Loss', marker='o')
axes[0, 0].plot(history['val_loss'], label='Val Loss', marker='s')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training and Validation Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Accuracy
axes[0, 1].plot([acc*100 for acc in history['train_acc']], label='Train Acc', marker='o')
axes[0, 1].plot([acc*100 for acc in history['val_acc']], label='Val Acc', marker='s')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy (%)')
axes[0, 1].set_title('Training and Validation Accuracy')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# F1 Scores
axes[1, 0].plot(history['val_f1_macro'], label='Val F1 (Macro)', marker='o')
axes[1, 0].plot(history['val_f1_weighted'], label='Val F1 (Weighted)', marker='s')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('F1 Score')
axes[1, 0].set_title('Validation F1 Scores')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Summary table
axes[1, 1].axis('off')
summary_data = [
    ['Metric', 'Value'],
    ['Best Epoch', f"{best_epoch}"],
    ['Best Val F1', f"{best_val_f1:.4f}"],
    ['Test Accuracy', f"{test_acc*100:.2f}%"],
    ['Test F1 (Macro)', f"{test_f1_macro:.4f}"],
    ['Test F1 (Weighted)', f"{test_f1_weighted:.4f}"],
    ['Context Mode', config.CONTEXT_MODE],
    ['Batch Size', f"{config.BATCH_SIZE}"],
]
table = axes[1, 1].table(cellText=summary_data, cellLoc='left', loc='center',
                         colWidths=[0.5, 0.5])
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1, 2)
axes[1, 1].set_title('Training Summary', pad=20)

plt.tight_layout()
plt.savefig(os.path.join(config.OUTPUT_PATH, 'training_history.png'), dpi=300, bbox_inches='tight')
plt.show()

print("Training history plots saved!")

## 13. Save Final Model and Tokenizer

In [None]:
# Save tokenizer
tokenizer_path = os.path.join(config.OUTPUT_PATH, 'tokenizer')
tokenizer.save_pretrained(tokenizer_path)
print(f"Tokenizer saved to {tokenizer_path}")

# Save model in PyTorch format
final_model_path = os.path.join(config.OUTPUT_PATH, 'role_classifier_final.pt')
torch.save({
    'model_state_dict': model.state_dict(),
    'config': {
        'model_name': config.MODEL_NAME,
        'num_labels': config.NUM_LABELS,
        'max_length': config.MAX_LENGTH,
        'context_mode': config.CONTEXT_MODE,
        'roles': config.ROLES,
        'role2id': config.ROLE2ID,
        'id2role': config.ID2ROLE
    },
    'test_metrics': {
        'accuracy': test_acc,
        'f1_macro': test_f1_macro,
        'f1_weighted': test_f1_weighted
    }
}, final_model_path)
print(f"Final model saved to {final_model_path}")

# Save training history
history_df = pd.DataFrame(history)
history_df.to_csv(os.path.join(config.OUTPUT_PATH, 'training_history.csv'), index=False)
print(f"Training history saved!")

print("\n" + "="*60)
print("All artifacts saved successfully!")
print("="*60)
print(f"\nSaved files in {config.OUTPUT_PATH}:")
print("  - best_model.pt (checkpoint with optimizer state)")
print("  - role_classifier_final.pt (final model for inference)")
print("  - tokenizer/ (tokenizer files)")
print("  - classification_report.txt")
print("  - confusion_matrix.png")
print("  - training_history.png")
print("  - training_history.csv")

## 14. Test Inference (Sample Predictions)

In [None]:
# Test inference function
def predict_role(text, model, tokenizer, config, device):
    """Predict rhetorical role for a given text"""
    model.eval()
    
    encoding = tokenizer(
        text,
        max_length=config.MAX_LENGTH,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )
    
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    
    with torch.no_grad():
        logits = model(input_ids, attention_mask)
        probabilities = torch.softmax(logits, dim=1)
        prediction = torch.argmax(logits, dim=1)
    
    pred_label = config.ID2ROLE[prediction.item()]
    confidence = probabilities[0][prediction].item()
    
    return pred_label, confidence, probabilities[0].cpu().numpy()


# Sample predictions
print("\n" + "="*60)
print("Sample Predictions")
print("="*60)

sample_texts = [
    "The petitioner filed a writ petition challenging the constitutional validity of Section 377.",
    "The main issue in this case is whether Section 377 violates fundamental rights.",
    "The petitioner argues that Section 377 is discriminatory and violates Article 14.",
    "The respondent contends that Section 377 is constitutionally valid.",
    "The court finds that Section 377 infringes upon the right to privacy and equality.",
    "Therefore, Section 377 is hereby declared unconstitutional."
]

for i, text in enumerate(sample_texts, 1):
    pred_role, confidence, probs = predict_role(text, model, tokenizer, config, device)
    print(f"\n{i}. Text: {text[:80]}..." if len(text) > 80 else f"\n{i}. Text: {text}")
    print(f"   Predicted Role: {pred_role}")
    print(f"   Confidence: {confidence*100:.2f}%")

print("\n" + "="*60)

## 15. Memory Cleanup

In [None]:
# Cleanup
del model, train_dataset, val_dataset, test_dataset
del train_loader, val_loader, test_loader
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

print("Memory cleaned up successfully!")

## 16. Summary

### Training Complete!

This notebook has:
1. ✅ Loaded the legal document dataset from `/kaggle/input/dataset/Hier_BiLSTM_CRF/`
2. ✅ Preprocessed data with configurable context modes
3. ✅ Trained an InLegalBERT-based classifier for rhetorical role classification
4. ✅ Evaluated the model on validation and test sets
5. ✅ Generated classification reports and visualizations
6. ✅ Saved all artifacts to `/kaggle/working/`

### Optimizations Implemented:
- Mixed Precision Training (AMP)
- Gradient Accumulation
- Memory-efficient data loading
- Configurable dataset sampling
- Automatic garbage collection

### Next Steps:
1. Download the trained model from `/kaggle/working/`
2. Integrate the model into your FastAPI backend
3. Use the model for document segmentation in your RAG pipeline
4. Fine-tune hyperparameters if needed (learning rate, batch size, context mode)

### Model Files:
- `best_model.pt` - Complete checkpoint for resuming training
- `role_classifier_final.pt` - Final model for inference
- `tokenizer/` - Tokenizer for preprocessing