In [1]:
# Add this first
from transformers import AutoTokenizer, AutoModelForTokenClassification
from datasets import load_from_disk
from torch.utils.data import DataLoader
from transformers import DataCollatorForTokenClassification
import torch

model_path = "../outputs/ner_model"

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForTokenClassification.from_pretrained(model_path)
val_dataset = load_from_disk("../data/tokenized_dataset")["validation"]

data_collator = DataCollatorForTokenClassification(tokenizer)
val_dataloader = DataLoader(val_dataset, batch_size=8, collate_fn=data_collator)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)





from seqeval.metrics import classification_report
import numpy as np

# Get the label names from the tokenizer config or define manually
label_names = [
   "O",
"B-COMPONENT",
"I-COMPONENT",
"B-MEASUREMENT",
"I-MEASUREMENT",
"B-WARNING",
"I-WARNING",
"B-PROCEDURE",
"I-PROCEDURE",
"B-TOOL",
"I-TOOL"
]  # adjust based on your project

# Helper function to decode predictions
def get_predictions_labels(model, dataloader):
    model.eval()
    true_labels = []
    pred_labels = []

    for batch in val_dataloader:
        with torch.no_grad():
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
        
        predictions = torch.argmax(outputs.logits, dim=-1)
        labels = batch["labels"]

        # Remove ignored index (-100)
        for pred, label in zip(predictions, labels):
            true, pred_true = [], []
            for p, l in zip(pred, label):
                if l != -100:
                    true.append(label_names[l])
                    pred_true.append(label_names[p])
            true_labels.append(true)
            pred_labels.append(pred_true)
    
    return pred_labels, true_labels


# Get predictions
preds, trues = get_predictions_labels(model, val_dataloader)

# Print classification report
print(classification_report(trues, preds))


              precision    recall  f1-score   support

   COMPONENT       0.14      0.20      0.17        15
 MEASUREMENT       0.00      0.00      0.00         4
   PROCEDURE       0.00      0.00      0.00        14
        TOOL       0.25      0.57      0.35         7

   micro avg       0.10      0.14      0.12        49
   macro avg       0.08      0.15      0.10        49
weighted avg       0.08      0.14      0.10        49

