In [1]:
!pip install seqeval

Collecting seqeval
  Downloading seqeval-1.2.2.tar.gz (43 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/43.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: seqeval
  Building wheel for seqeval (setup.py) ... [?25l[?25hdone
  Created wheel for seqeval: filename=seqeval-1.2.2-py3-none-any.whl size=16162 sha256=008f9729f10939bf743a2f32679ea3644b03e5049fc3159af320663e293b7f82
  Stored in directory: /root/.cache/pip/wheels/5f/b8/73/0b2c1a76b701a677653dd79ece07cfabd7457989dbfbdcd8d7
Successfully built seqeval
Installing collected packages: seqeval
Successfully installed seqeval-1.2.2


In [2]:

import os
import random
import logging
import torch
import numpy as np
from typing import List, Tuple, Dict, Optional
from dataclasses import dataclass

from transformers import (
    DistilBertConfig,
    DistilBertForTokenClassification,
    DistilBertTokenizer,
    # AdamW,
    get_linear_schedule_with_warmup,
)
from torch.optim import AdamW
from torch.utils.data import DataLoader, TensorDataset, RandomSampler, SequentialSampler
from torch.nn import CrossEntropyLoss
from seqeval.metrics import f1_score, precision_score, recall_score, classification_report
from tqdm import tqdm

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [9]:
labels = ['O', 'Authentication']

In [10]:
@dataclass
class Config:
    """Configuration for DistilBERT model training"""
    model_name: str = "distilbert-base-cased"
    max_seq_length: int = 256
    batch_size: int = 8
    learning_rate: float = 5e-5
    num_epochs: int = 10
    weight_decay: float = 0.0
    warmup_steps: int = 0
    adam_epsilon: float = 1e-8
    max_grad_norm: float = 1.0
    seed: int = 42
    do_lower_case: bool = False

In [16]:
class InputExample:
    """A single training/test example for token classification"""

    def __init__(self, guid: str, words: List[str], labels: List[str]):
        self.guid = guid
        self.words = words
        self.labels = labels


class InputFeatures:
    """A single set of features of data"""

    def __init__(self, input_ids, attention_mask, label_ids):
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.label_ids = label_ids

In [56]:
class DistilBertNERInference:
    """
    Inference pipeline for DistilBERT-based NER models
    """

    def __init__(
        self,
        model_dir: str,
        labels: List[str],
        max_seq_length: int = 256,
        device: Optional[str] = None,
        batch_size: int = 16

    ):
      self.model_dir = model_dir
      self.max_seq_length = max_seq_length
      self.batch_size = batch_size
      self.label_map = {label: i for i, label in enumerate(labels)}
      self.num_labels = len(labels)
      self.config = DistilBertConfig()
          # Set device
      if device:
          self.device = torch.device(device)
      else:
          self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
      self.do_lower_case = Config.do_lower_case
      # Padding label ID
      self.pad_token_label_id = CrossEntropyLoss().ignore_index


    def _set_seed(self):
        """Set random seed for reproducibility"""

        random.seed(Config.seed)
        np.random.seed(Config.seed)
        torch.manual_seed(Config.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(Config.seed)

    def load_model(self):
        """Load fintuned DistilBERT model and tokenizer"""
        logger.info(f"Loading DistilBERT model: {self.model_dir}")

        # Load tokenizer
        self.tokenizer = DistilBertTokenizer.from_pretrained(
            self.model_dir,
            do_lower_case=self.do_lower_case
        )

        # Load model configuration
        model_config = DistilBertConfig.from_pretrained(
            self.model_dir,
            num_labels=self.num_labels
        )

        # Load model
        self.model = DistilBertForTokenClassification.from_pretrained(
            self.model_dir,
            config=model_config
        )

        self.model.to(self.device)
        logger.info("DistilBERT model loaded successfully")

    def read_examples(self, conll_string: str, mode: str = "train") -> List[InputExample]:
        """
        Read examples from a CoNLL-format file

        Args:
            file_path: Path to the input file
            mode: Mode indicator (train/dev/test)

        Returns:
            List of InputExample objects
        """
        examples = []
        guid_index = 1




        words = []
        labels = []

        for line in conll_string.split("\n"):
            line = line.strip()

            if line.startswith("-DOCSTART-") or line == "":
                if words:
                    examples.append(InputExample(
                        guid=f"{mode}-{guid_index}",
                        words=words,
                        labels=labels
                    ))
                    guid_index += 1
                    words = []
                    labels = []
            else:
                splits = line.split()
                if len(splits) >= 2:
                    words.append(splits[0])
                    labels.append(splits[-1])

        # Add last example if exists
        if words:
            examples.append(InputExample(
                guid=f"{mode}-{guid_index}",
                words=words,
                labels=labels
            ))

        # logger.info(f"Read {len(examples)} examples from {file_path}")
        return examples

    def _read_examples_from_file(self, file_path: str, mode: str = "train") -> List[InputExample]:
        """
        Read examples from a CoNLL-format file

        Args:
            file_path: Path to the input file
            mode: Mode indicator (train/dev/test)

        Returns:
            List of InputExample objects
        """
        examples = []
        guid_index = 1

        with open(file_path, 'r', encoding='utf-8') as f:
            words = []
            labels = []

            for line in f:
                line = line.strip()

                if line.startswith("-DOCSTART-") or line == "":
                    if words:
                        examples.append(InputExample(
                            guid=f"{mode}-{guid_index}",
                            words=words,
                            labels=labels
                        ))
                        guid_index += 1
                        words = []
                        labels = []
                else:
                    splits = line.split()
                    if len(splits) >= 2:
                        words.append(splits[0])
                        labels.append(splits[-1])

            # Add last example if exists
            if words:
                examples.append(InputExample(
                    guid=f"{mode}-{guid_index}",
                    words=words,
                    labels=labels
                ))

        logger.info(f"Read {len(examples)} examples from {file_path}")
        return examples

    def _convert_examples_to_features(
        self,
        examples: List[InputExample]
    ) -> List[InputFeatures]:
        """
        Convert examples to features suitable for DistilBERT
        Note: DistilBERT doesn't use token_type_ids

        Args:
            examples: List of InputExample objects

        Returns:
            List of InputFeatures objects
        """
        features = []

        for ex_index, example in enumerate(tqdm(examples, desc="Converting examples")):
            tokens = []
            label_ids = []

            # Tokenize each word and align labels
            for word, label in zip(example.words, example.labels):
                word_tokens = self.tokenizer.tokenize(word)

                if len(word_tokens) > 0:
                    tokens.extend(word_tokens)
                    # Use the real label for the first token, pad for others
                    label_ids.extend([self.label_map[label]] +
                                   [self.pad_token_label_id] * (len(word_tokens) - 1))

            # Truncate if necessary
            max_length = self.max_seq_length - 2  # Account for [CLS] and [SEP]
            if len(tokens) > max_length:
                tokens = tokens[:max_length]
                label_ids = label_ids[:max_length]

            # Add special tokens
            tokens = [self.tokenizer.cls_token] + tokens + [self.tokenizer.sep_token]
            label_ids = [self.pad_token_label_id] + label_ids + [self.pad_token_label_id]

            # Convert tokens to IDs
            input_ids = self.tokenizer.convert_tokens_to_ids(tokens)

            # Create attention mask
            attention_mask = [1] * len(input_ids)

            # Pad to max length
            padding_length = self.max_seq_length - len(input_ids)
            input_ids += [self.tokenizer.pad_token_id] * padding_length
            attention_mask += [0] * padding_length
            label_ids += [self.pad_token_label_id] * padding_length

            assert len(input_ids) == self.max_seq_length
            assert len(attention_mask) == self.max_seq_length
            assert len(label_ids) == self.max_seq_length

            features.append(InputFeatures(
                input_ids=input_ids,
                attention_mask=attention_mask,
                label_ids=label_ids
            ))

        return features

    def _create_dataloader(
        self,
        features: List[InputFeatures],
        shuffle: bool = False
    ) -> DataLoader:
        """
        Create a DataLoader from features

        Args:
            features: List of InputFeatures
            shuffle: Whether to shuffle the data

        Returns:
            DataLoader object
        """
        all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
        all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
        all_label_ids = torch.tensor([f.label_ids for f in features], dtype=torch.long)

        dataset = TensorDataset(all_input_ids, all_attention_mask, all_label_ids)

        sampler = RandomSampler(dataset) if shuffle else SequentialSampler(dataset)
        dataloader = DataLoader(dataset, sampler=sampler, batch_size=self.batch_size)

        return dataloader

    def evaluate(self, eval_file: str, return_predictions: bool = True) -> Tuple[Dict[str, float], Optional[str]]:
        """
        Evaluate the model

        Args:
            eval_file: Path to evaluation data file
            return_predictions: If True, return predictions text along with metrics

        Returns:
            Tuple of (metrics_dict, predictions_text)
            - metrics_dict: Dictionary containing evaluation metrics
            - predictions_text: String with "token prediction" format (None if return_predictions=False)
        """
        if self.model is None or self.tokenizer is None:
            raise ValueError("Model not loaded.")

        logger.info("Running evaluation...")

        # Load and prepare evaluation data
        # eval_examples = self.read_examples(eval_file, mode="eval")
        eval_examples = self._read_examples_from_file(eval_file, mode="eval")
        eval_features = self._convert_examples_to_features(eval_examples)
        eval_dataloader = self._create_dataloader(eval_features, shuffle=False)

        # Evaluation
        self.model.eval()
        eval_loss = 0.0
        nb_eval_steps = 0
        preds = None
        out_label_ids = None

        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            batch = tuple(t.to(self.device) for t in batch)

            with torch.no_grad():
                inputs = {
                    "input_ids": batch[0],
                    "attention_mask": batch[1],
                    "labels": batch[2]
                }

                outputs = self.model(**inputs)
                tmp_eval_loss, logits = outputs[:2]

                eval_loss += tmp_eval_loss.item()

            nb_eval_steps += 1

            if preds is None:
                preds = logits.detach().cpu().numpy()
                out_label_ids = inputs["labels"].detach().cpu().numpy()
            else:
                preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
                out_label_ids = np.append(
                    out_label_ids,
                    inputs["labels"].detach().cpu().numpy(),
                    axis=0
                )

        eval_loss = eval_loss / nb_eval_steps
        preds = np.argmax(preds, axis=2)

        # Convert IDs to labels
        id_to_label = {i: label for label, i in self.label_map.items()}

        out_label_list = []
        preds_list = []

        for i in range(out_label_ids.shape[0]):
            temp_out = []
            temp_pred = []
            for j in range(out_label_ids.shape[1]):
                if out_label_ids[i, j] != self.pad_token_label_id:
                    temp_out.append(id_to_label[out_label_ids[i][j]])
                    temp_pred.append(id_to_label[preds[i][j]])
            out_label_list.append(temp_out)
            preds_list.append(temp_pred)

        # Calculate metrics
        results = {
            "loss": eval_loss,
            "precision": precision_score(out_label_list, preds_list),
            "recall": recall_score(out_label_list, preds_list),
            "f1": f1_score(out_label_list, preds_list),
        }

        # Print detailed classification report
        logger.info("\nClassification Report:")
        logger.info("\n" + classification_report(out_label_list, preds_list))

        # Generate predictions text if requested
        predictions_text = None
        if return_predictions:
            # Read original tokens from file
            original_tokens = []
            with open(eval_file, 'r', encoding='utf-8') as f:
                for line in f:
                    line = line.strip()
                    if line.startswith("-DOCSTART-") or line == "":
                        if original_tokens:
                            original_tokens.append("")  # Empty line between sentences
                    else:
                        splits = line.split()
                        if len(splits) >= 1:
                            original_tokens.append(splits[0])

            # Flatten predictions list
            flat_predictions = []
            for sent_preds in preds_list:
                flat_predictions.extend(sent_preds)
                flat_predictions.append("")  # Empty line between sentences

            # Create predictions text
            predictions_lines = []
            token_idx = 0
            pred_idx = 0

            for token in original_tokens:
                if token == "":
                    predictions_lines.append("")
                else:
                    if pred_idx < len(flat_predictions) and flat_predictions[pred_idx] != "":
                        predictions_lines.append(f"{token} {flat_predictions[pred_idx]}")
                        pred_idx += 1
                    else:
                        predictions_lines.append(f"{token} O")
                        if pred_idx < len(flat_predictions):
                            pred_idx += 1

            predictions_text = "\n".join(predictions_lines)

        return results, predictions_text
    def evaluate_over_text(self, conll_string: str, return_predictions: bool = True) -> Tuple[Dict[str, float], Optional[str]]:
        """
        Evaluate the model

        Args:
            eval_file: Path to evaluation data file
            return_predictions: If True, return predictions text along with metrics

        Returns:
            Tuple of (metrics_dict, predictions_text)
            - metrics_dict: Dictionary containing evaluation metrics
            - predictions_text: String with "token prediction" format (None if return_predictions=False)
        """
        if self.model is None or self.tokenizer is None:
            raise ValueError("Model not loaded.")

        logger.info("Running evaluation...")

        # Load and prepare evaluation data
        eval_examples = self.read_examples(conll_string, mode="eval")
        # eval_examples = self._read_examples_from_file(eval_file, mode="eval")
        eval_features = self._convert_examples_to_features(eval_examples)
        eval_dataloader = self._create_dataloader(eval_features, shuffle=False)

        # Evaluation
        self.model.eval()
        eval_loss = 0.0
        nb_eval_steps = 0
        preds = None
        out_label_ids = None

        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            batch = tuple(t.to(self.device) for t in batch)

            with torch.no_grad():
                inputs = {
                    "input_ids": batch[0],
                    "attention_mask": batch[1],
                    "labels": batch[2]
                }

                outputs = self.model(**inputs)
                tmp_eval_loss, logits = outputs[:2]

                eval_loss += tmp_eval_loss.item()

            nb_eval_steps += 1

            if preds is None:
                preds = logits.detach().cpu().numpy()
                out_label_ids = inputs["labels"].detach().cpu().numpy()
            else:
                preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
                out_label_ids = np.append(
                    out_label_ids,
                    inputs["labels"].detach().cpu().numpy(),
                    axis=0
                )

        eval_loss = eval_loss / nb_eval_steps
        preds = np.argmax(preds, axis=2)

        # Convert IDs to labels
        id_to_label = {i: label for label, i in self.label_map.items()}

        out_label_list = []
        preds_list = []

        for i in range(out_label_ids.shape[0]):
            temp_out = []
            temp_pred = []
            for j in range(out_label_ids.shape[1]):
                if out_label_ids[i, j] != self.pad_token_label_id:
                    temp_out.append(id_to_label[out_label_ids[i][j]])
                    temp_pred.append(id_to_label[preds[i][j]])
            out_label_list.append(temp_out)
            preds_list.append(temp_pred)

        # Calculate metrics
        # results = {
        #     "loss": eval_loss,
        #     "precision": precision_score(out_label_list, preds_list),
        #     "recall": recall_score(out_label_list, preds_list),
        #     "f1": f1_score(out_label_list, preds_list),
        # }

        # # Print detailed classification report
        # logger.info("\nClassification Report:")
        # logger.info("\n" + classification_report(out_label_list, preds_list))

        # Generate predictions text if requested
        predictions_text = None
        if return_predictions:
            # Read original tokens from file
            original_tokens = []
            # with open(eval_file, 'r', encoding='utf-8') as f:
            for line in conll_string.split('\n'):
                line = line.strip()
                if line.startswith("-DOCSTART-") or line == "":
                    if original_tokens:
                        original_tokens.append("")  # Empty line between sentences
                else:
                    splits = line.split()
                    if len(splits) >= 1:
                        original_tokens.append(splits[0])

            # Flatten predictions list
            flat_predictions = []
            for sent_preds in preds_list:
                flat_predictions.extend(sent_preds)
                flat_predictions.append("")  # Empty line between sentences

            # Create predictions text
            predictions_lines = []
            token_idx = 0
            pred_idx = 0

            for token in original_tokens:
                if token == "":
                    predictions_lines.append("")
                else:
                    if pred_idx < len(flat_predictions) and flat_predictions[pred_idx] != "":
                        predictions_lines.append(f"{token} {flat_predictions[pred_idx]}")
                        pred_idx += 1
                    else:
                        predictions_lines.append(f"{token} O")
                        if pred_idx < len(flat_predictions):
                            pred_idx += 1

            predictions_text = "\n".join(predictions_lines)

        return results, predictions_text



In [57]:
infer = DistilBertNERInference(
    model_dir="drive/MyDrive/exl/Authentication/model",
    max_seq_length=Config.max_seq_length,
    labels=labels
)

In [58]:
infer.load_model()

# Prediction

## Prediction over text file

In [24]:
results, predictions = infer.evaluate(eval_file="drive/MyDrive/exl/Authentication/data/test.txt", return_predictions=True)

Converting examples: 100%|██████████| 50/50 [00:00<00:00, 137.48it/s]
Evaluating: 100%|██████████| 4/4 [00:26<00:00,  6.66s/it]


In [25]:
results

{'loss': 0.05847675009863451,
 'precision': np.float64(0.8269230769230769),
 'recall': np.float64(0.9148936170212766),
 'f1': np.float64(0.8686868686868686)}

In [22]:
with open("drive/MyDrive/exl/Authentication/data/pred__text_v1.txt","w+") as f:
  f.write(predictions)

In [21]:
predictions



### validation

In [44]:
with open("drive/MyDrive/exl/Authentication/data/pred_v0.txt","r") as f:
  pred_v0_text = f.read()

In [46]:
pred_v1_rows = predictions.split("\n")
pred_v0_rows = pred_v0_text.split("\n")


mismatch_count=0
for i, pred_v0_row in enumerate(pred_v0_rows):
  pred_v1_row = pred_v1_rows[i]

  if pred_v0_row != pred_v1_row:
    print(i)
    print("pred_v0_row : ",pred_v0_row)
    print("pred_v1_row : ",pred_v1_row)
    mismatch_count += 1

In [45]:
pred_v0_text



In [47]:
mismatch_count

0

In [48]:
len(pred_v0_rows)

6733

In [50]:
len(pred_v1_rows)

6733

## Prediction over text string

In [59]:
with open("drive/MyDrive/exl/Authentication/data/test.txt", "r") as f:
  test_data = f.read()

test_data_rows = test_data.split("\n")
test_data_processed_rows = []

for row in test_data_rows:
  tk_label = row.split(' ')
  if len(tk_label) < 2:
    test_data_processed_rows.append(row)
    continue

  test_data_processed_rows.append(tk_label[0] + " O")
processed_test_data = "\n".join(test_data_processed_rows)


In [61]:
results, predictions = infer.evaluate_over_text(processed_test_data, return_predictions=True)

Converting examples: 100%|██████████| 50/50 [00:00<00:00, 204.94it/s]
Evaluating: 100%|██████████| 4/4 [00:21<00:00,  5.35s/it]


## validation

In [62]:
with open("drive/MyDrive/exl/Authentication/data/pred_v0.txt","r") as f:
  pred_v0_text = f.read()

In [63]:
pred_v1_rows = predictions.split("\n")
pred_v0_rows = pred_v0_text.split("\n")


mismatch_count=0
for i, pred_v0_row in enumerate(pred_v0_rows):
  pred_v1_row = pred_v1_rows[i]

  if pred_v0_row != pred_v1_row:
    print(i)
    print("pred_v0_row : ",pred_v0_row)
    print("pred_v1_row : ",pred_v1_row)
    mismatch_count += 1

In [64]:
mismatch_count

0

In [67]:
len(predictions.split("\n"))

6733

In [66]:
pred_v0_text



# Ref

In [None]:
# class DistilBertNERTraining:
#     """
#     Class for training DistilBERT-based NER models
#     DistilBERT is 40% smaller and 60% faster than BERT while retaining 97% of performance
#     """

#     def __init__(
#         self,
#         labels: List[str],
#         model_save_dir: str,
#         config: Optional[DistilBertConfig] = None,
#         device: Optional[str] = None
#     ):
#         """
#         Initialize the training class

#         Args:
#             labels: List of NER labels (e.g., ['O', 'Authentication'])
#             model_save_dir: Directory to save the trained model
#             config: Training configuration
#             device: Device to use for training ('cuda' or 'cpu')
#         """
#         self.labels = labels
#         self.label_map = {label: i for i, label in enumerate(labels)}
#         self.num_labels = len(labels)
#         self.model_save_dir = model_save_dir
#         self.config = config or DistilBertConfig()

#         # Set device
#         if device:
#             self.device = torch.device(device)
#         else:
#             self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#         # Padding label ID
#         self.pad_token_label_id = CrossEntropyLoss().ignore_index

#         # Initialize model and tokenizer as None
#         self.model = None
#         self.tokenizer = None

#         # Set random seed for reproducibility
#         self._set_seed()

#         logger.info(f"Initialized DistilBertNERTraining with {self.num_labels} labels")
#         logger.info(f"Device: {self.device}")
#         logger.info(f"Model: {self.config.model_name}")

#     def _set_seed(self):
#         """Set random seed for reproducibility"""
#         import random
#         random.seed(self.config.seed)
#         np.random.seed(self.config.seed)
#         torch.manual_seed(self.config.seed)
#         if torch.cuda.is_available():
#             torch.cuda.manual_seed_all(self.config.seed)

#     def load_pretrained_model(self):
#         """Load pretrained DistilBERT model and tokenizer"""
#         logger.info(f"Loading pretrained DistilBERT model: {self.config.model_name}")

#         # Load tokenizer
#         self.tokenizer = DistilBertTokenizer.from_pretrained(
#             self.config.model_name,
#             do_lower_case=False
#         )

#         # Load model configuration
#         model_config = DistilBertConfig.from_pretrained(
#             self.config.model_name,
#             num_labels=self.num_labels
#         )

#         # Load model
#         self.model = DistilBertForTokenClassification.from_pretrained(
#             self.config.model_name,
#             config=model_config
#         )

#         self.model.to(self.device)
#         logger.info("DistilBERT model loaded successfully")

#     def _read_examples_from_file(self, file_path: str, mode: str = "train") -> List[InputExample]:
#         """
#         Read examples from a CoNLL-format file

#         Args:
#             file_path: Path to the input file
#             mode: Mode indicator (train/dev/test)

#         Returns:
#             List of InputExample objects
#         """
#         examples = []
#         guid_index = 1

#         with open(file_path, 'r', encoding='utf-8') as f:
#             words = []
#             labels = []

#             for line in f:
#                 line = line.strip()

#                 if line.startswith("-DOCSTART-") or line == "":
#                     if words:
#                         examples.append(InputExample(
#                             guid=f"{mode}-{guid_index}",
#                             words=words,
#                             labels=labels
#                         ))
#                         guid_index += 1
#                         words = []
#                         labels = []
#                 else:
#                     splits = line.split()
#                     if len(splits) >= 2:
#                         words.append(splits[0])
#                         labels.append(splits[-1])

#             # Add last example if exists
#             if words:
#                 examples.append(InputExample(
#                     guid=f"{mode}-{guid_index}",
#                     words=words,
#                     labels=labels
#                 ))

#         logger.info(f"Read {len(examples)} examples from {file_path}")
#         return examples

#     def _convert_examples_to_features(
#         self,
#         examples: List[InputExample]
#     ) -> List[InputFeatures]:
#         """
#         Convert examples to features suitable for DistilBERT
#         Note: DistilBERT doesn't use token_type_ids

#         Args:
#             examples: List of InputExample objects

#         Returns:
#             List of InputFeatures objects
#         """
#         features = []

#         for ex_index, example in enumerate(tqdm(examples, desc="Converting examples")):
#             tokens = []
#             label_ids = []

#             # Tokenize each word and align labels
#             for word, label in zip(example.words, example.labels):
#                 word_tokens = self.tokenizer.tokenize(word)

#                 if len(word_tokens) > 0:
#                     tokens.extend(word_tokens)
#                     # Use the real label for the first token, pad for others
#                     label_ids.extend([self.label_map[label]] +
#                                    [self.pad_token_label_id] * (len(word_tokens) - 1))

#             # Truncate if necessary
#             max_length = self.config.max_seq_length - 2  # Account for [CLS] and [SEP]
#             if len(tokens) > max_length:
#                 tokens = tokens[:max_length]
#                 label_ids = label_ids[:max_length]

#             # Add special tokens
#             tokens = [self.tokenizer.cls_token] + tokens + [self.tokenizer.sep_token]
#             label_ids = [self.pad_token_label_id] + label_ids + [self.pad_token_label_id]

#             # Convert tokens to IDs
#             input_ids = self.tokenizer.convert_tokens_to_ids(tokens)

#             # Create attention mask
#             attention_mask = [1] * len(input_ids)

#             # Pad to max length
#             padding_length = self.config.max_seq_length - len(input_ids)
#             input_ids += [self.tokenizer.pad_token_id] * padding_length
#             attention_mask += [0] * padding_length
#             label_ids += [self.pad_token_label_id] * padding_length

#             assert len(input_ids) == self.config.max_seq_length
#             assert len(attention_mask) == self.config.max_seq_length
#             assert len(label_ids) == self.config.max_seq_length

#             features.append(InputFeatures(
#                 input_ids=input_ids,
#                 attention_mask=attention_mask,
#                 label_ids=label_ids
#             ))

#         return features

#     def _create_dataloader(
#         self,
#         features: List[InputFeatures],
#         shuffle: bool = False
#     ) -> DataLoader:
#         """
#         Create a DataLoader from features

#         Args:
#             features: List of InputFeatures
#             shuffle: Whether to shuffle the data

#         Returns:
#             DataLoader object
#         """
#         all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
#         all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
#         all_label_ids = torch.tensor([f.label_ids for f in features], dtype=torch.long)

#         dataset = TensorDataset(all_input_ids, all_attention_mask, all_label_ids)

#         sampler = RandomSampler(dataset) if shuffle else SequentialSampler(dataset)
#         dataloader = DataLoader(dataset, sampler=sampler, batch_size=self.config.batch_size)

#         return dataloader

#     def train(self, train_file: str, dev_file: Optional[str] = None):
#         """
#         Train the model

#         Args:
#             train_file: Path to training data file
#             dev_file: Optional path to development/validation data file
#         """
#         if self.model is None or self.tokenizer is None:
#             raise ValueError("Model not loaded. Call load_pretrained_model() first.")

#         logger.info("Starting training...")

#         # Load and prepare training data
#         train_examples = self._read_examples_from_file(train_file, mode="train")
#         train_features = self._convert_examples_to_features(train_examples)
#         train_dataloader = self._create_dataloader(train_features, shuffle=True)

#         # Calculate total training steps
#         num_train_steps = len(train_dataloader) * self.config.num_epochs

#         # Prepare optimizer and scheduler
#         no_decay = ["bias", "LayerNorm.weight"]
#         optimizer_grouped_parameters = [
#             {
#                 "params": [p for n, p in self.model.named_parameters()
#                           if not any(nd in n for nd in no_decay)],
#                 "weight_decay": self.config.weight_decay,
#             },
#             {
#                 "params": [p for n, p in self.model.named_parameters()
#                           if any(nd in n for nd in no_decay)],
#                 "weight_decay": 0.0,
#             },
#         ]

#         optimizer = AdamW(
#             optimizer_grouped_parameters,
#             lr=self.config.learning_rate,
#             eps=self.config.adam_epsilon
#         )

#         scheduler = get_linear_schedule_with_warmup(
#             optimizer,
#             num_warmup_steps=self.config.warmup_steps,
#             num_training_steps=num_train_steps
#         )

#         # Training loop
#         logger.info(f"***** Running training *****")
#         logger.info(f"  Num examples = {len(train_examples)}")
#         logger.info(f"  Num epochs = {self.config.num_epochs}")
#         logger.info(f"  Batch size = {self.config.batch_size}")
#         logger.info(f"  Total optimization steps = {num_train_steps}")

#         global_step = 0
#         train_loss = 0.0

#         self.model.zero_grad()

#         for epoch in range(self.config.num_epochs):
#             logger.info(f"Epoch {epoch + 1}/{self.config.num_epochs}")

#             self.model.train()
#             epoch_loss = 0.0

#             for step, batch in enumerate(tqdm(train_dataloader, desc="Training")):
#                 batch = tuple(t.to(self.device) for t in batch)

#                 inputs = {
#                     "input_ids": batch[0],
#                     "attention_mask": batch[1],
#                     "labels": batch[2]
#                 }

#                 outputs = self.model(**inputs)
#                 loss = outputs[0]

#                 loss.backward()

#                 train_loss += loss.item()
#                 epoch_loss += loss.item()

#                 torch.nn.utils.clip_grad_norm_(
#                     self.model.parameters(),
#                     self.config.max_grad_norm
#                 )

#                 optimizer.step()
#                 scheduler.step()
#                 self.model.zero_grad()
#                 global_step += 1

#             avg_epoch_loss = epoch_loss / len(train_dataloader)
#             logger.info(f"Epoch {epoch + 1} - Average loss: {avg_epoch_loss:.4f}")

#             # Evaluate on dev set if provided
#             if dev_file:
#                 dev_metrics, _ = self.evaluate(dev_file, return_predictions=False)
#                 logger.info(f"Dev metrics - Precision: {dev_metrics['precision']:.4f}, "
#                           f"Recall: {dev_metrics['recall']:.4f}, F1: {dev_metrics['f1']:.4f}")

#         avg_train_loss = train_loss / global_step
#         logger.info(f"Training completed. Average training loss: {avg_train_loss:.4f}")

#     def evaluate(self, eval_file: str, return_predictions: bool = True) -> Tuple[Dict[str, float], Optional[str]]:
#         """
#         Evaluate the model

#         Args:
#             eval_file: Path to evaluation data file
#             return_predictions: If True, return predictions text along with metrics

#         Returns:
#             Tuple of (metrics_dict, predictions_text)
#             - metrics_dict: Dictionary containing evaluation metrics
#             - predictions_text: String with "token prediction" format (None if return_predictions=False)
#         """
#         if self.model is None or self.tokenizer is None:
#             raise ValueError("Model not loaded.")

#         logger.info("Running evaluation...")

#         # Load and prepare evaluation data
#         eval_examples = self._read_examples_from_file(eval_file, mode="eval")
#         eval_features = self._convert_examples_to_features(eval_examples)
#         eval_dataloader = self._create_dataloader(eval_features, shuffle=False)

#         # Evaluation
#         self.model.eval()
#         eval_loss = 0.0
#         nb_eval_steps = 0
#         preds = None
#         out_label_ids = None

#         for batch in tqdm(eval_dataloader, desc="Evaluating"):
#             batch = tuple(t.to(self.device) for t in batch)

#             with torch.no_grad():
#                 inputs = {
#                     "input_ids": batch[0],
#                     "attention_mask": batch[1],
#                     "labels": batch[2]
#                 }

#                 outputs = self.model(**inputs)
#                 tmp_eval_loss, logits = outputs[:2]

#                 eval_loss += tmp_eval_loss.item()

#             nb_eval_steps += 1

#             if preds is None:
#                 preds = logits.detach().cpu().numpy()
#                 out_label_ids = inputs["labels"].detach().cpu().numpy()
#             else:
#                 preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
#                 out_label_ids = np.append(
#                     out_label_ids,
#                     inputs["labels"].detach().cpu().numpy(),
#                     axis=0
#                 )

#         eval_loss = eval_loss / nb_eval_steps
#         preds = np.argmax(preds, axis=2)

#         # Convert IDs to labels
#         id_to_label = {i: label for label, i in self.label_map.items()}

#         out_label_list = []
#         preds_list = []

#         for i in range(out_label_ids.shape[0]):
#             temp_out = []
#             temp_pred = []
#             for j in range(out_label_ids.shape[1]):
#                 if out_label_ids[i, j] != self.pad_token_label_id:
#                     temp_out.append(id_to_label[out_label_ids[i][j]])
#                     temp_pred.append(id_to_label[preds[i][j]])
#             out_label_list.append(temp_out)
#             preds_list.append(temp_pred)

#         # Calculate metrics
#         results = {
#             "loss": eval_loss,
#             "precision": precision_score(out_label_list, preds_list),
#             "recall": recall_score(out_label_list, preds_list),
#             "f1": f1_score(out_label_list, preds_list),
#         }

#         # Print detailed classification report
#         logger.info("\nClassification Report:")
#         logger.info("\n" + classification_report(out_label_list, preds_list))

#         # Generate predictions text if requested
#         predictions_text = None
#         if return_predictions:
#             # Read original tokens from file
#             original_tokens = []
#             with open(eval_file, 'r', encoding='utf-8') as f:
#                 for line in f:
#                     line = line.strip()
#                     if line.startswith("-DOCSTART-") or line == "":
#                         if original_tokens:
#                             original_tokens.append("")  # Empty line between sentences
#                     else:
#                         splits = line.split()
#                         if len(splits) >= 1:
#                             original_tokens.append(splits[0])

#             # Flatten predictions list
#             flat_predictions = []
#             for sent_preds in preds_list:
#                 flat_predictions.extend(sent_preds)
#                 flat_predictions.append("")  # Empty line between sentences

#             # Create predictions text
#             predictions_lines = []
#             token_idx = 0
#             pred_idx = 0

#             for token in original_tokens:
#                 if token == "":
#                     predictions_lines.append("")
#                 else:
#                     if pred_idx < len(flat_predictions) and flat_predictions[pred_idx] != "":
#                         predictions_lines.append(f"{token} {flat_predictions[pred_idx]}")
#                         pred_idx += 1
#                     else:
#                         predictions_lines.append(f"{token} O")
#                         if pred_idx < len(flat_predictions):
#                             pred_idx += 1

#             predictions_text = "\n".join(predictions_lines)

#         return results, predictions_text

#     def save_model(self):
#         """Save the trained model and tokenizer"""
#         logger.info(f"Saving model to {self.model_save_dir}")

#         if not os.path.exists(self.model_save_dir):
#             os.makedirs(self.model_save_dir)

#         # Save model
#         model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
#         model_to_save.save_pretrained(self.model_save_dir)

#         # Save tokenizer
#         self.tokenizer.save_pretrained(self.model_save_dir)

#         # Save labels
#         labels_file = os.path.join(self.model_save_dir, "labels.txt")
#         with open(labels_file, 'w') as f:
#             f.write('\n'.join(self.labels))

#         logger.info("Model saved successfully")


