# Multi-Task PII Detection and Co-reference Detection Training

This notebook trains a BERT-like model for two tasks:
1. **PII Detection**: Identifying Personally Identifiable Information in text
2. **Co-reference Detection**: Identifying mentions that refer to the same entity

## Features
- Multi-task learning with shared encoder and separate classification heads
- Comprehensive metrics: Precision, Recall, F1 (weighted, macro, per-class)
- **Apple Silicon (M4) GPU support** via Metal Performance Shaders (MPS)
- Automatic device detection (MPS > CUDA > CPU)
- Support for custom loss functions and class weights

## Requirements
- PyTorch with MPS support (for M4 Mac): `pip install torch torchvision torchaudio`
- Training data in `model/dataset/training_samples/` directory
- Python packages: transformers, datasets, scikit-learn, accelerate


## 1. Setup and Installation (Local M4 Mac)


In [1]:
# Install required packages (if not already installed)
# Uncomment the line below if you need to install packages
# %pip install -q transformers datasets scikit-learn torch accelerate

# For M4 Mac, make sure you have PyTorch with MPS support:
# pip install torch torchvision torchaudio


## 2. Verify Training Data

Make sure your training samples are in `model/dataset/training_samples/` directory.


In [None]:
# Verify local training data directory exists
from pathlib import Path

# Get project root (assuming notebook is in model/ directory)
project_root = Path.cwd().parent if Path.cwd().name == "model" else Path.cwd()
training_samples_dir = project_root / "model" / "dataset" / "training_samples"

print("üìÅ Checking for training data...")
print(f"   Project root: {project_root}")
print(f"   Training samples dir: {training_samples_dir}")

if training_samples_dir.exists():
    json_files = list(training_samples_dir.glob("*.json"))
    print("‚úÖ Found training samples directory")
    print(f"   Number of JSON files: {len(json_files)}")
    if len(json_files) > 0:
        print(f"   Sample files: {json_files[:3]}")
    else:
        print(f"‚ö†Ô∏è  Warning: No JSON files found in {training_samples_dir}")
        print("   Make sure your training samples are in this directory")
else:
    print(f"‚ùå Training samples directory not found: {training_samples_dir}")
    print("   Please ensure the model/dataset/training_samples directory exists")
    print("   and contains your training JSON files")
    raise FileNotFoundError(
        f"Training samples directory not found: {training_samples_dir}"
    )

print("\n‚úÖ Training data ready!")

üìÅ Checking for training data...
   Project root: /Users/hannes/opensource/kiji-proxy
   Training samples dir: /Users/hannes/opensource/kiji-proxy/dataset/training_samples
‚úÖ Found training samples directory
   Number of JSON files: 5
   Sample files: [PosixPath('/Users/hannes/opensource/kiji-proxy/dataset/training_samples/20251124103832_fb0dd1a3caa8842bb2a1c9af9bbf3592e9c98d2545f4224d93c176e0e9ba7612.json'), PosixPath('/Users/hannes/opensource/kiji-proxy/dataset/training_samples/20251124103840_8dc19700e93885415fa096ea64e4717a6ec6c474a2eb14efe55d32d88226a158.json'), PosixPath('/Users/hannes/opensource/kiji-proxy/dataset/training_samples/20251124103848_29a096d927d7e4ee857641892e54733d3e7fd03c08dbaf454bd6b52845bbd532.json')]

‚úÖ Training data ready!


## 3. Configuration


In [None]:
import logging
import sys
import time
from pathlib import Path

# Set up logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

# Determine project root - handle both notebook execution from model/ and project root
current_dir = Path.cwd()
if current_dir.name == "model":
    # Running from model/ directory
    project_root = current_dir.parent
    # Add both project root and model directory to path
    if str(project_root) not in sys.path:
        sys.path.insert(0, str(project_root))
    if str(current_dir) not in sys.path:
        sys.path.insert(0, str(current_dir))
else:
    # Running from project root
    project_root = current_dir
    model_dir = current_dir / "model"
    if str(project_root) not in sys.path:
        sys.path.insert(0, str(project_root))
    if str(model_dir) not in sys.path:
        sys.path.insert(0, str(model_dir))

print(f"üìÅ Current directory: {current_dir}")
print(f"üìÅ Project root: {project_root}")
print(f"üìÅ Python path includes: {list(sys.path[:3])}")

# Import training modules - try multiple import strategies
try:
    # Try absolute import from project root
    from model.config import EnvironmentSetup, TrainingConfig
    from model.preprocessing import DatasetProcessor
    from model.trainer import PIITrainer

    print("‚úÖ Imports successful (absolute from project root)")
