# Pytorch-Lightning Integration for DeepChem Models


In this tutorial we will go through how to setup a deepchem model inside the [pytorch-lightning](https://www.pytorchlightning.ai/) framework. Lightning is a pytorch framework which simplifies the process of experimenting with pytorch models easier. A few key functionalities offered by pytorch lightning which deepchem users can find useful are:

1. Multi-gpu training functionalities: pytorch-lightning provides easy multi-gpu, multi-node training. It also simplifies the process of launching multi-gpu, multi-node jobs across different cluster infrastructure, e.g. AWS, slurm based clusters.

1. Reducing boilerplate pytorch code: lightning takes care of details like, `optimizer.zero_grad(), model.train(), model.eval()`. Lightning also provides experiment logging functionality, for e.g. irrespective of training on CPU, GPU, multi-nodes the user can use the method `self.log` inside the trainer and it will appropriately log the metrics.
1. Features that can speed up training: half-precision training, gradient checkpointing, code profiling.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deepchem/deepchem/blob/master/examples/tutorials/PytorchLightning_Integration.ipynb)

## Setup

- This notebook assumes that you have already installed deepchem, if you have not follow the instructions at the deepchem installation page: https://deepchem.readthedocs.io/en/latest/get_started/installation.html.
- Install pytorch lightning following the instructions on lightning's home page: https://www.pytorchlightning.ai/

In [1]:
!pip install --pre deepchem
!pip install pytorch_lightning

Collecting rdkit-pypi
  Downloading rdkit_pypi-2021.9.5.1-cp38-cp38-macosx_11_0_arm64.whl (15.9 MB)
