In [1]:
import os
import torch
import pandas as pd
from Bio import SeqIO
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report, confusion_matrix
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns


class VirusDataset(torch.utils.data.Dataset):
    def __init__(self, sequences, labels, tokenizer, max_length=512):
        self.sequences = sequences
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        sequence = self.sequences[idx]
        label = self.labels[idx]
        sequence = ''.join(c for c in sequence.upper() if c in ['A', 'T', 'G', 'C', 'N'])
        
        encoding = self.tokenizer(
            sequence,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

def load_test_data(fasta_file, label_file, label_mapping_file):
    label_to_idx = {}
    idx_to_label = {}
    with open(label_mapping_file, 'r') as f:
        for line in f:
            idx, label = line.strip().split('\t')
            idx = int(idx)
            label_to_idx[label] = idx
            idx_to_label[idx] = label
    
    # 加载序列
    sequences = []
    seq_ids = []
    for record in SeqIO.parse(fasta_file, "fasta"):
        sequences.append(str(record.seq))
        main_id = record.id.split('|')[0]
        seq_ids.append(main_id)
    
    print(f"加载了 {len(sequences)} 条测试序列")
    
    # 加载标签
    labels_df = pd.read_csv(label_file)
    print(f"测试标签文件包含 {len(labels_df)} 条记录")
    id_to_label = dict(zip(labels_df['accession'], labels_df['subtype']))
    
    # 匹配序列和标签
    labels = []
    valid_seq_ids = []
    valid_sequences = []
    
    for i, seq_id in enumerate(seq_ids):
        if seq_id in id_to_label:
            label = id_to_label[seq_id]
            if label in label_to_idx: 
                labels.append(label_to_idx[label])
                valid_seq_ids.append(seq_id)
                valid_sequences.append(sequences[i])
            else:
                print(f"警告: 标签 {label} 在训练集中未出现，跳过序列 {seq_id}")
        else:
            print(f"警告: 序列ID {seq_id} 在标签文件中未找到")
    
    print(f"成功匹配了 {len(labels)} 条测试序列的标签")
    
    return valid_sequences, labels, valid_seq_ids, idx_to_label

def plot_confusion_matrix(y_true, y_pred, idx_to_label, output_dir):
    cm = confusion_matrix(y_true, y_pred)
    
    labels = [idx_to_label[i] for i in range(len(idx_to_label))]
    
    plt.figure(figsize=(10, 8), dpi=800)  # 设置DPI为800
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.title('Confusion Matrix')
    plt.tight_layout()
    
    plt.savefig(os.path.join(output_dir, 'a_new_class_10_confusion_matrix.png'), dpi=800)  # 保存时也设置DPI为800
    plt.close()

def evaluate_model():
    # 设置参数
    
    model_dir = "/root/autodl-tmp/Influenza_BERT/new_class_5/checkpoint-6139"
    test_fasta_file = "/root/autodl-tmp/Influenza_BERT/test_sequences_5.fasta"
    test_label_file = "/root/autodl-tmp/Influenza_BERT/test_labels_5.csv"
    label_mapping_file = os.path.join(model_dir, "label_mapping.txt")
    batch_size = 8
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    print(f"使用设备: {device}")
    
    # 加载测试数据
    test_sequences, test_labels, test_ids, idx_to_label = load_test_data(
        test_fasta_file, test_label_file, label_mapping_file
    )
    
    # 加载模型和tokenizer
    
    tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
    model = AutoModelForSequenceClassification.from_pretrained(model_dir, trust_remote_code=True)
    model.to(device)
    model.eval()
    
    # 创建测试数据集和数据加载器
    test_dataset = VirusDataset(test_sequences, test_labels, tokenizer)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size)
    
    # 进行预测
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in test_dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            
            preds = torch.argmax(logits, dim=1).cpu().numpy()
            
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
    
    # 计算评估指标
    accuracy = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='weighted')
    
    print("\n测试集评估结果:")
    print(f"准确率 (Accuracy): {accuracy:.4f}")
    print(f"精确率 (Precision): {precision:.4f}")
    print(f"召回率 (Recall): {recall:.4f}")
    print(f"F1分数: {f1:.4f}")
    
    # 生成详细的分类报告
    target_names = [idx_to_label[i] for i in range(len(idx_to_label))]
    class_report = classification_report(all_labels, all_preds, target_names=target_names)
    print("\n分类报告:")
    print(class_report)
    
    # 绘制混淆矩阵
    plot_confusion_matrix(all_labels, all_preds, idx_to_label, model_dir)
    print(f"混淆矩阵已保存到 {os.path.join(model_dir, 'confusion_matrix.png')}")
    
    # 保存详细的预测结果
    results_df = pd.DataFrame({
        'sequence_id': test_ids,
        'true_label': [idx_to_label[label] for label in all_labels],
        'predicted_label': [idx_to_label[pred] for pred in all_preds],
        'correct': [pred == label for pred, label in zip(all_preds, all_labels)]
    })
    
    results_file = os.path.join(model_dir, 'test_predictions.csv')
    results_df.to_csv(results_file, index=False)
    print(f"详细预测结果已保存到 {results_file}")

if __name__ == "__main__":
    evaluate_model()

使用设备: cuda
加载了 5261 条测试序列
测试标签文件包含 5261 条记录
成功匹配了 5261 条测试序列的标签





测试集评估结果:
准确率 (Accuracy): 0.9899
精确率 (Precision): 0.9901
召回率 (Recall): 0.9899
F1分数: 0.9900

分类报告:
              precision    recall  f1-score   support

        H1N1       0.98      0.99      0.99      1678
        H3N2       1.00      0.99      0.99      3452
        H5N1       1.00      0.94      0.97        36
        H7N9       0.92      0.97      0.94        34
        H9N2       0.90      0.98      0.94        61

    accuracy                           0.99      5261
   macro avg       0.96      0.98      0.97      5261
weighted avg       0.99      0.99      0.99      5261

混淆矩阵已保存到 /root/autodl-tmp/Influenza_BERT/new_class_5/checkpoint-6139/confusion_matrix.png
详细预测结果已保存到 /root/autodl-tmp/Influenza_BERT/new_class_5/checkpoint-6139/test_predictions.csv
