# Part1: Predicting Molecular Properties


## Deep learning on graphs: 

In this hands-on exercise, we will guide you to design and train your first Graph Neural Network. We will focus here on a graph classification task. In this setting, you're given a set of graphs associated with a label. The goal is to model this set of graphs in order to extract relevant information from 1. the graph topology and 2. the node features to predict the associated label. In this example, the graphs are molecules where the nodes are the atoms and the edges the chemical bonds. More specifically, we are using the MUTAG dataset that consists of 188 chemical compounds divided into two 
classes according to their mutagenic effect on a bacterium. 

## Objectives:

- Familiarise yourself with the concept of graph. 
- Introduce the Deep Graph Library (DGL) -- the de facto python library for learning on graphs
- Design from scratch an instance of GNN: the Graph Isomorphism Network (GIN)
- Train a graph classification task: predicting the mutagenicity of a molecule 


## A.) Build a graph in DGL: the `DGLGraph`
- Declare a graph is as simple as `g = dgl.DGLGraph()`
- Add nodes with `add_nodes()`
- Add egdes with `add_edges()`
- Add node features using the `ndata` attribute

In [1]:
import dgl 
import numpy as np 
import torch 

# 1. declare a graph 
graph = dgl.DGLGraph()

# 2. add 10 nodes
graph.add_nodes(10)

# 3. add 7 random edges 
from_ = np.random.randint(0, 9, 7)
to_ = np.random.randint(0, 9, 7)
graph.add_edges(from_, to_)

# 4. add node features 
graph.ndata['attr'] = torch.randn(10, 5)

def print_graph_properties(g):
  print('Graph has the following properties:')
  print('- {} nodes'.format(g.number_of_nodes()))
  for key, val in g.ndata.items():
    print('- "{}" with {} node features'.format(key, list(val.shape)))
  print('- {} edges'.format(g.number_of_edges()))
  
print_graph_properties(graph)

Using backend: pytorch


Graph has the following properties:
- 10 nodes
- "attr" with [10, 5] node features
- 7 edges




## B.) Dataloading with DGL:

### DGL built-in dataloader 
- DGL provides a set of built-in dataloader for common datasets (eg `dgl.data.GINDataset`, `dgl.data.TUDataset`)
- This module is analogous to the `torchvision` library that provides an API to load popular computer vision datasets. 

### DGL is built around `networkx` & PyTorch 
- Allows to (partially) use the `networkx` API when dealing with the graph objects 
- Conversion from `DGLGraph` to networkx `Graph` is straightforward
- Allows to use the PyTorch API when manipulating the node & edges features 

### How to build a batch of graphs: The DGL `batch` 
- For graph classification, as in image classification, we need to build a batch, ie a set of samples that are fed to the model. 
- As opposed to image classification where each sample can be resized and padded to obtain the same size, adopting the same approach with graphs is not feasible. 
- A DGL batch is built using the observation that a set of N graphs can be represented as one large disconnected graph made of N connected components. 

In [2]:
import dgl 
import torch 
from torch.utils.data import DataLoader
import random 

# 1. load the data: graphs and labels 
data = dgl.data.GINDataset('MUTAG', self_loop=False)

# 2. Inspect manually the data by printing one of the samples
g, label = data[0]
print_graph_properties(g)

# 3. Batchify and train/val split the data
data = list(zip(data.graphs, data.labels))
random.shuffle(data)
train_data = data[:int(len(data)* 0.7)]
val_data = data[int(len(data)*0.7):]
batch_size = 8

def collate(batch):
  g = dgl.batch([example[0] for example in batch])
  l = torch.LongTensor([example[1] for example in batch])
  return g, l

train_dataloader = DataLoader(train_data, batch_size, shuffle=True, collate_fn=collate)
val_dataloader = DataLoader(train_data, batch_size, collate_fn=collate)

Graph has the following properties:
- 23 nodes
- "attr" with [23, 7] node features
- "label" with [23] node features
- 54 edges


## C.) Designing a DGL model

### Object oriented approach
- Define your Graph Neural Network layer as a python object 
- Define a model object that is:
  - instantiating several GNN layers 
  - implementing a global pooling operation, e.g., with a sum
  - projecting the graph embedding to the number of classes using an MLP
  
### The Graph Isomorphism Network
- The GIN proposes to update each node as:
\begin{equation}
h_v^{(k)} = \mbox{MLP}^{(k)} (h_v^{(k-1)} + \sum_{u \in N(v)} h_u^{(k-1)})
\end{equation}
- The graph-level embedding is then obtained using:
\begin{equation}
h_G = \sum_{v \in V} h_v^{(k)}
\end{equation}
- The DGL library also provides a high-level API to directly load Graph Neural Network layers in the `dgl.nn.pytorch.conv` module. 

In [3]:
import torch
import torch.nn as nn 
import torch.nn.functional as F


