In [2]:
import torch
import numpy as np
import pandas as pd
import copy
import torch.nn.functional as F
import scipy.sparse as sp
from torch_geometric.nn import GCNConv,GATConv,SAGEConv
from torch_geometric.datasets import Planetoid
from torch.nn import Linear
from sklearn.preprocessing import StandardScaler
from sklearn import preprocessing
import datetime
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import global_mean_pool

In [3]:
from torch.nn import Linear
from torch_geometric.nn import GraphConv

class GNN(torch.nn.Module):
    def __init__(self, num_node_features,hidden_channels,num_classes):
        super(GNN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GraphConv(num_node_features, hidden_channels)
        self.conv2 = GraphConv(hidden_channels, hidden_channels)
        self.conv3 = GraphConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels,num_classes)
        
    def forward(self, x, edge_index, batch, edge_weight=None):
        # 1. 获得节点嵌入
        x = self.conv1(x, edge_index,edge_weight)
        x = x.relu()
        x = self.conv2(x, edge_index,edge_weight)
        x = x.relu()
        x = self.conv3(x, edge_index,edge_weight)
        
        # 2. Readout layer
        x = global_mean_pool(x, batch)   # [batch_size, hidden_channels]
        
        # 3. 分类器
        x = F.dropout(x, training=self.training)
        x = self.lin(x)
        return x

In [4]:
device = "cpu"

In [6]:
model = torch.load("/public/home/liujunwu/workdir/scripts/GNN_Reactome/reaction_file/GSE16879/GSE16879.IBD.Response.pt",map_location=device)
print (model)

GNN(
  (conv1): GraphConv(1, 64)
  (conv2): GraphConv(64, 64)
  (conv3): GraphConv(64, 64)
  (lin): Linear(in_features=64, out_features=2, bias=True)
)


In [8]:
for name, param in model.named_parameters(): print(name, param)