except ImportError:
    try:
        # Try relative import (when running from model/ directory)
        from config import EnvironmentSetup, TrainingConfig
        from preprocessing import DatasetProcessor
        from trainer import PIITrainer

        print("‚úÖ Imports successful (relative from model directory)")
    except ImportError as e:
        print(f"‚ùå Import failed: {e}")
        raise

üìÅ Current directory: /Users/hannes/opensource/kiji-proxy/model
üìÅ Project root: /Users/hannes/opensource/kiji-proxy
üìÅ Python path includes: ['/Users/hannes/opensource/kiji-proxy/model', '/Users/hannes/.local/share/uv/python/cpython-3.13.9-macos-aarch64-none/lib/python313.zip', '/Users/hannes/.local/share/uv/python/cpython-3.13.9-macos-aarch64-none/lib/python3.13']


  from .autonotebook import tqdm as notebook_tqdm


‚úÖ Imports successful (relative from model directory)


In [None]:
# Configure training parameters
# Get project root for proper path resolution
project_root = Path.cwd().parent if Path.cwd().name == "model" else Path.cwd()

config = TrainingConfig(
    # Model settings
    model_name="distilbert-base-cased",  # or "bert-base-cased", "roberta-base", etc.
    # Training parameters
    num_epochs=3,
    batch_size=32,  # Adjust based on available memory (M4 Mac can handle this)
    learning_rate=3e-5,
    # Training optimization
    warmup_steps=500,
    weight_decay=0.01,
    save_steps=1000,
    eval_steps=500,
    logging_steps=100,
    seed=42,
    # Output settings
    output_dir=str(project_root / "model" / "trained"),  # Save in model directory
    use_wandb=False,  # Set to True if using Weights & Biases
    use_custom_loss=True,
    # Dataset settings
    eval_size_ratio=0.2,  # 20% for validation
    training_samples_dir=str(
        project_root / "model" / "dataset" / "training_samples"
    ),  # Local path
    # Multi-task learning weights
    pii_loss_weight=1.0,
    coref_loss_weight=1.0,
)

# Print configuration summary
config.print_summary()

2025-11-24 14:23:49,411 - INFO - 
üìã Training Configuration:
2025-11-24 14:23:49,412 - INFO -   Model: distilbert-base-cased
2025-11-24 14:23:49,412 - INFO -   Epochs: 3
2025-11-24 14:23:49,412 - INFO -   Batch Size: 32
2025-11-24 14:23:49,412 - INFO -   Learning Rate: 3e-05
2025-11-24 14:23:49,413 - INFO -   Max Samples: 400000
2025-11-24 14:23:49,413 - INFO -   Output Dir: /Users/hannes/opensource/kiji-proxy/pii_model
2025-11-24 14:23:49,413 - INFO -   Custom Loss: True


## 4. Environment Setup


In [5]:
# Disable wandb (if not using it)
EnvironmentSetup.disable_wandb()

# Check device availability (MPS for M4 Mac, CUDA for NVIDIA, or CPU)
EnvironmentSetup.check_gpu()

# Get the device that will be used for training
device = EnvironmentSetup.get_device()
print(f"\nüñ•Ô∏è  Training will use device: {device}")
if device.type == "mps":
    print("   ‚úÖ Using Apple Silicon GPU acceleration (Metal Performance Shaders)")
elif device.type == "cuda":
    print("   ‚úÖ Using NVIDIA GPU acceleration")
else:
    print("   ‚ö†Ô∏è  Using CPU (training will be slower)")

2025-11-24 14:23:49,417 - INFO - ‚úÖ Weights & Biases (wandb) disabled
2025-11-24 14:23:49,430 - INFO - 
‚úÖ MPS (Metal) available: True
2025-11-24 14:23:49,430 - INFO -    Using Apple Silicon GPU acceleration
2025-11-24 14:23:49,430 - INFO -    Device: mps



üñ•Ô∏è  Training will use device: mps
   ‚úÖ Using Apple Silicon GPU acceleration (Metal Performance Shaders)


## 5. Load and Prepare Datasets


In [6]:
# Initialize dataset processor
logger.info("\nüì• Preparing datasets...")
dataset_processor = DatasetProcessor(config)

# Load and prepare training/validation datasets
train_dataset, val_dataset, mappings, coref_info = dataset_processor.prepare_datasets()

print("\n‚úÖ Datasets prepared:")
print(f"  Training samples: {len(train_dataset)}")
print(f"  Validation samples: {len(val_dataset)}")

