In [2]:
import torch
import numpy
import sklearn
import random
import time
import torch.nn.functional as F
from IPython.display import Javascript
from torch.nn import Linear
from sklearn import preprocessing
from torch_geometric.datasets import TUDataset
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GraphConv, global_mean_pool
import matplotlib.pyplot as plt

random.seed = 88888888

In [3]:
device = cuda0 = torch.device('cuda:0')
cpu = torch.device('cpu')

edges_fn = '/mnt/home/yuankeji/RanceLab/reticula_new/reticula/data/GEO_model_training/input/edges.txt'
node_features_fn = '/mnt/home/yuankeji/RanceLab/reticula_new/reticula/data/GEO_model_training/input/node_features.txt'
graph_targets_fn = '/mnt/home/yuankeji/RanceLab/reticula_new/reticula/data/GEO_model_training/input/graph_targets.txt'

# magic numbers
INPUT_CHANNELS = 1
OUTPUT_CHANNELS = 26
HIDDEN_CHANNELS = 64
BATCH_SIZE = 64
EPOCHS = 500 #set this to 200 - 2000
BENCHMARKING = False

In [4]:
feature_v = numpy.loadtxt(node_features_fn)
print(feature_v)

FileNotFoundError: /mnt/home/yuankeji/RanceLab/reticula_new/reticula/data/GEO_model_training/input/node_features.txt not found.

In [None]:
def read_reactome_graph(edges_fn, node_features_fn):
    edge_v1 = []
    edge_v2 = []

    for line in open(edges_fn, 'r'):
        data = line.split()
        node1 = int(data[0]) - 1 #subtracting to convert R idx to python idx
        node2 = int(data[1]) - 1 # " "
        edge_v1.append( node1 )
        edge_v2.append( node2 )

    return edge_v1, edge_v2

In [None]:
(edge_v1, edge_v2) = read_reactome_graph(edges_fn, node_features_fn)
print(edge_v1)
print(edge_v2)

In [None]:
def build_scratch_loader(batch_size):
  dataset = TUDataset(root='data/TUDataset', name='MUTAG')
  data_list = []
  for graph_obj in dataset:
    x = torch.tensor(graph_obj.x[:,1],dtype=torch.float)
    x = x.unsqueeze(1)
    y = graph_obj.y
    edge_index = graph_obj.edge_index
    data_list.append(Data(x = x, y = y, edge_index = edge_index))

  loader = DataLoader(data_list,batch_size=batch_size,shuffle=True)

  return loader

In [None]:
def build_reactome_graph_datalist(edge_v1, edge_v2, node_features_fn, graph_targets_fn):
    edge_index = torch.tensor([edge_v1, edge_v2], dtype = torch.long)
    feature_v = numpy.loadtxt(node_features_fn)
    target_v = numpy.loadtxt(graph_targets_fn,dtype=str,delimiter=",")
    
    target_encoder = sklearn.preprocessing.LabelEncoder()
    target_v = target_encoder.fit_transform(target_v)
    
    print(len(feature_v))
    print(len(target_v))

    data_list = []
    for row_idx in range(len(feature_v)):
      features = feature_v[row_idx,:]
      x = torch.tensor(features,dtype=torch.float)
      x = x.unsqueeze(1)
      y = torch.tensor([target_v[row_idx]])
      data_list.append(Data(x = x, y = y, edge_index = edge_index))

    return data_list

def build_reactome_graph_loader(data_list,batch_size):

    loader = DataLoader(data_list,batch_size=batch_size,shuffle=True)

    return loader

In [None]:

class GNN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GNN, self).__init__()

        self.conv1 = GraphConv(INPUT_CHANNELS, hidden_channels)
        self.conv2 = GraphConv(hidden_channels,hidden_channels)
        self.conv3 = GraphConv(hidden_channels,hidden_channels)
        self.lin = Linear(hidden_channels, OUTPUT_CHANNELS)

    def forward(self, x, edge_index, batch, edge_weight=None):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index, edge_weight)
        x = x.relu()
        x = self.conv2(x, edge_index, edge_weight)
        x = x.relu()
        x = self.conv3(x, edge_index, edge_weight)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, training=self.training)
        x = self.lin(x)
        
        return x

