Skip to content

Commit

Permalink
Minor improvements to molnet loader functions
Browse files Browse the repository at this point in the history
  • Loading branch information
peastman committed Oct 14, 2020
1 parent 413c6a4 commit 47006c5
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 11 deletions.
9 changes: 5 additions & 4 deletions deepchem/molnet/defaults.py
Expand Up @@ -17,10 +17,11 @@
logger = logging.getLogger(__name__)

featurizers = {
'ECFP': dc.feat.CircularFingerprint(size=1024),
'GraphConv': dc.feat.ConvMolFeaturizer(),
'Weave': dc.feat.WeaveFeaturizer(),
'Raw': dc.feat.RawFeaturizer()
'ecfp': dc.feat.CircularFingerprint(size=1024),
'graphconv': dc.feat.ConvMolFeaturizer(),
'weave': dc.feat.WeaveFeaturizer(),
'raw': dc.feat.RawFeaturizer(),
'smiles2img': dc.feat.SmilesToImage(img_size=80, img_spec='std')
}

splitters = {
Expand Down
8 changes: 4 additions & 4 deletions deepchem/molnet/load_function/delaney_datasets.py
Expand Up @@ -68,9 +68,9 @@ def load_delaney(
splitter = kwargs['split']
logger.warning("'split' is deprecated. Use 'splitter' instead.")
if isinstance(featurizer, str):
featurizer = dc.molnet.defaults.featurizers[featurizer]
featurizer = dc.molnet.defaults.featurizers[featurizer.lower()]
if isinstance(splitter, str):
splitter = dc.molnet.defaults.splitters[splitter]
splitter = dc.molnet.defaults.splitters[splitter.lower()]
if data_dir is None:
data_dir = DEFAULT_DIR
if save_dir is None:
Expand All @@ -80,8 +80,8 @@ def load_delaney(
# Try to reload cached datasets.

if reload:
featurizer_name = str(featurizer.__class__.__name__)
splitter_name = str(splitter.__class__.__name__)
featurizer_name = str(featurizer)
splitter_name = str(splitter)
if not move_mean:
featurizer_name = featurizer_name + "_mean_unmoved"
save_folder = os.path.join(save_dir, "delaney-featurized", featurizer_name,
Expand Down
43 changes: 40 additions & 3 deletions deepchem/splits/splitters.py
@@ -1,6 +1,7 @@
"""
Contains an abstract base class that supports chemically aware data splits.
"""
import inspect
import os
import random
import tempfile
Expand Down Expand Up @@ -270,7 +271,30 @@ def __str__(self) -> str:
>>> str(dc.splits.RandomSplitter())
'RandomSplitter'
"""
return self.__class__.__name__
args_spec = inspect.getfullargspec(self.__init__) # type: ignore
args_names = [arg for arg in args_spec.args if arg != 'self']
args_num = len(args_names)
args_default_values = [None for _ in range(args_num)]
if args_spec.defaults is not None:
defaults = list(args_spec.defaults)
args_default_values[-len(defaults):] = defaults

override_args_info = ''
for arg_name, default in zip(args_names, args_default_values):
if arg_name in self.__dict__:
arg_value = self.__dict__[arg_name]
# validation
# skip list
if isinstance(arg_value, list):
continue
if isinstance(arg_value, str):
# skip path string
if "\\/." in arg_value or "/" in arg_value or '.' in arg_value:
continue
# main logic
if default != arg_value:
override_args_info += '_' + arg_name + '_' + str(arg_value)
return self.__class__.__name__ + override_args_info

def __repr__(self) -> str:
"""Convert self to repr representation.
Expand All @@ -284,9 +308,22 @@ def __repr__(self) -> str:
--------
>>> import deepchem as dc
>>> dc.splits.RandomSplitter()
RandomSplitter
RandomSplitter[]
"""
return self.__str__()
args_spec = inspect.getfullargspec(self.__init__) # type: ignore
args_names = [arg for arg in args_spec.args if arg != 'self']
args_info = ''
for arg_name in args_names:
value = self.__dict__[arg_name]
# for str
if isinstance(value, str):
value = "'" + value + "'"
# for list
if isinstance(value, list):
threshold = get_print_threshold()
value = np.array2string(np.array(value), threshold=threshold)
args_info += arg_name + '=' + str(value) + ', '
return self.__class__.__name__ + '[' + args_info[:-2] + ']'


class RandomSplitter(Splitter):
Expand Down

0 comments on commit 47006c5

Please sign in to comment.