Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement DCLightningModule #2945

Merged
merged 6 commits into from
Jul 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions deepchem/feat/graph_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __repr__(self) -> str:
if self.edge_features is not None:
edge_features_str = str(list(self.edge_features.shape))
else:
edge_features_str = None
edge_features_str = "None"

out = "%s(node_features=%s, edge_index=%s, edge_features=%s" % (
cls, node_features_str, edge_index_str, edge_features_str)
Expand Down Expand Up @@ -146,12 +146,11 @@ def to_pyg_graph(self):
kwargs = {}
for key, value in self.kwargs.items():
kwargs[key] = torch.from_numpy(value).float()
return Data(
x=torch.from_numpy(self.node_features).float(),
edge_index=torch.from_numpy(self.edge_index).long(),
edge_attr=edge_features,
pos=node_pos_features,
**kwargs)
return Data(x=torch.from_numpy(self.node_features).float(),
edge_index=torch.from_numpy(self.edge_index).long(),
edge_attr=edge_features,
pos=node_pos_features,
**kwargs)

def to_dgl_graph(self, self_loop: bool = False):
"""Convert to DGL graph data instance
Expand All @@ -177,9 +176,8 @@ def to_dgl_graph(self, self_loop: bool = False):
src = self.edge_index[0]
dst = self.edge_index[1]

g = dgl.graph(
(torch.from_numpy(src).long(), torch.from_numpy(dst).long()),
num_nodes=self.num_nodes)
g = dgl.graph((torch.from_numpy(src).long(), torch.from_numpy(dst).long()),
num_nodes=self.num_nodes)
g.ndata['x'] = torch.from_numpy(self.node_features).float()

if self.node_pos_features is not None:
Expand Down
49 changes: 27 additions & 22 deletions deepchem/feat/smiles_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import re
import pkg_resources
from typing import List
from typing import List, Optional
from transformers import BertTokenizer
from logging import getLogger

Expand Down Expand Up @@ -84,20 +84,17 @@ def __init__(
"""

super().__init__(vocab_file, **kwargs)
# take into account special tokens in max length
self.max_len_single_sentence = self.model_max_length - 2
self.max_len_sentences_pair = self.model_max_length - 3

if not os.path.isfile(vocab_file):
raise ValueError(
"Can't find a vocab file at path '{}'.".format(vocab_file))
self.vocab = load_vocab(vocab_file)
self.highest_unused_index = max(
[i for i, v in enumerate(self.vocab.keys()) if v.startswith("[unused")])
self.ids_to_tokens = collections.OrderedDict(
[(ids, tok) for tok, ids in self.vocab.items()])
self.ids_to_tokens = collections.OrderedDict([
(ids, tok) for tok, ids in self.vocab.items()
])
self.basic_tokenizer = BasicSmilesTokenizer()
self.init_kwargs["model_max_length"] = self.model_max_length

@property
def vocab_size(self):
Expand All @@ -107,7 +104,7 @@ def vocab_size(self):
def vocab_list(self):
return list(self.vocab.keys())

def _tokenize(self, text: str):
def _tokenize(self, text: str, max_seq_length: int = 512, **kwargs):
"""Tokenize a string into a list of tokens.

Parameters
Expand All @@ -116,7 +113,11 @@ def _tokenize(self, text: str):
Input string sequence to be tokenized.
"""

split_tokens = [token for token in self.basic_tokenizer.tokenize(text)]
max_len_single_sentence = max_seq_length - 2
split_tokens = [
token for token in self.basic_tokenizer.tokenize(text)
[:max_len_single_sentence]
]
return split_tokens

def _convert_token_to_id(self, token: str):
Expand Down Expand Up @@ -158,7 +159,8 @@ def convert_tokens_to_string(self, tokens: List[str]):
out_string: str = " ".join(tokens).replace(" ##", "").strip()
return out_string

def add_special_tokens_ids_single_sequence(self, token_ids: List[int]):
def add_special_tokens_ids_single_sequence(self,
token_ids: List[Optional[int]]):
"""Adds special tokens to the a sequence for sequence classification tasks.

A BERT sequence has the following format: [CLS] X [SEP]
Expand All @@ -182,8 +184,9 @@ def add_special_tokens_single_sequence(self, tokens: List[str]):
"""
return [self.cls_token] + tokens + [self.sep_token]

