In [None]:
import torch
import torch.nn as nn
import shap
import os
import glob
import sys
import numpy as np
import pandas as pd
import torch
import torch.utils.data as Data
import torch.nn.functional as F
from scipy.stats import pearsonr
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.nn import GCNConv, global_max_pool as gmp, global_mean_pool as gap
from sklearn.model_selection import StratifiedKFold

from model import Drug_Molecular, Cell_Line, GO_Network, ATC_Network, CNN_Drug, CNN_GO, CNN_ATC, FCNN, FCNN, Synergy
from drug_util import GraphDataset, collate
from process_data import getData
from utils import  metric, set_seed_all, SynergyDataset
import warnings
warnings.filterwarnings("ignore")

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
def load_data(cell_exp_path,drug_synergy_path):
    drug_feature, drug_go_adj_weight, drug_atc_adj_weight, drug_smiles_fea, cell_feature, synergy, d_map, c_map = getData(cell_exp_path,drug_synergy_path)
    # 筛选样本
    filtered_synergy = []
    for row in synergy:
        if row[3] > 10:  # 大于10为阳性
            row[3] = 1
            filtered_synergy.append(row)
        elif row[3] < 0:  # 小于0为阴性
            row[3] = 0
            filtered_synergy.append(row)
        # 其余样本被剔除
    return drug_feature, drug_go_adj_weight, drug_atc_adj_weight, drug_smiles_fea, cell_feature, filtered_synergy, d_map, c_map

In [None]:
cell_exp_path = '../Data/TRAIN/train_cell_exp.csv'
cancer_drug_combination = '../Data/SHAP/cancer_drug_combination.csv'

dataset_name = "SHAP"
drug_feature, drug_go_adj_weight, drug_atc_adj_weight, drug_smiles_fea, cell_feature, cancer_drug_combination, d_map, c_map = load_data(cell_exp_path,cancer_drug_combination)
columns = ['Drug1', 'Drug2', 'Cell_line', 'Loewe']
cancer_drug_combination = pd.DataFrame(cancer_drug_combination , columns=columns)
drug_smiles_fea = torch.tensor(drug_smiles_fea).to(device)
reverse_d_map = {v: k for k, v in d_map.items()}
reverse_c_map = {v: k for k, v in c_map.items()}

GO_adj = drug_go_adj_weight[0].to(device)
GO_weight = drug_go_adj_weight[1].to(device)
ATC_adj = drug_atc_adj_weight[0].to(device)
ATC_weight = drug_atc_adj_weight[1].to(device)

cell_feature = torch.tensor(cell_feature, dtype=torch.float32).to(device)

drug_set = Data.DataLoader(dataset=GraphDataset(graphs_dict=drug_feature),
                        collate_fn=collate, batch_size=len(drug_feature), shuffle=False)
cell_set = Data.DataLoader(dataset=cell_feature,
                            batch_size=len(cell_feature), shuffle=False)


# ---model_build
DM_dim = [75,512,256] # Drug_Molecular
CL_dim = [len(cell_feature[1]),256] # Cell_Line
GN_dim = [len(drug_smiles_fea[1]),512,256] # GO_Network
AN_dim = [len(drug_smiles_fea[1]),512,256] # ATC_Network

CD_dim = [512,256] #Drug_Molecular+Cell_Line
CO_dim = [512,256] #GO_Network+Cell_Line
CT_dim = [512,256] #ATC_Network+Cell_Line
FN_dim = [(CD_dim[1] * 2 + CO_dim[1] * 2 + CT_dim[1]* 2 + CL_dim[1]),[1024,512,128]]

model = Synergy(Drug_Molecular(dim_drug = DM_dim[0], hidden_dim = DM_dim[1], output_dim = DM_dim[2], heads=4),
                Cell_Line(dim_cellline = CL_dim[0], hidden_dim = CL_dim[1]),
                GO_Network(feature_dim = GN_dim[0], hidden_dim = GN_dim[1], output_dim = GN_dim[2]),
                ATC_Network(feature_dim = AN_dim[0], hidden_dim = AN_dim[1], output_dim = AN_dim[2]),
                CNN_Drug(embed_dim = CD_dim[0], hidden_dim = CD_dim[1]),
                CNN_GO(embed_dim = CO_dim[0], hidden_dim = CO_dim[1]),
                CNN_ATC(embed_dim = CT_dim[0], hidden_dim = CT_dim[1]),
                FCNN(embed_dim = FN_dim[0], hidden_dim = FN_dim[1])
                ).to(device)

