# Model Training From Pretrained Model Creating Mini Piidgeon models

In [8]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForTokenClassification,AutoConfig, BertTokenizerFast
import pytorch_lightning as pl
import ast
from torchmetrics import Precision, Recall, F1Score, Accuracy
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from typing import Optional, Dict, Any, List
import numpy as np
import os
import json
from torchmetrics.classification import MulticlassPrecision
import datetime
from safetensors.torch import save_file,load_file


class DetailedEarlyStopping(EarlyStopping):
    def __init__(self, *args, **kwargs):
        # Override monitor parameter to use a custom metric
        kwargs['monitor'] = 'avg_target_precision'  # Custom metric name
        super().__init__(*args, **kwargs)
        self.best_metrics = None

    def _run_early_stopping_check(self, trainer):
        # Calculate average of precision for labels 1 and 2
        precision_1 = trainer.callback_metrics.get('val_precision_label_1', torch.tensor(0.0))
        precision_2 = trainer.callback_metrics.get('val_precision_label_2', torch.tensor(0.0))
        
        # Calculate average precision for monitored labels
        avg_precision = (precision_1 + precision_2) / 2
        
        # Update the monitor value in trainer's callback metrics
        trainer.callback_metrics['avg_target_precision'] = avg_precision
        
        stop_training = super()._run_early_stopping_check(trainer)
        
        if stop_training:
            print("\nEarly stopping triggered! Saving best model and metrics...")
            
            # Save the model
            best_model_path = os.path.join(trainer.checkpoint_callback.dirpath, 'best_model_early_stop')
            trainer.lightning_module.model.save_pretrained(best_model_path)
            
            # Store best metrics
            self.best_metrics = {
                name: value.item() if hasattr(value, 'item') else value
                for name, value in trainer.callback_metrics.items()
            }
            
            # Print detailed metrics
            print("\nBest Model Metrics:")
            print(f"Precision Label 1: {precision_1:.4f}")
            print(f"Precision Label 2: {precision_2:.4f}")
            print(f"Average Target Precision: {avg_precision:.4f}")

        return stop_training

class LabelMapper:
    def __init__(self):
        # Start with 'O' as it's always present
        self.label_to_id: Dict[str, int] = {'O': 0}
        self.id_to_label: Dict[int, str] = {0: 'O'}
        self.num_labels: int = 1
        
    def fit(self, label_sequences: List[List[str]]) -> None:
        """
        Fit the mapper on a list of label sequences, combining B- and I- prefixes.
        Args:
            label_sequences: List of label sequences, where each sequence is a list of string labels
        """
        # Collect all unique base labels (without B- or I- prefixes)
        base_labels = set()
        for sequence in label_sequences:
            for label in sequence:
                if label != 'O':
                    # Strip B- or I- prefix and add base label
                    base_label = label[2:] if label.startswith(('B-', 'I-')) else label
                    base_labels.add(base_label)
        
        # Sort base labels for consistency
        sorted_base_labels = sorted(base_labels)
        
        # Create mappings for O, B- and I- variants
        label_mappings = {'O': 0}  # Start with O
        current_idx = 1
        
        # Create id_to_label mapping that properly preserves B- prefixes
        id_to_label_mapping = {0: 'O'}

        for base_label in sorted_base_labels:
            # Add B- and I- variants with same index
            b_label = f'B-{base_label}'
            i_label = f'I-{base_label}'
            label_mappings[b_label] = current_idx
            label_mappings[i_label] = current_idx
            # Store the B- variant in id_to_label mapping
            id_to_label_mapping[current_idx] = b_label
            current_idx += 1
            
        # Store mappings
        self.label_to_id = label_mappings
        self.id_to_label = id_to_label_mapping
        self.num_labels = current_idx
        
        # Print mapping information
        print(f"\nFound {len(base_labels)} unique base labels (excluding O)")
        print(f"Total number of labels after mapping: {self.num_labels}")
        print("\nLabel mapping:")
        for label, idx in sorted(self.label_to_id.items()):
            print(f"{label}: {idx}")

    def encode(self, labels: List[str]) -> List[int]:
        """Convert string labels to IDs, maintaining B-/I- sequence."""
        return [self.label_to_id[label] for label in labels]

    def decode(self, ids: List[int]) -> List[str]:
        """
        Convert IDs back to string labels.
        For non-O labels, uses B- prefix for first token of an entity,
        I- prefix for subsequent tokens of the same entity.
        """
        decoded_labels = []
        prev_id = 0  # O tag
        
        for id in ids:
            if id == 0:
                decoded_labels.append('O')
                prev_id = 0
            else:
                base_label = self.id_to_label[id][2:]  # Strip B- prefix from stored label
                if id != prev_id:
                    # New entity starts
                    decoded_labels.append(f'B-{base_label}')
                else:
                    # Continue existing entity
                    decoded_labels.append(f'I-{base_label}')
                prev_id = id
                
        return decoded_labels

    def decode_with_bio(self, logits: torch.Tensor, attention_mask: torch.Tensor) -> List[List[str]]:
        """
        Decode model logits to BIO labels, handling B-/I- prefixes properly.
        
        Args:
            logits: Model logits (batch_size, sequence_length, num_labels)
            attention_mask: Attention mask (batch_size, sequence_length)
            
        Returns:
            List of label sequences with proper BIO tagging
        """
        predictions = torch.argmax(logits, dim=-1)  # (batch_size, sequence_length)
        batch_labels = []
        
        for pred, mask in zip(predictions, attention_mask):
            # Only process tokens that aren't padding
            valid_length = torch.sum(mask).item()
            pred = pred[:valid_length].cpu().tolist()
            
            # Convert to BIO format
            labels = self.decode(pred)
            batch_labels.append(labels)
            
        return batch_labels

class PiiDataset(Dataset):
    def __init__(self, data: pd.DataFrame, tokenizer, label_mapper: LabelMapper = None, max_length: int = 128):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.label_mapper = label_mapper
        
        # Validate source_text column
        if 'source_text' not in self.data.columns:
            raise ValueError("DataFrame must contain 'source_text' column")
        
        # Ensure source_text is string type
        self.data['source_text'] = self.data['source_text'].astype(str)

    def __len__(self):
        return len(self.data)
    
    def parse_labels(self, labels_str) -> List[str]:
        """Safely parse label strings into list of label strings."""
        try:
            if isinstance(labels_str, str):
                if '[' in labels_str:
                    return ast.literal_eval(labels_str)
                return labels_str.split(',')
            elif isinstance(labels_str, list):
                return labels_str
            return ['O'] * self.max_length
        except Exception as e:
            print(f"Error parsing labels: {e}, value: {labels_str}")
            return ['O'] * self.max_length

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        
        # Ensure text is string type
        text = str(row['source_text'])
        
        # Handle empty strings
        if not text.strip():
            text = " "  # Use single space for empty strings
            
        try:
            encoding = self.tokenizer(
                text,
                padding='max_length',
                truncation=True,
                max_length=self.max_length,
                return_tensors="pt"
            )
        except Exception as e:
            print(f"Tokenization error at index {idx}: {e}")
            print(f"Text: {text}")
            raise
        
        if 'token_type_ids' in encoding:
            del encoding['token_type_ids']
            
        encoding = {key: value.squeeze(0) for key, value in encoding.items()}
        
        # Parse and process labels
        labels = self.parse_labels(row['mbert_token_classes'])
        if self.label_mapper:
            labels = self.label_mapper.encode(labels)
        labels = torch.tensor(labels, dtype=torch.long)
            
        if len(labels) < self.max_length:
            labels = torch.cat([labels, torch.zeros(self.max_length - len(labels), dtype=torch.long)])
        elif len(labels) > self.max_length:
            labels = labels[:self.max_length]

        return {**encoding, 'labels': labels}

