<a href="https://colab.research.google.com/github/kadeng/colab_tutorials/blob/master/docs/torchdrug/TorchDrug_Pretraining_and_Finetuning_Tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!apt-get install ninja-build

Reading package lists... Done
Building dependency tree       
Reading state information... Done
The following package was automatically installed and is no longer required:
  libnvidia-common-460
Use 'apt autoremove' to remove it.
The following NEW packages will be installed:
  ninja-build
0 upgraded, 1 newly installed, 0 to remove and 42 not upgraded.
Need to get 93.3 kB of archives.
After this operation, 296 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu bionic/universe amd64 ninja-build amd64 1.8.2-1 [93.3 kB]
Fetched 93.3 kB in 1s (114 kB/s)
Selecting previously unselected package ninja-build.
(Reading database ... 155629 files and directories currently installed.)
Preparing to unpack .../ninja-build_1.8.2-1_amd64.deb ...
Unpacking ninja-build (1.8.2-1) ...
Setting up ninja-build (1.8.2-1) ...
Processing triggers for man-db (2.8.3-2ubuntu0.1) ...


In [2]:
import os
import sys
import torch
if not os.path.exists("/gdrive/MyDrive/colab/pipenv"):
  from google.colab import drive
  drive.mount('/gdrive')
if not os.path.exists("/gdrive/MyDrive/colab/pipenv"):
  print("Installing packages..")
  os.path.makedirs("/gdrive/MyDrive/colab/pipenv", exist_ok=True)
  os.environ["TORCH_VERSION"] = torch.__version__
  !pip install -t /gdrive/MyDrive/colab/pipenv/ torch-scatter -f https://pytorch-geometric.com/whl/torch-$TORCH_VERSION.html
  !pip install -t /gdrive/MyDrive/colab/pipenv/ git+https://github.com/DeepGraphLearning/torchdrug
  !pip install -t /gdrive/MyDrive/colab/pipenv/ ninja
  
  print("Done installing packages")
sys.path.insert(0, "/gdrive/MyDrive/colab/pipenv")
print("OK")

Mounted at /gdrive
OK


### Introduction

In many drug discovery tasks, it is costly in both time and money to collect labeled data. As a solution, self-supervised pretraining is introduced to learn molecular representations from massive unlabeled data.

In this tutorial, we will demonstrate how to pretrain a graph neural network on molecules, and how to finetune the model on downstream tasks.

### Manual Steps

0.   Get your own copy of this file via "File > Save a copy in Drive...",
1.   Set the runtime to **GPU** via "Runtime > Change runtime type..."

### Colab Tutorials