class GINLayer(nn.Module):

    def __init__(self, node_dim, out_dim):
        """
        Implementation of a GIN (Graph Isomorphism Network) layer.

        Original paper:
          - How Powerful are Graph Neural Networks: https://arxiv.org/abs/1810.00826
          - Author's public implementation: https://github.com/weihua916/powerful-gnns
          
        :param node_dim: (int) input dimension of the node features
        :param out_dim: (int) output dimension of the node features 
        """
        
        super(GINLayer, self).__init__()

        self.batchnorm_h = nn.BatchNorm1d(out_dim)

        self.mlp = nn.Sequential(
        nn.Linear(node_dim, out_dim),
        nn.ReLU(),
        nn.Linear(out_dim, out_dim),
      )

    def msg_fn(self, edges):
        """
        Message of each node
        """
        return {'msg': edges.src['h']}
      
    def reduce_fn(self, nodes):
        """
        For each node, aggregate the nodes using a reduce function.
        Current supported functions are sum and mean.
        """
        accum = torch.sum(nodes.mailbox['msg'], dim=1)
        return {'agg_msg': accum}

    def node_update_fn(self, nodes):
        """
        Node update function
        """
        h = nodes.data['h']
        h = self.mlp(h)
        h = F.relu(h)
        return {'h_out': h}

    def forward(self, g, h):
        """
        Forward-pass of a GIN layer.
        :param g: (DGLGraph) graph to process
        :param h: (FloatTensor) node features
        """

        # 1. set node features to g
        g.ndata['h'] = h

        # 2. message passing 
        g.update_all(self.msg_fn, self.reduce_fn)
        g.ndata['h'] = g.ndata.pop('agg_msg') + g.ndata.pop('h')
        g.apply_nodes(func=self.node_update_fn)

        # 3. pop node features & apply batch norm 
        h = g.ndata.pop('h_out')
        h = self.batchnorm_h(h)

        return h

      
class Model(nn.Module):

    def __init__(self, node_dim, out_dim, num_layers, num_classes):
      super(Model, self).__init__()
      
      # 1. define series of GNN layer 
      self.layers = nn.ModuleList()
      for layer_id in range(num_layers):
        self.layers.append(GINLayer(node_dim if layer_id == 0 else out_dim, out_dim))
        
      # 2. define classifier 
      self.classifier = nn.Sequential(
        nn.Linear(out_dim, out_dim),
        nn.ReLU(),
        nn.Linear(out_dim, num_classes)
      )

    def forward(self, g):
      
      # 1. loop over the GNN layers 
      h = g.ndata['attr'].type('torch.FloatTensor')
      for layer in self.layers:
        h = layer(g, h)
        
      # 2. apply pooling to build fixed-size representation
      g.ndata['h'] = h
      g_emb = dgl.sum_nodes(g, 'h')
      
      # 3. apply classifier to get the logits 
      logits = self.classifier(g_emb)
      
      return logits


## D.) Define the training and testing loop

### Use classic PyTorch training loop 
- Define the model parameters (num layers, GNN dimensions)
- Define the training parameters (optimizer, learning rate, weight decay, number of epochs)

In [5]:
import torch
from tqdm import trange

# with cuda?
cuda = torch.cuda.is_available()
device = 'cuda:0' if cuda else 'cpu'

# declare model
model = Model(
  node_dim=g.ndata['attr'].shape[1],
  out_dim=32,
  num_layers=3,
  num_classes=2
)
model.to(device)

# build optimizer
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=10e-3,
    weight_decay=5e-4
)

# define loss function
loss_fn = torch.nn.CrossEntropyLoss()

# training 
val_loss = 10e5
val_accuracy = 0.

with trange(50) as t:
  for epoch in t:
    t.set_description('Validation with loss={} | accuracy={}'.format(val_loss, val_accuracy))
    # A.) train for 1 epoch 
    model.train()
    for graphs, labels in train_dataloader:

        # 1. forward pass
        labels = labels.to(device)
        logits = model(graphs)

        # 2. backward pass
        loss = loss_fn(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # B.) validate
    model.eval()
    all_val_logits = []
    all_val_labels = []
    for graphs, labels in val_dataloader:
        with torch.no_grad():
            labels = labels.to(device)
            logits = model(graphs)
        all_val_logits.append(logits)
        all_val_labels.append(labels)

    all_val_logits = torch.cat(all_val_logits).cpu()
    all_val_labels = torch.cat(all_val_labels).cpu()

    with torch.no_grad():
        val_loss = round(loss_fn(all_val_logits, all_val_labels).item(), 2)
        _, predictions = torch.max(all_val_logits, dim=1)
        correct = torch.sum(predictions.to(int) == all_val_labels.to(int))
        val_accuracy = round(correct.item() * 1.0 / len(all_val_labels), 2)


Validation with loss=0.25 | accuracy=0.878: 100%|██████████| 50/50 [00:20<00:00,  2.44it/s]  
