In [1]:
# Install required packages.
#!pip install torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
!pip install -q torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html
!pip install -q torch-sparse -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html
!pip install -q torch-geometric
!pip install -q sklearn
#!pip install -q captum

#%matplotlib inline
#import matplotlib.pyplot as plt

[K     |████████████████████████████████| 2.6MB 11.9MB/s 
[K     |████████████████████████████████| 1.5MB 11.8MB/s 
[K     |████████████████████████████████| 215kB 13.2MB/s 
[K     |████████████████████████████████| 235kB 11.2MB/s 
[K     |████████████████████████████████| 2.2MB 24.0MB/s 
[K     |████████████████████████████████| 51kB 6.9MB/s 
[?25h  Building wheel for torch-geometric (setup.py) ... [?25l[?25hdone


In [2]:
import torch
import numpy
import sklearn
import random
import time
import torch.nn.functional as F
from IPython.display import Javascript
from google.colab import drive
from torch.nn import Linear
from sklearn import preprocessing
from torch_geometric.datasets import TUDataset
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GraphConv, global_mean_pool

random.seed = 88888888

In [3]:
# load graph_targets.txt, node_features.txt and edges.txt
drive.mount('/content/gdrive')

!ls '/content/gdrive/My Drive/Academia/OHSU/Proposal/pathway_hierarchy_graph_targets.txt'
!ls '/content/gdrive/My Drive/Academia/OHSU/Proposal/pathway_hierarchy_edges.txt'
!ls '/content/gdrive/My Drive/Academia/OHSU/Proposal/pathway_hierarchy_node_features.txt'

Mounted at /content/gdrive
'/content/gdrive/My Drive/Academia/OHSU/Proposal/pathway_hierarchy_graph_targets.txt'
'/content/gdrive/My Drive/Academia/OHSU/Proposal/pathway_hierarchy_edges.txt'
'/content/gdrive/My Drive/Academia/OHSU/Proposal/pathway_hierarchy_node_features.txt'


In [4]:
#!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp36-cp36m-linux_x86_64.whl
#import torch_xla
#import torch_xla.core.xla_model as xm
#device = xm.xla_device() # google TPU device may require Runtime -> Factory reset runtime
device = cuda0 = torch.device('cuda:0')
cpu = torch.device('cpu')

# from https://discuss.pytorch.org/t/how-to-define-train-mask-val-mask-test-mask-in-my-own-dataset/56289
edges_fn = '/content/gdrive/My Drive/Academia/OHSU/Proposal/pathway_hierarchy_edges.txt'
node_features_fn = '/content/gdrive/My Drive/Academia/OHSU/Proposal/pathway_hierarchy_node_features.txt'
graph_targets_fn = '/content/gdrive/My Drive/Academia/OHSU/Proposal/pathway_hierarchy_graph_targets.txt'

# magic numbers
INPUT_CHANNELS = 1
OUTPUT_CHANNELS = 51
HIDDEN_CHANNELS = 64
BATCH_SIZE = 64
EPOCHS = 500 #set this to 200 - 2000
BENCHMARKING = False

In [5]:
def read_reactome_graph(edges_fn, node_features_fn):
    edge_v1 = []
    edge_v2 = []

    for line in open(edges_fn, 'r'):
        data = line.split()
        node1 = int(data[0]) - 1 #subtracting to convert R idx to python idx
        node2 = int(data[1]) - 1 # " "
        edge_v1.append( node1 )
        edge_v2.append( node2 )

    return edge_v1, edge_v2

In [6]:
(edge_v1, edge_v2) = read_reactome_graph(edges_fn, node_features_fn)

In [7]:
def build_scratch_loader(batch_size):
  dataset = TUDataset(root='data/TUDataset', name='MUTAG')
  data_list = []
  for graph_obj in dataset:
    x = torch.tensor(graph_obj.x[:,1],dtype=torch.float)
    x = x.unsqueeze(1)
    y = graph_obj.y
    edge_index = graph_obj.edge_index
    data_list.append(Data(x = x, y = y, edge_index = edge_index))

  loader = DataLoader(data_list,batch_size=batch_size,shuffle=True)

  return loader

In [8]:
def build_reactome_graph_datalist(edge_v1, edge_v2, node_features_fn, graph_targets_fn):
    edge_index = torch.tensor([edge_v1, edge_v2], dtype = torch.long)
    feature_v = numpy.loadtxt(node_features_fn)
    target_v = numpy.loadtxt(graph_targets_fn,dtype=str,delimiter="\n")
    
    target_encoder = sklearn.preprocessing.LabelEncoder()
    target_v = target_encoder.fit_transform(target_v)

    data_list = []
    for row_idx in range(len(feature_v)):
      features = feature_v[row_idx,:]
      random.shuffle(features)
      x = torch.tensor(features,dtype=torch.float)
      x = x.unsqueeze(1)
      y = torch.tensor([target_v[row_idx]])
      data_list.append(Data(x = x, y = y, edge_index = edge_index))

    return data_list

def build_reactome_graph_loader(data_list,batch_size):

    loader = DataLoader(data_list,batch_size=batch_size,shuffle=True)

    return loader

In [9]:
# from https://colab.research.google.com/drive/1I8a0DfQ3fI7Njc62__mVXUlcAleUclnb?usp=sharing#scrollTo=CN3sRVuaQ88l
class GNN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GNN, self).__init__()

        self.conv1 = GraphConv(INPUT_CHANNELS, hidden_channels)
        self.conv2 = GraphConv(hidden_channels,hidden_channels)
        self.conv3 = GraphConv(hidden_channels,hidden_channels)
        self.lin = Linear(hidden_channels, OUTPUT_CHANNELS)

    def forward(self, x, edge_index, batch, edge_weight=None):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index, edge_weight)
        x = x.relu()
        x = self.conv2(x, edge_index, edge_weight)
        x = x.relu()
        x = self.conv3(x, edge_index, edge_weight)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, training=self.training)
        x = self.lin(x)
        
        return x

This might keep our session alive

from https://medium.com/@shivamrawat_756/how-to-prevent-google-colab-from-disconnecting-717b88a128c0

ctrl + shift + i

function ClickConnect(){{console.log("Working");document.querySelector("colab-connect-button").shadowRoot.getElementById('connect').click();}

In [10]:
display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))

model = GNN(hidden_channels=HIDDEN_CHANNELS)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters())
criterion = torch.nn.CrossEntropyLoss()

def train(loader,device):
  model.train()

  for batch in loader:  # Iterate in batches over the training dataset.
    x = batch.x.to(device)
    e = batch.edge_index.to(device)
    b = batch.batch.to(device)
    y = batch.y.to(device)
    out = model(x, e, b)  # Perform a single forward pass.
    loss = criterion(out, y)  # Compute the loss.
    loss.backward()  # Derive gradients.
    optimizer.step()  # Update parameters based on gradients.
    optimizer.zero_grad()  # Clear gradients.

def test(loader,device):
  model.eval()

  correct = 0
  for batch in loader:  # Iterate in batches over the training/test dataset.
    x = batch.x.to(device)
    e = batch.edge_index.to(device)
    b = batch.batch.to(device)
    y = batch.y.to(device)
    out = model(x, e, b)  # Perform a single forward pass.
    loss = criterion(out, y)  # Compute the loss.
    pred = out.argmax(dim=1)  # Use the class with highest probability.
    correct += int((pred == y).sum())  # Check against ground-truth labels.
  return correct / len(loader.dataset)  # Derive ratio of correct predictions.

<IPython.core.display.Javascript object>

In [11]:
acc_str = ''
if(BENCHMARKING):

  test_b_sizes = [1,8,16,32,64,128]

  for test_b_size in test_b_sizes:
    print(f'Executing training routine with batch size = {test_b_size}')
    data_list = build_reactome_graph_datalist(edge_v1, edge_v2, node_features_fn, graph_targets_fn)
    test_batch_size_data_loader = build_reactome_graph_loader(data_list,test_b_size)
  
    start = time.time()
    train(test_batch_size_data_loader,device)
    end = time.time()
    training_time = end - start

    start = time.time()
    train_acc = test(test_batch_size_data_loader,device)
    end = time.time()
    test_time = end - start

    acc_str += f'{train_acc:.4f}\n'
    print(f'Batch Size: {test_b_size}')
    print(f'Training Time: {training_time}')
    print(f'Test Time: {test_time}')
    print(f'Accuracy: {train_acc}')
    BENCHMARKING = False
else:
  #data_loader = build_scratch_loader(BATCH_SIZE) # testing
  data_list = build_reactome_graph_datalist(edge_v1, edge_v2, node_features_fn, graph_targets_fn)
  random.shuffle(data_list)

  BENCHMARKING = True

In [12]:
if(BENCHMARKING):
  fold_size = 911
  fold = 'full_dataset'
  #>>> train =              z[:fold_size * (fold - 1)] +         z[fold_size * fold:]
  #train_data_list = data_list[:fold_size * (fold - 1)] + data_list[fold_size * fold:]
  #>>> test =              z[fold_size * (fold - 1):fold_size * fold]
  #test_data_list = data_list[fold_size * (fold - 1):fold_size * fold]
  train_data_list = data_list

  print(f'Number of training graphs: {len(train_data_list)}')
  #print(f'Number of test graphs: {len(test_data_list)}')
  train_data_loader = build_reactome_graph_loader(train_data_list,BATCH_SIZE)
  #test_data_loader = build_reactome_graph_loader(test_data_list,BATCH_SIZE)
  for epoch in range(EPOCHS):
    train(train_data_loader,device)
    train_acc = test(train_data_loader,device)
    #test_acc = test(test_data_loader,device)
    acc_str += f'{train_acc:.4f}'#',{test_acc:.4f}\n'
    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}')#', Test Acc: {test_acc:.4f}')

  training_acc_fn = F"graph_classification_acc_rewire2_fold_{fold}.txt"
  path = F"/content/gdrive/My Drive/Academia/OHSU/Proposal/pathway_hierarchy_{training_acc_fn}"
  with open(path, 'w') as writefile:
      writefile.write(acc_str)
  model_save_name = F"trained_pytorch_model_rewire2_fold_{fold}.pt"
  path = F"/content/gdrive/My Drive/Academia/OHSU/Proposal/pathway_hierarchy_{model_save_name}" 
  torch.save(model.state_dict(), path)
  print(F"model saved as {path}")

Number of training graphs: 9115
Epoch: 000, Train Acc: 0.1889
Epoch: 001, Train Acc: 0.2564
Epoch: 002, Train Acc: 0.2792
Epoch: 003, Train Acc: 0.2914
Epoch: 004, Train Acc: 0.2811
Epoch: 005, Train Acc: 0.2827
Epoch: 006, Train Acc: 0.3032
Epoch: 007, Train Acc: 0.3024
Epoch: 008, Train Acc: 0.3016
Epoch: 009, Train Acc: 0.3012
Epoch: 010, Train Acc: 0.3010
Epoch: 011, Train Acc: 0.3073
Epoch: 012, Train Acc: 0.3149
Epoch: 013, Train Acc: 0.3280
Epoch: 014, Train Acc: 0.3217
Epoch: 015, Train Acc: 0.3296
Epoch: 016, Train Acc: 0.3214
Epoch: 017, Train Acc: 0.3433
Epoch: 018, Train Acc: 0.3424
Epoch: 019, Train Acc: 0.3414
Epoch: 020, Train Acc: 0.3427
Epoch: 021, Train Acc: 0.3461
Epoch: 022, Train Acc: 0.3511
Epoch: 023, Train Acc: 0.3410
Epoch: 024, Train Acc: 0.3546
Epoch: 025, Train Acc: 0.3528
Epoch: 026, Train Acc: 0.3540
Epoch: 027, Train Acc: 0.3638
Epoch: 028, Train Acc: 0.3579
Epoch: 029, Train Acc: 0.3630
Epoch: 030, Train Acc: 0.3637
Epoch: 031, Train Acc: 0.3736
Epoch: 0

In [13]:
DRAWING = False
if(DRAWING):
  import networkx as nx
  import numpy as np

  from torch_geometric.utils import to_networkx


  def draw_molecule(g, edge_mask=None, draw_edge_labels=False):
      g = g.copy().to_undirected()
      node_labels = {}
      for u, data in g.nodes(data=True):
          node_labels[u] = data['name']
      pos = nx.spring_layout(g)
      if edge_mask is None:
          edge_color = 'black'
          widths = None
      else:
          edge_color = [edge_mask[(u, v)] for u, v in g.edges()]
          widths = [x * 10 for x in edge_color]
      nx.draw(g, pos=pos, labels=node_labels, width=widths,
              edge_color=edge_color, edge_cmap=plt.cm.Blues,
              node_color='azure')
      
      if draw_edge_labels and edge_mask is not None:
          edge_labels = {k: ('%.2f' % v) for k, v in edge_mask.items()}    
          nx.draw_networkx_edge_labels(g, pos, edge_labels=edge_labels,
                                      font_color='red')
      plt.show()


  def to_molecule(data):
      g = to_networkx(data, node_attrs=['x'])
      for u, data in g.nodes(data=True):
          data['name'] = data['x']
          del data['x']
      return g

In [14]:
if(DRAWING):
  from captum.attr import Saliency, IntegratedGradients

  def model_forward(edge_mask, data):
      batch = torch.zeros(data.x.shape[0], dtype=int).to(device)
      out = model(data.x,
                  data.edge_index, 
                  batch,
                  edge_mask)
      return out


  def explain(method, data, target=0):
      input_mask = torch.ones(data.edge_index.shape[1]).requires_grad_(True).to(device)
      if method == 'ig':
          ig = IntegratedGradients(model_forward)
          mask = ig.attribute(input_mask,target=target,
                              additional_forward_args=(data,),
                              internal_batch_size=data.edge_index.shape[1])
      elif method == 'saliency':
          saliency = Saliency(model_forward)
          mask = saliency.attribute(input_mask, target=target,
                                    additional_forward_args=(data,))
      else:
          raise Exception('Unknown explanation method')

      edge_mask = np.abs(mask.cpu().detach().numpy())
      if edge_mask.max() > 0:  # avoid division by zero
          edge_mask = edge_mask / edge_mask.max()
      return edge_mask

In [15]:
if(DRAWING):
  data_loader = build_reactome_graph_loader(edge_v1, edge_v2, node_features_fn, graph_targets_fn,BATCH_SIZE) # reactome

In [16]:
if(DRAWING):
  model = GNN(hidden_channels=HIDDEN_CHANNELS)
  model = model.to(device)

  model_save_name = 'trained_pytorch_model.pt'
  path = F"/content/gdrive/My Drive/Academia/OHSU/Proposal/{model_save_name}" 
  model.load_state_dict(torch.load(path))
  model.eval()

In [17]:
if(DRAWING):
  d = data_loader.dataset[0]
  d.edge_index.shape[1]

In [18]:
if(DRAWING):
  import random
  from collections import defaultdict

  def aggregate_edge_directions(edge_mask, data):
      edge_mask_dict = defaultdict(float)
      for val, u, v in list(zip(edge_mask, *data.edge_index)):
          u, v = u.item(), v.item()
          if u > v:
              u, v = v, u
          edge_mask_dict[(u, v)] += val
      return edge_mask_dict
      
  data = data_loader.dataset[0]
  mol = to_molecule(data) # 'float' object has no attribute 'index'

  for title, method in [('Integrated Gradients', 'ig'), ('Saliency', 'saliency')]:
      data.to(device)
      edge_mask = explain(method, data, target=0)
      edge_mask_dict = aggregate_edge_directions(edge_mask, data)
      plt.figure(figsize=(100, 50))
      plt.title(title)
      draw_molecule(mol, edge_mask_dict)