In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoTokenizer
from sklearn.metrics import classification_report
from seqeval.metrics import f1_score, classification_report as seqeval_report
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
from torchcrf  import CRF
# Check for CUDA
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Data preprocessing
def read_conll_format(file_path):
    sentences = []
    tags = []
    
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            sentence = []
            sentence_tags = []
            
            for line in f:
                line = line.strip()
                if line:
                    # Split by spaces
                    parts = line.split(' ')
                    # Remove empty strings
                    parts = [part for part in parts if part]
                    if len(parts) >= 2:
                        token = parts[0]
                        tag = parts[-1]
                        sentence.append(token)
                        sentence_tags.append(tag)
                else:
                    if sentence:  # Skip empty sentences
                        sentences.append(sentence)
                        tags.append(sentence_tags)
                        sentence = []
                        sentence_tags = []
            
            # Add the last sentence if it's not empty
            if sentence:
                sentences.append(sentence)
                tags.append(sentence_tags)
    except Exception as e:
        print(f"Error reading file {file_path}: {e}")
        raise
            
    return sentences, tags

# Create mappings for tags to indices and vice versa
def create_tag_mappings(tags_list):
    unique_tags = set()
    for tags in tags_list:
        unique_tags.update(tags)
    
    tag2idx = {tag: idx for idx, tag in enumerate(sorted(unique_tags))}
    idx2tag = {idx: tag for tag, idx in tag2idx.items()}
    
    return tag2idx, idx2tag

# Custom Dataset
class NERDataset(Dataset):
    def __init__(self, sentences, tags, tokenizer, tag2idx, max_len=128):
        self.sentences = sentences
        self.tags = tags
        self.tokenizer = tokenizer
        self.tag2idx = tag2idx
        self.max_len = max_len
        
    def __len__(self):
        return len(self.sentences)
    
    def __getitem__(self, idx):
        sentence = self.sentences[idx]
        tags = self.tags[idx]
        
        # Tokenize the sentence
        encoding = self.tokenizer(
            sentence,
            is_split_into_words=True,
            padding='max_length',
            truncation=True,
            max_length=self.max_len,
            return_tensors='pt'
        )
        
        # Remove batch dimension
        input_ids = encoding['input_ids'].squeeze(0)
        attention_mask = encoding['attention_mask'].squeeze(0)
        
        # Create tag sequence with -100 as padding/special tokens
        tag_ids = torch.full((self.max_len,), fill_value=self.tag2idx['O'], dtype=torch.long)
        
        # Map word pieces to tags
        word_ids = encoding.word_ids()
        
        for i, word_idx in enumerate(word_ids):
            if word_idx is None:
                # Special tokens like [CLS], [SEP], [PAD]
                continue
            elif word_idx < len(tags):
                # Only assign tags to first subword of each word
                if i == 0 or word_ids[i-1] != word_idx:
                    tag_ids[i] = self.tag2idx[tags[word_idx]]
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': tag_ids,
            'word_ids': word_ids
        }

# Model Architecture
class NERModel(nn.Module):
    def __init__(self, pretrained_model_name, num_tags):
        super().__init__()
        self.bert = AutoModel.from_pretrained(pretrained_model_name)
        self.dropout = nn.Dropout(0.1)
        self.bilstm = nn.LSTM(
            input_size=self.bert.config.hidden_size,
            hidden_size=256, num_layers=2,
            bidirectional=True, batch_first=True,
            dropout=0.2
        )
        self.classifier = nn.Linear(512, num_tags)
        # 用 torchcrf
        self.crf = CRF(num_tags, batch_first=True)

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        seq_out = self.dropout(outputs.last_hidden_state)
        lstm_out, _ = self.bilstm(seq_out)
        logits = self.classifier(lstm_out)  # (batch, seq_len, num_tags)

        mask = attention_mask.bool()

        if labels is not None:
            # CRF loss: torchcrf 預設輸出的是 log-likelihood，要取負號
            loss = -self.crf(logits, labels, mask=mask, reduction='mean')
            return loss, logits
        else:
            # CRF decode -> list of [seq_len] tag idx
            best_paths = self.crf.decode(logits, mask=mask)
            return best_paths, logits

