Skip to content

Commit

Permalink
Refactored molnet loader
Browse files Browse the repository at this point in the history
  • Loading branch information
peastman committed Oct 14, 2020
1 parent 47006c5 commit 407db0e
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 96 deletions.
2 changes: 2 additions & 0 deletions deepchem/molnet/__init__.py
Expand Up @@ -37,6 +37,8 @@
from deepchem.molnet.load_function.material_datasets.load_mp_formation_energy import load_mp_formation_energy
from deepchem.molnet.load_function.material_datasets.load_mp_metallicity import load_mp_metallicity

from deepchem.molnet.load_function.molnet_loader import featurizers, splitters, _MolnetLoader

from deepchem.molnet.dnasim import simulate_motif_density_localization
from deepchem.molnet.dnasim import simulate_motif_counting
from deepchem.molnet.dnasim import simple_motif_embedding
Expand Down
17 changes: 0 additions & 17 deletions deepchem/molnet/defaults.py
Expand Up @@ -16,23 +16,6 @@

logger = logging.getLogger(__name__)

featurizers = {
'ecfp': dc.feat.CircularFingerprint(size=1024),
'graphconv': dc.feat.ConvMolFeaturizer(),
'weave': dc.feat.WeaveFeaturizer(),
'raw': dc.feat.RawFeaturizer(),
'smiles2img': dc.feat.SmilesToImage(img_size=80, img_spec='std')
}

splitters = {
'index': dc.splits.IndexSplitter(),
'random': dc.splits.RandomSplitter(),
'scaffold': dc.splits.ScaffoldSplitter(),
'butina': dc.splits.ButinaSplitter(),
'task': dc.splits.TaskSplitter(),
'stratified': dc.splits.RandomStratifiedSplitter()
}


def get_defaults(module_name: str = None) -> Dict[str, Any]:
"""Get featurizers, transformers, and splitters.
Expand Down
113 changes: 34 additions & 79 deletions deepchem/molnet/load_function/delaney_datasets.py
Expand Up @@ -4,13 +4,32 @@
import os
import logging
import deepchem as dc
from deepchem.data import Dataset, DiskDataset
from deepchem.molnet.load_function.molnet_loader import _MolnetLoader
from deepchem.data import Dataset
from typing import List, Optional, Tuple, Union

logger = logging.getLogger(__name__)

DEFAULT_DIR = dc.utils.data_utils.get_data_dir()
DELANEY_URL = "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/delaney-processed.csv"
DELANEY_TASKS = ['measured log solubility in mols per litre']


class _DelaneyLoader(_MolnetLoader):

def create_dataset(self) -> Dataset:
logger.info("About to featurize Delaney dataset.")
dataset_file = os.path.join(self.data_dir, "delaney-processed.csv")
if not os.path.exists(dataset_file):
dc.utils.data_utils.download_url(url=DELANEY_URL, dest_dir=self.data_dir)
loader = dc.data.CSVLoader(
tasks=DELANEY_TASKS, feature_field="smiles", featurizer=self.featurizer)
return loader.create_dataset(dataset_file, shard_size=8192)

def get_transformers(self, dataset: Dataset) -> List[dc.trans.Transformer]:
return [
dc.trans.NormalizationTransformer(
transform_y=True, dataset=dataset, move_mean=self.args['move_mean'])
]


def load_delaney(
Expand All @@ -22,9 +41,9 @@ def load_delaney(
save_dir: Optional[str] = None,
**kwargs
) -> Tuple[List[str], Tuple[Dataset, ...], List[dc.trans.Transformer]]:
"""Load delaney dataset
"""Load Delaney dataset
The Delaney(ESOL) dataset a regression dataset containing structures and
The Delaney (ESOL) dataset a regression dataset containing structures and
water solubility data for 1128 compounds. The dataset is widely used to
validate machine learning models on estimating solubility directly from
molecular structures (as encoded in SMILES strings).
Expand All @@ -42,11 +61,11 @@ def load_delaney(
----------
featurizer: Featurizer or str
the featurizer to use for processing the data. Alternatively you can pass
one of the names from dc.molnet.defaults.featurizers as a shortcut.
one of the names from dc.molnet.featurizers as a shortcut.
splitter: Splitter or str
the splitter to use for splitting the data into training, validation, and
test sets. Alternatively you can pass one of the names from
dc.molnet.defaults.splitters as a shortcut. If this is None, all the data
dc.molnet.splitters as a shortcut. If this is None, all the data
will be included in a single dataset.
reload: bool
if True, the first call for a particular featurizer and splitter will cache
Expand All @@ -64,76 +83,12 @@ def load_delaney(
molecular structure." Journal of chemical information and computer
sciences 44.3 (2004): 1000-1005.
"""
if 'split' in kwargs:
splitter = kwargs['split']
logger.warning("'split' is deprecated. Use 'splitter' instead.")
if isinstance(featurizer, str):
featurizer = dc.molnet.defaults.featurizers[featurizer.lower()]
if isinstance(splitter, str):
splitter = dc.molnet.defaults.splitters[splitter.lower()]
if data_dir is None:
data_dir = DEFAULT_DIR
if save_dir is None:
save_dir = DEFAULT_DIR
tasks = ['measured log solubility in mols per litre']