class PiiDataModule(pl.LightningDataModule):
    def __init__(self, train_df, val_df, batch_size=16):
        super().__init__()
        # Ensure DataFrames are copied to avoid modifying original data
        self.train_df = train_df.copy()
        self.val_df = val_df.copy()
        self.batch_size = batch_size
        self.tokenizer = BertTokenizerFast.from_pretrained("google-bert/bert-base-multilingual-cased")
        self.label_mapper = LabelMapper()
        
        # Validate and clean data
        self._validate_and_clean_data()

    def _validate_and_clean_data(self):
        """Validate and clean the input DataFrames."""
        required_columns = ['source_text', 'mbert_token_classes']
        
        for df_name, df in [('train', self.train_df), ('val', self.val_df)]:
            # Check required columns
            missing_cols = [col for col in required_columns if col not in df.columns]
            if missing_cols:
                raise ValueError(f"Missing required columns in {df_name}_df: {missing_cols}")
            
            # Convert source_text to string
            df['source_text'] = df['source_text'].astype(str)
            
            # Remove empty rows
            empty_mask = df['source_text'].str.strip().eq('')
            if empty_mask.any():
                print(f"Warning: Removing {empty_mask.sum()} empty rows from {df_name}_df")
                if df_name == 'train':
                    self.train_df = df[~empty_mask].reset_index(drop=True)
                else:
                    self.val_df = df[~empty_mask].reset_index(drop=True)

    def setup(self, stage=None):
        # Create and fit label mapper on training data
        train_labels = [self.parse_labels(labels) for labels in self.train_df['mbert_token_classes']]
        self.label_mapper.fit(train_labels)
        
        # Create datasets
        self.train_dataset = PiiDataset(self.train_df, self.tokenizer, self.label_mapper)
        self.val_dataset = PiiDataset(self.val_df, self.tokenizer, self.label_mapper)
    
    def parse_labels(self, labels_str):
        """Helper method to parse labels for fitting the mapper."""
        try:
            if isinstance(labels_str, str):
                if '[' in labels_str:
                    return ast.literal_eval(labels_str)
                return labels_str.split(',')
            elif isinstance(labels_str, list):
                return labels_str
            return ['O']
        except Exception as e:
            print(f"Error parsing labels during setup: {e}")
            return ['O']

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=2,
            pin_memory=False,
            persistent_workers=False,
            prefetch_factor=2
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            num_workers=2,
            pin_memory=False,
            persistent_workers=False,
            prefetch_factor=2
        )

class PiiModel(pl.LightningModule):
    def __init__(self, num_labels: int, model_name="iiiorg/piiranha-v1-detect-personal-information", lr=2.5e-5):
        super().__init__()
        self.save_hyperparameters()
        self.num_labels = num_labels
        self.lr = lr
        
        # Load model configuration
        config = AutoConfig.from_pretrained(model_name)
        config.num_labels = num_labels
        config.use_cache = False
        
        self.model = AutoModelForTokenClassification.from_pretrained(
            model_name,
            config=config,
            ignore_mismatched_sizes=True
        )
        
        self.model.gradient_checkpointing_enable()
        
        # Initialize per-label metrics
        self.train_metrics = None
        self.val_metrics = None
        self.test_metrics = None
        
        # Track best validation metrics
        self.best_val_metrics = {}

    def _create_metrics(self) -> Dict[str, Any]:
        """Create metrics including per-label metrics."""
        metrics = {
            'precision': Precision(task="multiclass", num_classes=self.num_labels, average='macro'),
            'recall': Recall(task="multiclass", num_classes=self.num_labels, average='macro'),
            'f1': F1Score(task="multiclass", num_classes=self.num_labels, average='macro'),
            'accuracy': Accuracy(task="multiclass", num_classes=self.num_labels, average='macro')
        }
        
        # Add per-label metrics
        for i in range(self.num_labels):
            metrics.update({
                f'precision_label_{i}': Precision(task="multiclass", num_classes=self.num_labels, average=None),
                f'recall_label_{i}': Recall(task="multiclass", num_classes=self.num_labels, average=None),
                f'f1_label_{i}': F1Score(task="multiclass", num_classes=self.num_labels, average=None)
            })
        
        return {name: metric.to(self.device) for name, metric in metrics.items()}
    
    def on_train_start(self):
        """Initialize metrics on the correct device at the start of training."""
        if self.train_metrics is None:
            self.train_metrics = self._create_metrics()
        if self.val_metrics is None:
            self.val_metrics = self._create_metrics()
        if self.test_metrics is None:
            self.test_metrics = self._create_metrics()

    def forward(self, input_ids, attention_mask, labels=None):
        """Forward pass of the model."""
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        return outputs

    def _compute_metrics(self, preds, labels, metrics):
        """Compute all metrics including per-label metrics."""
        preds = preds.to(self.device)
        labels = labels.to(self.device)
        
        results = {}
        for name, metric in metrics.items():
            if 'label_' in name:
                # For per-label metrics, get the specific label index
                label_idx = int(name.split('_')[-1])
                value = metric(preds, labels)[label_idx]
                results[name] = value
            else:
                results[name] = metric(preds, labels)
        
        return results

    def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
        """Training step."""
        outputs = self(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            labels=batch['labels']
        )
        
        loss = outputs.loss
        logits = outputs.logits
        preds = torch.argmax(logits, dim=-1)
        
        # Initialize metrics if not already done
        if self.train_metrics is None:
            self.train_metrics = self._create_metrics()
        
        # Compute and log metrics
        metrics = self._compute_metrics(preds, batch['labels'], self.train_metrics)
        
        # Log metrics
        self.log("train_loss", loss, prog_bar=True)
        for name, value in metrics.items():
            self.log(f"train_{name}", value, prog_bar=True)
            
        return loss

    def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> None:
        outputs = self(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            labels=batch['labels']
        )
        
        loss = outputs.loss
        logits = outputs.logits
        preds = torch.argmax(logits, dim=-1)
        
        if self.val_metrics is None:
            self.val_metrics = self._create_metrics()
        
        metrics = self._compute_metrics(preds, batch['labels'], self.val_metrics)
        
        # Log all metrics
        self.log("val_loss", loss, prog_bar=True)
        for name, value in metrics.items():
            self.log(f"val_{name}", value, prog_bar=True)
        
        # Update best metrics if needed
        if not self.best_val_metrics or metrics['f1'] > self.best_val_metrics.get('f1', 0):
            self.best_val_metrics = metrics

    def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> None:
        """Test step."""
        outputs = self(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            labels=batch['labels']
        )
        
        loss = outputs.loss
        logits = outputs.logits
        preds = torch.argmax(logits, dim=-1)
        
        # Initialize metrics if not already done
        if self.test_metrics is None:
            self.test_metrics = self._create_metrics()
        
        # Compute and log metrics
        metrics = self._compute_metrics(preds, batch['labels'], self.test_metrics)
        
        # Log metrics
        self.log("test_loss", loss)
        for name, value in metrics.items():
            self.log(f"test_{name}", value)

    def configure_optimizers(self):
        """Configure optimizers and learning rate schedulers."""
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
        
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=0.1,
            patience=2,
            verbose=True
        )
        
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss",
            },
        }

    def predict_bio_labels(self, input_ids, attention_mask, label_mapper):
        """Get predictions with proper BIO tagging."""
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        return label_mapper.decode_with_bio(outputs.logits, attention_mask)