In [None]:
model = GNN(hidden_channels=HIDDEN_CHANNELS)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters())
criterion = torch.nn.CrossEntropyLoss()

def train(loader,device):
  model.train()

  for batch in loader:  # Iterate in batches over the training dataset.
    x = batch.x.to(device)
    e = batch.edge_index.to(device)
    b = batch.batch.to(device)
    y = batch.y.to(device)
    
    out = model(x, e, b)  # Perform a single forward pass.
    
    loss = criterion(out, y)  # Compute the loss.
    loss.backward()  # Derive gradients.
    optimizer.step()  # Update parameters based on gradients.
    optimizer.zero_grad()  # Clear gradients.

def test(loader,device):
  model.eval()

  correct = 0
  for batch in loader:  # Iterate in batches over the training/test dataset.
    x = batch.x.to(device)
    e = batch.edge_index.to(device)
    b = batch.batch.to(device)
    y = batch.y.to(device)
    out = model(x, e, b)  # Perform a single forward pass.
    loss = criterion(out, y)  # Compute the loss.
    pred = out.argmax(dim=1)  # Use the class with highest probability.
    correct += int((pred == y).sum())  # Check against ground-truth labels.
  return correct / len(loader.dataset)  # Derive ratio of correct predictions.

In [None]:
acc_str = ''
if(BENCHMARKING):

  test_b_sizes = [1,8,16,32,64,128]

  for test_b_size in test_b_sizes:
    print(f'Executing training routine with batch size = {test_b_size}')
    data_list = build_reactome_graph_datalist(edge_v1, edge_v2, node_features_fn, graph_targets_fn)
    test_batch_size_data_loader = build_reactome_graph_loader(data_list,test_b_size)
  
    start = time.time()
    train(test_batch_size_data_loader,device)
    end = time.time()
    training_time = end - start

    start = time.time()
    train_acc = test(test_batch_size_data_loader,device)
    end = time.time()
    test_time = end - start

    acc_str += f'{train_acc:.4f}\n'
    print(f'Batch Size: {test_b_size}')
    print(f'Training Time: {training_time}')
    print(f'Test Time: {test_time}')
    print(f'Accuracy: {train_acc}')
    BENCHMARKING = False
else:
  #data_loader = build_scratch_loader(BATCH_SIZE) # testing
  data_list = build_reactome_graph_datalist(edge_v1, edge_v2, node_features_fn, graph_targets_fn)
  random.shuffle(data_list)
#   print(data_list)

  BENCHMARKING = True

In [None]:
if(BENCHMARKING):
  fold_size = 911
  fold = 'full_dataset'
#   >>> train =              z[:fold_size * (fold - 1)] +         z[fold_size * fold:]
#   train_data_list = data_list[:fold_size * (fold - 1)] + data_list[fold_size * fold:]
  #>>> test =              z[fold_size * (fold - 1):fold_size * fold]
  #test_data_list = data_list[fold_size * (fold - 1):fold_size * fold]
  train_data_list = data_list

  print(f'Number of training graphs: {len(train_data_list)}')
  #print(f'Number of test graphs: {len(test_data_list)}')
  train_data_loader = build_reactome_graph_loader(train_data_list,BATCH_SIZE)
  #test_data_loader = build_reactome_graph_loader(test_data_list,BATCH_SIZE)
  for epoch in range(EPOCHS):
    train(train_data_loader,device)
    train_acc = test(train_data_loader,device)
    #test_acc = test(test_data_loader,device) 
    acc_str += f'{train_acc:.4f}'#',{test_acc:.4f}\n'
    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}')#', Test Acc: {test_acc:.4f}')

  training_acc_fn = F"graph_classification_acc_rewired10_{fold}.txt"
  path = F"/mnt/home/yuankeji/RanceLab/reticula_new/reticula/data/GEO_model_training/GNN/{training_acc_fn}"
  with open(path, 'w') as writefile:
      writefile.write(acc_str)
  model_save_name = F"trained_pytorch_model_rewired10_fold_{fold}.pt"
  path = F"/mnt/home/yuankeji/RanceLab/reticula_new/reticula/data/GEO_model_training/GNN/{model_save_name}" 
  torch.save(model.state_dict(), path)
  print(F"model saved as {path}")
  # real network gets to 0.8417