#### Quick Start
1. [Basic Usage and Pipeline](https://colab.research.google.com/drive/1Tbnr1Fog_YjkqU1MOhcVLuxqZ4DC-c8-#forceEdit=true&sandboxMode=true)

#### Drug Discovery Tasks
1. [Property Prediction](https://colab.research.google.com/drive/1sb2w3evdEWm-GYo28RksvzJ74p63xHMn?usp=sharing#forceEdit=true&sandboxMode=true)
2. [Pretrained Molecular Representations](https://colab.research.google.com/drive/10faCIVIfln20f2h1oQk2UrXiAMqZKLoW?usp=sharing#forceEdit=true&sandboxMode=true)
3. [De Novo Molecule Design](https://colab.research.google.com/drive/1JEMiMvSBuqCuzzREYpviNZZRVOYsgivA?usp=sharing#forceEdit=true&sandboxMode=true)
4. [Retrosynthesis](https://colab.research.google.com/drive/1IH1hk7K3MaxAEe5m6CFY7Eyej3RuiEL1?usp=sharing#forceEdit=true&sandboxMode=true)
5. [Knowledge Graph Reasoning](https://colab.research.google.com/drive/1-sjqQZhYrGM0HiMuaqXOiqhDNlJi7g_I?usp=sharing#forceEdit=true&sandboxMode=true)

# Self-Supervised Pretraining

Pretraining is an effective approach to transfer learning in Graph Neural Networks for graph-level property prediction. Here we focus on pretraining GNNs via different self-supervised strategies. These methods typically construct unsupervised loss functions based on structural information in molecules.

For illustrative purpose, we only use the ClinTox dataset in this tutorial, which is much smaller than the standard pretraining datasets. For real applications, we suggest using larger datasets like ZINC2M.



## Infograph

InfoGraph (IG) proposes to maximize the mutual information between the graph-level and node-level representations. It learns the model by distinguishing whether a node-graph pair comes from a single graph or two different graphs. The following figure illustrates the high-level idea of InfoGraph.

![infograph.png](https://raw.githubusercontent.com/DeepGraphLearning/torchdrug/master/asset/model/infograph.png)

We use GIN as our graph represenation model, and wrap it with InfoGraph.


In [3]:
import torch
from torch import nn
from torch.utils import data as torch_data

from torchdrug import core, datasets, tasks, models

dataset = datasets.ClinTox("/gdrive/MyDrive/colab/molecule-datasets/", node_feature="pretrain",
                           edge_feature="pretrain")

gin_model = models.GIN(input_dim=dataset.node_feature_dim,
                       hidden_dims=[300, 300, 300, 300, 300],
                       edge_input_dim=dataset.edge_feature_dim,
                       batch_norm=True, readout="mean")
model = models.InfoGraph(gin_model, separate_model=False)

task = tasks.Unsupervised(model)
optimizer = torch.optim.Adam(task.parameters(), lr=1e-3)
solver = core.Engine(task, dataset, None, None, optimizer, gpus=[0], batch_size=256)

solver.train(num_epoch=10)
solver.save("clintox_gin_infograph.pth")

Loading /gdrive/MyDrive/colab/molecule-datasets/clintox.csv: 100%|██████████| 1485/1485 [00:00<00:00, 87951.39it/s]
Constructing molecules from SMILES: 100%|██████████| 1484/1484 [00:02<00:00, 689.04it/s]


13:42:10   {'batch_size': 256,
 'class': 'core.Engine',
 'gpus': [0],
 'gradient_interval': 1,
 'log_interval': 100,
 'logger': 'logging',
 'num_worker': 0,
 'optimizer': {'amsgrad': False,
               'betas': (0.9, 0.999),
               'class': 'optim.Adam',
               'eps': 1e-08,
               'lr': 0.001,
               'maximize': False,
               'weight_decay': 0},
 'scheduler': None,
 'task': Unsupervised(
  (model): InfoGraph(
    (model): GraphIsomorphismNetwork(
      (layers): ModuleList(
        (0): GraphIsomorphismConv(
          (batch_norm): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (mlp): MultiLayerPerceptron(
            (layers): ModuleList(
              (0): Linear(in_features=22, out_features=300, bias=True)
              (1): Linear(in_features=300, out_features=300, bias=True)
            )
          )
          (edge_linear): Linear(in_features=11, out_features=22, bias=True)
        )
        (

In [7]:
!mkdir -p /gdrive/MyDrive/colab/molecule-models/
solver.save("/gdrive/MyDrive/colab/molecule-models/clintox_gin_infograph.pth")

13:43:42   Save checkpoint to /gdrive/MyDrive/colab/molecule-models/clintox_gin_infograph.pth


## Attribute Masking

The aim of Attribute Masking (AM) is to capture domain knowledge by learning the regularities of the node/edge attributes distributed over graph structure. The high-level idea is to predict atom types in molecular graphs from randomly masked node features.

![attrmasking.png](https://raw.githubusercontent.com/DeepGraphLearning/torchdrug/master/asset/model/attribute_masking.png)

Again, we use GIN as our graph representation model.



In [22]:
import torch
from torch import nn
from torch.utils import data as torch_data

from torchdrug import core, datasets, tasks, models

dataset = datasets.ClinTox("/gdrive/MyDrive/colab/molecule-datasets/", node_feature="pretrain",
                           edge_feature="pretrain")

model = models.GIN(input_dim=dataset.node_feature_dim,
                   hidden_dims=[300, 300, 300, 300, 300],
                   edge_input_dim=dataset.edge_feature_dim,
                   batch_norm=True, readout="mean")
task = tasks.AttributeMasking(model, mask_rate=0.15)

optimizer = torch.optim.Adam(task.parameters(), lr=1e-3)
solver = core.Engine(task, dataset, None, None, optimizer, gpus=[0], batch_size=256)

solver.train(num_epoch=10)
solver.save("/gdrive/MyDrive/colab/molecule-models/clintox_gin_attributemasking.pth")

Loading /gdrive/MyDrive/colab/molecule-datasets/clintox.csv: 100%|██████████| 1485/1485 [00:00<00:00, 94021.40it/s]
Constructing molecules from SMILES: 100%|██████████| 1484/1484 [00:01<00:00, 841.17it/s]

13:48:04   {'batch_size': 256,
 'class': 'core.Engine',
 'gpus': [0],
 'gradient_interval': 1,
 'log_interval': 100,
 'logger': 'logging',
 'num_worker': 0,
 'optimizer': {'amsgrad': False,
               'betas': (0.9, 0.999),
               'class': 'optim.Adam',
               'eps': 1e-08,
               'lr': 0.001,
               'maximize': False,
               'weight_decay': 0},
 'scheduler': None,
 'task': {'class': 'tasks.AttributeMasking',
          'mask_rate': 0.15,
          'model': {'activation': 'relu',
                    'batch_norm': True,
                    'class': 'models.GIN',
                    'concat_hidden': False,
                    'edge_input_dim': 11,
                    'eps': 0,
                    'hidden_dims': [300, 300, 300, 300, 300],
                    'input_dim': 22,
                    'learn_eps': False,
                    'num_mlp_layer': 2,
                    'readout': 'mean',
                    'short_cut': False},
          'num




13:48:04   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
13:48:04   accuracy: 0
13:48:04   cross entropy: 4.8936
13:48:04   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
13:48:04   Epoch 0 end
13:48:04   duration: 0.28 secs
13:48:04   speed: 21.34 batch / sec
13:48:04   ETA: 2.53 secs
13:48:04   max GPU memory: 253.7 MiB
13:48:04   ------------------------------
13:48:04   average accuracy: 0.571284
13:48:04   average cross entropy: 3.52324
13:48:04   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
13:48:04   Epoch 1 begin
13:48:04   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
13:48:04   Epoch 1 end
13:48:04   duration: 0.27 secs
13:48:04   speed: 22.25 batch / sec
13:48:04   ETA: 2.20 secs
13:48:04   max GPU memory: 251.3 MiB
13:48:04   ------------------------------
13:48:04   average accuracy: 0.710235
13:48:04   average cross entropy: 1.35047
13:48:04   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
13:48:04   Epoch 2 begin
13:48:04   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
13:48:04   Epoch 2 end
13:48:04   duration: 0.25 secs
13:48:04   speed: 23.81 batch / sec
13

# Finetune on Labeled Datasets
When the GNN pre-training is finished, we can finetune the pre-trained GNN model on downstream tasks. Here we use BACE dataset for illustration, which contains 1,513 molecules with binding affinity results a set of inhibitors of human 𝛽-secretase 1(BACE-1).

First, we download the BACE dataset and split it into training, validation and test sets. Note that we need to set the node and edge feature in the dataset as pretrain in order to make it compatible with the pretrained model.



In [24]:
from torchdrug import data

dataset = datasets.BACE("/gdrive/MyDrive/colab/molecule-datasets/",
                        node_feature="pretrain", edge_feature="pretrain")
lengths = [int(0.8 * len(dataset)), int(0.1 * len(dataset))]
lengths += [len(dataset) - sum(lengths)]
train_set, valid_set, test_set = data.ordered_scaffold_split(dataset, lengths)

14:16:00   Downloading http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/bace.csv to /gdrive/MyDrive/colab/molecule-datasets/bace.csv


Loading /gdrive/MyDrive/colab/molecule-datasets/bace.csv: 100%|██████████| 1514/1514 [00:00<00:00, 8909.05it/s]
Constructing molecules from SMILES: 100%|██████████| 1513/1513 [00:02<00:00, 724.31it/s]


Then, we define the same model as the pre-training stage and set up the optimizer and solver for our downstream task. The only difference here is that we use PropertyPrediction task to support supervised learning.



In [25]:
model = models.GIN(input_dim=dataset.node_feature_dim,
                hidden_dims=[300, 300, 300, 300, 300],
                edge_input_dim=dataset.edge_feature_dim,
                batch_norm=True, readout="mean")
task = tasks.PropertyPrediction(model, task=dataset.tasks,
                                criterion="bce", metric=("auprc", "auroc"))

optimizer = torch.optim.Adam(task.parameters(), lr=1e-3)
solver = core.Engine(task, train_set, valid_set, test_set, optimizer,
                     gpus=[0], batch_size=256)

14:16:30   Preprocess training set
14:16:30   {'batch_size': 256,
 'class': 'core.Engine',
 'gpus': [0],
 'gradient_interval': 1,
 'log_interval': 100,
 'logger': 'logging',
 'num_worker': 0,
 'optimizer': {'amsgrad': False,
               'betas': (0.9, 0.999),
               'class': 'optim.Adam',
               'eps': 1e-08,
               'lr': 0.001,
               'maximize': False,
               'weight_decay': 0},
 'scheduler': None,
 'task': {'class': 'tasks.PropertyPrediction',
          'criterion': 'bce',
          'metric': ('auprc', 'auroc'),
          'model': {'activation': 'relu',
                    'batch_norm': True,
                    'class': 'models.GIN',
                    'concat_hidden': False,
                    'edge_input_dim': 11,
                    'eps': 0,
                    'hidden_dims': [300, 300, 300, 300, 300],
                    'input_dim': 22,
                    'learn_eps': False,
                    'num_mlp_layer': 2,
                

Now we can load our pretrained model and finetune it on downstream datasets.



In [26]:
checkpoint = torch.load("/gdrive/MyDrive/colab/molecule-models/clintox_gin_infograph.pth")["model"]
task.load_state_dict(checkpoint, strict=False)

solver.train(num_epoch=100)
solver.evaluate("valid")

14:17:24   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
14:17:24   Epoch 0 begin
14:17:24   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
14:17:24   binary cross entropy: 0.697578
14:17:24   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
14:17:24   Epoch 0 end
14:17:24   duration: 53.92 secs
14:17:24   speed: 0.09 batch / sec
14:17:24   ETA: 1.48 hours
14:17:24   max GPU memory: 312.7 MiB
14:17:24   ------------------------------
14:17:24   average binary cross entropy: 0.609225
14:17:24   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
14:17:24   Epoch 1 begin
14:17:24   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
14:17:24   Epoch 1 end
14:17:24   duration: 0.22 secs
14:17:24   speed: 22.23 batch / sec
14:17:24   ETA: 44.22 mins
14:17:24   max GPU memory: 312.9 MiB
14:17:24   ------------------------------
14:17:24   average binary cross entropy: 0.537239
14:17:24   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
14:17:24   Epoch 2 begin
14:17:25   >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
14:17:25   Epoch 2 end
14:17:25   duration: 0.23 secs
14:17:25   speed: 22.04 batch / sec
14:17:2

{'auprc [Class]': tensor(0.8879, device='cuda:0'),
 'auroc [Class]': tensor(0.5674, device='cuda:0')}