# Training function
def train(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    
    for batch in tqdm(dataloader, desc="Training"):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        optimizer.zero_grad()
        
        loss, _ = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader)

# Evaluation function
# Evaluation function
def evaluate(model, dataloader, idx2tag, device, detailed=False):
    model.eval()
    all_preds, all_trues = [], []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            ids    = batch['input_ids'].to(device)
            mask   = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # 直接从模型获取best_paths
            best_paths, _ = model(input_ids=ids, attention_mask=mask)

            # 对齐 subword -> word
            for i in range(ids.size(0)):
                wids = batch['word_ids'][i]
                prev = None
                pred_seq, true_seq = [], []
                
                for j, widx in enumerate(wids):
                    # 跳过特殊tokens和重复的word_id
                    if widx is None or widx == prev:
                        continue
                    
                    # 跳过超出长度的预测或padding
                    if j >= len(best_paths[i]) or j >= mask[i].sum().item():
                        continue
                        
                    # 获取预测标签和真实标签
                    pred_tag = idx2tag[best_paths[i][j]]
                    true_tag = idx2tag[labels[i, j].item()]
                    
                    pred_seq.append(pred_tag)
                    true_seq.append(true_tag)
                    prev = widx
                
                # 只添加非空序列
                if pred_seq and true_seq:
                    all_preds.append(pred_seq)
                    all_trues.append(true_seq)

    # 生成报告 dict
    report = seqeval_report(all_trues, all_preds, output_dict=True, zero_division=0)
    micro = report['micro avg']
    precision = micro['precision']
    recall    = micro['recall']
    f1        = micro['f1-score']

    if detailed:
        print("="*50)
        print("Secbert-BiLstm-CRF")
        print("="*50, "\n")
        print(f"Test Results - Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}\n")
        print("Detailed Test Results by Entity Type:")
        
        # 获取所有实体标签（包括所有B-, I-, S-, E-前缀）
        for tag, m in sorted(report.items()):
            if tag in ['O','macro avg','micro avg','weighted avg']:
                continue
            print(f"{tag} {m['precision']:.2f} {m['recall']:.2f} {m['f1-score']:.2f}")
        return precision, recall, f1, report

    return precision, recall, f1

def collate_fn(batch):
    return {
        'input_ids': torch.stack([x['input_ids'] for x in batch]),
        'attention_mask': torch.stack([x['attention_mask'] for x in batch]),
        'labels': torch.stack([x['labels'] for x in batch]),
        'word_ids': [x['word_ids'] for x in batch]
    }



