In [None]:
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from Bio import SeqIO
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig
from tqdm import tqdm


plt.rcParams['font.family'] = ['sans-serif']
plt.rcParams['font.sans-serif'] = ['DejaVu Sans', 'Arial', 'Microsoft YaHei', 'SimHei']
plt.rcParams['font.size'] = 34
plt.rcParams['axes.titlesize'] = 34
plt.rcParams['axes.labelsize'] = 34
plt.rcParams['xtick.labelsize'] = 34
plt.rcParams['ytick.labelsize'] = 34
plt.rcParams['figure.dpi'] = 800
plt.rcParams['savefig.dpi'] = 800



def load_data(fasta_file, label_file, existing_label_to_idx=None):
    """加载序列数据和标签"""
    sequences = []
    seq_ids = []
    for record in SeqIO.parse(fasta_file, "fasta"):
        sequences.append(str(record.seq))
        main_id = record.id.split('|')[0]
        seq_ids.append(main_id)
    
    print(f"加载了 {len(sequences)} 条序列")
    
    labels_df = pd.read_csv(label_file)
    print(f"标签文件包含 {len(labels_df)} 条记录")
    id_to_label = dict(zip(labels_df['accession'], labels_df['subtype']))
    
    labels = []
    valid_seq_ids = []
    valid_sequences = []
    
    for i, seq_id in enumerate(seq_ids):
        if seq_id in id_to_label:
            labels.append(id_to_label[seq_id])
            valid_seq_ids.append(seq_id)
            valid_sequences.append(sequences[i])
    
    print(f"成功匹配了 {len(labels)} 条序列的标签")
    
    if existing_label_to_idx is None:
        unique_labels = sorted(set(labels))
        label_to_idx = {label: idx for idx, label in enumerate(unique_labels)}
    else:
        label_to_idx = existing_label_to_idx
    
    numeric_labels = [label_to_idx[label] for label in labels]
    
    return valid_sequences, numeric_labels, valid_seq_ids, label_to_idx

def analyze_feature_importance(model, sequence, tokenizer, label_idx, device, window_size=5):
    """分析序列特征重要性（使用窗口遮蔽分析）"""
    model.eval()
    
    # 基准预测
    with torch.no_grad():
        base_encoding = tokenizer(
            sequence,
            max_length=512,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        ).to(device)
        
        base_output = model(**base_encoding)
        base_prob = torch.softmax(base_output.logits, dim=-1)[0, label_idx].item()
    
    # 初始化窗口重要性分数数组
    window_importance = np.zeros(len(sequence) - window_size + 1)
    
    # 对每个窗口进行遮蔽分析
    for i in tqdm(range(len(sequence) - window_size + 1), desc="分析窗口重要性"):
        # 创建遮蔽序列
        masked_seq = list(sequence)
        for j in range(i, i + window_size):
            masked_seq[j] = 'N'
        masked_seq = ''.join(masked_seq)
        
        # 计算遮蔽后的预测概率
        with torch.no_grad():
            masked_encoding = tokenizer(
                masked_seq,
                max_length=512,
                padding="max_length",
                truncation=True,
                return_tensors="pt"
            ).to(device)
            
            masked_output = model(**masked_encoding)
            masked_prob = torch.softmax(masked_output.logits, dim=-1)[0, label_idx].item()
        
        # 计算窗口重要性
        window_importance[i] = abs(base_prob - masked_prob)
    
    return window_importance