def train_pii_model(train_df, val_df, batch_size=16, max_epochs=10, save_dir="pii_model"):
    """Train the PII model with aggressive memory optimizations."""
    import gc
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Clear GPU cache and run garbage collection
    if device.type == "cuda":
        torch.cuda.empty_cache()
        gc.collect()

    os.makedirs(save_dir, exist_ok=True)
    checkpoints_dir = os.path.join(save_dir, "checkpoints")
    metrics_dir = os.path.join(save_dir, "metrics")
    os.makedirs(checkpoints_dir, exist_ok=True)
    os.makedirs(metrics_dir, exist_ok=True)

    # Initialize data module with reduced num_workers
    data_module = PiiDataModule(
        train_df=train_df,
        val_df=val_df,
        batch_size=batch_size
    )
    data_module.setup()
    
    # Initialize model with memory optimizations
    model = PiiModel(
        num_labels=data_module.label_mapper.num_labels,
        model_name="iiiorg/piiranha-v1-detect-personal-information"
    )
    
    initial_weights_dir = os.path.join(save_dir, "initial_weights")
    os.makedirs(initial_weights_dir, exist_ok=True)
    model.model.save_pretrained(initial_weights_dir)
    print(f"Initial weights of the model saved to {initial_weights_dir}")
    
    # Additional memory optimizations for the model
    if device.type == "cuda":
        model.model.gradient_checkpointing_enable()
        # Optimize memory allocation for attention mechanisms
        model.model.config.use_cache = False
    
    label_mapper_path = os.path.join(save_dir, "label_mapper.json")
    label_mapper_data = {
        'label_to_id': data_module.label_mapper.label_to_id,
        'id_to_label': data_module.label_mapper.id_to_label,
        'num_labels': data_module.label_mapper.num_labels
    }
    with open(label_mapper_path, 'w') as f:
        json.dump(label_mapper_data, f, indent=2)

    class EnhancedMetricCheckpoint(pl.callbacks.ModelCheckpoint):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.metrics_history = []
            self.last_metrics=None
        def on_validation_end(self, trainer, pl_module):
            # Always save the current epoch checkpoint
            current_epoch = trainer.current_epoch
            current_loss = trainer.callback_metrics.get('val_loss', 0.0)
            
            # Create checkpoint filename
            checkpoint_filename = f'checkpoint-epoch-{current_epoch:02d}-loss-{current_loss:.4f}.ckpt'
            checkpoint_path = os.path.join(self.dirpath, checkpoint_filename)
            
            # Save checkpoint
            trainer.save_checkpoint(checkpoint_path)
            print(f"\nSaved checkpoint for epoch {current_epoch}: {checkpoint_filename}")
            
            # Collect metrics
            metrics = {
                'epoch': current_epoch,
                'global_step': trainer.global_step,
                'timestamp': datetime.datetime.now().isoformat(),
                'train_metrics': {},
                'val_metrics': {}
            }
            
            for key, value in trainer.callback_metrics.items():
                if isinstance(value, torch.Tensor):
                    value = value.item()
                if key.startswith('train_'):
                    metrics['train_metrics'][key] = value
                elif key.startswith('val_'):
                    metrics['val_metrics'][key] = value
            
            self.metrics_history.append(metrics)
            self.last_metrics = metrics
            # Save metrics
            metrics_file = os.path.join(metrics_dir, f'metrics_epoch_{current_epoch}.json')
            with open(metrics_file, 'w') as f:
                json.dump(metrics, f, indent=2)
            
            # Save the model in safetensors format
            model_dir = os.path.join(self.dirpath, f'model_epoch_{current_epoch:02d}')
            os.makedirs(model_dir, exist_ok=True)
            config_path = os.path.join(model_dir, 'config.json')
            model_weights_path = os.path.join(model_dir, 'model.safetensors')

            # Save model configuration
            pl_module.model.config.save_pretrained(model_dir)

            # Save model weights using safetensors
            save_file(pl_module.model.state_dict(), model_weights_path)

            print(f"\nSaved model configuration and weights for epoch {current_epoch}.")


            # Update best metrics if needed
            if self.best_model_path:
                best_metrics_file = os.path.join(metrics_dir, 'best_metrics.json')
                with open(best_metrics_file, 'w') as f:
                    json.dump(metrics, f, indent=2)

    
    class FinalEpochCallback(pl.Callback):
        def on_train_end(self, trainer, pl_module):
            # Save final model state
            final_model_path = os.path.join(save_dir, "final_model")
            os.makedirs(final_model_path, exist_ok=True)
            pl_module.model.save_pretrained(final_model_path)
            print(f"\nSaved final model to {final_model_path}")

    # Setup callbacks with reduced checkpoint frequency
    checkpoint_callback = EnhancedMetricCheckpoint(
        dirpath=checkpoints_dir,
        filename='checkpoint-{epoch:02d}-{val_loss:.4f}',
        monitor='val_loss',
        mode='min',
        save_top_k=3,  # Save top 3 models
        save_last=True,  # Save last model
        every_n_epochs=1  # Save checkpoint every epoch
    )
    
    early_stopping = DetailedEarlyStopping(
        patience=1,
        mode="max"
    )
    
    lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval="epoch")
    
    # Setup logger with reduced logging frequency
    logger = TensorBoardLogger(
        "tb_logs", 
        name="piiranha_training",
        log_graph=False,  # Disable computational graph logging
        max_queue=10
    )
    callbacks = [
        checkpoint_callback, 
        early_stopping, 
        pl.callbacks.LearningRateMonitor(logging_interval="epoch"),
        FinalEpochCallback()
    ]
    # Initialize trainer with memory optimizations
    trainer = pl.Trainer(
        max_epochs=max_epochs,
        callbacks=[checkpoint_callback, early_stopping, lr_monitor],
        logger=logger,
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        devices=1,
        gradient_clip_val=1.0,
        log_every_n_steps=50,  # Reduced logging frequency
        precision="16-mixed",  # Use mixed precision training with reduced precision
        accumulate_grad_batches=2,  # Gradient accumulation to reduce memory usage
        strategy='ddp_notebook' if torch.cuda.is_available() else "auto",
        enable_progress_bar=True,
        enable_model_summary=False,  # Disable model summary to save memory
        inference_mode=True,  # Enable inference mode optimization
        profiler=None,  # Disable profiling
    )

    
    # Modified DataLoader settings in PiiDataModule
    data_module.train_dataloader = lambda: DataLoader(
        data_module.train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,  # Reduced number of workers
        pin_memory=False,  # Disable pin_memory
        persistent_workers=False,  # Disable persistent workers
        prefetch_factor=2  # Reduced prefetch factor
    )
    
    data_module.val_dataloader = lambda: DataLoader(
        data_module.val_dataset,
        batch_size=batch_size,
        num_workers=2,
        pin_memory=False,
        persistent_workers=False,
        prefetch_factor=2
    )
    
    # Train the model
    trainer.fit(model, datamodule=data_module)
    
    # Print final best metrics
    final_metrics = {
        'best_metrics': model.best_val_metrics,
        'final_metrics': checkpoint_callback.last_metrics if checkpoint_callback.last_metrics else {},
        'training_completed': True,
        'total_epochs': trainer.current_epoch,
        'early_stopped': trainer.should_stop,
        'best_model_path': checkpoint_callback.best_model_path,
        'last_model_path': checkpoint_callback.last_model_path
    }
    
    with open(os.path.join(save_dir, "training_summary.json"), 'w') as f:
        json.dump({k: float(v) if isinstance(v, torch.Tensor) else v 
                  for k, v in final_metrics.items()}, f, indent=2)
    
    print("\nTraining Summary:")
    print(f"Total epochs completed: {trainer.current_epoch + 1}")
    print(f"Best model saved at: {checkpoint_callback.best_model_path}")
    print("\nFinal Best Metrics:")
    for name, value in model.best_val_metrics.items():
        if isinstance(value, torch.Tensor):
            value = value.item()
        print(f"{name}: {value:.4f}")

    # Save the model and related files
    final_model_path = os.path.join(save_dir, "final_model")
    os.makedirs(final_model_path, exist_ok=True)
    model.model.save_pretrained(final_model_path)
    
    print(f"\nModel saved to {final_model_path}")
    
    return model, data_module.label_mapper

def load_checkpoint(checkpoint_path: str, model_dir="pii_model"):
    """
    Load a specific checkpoint.
    
    Args:
        checkpoint_path: Path to the checkpoint file
        model_dir: Directory containing label_mapper.json and tokenizer files
    """
    
    # Load label mapper
    with open(os.path.join(model_dir, 'label_mapper.json'), 'r') as f:
        label_mapper_data = json.load(f)
    
    label_mapper = LabelMapper()
    label_mapper.label_to_id = label_mapper_data['label_to_id']
    label_mapper.id_to_label = {int(k): v for k, v in label_mapper_data['id_to_label'].items()}
    label_mapper.num_labels = label_mapper_data['num_labels']
    
    # Load tokenizer
    tokenizer = BertTokenizerFast.from_pretrained("google-bert/bert-base-multilingual-cased")
    
    # Initialize model
    model = PiiModel(num_labels=label_mapper.num_labels)
    
    # Load checkpoint
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['state_dict'])
    
    return model, tokenizer, label_mapper

def list_checkpoints(save_dir="pii_model"):
    """List all available checkpoints."""
    import os
    
    checkpoint_dir = os.path.join(save_dir, "epoch_checkpoints")
    best_model_dir = os.path.join(save_dir, "best_model")
    
    print("Regular checkpoints:")
    if os.path.exists(checkpoint_dir):
        checkpoints = sorted([f for f in os.listdir(checkpoint_dir) if f.endswith('.ckpt')])
        for ckpt in checkpoints:
            print(f"  - {ckpt}")
    else:
        print("  No regular checkpoints found")
        
    print("\nBest model checkpoints:")
    if os.path.exists(best_model_dir):
        checkpoints = sorted([f for f in os.listdir(best_model_dir) if f.endswith('.ckpt')])
        for ckpt in checkpoints:
            print(f"  - {ckpt}")
    else:
        print("  No best model checkpoints found")