2025-11-24 14:23:49,434 - INFO - 
üì• Preparing datasets...
2025-11-24 14:23:49,857 - INFO - 
üì• Loading training samples from /Users/hannes/opensource/kiji-proxy/dataset/training_samples...
2025-11-24 14:23:49,857 - INFO - Found 5 JSON files
2025-11-24 14:23:49,861 - INFO - ‚úÖ Loaded 5 training samples
2025-11-24 14:23:49,951 - INFO - ‚úÖ Label mappings saved to /Users/hannes/opensource/kiji-proxy/pii_model/label_mappings.json
2025-11-24 14:23:49,952 - INFO - 
üìä Dataset Summary:
2025-11-24 14:23:49,952 - INFO -   Training samples: 4
2025-11-24 14:23:49,952 - INFO -   Validation samples: 1
2025-11-24 14:23:49,952 - INFO -   PII labels: 49
2025-11-24 14:23:49,952 - INFO -   Co-reference labels: 4



‚úÖ Datasets prepared:
  Training samples: 4
  Validation samples: 1


## 6. Initialize Model and Trainer


In [7]:
# Initialize trainer
logger.info("\nüîß Initializing trainer...")
trainer = PIITrainer(config)

# Load label mappings
trainer.load_label_mappings(mappings, coref_info)

# Initialize model
trainer.initialize_model()

print("\n‚úÖ Model and trainer initialized successfully!")

2025-11-24 14:23:49,957 - INFO - 
üîß Initializing trainer...
2025-11-24 14:23:50,320 - INFO - ‚úÖ Loaded 49 PII label mappings
2025-11-24 14:23:50,320 - INFO - ‚úÖ Loaded 4 co-reference label mappings
2025-11-24 14:23:50,462 - INFO - ‚úÖ Initialized multi-task loss (PII: 49 classes, Co-ref: 4 classes)
2025-11-24 14:23:50,462 - INFO - ‚úÖ Model initialized with 49 PII labels and 4 co-reference labels



‚úÖ Model and trainer initialized successfully!


## 7. Train the Model


In [8]:
# Start training
logger.info("\nüèãÔ∏è  Starting training...")
logger.info("=" * 60)

start_time = time.time()
trained_trainer = trainer.train(train_dataset, val_dataset)
training_time = time.time() - start_time

logger.info(f"\n‚è±Ô∏è  Training completed in {training_time / 60:.1f} minutes")

2025-11-24 14:23:50,469 - INFO - 
üèãÔ∏è  Starting training...
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
2025-11-24 14:23:50,561 - INFO - ‚úÖ Using MultiTaskTrainer with multi-task loss
2025-11-24 14:23:50,561 - INFO - 
üèãÔ∏è  Starting multi-task training...


Step,Training Loss,Validation Loss


2025-11-24 14:23:53,386 - INFO - 
‚úÖ Training completed. Model saved to /Users/hannes/opensource/kiji-proxy/pii_model
2025-11-24 14:23:53,386 - INFO - 
‚è±Ô∏è  Training completed in 0.0 minutes


## 8. Evaluate the Model


In [9]:
# Evaluate on validation set
logger.info("\nüìä Evaluating model...")
results = trainer.evaluate(val_dataset, trained_trainer)

# Display key metrics
print("\n" + "=" * 60)
print("üìä EVALUATION RESULTS")
print("=" * 60)
print("\nüîç PII Detection Metrics:")
print(f"  F1 (weighted): {results.get('eval_pii_f1_weighted', 'N/A'):.4f}")
print(f"  F1 (macro): {results.get('eval_pii_f1_macro', 'N/A'):.4f}")
print(
    f"  Precision (weighted): {results.get('eval_pii_precision_weighted', 'N/A'):.4f}"
)
print(f"  Precision (macro): {results.get('eval_pii_precision_macro', 'N/A'):.4f}")
print(f"  Recall (weighted): {results.get('eval_pii_recall_weighted', 'N/A'):.4f}")
print(f"  Recall (macro): {results.get('eval_pii_recall_macro', 'N/A'):.4f}")

if "eval_coref_f1_weighted" in results:
    print("\nüîç Co-reference Detection Metrics:")
    print(f"  F1 (weighted): {results.get('eval_coref_f1_weighted', 'N/A'):.4f}")
    print(f"  F1 (macro): {results.get('eval_coref_f1_macro', 'N/A'):.4f}")
    print(
        f"  Precision (weighted): {results.get('eval_coref_precision_weighted', 'N/A'):.4f}"
    )
    print(
        f"  Precision (macro): {results.get('eval_coref_precision_macro', 'N/A'):.4f}"
    )
    print(
        f"  Recall (weighted): {results.get('eval_coref_recall_weighted', 'N/A'):.4f}"
    )
    print(f"  Recall (macro): {results.get('eval_coref_recall_macro', 'N/A'):.4f}")

2025-11-24 14:23:53,391 - INFO - 
üìä Evaluating model...