def visualize_importance(sequence, window_importance, seq_id, subtype, output_dir, window_size=5):
    """可视化窗口重要性分布"""
    # 创建图形，使用物理尺寸而不是像素尺寸
    fig = plt.figure(figsize=(30, 10))
    
    # 创建热图
    importance_matrix = window_importance.reshape(1, -1)
    sns.heatmap(
        importance_matrix,
        cmap='YlOrRd',
        xticklabels=False,
        yticklabels=False,
        cbar_kws={'label': 'Window Importance', 'shrink': 0.8},
    )
    
    plt.title(f"Window Importance Distribution - {seq_id} ({subtype})")
    plt.xlabel(f"Window Position (size={window_size})", fontsize=34)
    
    # 添加位置标记
    seq_length = len(sequence) - window_size + 1
    step = 50
    positions = range(0, seq_length, step)
    plt.xticks(positions, positions, fontsize=34, rotation=50)
    plt.gca().margins(x=0.05)
    
    # 添加重要区域标注
    threshold = np.percentile(window_importance, 90)
    important_regions = np.where(window_importance > threshold)[0]
    
    if len(important_regions) > 0:
        current_start = important_regions[0]
        current_end = current_start
        
        for i in range(1, len(important_regions)):
            if important_regions[i] == current_end + 1:
                current_end = important_regions[i]
            else:
                plt.axvspan(current_start, current_end, color='blue', alpha=0.3)
                current_start = important_regions[i]
                current_end = current_start
        
        plt.axvspan(current_start, current_end, color='blue', alpha=0.3)
    
    # 调整布局并保存
    plt.tight_layout(pad=2.0)
    
    # 保存为矢量格式和位图格式
    output_file_svg = os.path.join(output_dir, f"{subtype}_{seq_id}_importance.svg")
    output_file_png = os.path.join(output_dir, f"{subtype}_{seq_id}_importance.png")
    
    # 保存SVG格式（用于后续编辑）
    plt.savefig(output_file_svg, format='svg', bbox_inches='tight')
    # 保存PNG格式（用于直接查看）
    plt.savefig(output_file_png, dpi=800, bbox_inches='tight')
    
    plt.close()
    
    return np.argsort(window_importance)[-20:]

def analyze_sequences(sequences, labels, seq_ids, model, tokenizer, label_to_idx, output_dir, device):
    """分析所有序列的特征重要性"""
    os.makedirs(output_dir, exist_ok=True)
    subtype_patterns = {}
    
    for i, (sequence, label, seq_id) in enumerate(zip(sequences, labels, seq_ids)):
        subtype = [k for k, v in label_to_idx.items() if v == label][0]
        print(f"\n分析序列 {i+1}/{len(sequences)}: {seq_id} ({subtype})")
        
        try:
            window_importance = analyze_feature_importance(model, sequence, tokenizer, label, device)
            
            if window_importance is not None:
                important_windows = visualize_importance(
                    sequence, window_importance, seq_id, subtype, output_dir
                )
                
                if subtype not in subtype_patterns:
                    subtype_patterns[subtype] = []
                
                for window_start in important_windows:
                    window_seq = sequence[window_start:window_start+5]
                    subtype_patterns[subtype].append({
                        'window_start': window_start,
                        'window_sequence': window_seq,
                        'importance': float(window_importance[window_start])
                    })
            else:
                print(f"跳过序列 {seq_id}：无法计算特征重要性")
                        
        except Exception as e:
            print(f"处理序列 {seq_id} 时发生错误: {str(e)}")
            continue
    
    return subtype_patterns

def generate_summary_report(subtype_patterns, output_dir):
    """生成分析报告"""
    report_file = os.path.join(output_dir, "feature_importance_summary.txt")
    
    with open(report_file, "w", encoding="utf-8") as f:
        f.write("# Window Importance Analysis Summary\n\n")
        
        for subtype, patterns in subtype_patterns.items():
            f.write(f"## {subtype} Analysis\n\n")
            
            # 统计重要窗口信息
            window_stats = {}
            for pattern in patterns:
                window_start = pattern['window_start']
                if window_start not in window_stats:
                    window_stats[window_start] = {
                        'count': 0,
                        'sequences': {},
                        'avg_importance': 0
                    }
                
                window_stats[window_start]['count'] += 1
                window_seq = pattern['window_sequence']
                if window_seq not in window_stats[window_start]['sequences']:
                    window_stats[window_start]['sequences'][window_seq] = 0
                window_stats[window_start]['sequences'][window_seq] += 1
                window_stats[window_start]['avg_importance'] += pattern['importance']
            
            # 计算平均值并排序
            for window_start in window_stats:
                window_stats[window_start]['avg_importance'] /= window_stats[window_start]['count']
            
            sorted_windows = sorted(
                window_stats.items(),
                key=lambda x: x[1]['avg_importance'],
                reverse=True
            )
            
            f.write("Window Start\tImportance\tSequence Distribution\n")
            for window_start, stats in sorted_windows[:30]:
                seq_dist = ", ".join([
                    f"{seq}:{cnt}" for seq, cnt in stats['sequences'].items()
                ])
                f.write(f"{window_start}\t{stats['avg_importance']:.4f}\t{seq_dist}\n")
            
            f.write("\n")
    
    print(f"已生成汇总报告: {report_file}")