# Plot training history
def plot_training_history(train_losses, val_f1s, save_path='training_history.png'):
    plt.figure(figsize=(12, 5))
    
    # Plot training loss
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Training Loss')
    plt.title('Training Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    
    # Plot validation F1
    plt.subplot(1, 2, 2)
    plt.plot(val_f1s, label='Validation F1')
    plt.title('Validation F1 Score')
    plt.xlabel('Epochs')
    plt.ylabel('F1 Score')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

# Main function
def main():
    # Create output directory
    output_dir = "model_output"
    os.makedirs(output_dir, exist_ok=True)

    # Load data
    train_sents, train_tags = read_conll_format('data/train.txt')
    valid_sents, valid_tags = read_conll_format('data/valid.txt')
    test_sents,  test_tags  = read_conll_format('data/test.txt')

    # Prepare tag mappings & tokenizer
    tag2idx, idx2tag = create_tag_mappings(train_tags + valid_tags + test_tags)
    tokenizer = AutoTokenizer.from_pretrained('jackaduma/SecBERT')

    # Datasets & loaders
    train_ds = NERDataset(train_sents, train_tags, tokenizer, tag2idx, max_len=256)
    valid_ds = NERDataset(valid_sents, valid_tags, tokenizer, tag2idx, max_len=256)
    test_ds  = NERDataset(test_sents,  test_tags,  tokenizer, tag2idx, max_len=256)
    train_loader = DataLoader(train_ds, batch_size=8, shuffle=True,collate_fn=collate_fn)
    valid_loader = DataLoader(valid_ds, batch_size=8, shuffle=False, collate_fn=collate_fn)
    test_loader  = DataLoader(test_ds,  batch_size=8, shuffle=False, collate_fn=collate_fn)

    # Model, optimizer, scheduler
    model = NERModel('jackaduma/SecBERT', len(tag2idx)).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=2, verbose=True
    )

    # Training loop
    num_epochs = 10
    best_f1 = 0
    train_losses, val_f1s = [], []

    print(f"Starting training for {num_epochs} epochs...")
    for epoch in range(num_epochs):
        loss = train(model, train_loader, optimizer, device)
        train_losses.append(loss)

        valid_precision, valid_recall, valid_f1 = evaluate(model, valid_loader, idx2tag, device, detailed=False)
        val_f1s.append(valid_f1)
        scheduler.step(valid_f1)

        print(f"\nEpoch {epoch+1}/{num_epochs} — Train Loss: {loss:.4f}, Valid F1: {valid_f1:.4f}")
        print(f"  Validation Precision: {valid_precision:.4f}, Recall: {valid_recall:.4f}")
        if valid_f1 > best_f1:
            best_f1 = valid_f1
            torch.save(model.state_dict(), os.path.join(output_dir, 'best_model.pt'))
            print(f"  Saved new best model (F1: {valid_f1:.4f})")

    # Plot history
    plot_training_history(train_losses, val_f1s, os.path.join(output_dir, 'training_history.png'))

    # Final evaluation on test set
    print("\nLoading best model for test evaluation...")
    model.load_state_dict(torch.load(os.path.join(output_dir, 'best_model.pt')))

    print("Evaluating on test set...")
    test_precision, test_recall, test_f1, test_report = evaluate(model, test_loader, idx2tag, device, detailed=True)
    print(f"Test Results - Precision: {test_precision:.4f}, Recall: {test_recall:.4f}, F1: {test_f1:.4f}")

    # Save detailed test results
    with open(os.path.join(output_dir, "test_results.txt"), 'w', encoding='utf-8') as f:
        f.write("="*50 + "\n")
        f.write("Detailed Test Set Evaluation Results\n")
        f.write("="*50 + "\n\n")
        f.write(f"Test Results - Precision: {test_precision:.4f}, Recall: {test_recall:.4f}, F1: {test_f1:.4f}\n\n")
        f.write("Detailed Test Results by Entity Type:\n")
        f.write(f"{'Entity':<10} {'Precision':>9} {'Recall':>7} {'F1-score':>9}\n")
        f.write("-"*40 + "\n")
        for tag in sorted(test_report.keys()):
            if tag in ['O','macro avg','micro avg','weighted avg']:
                continue
            m = test_report[tag]
            f.write(f"{tag:<10} {m['precision']:>7.2f}    {m['recall']:>5.2f}    {m['f1-score']:>7.2f}\n")


    print(f"Detailed results saved to {output_dir}/test_results.txt")


if __name__ == "__main__":
    main()

Using device: cuda
Starting training for 10 epochs...


Training: 100%|██████████| 657/657 [05:03<00:00,  2.17it/s]
Evaluating: 100%|██████████| 83/83 [00:04<00:00, 20.08it/s]



Epoch 1/10 — Train Loss: 30.7040, Valid F1: 0.1000
  Validation Precision: 0.1371, Recall: 0.0787
  Saved new best model (F1: 0.1000)


Training: 100%|██████████| 657/657 [02:33<00:00,  4.28it/s]
Evaluating: 100%|██████████| 83/83 [00:04<00:00, 20.46it/s]