def add_special_tokens_ids_sequence_pair(self, token_ids_0: List[int],
token_ids_1: List[int]) -> List[int]:
def add_special_tokens_ids_sequence_pair(
self, token_ids_0: List[Optional[int]],
token_ids_1: List[Optional[int]]) -> List[Optional[int]]:
"""Adds special tokens to a sequence pair for sequence classification tasks.
A BERT sequence pair has the following format: [CLS] A [SEP] B [SEP]

Expand All @@ -201,15 +204,15 @@ def add_special_tokens_ids_sequence_pair(self, token_ids_0: List[int],
return cls + token_ids_0 + sep + token_ids_1 + sep

def add_padding_tokens(self,
token_ids: List[int],
token_ids: List[Optional[int]],
length: int,
right: bool = True) -> List[int]:
right: bool = True) -> List[Optional[int]]:
"""Adds padding tokens to return a sequence of length max_length.
By default padding tokens are added to the right of the sequence.

Parameters
----------
token_ids: list[int]
token_ids: list[optional[int]]
list of tokenized input ids. Can be obtained using the encode or encode_plus methods.
length: int
TODO
Expand All @@ -229,8 +232,10 @@ def add_padding_tokens(self,
return padding + token_ids

def save_vocabulary(
self, vocab_path: str
): # -> tuple[str]: doctest issue raised with this return type annotation
self,
save_directory: str,
filename_prefix: Optional[str] = None
): # -> Tuple[str]: doctest issue raised with this return type annotation
"""Save the tokenizer vocabulary to a file.

Parameters
Expand All @@ -247,13 +252,13 @@ def save_vocabulary(
Default vocab file is found in deepchem/feat/tests/data/vocab.txt
"""
index = 0
if os.path.isdir(vocab_path):
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
if os.path.isdir(save_directory):
vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
else:
vocab_file = vocab_path
vocab_file = save_directory
with open(vocab_file, "w", encoding="utf-8") as writer:
for token, token_index in sorted(
self.vocab.items(), key=lambda kv: kv[1]):
for token, token_index in sorted(self.vocab.items(),
key=lambda kv: kv[1]):
if index != token_index:
logger.warning(
"Saving vocabulary to {}: vocabulary indices are not consecutive."
Expand Down
Empty file.
95 changes: 95 additions & 0 deletions deepchem/models/lightning/dc_lightning_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
try:
import torch
import pytorch_lightning as pl # noqa
PYTORCH_LIGHTNING_IMPORT_FAILED = False
except ImportError:
PYTORCH_LIGHTNING_IMPORT_FAILED = True


class DCLightningModule(pl.LightningModule):
"""DeepChem Lightning Module to be used with Lightning trainer.

Example code
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The formatting here and elsewhere in the docstring has issues. Are you planning to fix in #2958 @Chahalprincy? This new class also needs to be added to the docs/ folder to render on readthedocs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I will do it in another PR. Sorry, I missed it.

>>> import deepchem as dc
>>> from deepchem.models import MultitaskClassifier
>>> import numpy as np
>>> import torch
>>> from torch.utils.data import DataLoader
>>> from deepchem.models.lightning.dc_lightning_module
... import DCLightningModule
>>> model = MultitaskClassifier(params)
>>> valid_dataloader = DataLoader(test_dataset)
>>> lightning_module = DCLightningModule(model)
>>> trainer = pl.Trainer(max_epochs=1)
>>> trainer.fit(lightning_module, valid_dataloader)

The lightning module is a wrapper over deepchem's torch model.
This module directly works with pytorch lightning trainer
which runs training for multiple epochs and also is responsible
for setting up and training models on multiple GPUs.
"""

def __init__(self, dc_model):
"""Create a new DCLightningModule

Parameters
----------
dc_model: deepchem.models.torch_models.torch_model.TorchModel
TorchModel to be wrapped inside the lightning module.
"""
super().__init__()
self.dc_model = dc_model

self.pt_model = self.dc_model.model
self.loss = self.dc_model._loss_fn

def configure_optimizers(self):
"""Configure optimizers, for details refer to:
https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.core.LightningModule.html?highlight=LightningModule
"""
return self.dc_model.optimizer._create_pytorch_optimizer(
self.pt_model.parameters(),)

def training_step(self, batch, batch_idx):
"""For details refer to:
https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.core.LightningModule.html?highlight=LightningModule # noqa

Args:
batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
batch_idx (``int``): Integer displaying index of this batch
optimizer_idx (``int``): When using multiple optimizers, this argument will also be present.
hiddens (``Any``): Passed in if
:paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` > 0.

Return:
Any of.

- :class:`~torch.Tensor` - The loss tensor
- ``dict`` - A dictionary. Can include any keys, but must include the key ``'loss'``
- ``None`` - Training will skip to the next batch. This is only for automatic optimization.
This is not supported for multi-GPU, TPU, IPU, or DeepSpeed.
"""
batch = batch.batch_list
inputs, labels, weights = self.dc_model._prepare_batch(batch)

outputs = self.pt_model(inputs[0])

if isinstance(outputs, torch.Tensor):
outputs = [outputs]

if self.dc_model._loss_outputs is not None:
outputs = [outputs[i] for i in self.dc_model._loss_outputs]

loss_outputs = self.loss(outputs, labels, weights)

self.log(
"train_loss",
loss_outputs,
on_epoch=True,
sync_dist=True,
reduce_fx="mean",
prog_bar=True,
)

return loss_outputs
65 changes: 65 additions & 0 deletions deepchem/models/lightning/tests/test_dc_lightning_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import unittest
import deepchem as dc
import numpy as np

try:
from deepchem.models import MultitaskClassifier
import torch
from torch.utils.data import DataLoader
from deepchem.models.lightning.dc_lightning_module import DCLightningModule
import pytorch_lightning as pl # noqa
PYTORCH_LIGHTNING_IMPORT_FAILED = False
except ImportError:
PYTORCH_LIGHTNING_IMPORT_FAILED = True


class TestDCLightningModule(unittest.TestCase):

@unittest.skipIf(PYTORCH_LIGHTNING_IMPORT_FAILED,
'PyTorch Lightning is not installed')
def test_multitask_classifier(self):

class TestDatasetBatch:

def __init__(self, batch):
X = [np.array([b[0] for b in batch])]
y = [np.array([b[1] for b in batch])]
w = [np.array([b[2] for b in batch])]
self.batch_list = [X, y, w]

def collate_dataset_wrapper(batch):
return TestDatasetBatch(batch)

class TestDataset(torch.utils.data.Dataset):

def __init__(self, dataset):
self._samples = dataset

def __len__(self):
return len(self._samples)

def __getitem__(self, index):
y = np.zeros((1, 2))
y[0, int(self._samples.y[index][0])] = 1.0
return (
self._samples.X[index],
y,
self._samples.w[index],
)

tasks, datasets, _ = dc.molnet.load_clintox()
_, valid_dataset, _ = datasets

model = MultitaskClassifier(n_tasks=len(tasks),
n_features=1024,
layer_sizes=[1000],
dropouts=0.2,
learning_rate=0.0001)

valid_dataloader = DataLoader(TestDataset(valid_dataset),
batch_size=64,
collate_fn=collate_dataset_wrapper)

lightning_module = DCLightningModule(model)
trainer = pl.Trainer(max_epochs=1)
trainer.fit(lightning_module, valid_dataloader)
2 changes: 1 addition & 1 deletion deepchem/models/torch_models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,7 @@ def forward(self,

if self.is_undirected is True:
# coonverting edge features to its original shape
split = torch.split(edge_features, (edge_features_len, edge_features_len))
split = torch.split(edge_features, [edge_features_len, edge_features_len])
edge_features = (split[0] + split[1]) / 2

if self.residual_connection:
Expand Down
2 changes: 1 addition & 1 deletion requirements/env_common.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies:
- pyGPGO
- pymatgen
- simdna
- transformers==4.6.*
- transformers==4.10.*
- xgboost
- git+https://github.com/samoturk/mol2vec
- tb-nightly # @arunppsg shift to tensorboard stable version once tb 2.9.1 is released
Loading