# Serbian Legal NER Pipeline with BERT-CRF - Refactored

This notebook demonstrates the BERT-CRF approach for Serbian Legal NER using shared modules.
BERT-CRF combines BERT embeddings with a Conditional Random Field (CRF) layer for better sequence modeling.

## Key Features:
- **BERT Embeddings**: Contextual word representations
- **CRF Layer**: Enforces valid BIO sequence constraints
- **Improved Performance**: Better handling of sequence dependencies
- **Entity Boundary Detection**: More accurate entity span detection

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

## 1. Environment Setup and Dependencies

In [None]:
# Install required packages including pytorch-crf for CRF layer
!pip install transformers torch datasets tokenizers scikit-learn seqeval pandas numpy matplotlib seaborn tqdm pytorch-crf

In [None]:
# Import shared modules
import sys
import os
import warnings
warnings.filterwarnings('ignore')

# Add the shared modules to path
sys.path.append('/content/drive/MyDrive/NER_Master/ner/')

# Import from shared modules
from shared import (
    # Configuration
    ENTITY_TYPES, BIO_LABELS, DEFAULT_TRAINING_ARGS,
    get_default_model_config, get_paths, setup_environment, get_default_training_args,
    
    # Data processing
    LabelStudioToBIOConverter, load_labelstudio_data, 
    analyze_labelstudio_data, validate_bio_examples,
    
    # Dataset
    NERDataset, split_dataset, tokenize_and_align_labels_with_sliding_window,
    print_sequence_analysis, create_huggingface_datasets,
    
    # Model utilities
    load_model_and_tokenizer, create_training_arguments, create_trainer,
    detailed_evaluation, save_model_info, setup_device_and_seed,
    
    
    # Evaluation
    generate_evaluation_report, plot_training_history, plot_entity_distribution
)

from transformers import DataCollatorForTokenClassification, Trainer
import torch
import torch.nn as nn

# Install pytorch-crf if not available
try:
    from torchcrf import CRF
except ImportError:
    !pip install pytorch-crf
    from torchcrf import CRF

# Setup device and random seed
device = setup_device_and_seed(42)
print(f"🔧 Using device: {device}")

## 2. Configuration and Environment Setup

In [None]:
# Setup environment and paths for Google Colab
env_setup = setup_environment(use_local=False, create_dirs=True)
paths = env_setup['paths']

# Model configuration for BERT-CRF
MODEL_NAME = "classla/bcms-bertic"
# BERT-CRF Configuration (notebook-specific)
BERT_CRF_CONFIG = {
    "dropout_rate": 0.1,
    "bert_lr": 3e-5,
    "classifier_lr": 1e-4,
    "crf_lr": 1e-3,
    "max_length": 512,
    "stride": 128,
    "num_train_epochs": 8,
    "per_device_train_batch_size": 4,
    "per_device_eval_batch_size": 4,
    "learning_rate": 3e-5,
    "warmup_steps": 500,
    "weight_decay": 0.01
}

# Output directory
OUTPUT_DIR = f"{paths['models_dir']}/bertic_bert_crf"
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"🔧 BERT-CRF Configuration:")
print(f"  Model: {MODEL_NAME}")
print(f"  Output directory: {OUTPUT_DIR}")
print(f"  Entity types: {len(ENTITY_TYPES)}")
print(f"  BIO labels: {len(BIO_LABELS)}")
print(f"  Dropout rate: {BERT_CRF_CONFIG['dropout_rate']}")
print(f"  BERT LR: {BERT_CRF_CONFIG['bert_lr']}")
print(f"  CRF LR: {BERT_CRF_CONFIG['crf_lr']}")

## 3. Data Loading and Preprocessing

In [None]:
# Load and analyze LabelStudio data
labelstudio_data = load_labelstudio_data(paths['labelstudio_json'])

if labelstudio_data:
    analysis = analyze_labelstudio_data(labelstudio_data)
    
    # Convert to BIO format
    converter = LabelStudioToBIOConverter(
        judgments_dir=paths['judgments_dir'],
        labelstudio_files_dir=paths.get('labelstudio_files_dir')
    )
    
    bio_examples = converter.convert_to_bio(labelstudio_data)
    print(f"✅ Converted {len(bio_examples)} examples to BIO format")
    
    # Validate BIO examples
    valid_examples, stats = validate_bio_examples(bio_examples)
    print(f"📊 Validation complete: {stats['valid_examples']} valid examples")