def main():
    model_path = "/root/autodl-tmp/Influenza_BERT/new_class_10/checkpoint-37400"
    test_fasta_file = "/root/autodl-tmp/data/new_class_10/test_sequences_10.fasta"
    test_label_file = "/root/autodl-tmp/data/new_class_10/test_labels_10.csv"
    output_dir = "./feature_importance_results_window_test"
    target_subtypes = ['H5N1',"H9N2","H7N9","H5N8","H3N2","H13N6","H13N8"]
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"使用设备: {device}")
    
    print("加载模型和tokenizer...")
    config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
    model = AutoModelForSequenceClassification.from_pretrained(
        model_path,
        config=config,
        trust_remote_code=True
    ).to(device)
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    
    print("加载测试数据...")
    sequences, labels, seq_ids, label_to_idx = load_data(test_fasta_file, test_label_file)
    
    # 创建字典来存储每个亚型的序列
    subtype_sequences = {}
    for i, (sequence, label, seq_id) in enumerate(zip(sequences, labels, seq_ids)):
        subtype = [k for k, v in label_to_idx.items() if v == label][0]
        if subtype in target_subtypes:
            if subtype not in subtype_sequences:
                subtype_sequences[subtype] = []
            subtype_sequences[subtype].append((sequence, label, seq_id))
    
    test_sequences = []
    test_labels = []
    test_seq_ids = []
    
    # 设置随机种子以确保结果可重复
    np.random.seed(42)
    
    for subtype in target_subtypes:
        if subtype in subtype_sequences:
            available_sequences = subtype_sequences[subtype]
            if len(available_sequences) >= 8:
                # 随机选择8条序列
                selected_indices = np.random.choice(len(available_sequences), 8, replace=False)
                for idx in selected_indices:
                    sequence, label, seq_id = available_sequences[idx]
                    test_sequences.append(sequence)
                    test_labels.append(label)
                    test_seq_ids.append(seq_id)
                    print(f"随机选择亚型 {subtype} 的序列: {seq_id}")
            else:
                print(f"警告：亚型 {subtype} 的序列数量不足8条（仅有 {len(available_sequences)} 条）")
    
    print("\n开始特征重要性分析...")
    model.eval()
    subtype_patterns = analyze_sequences(
        test_sequences, test_labels, test_seq_ids,
        model, tokenizer, label_to_idx,
        output_dir, device
    )
    
    generate_summary_report(subtype_patterns, output_dir)
    print("分析完成！")

if __name__ == "__main__":
    main()

使用设备: cuda
加载模型和tokenizer...




加载测试数据...
加载了 5610 条序列
标签文件包含 5610 条记录
成功匹配了 5610 条序列的标签
随机选择亚型 H5N1 的序列: KY614903.1
随机选择亚型 H5N1 的序列: KY614875.1
随机选择亚型 H5N1 的序列: MF116327.1
随机选择亚型 H5N1 的序列: KY614823.1
随机选择亚型 H5N1 的序列: KY614942.1
随机选择亚型 H5N1 的序列: KY614926.1
随机选择亚型 H5N1 的序列: KY614898.1
随机选择亚型 H5N1 的序列: KY614970.1
随机选择亚型 H9N2 的序列: MF164904.1
随机选择亚型 H9N2 的序列: KY785875.1
随机选择亚型 H9N2 的序列: KY872759.1
随机选择亚型 H9N2 的序列: KY983113.1
随机选择亚型 H9N2 的序列: MF172927.1
随机选择亚型 H9N2 的序列: KY785757.1
随机选择亚型 H9N2 的序列: KY785873.1
随机选择亚型 H9N2 的序列: MF319196.1
随机选择亚型 H7N9 的序列: MF357791.1
随机选择亚型 H7N9 的序列: KP418097.1
随机选择亚型 H7N9 的序列: MF357746.1
随机选择亚型 H7N9 的序列: MF357718.1
随机选择亚型 H7N9 的序列: MF510873.1
随机选择亚型 H7N9 的序列: MF510864.1
随机选择亚型 H7N9 的序列: MF357786.1
随机选择亚型 H7N9 的序列: MF357720.1
随机选择亚型 H5N8 的序列: KY576113.1
随机选择亚型 H5N8 的序列: KY828814.1
随机选择亚型 H5N8 的序列: MF073917.1
随机选择亚型 H5N8 的序列: MF073915.1
随机选择亚型 H5N8 的序列: MF037850.1
随机选择亚型 H5N8 的序列: MF073910.1
随机选择亚型 H5N8 的序列: MF073908.1
随机选择亚型 H5N8 的序列: KY576103.1
随机选择亚型 H3N2 的序列: CY234179.1
随机选择亚型 H3N2 的序列: CY

分析窗口重要性: 100%|██████████| 2229/2229 [00:22<00:00, 100.10it/s]



