# Serbian Legal NER Pipeline with Class Weights - Refactored

This notebook demonstrates the class-weighted approach for Serbian Legal NER using shared modules.
Class weights help address the imbalanced distribution of entity types in legal documents.

## Key Features:
- **Class Weighting**: Higher weights for rare entity types
- **Imbalance Handling**: Better performance on underrepresented entities
- **Weighted Loss Function**: Penalizes misclassification of rare entities more heavily
- **Improved Recall**: Better detection of low-frequency entities like CASE_NUMBER, JUDGE

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

## 1. Environment Setup and Dependencies

In [None]:
# Install required packages
!pip install transformers torch datasets tokenizers scikit-learn seqeval pandas numpy matplotlib seaborn tqdm

In [None]:
# Import shared modules
import sys
import os
import warnings
warnings.filterwarnings('ignore')

# Add the shared modules to path
sys.path.append('/content/drive/MyDrive/NER_Master/ner/')

# Reload shared modules to get latest changes
import importlib
import shared
import shared.model_utils
import shared.data_processing
import shared.dataset
import shared.evaluation
import shared.config
importlib.reload(shared.config)
importlib.reload(shared.data_processing)
importlib.reload(shared.dataset)
importlib.reload(shared.model_utils)
importlib.reload(shared.evaluation)
importlib.reload(shared)

# Import from shared modules
from shared import (
    # Configuration
    ENTITY_TYPES, BIO_LABELS, DEFAULT_TRAINING_ARGS,
    get_default_model_config, get_paths, setup_environment, get_default_training_args,
    
    # Data processing
    LabelStudioToBIOConverter, load_labelstudio_data, 
    analyze_labelstudio_data, validate_bio_examples,
    
    # Dataset
    NERDataset, split_dataset, tokenize_and_align_labels_with_sliding_window,
    print_sequence_analysis, create_huggingface_datasets,
    
    # Model utilities
    load_model_and_tokenizer, create_training_arguments, create_trainer,
    detailed_evaluation, save_model_info, setup_device_and_seed,
    load_inference_pipeline,

    
    # Evaluation
    generate_evaluation_report, plot_training_history, plot_entity_distribution
)

from transformers import DataCollatorForTokenClassification, Trainer
import torch
import torch.nn as nn
from collections import Counter
import numpy as np

# Class Weights Configuration (notebook-specific)
DEFAULT_CLASS_WEIGHTS = {
    "CASE_NUMBER": 43.24,
    "JUDGE": 22.90,
    "REGISTRAR": 23.68,
    "SANCTION_TYPE": 25.78,
    "PROCEDURE_COSTS": 23.13,
    "COURT": 1.0,
    "CRIMINAL_ACT": 1.0,
    "DEFENDANT": 1.0,
    "PROSECUTOR": 1.0,
    "PROVISION": 1.0
}

# Setup device and random seed
device = setup_device_and_seed(42)
print(f"🔧 Using device: {device}")

In [None]:
# Class Weights Functions to Add to Class Weights Notebook
# Copy this code into a new cell in the Class Weights notebook after the imports

import torch
import torch.nn as nn
from transformers import Trainer
from collections import Counter
import numpy as np

# Class Weights Configuration (notebook-specific)
DEFAULT_CLASS_WEIGHTS = {
    "CASE_NUMBER": 43.24,
    "JUDGE": 22.90,
    "REGISTRAR": 23.68,
    "SANCTION_TYPE": 25.78,
    "PROCEDURE_COSTS": 23.13,
    "COURT": 1.0,
    "CRIMINAL_ACT": 1.0,
    "DEFENDANT": 1.0,
    "PROSECUTOR": 1.0,
    "PROVISION": 1.0
}

def calculate_class_weights(examples, label_to_id, method="inverse_frequency"):
    """Calculate class weights based on label frequency"""
    
    # Count label frequencies
    label_counts = Counter()
    total_tokens = 0
    
    for example in examples:
        for label in example['labels']:
            if label != 'O':  # Ignore 'O' labels
                entity_type = label.split('-')[-1] if '-' in label else label
                label_counts[entity_type] += 1
                total_tokens += 1
    
    # Calculate weights
    weights = {}
    
    if method == "inverse_frequency":
        # Inverse frequency weighting
        for entity_type in label_counts:
            frequency = label_counts[entity_type] / total_tokens
            weights[entity_type] = 1.0 / frequency
            
        # Normalize weights
        min_weight = min(weights.values())
        for entity_type in weights:
            weights[entity_type] = weights[entity_type] / min_weight
            
    elif method == "balanced":
        # Balanced class weights (sklearn style)
        n_classes = len(label_counts)
        for entity_type in label_counts:
            weights[entity_type] = total_tokens / (n_classes * label_counts[entity_type])
    
    return weights