Epoch 2/10 — Train Loss: 14.5786, Valid F1: 0.5482
  Validation Precision: 0.5707, Recall: 0.5274
  Saved new best model (F1: 0.5482)


Training: 100%|██████████| 657/657 [02:35<00:00,  4.23it/s]
Evaluating: 100%|██████████| 83/83 [00:04<00:00, 20.38it/s]



Epoch 3/10 — Train Loss: 6.6345, Valid F1: 0.6940
  Validation Precision: 0.6799, Recall: 0.7087
  Saved new best model (F1: 0.6940)


Training: 100%|██████████| 657/657 [02:33<00:00,  4.27it/s]
Evaluating: 100%|██████████| 83/83 [00:04<00:00, 19.80it/s]



Epoch 4/10 — Train Loss: 3.7665, Valid F1: 0.7353
  Validation Precision: 0.7572, Recall: 0.7147
  Saved new best model (F1: 0.7353)


Training: 100%|██████████| 657/657 [03:26<00:00,  3.19it/s]
Evaluating: 100%|██████████| 83/83 [00:09<00:00,  8.67it/s]



Epoch 5/10 — Train Loss: 2.5707, Valid F1: 0.7686
  Validation Precision: 0.7615, Recall: 0.7759
  Saved new best model (F1: 0.7686)


Training: 100%|██████████| 657/657 [06:15<00:00,  1.75it/s]
Evaluating: 100%|██████████| 83/83 [00:09<00:00,  8.70it/s]



Epoch 6/10 — Train Loss: 1.8379, Valid F1: 0.7685
  Validation Precision: 0.7435, Recall: 0.7952


Training: 100%|██████████| 657/657 [06:14<00:00,  1.75it/s]
Evaluating: 100%|██████████| 83/83 [00:09<00:00,  8.39it/s]



Epoch 7/10 — Train Loss: 1.5190, Valid F1: 0.7774
  Validation Precision: 0.7587, Recall: 0.7971
  Saved new best model (F1: 0.7774)


Training: 100%|██████████| 657/657 [06:16<00:00,  1.75it/s]
Evaluating: 100%|██████████| 83/83 [00:09<00:00,  8.39it/s]



Epoch 8/10 — Train Loss: 1.2637, Valid F1: 0.7768
  Validation Precision: 0.7506, Recall: 0.8049


Training: 100%|██████████| 657/657 [06:12<00:00,  1.76it/s]
Evaluating: 100%|██████████| 83/83 [00:09<00:00,  8.52it/s]



Epoch 9/10 — Train Loss: 1.0275, Valid F1: 0.7867
  Validation Precision: 0.7842, Recall: 0.7892
  Saved new best model (F1: 0.7867)


Training: 100%|██████████| 657/657 [06:12<00:00,  1.76it/s]
Evaluating: 100%|██████████| 83/83 [00:09<00:00,  8.45it/s]



Epoch 10/10 — Train Loss: 0.9225, Valid F1: 0.7886
  Validation Precision: 0.7791, Recall: 0.7984
  Saved new best model (F1: 0.7886)

Loading best model for test evaluation...
Evaluating on test set...


Evaluating: 100%|██████████| 83/83 [00:10<00:00,  8.23it/s]


Secbert-BiLstm-CRF

Test Results - Precision: 0.8246, Recall: 0.8450, F1: 0.8347

Detailed Test Results by Entity Type:
Area 0.82 0.89 0.86
Exp 0.98 0.99 0.99
Features 0.96 0.96 0.96
HackOrg 0.80 0.79 0.79
Idus 0.83 0.92 0.87
OffAct 0.80 0.81 0.80
Org 0.70 0.72 0.71
Purp 0.79 0.97 0.87
SamFile 0.97 0.77 0.86
SecTeam 0.91 0.86 0.89
Time 0.89 0.90 0.90
Tool 0.65 0.76 0.70
Way 0.92 0.96 0.94
Test Results - Precision: 0.8246, Recall: 0.8450, F1: 0.8347
Detailed results saved to model_output/test_results.txt