分析序列 2/56: KY614875.1 (H5N1)


分析窗口重要性: 100%|██████████| 1023/1023 [00:09<00:00, 104.49it/s]



分析序列 3/56: MF116327.1 (H5N1)


分析窗口重要性: 100%|██████████| 1561/1561 [00:15<00:00, 104.06it/s]



分析序列 4/56: KY614823.1 (H5N1)


分析窗口重要性: 100%|██████████| 2016/2016 [00:19<00:00, 101.27it/s]



分析序列 5/56: KY614942.1 (H5N1)


分析窗口重要性: 100%|██████████| 2011/2011 [00:19<00:00, 101.20it/s]



分析序列 6/56: KY614926.1 (H5N1)


分析窗口重要性: 100%|██████████| 1083/1083 [00:10<00:00, 104.92it/s]



分析序列 7/56: KY614898.1 (H5N1)


分析窗口重要性: 100%|██████████| 1383/1383 [00:13<00:00, 104.66it/s]



分析序列 8/56: KY614970.1 (H5N1)


分析窗口重要性: 100%|██████████| 1023/1023 [00:09<00:00, 105.84it/s]



分析序列 9/56: MF164904.1 (H9N2)


分析窗口重要性: 100%|██████████| 545/545 [00:05<00:00, 106.86it/s]



分析序列 10/56: KY785875.1 (H9N2)


分析窗口重要性: 100%|██████████| 1341/1341 [00:12<00:00, 104.26it/s]



分析序列 11/56: KY872759.1 (H9N2)


分析窗口重要性: 100%|██████████| 438/438 [00:04<00:00, 106.46it/s]



分析序列 12/56: KY983113.1 (H9N2)


分析窗口重要性: 100%|██████████| 1561/1561 [00:15<00:00, 102.95it/s]



分析序列 13/56: MF172927.1 (H9N2)


分析窗口重要性: 100%|██████████| 2121/2121 [00:21<00:00, 100.52it/s]



分析序列 14/56: KY785757.1 (H9N2)


分析窗口重要性: 100%|██████████| 2258/2258 [00:22<00:00, 100.13it/s]



分析序列 15/56: KY785873.1 (H9N2)


分析窗口重要性: 100%|██████████| 838/838 [00:07<00:00, 105.91it/s]



分析序列 16/56: MF319196.1 (H9N2)


分析窗口重要性: 100%|██████████| 425/425 [00:03<00:00, 107.07it/s]



分析序列 17/56: MF357791.1 (H7N9)


分析窗口重要性: 100%|██████████| 998/998 [00:09<00:00, 104.75it/s]



分析序列 18/56: KP418097.1 (H7N9)


分析窗口重要性: 100%|██████████| 978/978 [00:09<00:00, 106.10it/s]



分析序列 19/56: MF357746.1 (H7N9)


分析窗口重要性: 100%|██████████| 2312/2312 [00:23<00:00, 99.79it/s] 



分析序列 20/56: MF357718.1 (H7N9)


分析窗口重要性: 100%|██████████| 1418/1418 [00:13<00:00, 104.60it/s]



分析序列 21/56: MF510873.1 (H7N9)


分析窗口重要性: 100%|██████████| 1679/1679 [00:16<00:00, 102.74it/s]



分析序列 22/56: MF510864.1 (H7N9)


分析窗口重要性: 100%|██████████| 1679/1679 [00:16<00:00, 101.77it/s]



分析序列 23/56: MF357786.1 (H7N9)


分析窗口重要性: 100%|██████████| 2312/2312 [00:23<00:00, 99.86it/s] 



分析序列 24/56: MF357720.1 (H7N9)


分析窗口重要性: 100%|██████████| 861/861 [00:08<00:00, 103.00it/s]



分析序列 25/56: KY576113.1 (H5N8)


分析窗口重要性: 100%|██████████| 2276/2276 [00:22<00:00, 99.65it/s] 



分析序列 26/56: KY828814.1 (H5N8)


分析窗口重要性: 100%|██████████| 983/983 [00:09<00:00, 99.26it/s] 



分析序列 27/56: MF073917.1 (H5N8)


分析窗口重要性: 100%|██████████| 1741/1741 [00:16<00:00, 102.94it/s]



分析序列 28/56: MF073915.1 (H5N8)


分析窗口重要性: 100%|██████████| 2324/2324 [00:22<00:00, 101.56it/s]



分析序列 29/56: MF037850.1 (H5N8)


分析窗口重要性: 100%|██████████| 2198/2198 [00:21<00:00, 101.16it/s]