2025-11-24 14:23:53,711 - INFO - 
üìä Evaluation Results:
2025-11-24 14:23:53,711 - INFO - 
üîç PII Detection Metrics:
2025-11-24 14:23:53,711 - INFO -   F1:
2025-11-24 14:23:53,711 - INFO -     eval_pii_f1: 0.0143
2025-11-24 14:23:53,712 - INFO -     eval_pii_f1_macro: 0.0174
2025-11-24 14:23:53,712 - INFO -     eval_pii_f1_weighted: 0.0143
2025-11-24 14:23:53,712 - INFO -   PRECISION:
2025-11-24 14:23:53,712 - INFO -     eval_pii_precision_macro: 0.0145
2025-11-24 14:23:53,712 - INFO -     eval_pii_precision_weighted: 0.0119
2025-11-24 14:23:53,713 - INFO -   RECALL:
2025-11-24 14:23:53,713 - INFO -     eval_pii_recall_macro: 0.0217
2025-11-24 14:23:53,713 - INFO -     eval_pii_recall_weighted: 0.0179
2025-11-24 14:23:53,713 - INFO - 
üîç Co-reference Detection Metrics:
2025-11-24 14:23:53,713 - INFO -   F1:
2025-11-24 14:23:53,713 - INFO -     eval_coref_f1: 0.0785
2025-11-24 14:23:53,714 - INFO -     eval_coref_f1_macro: 0.0505
2025-11-24 14:23:53,714 - INFO -     eval_coref_f1_


üìä EVALUATION RESULTS

üîç PII Detection Metrics:
  F1 (weighted): 0.0143
  F1 (macro): 0.0174
  Precision (weighted): 0.0119
  Precision (macro): 0.0145
  Recall (weighted): 0.0179
  Recall (macro): 0.0217

üîç Co-reference Detection Metrics:
  F1 (weighted): 0.0785
  F1 (macro): 0.0505
  Precision (weighted): 0.4206
  Precision (macro): 0.1852
  Recall (weighted): 0.0536
  Recall (macro): 0.0478


## 9. Model Saved Locally


In [10]:
# Model saved locally
model_path = Path(config.output_dir).absolute()
print(f"\nüíæ Model saved locally at: {model_path}")

# List model files
if model_path.exists():
    model_files = list(model_path.glob("*"))
    print("\nüìÅ Model files:")
    for f in sorted(model_files):
        if f.is_file():
            size_mb = f.stat().st_size / (1024 * 1024)
            print(f"   {f.name} ({size_mb:.2f} MB)")

    print("\n‚úÖ Model is ready to use!")
    print("üí° You can load it using:")
    print("   from model.model import MultiTaskPIIDetectionModel")
    print(f"   model = MultiTaskPIIDetectionModel.from_pretrained('{model_path}')")
else:
    print(f"‚ö†Ô∏è  Model directory not found: {model_path}")


üíæ Model saved locally at: /Users/hannes/opensource/kiji-proxy/pii_model

üìÅ Model files:
   README.md (0.01 MB)
   label_mappings.json (0.00 MB)
   model.safetensors (248.85 MB)
   special_tokens_map.json (0.00 MB)
   tokenizer.json (0.64 MB)
   tokenizer_config.json (0.00 MB)
   training_args.bin (0.01 MB)
   vocab.txt (0.20 MB)

‚úÖ Model is ready to use!
üí° You can load it using:
   from model.model import MultiTaskPIIDetectionModel
   model = MultiTaskPIIDetectionModel.from_pretrained('/Users/hannes/opensource/kiji-proxy/pii_model')


## 10. Final Summary


In [11]:
print("\n" + "=" * 60)
print("üéâ TRAINING COMPLETE!")
print("=" * 60)
print("\nüìä Final Metrics:")
print(
    f"  PII F1 (weighted): {results.get('eval_pii_f1_weighted', results.get('eval_pii_f1', 'N/A')):.4f}"
)
if "eval_coref_f1_weighted" in results:
    print(
        f"  Co-reference F1 (weighted): {results.get('eval_coref_f1_weighted', 'N/A'):.4f}"
    )
print(f"\nüíæ Model saved to: {config.output_dir}")
print("=" * 60)


üéâ TRAINING COMPLETE!

üìä Final Metrics:
  PII F1 (weighted): 0.0143
  Co-reference F1 (weighted): 0.0785

üíæ Model saved to: /Users/hannes/opensource/kiji-proxy/pii_model


## Next Steps

1. **Download the model**: The trained model is saved in `config.output_dir`
2. **Use the model**: Load it using `MultiTaskPIIDetectionModel.from_pretrained()`
3. **Evaluate on test set**: Use the `evaluate()` method with your test dataset
4. **Fine-tune**: Adjust hyperparameters and retrain if needed

### Model Files
- `pytorch_model.bin`: Model weights
- `config.json`: Model configuration
- `label_mappings.json`: Label mappings for both tasks
- `tokenizer_config.json` and related files: Tokenizer configuration