In [5]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoTokenizer
from seqeval.metrics import classification_report as seqeval_report
import os
from tqdm import tqdm
from torchcrf import CRF

# Check CUDA
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Data preprocessing
def read_conll_format(file_path):
    sentences = []
    tags = []
    
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            sentence = []
            sentence_tags = []
            
            for line in f:
                line = line.strip()
                if line:
                    # Split by whitespace
                    parts = line.split(' ')
                    # Remove empty strings
                    parts = [part for part in parts if part]
                    if len(parts) >= 2:
                        token = parts[0]
                        tag = parts[-1]
                        sentence.append(token)
                        sentence_tags.append(tag)
                else:
                    if sentence:  # Skip empty sentences
                        sentences.append(sentence)
                        tags.append(sentence_tags)
                        sentence = []
                        sentence_tags = []
            
            # Add the last sentence (if not empty)
            if sentence:
                sentences.append(sentence)
                tags.append(sentence_tags)
    except Exception as e:
        print(f"Error reading file {file_path}: {e}")
        raise
            
    return sentences, tags

# Create tag mappings
def create_tag_mappings(tags_list):
    unique_tags = set()
    for tags in tags_list:
        unique_tags.update(tags)
    
    tag2idx = {tag: idx for idx, tag in enumerate(sorted(unique_tags))}
    idx2tag = {idx: tag for tag, idx in tag2idx.items()}
    
    return tag2idx, idx2tag

# Custom dataset
class NERDataset(Dataset):
    def __init__(self, sentences, tags, tokenizer, tag2idx, max_len=128):
        self.sentences = sentences
        self.tags = tags
        self.tokenizer = tokenizer
        self.tag2idx = tag2idx
        self.max_len = max_len
        
    def __len__(self):
        return len(self.sentences)
    
    def __getitem__(self, idx):
        sentence = self.sentences[idx]
        tags = self.tags[idx]
        
        # Tokenization
        encoding = self.tokenizer(
            sentence,
            is_split_into_words=True,
            padding='max_length',
            truncation=True,
            max_length=self.max_len,
            return_tensors='pt'
        )
        
        # Remove batch dimension
        input_ids = encoding['input_ids'].squeeze(0)
        attention_mask = encoding['attention_mask'].squeeze(0)
        
        # Create tag sequence, fill with the index of the 'O' tag
        tag_ids = torch.full((self.max_len,), fill_value=self.tag2idx['O'], dtype=torch.long)
        
        # Map subwords to tags
        word_ids = encoding.word_ids()
        
        for i, word_idx in enumerate(word_ids):
            if word_idx is None:
                # Special tokens like [CLS], [SEP], [PAD]
                continue
            elif word_idx < len(tags):
                # Assign tags only to the first subword of each word
                if i == 0 or word_ids[i-1] != word_idx:
                    tag_ids[i] = self.tag2idx[tags[word_idx]]
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': tag_ids,
            'word_ids': word_ids
        }

# Model architecture
class NERModel(nn.Module):
    def __init__(self, pretrained_model_name, num_tags):
        super().__init__()
        self.bert = AutoModel.from_pretrained(pretrained_model_name)
        self.dropout = nn.Dropout(0.1)
        self.bilstm = nn.LSTM(
            input_size=self.bert.config.hidden_size,
            hidden_size=256, num_layers=2,
            bidirectional=True, batch_first=True,
            dropout=0.2
        )
        self.classifier = nn.Linear(512, num_tags)
        # Use torchcrf
        self.crf = CRF(num_tags, batch_first=True)

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        seq_out = self.dropout(outputs.last_hidden_state)
        lstm_out, _ = self.bilstm(seq_out)
        logits = self.classifier(lstm_out)  # (batch, seq_len, num_tags)

        mask = attention_mask.bool()

        if labels is not None:
            # CRF loss: torchcrf outputs log-likelihood by default, take the negative
            loss = -self.crf(logits, labels, mask=mask, reduction='mean')
            return loss, logits
        else:
            # CRF decoding -> list of [seq_len] tag indices
            best_paths = self.crf.decode(logits, mask=mask)
            return best_paths, logits