conv1.lin_rel.weight Parameter containing:
tensor([[-1.0173],
        [ 0.2253],
        [ 0.1614],
        [ 0.4782],
        [ 0.7046],
        [ 0.5788],
        [ 0.7098],
        [-0.2629],
        [-0.7253],
        [-0.3379],
        [-0.3475],
        [-0.0324],
        [ 0.1350],
        [-0.5218],
        [-0.8372],
        [-0.5203],
        [ 0.2390],
        [ 0.3491],
        [-0.7442],
        [-0.9946],
        [-0.1148],
        [-0.2876],
        [-0.3221],
        [-0.9241],
        [ 0.4518],
        [-0.2196],
        [ 0.8028],
        [-0.4908],
        [-0.7956],
        [-0.1230],
        [ 0.0697],
        [ 0.6123],
        [-0.0452],
        [ 0.6832],
        [-0.5684],
        [-0.7486],
        [-0.4536],
        [-0.3731],
        [ 0.2716],
        [ 0.7436],
        [ 0.7142],
        [-0.5476],
        [-0.3486],
        [-0.3352],
        [ 0.3811],
        [-0.2897],
        [-0.2741],
        [ 0.1367],
        [-0.7182],
        [ 0.6556],
       

In [8]:
bulk_total_sets = torch.load("/public/home/liujunwu/workdir/scripts/GNN_Reactome/reaction_file/GSE16879/GSE16879.IBD.Response.data.pt",map_location=device)

In [9]:
from captum.attr import Saliency, IntegratedGradients
def model_forward(e_mask, dt):
    #batch = torch.zeros(dt.x.shape[0], dtype=int).to(device)
    batch = torch.zeros(dt.x.shape[0], dtype=int).to(device)
    out = model(dt.x,
                dt.edge_index,
                batch,
                e_mask)
    return out
def explain(method, dt, target=0):
    #input_mask = torch.ones(dt.edge_index.shape[1]).requires_grad_(True).to(device)
    input_mask = torch.ones(dt.edge_index.shape[1]).requires_grad_(True)
    if method == 'ig':
        ig = IntegratedGradients(model_forward)
        mask = ig.attribute(input_mask, target=target,
                            additional_forward_args=(dt,),
                            internal_batch_size=dt.edge_index.shape[1])
    elif method == 'saliency':
        saliency = Saliency(model_forward)
        mask = saliency.attribute(input_mask, target=target,
                                  additional_forward_args=(data,))
    else:
        raise Exception('Unknown explanation method')
 
    edge_mask = np.abs(mask.cpu().detach().numpy())
    if edge_mask.max() > 0:  # avoid division by zero
        edge_mask = edge_mask / edge_mask.max()
    return edge_mask

In [10]:
data = bulk_total_sets[0]
print (data.y)
target = 0
edge_mask = explain("ig", data, target=data.y)

tensor(1)


In [38]:
bulk_total_sets

[Data(x=[11714, 1], edge_index=[2, 240272], y=1),
 Data(x=[11714, 1], edge_index=[2, 240272], y=1),
 Data(x=[11714, 1], edge_index=[2, 240272], y=1),
 Data(x=[11714, 1], edge_index=[2, 240272], y=1),
 Data(x=[11714, 1], edge_index=[2, 240272], y=1),
 Data(x=[11714, 1], edge_index=[2, 240272], y=1),
 Data(x=[11714, 1], edge_index=[2, 240272], y=1),
 Data(x=[11714, 1], edge_index=[2, 240272], y=1),
 Data(x=[11714, 1], edge_index=[2, 240272], y=1),
 Data(x=[11714, 1], edge_index=[2, 240272], y=1),
 Data(x=[11714, 1], edge_index=[2, 240272], y=1),
 Data(x=[11714, 1], edge_index=[2, 240272], y=1),
 Data(x=[11714, 1], edge_index=[2, 240272], y=1),
 Data(x=[11714, 1], edge_index=[2, 240272], y=1),
 Data(x=[11714, 1], edge_index=[2, 240272], y=0),
 Data(x=[11714, 1], edge_index=[2, 240272], y=0),
 Data(x=[11714, 1], edge_index=[2, 240272], y=0),
 Data(x=[11714, 1], edge_index=[2, 240272], y=0),
 Data(x=[11714, 1], edge_index=[2, 240272], y=0),
 Data(x=[11714, 1], edge_index=[2, 240272], y=0),


In [12]:
np.where(edge_mask==np.max(edge_mask))

(array([233823]),)

In [20]:
edge_mask.argsort()[-50:][::-1]

array([233823, 239196, 125117, 120636,  44181,  11672, 185087, 185323,
       121488, 147154, 112694,  70468, 205397,  23721, 196262,  44245,
       169383,  70484,  82611, 196265,   6904, 166421,  91236, 182182,
        36298,  58184, 194504, 148175,  91067,  23722, 145474, 135187,
        82278,  67134, 155934, 115213, 115214,  68240,  38615,  23808,
       189287, 189286, 182114,  36221,  38687, 127183,  35852, 133597,
       205399, 192133])

In [13]:
data.edge_index.T[233823]

tensor([ 8736, 11287])

In [17]:
import json
nodes_index_file = "/public/home/liujunwu/workdir/scripts/GNN_Reactome/reaction_file/reactome_reaction.uniqnodes.json"
with open(nodes_index_file,'r', encoding='UTF-8') as f:
    nodes_index_dict = json.load(f)

In [18]:
new_d = {v:k for k,v in nodes_index_dict.items()}

In [19]:
print (new_d[8736],new_d[11287])

R-HSA-8956106 R-HSA-9755303


In [37]:
edge_freq_time = {}
for data in bulk_total_sets:
    if (data.y == 0): ## noresponse
        edge_mask = explain("ig", data, target=data.y)
        edge_mask_top100 = edge_mask.argsort()[-100:][::-1]
        for i in edge_mask_top100:
            if (i not in edge_freq_time.keys()):
                edge_freq_time[i] = 1
            else:
                edge_freq_time[i] +=1
    else:
        pass

In [24]:
a = [1,2,3,4,5]
type(edge_mask)

numpy.ndarray

In [39]:
sorted(edge_freq_time.items(),key = lambda x:x[1],reverse = True)

[(70484, 33),
 (205397, 33),
 (44181, 31),
 (196265, 31),
 (6896, 31),
 (82611, 30),
 (169383, 29),
 (6905, 28),
 (184852, 28),
 (147154, 28),
 (148169, 28),
 (148132, 28),
 (112694, 27),
 (23721, 27),
 (189286, 27),
 (189287, 27),
 (6904, 26),
 (23722, 26),
 (196262, 25),
 (205399, 25),
 (70468, 24),
 (58184, 24),
 (11672, 24),
 (44245, 22),
 (233823, 21),
 (70467, 21),
 (162006, 20),
 (125117, 20),
 (159554, 20),
 (238290, 20),
 (169634, 20),
 (181740, 19),
 (239196, 19),
 (10737, 19),
 (10736, 19),
 (23730, 19),
 (162007, 18),
 (188980, 18),
 (159553, 17),
 (80648, 17),
 (159552, 17),
 (120472, 17),
 (227773, 16),
 (206150, 15),
 (60723, 15),
 (10995, 15),
 (191165, 14),
 (129620, 14),
 (44249, 14),
 (10994, 14),
 (23728, 13),
 (120150, 12),
 (23729, 12),
 (80647, 12),
 (44240, 12),
 (10738, 12),
 (155846, 12),
 (6897, 11),
 (185087, 11),
 (185323, 11),
 (112158, 11),
 (72673, 11),
 (6899, 11),
 (216524, 11),
 (216527, 11),
 (79928, 11),
 (103348, 11),
 (60724, 11),
 (238814, 10),
 

In [46]:
edge_freq_time_sort  = sorted(edge_freq_time.items(),key = lambda x:x[1],reverse = True)

In [47]:
type(edge_freq_time_sort)

list

In [50]:
for key,value in edge_freq_time_sort:
    print (data.edge_index.T[key])
    break

tensor([2809, 2813])