learning_rate = 0.0001
weights = torch.tensor([0.5, 16], dtype=torch.float32)
loss_func = torch.nn.CrossEntropyLoss(weight=weights)
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate, betas = (0.9, 0.999), weight_decay = 1e-4, amsgrad = False)
model.load_state_dict(torch.load('the_bestt_model.pth', map_location=torch.device(device)))

In [None]:
cancer_loader = Data.DataLoader(SynergyDataset(cancer_drug_combination), batch_size=len(cancer_drug_combination), shuffle=False)
for i, data in enumerate(cancer_loader, 0):
    # 获取三元组和Loewe值
    cancer_index, cancer_labels = data
feature = model(drug_set, cell_set ,GO_adj, GO_weight, 
          ATC_adj, ATC_weight,drug_smiles_fea,cancer_index)
feature_cancer_combination = feature[5]

In [None]:
class Cell_Line_SHAP(torch.nn.Module):
    def __init__(self, dim_cellline, hidden_dim, dropout=0.2):
        super(Cell_Line,self).__init__()
        self.relu = nn.LeakyReLU()
        self.dropout = nn.Dropout(dropout)
        self.cell_fc1 = nn.Linear(dim_cellline, hidden_dim)
        self.cell_bn1 = nn.BatchNorm1d(hidden_dim)

    def forward(self, gexpr_data):
        x_cell = self.cell_fc1(gexpr_data)
        x_cell = self.cell_bn1(x_cell)
        return x_cell
class FCNN_SHAP(torch.nn.Module):
    def __init__(self, embed_dim, hidden_dim,dropout=0.2):
        super(FCNN, self).__init__()
        self.last_layer_feature = None
        self.relu = nn.LeakyReLU()
        self.dropout = nn.Dropout(dropout)
        self.embed_fc1 = nn.Linear(embed_dim, hidden_dim[0])
        self.embed_bn1 = nn.BatchNorm1d(hidden_dim[0])
        self.embed_fc2 = nn.Linear(hidden_dim[0], hidden_dim[1])
        self.embed_bn2 = nn.BatchNorm1d(hidden_dim[1])
        self.embed_fc3 = nn.Linear(hidden_dim[1], hidden_dim[2])
        self.embed_bn3 = nn.BatchNorm1d(hidden_dim[2])
        self.embed_fc4 = nn.Linear(hidden_dim[2], 2)

    def forward(self, Drug1, Drug2, GO1, GO2, ATC1, ATC2, cell_embed, index):
        embed = torch.cat((Drug1, Drug2, GO1, GO2, ATC1, ATC2, cell_embed[index[:, 2],:]),1)
        embed = self.embed_fc1(embed)
        embed = self.embed_bn1(embed)
        embed = self.relu(embed)
        embed = self.dropout(embed) 
        embed = self.embed_fc2(embed)
        embed = self.embed_bn2(embed)
        embed = self.relu(embed)
        embed = self.dropout(embed) 
        embed = self.embed_fc3(embed)
        embed = self.embed_bn3(embed)
        embed = self.relu(embed)
        embed = self.dropout(embed) 
        embed = self.embed_fc4(embed)

        return embed
    
class Synergy_SHAP(torch.nn.Module):
    def __init__(self, Cell_Line_SHAP, FCNN_SHAP):
        super(Synergy_SHAP, self).__init__()
        self.Cell_Line_SHAP = Cell_Line_SHAP
        self.FCNN_SHAP = FCNN_SHAP

    def forward(self,feature):
        cell_dim = self.Cell_Line_SHAP(feature[:,1536:2485])
        feature_dim = torch.cat((feature[:,0:1536], cell_dim),1)
        synergy_score = self.FCNN_SHAP(feature_dim)
        probs = F.softmax(synergy_score, dim=1)
        return probs

In [None]:
fcnn_shap = FCNN_SHAP(embed_dim=1792, hidden_dim=[1024, 256])
cell_shap = Cell_Line_SHAP(949,256)
fcnn_shap.load_state_dict(model.FCNN.state_dict())
cell_shap.load_state_dict(model.Cell_Line.state_dict())
model_SHAP = Synergy_SHAP(cell_shap,fcnn_shap)
model_SHAP.eval()