def collate_fn(batch):
    return {
        'input_ids': torch.stack([x['input_ids'] for x in batch]),
        'attention_mask': torch.stack([x['attention_mask'] for x in batch]),
        'labels': torch.stack([x['labels'] for x in batch]),
        'word_ids': [x['word_ids'] for x in batch]
    }

# Calculate metrics manually
def calculate_metrics(true_tags, pred_tags):
    """Calculate precision, recall, f1 for each tag type"""
    # Collect all unique tags
    all_tags = set()
    for seq in true_tags + pred_tags:
        all_tags.update(seq)
    
    # Remove 'O' tag if present
    if 'O' in all_tags:
        all_tags.remove('O')
    
    tag_metrics = {}
    
    for tag in sorted(all_tags):
        tp, fp, fn = 0, 0, 0
        
        for true_seq, pred_seq in zip(true_tags, pred_tags):
            true_tag_indices = [i for i, t in enumerate(true_seq) if t == tag]
            pred_tag_indices = [i for i, t in enumerate(pred_seq) if t == tag]
            
            # True positive: in both true and predicted
            tp += len(set(true_tag_indices) & set(pred_tag_indices))
            
            # False positive: in predicted but not in true
            fp += len(set(pred_tag_indices) - set(true_tag_indices))
            
            # False negative: in true but not in predicted
            fn += len(set(true_tag_indices) - set(pred_tag_indices))
        
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
        
        tag_metrics[tag] = {
            'precision': precision,
            'recall': recall,
            'f1-score': f1
        }
    
    return tag_metrics

# Evaluation function with detailed breakdown
def evaluate(model, dataloader, idx2tag, device):
    model.eval()
    all_preds, all_trues = [], []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            ids    = batch['input_ids'].to(device)
            mask   = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # Get best_paths directly from the model
            best_paths, _ = model(input_ids=ids, attention_mask=mask)

            # Align subword -> word
            for i in range(ids.size(0)):
                wids = batch['word_ids'][i]
                prev = None
                pred_seq, true_seq = [], []
                
                for j, widx in enumerate(wids):
                    # Skip special tokens and repeated word_ids
                    if widx is None or widx == prev:
                        continue
                    
                    # Skip predictions or padding beyond length
                    if j >= len(best_paths[i]) or j >= mask[i].sum().item():
                        continue
                        
                    # Get predicted and true tags
                    pred_tag = idx2tag[best_paths[i][j]]
                    true_tag = idx2tag[labels[i, j].item()]
                    
                    pred_seq.append(pred_tag)
                    true_seq.append(true_tag)
                    prev = widx
                
                # Add only non-empty sequences
                if pred_seq and true_seq:
                    all_preds.append(pred_seq)
                    all_trues.append(true_seq)

    # Generate report dict - first use seqeval for overall metrics
    seqeval_report_dict = seqeval_report(all_trues, all_preds, output_dict=True, zero_division=0)
    micro = seqeval_report_dict['micro avg']
    precision = micro['precision']
    recall    = micro['recall']
    f1        = micro['f1-score']

    # Now get detailed metrics for each tag type including all prefixes
    tag_metrics = calculate_metrics(all_trues, all_preds) 

    print("="*50)
    print("Secbert-BiLstm-CRF")
    print("="*50, "\n")
    print(f"Test results - Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}\n")
    print("Detailed test results by entity type:")
    
    # Display all tag types (including B-, I-, E-, S- prefixes)
    for tag, metrics in sorted(tag_metrics.items()):
        print(f"{tag} {metrics['precision']:.2f} {metrics['recall']:.2f} {metrics['f1-score']:.2f}")
    
    # Save detailed results
    output_dir = "model_output"
    os.makedirs(output_dir, exist_ok=True)
    
    with open(os.path.join(output_dir, "detailed_test_results.txt"), 'w', encoding='utf-8') as f:
        f.write("="*50 + "\n")
        f.write("Detailed test evaluation results\n")
        f.write("="*50 + "\n\n")
        f.write(f"Test results - Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}\n\n")
        f.write("Detailed test results by entity type:\n")
        
        # Write all tag types, including B-, I-, E-, S- prefixes
        for tag, metrics in sorted(tag_metrics.items()):
            f.write(f"{tag} {metrics['precision']:.2f} {metrics['recall']:.2f} {metrics['f1-score']:.2f}\n")

    print(f"Detailed results saved to {output_dir}/detailed_test_results.txt")
    
    return precision, recall, f1, tag_metrics