[K     |████████████████████████████████| 15.9 MB 6.8 MB/s eta 0:00:01
Installing collected packages: rdkit-pypi
Successfully installed rdkit-pypi-2021.9.5.1
Processing /Users/princychahal/Library/Caches/pip/wheels/8e/70/28/3d6ccd6e315f65f245da085482a2e1c7d14b90b30f239e2cf4/future-0.18.2-py3-none-any.whl


Installing collected packages: future
Successfully installed future-0.18.2


Import the relevant packages.

In [2]:
import deepchem as dc
from deepchem.models import GCNModel
import pytorch_lightning as pl
import torch
from torch.nn import functional as F
from torch import nn
import pytorch_lightning as pl
from pytorch_lightning.core.lightning import LightningModule
from torch.optim import Adam
import numpy as np
import torch

## Deepchem Example

Below we show an example of a Graph Convolution Network (GCN). Note that this is a simple example which uses a GCNModel to predict the label from an input sequence. We do not showcase the complete functionality of deepchem in this example as we want to restructure the deepchem code and adapt it so that it can be easily plugged into pytorch-lightning. This example was inspired from the `GCNModel` documentation present [here](https://github.com/deepchem/deepchem/blob/a68f8c072b80a1bce5671250aef60f9cc8519bec/deepchem/models/torch_models/gcn.py#L200).

**Prepare the dataset**: for training our deepchem models we need a dataset that we can use to train the model. Below we prepare a sample dataset for the purposes of this tutorial. Below we also directly use the featurized to encode examples for the dataset.

In [3]:
smiles = ["C1CCC1", "CCC"]
labels = [0., 1.]
featurizer = dc.feat.MolGraphConvFeaturizer()
X = featurizer.featurize(smiles)
dataset = dc.data.NumpyDataset(X=X, y=labels)

**Setup the model**: now we initialize the Graph Convolutional Network model that we will use in our training. 

In [4]:
model = GCNModel(
    mode='classification',
    n_tasks=1,
    batch_size=2,
    learning_rate=0.001
)

[16:00:37] /Users/princychahal/Documents/github/dgl/src/runtime/tensordispatch.cc:43: TensorDispatcher: dlopen failed: Using backend: pytorch
dlopen(/Users/princychahal/mambaforge/envs/keras_try_5/lib/python3.8/site-packages/dgl-0.8-py3.8-macosx-11.0-arm64.egg/dgl/tensoradapter/pytorch/libtensoradapter_pytorch_1.10.2.dylib, 1): image not found


**Train the model**: fit the model on our training dataset, also specify the number of epochs to run.

In [5]:
loss = model.fit(dataset, nb_epoch=5)
print(loss)

0.18830760717391967




## Pytorch-Lightning + Deepchem example

Now we will look at an example of the GCN model adapt for Pytorch-Lightning. For using Pytorch-Lightning there are two important components:
1. `LightningDataModule`: This module defines who the data is prepared and fed into the model so that the model can use it for training. The module defines the train dataloader function which are directly used by the trainer to generate data for the `LightningModule`. To learn more about the `LightningDataModule` refer to the [datamodules documentation](https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html).
2. `LightningModule`: This module defines the training, validation steps for our model. We can use this module to initialize our model based on the hyperparameters. There are a number of boilerplate functions which we use directly to track our experiments, for example we can save all the hyperparameters that we used for training using the `self.save_hyperparameters()` method. For more details on how to use this module refer to the [lightningmodules documentation](https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html).

**Setup the torch dataset**: Note that here we need to create a custome `SmilesDataset` so that we can easily interface with the deepchem featurizers. For this interface we need to define a collate method so that we can create batches for the dataset.

In [6]:
# prepare LightningDataModule
class SmilesDataset(torch.utils.data.Dataset):
    def __init__(self, smiles, labels):
        assert len(smiles) == len(labels)
        featurizer = dc.feat.MolGraphConvFeaturizer()
        X = featurizer.featurize(smiles)
        self._samples = dc.data.NumpyDataset(X=X, y=labels)
        
    def __len__(self):
        return len(self._samples)
        
    def __getitem__(self, index):
        return (
            self._samples.X[index],
            self._samples.y[index],
            self._samples.w[index],
        )
    
    
class SmilesDatasetBatch:
    def __init__(self, batch):
        X = [np.array([b[0] for b in batch])]
        y = [np.array([b[1] for b in batch])]
        w = [np.array([b[2] for b in batch])]
        self.batch_list = [X, y, w]
        
        
def collate_smiles_dataset_wrapper(batch):
    return SmilesDatasetBatch(batch)

**Create the GCN specific lightning module**: in this part we use an object of the `SmilesDataset` created above to create the `SmilesDatasetModule`

In [7]:
class SmilesDatasetModule(pl.LightningDataModule):
    def __init__(self, train_smiles, train_labels, batch_size):
        super().__init__()
        self._train_smiles = train_smiles
        self._train_labels = train_labels
        self._batch_size = batch_size
        
    def setup(self, stage):
        self.train_dataset = SmilesDataset(
            self._train_smiles,
            self._train_labels,
        )
        
    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self._batch_size,
            collate_fn=collate_smiles_dataset_wrapper,
            shuffle=True,  
        )

**Create the lightning module**: in this part we create the GCN specific lightning module. This class specifies the logic flow for the training step. We also create the required models, optimizers and losses for the training flow.

In [8]:
# prepare the LightningModule
class GCNModule(pl.LightningModule):
    def __init__(self, mode, n_tasks, learning_rate):
        super().__init__()
        self.save_hyperparameters(
            "mode",
            "n_tasks",
            "learning_rate",
        )
        self.gcn_model = GCNModel(
            mode=self.hparams.mode,
            n_tasks=self.hparams.n_tasks,
            learning_rate=self.hparams.learning_rate,
        )
        self.pt_model = self.gcn_model.model
        self.loss = self.gcn_model._loss_fn
        
    def configure_optimizers(self):
        return self.gcn_model.optimizer._create_pytorch_optimizer(
            self.pt_model.parameters(),
        )
    
    def training_step(self, batch, batch_idx):
        batch = batch.batch_list
        inputs, labels, weights = self.gcn_model._prepare_batch(batch)
        outputs = self.pt_model(inputs)
        
        if isinstance(outputs, torch.Tensor):
            outputs = [outputs]
    
        if self.gcn_model._loss_outputs is not None:
            outputs = [outputs[i] for i in self.gcn_model._loss_outputs]
    
        loss_outputs = self.loss(outputs, labels, weights)
        
        self.log(
            "train_loss",
            loss_outputs,
            on_epoch=True,
            sync_dist=True,
            reduce_fx="mean",
            prog_bar=True,
        )
        
        return loss_outputs

**Create the relevant objects**

In [9]:
# create module objects
smiles_datasetmodule = SmilesDatasetModule(
    train_smiles=["C1CCC1", "CCC", "C1CCC1", "CCC", "C1CCC1", "CCC", "C1CCC1", "CCC", "C1CCC1", "CCC"],
    train_labels=[0., 1., 0., 1., 0., 1., 0., 1., 0., 1.],
    batch_size=2,
)

gcnmodule = GCNModule(
    mode="classification",
    n_tasks=1,
    learning_rate=1e-3,
)

## Lightning Trainer

Trainer is the wrapper which builds on top of the `LightningDataModule` and `LightningModule`. When constructing the lightning trainer you can also specify the number of epochs, max-steps to run, number of GPUs, number of nodes to be used for trainer. Lightning trainer acts as a wrapper over your distributed training setup and this way you are able to build your models in a way you would build them in a simple way for your local runs.

In [10]:
trainer = pl.Trainer(
    max_epochs=5,
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


**Call the fit function to run model training**

In [11]:
# train
trainer.fit(
    model=gcnmodule,
    datamodule=smiles_datasetmodule,
)


  | Name     | Type | Params
----------------------------------
0 | pt_model | GCN  | 29.4 K
----------------------------------
29.4 K    Trainable params
0         Non-trainable params
29.4 K    Total params
0.118     Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]