-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implement DCLightningModule #2945
Merged
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
b386895
implement DCLightningModule
Chahalprincy d4d0082
Address review comments
Chahalprincy 37cd7e0
Move imports under try-except
Chahalprincy f202b51
Move helper classes inside test function
Chahalprincy 2806bd3
update pytorch lightning dependency
Chahalprincy f874cbf
update transformers dependency version, corresponding changes
Chahalprincy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
try: | ||
import torch | ||
import pytorch_lightning as pl # noqa | ||
PYTORCH_LIGHTNING_IMPORT_FAILED = False | ||
except ImportError: | ||
PYTORCH_LIGHTNING_IMPORT_FAILED = True | ||
|
||
|
||
class DCLightningModule(pl.LightningModule): | ||
"""DeepChem Lightning Module to be used with Lightning trainer. | ||
|
||
Example code | ||
>>> import deepchem as dc | ||
>>> from deepchem.models import MultitaskClassifier | ||
>>> import numpy as np | ||
>>> import torch | ||
>>> from torch.utils.data import DataLoader | ||
>>> from deepchem.models.lightning.dc_lightning_module | ||
... import DCLightningModule | ||
>>> model = MultitaskClassifier(params) | ||
>>> valid_dataloader = DataLoader(test_dataset) | ||
>>> lightning_module = DCLightningModule(model) | ||
>>> trainer = pl.Trainer(max_epochs=1) | ||
>>> trainer.fit(lightning_module, valid_dataloader) | ||
|
||
The lightning module is a wrapper over deepchem's torch model. | ||
This module directly works with pytorch lightning trainer | ||
which runs training for multiple epochs and also is responsible | ||
for setting up and training models on multiple GPUs. | ||
""" | ||
|
||
def __init__(self, dc_model): | ||
"""Create a new DCLightningModule | ||
|
||
Parameters | ||
---------- | ||
dc_model: deepchem.models.torch_models.torch_model.TorchModel | ||
TorchModel to be wrapped inside the lightning module. | ||
""" | ||
super().__init__() | ||
self.dc_model = dc_model | ||
|
||
self.pt_model = self.dc_model.model | ||
self.loss = self.dc_model._loss_fn | ||
|
||
def configure_optimizers(self): | ||
"""Configure optimizers, for details refer to: | ||
https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.core.LightningModule.html?highlight=LightningModule | ||
""" | ||
return self.dc_model.optimizer._create_pytorch_optimizer( | ||
self.pt_model.parameters(),) | ||
|
||
def training_step(self, batch, batch_idx): | ||
"""For details refer to: | ||
https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.core.LightningModule.html?highlight=LightningModule # noqa | ||
|
||
Args: | ||
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): | ||
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. | ||
batch_idx (``int``): Integer displaying index of this batch | ||
optimizer_idx (``int``): When using multiple optimizers, this argument will also be present. | ||
hiddens (``Any``): Passed in if | ||
:paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0. | ||
|
||
Return: | ||
Any of. | ||
|
||
- :class:`~torch.Tensor` - The loss tensor | ||
- ``dict`` - A dictionary. Can include any keys, but must include the key ``'loss'`` | ||
- ``None`` - Training will skip to the next batch. This is only for automatic optimization. | ||
This is not supported for multi-GPU, TPU, IPU, or DeepSpeed. | ||
""" | ||
batch = batch.batch_list | ||
inputs, labels, weights = self.dc_model._prepare_batch(batch) | ||
|
||
outputs = self.pt_model(inputs[0]) | ||
|
||
if isinstance(outputs, torch.Tensor): | ||
outputs = [outputs] | ||
|
||
if self.dc_model._loss_outputs is not None: | ||
outputs = [outputs[i] for i in self.dc_model._loss_outputs] | ||
|
||
loss_outputs = self.loss(outputs, labels, weights) | ||
|
||
self.log( | ||
"train_loss", | ||
loss_outputs, | ||
on_epoch=True, | ||
sync_dist=True, | ||
reduce_fx="mean", | ||
prog_bar=True, | ||
) | ||
|
||
return loss_outputs |
65 changes: 65 additions & 0 deletions
65
deepchem/models/lightning/tests/test_dc_lightning_module.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import unittest | ||
import deepchem as dc | ||
import numpy as np | ||
|
||
try: | ||
from deepchem.models import MultitaskClassifier | ||
import torch | ||
from torch.utils.data import DataLoader | ||
from deepchem.models.lightning.dc_lightning_module import DCLightningModule | ||
import pytorch_lightning as pl # noqa | ||
PYTORCH_LIGHTNING_IMPORT_FAILED = False | ||
except ImportError: | ||
PYTORCH_LIGHTNING_IMPORT_FAILED = True | ||
|
||
|
||
class TestDCLightningModule(unittest.TestCase): | ||
|
||
@unittest.skipIf(PYTORCH_LIGHTNING_IMPORT_FAILED, | ||
'PyTorch Lightning is not installed') | ||
def test_multitask_classifier(self): | ||
|
||
class TestDatasetBatch: | ||
|
||
def __init__(self, batch): | ||
X = [np.array([b[0] for b in batch])] | ||
y = [np.array([b[1] for b in batch])] | ||
w = [np.array([b[2] for b in batch])] | ||
self.batch_list = [X, y, w] | ||
|
||
def collate_dataset_wrapper(batch): | ||
return TestDatasetBatch(batch) | ||
|
||
class TestDataset(torch.utils.data.Dataset): | ||
|
||
def __init__(self, dataset): | ||
self._samples = dataset | ||
|
||
def __len__(self): | ||
return len(self._samples) | ||
|
||
def __getitem__(self, index): | ||
y = np.zeros((1, 2)) | ||
y[0, int(self._samples.y[index][0])] = 1.0 | ||
return ( | ||
self._samples.X[index], | ||
y, | ||
self._samples.w[index], | ||
) | ||
|
||
tasks, datasets, _ = dc.molnet.load_clintox() | ||
_, valid_dataset, _ = datasets | ||
|
||
model = MultitaskClassifier(n_tasks=len(tasks), | ||
n_features=1024, | ||
layer_sizes=[1000], | ||
dropouts=0.2, | ||
learning_rate=0.0001) | ||
|
||
valid_dataloader = DataLoader(TestDataset(valid_dataset), | ||
batch_size=64, | ||
collate_fn=collate_dataset_wrapper) | ||
|
||
lightning_module = DCLightningModule(model) | ||
trainer = pl.Trainer(max_epochs=1) | ||
trainer.fit(lightning_module, valid_dataloader) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The formatting here and elsewhere in the docstring has issues. Are you planning to fix in #2958 @Chahalprincy? This new class also needs to be added to the docs/ folder to render on readthedocs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I will do it in another PR. Sorry, I missed it.