else:
    print("❌ No data loaded. Please check your paths.")
    raise Exception("Data loading failed")

## 4. Dataset Preparation and Splitting

In [None]:
# Create NER dataset
ner_dataset = NERDataset(valid_examples)
prepared_examples = ner_dataset.prepare_for_training()

# Split dataset
train_examples, val_examples, test_examples = split_dataset(
    prepared_examples, test_size=0.2, val_size=0.1, random_state=42
)

print(f"📊 Dataset split:")
print(f"  Training: {len(train_examples)} examples")
print(f"  Validation: {len(val_examples)} examples")
print(f"  Test: {len(test_examples)} examples")
print(f"  Total labels: {ner_dataset.get_num_labels()}")

## 5. BERT-CRF Model Creation

In [None]:
# BERT-CRF Functions to Add to BERT-CRF Notebook
# Copy this code into a new cell in the BERT-CRF notebook after the imports

try:
    from torchcrf import CRF
except ImportError:
    !pip install pytorch-crf
    from torchcrf import CRF

import torch
import torch.nn as nn
from transformers import Trainer

# BERT-CRF Model Implementation (notebook-specific)
class BertCrfForTokenClassification(nn.Module):
    """BERT model with CRF layer for token classification"""
    
    def __init__(self, bert_model, num_labels, dropout_rate=0.1):
        super().__init__()
        self.num_labels = num_labels
        self.bert = bert_model
        self.dropout = nn.Dropout(dropout_rate)
        self.classifier = nn.Linear(bert_model.config.hidden_size, num_labels)
        self.crf = CRF(num_labels, batch_first=True)
        
    def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)
        
        if labels is not None:
            # Create mask for CRF (exclude padding and special tokens)
            # Use attention mask directly - CRF requires first timestep to be valid
            mask = attention_mask == 1
            
            # Replace -100 with 0 for CRF computation
            crf_labels = labels.clone()
            crf_labels[labels == -100] = 0
            
            # Ensure first timestep is valid for CRF (required by torchcrf)
            # The first token (CLS) should always be valid but ignored in loss
            
            # Compute CRF loss
            loss = -self.crf(logits, crf_labels, mask=mask, reduction='mean')
            return {'loss': loss, 'logits': logits}
        else:
            # Decode best path
            mask = attention_mask == 1
            predictions = self.crf.decode(logits, mask=mask)
            return {'logits': logits, 'predictions': predictions}

def create_bert_crf_model(model_name, num_labels, dropout_rate=0.1):
    """Create BERT-CRF model"""
    from transformers import AutoModel
    
    # Load BERT model (without classification head)
    bert_model = AutoModel.from_pretrained(model_name)
    
    # Create BERT-CRF model
    model = BertCrfForTokenClassification(
        bert_model=bert_model,
        num_labels=num_labels,
        dropout_rate=dropout_rate
    )
    
    return model

print("✅ BERT-CRF classes and functions defined")


In [None]:
# Create BERT-CRF model
model = create_bert_crf_model(
    model_name=MODEL_NAME,
    num_labels=ner_dataset.get_num_labels(),
    dropout_rate=BERT_CRF_CONFIG['dropout_rate']
)

# Load tokenizer
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

print(f"✅ BERT-CRF model created successfully")
print(f"📊 Model parameters: {sum(p.numel() for p in model.parameters()):,}")

## 6. Data Tokenization

In [None]:
# Tokenize datasets with sliding window
print("🔤 Tokenizing datasets...")

train_tokenized = tokenize_and_align_labels_with_sliding_window(
    train_examples, tokenizer, ner_dataset.label_to_id, 
    max_length=BERT_CRF_CONFIG['max_length'], 
    stride=BERT_CRF_CONFIG['stride']
)

val_tokenized = tokenize_and_align_labels_with_sliding_window(
    val_examples, tokenizer, ner_dataset.label_to_id,
    max_length=BERT_CRF_CONFIG['max_length'], 
    stride=BERT_CRF_CONFIG['stride']
)

test_tokenized = tokenize_and_align_labels_with_sliding_window(
    test_examples, tokenizer, ner_dataset.label_to_id,
    max_length=BERT_CRF_CONFIG['max_length'], 
    stride=BERT_CRF_CONFIG['stride']
)