In [None]:
import random
random_indices = random.sample(range(len(feature_cancer_combination[0])), 500)
feature_cancer_backgroud = [feature_cancer_combination[0][random_indices]]
explainer = shap.GradientExplainer(model_SHAP,feature_cancer_backgroud)
shap_values_combination = explainer.shap_values(feature_cancer_combination)

In [None]:
cell_exp = pd.read_csv('../Data/TRAIN/train_cell_exp.csv')
gene_names = cell_exp.columns[1:]
shaps_gene = pd.DataFrame(shap_values_combination[:,1536:2485,1],columns=gene_names)
cell_feature_pd = pd.DataFrame(feature_cancer_combination[0][:,1536:2485].detach().numpy(),columns=gene_names)

In [None]:
correlations = [pearsonr(cell_feature_pd.iloc[:,i].values,shaps_gene[:,i])[0] for i in range(shaps_gene.shape[1])]
corr_frame = pd.DataFrame(index=cell_feature_pd.columns)
corr_frame['corr'] = correlations

In [None]:
import matplotlib.pyplot as plt
##
## select the top 20 gene expression features by pos. expr-attr corr.
zero_frame = corr_frame.fillna(0).sort_values('corr',ascending=False)
X_sig = cell_feature_pd.loc[:,cell_feature_pd .columns.isin(zero_frame.index[:20])]
shaps_sig = shaps_gene[:,:949][:,cell_feature_pd.columns.isin(zero_frame.index[:20])]
zero_frame_pos = zero_frame

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.colors as mcolors
colors = ["#6587BC","#f7f7f7", "#CF6A69"]
cmap = mcolors.LinearSegmentedColormap.from_list("", colors)
shap.summary_plot(shaps_sig,X_sig,plot_size=(5,6), plot_type="dot", cmap=cmap, show=False) 
plt.savefig("summary_plot_pos.pdf", dpi=1000, format='pdf')
plt.show()

In [None]:
## select the top 20 gene expression features by neg. expr-attr corr.
zero_frame = corr_frame.fillna(0).sort_values('corr',ascending=True)
X_sig = cell_feature_pd.loc[:,cell_feature_pd.columns.isin(zero_frame.index[:20])]
shaps_sig = shaps_gene[:,:949][:,cell_feature_pd.columns.isin(zero_frame.index[:20])]
zero_frame_neg = zero_frame

In [None]:
colors = ["#6587BC","#f7f7f7", "#CF6A69"]
cmap = mcolors.LinearSegmentedColormap.from_list("", colors)
shap.summary_plot(shaps_sig,X_sig,plot_size=(5,6), plot_type="dot", cmap=cmap, show=False) 
plt.savefig("summary_plot_neg.pdf", dpi=1000, format='pdf')
plt.show()

In [None]:
feature_importance = shaps_gene[:,:949][:,:].mean(axis=0)
feature_names = cell_feature_pd.columns
importance_df = pd.DataFrame({
    'Feature': feature_names,
    'Importance': feature_importance
}).sort_values(by='Importance', ascending=False)

importance_df['Ratio'] = importance_df['Importance'] / importance_df['Importance'].sum()

top_20_features = importance_df.head(20)

fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True))

labels = top_20_features['Feature']
values = top_20_features['Ratio']

angles = np.linspace(0, 2 * np.pi, len(labels), endpoint=False).tolist()
width = (2 * np.pi) / len(labels)

bars = ax.bar(
    angles, 
    values, 
    width=width, 
    align='center', 
    edgecolor="white", 
    linewidth=1, 
    alpha=0.8
)

cmap = plt.cm.tab20 
for i, bar in enumerate(bars):
    bar.set_facecolor(cmap(i % cmap.N))

ax.set_xticks(angles)
ax.set_xticklabels(labels, fontsize=10, rotation=45, ha='right')
plt.title('Top 20 gene Feature Importance Ratio', va='bottom', fontsize=14)
plt.tight_layout()
#output_path = "top_20_gene_importance_rose.pdf"
#plt.savefig(output_path, format='pdf', bbox_inches='tight')
plt.show()

In [None]:
pos_importance_df = importance_df[importance_df['Feature'].isin(zero_frame_pos.index[:20])]
fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True))
labels = pos_importance_df['Feature']
values = pos_importance_df['Ratio']
angles = np.linspace(0, 2 * np.pi, len(labels), endpoint=False).tolist()
width = (2 * np.pi) / len(labels)
bars = ax.bar(
    angles, 
    values, 
    width=width, 
    align='center', 
    edgecolor="white", 
    linewidth=1, 
    alpha=0.8
)

