Skip to content

Commit

Permalink
Global control of graph caching
Browse files Browse the repository at this point in the history
  • Loading branch information
swansonk14 committed Sep 10, 2020
1 parent a321542 commit 93e6420
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 31 deletions.
3 changes: 2 additions & 1 deletion chemprop/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,12 @@ class TrainArgs(CommonArgs):
"""The number of batches between each logging of the training loss."""
show_individual_scores: bool = False
"""Show all scores for individual targets, not just average, at the end."""
cache_cutoff: int = 10000
cache_cutoff: float = 10000
"""
Maximum number of molecules in dataset to allow caching.
Below this number, caching is used and data loading is sequential.
Above this number, caching is not used and data loading is parallel.
Use "inf" to always cache.
"""
save_preds: bool = False
"""Whether to save test split predictions during training."""
Expand Down
8 changes: 4 additions & 4 deletions chemprop/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .data import (
get_cache_graph,
get_cache_mol,
cache_graph,
cache_mol,
MoleculeDatapoint,
MoleculeDataset,
MoleculeDataLoader,
Expand All @@ -24,8 +24,8 @@
)

__all__ = [
'get_cache_graph',
'get_cache_mol',
'cache_graph',
'cache_mol',
'MoleculeDatapoint',
'MoleculeDataset',
'MoleculeDataLoader',
Expand Down
26 changes: 8 additions & 18 deletions chemprop/data/data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import threading
from collections import OrderedDict
from functools import partial
from random import Random
from typing import Dict, Iterator, List, Optional, Union

Expand All @@ -18,7 +17,7 @@
SMILES_TO_GRAPH: Dict[str, MolGraph] = {}


def get_cache_graph() -> bool:
def cache_graph() -> bool:
r"""Returns whether :class:`~chemprop.features.MolGraph`\ s will be cached."""
return CACHE_GRAPH

Expand All @@ -34,7 +33,7 @@ def set_cache_graph(cache_graph: bool) -> None:
SMILES_TO_MOL: Dict[str, Chem.Mol] = {}


def get_cache_mol() -> bool:
def cache_mol() -> bool:
r"""Returns whether RDKit molecules will be cached."""
return CACHE_MOL

Expand Down Expand Up @@ -94,8 +93,7 @@ def mol(self) -> Chem.Mol:
"""Gets the corresponding RDKit molecule for this molecule's SMILES."""
mol = SMILES_TO_MOL.get(self.smiles, Chem.MolFromSmiles(self.smiles))

if CACHE_MOL:
print('cache')
if cache_mol():
SMILES_TO_MOL[self.smiles] = mol

return mol
Expand Down Expand Up @@ -157,7 +155,7 @@ def mols(self) -> List[Chem.Mol]:
"""
return [d.mol for d in self._data]

def batch_graph(self, cache: bool = False) -> BatchMolGraph:
def batch_graph(self) -> BatchMolGraph:
r"""
Constructs a :class:`~chemprop.features.BatchMolGraph` with the graph featurization of all the molecules.
Expand All @@ -167,8 +165,6 @@ def batch_graph(self, cache: bool = False) -> BatchMolGraph:
set of :class:`MoleculeDatapoint`\ s changes, then the returned :class:`~chemprop.features.BatchMolGraph`
will be incorrect for the underlying data.
:param cache: Whether to store the individual :class:`~chemprop.features.MolGraph` featurizations
for each molecule in a global cache.
:return: A :class:`~chemprop.features.BatchMolGraph` containing the graph featurization of all the molecules.
"""
if self._batch_graph is None:
Expand All @@ -178,7 +174,7 @@ def batch_graph(self, cache: bool = False) -> BatchMolGraph:
mol_graph = SMILES_TO_GRAPH[d.smiles]
else:
mol_graph = MolGraph(d.mol)
if cache:
if cache_graph():
SMILES_TO_GRAPH[d.smiles] = mol_graph
mol_graphs.append(mol_graph)

Expand Down Expand Up @@ -366,20 +362,18 @@ def __len__(self) -> int:
return self.length


def construct_molecule_batch(data: List[MoleculeDatapoint], cache: bool = False) -> MoleculeDataset:
def construct_molecule_batch(data: List[MoleculeDatapoint]) -> MoleculeDataset:
r"""
Constructs a :class:`MoleculeDataset` from a list of :class:`MoleculeDatapoint`\ s.
Additionally, precomputes the :class:`~chemprop.features.BatchMolGraph` for the constructed
:class:`MoleculeDataset`.
:param data: A list of :class:`MoleculeDatapoint`\ s.
:param cache: Whether to store the individual :class:`~chemprop.features.MolGraph` featurizations
for each molecule in a global cache.
:return: A :class:`MoleculeDataset` containing all the :class:`MoleculeDatapoint`\ s.
"""
data = MoleculeDataset(data)
data.batch_graph(cache=cache) # Forces computation and caching of the BatchMolGraph for the molecules
data.batch_graph() # Forces computation and caching of the BatchMolGraph for the molecules

return data

Expand All @@ -391,16 +385,13 @@ def __init__(self,
dataset: MoleculeDataset,
batch_size: int = 50,
num_workers: int = 8,
cache: bool = False,
class_balance: bool = False,
shuffle: bool = False,
seed: int = 0):
"""
:param dataset: The :class:`MoleculeDataset` containing the molecules to load.
:param batch_size: Batch size.
:param num_workers: Number of workers used to build batches.
:param cache: Whether to store the individual :class:`~chemprop.features.MolGraph` featurizations
for each molecule in a global cache.
:param class_balance: Whether to perform class balancing (i.e., use an equal number of positive
and negative molecules). Class balance is only available for single task
classification datasets. Set shuffle to True in order to get a random
Expand All @@ -411,7 +402,6 @@ def __init__(self,
self._dataset = dataset
self._batch_size = batch_size
self._num_workers = num_workers
self._cache = cache
self._class_balance = class_balance
self._shuffle = shuffle
self._seed = seed
Expand All @@ -434,7 +424,7 @@ def __init__(self,
batch_size=self._batch_size,
sampler=self._sampler,
num_workers=self._num_workers,
collate_fn=partial(construct_molecule_batch, cache=self._cache),
collate_fn=construct_molecule_batch,
multiprocessing_context=self._context,
timeout=self._timeout
)
Expand Down
13 changes: 5 additions & 8 deletions chemprop/train/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .train import train
from chemprop.args import TrainArgs
from chemprop.constants import MODEL_FILE_NAME
from chemprop.data import get_class_sizes, get_data, MoleculeDataLoader, MoleculeDataset, split_data, StandardScaler
from chemprop.data import get_class_sizes, get_data, MoleculeDataLoader, MoleculeDataset, set_cache_graph, split_data
from chemprop.models import MoleculeModel
from chemprop.nn_utils import param_count
from chemprop.utils import build_optimizer, build_lr_scheduler, get_loss_func, load_checkpoint,makedirs, \
Expand Down Expand Up @@ -106,33 +106,30 @@ def run_training(args: TrainArgs,

# Automatically determine whether to cache
if len(data) <= args.cache_cutoff:
cache = True
set_cache_graph(True)
num_workers = 0
else:
cache = False
set_cache_graph(False)
num_workers = args.num_workers

# Create data loaders
train_data_loader = MoleculeDataLoader(
dataset=train_data,
batch_size=args.batch_size,
num_workers=num_workers,
cache=cache,
class_balance=args.class_balance,
shuffle=True,
seed=args.seed
)
val_data_loader = MoleculeDataLoader(
dataset=val_data,
batch_size=args.batch_size,
num_workers=num_workers,
cache=cache
num_workers=num_workers
)
test_data_loader = MoleculeDataLoader(
dataset=test_data,
batch_size=args.batch_size,
num_workers=num_workers,
cache=cache
num_workers=num_workers
)

if args.class_balance:
Expand Down

0 comments on commit 93e6420

Please sign in to comment.