In [None]:
DRAWING = True
if(DRAWING):
  import networkx as nx
  import numpy as np
  from torch_geometric.utils import to_networkx


  def draw_molecule(g, edge_mask=None, draw_edge_labels=False):
      g = g.copy().to_undirected()
      node_labels = {}
      for u, data in g.nodes(data=True):
          node_labels[u] = data['name']
      pos = nx.spring_layout(g)
      if edge_mask is None:
          edge_color = 'black'
          widths = None
      else:
          edge_color = [edge_mask[(u, v)] for u, v in g.edges()]
          widths = [x * 10 for x in edge_color]
      nx.draw(g, pos=pos, labels=node_labels, width=widths,
              edge_color=edge_color, edge_cmap=plt.cm.Blues,
              node_color='azure')
      
      if draw_edge_labels and edge_mask is not None:
          edge_labels = {k: ('%.2f' % v) for k, v in edge_mask.items()}    
          nx.draw_networkx_edge_labels(g, pos, edge_labels=edge_labels,
                                      font_color='red')
      plt.show()


  def to_molecule(data):
      g = to_networkx(data, node_attrs=['x'])
      print('g',g)
      for u, data in g.nodes(data=True):
          data['name'] = data['x']
          del data['x']
      print(data,g)
      return g

In [None]:
if(DRAWING):
  from captum.attr import Saliency, IntegratedGradients

  def model_forward(edge_mask, data):
      batch = torch.zeros(data.x.shape[0], dtype=int).to(device)
      out = model(data.x,
                  data.edge_index, 
                  batch,
                  edge_mask)
      return out


  def explain(method, data, target=0):
      input_mask = torch.ones(data.edge_index.shape[1]).requires_grad_(True).to(device)
      if method == 'ig':
          ig = IntegratedGradients(model_forward)
          mask = ig.attribute(input_mask,target=target,
                              additional_forward_args=(data,),
                              internal_batch_size=data.edge_index.shape[1])
      elif method == 'saliency':
          saliency = Saliency(model_forward)
          mask = saliency.attribute(input_mask, target=target,
                                    additional_forward_args=(data,))
      else:
          raise Exception('Unknown explanation method')

      edge_mask = np.abs(mask.cpu().detach().numpy())
      if edge_mask.max() > 0:  # avoid division by zero
          edge_mask = edge_mask / edge_mask.max()
      return edge_mask

In [None]:
print(edge_v1, edge_v2, node_features_fn, graph_targets_fn,BATCH_SIZE)

In [None]:
if(DRAWING):
    data_list = build_reactome_graph_datalist(edge_v1, edge_v2, node_features_fn, graph_targets_fn)
    data_loader = build_reactome_graph_loader(data_list, BATCH_SIZE)

In [None]:
if(DRAWING):
  model = GNN(hidden_channels=HIDDEN_CHANNELS)
  model = model.to(device)

  model_save_name = 'trained_pytorch_model_rewired10_fold_full_dataset.pt'
  path = F"/mnt/home/yuankeji/RanceLab/reticula_new/reticula/data/GEO_model_training/GNN/{model_save_name}" 
  model.load_state_dict(torch.load(path))
  model.eval()

In [None]:
if(DRAWING):
  d = data_loader.dataset[0]
  d.edge_index.shape[1]

In [None]:
print(data_loader.dataset[0])
print(data_list)
mol = to_molecule(data)