# Custom Weighted Trainer (notebook-specific)
class WeightedTrainer(Trainer):
    """Custom trainer that applies class weights to the loss function"""
    
    def __init__(self, class_weights=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.class_weights = class_weights
        if class_weights is not None:
            # Convert to tensor and move to device
            self.class_weights_tensor = torch.tensor(
                list(class_weights.values()), 
                dtype=torch.float32
            )
    
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """
        Override compute_loss to apply class weights
        """
        labels = inputs.get("labels")
        outputs = model(**inputs)
        
        if labels is not None and self.class_weights is not None:
            # Move class weights to the same device as the model
            if self.class_weights_tensor.device != outputs.logits.device:
                self.class_weights_tensor = self.class_weights_tensor.to(outputs.logits.device)
            
            # Compute weighted cross entropy loss
            loss_fct = nn.CrossEntropyLoss(weight=self.class_weights_tensor, ignore_index=-100)
            
            # Flatten for loss computation
            active_loss = inputs["attention_mask"].view(-1) == 1
            active_logits = outputs.logits.view(-1, self.model.config.num_labels)
            active_labels = torch.where(
                active_loss, 
                labels.view(-1), 
                torch.tensor(loss_fct.ignore_index).type_as(labels)
            )
            
            loss = loss_fct(active_logits, active_labels)
        else:
            # Use default loss if no class weights
            loss = outputs.loss
        
        return (loss, outputs) if return_outputs else loss

print("✅ Class weights functions and WeightedTrainer defined")


## 2. Configuration and Environment Setup

In [None]:
# Setup environment and paths for Google Colab
env_setup = setup_environment(use_local=False, create_dirs=True)
paths = env_setup['paths']

# Model configuration for class weights
MODEL_NAME = "classla/bcms-bertic"
experiment_config = get_experiment_config("class_weights")

# Output directory
OUTPUT_DIR = f"{paths['models_dir']}/bertic_class_weights_refactored"
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"🔧 Class Weights Configuration:")
print(f"  Model: {MODEL_NAME}")
print(f"  Output directory: {OUTPUT_DIR}")
print(f"  Entity types: {len(ENTITY_TYPES)}")
print(f"  BIO labels: {len(BIO_LABELS)}")
print(f"  Learning rate: {experiment_config['learning_rate']}")
print(f"  Epochs: {experiment_config['num_train_epochs']}")
print(f"  Batch size: {experiment_config['per_device_train_batch_size']}")

## 3. Data Loading and Preprocessing

In [None]:
# Load and analyze LabelStudio data
labelstudio_data = load_labelstudio_data(paths['labelstudio_json'])

if labelstudio_data:
    analysis = analyze_labelstudio_data(labelstudio_data)
    
    # Convert to BIO format
    converter = LabelStudioToBIOConverter(
        judgments_dir=paths['judgments_dir'],
        labelstudio_files_dir=paths.get('labelstudio_files_dir')
    )
    
    bio_examples = converter.convert_to_bio(labelstudio_data)
    print(f"✅ Converted {len(bio_examples)} examples to BIO format")
    
    # Validate BIO examples
    valid_examples, stats = validate_bio_examples(bio_examples)
    print(f"📊 Validation complete: {stats['valid_examples']} valid examples")
else:
    print("❌ No data loaded. Please check your paths.")
    raise Exception("Data loading failed")

## 4. Dataset Preparation and Class Weight Calculation

In [None]:
# Create NER dataset
ner_dataset = NERDataset(valid_examples)
prepared_examples = ner_dataset.prepare_for_training()

# Split dataset
train_examples, val_examples, test_examples = split_dataset(
    prepared_examples, test_size=0.2, val_size=0.1, random_state=42
)

# Calculate class weights based on training data
class_weights = calculate_class_weights(
    train_examples, 
    ner_dataset.label_to_id,
    method='inverse_frequency'
)

print(f"📊 Dataset split:")
print(f"  Training: {len(train_examples)} examples")
print(f"  Validation: {len(val_examples)} examples")
print(f"  Test: {len(test_examples)} examples")
print(f"  Total labels: {ner_dataset.get_num_labels()}")

