In [None]:
import numpy as np
import pandas as pd
import torch
import torch.utils.data as Data
import torch.nn.functional as F
import deepchem as dc
from torch_geometric.utils import dense_to_sparse
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.model_selection import StratifiedKFold
from rdkit import Chem


import os
import glob
import sys
from model import Drug_Molecular, Cell_Line, GO_Network, ATC_Network, CNN_Drug, CNN_GO, CNN_ATC, FCNN, Synergy
from drug_util import GraphDataset, collate, drug_feature_extract
from process_data import getData
from utils import  metric, set_seed_all, SynergyDataset, get_MACCS
import warnings
warnings.filterwarnings("ignore")

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
drug_smiles_file = '../Data/TRAIN/drug_smiles.csv'
drug_go_file = '../Data/TRAIN/drug_go.csv'
drug_atc_file = '../Data/TRAIN/drug_atc.csv'

drug_smiles = pd.read_csv(drug_smiles_file, sep=",", header=0)
drug_data = pd.DataFrame()
drug_smiles_fea = []

featurizer = dc.feat.ConvMolFeaturizer()
for tup in zip(drug_smiles['drugbank_id'], drug_smiles['smiles']):
    mol = Chem.MolFromSmiles(tup[1])
    mol_f = featurizer.featurize(mol)
    drug_data[str(tup[0])] = [mol_f[0].get_atom_features(), mol_f[0].get_adjacency_list()]
    drug_smiles_fea.append(get_MACCS(tup[1]))

drug_num = len(drug_data.keys())
d_map = dict(zip(drug_data.keys(), range(drug_num)))
drug_feature = drug_feature_extract(drug_data)

In [None]:
# ---model_build
DM_dim = [75,512,256] # Drug_Molecular
CL_dim = [len(cell_feature[1]),256] # Cell_Line
GN_dim = [len(drug_smiles_fea[1]),512,256] # GO_Network
AN_dim = [len(drug_smiles_fea[1]),512,256] # ATC_Network

CD_dim = [512,256] #Drug_Molecular+Cell_Line
CO_dim = [512,256] #GO_Network+Cell_Line
CT_dim = [512,256] #ATC_Network+Cell_Line
FN_dim = [(CD_dim[1] * 2 + CO_dim[1] * 2 + CT_dim[1]* 2 + CL_dim[1]),[1024,512,128]]

# Initialize the full synergy prediction model
model = Synergy(Drug_Molecular(dim_drug = DM_dim[0], hidden_dim = DM_dim[1], output_dim = DM_dim[2], heads=4),
                Cell_Line(dim_cellline = CL_dim[0], hidden_dim = CL_dim[1]),
                GO_Network(feature_dim = GN_dim[0], hidden_dim = GN_dim[1], output_dim = GN_dim[2]),
                ATC_Network(feature_dim = AN_dim[0], hidden_dim = AN_dim[1], output_dim = AN_dim[2]),
                CNN_Drug(embed_dim = CD_dim[0], hidden_dim = CD_dim[1]),
                CNN_GO(embed_dim = CO_dim[0], hidden_dim = CO_dim[1]),
                CNN_ATC(embed_dim = CT_dim[0], hidden_dim = CT_dim[1]),
                FCNN(embed_dim = FN_dim[0], hidden_dim = FN_dim[1])
                ).to(device)

model.load_state_dict(torch.load('the_best_model.pth', map_location=torch.device(device)))        

In [None]:
model_drug_structure = model.Drug_Molecular
drug_set = Data.DataLoader(dataset=GraphDataset(graphs_dict=drug_feature),
                        collate_fn=collate, batch_size=len(drug_feature), shuffle=False)
for i ,drug in enumerate(drug_set,0):
    drug_embed, (attention_weights_1, attention_weights_2) = model_drug_structure(drug,True)

In [None]:
def aggregate_multihead_attention_scores(edge_index, attention_weights, num_nodes):
    from collections import defaultdict
    # 初始化每个节点的注意力分数为 0
    node_attention_scores = defaultdict(float)
    
    # 对多头注意力权重进行平均
    attention_weights = attention_weights.mean(dim=1)  # 平均所有头的注意力权重 (num_edges,)
    
    for i in range(edge_index.size(1)):
        src_node = edge_index[0, i].item()  # 边的起始节点
        weight = attention_weights[i].item()  # 对应的注意力权重
        node_attention_scores[src_node] += weight
    
    # 转换为张量
    scores = [node_attention_scores[node] for node in range(num_nodes)]
    return torch.tensor(scores, dtype=torch.float32)