cmap = plt.cm.tab20 
for i, bar in enumerate(bars):
    bar.set_facecolor(cmap(i % cmap.N))

ax.set_xticks(angles)
ax.set_xticklabels(labels, fontsize=10, rotation=45, ha='right')

plt.title('Top 20 pos gene Feature Importance Ratio', va='bottom', fontsize=14)
plt.tight_layout()
#output_path = "top_20_pos_gene_importance_rose.pdf"
#plt.savefig(output_path, format='pdf', bbox_inches='tight')
plt.show()

In [None]:
neg_importance_df = importance_df[importance_df['Feature'].isin(zero_frame_neg.index[:20])]
fig, ax = plt.subplots(figsize=(8, 8), subplot_kw=dict(polar=True))
labels = neg_importance_df['Feature']
values = neg_importance_df['Ratio']
angles = np.linspace(0, 2 * np.pi, len(labels), endpoint=False).tolist()
width = (2 * np.pi) / len(labels)
bars = ax.bar(
    angles, 
    values, 
    width=width, 
    align='center', 
    edgecolor="white", 
    linewidth=1, 
    alpha=0.8
)

cmap = plt.cm.tab20
for i, bar in enumerate(bars):
    bar.set_facecolor(cmap(i % cmap.N))
ax.set_xticks(angles)
ax.set_xticklabels(labels, fontsize=10, rotation=45, ha='right')
plt.title('Top 20 neg gene Feature Importance Ratio', va='bottom', fontsize=14)
plt.tight_layout()
#output_path = "top_20_neg_gene_importance_rose.pdf"
#plt.savefig(output_path, format='pdf', bbox_inches='tight')
plt.show()

In [None]:
X_sig_pos = cell_feature_pd.loc[:, cell_feature_pd.columns.isin(zero_frame_pos.index[:5])]
n_genes = len(X_sig_pos.columns)
ncols = 5
nrows = int(np.ceil(n_genes / ncols))
fig, axes = plt.subplots(nrows, ncols, figsize=(4 * ncols, 4 * nrows))
axes = axes.flatten()
for i, gene in enumerate(X_sig_pos.columns):
    ax = axes[i]
    expr_values = X_sig_pos[gene].values
    shap_values = shaps_sig[:, i]
    correlation = corr_frame.loc[gene, 'corr']
    ax.scatter(expr_values, shap_values, s=5, alpha=1, color='#80b3e1', rasterized=True)
 
    ax.set_title(f"Gene: {gene} (Corr: {correlation:.2f})", fontsize=9, pad=5)
    ax.set_xlabel("Expression", fontsize=8)
    ax.set_ylabel("SHAP", fontsize=8)
    ax.tick_params(axis='both', which='major', labelsize=7)
    ax.grid(True, alpha=0.3)
for j in range(i + 1, len(axes)):
    axes[j].set_visible(False)
 
plt.tight_layout()
#plt.savefig('pos_top20_genes_cor.pdf',
#            bbox_inches='tight',
#            dpi=150)
plt.show()

In [None]:
X_sig_neg = cell_feature_pd.loc[:,cell_feature_pd.columns.isin(zero_frame_neg.index[:5])]
n_genes = len(X_sig_neg.columns)
ncols = 5
nrows = int(np.ceil(n_genes / ncols))

fig, axes = plt.subplots(nrows, ncols, figsize=(4 * ncols, 4 * nrows))
 
axes = axes.flatten()
 
for i, gene in enumerate(X_sig_neg.columns):
    ax = axes[i]
    expr_values = X_sig_neg[gene].values
    shap_values = shaps_sig[:, i]
    correlation = corr_frame.loc[gene, 'corr']
    ax.scatter(expr_values, shap_values, s=5, alpha=1, color='#80b3e1', rasterized=True)
 
    ax.set_title(f"Gene: {gene} (Corr: {correlation:.2f})", fontsize=9, pad=5)
    ax.set_xlabel("Expression", fontsize=8)
    ax.set_ylabel("SHAP", fontsize=8)
    ax.tick_params(axis='both', which='major', labelsize=7)
    ax.grid(True, alpha=0.3)
for j in range(i + 1, len(axes)):
    axes[j].set_visible(False)
 
plt.tight_layout()
#plt.savefig('neg_top20_genes_cor.pdf',
#            bbox_inches='tight',
#            dpi=150)

plt.show()