print(f"\n⚖️  Calculated class weights:")
for label_id, weight in enumerate(class_weights):
    label_name = ner_dataset.id_to_label[label_id]
    print(f"  {label_name}: {weight:.2f}")

## 5. Model Creation with Class Weights

In [None]:
# Load model and tokenizer
model, tokenizer = load_model_and_tokenizer(
    MODEL_NAME, 
    ner_dataset.get_num_labels(),
    ner_dataset.id_to_label,
    ner_dataset.label_to_id
)

# Convert class weights to tensor and move to device
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float).to(device)

# Create custom weighted loss function
import torch.nn as nn

class WeightedCrossEntropyLoss(nn.Module):
    def __init__(self, weights):
        super().__init__()
        self.weights = weights
        self.loss_fn = nn.CrossEntropyLoss(weight=weights, ignore_index=-100)
    
    def forward(self, logits, labels):
        # Reshape for loss calculation
        active_loss = labels.view(-1) != -100
        active_logits = logits.view(-1, logits.shape[-1])[active_loss]
        active_labels = labels.view(-1)[active_loss]
        
        return self.loss_fn(active_logits, active_labels)

# Replace model's loss function
weighted_loss = WeightedCrossEntropyLoss(class_weights_tensor)

print(f"✅ Model created with weighted loss function")
print(f"📊 Model parameters: {sum(p.numel() for p in model.parameters()):,}")

## 6. Data Tokenization

In [None]:
# Tokenize datasets with sliding window
print("🔤 Tokenizing datasets...")

train_tokenized = tokenize_and_align_labels_with_sliding_window(
    train_examples, tokenizer, ner_dataset.label_to_id, 
    max_length=experiment_config['max_length'], 
    stride=experiment_config['stride']
)

val_tokenized = tokenize_and_align_labels_with_sliding_window(
    val_examples, tokenizer, ner_dataset.label_to_id,
    max_length=experiment_config['max_length'], 
    stride=experiment_config['stride']
)

test_tokenized = tokenize_and_align_labels_with_sliding_window(
    test_examples, tokenizer, ner_dataset.label_to_id,
    max_length=experiment_config['max_length'], 
    stride=experiment_config['stride']
)

# Create HuggingFace datasets
train_dataset, val_dataset, test_dataset = create_huggingface_datasets(
    train_tokenized, val_tokenized, test_tokenized
)

# Data collator
data_collator = DataCollatorForTokenClassification(
    tokenizer=tokenizer,
    padding=True,
    return_tensors="pt"
)

print("✅ Tokenization complete")

## 7. Training Setup with Weighted Loss

In [None]:
# Create custom trainer with weighted loss
from transformers import Trainer

class WeightedTrainer(Trainer):
    def __init__(self, weighted_loss_fn, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.weighted_loss_fn = weighted_loss_fn
    
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs.get("labels")
        outputs = model(**inputs)
        logits = outputs.get("logits")
        
        # Use weighted loss
        loss = self.weighted_loss_fn(logits, labels)
        
        return (loss, outputs) if return_outputs else loss

# Create training arguments
training_args = create_training_arguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=experiment_config['num_train_epochs'],
    per_device_train_batch_size=experiment_config['per_device_train_batch_size'],
    per_device_eval_batch_size=experiment_config['per_device_eval_batch_size'],
    learning_rate=experiment_config['learning_rate'],
    warmup_steps=experiment_config['warmup_steps'],
    weight_decay=experiment_config['weight_decay'],
    logging_steps=50,
    eval_steps=100,
    save_steps=500,
    early_stopping_patience=3
)

# Create weighted trainer
trainer = WeightedTrainer(
    weighted_loss_fn=weighted_loss,
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=lambda eval_pred: {
        'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'accuracy': 0.0
    }  # Placeholder metrics
)

print("🏋️  Weighted trainer created successfully")

## 8. Model Training with Class Weights

In [None]:
# Start training with class weights
print("🚀 Starting training with class weights...")
print("⚖️  Using weighted loss to handle class imbalance")

trainer.train()

print("💾 Saving class-weighted model...")
trainer.save_model()
tokenizer.save_pretrained(OUTPUT_DIR)

# Save model info with class weights details
save_model_info(
    output_dir=OUTPUT_DIR,
    model_name=MODEL_NAME,
    model_type="class_weighted",
    num_labels=ner_dataset.get_num_labels(),
    id_to_label=ner_dataset.id_to_label,
    label_to_id=ner_dataset.label_to_id,
    training_args=training_args,
    additional_info={
        "class_weights": class_weights.tolist(),
        "uses_class_weights": True,
        "weight_method": "inverse_frequency"
    }
)