分析序列 30/56: MF073910.1 (H5N8)


分析窗口重要性: 100%|██████████| 1554/1554 [00:14<00:00, 104.66it/s]



分析序列 31/56: MF073908.1 (H5N8)


分析窗口重要性: 100%|██████████| 2191/2191 [00:22<00:00, 98.19it/s] 



分析序列 32/56: KY576103.1 (H5N8)


分析窗口重要性: 100%|██████████| 2147/2147 [00:21<00:00, 101.00it/s]



分析序列 33/56: CY234179.1 (H3N2)


分析窗口重要性: 100%|██████████| 998/998 [00:09<00:00, 104.24it/s]



分析序列 34/56: CY230784.1 (H3N2)


分析窗口重要性: 100%|██████████| 2312/2312 [00:23<00:00, 99.07it/s] 



分析序列 35/56: KU590053.1 (H3N2)


分析窗口重要性: 100%|██████████| 2270/2270 [00:22<00:00, 100.21it/s]



分析序列 36/56: KT839992.1 (H3N2)


分析窗口重要性: 100%|██████████| 978/978 [00:09<00:00, 105.45it/s]



分析序列 37/56: CY228902.1 (H3N2)


分析窗口重要性: 100%|██████████| 1537/1537 [00:14<00:00, 103.27it/s]



分析序列 38/56: CY230927.1 (H3N2)


分析窗口重要性: 100%|██████████| 1438/1438 [00:13<00:00, 104.05it/s]



分析序列 39/56: KC883406.1 (H3N2)


分析窗口重要性: 100%|██████████| 1406/1406 [00:13<00:00, 104.36it/s]



分析序列 40/56: KT837333.1 (H3N2)


分析窗口重要性: 100%|██████████| 1406/1406 [00:13<00:00, 104.54it/s]



分析序列 41/56: MF147670.1 (H13N6)


分析窗口重要性: 100%|██████████| 2337/2337 [00:23<00:00, 98.45it/s]



分析序列 42/56: MF147162.1 (H13N6)


分析窗口重要性: 100%|██████████| 1023/1023 [00:09<00:00, 104.15it/s]



分析序列 43/56: MF148142.1 (H13N6)


分析窗口重要性: 100%|██████████| 886/886 [00:08<00:00, 106.59it/s]



分析序列 44/56: MF147063.1 (H13N6)


分析窗口重要性: 100%|██████████| 2229/2229 [00:21<00:00, 101.36it/s]



分析序列 45/56: MF146699.1 (H13N6)


分析窗口重要性: 100%|██████████| 2337/2337 [00:23<00:00, 98.62it/s]



分析序列 46/56: MF146211.1 (H13N6)


分析窗口重要性: 100%|██████████| 1761/1761 [00:17<00:00, 101.45it/s]



分析序列 47/56: MF147617.1 (H13N6)


分析窗口重要性: 100%|██████████| 886/886 [00:08<00:00, 104.83it/s]



分析序列 48/56: MF148039.1 (H13N6)


分析窗口重要性: 100%|██████████| 1561/1561 [00:15<00:00, 102.28it/s]



分析序列 49/56: MF147764.1 (H13N8)


分析窗口重要性: 100%|██████████| 1561/1561 [00:15<00:00, 102.23it/s]



分析序列 50/56: MF146050.1 (H13N8)


分析窗口重要性: 100%|██████████| 886/886 [00:08<00:00, 105.02it/s]



分析序列 51/56: MF148077.1 (H13N8)


分析窗口重要性: 100%|██████████| 2337/2337 [00:23<00:00, 100.75it/s]



分析序列 52/56: MF146365.1 (H13N8)


分析窗口重要性: 100%|██████████| 1561/1561 [00:14<00:00, 104.18it/s]



分析序列 53/56: MF147732.1 (H13N8)


分析窗口重要性: 100%|██████████| 1023/1023 [00:09<00:00, 105.27it/s]



分析序列 54/56: MF147905.1 (H13N8)


分析窗口重要性: 100%|██████████| 1561/1561 [00:14<00:00, 105.54it/s]



分析序列 55/56: MF146733.1 (H13N8)


分析窗口重要性: 100%|██████████| 886/886 [00:08<00:00, 103.27it/s]



分析序列 56/56: MF145996.1 (H13N8)


分析窗口重要性: 100%|██████████| 1761/1761 [00:17<00:00, 101.52it/s]


已生成汇总报告: ./feature_importance_results_window_test/feature_importance_summary.txt
分析完成！
