Skip to content

Commit

Permalink
Merge pull request #2903 from arunppsg/qm9
Browse files Browse the repository at this point in the history
Extracting molecular coordinates for QM9 dataset from sdf files
  • Loading branch information
rbharath committed May 27, 2022
2 parents 553c91e + ad98123 commit 4491b78
Show file tree
Hide file tree
Showing 20 changed files with 197 additions and 71 deletions.
1 change: 0 additions & 1 deletion contrib/vina_model/test_vina_model.py
Expand Up @@ -23,7 +23,6 @@
from deepchem.models.tensorflow_models.vina_model import get_cells_for_atoms
from deepchem.models.tensorflow_models.vina_model import compute_neighbor_list
import deepchem.utils.rdkit_util as rdkit_util
from deepchem.utils.save import load_sdf_files
from deepchem.utils import pad_array


Expand Down
61 changes: 35 additions & 26 deletions deepchem/data/data_loader.py
Expand Up @@ -13,7 +13,7 @@
import numpy as np

from deepchem.utils.typing import OneOrMany
from deepchem.utils.data_utils import load_image_files, load_csv_files, load_json_files, load_sdf_files
from deepchem.utils.data_utils import load_image_files, load_csv_files, load_json_files, load_sdf_files, unzip_file
from deepchem.feat import UserDefinedFeaturizer, Featurizer
from deepchem.data import Dataset, DiskDataset, NumpyDataset, ImageDataset
from deepchem.feat.molecule_featurizers import OneHotFeaturizer
Expand Down Expand Up @@ -507,8 +507,8 @@ def _featurize_shard(self,
shard[feature_fields] = shard[feature_fields].apply(pd.to_numeric)
X_shard = shard[feature_fields].to_numpy()
time2 = time.time()
logger.info(
"TIMING: user specified processing took %0.3f s" % (time2 - time1))
logger.info("TIMING: user specified processing took %0.3f s" %
(time2 - time1))
return (X_shard, np.ones(len(X_shard), dtype=bool))


Expand Down Expand Up @@ -795,10 +795,10 @@ def create_dataset(self,
processed_files.append(input_file)
elif extension == ".zip":
zip_dir = tempfile.mkdtemp()
zip_ref = zipfile.ZipFile(input_file, 'r')
zip_ref.extractall(path=zip_dir)
zip_ref.close()
zip_files = [os.path.join(zip_dir, name) for name in zip_ref.namelist()]
unzip_file(input_file, zip_dir)
zip_files = [
os.path.join(zip_dir, name) for name in os.listdir(zip_dir)
]
for zip_file in zip_files:
_, extension = os.path.splitext(zip_file)
extension = extension.lower()
Expand Down Expand Up @@ -850,11 +850,10 @@ def _get_shards(self, input_files: List[str],
Iterator[pd.DataFrame]
Iterator over shards
"""
return load_sdf_files(
input_files=input_files,
clean_mols=self.sanitize,
tasks=self.tasks,
shard_size=shard_size)
return load_sdf_files(input_files=input_files,
clean_mols=self.sanitize,
tasks=self.tasks,
shard_size=shard_size)

def _featurize_shard(self,
shard: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]:
Expand All @@ -876,7 +875,16 @@ def _featurize_shard(self,
Boolean values indicating successful featurization for corresponding
sample in the source.
"""
features = [elt for elt in self.featurizer(shard[self.mol_field])]
pos_cols = ['pos_x', 'pos_y', 'pos_z']
if set(pos_cols).issubset(shard.columns):
features = [
elt for elt in self.featurizer(shard[self.mol_field],
pos_x=shard['pos_x'],
pos_y=shard['pos_y'],
pos_z=shard['pos_z'])
]
else:
features = [elt for elt in self.featurizer(shard[self.mol_field])]
valid_inds = np.array(
[1 if np.array(elt).size > 0 else 0 for elt in features], dtype=bool)
features = [
Expand Down Expand Up @@ -953,8 +961,8 @@ def __init__(self,
if isinstance(featurizer, UserDefinedFeaturizer): # User defined featurizer
self.user_specified_features = featurizer.feature_fields
elif featurizer is None: # Default featurizer
featurizer = OneHotFeaturizer(
charset=["A", "C", "T", "G"], max_length=None)
featurizer = OneHotFeaturizer(charset=["A", "C", "T", "G"],
max_length=None)

# Set self.featurizer
self.featurizer = featurizer
Expand Down Expand Up @@ -1158,16 +1166,17 @@ def create_dataset(self,

if in_memory:
if data_dir is None:
return NumpyDataset(
load_image_files(image_files), y=labels, w=weights, ids=image_files)
return NumpyDataset(load_image_files(image_files),
y=labels,
w=weights,
ids=image_files)
else:
dataset = DiskDataset.from_numpy(
load_image_files(image_files),
y=labels,
w=weights,
ids=image_files,
tasks=self.tasks,
data_dir=data_dir)
dataset = DiskDataset.from_numpy(load_image_files(image_files),
y=labels,
w=weights,
ids=image_files,
tasks=self.tasks,
data_dir=data_dir)
if shard_size is not None:
dataset.reshard(shard_size)
return dataset
Expand Down Expand Up @@ -1311,8 +1320,8 @@ def _get_shards(self, inputs: List,

# FIXME: Signature of "_featurize_shard" incompatible with supertype "DataLoader"
def _featurize_shard( # type: ignore[override]
self, shard: List, global_index: int
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
self, shard: List, global_index: int) -> Tuple[np.ndarray, np.ndarray,
np.ndarray, np.ndarray]:
"""Featurizes a shard of an input data.
Parameters
Expand Down
5 changes: 4 additions & 1 deletion deepchem/feat/base_classes.py
Expand Up @@ -292,7 +292,10 @@ def featurize(self, datapoints, log_every_n=1000, **kwargs) -> np.ndarray:
new_order = rdmolfiles.CanonicalRankAtoms(mol)
mol = rdmolops.RenumberAtoms(mol, new_order)

features.append(self._featurize(mol, **kwargs))
kwargs_per_datapoint = {}
for key in kwargs.keys():
kwargs_per_datapoint[key] = kwargs[key][i]
features.append(self._featurize(mol, **kwargs_per_datapoint))
except Exception as e:
if isinstance(mol, Chem.rdchem.Mol):
mol = Chem.MolToSmiles(mol)
Expand Down
38 changes: 29 additions & 9 deletions deepchem/feat/molecule_featurizers/mol_graph_conv_featurizer.py
Expand Up @@ -26,9 +26,10 @@
from deepchem.utils.rdkit_utils import compute_pairwise_ring_info


def _construct_atom_feature(
atom: RDKitAtom, h_bond_infos: List[Tuple[int, str]], use_chirality: bool,
use_partial_charge: bool) -> np.ndarray:
def _construct_atom_feature(atom: RDKitAtom, h_bond_infos: List[Tuple[int,
str]],
use_chirality: bool,
use_partial_charge: bool) -> np.ndarray:
"""Construct an atom feature from a RDKit atom object.
Parameters
Expand Down Expand Up @@ -227,10 +228,29 @@ def _featurize(self, datapoint: RDKitMol, **kwargs) -> GraphData:
features += 2 * [_construct_bond_feature(bond)]
bond_features = np.asarray(features, dtype=float)

return GraphData(
node_features=atom_features,
edge_index=np.asarray([src, dest], dtype=int),
edge_features=bond_features)
# load_sdf_files returns pos as strings but user can also specify
# numpy arrays for atom coordinates
pos = []
if 'pos_x' in kwargs and 'pos_y' in kwargs and 'pos_z' in kwargs:
if isinstance(kwargs['pos_x'], str):
pos_x = eval(kwargs['pos_x'])
elif isinstance(kwargs['pos_x'], np.ndarray):
pos_x = kwargs['pos_x']
if isinstance(kwargs['pos_y'], str):
pos_y = eval(kwargs['pos_y'])
elif isinstance(kwargs['pos_y'], np.ndarray):
pos_y = kwargs['pos_y']
if isinstance(kwargs['pos_z'], str):
pos_z = eval(kwargs['pos_z'])
elif isinstance(kwargs['pos_z'], np.ndarray):
pos_z = kwargs['pos_z']

for x, y, z in zip(pos_x, pos_y, pos_z):
pos.append([x, y, z])
return GraphData(node_features=atom_features,
edge_index=np.asarray([src, dest], dtype=int),
edge_features=bond_features,
pos=np.asarray(pos))


class PagtnMolGraphFeaturizer(MolecularFeaturizer):
Expand Down Expand Up @@ -328,8 +348,8 @@ def _pagtn_atom_featurizer(self, atom: RDKitAtom) -> np.ndarray:
numpy vector of atom features.
"""
atom_type = get_atom_type_one_hot(atom, self.SYMBOLS, False)
formal_charge = get_atom_formal_charge_one_hot(
atom, include_unknown_set=False)
formal_charge = get_atom_formal_charge_one_hot(atom,
include_unknown_set=False)
degree = get_atom_total_degree_one_hot(atom, list(range(11)), False)
exp_valence = get_atom_explicit_valence_one_hot(atom, list(range(7)), False)
imp_valence = get_atom_implicit_valence_one_hot(atom, list(range(6)), False)
Expand Down
19 changes: 18 additions & 1 deletion deepchem/feat/tests/test_mol_graph_conv_featurizer.py
@@ -1,5 +1,5 @@
import unittest

import numpy as np
from deepchem.feat import MolGraphConvFeaturizer
from deepchem.feat import PagtnMolGraphFeaturizer

Expand Down Expand Up @@ -72,6 +72,23 @@ def test_featurizer_with_use_partial_charge(self):
assert graph_feat[1].num_node_features == 31
assert graph_feat[1].num_edges == 44

def test_featurizer_with_pos_kwargs(self):
# Test featurizer with atom 3-D coordinates as kwargs
smiles = ["C1=CC=CN=C1", "CC"]
pos_x = [np.random.randn(6), np.random.randn(2)]
pos_y, pos_z = pos_x, pos_x
featurizer = MolGraphConvFeaturizer()
graph_feat = featurizer.featurize(smiles,
pos_x=pos_x,
pos_y=pos_y,
pos_z=pos_z)

assert len(graph_feat) == 2
assert graph_feat[0].num_nodes == 6
assert graph_feat[0].pos.shape == (6, 3)
assert graph_feat[1].num_nodes == 2
assert graph_feat[1].pos.shape == (2, 3)


class TestPagtnMolGraphConvFeaturizer(unittest.TestCase):

Expand Down
7 changes: 4 additions & 3 deletions deepchem/molnet/load_function/qm9_datasets.py
Expand Up @@ -23,9 +23,10 @@ def create_dataset(self) -> Dataset:
dc.utils.data_utils.download_url(url=GDB9_URL, dest_dir=self.data_dir)
dc.utils.data_utils.untargz_file(
os.path.join(self.data_dir, "gdb9.tar.gz"), self.data_dir)
loader = dc.data.SDFLoader(
tasks=self.tasks, featurizer=self.featurizer, sanitize=True)
return loader.create_dataset(dataset_file, shard_size=8192)
loader = dc.data.SDFLoader(tasks=self.tasks,
featurizer=self.featurizer,
sanitize=True)
return loader.create_dataset(dataset_file, shard_size=4096)


def load_qm9(
Expand Down
68 changes: 48 additions & 20 deletions deepchem/utils/data_utils.py
Expand Up @@ -249,11 +249,15 @@ def load_sdf_files(input_files: List[str],

df_rows = []
for input_file in input_files:
# Tasks are either in .sdf.csv file or in the .sdf file itself
# Tasks are either in .sdf.csv file or in the .sdf file itself for QM9 dataset
has_csv = os.path.isfile(input_file + ".csv")
# Structures are stored in .sdf file
logger.info("Reading structures from %s." % input_file)
suppl = Chem.SDMolSupplier(str(input_file), clean_mols, False, False)

suppl = Chem.SDMolSupplier(str(input_file),
sanitize=clean_mols,
removeHs=False,
strictParsing=False)
for ind, mol in enumerate(suppl):
if mol is None:
continue
Expand All @@ -262,28 +266,47 @@ def load_sdf_files(input_files: List[str],
if not has_csv: # Get task targets from .sdf file
for task in tasks:
df_row.append(mol.GetProp(str(task)))

conf = mol.GetConformer()
positions = conf.GetPositions()
pos_x, pos_y, pos_z = zip(*positions)
df_row.append(str(pos_x))
df_row.append(str(pos_y))
df_row.append(str(pos_z))
df_rows.append(df_row)

if shard_size is not None and len(df_rows) == shard_size:
if has_csv:
mol_df = pd.DataFrame(df_rows, columns=('mol_id', 'smiles', 'mol'))
mol_df = pd.DataFrame(df_rows,
columns=('mol_id', 'smiles', 'mol', 'pos_x',
'pos_y', 'pos_z'))
raw_df = next(load_csv_files([input_file + ".csv"], shard_size=None))
yield pd.concat([mol_df, raw_df], axis=1, join='inner')
else:
mol_df = pd.DataFrame(
df_rows, columns=('mol_id', 'smiles', 'mol') + tuple(tasks))
# Note: Here, the order of columns is based on the order in which the values
# are appended to `df_row`. Since pos_x, pos_y, pos_z are appended after appending
# tasks above, they occur after `tasks` here.
# FIXME Ideally, we should use something like a dictionary here to keep it independent
# of column ordering.
mol_df = pd.DataFrame(df_rows,
columns=('mol_id', 'smiles', 'mol') +
tuple(tasks) + ('pos_x', 'pos_y', 'pos_z'))
yield mol_df
# Reset aggregator
df_rows = []

# Handle final leftovers for this file
if len(df_rows) > 0:
if has_csv:
mol_df = pd.DataFrame(df_rows, columns=('mol_id', 'smiles', 'mol'))
mol_df = pd.DataFrame(df_rows,
columns=('mol_id', 'smiles', 'mol', 'pos_x',
'pos_y', 'pos_z'))
raw_df = next(load_csv_files([input_file + ".csv"], shard_size=None))
yield pd.concat([mol_df, raw_df], axis=1, join='inner')
else:
mol_df = pd.DataFrame(
df_rows, columns=('mol_id', 'smiles', 'mol') + tuple(tasks))
mol_df = pd.DataFrame(df_rows,
columns=('mol_id', 'smiles', 'mol') +
tuple(tasks) + ('pos_x', 'pos_y', 'pos_z'))
yield mol_df
df_rows = []

Expand Down Expand Up @@ -312,8 +335,8 @@ def load_csv_files(input_files: List[str],
else:
logger.info("About to start loading CSV from %s" % input_file)
for df in pd.read_csv(input_file, chunksize=shard_size):
logger.info(
"Loading shard %d of size %s." % (shard_num, str(shard_size)))
logger.info("Loading shard %d of size %s." %
(shard_num, str(shard_size)))
df = df.replace(np.nan, str(""), regex=True)
shard_num += 1
yield df
Expand Down Expand Up @@ -346,10 +369,12 @@ def load_json_files(input_files: List[str],
yield pd.read_json(input_file, orient='records', lines=True)
else:
logger.info("About to start loading json from %s." % input_file)
for df in pd.read_json(
input_file, orient='records', chunksize=shard_size, lines=True):
logger.info(
"Loading shard %d of size %s." % (shard_num, str(shard_size)))
for df in pd.read_json(input_file,
orient='records',
chunksize=shard_size,
lines=True):
logger.info("Loading shard %d of size %s." %
(shard_num, str(shard_size)))
df = df.replace(np.nan, str(""), regex=True)
shard_num += 1
yield df
Expand Down Expand Up @@ -504,9 +529,11 @@ def load_from_disk(filename: str) -> Any:
raise ValueError("Unrecognized filetype for %s" % filename)


def load_dataset_from_disk(save_dir: str) -> Tuple[bool, Optional[Tuple[
"dc.data.DiskDataset", "dc.data.DiskDataset", "dc.data.DiskDataset"]], List[
"dc.trans.Transformer"]]:
def load_dataset_from_disk(
save_dir: str
) -> Tuple[bool, Optional[Tuple["dc.data.DiskDataset", "dc.data.DiskDataset",
"dc.data.DiskDataset"]],
List["dc.trans.Transformer"]]:
"""Loads MoleculeNet train/valid/test/transformers from disk.
Expects that data was saved using `save_dataset_to_disk` below. Expects the
Expand Down Expand Up @@ -556,9 +583,10 @@ def load_dataset_from_disk(save_dir: str) -> Tuple[bool, Optional[Tuple[
return loaded, all_dataset, transformers


def save_dataset_to_disk(
save_dir: str, train: "dc.data.DiskDataset", valid: "dc.data.DiskDataset",
test: "dc.data.DiskDataset", transformers: List["dc.trans.Transformer"]):
def save_dataset_to_disk(save_dir: str, train: "dc.data.DiskDataset",
valid: "dc.data.DiskDataset",
test: "dc.data.DiskDataset",
transformers: List["dc.trans.Transformer"]):
"""Utility used by MoleculeNet to save train/valid/test datasets.
This utility function saves a train/valid/test split of a dataset along
Expand Down
File renamed without changes.
File renamed without changes.
28 changes: 28 additions & 0 deletions deepchem/utils/test/assets/gdb9_small.sdf
@@ -0,0 +1,28 @@
gdb_1
-OEChem-03231823243D

5 4 0 0 0 0 0 0 0999 V2000
-0.0127 1.0858 0.0080 C 0 0 0 0 0 0 0 0 0 0 0 0
0.0022 -0.0060 0.0020 H 0 0 0 0 0 0 0 0 0 0 0 0
1.0117 1.4638 0.0003 H 0 0 0 0 0 0 0 0 0 0 0 0
-0.5408 1.4475 -0.8766 H 0 0 0 0 0 0 0 0 0 0 0 0
-0.5238 1.4379 0.9064 H 0 0 0 0 0 0 0 0 0 0 0 0
1 2 1 0 0 0 0
1 3 1 0 0 0 0
1 4 1 0 0 0 0
1 5 1 0 0 0 0
M END
$$$$
gdb_2
-OEChem-03231823233D

4 3 0 0 0 0 0 0 0999 V2000
-0.0404 1.0241 0.0626 N 0 0 0 0 0 0 0 0 0 0 0 0
0.0173 0.0125 -0.0274 H 0 0 0 0 0 0 0 0 0 0 0 0
0.9158 1.3587 -0.0288 H 0 0 0 0 0 0 0 0 0 0 0 0
-0.5203 1.3435 -0.7755 H 0 0 0 0 0 0 0 0 0 0 0 0
1 2 1 0 0 0 0
1 3 1 0 0 0 0
1 4 1 0 0 0 0
M END
$$$$
File renamed without changes.

0 comments on commit 4491b78

Please sign in to comment.