# Create HuggingFace datasets
train_dataset, val_dataset, test_dataset = create_huggingface_datasets(
    train_tokenized, val_tokenized, test_tokenized
)

# Data collator
data_collator = DataCollatorForTokenClassification(
    tokenizer=tokenizer,
    padding=True,
    return_tensors="pt"
)

print("✅ Tokenization complete")

## 7. BERT-CRF Training Setup

In [None]:
# Create training arguments for BERT-CRF
training_args = create_training_arguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=BERT_CRF_CONFIG['num_train_epochs'],
    per_device_train_batch_size=BERT_CRF_CONFIG['per_device_train_batch_size'],
    per_device_eval_batch_size=BERT_CRF_CONFIG['per_device_eval_batch_size'],
    learning_rate=BERT_CRF_CONFIG['bert_lr'],  # Use BERT learning rate
    warmup_steps=BERT_CRF_CONFIG['warmup_steps'],
    weight_decay=BERT_CRF_CONFIG['weight_decay'],
    logging_steps=50,
    eval_steps=100,
    save_steps=500,
    early_stopping_patience=3
)

# Create BERT-CRF trainer using shared function
trainer = create_trainer(
    model=model,
    training_args=training_args,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer,
    id_to_label=ner_dataset.id_to_label,
    early_stopping_patience=3
)

print("🏋️  BERT-CRF trainer created successfully")

## 8. Model Training

In [None]:
# Start BERT-CRF training
print("🚀 Starting BERT-CRF training...")
print("⚡ This may take longer than standard BERT due to CRF layer")

trainer.train()

print("💾 Saving BERT-CRF model...")
trainer.save_model()
tokenizer.save_pretrained(OUTPUT_DIR)

# Save model info with CRF-specific details
save_model_info(
    output_dir=OUTPUT_DIR,
    model_name=MODEL_NAME,
    model_type="bert_crf",
    num_labels=ner_dataset.get_num_labels(),
    id_to_label=ner_dataset.id_to_label,
    label_to_id=ner_dataset.label_to_id,
    training_args=training_args,
    additional_info={
        "dropout_rate": BERT_CRF_CONFIG['dropout_rate'],
        "bert_lr": BERT_CRF_CONFIG['bert_lr'],
        "crf_lr": BERT_CRF_CONFIG['crf_lr'],
        "uses_crf": True
    }
)

print("✅ BERT-CRF training completed!")

## 9. Model Evaluation

In [None]:
# Evaluate BERT-CRF model on test set using shared function
print("📊 Evaluating BERT-CRF model on test set...")

test_results = detailed_evaluation(
    trainer=trainer,
    dataset=test_dataset,
    dataset_name="Test (BERT-CRF)",
    id_to_label=ner_dataset.id_to_label
)

## 11. Training History and Visualization

In [None]:
# Plot training history
plot_training_history(trainer)

# Plot entity distribution
label_stats = ner_dataset.get_label_statistics()
plot_entity_distribution(label_stats['entity_counts'])

## 13. Summary and Results

In [None]:
print("\n🎯 BERT-CRF FINAL SUMMARY")
print("=" * 50)
print(f"Model: {MODEL_NAME} + CRF")
print(f"Training examples: {len(train_examples)}")
print(f"Validation examples: {len(val_examples)}")
print(f"Test examples: {len(test_examples)}")
print(f"Entity types: {len(ENTITY_TYPES)}")
print(f"BIO labels: {len(BIO_LABELS)}")
print(f"\nBERT-CRF Configuration:")
print(f"  Dropout rate: {BERT_CRF_CONFIG['dropout_rate']}")
print(f"  BERT learning rate: {BERT_CRF_CONFIG['bert_lr']}")
print(f"  CRF learning rate: {BERT_CRF_CONFIG['crf_lr']}")
print(f"\nTest Performance:")
print(f"  Precision: {test_results['precision']:.4f}")
print(f"  Recall: {test_results['recall']:.4f}")
print(f"  F1-score: {test_results['f1']:.4f}")
print(f"  Accuracy: {test_results['accuracy']:.4f}")
print(f"\nModel saved to: {OUTPUT_DIR}")
print("\n✅ BERT-CRF pipeline completed successfully!")
print("\n💡 CRF layer helps with:")
print("   • Valid BIO sequence constraints")
print("   • Better entity boundary detection")
print("   • Improved sequence modeling")