def main():
    # Load data
    train_sents, train_tags = read_conll_format('data/train.txt')
    valid_sents, valid_tags = read_conll_format('data/valid.txt')
    test_sents,  test_tags  = read_conll_format('data/test.txt')

    # Prepare tag mappings and tokenizer
    tag2idx, idx2tag = create_tag_mappings(train_tags + valid_tags + test_tags)
    tokenizer = AutoTokenizer.from_pretrained('jackaduma/SecBERT')

    # Create dataset and dataloader
    test_ds = NERDataset(test_sents, test_tags, tokenizer, tag2idx, max_len=256)
    test_loader = DataLoader(test_ds, batch_size=8, shuffle=False, collate_fn=collate_fn)

    # Create model
    model = NERModel('jackaduma/SecBERT', len(tag2idx)).to(device)
    
    # Load pre-trained model
    output_dir = "model_output"
    model_path = os.path.join(output_dir, 'best_model.pt')
    
    print(f"Loading model: {model_path}")
    model.load_state_dict(torch.load(model_path, map_location=device))
    
    # Evaluate and generate detailed results
    print("Evaluating on the test set...")
    test_precision, test_recall, test_f1, test_metrics = evaluate(model, test_loader, idx2tag, device)
    
    print(f"Test results - Precision: {test_precision:.4f}, Recall: {test_recall:.4f}, F1: {test_f1:.4f}")

if __name__ == "__main__":
    main()

Using device: cuda
Loading model: model_output\best_model.pt
Evaluating on the test set...


Evaluating: 100%|██████████| 83/83 [00:04<00:00, 19.75it/s]


Secbert-BiLstm-CRF

Test results - Precision: 0.8246, Recall: 0.8450, F1: 0.8347

Detailed test results by entity type:
B-Area 0.86 0.92 0.89
B-Exp 0.98 0.99 0.99
B-Features 0.97 0.96 0.96
B-HackOrg 0.84 0.81 0.83
B-Idus 0.85 0.95 0.90
B-OffAct 0.84 0.83 0.84
B-Org 0.75 0.72 0.74
B-Purp 0.84 0.99 0.91
B-SamFile 0.98 0.77 0.86
B-SecTeam 0.98 0.92 0.95
B-Time 0.94 0.94 0.94
B-Tool 0.69 0.78 0.73
B-Way 0.96 0.97 0.97
I-Area 0.77 0.90 0.83
I-Exp 1.00 1.00 1.00
I-Features 0.99 0.91 0.94
I-HackOrg 0.77 0.77 0.77
I-Idus 0.69 1.00 0.82
I-OffAct 0.91 0.77 0.83
I-Org 0.75 0.76 0.76
I-Purp 0.81 1.00 0.90
I-SamFile 1.00 0.85 0.92
I-SecTeam 0.74 0.86 0.79
I-Time 0.98 0.84 0.90
I-Tool 0.69 0.78 0.73
I-Way 0.94 0.98 0.96
Detailed results saved to model_output/detailed_test_results.txt
Test results - Precision: 0.8246, Recall: 0.8450, F1: 0.8347