# Try to reload cached datasets.

if reload:
featurizer_name = str(featurizer)
splitter_name = str(splitter)
if not move_mean:
featurizer_name = featurizer_name + "_mean_unmoved"
save_folder = os.path.join(save_dir, "delaney-featurized", featurizer_name,
splitter_name)
if splitter is None:
if os.path.exists(save_folder):
transformers = dc.utils.data_utils.load_transformers(save_folder)
return tasks, (DiskDataset(save_folder),), transformers
else:
loaded, all_dataset, transformers = dc.utils.data_utils.load_dataset_from_disk(
save_folder)
if all_dataset is not None:
return tasks, all_dataset, transformers

# Featurize Delaney dataset

logger.info("About to featurize Delaney dataset.")
dataset_file = os.path.join(data_dir, "delaney-processed.csv")
if not os.path.exists(dataset_file):
dc.utils.data_utils.download_url(url=DELANEY_URL, dest_dir=data_dir)
loader = dc.data.CSVLoader(
tasks=tasks, feature_field="smiles", featurizer=featurizer)
dataset = loader.create_dataset(dataset_file, shard_size=8192)

# Split and transform the dataset.

if splitter is None:
transformer_dataset: Dataset = dataset
else:
logger.info("About to split dataset with {} splitter.".format(
splitter.__class__.__name__))
train, valid, test = splitter.train_valid_test_split(dataset)
transformer_dataset = train
transformers = [
dc.trans.NormalizationTransformer(
transform_y=True, dataset=transformer_dataset, move_mean=move_mean)
]
logger.info("About to transform data.")
if splitter is None:
for transformer in transformers:
dataset = transformer.transform(dataset)
if reload and isinstance(dataset, DiskDataset):
dataset.move(save_folder)
dc.utils.data_utils.save_transformers(save_folder, transformers)
return tasks, (dataset,), transformers

for transformer in transformers:
train = transformer.transform(train)
valid = transformer.transform(valid)
test = transformer.transform(test)
if reload and isinstance(train, DiskDataset) and isinstance(
valid, DiskDataset) and isinstance(test, DiskDataset):
dc.utils.data_utils.save_dataset_to_disk(save_folder, train, valid, test,
transformers)
return tasks, (train, valid, test), transformers
loader = _DelaneyLoader(
featurizer, splitter, data_dir, save_dir, move_mean=move_mean, **kwargs)
featurizer_name = str(loader.featurizer)
splitter_name = 'None' if loader.splitter is None else str(loader.splitter)
if not move_mean:
featurizer_name = featurizer_name + "_mean_unmoved"
save_folder = os.path.join(loader.save_dir, "delaney-featurized",
featurizer_name, splitter_name)
return loader.load_dataset(DELANEY_TASKS, save_folder, reload)
139 changes: 139 additions & 0 deletions deepchem/molnet/load_function/molnet_loader.py
@@ -0,0 +1,139 @@
"""
Common code for loading MoleculeNet datasets.
"""
import os
import logging
import deepchem as dc
from deepchem.data import Dataset, DiskDataset
from typing import List, Optional, Tuple, Union

logger = logging.getLogger(__name__)

featurizers = {
'ecfp': dc.feat.CircularFingerprint(size=1024),
'graphconv': dc.feat.ConvMolFeaturizer(),
'weave': dc.feat.WeaveFeaturizer(),
'raw': dc.feat.RawFeaturizer(),
'smiles2img': dc.feat.SmilesToImage(img_size=80, img_spec='std')
}

