Skip to content

Commit

Permalink
Merge pull request #3766 from arunppsg/lightning-updates
Browse files Browse the repository at this point in the history
Minor fixes to PyTorch Lightning
  • Loading branch information
rbharath committed Jan 6, 2024
2 parents a8687a7 + 0995247 commit eaad47c
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 19 deletions.
25 changes: 19 additions & 6 deletions deepchem/models/lightning/dc_lightning_dataset_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import pytorch_lightning as pl
from typing import Callable
import deepchem as dc
import lightning as L
import torch


Expand All @@ -15,7 +17,7 @@ def collate_dataset_wrapper(batch):
return DCLightningDatasetBatch(batch)


class DCLightningDatasetModule(pl.LightningDataModule):
class DCLightningDatasetModule(L.LightningDataModule):
"""DeepChem Lightning Dataset Module to be used with the DCLightningModule and a Lightning trainer.
This module wraps over the the deepchem pytorch dataset and dataloader providing a generic interface to run training.
Expand All @@ -26,20 +28,30 @@ class DCLightningDatasetModule(pl.LightningDataModule):
This class requires PyTorch to be installed.
"""

def __init__(self, dataset, batch_size, collate_fn):
def __init__(self,
dataset: dc.data.Dataset,
batch_size: int,
collate_fn: Callable = collate_dataset_wrapper,
num_workers: int = 0):
"""Create a new DCLightningDatasetModule.
Parameters
----------
dataset: A deepchem dataset.
batch_size: Batch size for the dataloader.
collate_fn: Method to collate instances across batch.
dataset: dc.data.Dataset
A deepchem dataset.
batch_size: int
Batch size for the dataloader.
collate_fn: Callable
Method to collate instances across batch.
num_workers: int
Number of workers to load data
"""
super().__init__()
self._batch_size = batch_size
self._dataset = dataset.make_pytorch_dataset(
batch_size=self._batch_size)
self.collate_fn = collate_fn
self.num_workers = num_workers

def setup(self, stage):
self.train_dataset = self._dataset
Expand All @@ -56,5 +68,6 @@ def train_dataloader(self):
batch_size=None,
collate_fn=self.collate_fn,
shuffle=False,
num_workers=self.num_workers,
)
return dataloader
27 changes: 15 additions & 12 deletions deepchem/models/lightning/dc_lightning_module.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import torch
import pytorch_lightning as pl # noqa
import lightning as L # noqa
from deepchem.models.torch_models import ModularTorchModel, TorchModel


class DCLightningModule(pl.LightningModule):
class DCLightningModule(L.LightningModule):
"""DeepChem Lightning Module to be used with Lightning trainer.
TODO: Add dataloader, example code and fit, once datasetmodule
Expand Down Expand Up @@ -30,6 +31,7 @@ def __init__(self, dc_model):
self.dc_model = dc_model

self.pt_model = self.dc_model.model
# This will not work for ModularTorchModel as it is directly uses `loss_func` to compute loss.
self.loss = self.dc_model._loss_fn

def configure_optimizers(self):
Expand All @@ -55,24 +57,25 @@ def training_step(self, batch, batch_idx):
assert len(inputs) == 1
inputs = inputs[0]

outputs = self.pt_model(inputs)
if isinstance(self.dc_model, ModularTorchModel):
loss = self.dc_model.loss_func(inputs, labels, weights)
elif isinstance(self.dc_model, TorchModel):
outputs = self.pt_model(inputs)
if isinstance(outputs, torch.Tensor):
outputs = [outputs]

if isinstance(outputs, torch.Tensor):
outputs = [outputs]

if self.dc_model._loss_outputs is not None:
outputs = [outputs[i] for i in self.dc_model._loss_outputs]

loss_outputs = self.loss(outputs, labels, weights)
if self.dc_model._loss_outputs is not None:
outputs = [outputs[i] for i in self.dc_model._loss_outputs]
loss = self.loss(outputs, labels, weights)

self.log(
"train_loss",
loss_outputs,
loss,
on_epoch=True,
sync_dist=True,
reduce_fx="mean",
prog_bar=True,
batch_size=self.dc_model.batch_size,
)

return loss_outputs
return loss
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ sphinx_rtd_theme>=1.0
tensorflow
transformers>=4.34.1
torch==2.1.0 --extra-index-url https://download.pytorch.org/whl/cpu
pytorch-lightning>=2.1.2
lightning
jax
dm-haiku
optax
Expand Down

0 comments on commit eaad47c

Please sign in to comment.