In [None]:
if(DRAWING):
  import random
  from collections import defaultdict

  def aggregate_edge_directions(edge_mask, data):
      edge_mask_dict = defaultdict(float)
      for val, u, v in list(zip(edge_mask, *data.edge_index)):
          u, v = u.item(), v.item()
          if u > v:
              u, v = v, u
          edge_mask_dict[(u, v)] += val
      return edge_mask_dict
      
  data = data_loader.dataset[0]
  mol = to_molecule(data) # 'float' object has no attribute 'index'

  for title, method in [('Integrated Gradients', 'ig'), ('Saliency', 'saliency')]:
      data.to(device)
      edge_mask = explain(method, data, target=0)
      edge_mask_dict = aggregate_edge_directions(edge_mask, data)
      plt.figure(figsize=(100, 50))
      plt.title(title)
      draw_molecule(mol, edge_mask_dict)


In [None]:
print(data.edge_index)

In [None]:
  def explain(method, data, target=0):
      input_mask = torch.ones(data.edge_index.shape[1]).requires_grad_(True).to(device)
      print('input_mask', input_mask)
      if method == 'ig':
          ig = IntegratedGradients(model_forward)
          print('ig=', ig)
          mask = ig.attribute(input_mask,target=target,
                              additional_forward_args=(data,),
                              internal_batch_size=data.edge_index.shape[1])
          print('ig_mask', mask)
      elif method == 'saliency':
          saliency = Saliency(model_forward)
          print('saliency=', saliency)
          mask = saliency.attribute(input_mask, target=target,
                                    additional_forward_args=(data,))
          print('saliency_mask', mask)
      else:
          raise Exception('Unknown explanation method')

      edge_mask = np.abs(mask.cpu().detach().numpy())
      print('edge_mask', edge_mask)
      if edge_mask.max() > 0:  # avoid division by zero
          edge_mask = edge_mask / edge_mask.max()
      return edge_mask

In [None]:
  def aggregate_edge_directions(edge_mask, data):
      edge_mask_dict = defaultdict(float)
#       print('edge_mask_dict', edge_mask_dict)
      for val, u, v in list(zip(edge_mask, *data.edge_index)):
#           print("-----------")
#           print(val,u,v)
#           print("-----------")
          u, v = u.item(), v.item()
#           print(u,v)
          if u > v:
              u, v = v, u
          edge_mask_dict[(u, v)] += val
#           print('*****')
#           print(edge_mask_dict)
      return edge_mask_dict

In [None]:
print(mol.nodes(data=True))

In [None]:
  def draw_molecule(g, edge_mask=None, draw_edge_labels=False):
      g = g.copy().to_undirected()
      node_labels = {}
      for u, data in g.nodes(data=True):
          node_labels[u] = data['name']
      pos = nx.spring_layout(g)
      print(pos)
      if edge_mask is None:
          edge_color = 'black'
          widths = None
      else:
          edge_color = [edge_mask[(u, v)] for u, v in g.edges()]
          widths = [x * 10 for x in edge_color]
      print("---------------")
      print('edge_color', edge_color)
      print("---------------")
      print('widths', widths)
      nx.draw(g, pos=pos, labels=node_labels, width=widths,
              edge_color=edge_color, edge_cmap=plt.cm.Blues,
              node_color='azure')
      
      if draw_edge_labels and edge_mask is not None:
          edge_labels = {k: ('%.2f' % v) for k, v in edge_mask.items()}    
          nx.draw_networkx_edge_labels(g, pos, edge_labels=edge_labels,
                                      font_color='red')
      plt.show()


In [None]:
    for title, method in [('Integrated Gradients', 'ig'), ('Saliency', 'saliency')]:
      data.to(device)
      edge_mask = explain(method, data, target=0)
      print('edge_mask', edge_mask)
      edge_mask_dict = aggregate_edge_directions(edge_mask, data)
      plt.figure(figsize=(100, 50))
      plt.title(title)
      draw_molecule(mol, edge_mask_dict)