def get_drug_attention_scores(batch, node_attention_scores):
    max_batch = batch.max().item() + 1
    drug_attention_scores = []
    
    for drug_idx in range(max_batch):
        mask = (batch == drug_idx)
        drug_scores = node_attention_scores[mask]
        drug_attention_scores.append(drug_scores)
    
    return drug_attention_scores


In [None]:
for i ,drug in enumerate(drug_set,0):
    drug = drug

# 提取第一层的注意力权重
edge_index = attention_weights_1[0]
attention_weights = attention_weights_1[1]
# 聚合注意力分数到节点
num_nodes = drug.x.size(0)
node_attention_scores = aggregate_multihead_attention_scores(drug.edge_index, attention_weights, num_nodes).to(device)
# 分配注意力分数到每个药物
drug_attention_scores = get_drug_attention_scores(drug.batch, node_attention_scores)


In [None]:
def plot_drug_structure_with_attention(edge_index, attention_scores, title="Drug Structure with Attention", save_path=None, vmin=None, vmax=None):
    import networkx as nx
    import matplotlib.pyplot as plt
    from matplotlib.colors import Normalize
    import numpy as np
    # 构建 NetworkX 图
    G = nx.Graph()
    edges = edge_index.t().tolist()  # 转换为 Python 列表
    G.add_edges_from(edges)
    # 确保 attention_scores 是 NumPy 数组
    if isinstance(attention_scores, torch.Tensor):
        attention_scores = attention_scores.numpy()
    # 获取图中所有节点
    nodes = list(G.nodes())
    # 确保 attention_scores 的长度与图中的节点数一致
    if len(attention_scores) != len(nodes):
        raise ValueError("The length of attention_scores must match the number of nodes in the graph.")
    # 使用全局 vmin/vmax 进行归一化
    global_norm = Normalize(vmin=vmin, vmax=vmax)
    normalized_scores = global_norm(attention_scores)
    node_sizes = normalized_scores * 500 + 200  # 调整节点大小范围
    # 使用 kamada_kawai_layout 来减少节点重叠
    pos = nx.kamada_kawai_layout(G)
    # 绘制图
    plt.figure(figsize=(10, 8))
    ax = plt.gca()
    # 使用 ScalarMappable 创建颜色映射
    sm = plt.cm.ScalarMappable(cmap=plt.cm.Reds, norm=global_norm)
    sm.set_array([])
    # 绘制节点和边
    nx.draw_networkx(G, pos, node_size=node_sizes, node_color=normalized_scores, cmap=plt.cm.Reds, with_labels=False, ax=ax)
    nx.draw_networkx_edges(G, pos, edge_color='gray', width=1.0, alpha=0.7)
    # 标注每个节点的注意力值
    labels = {node: f"{attention_scores[i]:.2f}" for i, node in enumerate(nodes)}
    nx.draw_networkx_labels(G, pos, labels=labels, font_size=8, font_color='black', ax=ax)
    # 添加颜色条
    cbar = plt.colorbar(sm, orientation='vertical', label='Attention Score', ax=ax)
    # 设置标题和去掉坐标轴
    plt.title(title, fontsize=16)
    plt.axis('off')
    # 如果提供了保存路径，则保存为 PDF
    if save_path:
        plt.savefig(save_path, format='pdf', bbox_inches='tight', dpi=300)
        print(f"Figure saved to {save_path}")
    plt.show()

In [None]:
drug1_attention = drug_attention_scores[1444].cpu()
drug2_attention = drug_attention_scores[1462].cpu()
# 拼接成一个数组，找出最大值和最小值
combined_attention = torch.cat([drug1_attention, drug2_attention], dim=0)
vmin = combined_attention.min().item()
vmax = combined_attention.max().item()

In [None]:
# 第一个药物
first_drug_edge_index = edge_index[:, (drug.batch[edge_index[0]] == 1444)]
first_drug_attention_scores = drug_attention_scores[1444]
plot_drug_structure_with_attention(
    first_drug_edge_index.cpu(),
    first_drug_attention_scores.cpu(),
    vmin=vmin,
    vmax=vmax,
    save_path="DB11967_attention.pdf"
)

In [None]:
# 第二个药物
first_drug_edge_index = edge_index[:, (drug.batch[edge_index[0]] == 1462)]
first_drug_attention_scores = drug_attention_scores[1462]
plot_drug_structure_with_attention(
    first_drug_edge_index.cpu(),
    first_drug_attention_scores.cpu(),
    vmin=vmin,
    vmax=vmax,
    save_path="DB12267_attention.pdf"
)