def list_saved_metrics(save_dir="pii_model"):
    """List all saved metrics and checkpoints."""
    metrics_dir = os.path.join(save_dir, "metrics")
    checkpoints_dir = os.path.join(save_dir, "checkpoints")
    
    print("Available metric files:")
    if os.path.exists(metrics_dir):
        metric_files = sorted([f for f in os.listdir(metrics_dir) if f.endswith('.json')])
        for file in metric_files:
            print(f"  - {file}")
    else:
        print("  No metric files found")
    
    print("\nAvailable checkpoints:")
    if os.path.exists(checkpoints_dir):
        checkpoint_files = sorted([f for f in os.listdir(checkpoints_dir) if f.endswith('.ckpt')])
        for file in checkpoint_files:
            print(f"  - {file}")
    else:
        print("  No checkpoints found")

In [2]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"]='expandable_segments:True'

In [3]:
import pandas as pd
import json
import ast

train_df = pd.read_csv('train.csv')

print(len(train_df))
val_df = pd.read_csv('validation.csv')

def contains_labels(privacy_mask_str):
    try:
        # Convert the string representation to a list of dictionaries
        privacy_mask = ast.literal_eval(privacy_mask_str)
        # Check if any label is ACCOUNTNUM or IDCARTNUM
        return any(item['label'] in ['ACCOUNTNUM', 'IDCARDNUM'] for item in privacy_mask)
    except (ValueError, SyntaxError):
        return False

# Apply the filtering function to the DataFrame
filtered_train_df = train_df[train_df['privacy_mask'].apply(contains_labels)]
filtered_val_df = val_df[val_df['privacy_mask'].apply(contains_labels)]


325517


In [4]:
def filter_labels(labels_str, valid_labels):
    try:
        if isinstance(labels_str, str):
            labels = ast.literal_eval(labels_str)
        elif isinstance(labels_str, list):
            labels = labels_str
        else:
            return ['O'] * 128  # Assuming max length of 128

        # Replace labels not in valid_labels with 'O'
        return [label if label in valid_labels else 'O' for label in labels]
    except Exception as e:
        print(f"Error parsing labels: {e}, value: {labels_str}")
        return ['O'] * 128  # Assuming max length of 128
valid_labels = ['O', 'B-ACCOUNTNUM', 'I-ACCOUNTNUM', 'B-IDCARDNUM', 'I-IDCARDNUM']
filtered_train_df['mbert_token_classes'] = filtered_train_df['mbert_token_classes'].apply(lambda x: filter_labels(x, valid_labels))
filtered_val_df['mbert_token_classes'] = filtered_val_df['mbert_token_classes'].apply(lambda x: filter_labels(x, valid_labels))

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  filtered_train_df['mbert_token_classes'] = filtered_train_df['mbert_token_classes'].apply(lambda x: filter_labels(x, valid_labels))
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  filtered_val_df['mbert_token_classes'] = filtered_val_df['mbert_token_classes'].apply(lambda x: filter_labels(x, valid_labels))


In [9]:
model, label_mapper = train_pii_model(filtered_train_df, filtered_val_df,batch_size=32, save_dir="pii_model", max_epochs=10)

list_saved_metrics() # Best metrics and best Checkpoint

tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/996k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.96M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/625 [00:00<?, ?B/s]


Found 2 unique base labels (excluding O)
Total number of labels after mapping: 3

Label mapping:
B-ACCOUNTNUM: 1
B-IDCARDNUM: 2
I-ACCOUNTNUM: 1
I-IDCARDNUM: 2
O: 0


