/
dc_lightning_module.py
81 lines (66 loc) · 2.72 KB
/
dc_lightning_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import lightning as L # noqa
from deepchem.models.torch_models import ModularTorchModel, TorchModel
class DCLightningModule(L.LightningModule):
"""DeepChem Lightning Module to be used with Lightning trainer.
TODO: Add dataloader, example code and fit, once datasetmodule
is ready
The lightning module is a wrapper over deepchem's torch model.
This module directly works with pytorch lightning trainer
which runs training for multiple epochs and also is responsible
for setting up and training models on multiple GPUs.
https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.core.LightningModule.html?highlight=LightningModule
Notes
-----
This class requires PyTorch to be installed.
"""
def __init__(self, dc_model):
"""Create a new DCLightningModule.
Parameters
----------
dc_model: deepchem.models.torch_models.torch_model.TorchModel
TorchModel to be wrapped inside the lightning module.
"""
super().__init__()
self.dc_model = dc_model
self.pt_model = self.dc_model.model
# This will not work for ModularTorchModel as it is directly uses `loss_func` to compute loss.
self.loss = self.dc_model._loss_fn
def configure_optimizers(self):
return self.dc_model.optimizer._create_pytorch_optimizer(
self.pt_model.parameters(),)
def training_step(self, batch, batch_idx):
"""Perform a training step.
Parameters
----------
batch: A tensor, tuple or list.
batch_idx: Integer displaying index of this batch
optimizer_idx: When using multiple optimizers, this argument will also be present.
Returns
-------
loss_outputs: outputs of losses.
"""
batch = batch.batch_list
inputs, labels, weights = self.dc_model._prepare_batch(batch)
if isinstance(inputs, list):
assert len(inputs) == 1
inputs = inputs[0]
if isinstance(self.dc_model, ModularTorchModel):
loss = self.dc_model.loss_func(inputs, labels, weights)
elif isinstance(self.dc_model, TorchModel):
outputs = self.pt_model(inputs)
if isinstance(outputs, torch.Tensor):
outputs = [outputs]
if self.dc_model._loss_outputs is not None:
outputs = [outputs[i] for i in self.dc_model._loss_outputs]
loss = self.loss(outputs, labels, weights)
self.log(
"train_loss",
loss,
on_epoch=True,
sync_dist=True,
reduce_fx="mean",
prog_bar=True,
batch_size=self.dc_model.batch_size,
)
return loss