print("✅ Class-weighted training completed!")

## 9. Model Evaluation

In [None]:
# Evaluate class-weighted model on test set
print("📊 Evaluating class-weighted model on test set...")

test_results = detailed_evaluation(
    trainer=trainer,
    dataset=test_dataset,
    dataset_name="Test (Class Weighted)",
    id_to_label=ner_dataset.id_to_label
)

print(f"\n📈 Class-Weighted Test Results:")
print(f"  Precision: {test_results['precision']:.4f}")
print(f"  Recall: {test_results['recall']:.4f}")
print(f"  F1-score: {test_results['f1']:.4f}")
print(f"  Accuracy: {test_results['accuracy']:.4f}")

print(f"\n💡 Expected improvements with class weights:")
print(f"  ✅ Better recall for rare entities (CASE_NUMBER, JUDGE)")
print(f"  ✅ More balanced precision/recall across entity types")
print(f"  ✅ Reduced bias towards frequent entities")

## 10. Comprehensive Analysis

In [None]:
# Generate comprehensive evaluation report
evaluation_report = generate_evaluation_report(
    true_labels=test_results['true_labels'],
    predictions=test_results['true_predictions'],
    dataset_name="Test (Class Weighted)",
    focus_entities=["CASE_NUMBER", "JUDGE", "REGISTRAR", "SANCTION_TYPE", "PROCEDURE_COSTS"]
)

# Show class weight impact
print("\n⚖️  Class Weight Impact Analysis:")
print("\nHighest weighted entities (should see improved recall):")
weight_label_pairs = [(class_weights[ner_dataset.label_to_id[label]], label) 
                     for label in ner_dataset.label_to_id.keys() if label != 'O']
weight_label_pairs.sort(reverse=True)

for weight, label in weight_label_pairs[:10]:
    print(f"  {label}: weight = {weight:.2f}")

## 11. Training History and Visualization

In [None]:
# Plot training history
plot_training_history(trainer)

# Plot entity distribution
label_stats = ner_dataset.get_label_statistics()
plot_entity_distribution(label_stats['entity_counts'])

## 12. Inference Pipeline Testing

In [None]:
# Load class-weighted inference pipeline
pipeline = load_inference_pipeline(
    model_path=OUTPUT_DIR,
    max_length=experiment_config['max_length'],
    stride=experiment_config['stride']
)

# Test with sample text
sample_text = """Основни суд у Београду донео је пресуду у кривичном предмету К-1234/2023 против оптуженог Марка Петровића за кривично дело крађе из члана 203 Кривичног законика. Судија Ана Николић изрекла је казну затвора у трајању од 6 месеци."""

print("🔍 Testing class-weighted inference pipeline:")
print(f"Input text: {sample_text}")
print("\n📋 Detected entities (with class weighting):")

entities = pipeline.predict(sample_text)
for entity in entities:
    weight = class_weights[ner_dataset.label_to_id.get(f"B-{entity['label']}", 0)]
    print(f"  {entity['label']}: '{entity['text']}' (weight: {weight:.2f})")

print(f"\n✅ Found {len(entities)} entities using class-weighted model")

## 13. Summary and Results

In [None]:
print("\n🎯 CLASS-WEIGHTED FINAL SUMMARY")
print("=" * 50)
print(f"Model: {MODEL_NAME} + Class Weights")
print(f"Training examples: {len(train_examples)}")
print(f"Validation examples: {len(val_examples)}")
print(f"Test examples: {len(test_examples)}")
print(f"Entity types: {len(ENTITY_TYPES)}")
print(f"BIO labels: {len(BIO_LABELS)}")
print(f"\nClass Weighting Configuration:")
print(f"  Weight method: inverse_frequency")
print(f"  Learning rate: {experiment_config['learning_rate']}")
print(f"  Epochs: {experiment_config['num_train_epochs']}")
print(f"\nTest Performance:")
print(f"  Precision: {test_results['precision']:.4f}")
print(f"  Recall: {test_results['recall']:.4f}")
print(f"  F1-score: {test_results['f1']:.4f}")
print(f"  Accuracy: {test_results['accuracy']:.4f}")
print(f"\nModel saved to: {OUTPUT_DIR}")
print("\n✅ Class-weighted pipeline completed successfully!")
print("\n💡 Class weighting helps with:")
print("   • Better recall for rare entities")
print("   • Handling class imbalance")
print("   • More balanced performance across entity types")