Some weights of DebertaV2ForTokenClassification were not initialized from the model checkpoint at iiiorg/piiranha-v1-detect-personal-information and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([18]) in the checkpoint and torch.Size([3]) in the model instantiated
- classifier.weight: found shape torch.Size([18, 768]) in the checkpoint and torch.Size([3, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Initial weights of the model saved to pii_model/initial_weights


You are using a CUDA device ('NVIDIA L4') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------




Found 2 unique base labels (excluding O)
Total number of labels after mapping: 3

Label mapping:
B-ACCOUNTNUM: 1
B-IDCARDNUM: 2
I-ACCOUNTNUM: 1
I-IDCARDNUM: 2
O: 0


/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /teamspace/studios/this_studio/pii_model/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:431: It is recommended to use `self.log('val_loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:431: It is recommended to use `self.log('val_precision', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:431: It is recommended to use `self.log('val_recall', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/pytorch_lightning/trainer/connector


Saved checkpoint for epoch 0: checkpoint-epoch-00-loss-0.4331.ckpt

Saved model configuration and weights for epoch 0.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]


Saved checkpoint for epoch 0: checkpoint-epoch-00-loss-0.0230.ckpt

Saved model configuration and weights for epoch 0.


Validation: |          | 0/? [00:00<?, ?it/s]


Saved checkpoint for epoch 1: checkpoint-epoch-01-loss-0.0207.ckpt

Saved model configuration and weights for epoch 1.


In [18]:

# Load the config
config = AutoConfig.from_pretrained("pii_model/checkpoints/model_epoch_03")

# Initialize the model with the config
model = AutoModelForTokenClassification.from_config(config)

# Load the safetensors weights
state_dict = load_file("pii_model/checkpoints/model_epoch_03/model.safetensors")

# Load the state dict into the model
model.load_state_dict(state_dict)

# Load the tokenizer
tokenizer = BertTokenizerFast.from_pretrained("google-bert/bert-base-multilingual-cased")

# Load the label mapper if needed
with open("pii_model/label_mapper.json", 'r') as f:
    label_mapper_data = json.load(f)

label_mapper = LabelMapper()
label_mapper.label_to_id = label_mapper_data['label_to_id']
label_mapper.id_to_label = {int(k): v for k, v in label_mapper_data['id_to_label'].items()}
label_mapper.num_labels = label_mapper_data['num_labels']

In [33]:
# Set model to evaluation mode
model.eval()

# Get the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Get the text
text = filtered_val_df.iloc[27].source_text


# Tokenize with padding and truncation
encoded_text = tokenizer(
    text,
    padding='max_length',
    truncation=True,
    max_length=128,  # or whatever max_length you used during training
    return_tensors="pt"  # Return PyTorch tensors
)

# Move inputs to the same device as model
inputs = {
    'input_ids': encoded_text['input_ids'].to(device),
    'attention_mask': encoded_text['attention_mask'].to(device)
}

# Perform inference
with torch.no_grad():  # Disable gradient calculation
    outputs = model(**inputs)

# Get predictions
predictions = torch.argmax(outputs.logits, dim=-1)

# Convert predictions to labels using label_mapper
predicted_labels = label_mapper.decode(predictions[0].cpu().tolist())

# Print results
print("Original text:", text)
print("\nPredicted labels:", predicted_labels)

# If you want to see the token-label alignment:
tokens = tokenizer.convert_ids_to_tokens(encoded_text['input_ids'][0])
print("\nToken-label alignment:")
for token, label in zip(tokens, predicted_labels):
    if token != '[PAD]':  # Skip padding tokens
        print(f"{token}: {label}")

Original text: <p>Validatie ID: 663148812, Sociale beveiliging: 862460623.</p><p>Vertrek via: Duplex 69, Schiedam, 156.</p>

Predicted labels: ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-IDCARDNUM', 'I-IDCARDNUM', 'I-IDCARDNUM', 'I-IDCARDNUM', 'I-IDCARDNUM', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']

Token-label alignment:
[CLS]: O
<: O
p: O
>: O
Val: O
##idat: O
##ie: O
ID: O
:: B-IDCARDNUM
663: I-IDCARDNUM
##14: I-IDCARDNUM
##8: I-IDCARDNUM
##8: I-IDCARDNUM
##1

# Accuracy dependant Piidgeon training

In [36]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"]='expandable_segments:True'

In [1]:
class DetailedEarlyStopping(EarlyStopping):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.best_metrics = None

    def _run_early_stopping_check(self, trainer):
        stop_training = super()._run_early_stopping_check(trainer)
        
        if stop_training:
            print("\nEarly stopping triggered! Saving best model and metrics...")
            
            # Save the model
            best_model_path = os.path.join(trainer.checkpoint_callback.dirpath, 'best_model_early_stop')
            trainer.lightning_module.model.save_pretrained(best_model_path)
            
            # Store best metrics
            self.best_metrics = {
                name: value.item() if hasattr(value, 'item') else value
                for name, value in trainer.callback_metrics.items()
            }
            
            # Print detailed metrics
            print("\nBest Model Metrics:")
            for name, value in self.best_metrics.items():
                print(f"{name}: {value:.4f}")

        return stop_training

class LabelMapper:
    def __init__(self):
        # Start with 'O' as it's always present
        self.label_to_id: Dict[str, int] = {'O': 0}
        self.id_to_label: Dict[int, str] = {0: 'O'}
        self.num_labels: int = 1
        
    def fit(self, label_sequences: List[List[str]]) -> None:
        """
        Fit the mapper on a list of label sequences, combining B- and I- prefixes.
        Args:
            label_sequences: List of label sequences, where each sequence is a list of string labels
        """
        # Collect all unique base labels (without B- or I- prefixes)
        base_labels = set()
        for sequence in label_sequences:
            for label in sequence:
                if label != 'O':
                    # Strip B- or I- prefix and add base label
                    base_label = label[2:] if label.startswith(('B-', 'I-')) else label
                    base_labels.add(base_label)
        
        # Sort base labels for consistency
        sorted_base_labels = sorted(base_labels)
        
        # Create mappings for O, B- and I- variants
        label_mappings = {'O': 0}  # Start with O
        current_idx = 1
        
        # Create id_to_label mapping that properly preserves B- prefixes
        id_to_label_mapping = {0: 'O'}

        for base_label in sorted_base_labels:
            # Add B- and I- variants with same index
            b_label = f'B-{base_label}'
            i_label = f'I-{base_label}'
            label_mappings[b_label] = current_idx
            label_mappings[i_label] = current_idx
            # Store the B- variant in id_to_label mapping
            id_to_label_mapping[current_idx] = b_label
            current_idx += 1
            
        # Store mappings
        self.label_to_id = label_mappings
        self.id_to_label = id_to_label_mapping
        self.num_labels = current_idx
        
        # Print mapping information
        print(f"\nFound {len(base_labels)} unique base labels (excluding O)")
        print(f"Total number of labels after mapping: {self.num_labels}")
        print("\nLabel mapping:")
        for label, idx in sorted(self.label_to_id.items()):
            print(f"{label}: {idx}")

    def encode(self, labels: List[str]) -> List[int]:
        """Convert string labels to IDs, maintaining B-/I- sequence."""
        return [self.label_to_id[label] for label in labels]

    def decode(self, ids: List[int]) -> List[str]:
        """
        Convert IDs back to string labels.
        For non-O labels, uses B- prefix for first token of an entity,
        I- prefix for subsequent tokens of the same entity.
        """
        decoded_labels = []
        prev_id = 0  # O tag
        
        for id in ids:
            if id == 0:
                decoded_labels.append('O')
                prev_id = 0
            else:
                base_label = self.id_to_label[id][2:]  # Strip B- prefix from stored label
                if id != prev_id:
                    # New entity starts
                    decoded_labels.append(f'B-{base_label}')
                else:
                    # Continue existing entity
                    decoded_labels.append(f'I-{base_label}')
                prev_id = id
                
        return decoded_labels

    def decode_with_bio(self, logits: torch.Tensor, attention_mask: torch.Tensor) -> List[List[str]]:
        """
        Decode model logits to BIO labels, handling B-/I- prefixes properly.
        
        Args:
            logits: Model logits (batch_size, sequence_length, num_labels)
            attention_mask: Attention mask (batch_size, sequence_length)
            
        Returns:
            List of label sequences with proper BIO tagging
        """
        predictions = torch.argmax(logits, dim=-1)  # (batch_size, sequence_length)
        batch_labels = []
        
        for pred, mask in zip(predictions, attention_mask):
            # Only process tokens that aren't padding
            valid_length = torch.sum(mask).item()
            pred = pred[:valid_length].cpu().tolist()
            
            # Convert to BIO format
            labels = self.decode(pred)
            batch_labels.append(labels)
            
        return batch_labels

class PiiDataset(Dataset):
    def __init__(self, data: pd.DataFrame, tokenizer, label_mapper: LabelMapper = None, max_length: int = 128):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.label_mapper = label_mapper
        
        # Validate source_text column
        if 'source_text' not in self.data.columns:
            raise ValueError("DataFrame must contain 'source_text' column")
        
        # Ensure source_text is string type
        self.data['source_text'] = self.data['source_text'].astype(str)

    def __len__(self):
        return len(self.data)
    
    def parse_labels(self, labels_str) -> List[str]:
        """Safely parse label strings into list of label strings."""
        try:
            if isinstance(labels_str, str):
                if '[' in labels_str:
                    return ast.literal_eval(labels_str)
                return labels_str.split(',')
            elif isinstance(labels_str, list):
                return labels_str
            return ['O'] * self.max_length
        except Exception as e:
            print(f"Error parsing labels: {e}, value: {labels_str}")
            return ['O'] * self.max_length

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        
        # Ensure text is string type
        text = str(row['source_text'])
        
        # Handle empty strings
        if not text.strip():
            text = " "  # Use single space for empty strings
            
        try:
            encoding = self.tokenizer(
                text,
                padding='max_length',
                truncation=True,
                max_length=self.max_length,
                return_tensors="pt"
            )
        except Exception as e:
            print(f"Tokenization error at index {idx}: {e}")
            print(f"Text: {text}")
            raise
        
        if 'token_type_ids' in encoding:
            del encoding['token_type_ids']
            
        encoding = {key: value.squeeze(0) for key, value in encoding.items()}
        
        # Parse and process labels
        labels = self.parse_labels(row['mbert_token_classes'])
        if self.label_mapper:
            labels = self.label_mapper.encode(labels)
        labels = torch.tensor(labels, dtype=torch.long)
            
        if len(labels) < self.max_length:
            labels = torch.cat([labels, torch.zeros(self.max_length - len(labels), dtype=torch.long)])
        elif len(labels) > self.max_length:
            labels = labels[:self.max_length]

        return {**encoding, 'labels': labels}

class PiiDataModule(pl.LightningDataModule):
    def __init__(self, train_df, val_df, batch_size=16):
        super().__init__()
        # Ensure DataFrames are copied to avoid modifying original data
        self.train_df = train_df.copy()
        self.val_df = val_df.copy()
        self.batch_size = batch_size
        self.tokenizer = BertTokenizerFast.from_pretrained("google-bert/bert-base-multilingual-cased")
        self.label_mapper = LabelMapper()
        
        # Validate and clean data
        self._validate_and_clean_data()

    def _validate_and_clean_data(self):
        """Validate and clean the input DataFrames."""
        required_columns = ['source_text', 'mbert_token_classes']
        
        for df_name, df in [('train', self.train_df), ('val', self.val_df)]:
            # Check required columns
            missing_cols = [col for col in required_columns if col not in df.columns]
            if missing_cols:
                raise ValueError(f"Missing required columns in {df_name}_df: {missing_cols}")
            
            # Convert source_text to string
            df['source_text'] = df['source_text'].astype(str)
            
            # Remove empty rows
            empty_mask = df['source_text'].str.strip().eq('')
            if empty_mask.any():
                print(f"Warning: Removing {empty_mask.sum()} empty rows from {df_name}_df")
                if df_name == 'train':
                    self.train_df = df[~empty_mask].reset_index(drop=True)
                else:
                    self.val_df = df[~empty_mask].reset_index(drop=True)

    def setup(self, stage=None):
        # Create and fit label mapper on training data
        train_labels = [self.parse_labels(labels) for labels in self.train_df['mbert_token_classes']]
        self.label_mapper.fit(train_labels)
        
        # Create datasets
        self.train_dataset = PiiDataset(self.train_df, self.tokenizer, self.label_mapper)
        self.val_dataset = PiiDataset(self.val_df, self.tokenizer, self.label_mapper)
    
    def parse_labels(self, labels_str):
        """Helper method to parse labels for fitting the mapper."""
        try:
            if isinstance(labels_str, str):
                if '[' in labels_str:
                    return ast.literal_eval(labels_str)
                return labels_str.split(',')
            elif isinstance(labels_str, list):
                return labels_str
            return ['O']
        except Exception as e:
            print(f"Error parsing labels during setup: {e}")
            return ['O']

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=2,
            pin_memory=False,
            persistent_workers=False,
            prefetch_factor=2
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            num_workers=2,
            pin_memory=False,
            persistent_workers=False,
            prefetch_factor=2
        )

class PiiModel(pl.LightningModule):
    def __init__(self, num_labels: int, model_name="iiiorg/piiranha-v1-detect-personal-information", lr=2.5e-5):
        super().__init__()
        self.save_hyperparameters()
        self.num_labels = num_labels
        self.lr = lr
        
        # Load model configuration
        config = AutoConfig.from_pretrained(model_name)
        config.num_labels = num_labels
        config.use_cache = False
        
        self.model = AutoModelForTokenClassification.from_pretrained(
            model_name,
            config=config,
            ignore_mismatched_sizes=True
        )
        
        self.model.gradient_checkpointing_enable()
        
        # Initialize per-label metrics
        self.train_metrics = None
        self.val_metrics = None
        self.test_metrics = None
        
        # Track best validation metrics
        self.best_val_metrics = {}

    def _create_metrics(self) -> Dict[str, Any]:
        """Create metrics including per-label metrics."""
        metrics = {
            'precision': Precision(task="multiclass", num_classes=self.num_labels, average='macro'),
            'recall': Recall(task="multiclass", num_classes=self.num_labels, average='macro'),
            'f1': F1Score(task="multiclass", num_classes=self.num_labels, average='macro'),
            'accuracy': Accuracy(task="multiclass", num_classes=self.num_labels, average='macro')
        }
        
        # Add per-label metrics
        for i in range(self.num_labels):
            metrics.update({
                f'precision_label_{i}': Precision(task="multiclass", num_classes=self.num_labels, average=None),
                f'recall_label_{i}': Recall(task="multiclass", num_classes=self.num_labels, average=None),
                f'f1_label_{i}': F1Score(task="multiclass", num_classes=self.num_labels, average=None)
            })
        
        return {name: metric.to(self.device) for name, metric in metrics.items()}
    
    def on_train_start(self):
        """Initialize metrics on the correct device at the start of training."""
        if self.train_metrics is None:
            self.train_metrics = self._create_metrics()
        if self.val_metrics is None:
            self.val_metrics = self._create_metrics()
        if self.test_metrics is None:
            self.test_metrics = self._create_metrics()

    def forward(self, input_ids, attention_mask, labels=None):
        """Forward pass of the model."""
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        return outputs

    def _compute_metrics(self, preds, labels, metrics):
        """Compute all metrics including per-label metrics."""
        preds = preds.to(self.device)
        labels = labels.to(self.device)
        
        results = {}
        for name, metric in metrics.items():
            if 'label_' in name:
                # For per-label metrics, get the specific label index
                label_idx = int(name.split('_')[-1])
                value = metric(preds, labels)[label_idx]
                results[name] = value
            else:
                results[name] = metric(preds, labels)
        
        return results

    def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
        """Training step."""
        outputs = self(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            labels=batch['labels']
        )
        
        loss = outputs.loss
        logits = outputs.logits
        preds = torch.argmax(logits, dim=-1)
        
        # Initialize metrics if not already done
        if self.train_metrics is None:
            self.train_metrics = self._create_metrics()
        
        # Compute and log metrics
        metrics = self._compute_metrics(preds, batch['labels'], self.train_metrics)
        
        # Log metrics
        self.log("train_loss", loss, prog_bar=True)
        for name, value in metrics.items():
            self.log(f"train_{name}", value, prog_bar=True)
            
        return loss

    def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> None:
        outputs = self(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            labels=batch['labels']
        )
        
        loss = outputs.loss
        logits = outputs.logits
        preds = torch.argmax(logits, dim=-1)
        
        if self.val_metrics is None:
            self.val_metrics = self._create_metrics()
        
        metrics = self._compute_metrics(preds, batch['labels'], self.val_metrics)
        
        # Log all metrics
        self.log("val_loss", loss, prog_bar=True)
        for name, value in metrics.items():
            self.log(f"val_{name}", value, prog_bar=True)
        
        # Update best metrics if needed
        if not self.best_val_metrics or metrics['f1'] > self.best_val_metrics.get('f1', 0):
            self.best_val_metrics = metrics

    def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> None:
        """Test step."""
        outputs = self(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            labels=batch['labels']
        )
        
        loss = outputs.loss
        logits = outputs.logits
        preds = torch.argmax(logits, dim=-1)
        
        # Initialize metrics if not already done
        if self.test_metrics is None:
            self.test_metrics = self._create_metrics()
        
        # Compute and log metrics
        metrics = self._compute_metrics(preds, batch['labels'], self.test_metrics)
        
        # Log metrics
        self.log("test_loss", loss)
        for name, value in metrics.items():
            self.log(f"test_{name}", value)

    def configure_optimizers(self):
        """Configure optimizers and learning rate schedulers."""
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
        
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=0.1,
            patience=2,
            verbose=True
        )
        
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss",
            },
        }

    def predict_bio_labels(self, input_ids, attention_mask, label_mapper):
        """Get predictions with proper BIO tagging."""
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        return label_mapper.decode_with_bio(outputs.logits, attention_mask)

def train_pii_model(train_df, val_df, batch_size=16, max_epochs=10, save_dir="pii_model"):
    """Train the PII model with aggressive memory optimizations."""
    import gc
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Clear GPU cache and run garbage collection
    if device.type == "cuda":
        torch.cuda.empty_cache()
        gc.collect()

    os.makedirs(save_dir, exist_ok=True)
    checkpoints_dir = os.path.join(save_dir, "checkpoints")
    metrics_dir = os.path.join(save_dir, "metrics")
    os.makedirs(checkpoints_dir, exist_ok=True)
    os.makedirs(metrics_dir, exist_ok=True)

    # Initialize data module with reduced num_workers
    data_module = PiiDataModule(
        train_df=train_df,
        val_df=val_df,
        batch_size=batch_size
    )
    data_module.setup()
    
    # Initialize model with memory optimizations
    model = PiiModel(
        num_labels=data_module.label_mapper.num_labels,
        model_name="iiiorg/piiranha-v1-detect-personal-information"
    )
    
    initial_weights_dir = os.path.join(save_dir, "initial_weights")
    os.makedirs(initial_weights_dir, exist_ok=True)
    model.model.save_pretrained(initial_weights_dir)
    print(f"Initial weights of the model saved to {initial_weights_dir}")
    
    # Additional memory optimizations for the model
    if device.type == "cuda":
        model.model.gradient_checkpointing_enable()
        # Optimize memory allocation for attention mechanisms
        model.model.config.use_cache = False
    
    label_mapper_path = os.path.join(save_dir, "label_mapper.json")
    label_mapper_data = {
        'label_to_id': data_module.label_mapper.label_to_id,
        'id_to_label': data_module.label_mapper.id_to_label,
        'num_labels': data_module.label_mapper.num_labels
    }
    with open(label_mapper_path, 'w') as f:
        json.dump(label_mapper_data, f, indent=2)

    class EnhancedMetricCheckpoint(pl.callbacks.ModelCheckpoint):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
            self.metrics_history = []
            self.last_metrics=None
        def on_validation_end(self, trainer, pl_module):
            # Always save the current epoch checkpoint
            current_epoch = trainer.current_epoch
            current_loss = trainer.callback_metrics.get('val_loss', 0.0)
            
            # Create checkpoint filename
            checkpoint_filename = f'checkpoint-epoch-{current_epoch:02d}-loss-{current_loss:.4f}.ckpt'
            checkpoint_path = os.path.join(self.dirpath, checkpoint_filename)
            
            # Save checkpoint
            trainer.save_checkpoint(checkpoint_path)
            print(f"\nSaved checkpoint for epoch {current_epoch}: {checkpoint_filename}")
            
            # Collect metrics
            metrics = {
                'epoch': current_epoch,
                'global_step': trainer.global_step,
                'timestamp': datetime.datetime.now().isoformat(),
                'train_metrics': {},
                'val_metrics': {}
            }
            
            for key, value in trainer.callback_metrics.items():
                if isinstance(value, torch.Tensor):
                    value = value.item()
                if key.startswith('train_'):
                    metrics['train_metrics'][key] = value
                elif key.startswith('val_'):
                    metrics['val_metrics'][key] = value
            
            self.metrics_history.append(metrics)
            self.last_metrics = metrics
            # Save metrics
            metrics_file = os.path.join(metrics_dir, f'metrics_epoch_{current_epoch}.json')
            with open(metrics_file, 'w') as f:
                json.dump(metrics, f, indent=2)
            
            # Save the model in safetensors format
            model_dir = os.path.join(self.dirpath, f'model_epoch_{current_epoch:02d}')
            os.makedirs(model_dir, exist_ok=True)
            config_path = os.path.join(model_dir, 'config.json')
            model_weights_path = os.path.join(model_dir, 'model.safetensors')

            # Save model configuration
            pl_module.model.config.save_pretrained(model_dir)

            # Save model weights using safetensors
            save_file(pl_module.model.state_dict(), model_weights_path)

            print(f"\nSaved model configuration and weights for epoch {current_epoch}.")


            # Update best metrics if needed
            if self.best_model_path:
                best_metrics_file = os.path.join(metrics_dir, 'best_metrics.json')
                with open(best_metrics_file, 'w') as f:
                    json.dump(metrics, f, indent=2)

    
    class FinalEpochCallback(pl.Callback):
        def on_train_end(self, trainer, pl_module):
            # Save final model state
            final_model_path = os.path.join(save_dir, "final_model")
            os.makedirs(final_model_path, exist_ok=True)
            pl_module.model.save_pretrained(final_model_path)
            print(f"\nSaved final model to {final_model_path}")

    # Setup callbacks with reduced checkpoint frequency
    checkpoint_callback = EnhancedMetricCheckpoint(
        dirpath=checkpoints_dir,
        filename='checkpoint-{epoch:02d}-{val_loss:.4f}',
        monitor='val_loss',
        mode='min',
        save_top_k=3,  # Save top 3 models
        save_last=True,  # Save last model
        every_n_epochs=1  # Save checkpoint every epoch
    )
    
    early_stopping = DetailedEarlyStopping(
        monitor="val_loss",
        patience=2,
        mode="min"

    )
    
    lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval="epoch")
    
    # Setup logger with reduced logging frequency
    logger = TensorBoardLogger(
        "tb_logs", 
        name="piiranha_training",
        log_graph=False,  # Disable computational graph logging
        max_queue=10
    )
    callbacks = [
        checkpoint_callback, 
        early_stopping, 
        pl.callbacks.LearningRateMonitor(logging_interval="epoch"),
        FinalEpochCallback()
    ]
    # Initialize trainer with memory optimizations
    trainer = pl.Trainer(
        max_epochs=max_epochs,
        callbacks=[checkpoint_callback, early_stopping, lr_monitor],
        logger=logger,
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        devices=1,
        gradient_clip_val=1.0,
        log_every_n_steps=50,  # Reduced logging frequency
        precision="16-mixed",  # Use mixed precision training with reduced precision
        accumulate_grad_batches=2,  # Gradient accumulation to reduce memory usage
        strategy='ddp_notebook' if torch.cuda.is_available() else "auto",
        enable_progress_bar=True,
        enable_model_summary=False,  # Disable model summary to save memory
        inference_mode=True,  # Enable inference mode optimization
        profiler=None,  # Disable profiling
    )

    
    # Modified DataLoader settings in PiiDataModule
    data_module.train_dataloader = lambda: DataLoader(
        data_module.train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,  # Reduced number of workers
        pin_memory=False,  # Disable pin_memory
        persistent_workers=False,  # Disable persistent workers
        prefetch_factor=2  # Reduced prefetch factor
    )
    
    data_module.val_dataloader = lambda: DataLoader(
        data_module.val_dataset,
        batch_size=batch_size,
        num_workers=2,
        pin_memory=False,
        persistent_workers=False,
        prefetch_factor=2
    )
    
    # Train the model
    trainer.fit(model, datamodule=data_module)
    
    # Print final best metrics
    final_metrics = {
        'best_metrics': model.best_val_metrics,
        'final_metrics': checkpoint_callback.last_metrics if checkpoint_callback.last_metrics else {},
        'training_completed': True,
        'total_epochs': trainer.current_epoch,
        'early_stopped': trainer.should_stop,
        'best_model_path': checkpoint_callback.best_model_path,
        'last_model_path': checkpoint_callback.last_model_path
    }
    
    with open(os.path.join(save_dir, "training_summary.json"), 'w') as f:
        json.dump({k: float(v) if isinstance(v, torch.Tensor) else v 
                  for k, v in final_metrics.items()}, f, indent=2)
    
    print("\nTraining Summary:")
    print(f"Total epochs completed: {trainer.current_epoch + 1}")
    print(f"Best model saved at: {checkpoint_callback.best_model_path}")
    print("\nFinal Best Metrics:")
    for name, value in model.best_val_metrics.items():
        if isinstance(value, torch.Tensor):
            value = value.item()
        print(f"{name}: {value:.4f}")

    # Save the model and related files
    final_model_path = os.path.join(save_dir, "final_model")
    os.makedirs(final_model_path, exist_ok=True)
    model.model.save_pretrained(final_model_path)
    
    print(f"\nModel saved to {final_model_path}")
    
    return model, data_module.label_mapper

def load_checkpoint(checkpoint_path: str, model_dir="pii_model"):
    """
    Load a specific checkpoint.
    
    Args:
        checkpoint_path: Path to the checkpoint file
        model_dir: Directory containing label_mapper.json and tokenizer files
    """
    
    # Load label mapper
    with open(os.path.join(model_dir, 'label_mapper.json'), 'r') as f:
        label_mapper_data = json.load(f)
    
    label_mapper = LabelMapper()
    label_mapper.label_to_id = label_mapper_data['label_to_id']
    label_mapper.id_to_label = {int(k): v for k, v in label_mapper_data['id_to_label'].items()}
    label_mapper.num_labels = label_mapper_data['num_labels']
    
    # Load tokenizer
    tokenizer = BertTokenizerFast.from_pretrained("google-bert/bert-base-multilingual-cased")
    
    # Initialize model
    model = PiiModel(num_labels=label_mapper.num_labels)
    
    # Load checkpoint
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['state_dict'])
    
    return model, tokenizer, label_mapper

def list_checkpoints(save_dir="pii_model"):
    """List all available checkpoints."""
    import os
    
    checkpoint_dir = os.path.join(save_dir, "epoch_checkpoints")
    best_model_dir = os.path.join(save_dir, "best_model")
    
    print("Regular checkpoints:")
    if os.path.exists(checkpoint_dir):
        checkpoints = sorted([f for f in os.listdir(checkpoint_dir) if f.endswith('.ckpt')])
        for ckpt in checkpoints:
            print(f"  - {ckpt}")
    else:
        print("  No regular checkpoints found")
        
    print("\nBest model checkpoints:")
    if os.path.exists(best_model_dir):
        checkpoints = sorted([f for f in os.listdir(best_model_dir) if f.endswith('.ckpt')])
        for ckpt in checkpoints:
            print(f"  - {ckpt}")
    else:
        print("  No best model checkpoints found")

def list_saved_metrics(save_dir="pii_model"):
    """List all saved metrics and checkpoints."""
    metrics_dir = os.path.join(save_dir, "metrics")
    checkpoints_dir = os.path.join(save_dir, "checkpoints")
    
    print("Available metric files:")
    if os.path.exists(metrics_dir):
        metric_files = sorted([f for f in os.listdir(metrics_dir) if f.endswith('.json')])
        for file in metric_files:
            print(f"  - {file}")
    else:
        print("  No metric files found")
    
    print("\nAvailable checkpoints:")
    if os.path.exists(checkpoints_dir):
        checkpoint_files = sorted([f for f in os.listdir(checkpoints_dir) if f.endswith('.ckpt')])
        for file in checkpoint_files:
            print(f"  - {file}")
    else:
        print("  No checkpoints found")

In [5]:
import pandas as pd
import json
import ast

train_df = pd.read_csv('train.csv')

print(len(train_df))
val_df = pd.read_csv('validation.csv')

def contains_labels(privacy_mask_str):
    try:
        # Convert the string representation to a list of dictionaries
        privacy_mask = ast.literal_eval(privacy_mask_str)
        # Check if any label is ACCOUNTNUM or IDCARTNUM
        return any(item['label'] in ['ACCOUNTNUM', 'IDCARDNUM'] for item in privacy_mask)
    except (ValueError, SyntaxError):
        return False

# Apply the filtering function to the DataFrame
filtered_train_df = train_df[train_df['privacy_mask'].apply(contains_labels)]
filtered_val_df = val_df[val_df['privacy_mask'].apply(contains_labels)]

print(len(filtered_train_df))
len(filtered_val_df)


325517
31440


7898

In [6]:
def filter_labels(labels_str, valid_labels):
    try:
        if isinstance(labels_str, str):
            labels = ast.literal_eval(labels_str)
        elif isinstance(labels_str, list):
            labels = labels_str
        else:
            return ['O'] * 128  # Assuming max length of 128

        # Replace labels not in valid_labels with 'O'
        return [label if label in valid_labels else 'O' for label in labels]
    except Exception as e:
        print(f"Error parsing labels: {e}, value: {labels_str}")
        return ['O'] * 128  # Assuming max length of 128
valid_labels = ['O', 'B-ACCOUNTNUM', 'I-ACCOUNTNUM', 'B-IDCARDNUM', 'I-IDCARDNUM']
filtered_train_df['mbert_token_classes'] = filtered_train_df['mbert_token_classes'].apply(lambda x: filter_labels(x, valid_labels))
filtered_val_df['mbert_token_classes'] = filtered_val_df['mbert_token_classes'].apply(lambda x: filter_labels(x, valid_labels))

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  filtered_train_df['mbert_token_classes'] = filtered_train_df['mbert_token_classes'].apply(lambda x: filter_labels(x, valid_labels))
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  filtered_val_df['mbert_token_classes'] = filtered_val_df['mbert_token_classes'].apply(lambda x: filter_labels(x, valid_labels))


In [4]:
model, label_mapper = train_pii_model(filtered_train_df, filtered_val_df,batch_size=32, save_dir="pii_model", max_epochs=10)

list_saved_metrics() # Best metrics and best Checkpoint


Found 2 unique base labels (excluding O)
Total number of labels after mapping: 3

Label mapping:
B-ACCOUNTNUM: 1
B-IDCARDNUM: 2
I-ACCOUNTNUM: 1
I-IDCARDNUM: 2
O: 0


Some weights of DebertaV2ForTokenClassification were not initialized from the model checkpoint at iiiorg/piiranha-v1-detect-personal-information and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([18]) in the checkpoint and torch.Size([3]) in the model instantiated
- classifier.weight: found shape torch.Size([18, 768]) in the checkpoint and torch.Size([3, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Initial weights of the model saved to pii_model/initial_weights


You are using a CUDA device ('NVIDIA L4') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------




Found 2 unique base labels (excluding O)
Total number of labels after mapping: 3

Label mapping:
B-ACCOUNTNUM: 1
B-IDCARDNUM: 2
I-ACCOUNTNUM: 1
I-IDCARDNUM: 2
O: 0


/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /teamspace/studios/this_studio/pii_model/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:431: It is recommended to use `self.log('val_loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:431: It is recommended to use `self.log('val_precision', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:431: It is recommended to use `self.log('val_recall', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/pytorch_lightning/trainer/connector


Saved checkpoint for epoch 0: checkpoint-epoch-00-loss-1.2656.ckpt

Saved model configuration and weights for epoch 0.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]


Saved checkpoint for epoch 0: checkpoint-epoch-00-loss-0.0242.ckpt

Saved model configuration and weights for epoch 0.


Validation: |          | 0/? [00:00<?, ?it/s]


Saved checkpoint for epoch 1: checkpoint-epoch-01-loss-0.0172.ckpt

Saved model configuration and weights for epoch 1.


Validation: |          | 0/? [00:00<?, ?it/s]


Saved checkpoint for epoch 2: checkpoint-epoch-02-loss-0.0197.ckpt

Saved model configuration and weights for epoch 2.


Validation: |          | 0/? [00:00<?, ?it/s]


Saved checkpoint for epoch 3: checkpoint-epoch-03-loss-0.0177.ckpt

Saved model configuration and weights for epoch 3.

Training Summary:
Total epochs completed: 1
Best model saved at: /teamspace/studios/this_studio/pii_model/checkpoints/checkpoint-epoch=01-val_loss=0.0172.ckpt

Final Best Metrics:

Model saved to pii_model/final_model
Available metric files:
  - best_metrics.json
  - metrics_epoch_0.json
  - metrics_epoch_1.json
  - metrics_epoch_2.json
  - metrics_epoch_3.json

Available checkpoints:
  - checkpoint-epoch-00-loss-0.0242.ckpt
  - checkpoint-epoch-00-loss-1.2656.ckpt
  - checkpoint-epoch-01-loss-0.0172.ckpt
  - checkpoint-epoch-02-loss-0.0197.ckpt
  - checkpoint-epoch-03-loss-0.0177.ckpt
  - checkpoint-epoch=01-val_loss=0.0172.ckpt
  - checkpoint-epoch=02-val_loss=0.0197.ckpt
  - checkpoint-epoch=03-val_loss=0.0177.ckpt
  - last-v1.ckpt
  - last-v2.ckpt
  - last.ckpt


In [7]:
# Load the config
config = AutoConfig.from_pretrained("pii_model/checkpoints/model_epoch_01")

# Initialize the model with the config
model = AutoModelForTokenClassification.from_config(config)

# Load the safetensors weights
state_dict = load_file("pii_model/checkpoints/model_epoch_01/model.safetensors")

# Load the state dict into the model
model.load_state_dict(state_dict)

# Load the tokenizer
tokenizer = BertTokenizerFast.from_pretrained("google-bert/bert-base-multilingual-cased")

# Load the label mapper if needed
with open("pii_model/label_mapper.json", 'r') as f:
    label_mapper_data = json.load(f)

label_mapper = LabelMapper()
label_mapper.label_to_id = label_mapper_data['label_to_id']
label_mapper.id_to_label = {int(k): v for k, v in label_mapper_data['id_to_label'].items()}
label_mapper.num_labels = label_mapper_data['num_labels']

In [11]:

# Set model to evaluation mode
model.eval()

# Get the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Get the text
text = filtered_val_df.iloc[58].source_text


# Tokenize with padding and truncation
encoded_text = tokenizer(
    text,
    padding='max_length',
    truncation=True,
    max_length=128,  # or whatever max_length you used during training
    return_tensors="pt"  # Return PyTorch tensors
)

# Move inputs to the same device as model
inputs = {
    'input_ids': encoded_text['input_ids'].to(device),
    'attention_mask': encoded_text['attention_mask'].to(device)
}

# Perform inference
with torch.no_grad():  # Disable gradient calculation
    outputs = model(**inputs)

# Get predictions
predictions = torch.argmax(outputs.logits, dim=-1)

# Convert predictions to labels using label_mapper
predicted_labels = label_mapper.decode(predictions[0].cpu().tolist())

# Print results
print("Original text:", text)
print("\nPredicted labels:", predicted_labels)

# If you want to see the token-label alignment:
tokens = tokenizer.convert_ids_to_tokens(encoded_text['input_ids'][0])
print("\nToken-label alignment:")
for token, label in zip(tokens, predicted_labels):
    if token != '[PAD]':  # Skip padding tokens
        print(f"{token}: {label}")

Original text: 2    Budget Proposal    Proposal outlining the budget allocation for a specific project or department.
Begrotingsaanvraag voor de marketing afdeling, specifiek project ID 389680211. Budget: 115,190.18 euro. Bevestiging via postcode 6718 en rekeningnummer 614170101. Adres: Quadruplex 90.

Predicted labels: ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-IDCARDNUM', 'I-IDCARDNUM', 'I-IDCARDNUM', 'I-IDCARDNUM', 'I-IDCARDNUM', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-ACCOUNTNUM', 'I-ACCOUNTNUM', 'I-ACCOUNTNUM', 'I-ACCOUNTNUM', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'