splitters = {
'index': dc.splits.IndexSplitter(),
'random': dc.splits.RandomSplitter(),
'scaffold': dc.splits.ScaffoldSplitter(),
'butina': dc.splits.ButinaSplitter(),
'task': dc.splits.TaskSplitter(),
'stratified': dc.splits.RandomStratifiedSplitter()
}


class _MolnetLoader(object):
"""The class provides common functionality used by many molnet loader functions.
It is an abstract class. Subclasses implement loading of particular datasets.
"""

def __init__(self, featurizer: Union[dc.feat.Featurizer, str],
splitter: Union[dc.splits.Splitter, str, None],
data_dir: Optional[str], save_dir: Optional[str], **kwargs):
"""Construct an object for loading a dataset.
Parameters
----------
featurizer: Featurizer or str
the featurizer to use for processing the data. Alternatively you can pass
one of the names from dc.molnet.featurizers as a shortcut.
splitter: Splitter or str
the splitter to use for splitting the data into training, validation, and
test sets. Alternatively you can pass one of the names from
dc.molnet.splitters as a shortcut. If this is None, all the data
will be included in a single dataset.
data_dir: str
a directory to save the raw data in
save_dir: str
a directory to save the dataset in
"""
if 'split' in kwargs:
splitter = kwargs['split']
logger.warning("'split' is deprecated. Use 'splitter' instead.")
if isinstance(featurizer, str):
featurizer = featurizers[featurizer.lower()]
if isinstance(splitter, str):
splitter = splitters[splitter.lower()]
if data_dir is None:
data_dir = dc.utils.data_utils.get_data_dir()
if save_dir is None:
save_dir = dc.utils.data_utils.get_data_dir()
self.featurizer = featurizer
self.splitter = splitter
self.data_dir = data_dir
self.save_dir = save_dir
self.args = kwargs

def load_dataset(
self, tasks: List[str], save_folder: str, reload: bool
) -> Tuple[List[str], Tuple[Dataset, ...], List[dc.trans.Transformer]]:
"""Load the dataset.
Parameters
----------
tasks: List[str]
the names of the tasks in this dataset
save_folder: str
the directory in which the dataset should be saved
reload: bool
if True, the first call for a particular featurizer and splitter will cache
the datasets to disk, and subsequent calls will reload the cached datasets.
"""
# Try to reload cached datasets.

if reload:
if self.splitter is None:
if os.path.exists(save_folder):
transformers = dc.utils.data_utils.load_transformers(save_folder)
return tasks, (DiskDataset(save_folder),), transformers
else:
loaded, all_dataset, transformers = dc.utils.data_utils.load_dataset_from_disk(
save_folder)
if all_dataset is not None:
return tasks, all_dataset, transformers

# Create the dataset

dataset = self.create_dataset()

# Split and transform the dataset.

if self.splitter is None:
transformer_dataset: Dataset = dataset
else:
logger.info("About to split dataset with {} splitter.".format(
self.splitter.__class__.__name__))
train, valid, test = self.splitter.train_valid_test_split(dataset)
transformer_dataset = train
transformers = self.get_transformers(transformer_dataset)
logger.info("About to transform data.")
if self.splitter is None:
for transformer in transformers:
dataset = transformer.transform(dataset)
if reload and isinstance(dataset, DiskDataset):
dataset.move(save_folder)
dc.utils.data_utils.save_transformers(save_folder, transformers)
return tasks, (dataset,), transformers

for transformer in transformers:
train = transformer.transform(train)
valid = transformer.transform(valid)
test = transformer.transform(test)
if reload and isinstance(train, DiskDataset) and isinstance(
valid, DiskDataset) and isinstance(test, DiskDataset):
dc.utils.data_utils.save_dataset_to_disk(save_folder, train, valid, test,
transformers)
return tasks, (train, valid, test), transformers

def create_dataset(self) -> Dataset:
"""Subclasses must implement this to load the dataset."""
raise NotImplementedError()

def get_transformers(self, dataset: Dataset) -> List[dc.trans.Transformer]:
"""Subclasses must implement this to create the transformers for the dataset."""
raise NotImplementedError()

0 comments on commit 407db0e

Please sign in to comment.