diff --git a/deepchem/feat/graph_features.py b/deepchem/feat/graph_features.py index 770396d118..836f536d9f 100644 --- a/deepchem/feat/graph_features.py +++ b/deepchem/feat/graph_features.py @@ -5,6 +5,8 @@ from deepchem.feat.mol_graphs import ConvMol, WeaveMol from deepchem.data import DiskDataset import logging +from typing import Optional, List +from deepchem.utils.typing import RDKitMol, RDKitAtom def one_of_k_encoding(x, allowable_set): @@ -398,12 +400,75 @@ def bond_features(bond, use_chirality=False): ] if use_chirality: bond_feats = bond_feats + one_of_k_encoding_unk( - str(bond.GetStereo()), possible_bond_stereo) + str(bond.GetStereo()), GraphConvCoonstants.possible_bond_stereo) return np.array(bond_feats) -def pair_features(mol, edge_list, canon_adj_list, bt_len=6, - graph_distance=True): +def max_pair_distance_pairs(mol: RDKitMol, + max_pair_distance: Optional[int]) -> np.ndarray: + """Helper method which finds atom pairs within max_pair_distance graph distance. + + This helper method is used to find atoms which are within max_pair_distance + graph_distance of one another. This is done by using the fact that the + powers of an adjacency matrix encode path connectivity information. In + particular, if `adj` is the adjacency matrix, then `adj**k` has a nonzero + value at `(i, j)` if and only if there exists a path of graph distance `k` + between `i` and `j`. To find all atoms within `max_pair_distance` of each + other, we can compute the adjacency matrix powers `[adj, adj**2, + ...,adj**max_pair_distance]` and find pairs which are nonzero in any of + these matrices. Since adjacency matrices and their powers are positive + numbers, this is simply the nonzero elements of `adj + adj**2 + ... + + adj**max_pair_distance`. + + Parameters + ---------- + mol: rdkit.Chem.rdchem.Mol + RDKit molecules + max_pair_distance: Optional[int], (default None) + This value can be a positive integer or None. This + parameter determines the maximum graph distance at which pair + features are computed. For example, if `max_pair_distance==2`, + then pair features are computed only for atoms at most graph + distance 2 apart. If `max_pair_distance` is `None`, all pairs are + considered (effectively infinite `max_pair_distance`) + + + Returns + ------- + np.ndarray + Of shape `(2, num_pairs)` where `num_pairs` is the total number of pairs + within `max_pair_distance` of one another. + """ + from rdkit import Chem + from rdkit.Chem import rdmolops + N = len(mol.GetAtoms()) + if (max_pair_distance is None or max_pair_distance >= N): + max_distance = N + elif max_pair_distance is not None and max_pair_distance <= 0: + raise ValueError( + "max_pair_distance must either be a positive integer or None") + elif max_pair_distance is not None: + max_distance = max_pair_distance + adj = rdmolops.GetAdjacencyMatrix(mol) + # Handle edge case of self-pairs (i, i) + sum_adj = np.eye(N) + for i in range(max_distance): + # Increment by 1 since we don't want 0-indexing + power = i + 1 + sum_adj += np.linalg.matrix_power(adj, power) + nonzero_locs = np.where(sum_adj != 0) + num_pairs = len(nonzero_locs[0]) + # This creates a matrix of shape (2, num_pairs) + pair_edges = np.reshape(np.array(list(zip(nonzero_locs))), (2, num_pairs)) + return pair_edges + + +def pair_features(mol: RDKitMol, + bond_features_map: dict, + bond_adj_list: List, + bt_len: int = 6, + graph_distance: bool = True, + max_pair_distance: Optional[int] = None) -> np.ndarray: """Helper method used to compute atom pair feature vectors. Many different featurization methods compute atom pair features @@ -415,16 +480,26 @@ def pair_features(mol, edge_list, canon_adj_list, bt_len=6, ---------- mol: RDKit Mol Molecule to compute features on. - edge_list: list - List of edges to consider - canon_adj_list: list of lists - `canon_adj_list[i]` is a list of the atom indices that atom `i` shares a - list. This list is symmetrical so if `j in canon_adj_list[i]` then `i in - canon_adj_list[j]`. + bond_features_map: dict + Dictionary that maps pairs of atom ids (say `(2, 3)` for a bond between + atoms 2 and 3) to the features for the bond between them. + bond_adj_list: list of lists + `bond_adj_list[i]` is a list of the atom indices that atom `i` shares a + bond with . This list is symmetrical so if `j in bond_adj_list[i]` then `i + in bond_adj_list[j]`. bt_len: int, optional (default 6) The number of different bond types to consider. graph_distance: bool, optional (default True) - If true, use graph distance between molecules. Else use euclidean distance. + If true, use graph distance between molecules. Else use euclidean + distance. The specified `mol` must have a conformer. Atomic + positions will be retrieved by calling `mol.getConformer(0)`. + max_pair_distance: Optional[int], (default None) + This value can be a positive integer or None. This + parameter determines the maximum graph distance at which pair + features are computed. For example, if `max_pair_distance==2`, + then pair features are computed only for atoms at most graph + distance 2 apart. If `max_pair_distance` is `None`, all pairs are + considered (effectively infinite `max_pair_distance`) Note ---- @@ -433,32 +508,65 @@ def pair_features(mol, edge_list, canon_adj_list, bt_len=6, Returns ------- features: np.ndarray - Of shape `(N, N, bt_len + max_distance + 1)`. This is the array of pairwise - features for all atom pairs. + Of shape `(N_edges, bt_len + max_distance + 1)`. This is the array + of pairwise features for all atom pairs, where N_edges is the + number of edges within max_pair_distance of one another in this + molecules. + pair_edges: np.ndarray + Of shape `(2, num_pairs)` where `num_pairs` is the total number of + pairs within `max_pair_distance` of one another. """ if graph_distance: max_distance = 7 else: max_distance = 1 N = mol.GetNumAtoms() - features = np.zeros((N, N, bt_len + max_distance + 1)) + pair_edges = max_pair_distance_pairs(mol, max_pair_distance) + num_pairs = pair_edges.shape[1] + N_edges = pair_edges.shape[1] + features = np.zeros((N_edges, bt_len + max_distance + 1)) + # Get mapping + mapping = {} + for n in range(N_edges): + a1, a2 = pair_edges[:, n] + mapping[(int(a1), int(a2))] = n num_atoms = mol.GetNumAtoms() rings = mol.GetRingInfo().AtomRings() for a1 in range(num_atoms): - for a2 in canon_adj_list[a1]: + for a2 in bond_adj_list[a1]: # first `bt_len` features are bond features(if applicable) - features[a1, a2, :bt_len] = np.asarray( - edge_list[tuple(sorted((a1, a2)))], dtype=float) + if (int(a1), int(a2)) not in mapping: + raise ValueError( + "Malformed molecule with bonds not in specified graph distance.") + else: + n = mapping[(int(a1), int(a2))] + features[n, :bt_len] = np.asarray( + bond_features_map[tuple(sorted((a1, a2)))], dtype=float) for ring in rings: if a1 in ring: - # `bt_len`-th feature is if the pair of atoms are in the same ring - features[a1, ring, bt_len] = 1 - features[a1, a1, bt_len] = 0. + for a2 in ring: + if (int(a1), int(a2)) not in mapping: + # For ring pairs outside max pairs distance continue + continue + else: + n = mapping[(int(a1), int(a2))] + # `bt_len`-th feature is if the pair of atoms are in the same ring + if a2 == a1: + features[n, bt_len] = 0 + else: + features[n, bt_len] = 1 # graph distance between two atoms if graph_distance: + # distance is a matrix of 1-hot encoded distances for all atoms distance = find_distance( - a1, num_atoms, canon_adj_list, max_distance=max_distance) - features[a1, :, bt_len + 1:] = distance + a1, num_atoms, bond_adj_list, max_distance=max_distance) + for a2 in range(num_atoms): + if (int(a1), int(a2)) not in mapping: + # For ring pairs outside max pairs distance continue + continue + else: + n = mapping[(int(a1), int(a2))] + features[n, bt_len + 1:] = distance[a2] # Euclidean distance between atoms if not graph_distance: coords = np.zeros((N, 3)) @@ -469,10 +577,11 @@ def pair_features(mol, edge_list, canon_adj_list, bt_len=6, np.stack([coords] * N, axis=1) - \ np.stack([coords] * N, axis=0)), axis=2)) - return features + return features, pair_edges -def find_distance(a1, num_atoms, canon_adj_list, max_distance=7): +def find_distance(a1: RDKitAtom, num_atoms: int, bond_adj_list, + max_distance=7) -> np.ndarray: """Computes distances from provided atom. Parameters @@ -481,10 +590,10 @@ def find_distance(a1, num_atoms, canon_adj_list, max_distance=7): The source atom to compute distances from. num_atoms: int The total number of atoms. - canon_adj_list: list of lists - `canon_adj_list[i]` is a list of the atom indices that atom `i` shares a - list. This list is symmetrical so if `j in canon_adj_list[i]` then `i in - canon_adj_list[j]`. + bond_adj_list: list of lists + `bond_adj_list[i]` is a list of the atom indices that atom `i` shares a + bond with. This list is symmetrical so if `j in bond_adj_list[i]` then `i in + bond_adj_list[j]`. max_distance: int, optional (default 7) The max distance to search. @@ -498,7 +607,7 @@ def find_distance(a1, num_atoms, canon_adj_list, max_distance=7): distance = np.zeros((num_atoms, max_distance)) radial = 0 # atoms `radial` bonds away from `a1` - adj_list = set(canon_adj_list[a1]) + adj_list = set(bond_adj_list[a1]) # atoms less than `radial` bonds away all_list = set([a1]) while radial < max_distance: @@ -507,7 +616,7 @@ def find_distance(a1, num_atoms, canon_adj_list, max_distance=7): # find atoms `radial`+1 bonds away next_adj = set() for adj in adj_list: - next_adj.update(canon_adj_list[adj]) + next_adj.update(bond_adj_list[adj]) adj_list = next_adj - all_list radial = radial + 1 return distance @@ -647,6 +756,14 @@ class WeaveFeaturizer(MolecularFeaturizer): descriptors for each pair of atoms. These extra descriptors may provide for additional descriptive power but at the cost of a larger featurized dataset. + + Examples + -------- + >>> import deepchem as dc + >>> mols = ["C", "CCC"] + >>> featurizer = dc.feat.WeaveFeaturizer() + >>> X = featurizer.featurize(mols) + References ---------- .. [1] Kearnes, Steven, et al. "Molecular graph convolutions: moving beyond @@ -660,18 +777,31 @@ class WeaveFeaturizer(MolecularFeaturizer): name = ['weave_mol'] - def __init__(self, graph_distance=True, explicit_H=False, - use_chirality=False): - """ + def __init__(self, + graph_distance: bool = True, + explicit_H: bool = False, + use_chirality: bool = False, + max_pair_distance: Optional[int] = None): + """Initialize this featurizer with set parameters. + Parameters ---------- - graph_distance: bool, optional - If true, use graph distance. Otherwise, use Euclidean - distance. - explicit_H: bool, optional + graph_distance: bool, (default True) + If True, use graph distance for distance features. Otherwise, use + Euclidean distance. Note that this means that molecules that this + featurizer is invoked on must have valid conformer information if this + option is set. + explicit_H: bool, (default False) If true, model hydrogens in the molecule. - use_chirality: bool, optional + use_chirality: bool, (default False) If true, use chiral information in the featurization + max_pair_distance: Optional[int], (default None) + This value can be a positive integer or None. This + parameter determines the maximum graph distance at which pair + features are computed. For example, if `max_pair_distance==2`, + then pair features are computed only for atoms at most graph + distance 2 apart. If `max_pair_distance` is `None`, all pairs are + considered (effectively infinite `max_pair_distance`) """ # Distance is either graph distance(True) or Euclidean distance(False, # only support datasets providing Cartesian coordinates) @@ -682,9 +812,13 @@ def __init__(self, graph_distance=True, explicit_H=False, self.explicit_H = explicit_H # If uses use_chirality self.use_chirality = use_chirality + if isinstance(max_pair_distance, int) and max_pair_distance <= 0: + raise ValueError( + "max_pair_distance must either be a positive integer or None") + self.max_pair_distance = max_pair_distance if self.use_chirality: - self.bt_len = int( - GraphConvConstants.bond_fdim_base) + len(possible_bond_stereo) + self.bt_len = int(GraphConvConstants.bond_fdim_base) + len( + GraphConvConstants.possible_bond_stereo) else: self.bt_len = int(GraphConvConstants.bond_fdim_base) @@ -704,27 +838,28 @@ def _featurize(self, mol): nodes = np.vstack(nodes) # Get bond lists - edge_list = {} + bond_features_map = {} for b in mol.GetBonds(): - edge_list[tuple(sorted([b.GetBeginAtomIdx(), - b.GetEndAtomIdx()]))] = bond_features( - b, use_chirality=self.use_chirality) + bond_features_map[tuple(sorted([b.GetBeginAtomIdx(), + b.GetEndAtomIdx()]))] = bond_features( + b, use_chirality=self.use_chirality) # Get canonical adjacency list - canon_adj_list = [[] for mol_id in range(len(nodes))] - for edge in edge_list.keys(): - canon_adj_list[edge[0]].append(edge[1]) - canon_adj_list[edge[1]].append(edge[0]) + bond_adj_list = [[] for mol_id in range(len(nodes))] + for bond in bond_features_map.keys(): + bond_adj_list[bond[0]].append(bond[1]) + bond_adj_list[bond[1]].append(bond[0]) # Calculate pair features - pairs = pair_features( + pairs, pair_edges = pair_features( mol, - edge_list, - canon_adj_list, + bond_features_map, + bond_adj_list, bt_len=self.bt_len, - graph_distance=self.graph_distance) + graph_distance=self.graph_distance, + max_pair_distance=self.max_pair_distance) - return WeaveMol(nodes, pairs) + return WeaveMol(nodes, pairs, pair_edges) class AtomicConvFeaturizer(ComplexNeighborListFragmentAtomicCoordinates): diff --git a/deepchem/feat/mol_graphs.py b/deepchem/feat/mol_graphs.py index cb269be78d..6facdbad0f 100644 --- a/deepchem/feat/mol_graphs.py +++ b/deepchem/feat/mol_graphs.py @@ -1,10 +1,6 @@ """ Data Structures used to represented molecules for convolutions. """ -__author__ = "Han Altae-Tran and Bharath Ramsundar" -__copyright__ = "Copyright 2016, Stanford University" -__license__ = "MIT" - import csv import random import numpy as np @@ -375,16 +371,23 @@ class WeaveMol(object): """Molecular featurization object for weave convolutions. These objects are produced by WeaveFeaturizer, and feed into - WeaveModel. The underlying implementation is inspired by: + WeaveModel. The underlying implementation is inspired by [1]_. + - Kearnes, Steven, et al. "Molecular graph convolutions: moving beyond fingerprints." Journal of computer-aided molecular design 30.8 (2016): 595-608. + References + ---------- + .. [1] Kearnes, Steven, et al. "Molecular graph convolutions: moving beyond fingerprints." Journal of computer-aided molecular design 30.8 (2016): 595-608. """ - def __init__(self, nodes, pairs): + def __init__(self, nodes, pairs, pair_edges): self.nodes = nodes self.pairs = pairs self.num_atoms = self.nodes.shape[0] self.n_features = self.nodes.shape[1] + self.pair_edges = pair_edges + + def get_pair_edges(self): + return self.pair_edges def get_pair_features(self): return self.pairs diff --git a/deepchem/feat/tests/test_graph_features.py b/deepchem/feat/tests/test_graph_features.py index e94cb38fdf..9a4f27b631 100644 --- a/deepchem/feat/tests/test_graph_features.py +++ b/deepchem/feat/tests/test_graph_features.py @@ -1,10 +1,6 @@ """ Tests for ConvMolFeaturizer. """ -__author__ = "Han Altae-Tran and Bharath Ramsundar" -__copyright__ = "Copyright 2016, Stanford University" -__license__ = "MIT" - import unittest import os import numpy as np diff --git a/deepchem/feat/tests/test_weave.py b/deepchem/feat/tests/test_weave.py new file mode 100644 index 0000000000..40e6eee646 --- /dev/null +++ b/deepchem/feat/tests/test_weave.py @@ -0,0 +1,129 @@ +""" +Tests for weave featurizer. +""" +import numpy as np +import deepchem as dc +from deepchem.feat.graph_features import max_pair_distance_pairs + + +def test_max_pair_distance_pairs(): + """Test that max pair distance pairs are computed properly.""" + from rdkit import Chem + # Carbon + mol = Chem.MolFromSmiles('C') + # Test distance 1 + pair_edges = max_pair_distance_pairs(mol, 1) + assert pair_edges.shape == (2, 1) + assert np.all(pair_edges.flatten() == np.array([0, 0])) + # Test distance 2 + pair_edges = max_pair_distance_pairs(mol, 2) + assert pair_edges.shape == (2, 1) + assert np.all(pair_edges.flatten() == np.array([0, 0])) + + # Test alkane + mol = Chem.MolFromSmiles('CCC') + # Test distance 1 + pair_edges = max_pair_distance_pairs(mol, 1) + # 3 self connections and 2 bonds which are both counted twice because of + # symmetry for 7 total + assert pair_edges.shape == (2, 7) + # Test distance 2 + pair_edges = max_pair_distance_pairs(mol, 2) + # Everything is connected at this distance + assert pair_edges.shape == (2, 9) + + +def test_max_pair_distance_infinity(): + """Test that max pair distance pairs are computed properly with infinity distance.""" + from rdkit import Chem + # Test alkane + mol = Chem.MolFromSmiles('CCC') + # Test distance infinity + pair_edges = max_pair_distance_pairs(mol, None) + # Everything is connected at this distance + assert pair_edges.shape == (2, 9) + + # Test pentane + mol = Chem.MolFromSmiles('CCCCC') + # Test distance infinity + pair_edges = max_pair_distance_pairs(mol, None) + # Everything is connected at this distance + assert pair_edges.shape == (2, 25) + + +def test_weave_single_carbon(): + """Test that single carbon atom is featurized properly.""" + mols = ['C'] + featurizer = dc.feat.WeaveFeaturizer() + #from rdkit import Chem + mol_list = featurizer.featurize(mols) + mol = mol_list[0] + #mol = featurizer._featurize(Chem.MolFromSmiles("C")) + + # Only one carbon + assert mol.get_num_atoms() == 1 + + # Test feature sizes + assert mol.get_num_features() == 75 + + # No bonds, so only 1 pair feature (for the self interaction) + assert mol.get_pair_features().shape == (1 * 1, 14) + + +def test_weave_alkane(): + """Test on simple alkane""" + mols = ['CCC'] + featurizer = dc.feat.WeaveFeaturizer() + mol_list = featurizer.featurize(mols) + mol = mol_list[0] + + # 3 carbonds in alkane + assert mol.get_num_atoms() == 3 + + # Test feature sizes + assert mol.get_num_features() == 75 + + # Should be a 3x3 interaction grid + assert mol.get_pair_features().shape == (3 * 3, 14) + + +def test_weave_alkane_max_pairs(): + """Test on simple alkane with max pairs distance cutoff""" + mols = ['CCC'] + featurizer = dc.feat.WeaveFeaturizer(max_pair_distance=1) + #mol_list = featurizer.featurize(mols) + #mol = mol_list[0] + from rdkit import Chem + mol = featurizer._featurize(Chem.MolFromSmiles(mols[0])) + + # 3 carbonds in alkane + assert mol.get_num_atoms() == 3 + + # Test feature sizes + assert mol.get_num_features() == 75 + + # Should be a 7x14 interaction grid since there are 7 pairs within graph + # distance 1 (3 self interactions plus 2 bonds counted twice because of + # symmetry) + assert mol.get_pair_features().shape == (7, 14) + + +def test_carbon_nitrogen(): + """Test on carbon nitrogen molecule""" + # Note there is a central nitrogen of degree 4, with 4 carbons + # of degree 1 (connected only to central nitrogen). + mols = ['C[N+](C)(C)C'] + #import rdkit.Chem + #mols = [rdkit.Chem.MolFromSmiles(s) for s in raw_smiles] + featurizer = dc.feat.WeaveFeaturizer() + mols = featurizer.featurize(mols) + mol = mols[0] + + # 5 atoms in compound + assert mol.get_num_atoms() == 5 + + # Test feature sizes + assert mol.get_num_features() == 75 + + # Should be a 3x3 interaction grid + assert mol.get_pair_features().shape == (5 * 5, 14) diff --git a/deepchem/models/graph_models.py b/deepchem/models/graph_models.py index e4b929c3ae..d64aaaad54 100644 --- a/deepchem/models/graph_models.py +++ b/deepchem/models/graph_models.py @@ -197,7 +197,6 @@ def __init__(self, self.n_classes = n_classes # Build the model. - atom_features = Input(shape=(self.n_atom_feat[0],)) pair_features = Input(shape=(self.n_pair_feat[0],)) pair_split = Input(shape=tuple(), dtype=tf.int32) @@ -277,6 +276,71 @@ def __init__(self, super(WeaveModel, self).__init__( model, loss, output_types=output_types, batch_size=batch_size, **kwargs) + def compute_features_on_batch(self, X_b): + """Compute tensors that will be input into the model from featurized representation. + + The featurized input to `WeaveModel` is instances of `WeaveMol` created by + `WeaveFeaturizer`. This method converts input `WeaveMol` objects into + tensors used by the Keras implementation to compute `WeaveModel` outputs. + + Parameters + ---------- + X_b: np.ndarray + A numpy array with dtype=object where elements are `WeaveMol` objects. + + Returns + ------- + atom_feat: np.ndarray + Of shape `(N_atoms, N_atom_feat)`. + pair_feat: np.ndarray + Of shape `(N_pairs, N_pair_feat)`. Note that `N_pairs` will depend on + the number of pairs being considered. If `max_pair_distance` is + `None`, then this will be `N_atoms**2`. Else it will be the number + of pairs within the specifed graph distance. + pair_split: np.ndarray + Of shape `(N_pairs,)`. The i-th entry in this array will tell you the + originating atom for this pair (the "source"). Note that pairs are + symmetric so for a pair `(a, b)`, both `a` and `b` will separately be + sources at different points in this array. + atom_split: np.ndarray + Of shape `(N_atoms,)`. The i-th entry in this array will be the molecule + with the i-th atom belongs to. + atom_to_pair: np.ndarray + Of shape `(N_pairs, 2)`. The i-th row in this array will be the array + `[a, b]` if `(a, b)` is a pair to be considered. (Note by symmetry, this + implies some other row will contain `[b, a]`. + """ + atom_feat = [] + pair_feat = [] + atom_split = [] + atom_to_pair = [] + pair_split = [] + start = 0 + for im, mol in enumerate(X_b): + n_atoms = mol.get_num_atoms() + # pair_edges is of shape (2, N) + pair_edges = mol.get_pair_edges() + N_pairs = pair_edges[1] + # number of atoms in each molecule + atom_split.extend([im] * n_atoms) + # index of pair features + C0, C1 = np.meshgrid(np.arange(n_atoms), np.arange(n_atoms)) + atom_to_pair.append(pair_edges.T + start) + # Get starting pair atoms + pair_starts = pair_edges.T[:, 0] + # number of pairs for each atom + pair_split.extend(pair_starts + start) + start = start + n_atoms + + # atom features + atom_feat.append(mol.get_atom_features()) + # pair features + pair_feat.append(mol.get_pair_features()) + + return (np.concatenate(atom_feat, axis=0), np.concatenate( + pair_feat, axis=0), np.array(pair_split), np.array(atom_split), + np.concatenate(atom_to_pair, axis=0)) + def default_generator( self, dataset: Dataset, @@ -313,40 +377,7 @@ def default_generator( if self.mode == 'classification': y_b = to_one_hot(y_b.flatten(), self.n_classes).reshape( -1, self.n_tasks, self.n_classes) - atom_feat = [] - pair_feat = [] - atom_split = [] - atom_to_pair = [] - pair_split = [] - start = 0 - for im, mol in enumerate(X_b): - n_atoms = mol.get_num_atoms() - # number of atoms in each molecule - atom_split.extend([im] * n_atoms) - # index of pair features - C0, C1 = np.meshgrid(np.arange(n_atoms), np.arange(n_atoms)) - atom_to_pair.append( - np.transpose( - np.array([C1.flatten() + start, - C0.flatten() + start]))) - # number of pairs for each atom - pair_split.extend(C1.flatten() + start) - start = start + n_atoms - - # atom features - atom_feat.append(mol.get_atom_features()) - # pair features - pair_feat.append( - np.reshape(mol.get_pair_features(), - (n_atoms * n_atoms, self.n_pair_feat[0]))) - - inputs = [ - np.concatenate(atom_feat, axis=0), - np.concatenate(pair_feat, axis=0), - np.array(pair_split), - np.array(atom_split), - np.concatenate(atom_to_pair, axis=0) - ] + inputs = self.compute_features_on_batch(X_b) yield (inputs, [y_b], [w_b]) diff --git a/deepchem/models/tests/test_graph_models.py b/deepchem/models/tests/test_graph_models.py index 4392102f19..96376582ba 100644 --- a/deepchem/models/tests/test_graph_models.py +++ b/deepchem/models/tests/test_graph_models.py @@ -141,51 +141,6 @@ def test_graph_conv_atom_features(): y_pred1 = model.predict(dataset) -@flaky -@pytest.mark.slow -def test_weave_model(): - tasks, dataset, transformers, metric = get_dataset('classification', 'Weave') - - batch_size = 20 - model = WeaveModel( - len(tasks), - batch_size=batch_size, - mode='classification', - fully_connected_layer_sizes=[2000, 1000], - batch_normalize=True, - batch_normalize_kwargs={ - "fused": False, - "trainable": True, - "renorm": True - }, - learning_rage=0.0005) - model.fit(dataset, nb_epoch=200) - scores = model.evaluate(dataset, [metric], transformers) - assert scores['mean-roc_auc_score'] >= 0.9 - - -@pytest.mark.slow -def test_weave_regression_model(): - import numpy as np - import tensorflow as tf - tf.random.set_seed(123) - np.random.seed(123) - tasks, dataset, transformers, metric = get_dataset('regression', 'Weave') - - batch_size = 10 - model = WeaveModel( - len(tasks), - batch_size=batch_size, - mode='regression', - batch_normalize=False, - fully_connected_layer_sizes=[], - dropouts=0, - learning_rate=0.0005) - model.fit(dataset, nb_epoch=200) - scores = model.evaluate(dataset, [metric], transformers) - assert scores['mean_absolute_error'] < 0.1 - - @pytest.mark.slow def test_dag_model(): tasks, dataset, transformers, metric = get_dataset('classification', diff --git a/deepchem/models/tests/test_overfit.py b/deepchem/models/tests/test_overfit.py index 06aa1bb00b..c56cd058d3 100644 --- a/deepchem/models/tests/test_overfit.py +++ b/deepchem/models/tests/test_overfit.py @@ -2,10 +2,6 @@ Tests to make sure deepchem models can overfit on tiny datasets. """ -__author__ = "Bharath Ramsundar" -__copyright__ = "Copyright 2016, Stanford University" -__license__ = "MIT" - import os import numpy as np diff --git a/deepchem/models/tests/test_weave_models.py b/deepchem/models/tests/test_weave_models.py new file mode 100644 index 0000000000..c1e274a0eb --- /dev/null +++ b/deepchem/models/tests/test_weave_models.py @@ -0,0 +1,217 @@ +import unittest +import os +import numpy as np +import pytest +import scipy + +import deepchem as dc +from deepchem.data import NumpyDataset +from deepchem.models import GraphConvModel, DAGModel, WeaveModel, MPNNModel +from deepchem.molnet import load_bace_classification, load_delaney +from deepchem.feat import ConvMolFeaturizer + +from flaky import flaky + + +def get_dataset(mode='classification', featurizer='GraphConv', num_tasks=2): + data_points = 20 + if mode == 'classification': + tasks, all_dataset, transformers = load_bace_classification( + featurizer, reload=False) + else: + tasks, all_dataset, transformers = load_delaney(featurizer, reload=False) + + train, valid, test = all_dataset + for i in range(1, num_tasks): + tasks.append("random_task") + w = np.ones(shape=(data_points, len(tasks))) + + if mode == 'classification': + y = np.random.randint(0, 2, size=(data_points, len(tasks))) + metric = dc.metrics.Metric( + dc.metrics.roc_auc_score, np.mean, mode="classification") + else: + y = np.random.normal(size=(data_points, len(tasks))) + metric = dc.metrics.Metric( + dc.metrics.mean_absolute_error, mode="regression") + + ds = NumpyDataset(train.X[:data_points], y, w, train.ids[:data_points]) + + return tasks, ds, transformers, metric + + +def test_compute_features_on_infinity_distance(): + """Test that WeaveModel correctly transforms WeaveMol objects into tensors with infinite max_pair_distance.""" + featurizer = dc.feat.WeaveFeaturizer(max_pair_distance=None) + X = featurizer(["C", "CCC"]) + batch_size = 20 + model = WeaveModel( + 1, + batch_size=batch_size, + mode='classification', + fully_connected_layer_sizes=[2000, 1000], + batch_normalize=True, + batch_normalize_kwargs={ + "fused": False, + "trainable": True, + "renorm": True + }, + learning_rage=0.0005) + atom_feat, pair_feat, pair_split, atom_split, atom_to_pair = model.compute_features_on_batch( + X) + + # There are 4 atoms each of which have 75 atom features + assert atom_feat.shape == (4, 75) + # There are 10 pairs with infinity distance and 14 pair features + assert pair_feat.shape == (10, 14) + # 4 atoms in total + assert atom_split.shape == (4,) + assert np.all(atom_split == np.array([0, 1, 1, 1])) + # 10 pairs in total + assert pair_split.shape == (10,) + assert np.all(pair_split == np.array([0, 1, 1, 1, 2, 2, 2, 3, 3, 3])) + # 10 pairs in total each with start/finish + assert atom_to_pair.shape == (10, 2) + assert np.all( + atom_to_pair == np.array([[0, 0], [1, 1], [1, 2], [1, 3], [2, 1], [2, 2], + [2, 3], [3, 1], [3, 2], [3, 3]])) + + +def test_compute_features_on_distance_1(): + """Test that WeaveModel correctly transforms WeaveMol objects into tensors with finite max_pair_distance.""" + featurizer = dc.feat.WeaveFeaturizer(max_pair_distance=1) + X = featurizer(["C", "CCC"]) + batch_size = 20 + model = WeaveModel( + 1, + batch_size=batch_size, + mode='classification', + fully_connected_layer_sizes=[2000, 1000], + batch_normalize=True, + batch_normalize_kwargs={ + "fused": False, + "trainable": True, + "renorm": True + }, + learning_rage=0.0005) + atom_feat, pair_feat, pair_split, atom_split, atom_to_pair = model.compute_features_on_batch( + X) + + # There are 4 atoms each of which have 75 atom features + assert atom_feat.shape == (4, 75) + # There are 8 pairs with distance 1 and 14 pair features. (To see why 8, + # there's the self pair for "C". For "CCC" there are 7 pairs including self + # connections and accounting for symmetry.) + assert pair_feat.shape == (8, 14) + # 4 atoms in total + assert atom_split.shape == (4,) + assert np.all(atom_split == np.array([0, 1, 1, 1])) + # 10 pairs in total + assert pair_split.shape == (8,) + # The center atom is self connected and to both neighbors so it appears + # thrice. The canonical ranking used in MolecularFeaturizer means this + # central atom is ranked last in ordering. + assert np.all(pair_split == np.array([0, 1, 1, 2, 2, 3, 3, 3])) + # 10 pairs in total each with start/finish + assert atom_to_pair.shape == (8, 2) + assert np.all(atom_to_pair == np.array([[0, 0], [1, 1], [1, 3], [2, 2], + [2, 3], [3, 1], [3, 2], [3, 3]])) + + +@flaky +@pytest.mark.slow +def test_weave_model(): + tasks, dataset, transformers, metric = get_dataset('classification', 'Weave') + + batch_size = 20 + model = WeaveModel( + len(tasks), + batch_size=batch_size, + mode='classification', + fully_connected_layer_sizes=[2000, 1000], + batch_normalize=True, + batch_normalize_kwargs={ + "fused": False, + "trainable": True, + "renorm": True + }, + learning_rage=0.0005) + model.fit(dataset, nb_epoch=200) + scores = model.evaluate(dataset, [metric], transformers) + assert scores['mean-roc_auc_score'] >= 0.9 + + +@pytest.mark.slow +def test_weave_regression_model(): + import numpy as np + import tensorflow as tf + tf.random.set_seed(123) + np.random.seed(123) + tasks, dataset, transformers, metric = get_dataset('regression', 'Weave') + + batch_size = 10 + model = WeaveModel( + len(tasks), + batch_size=batch_size, + mode='regression', + batch_normalize=False, + fully_connected_layer_sizes=[], + dropouts=0, + learning_rate=0.0005) + model.fit(dataset, nb_epoch=200) + scores = model.evaluate(dataset, [metric], transformers) + assert scores['mean_absolute_error'] < 0.1 + + +def test_weave_fit_simple_infinity_distance(): + featurizer = dc.feat.WeaveFeaturizer(max_pair_distance=None) + X = featurizer(["C", "CCC"]) + y = np.array([0, 1.]) + dataset = dc.data.NumpyDataset(X, y) + + batch_size = 20 + model = WeaveModel( + 1, + batch_size=batch_size, + mode='classification', + fully_connected_layer_sizes=[2000, 1000], + batch_normalize=True, + batch_normalize_kwargs={ + "fused": False, + "trainable": True, + "renorm": True + }, + learning_rage=0.0005) + model.fit(dataset, nb_epoch=200) + transformers = [] + metric = dc.metrics.Metric( + dc.metrics.roc_auc_score, np.mean, mode="classification") + scores = model.evaluate(dataset, [metric], transformers) + assert scores['mean-roc_auc_score'] >= 0.9 + + +def test_weave_fit_simple_distance_1(): + featurizer = dc.feat.WeaveFeaturizer(max_pair_distance=1) + X = featurizer(["C", "CCC"]) + y = np.array([0, 1.]) + dataset = dc.data.NumpyDataset(X, y) + + batch_size = 20 + model = WeaveModel( + 1, + batch_size=batch_size, + mode='classification', + fully_connected_layer_sizes=[2000, 1000], + batch_normalize=True, + batch_normalize_kwargs={ + "fused": False, + "trainable": True, + "renorm": True + }, + learning_rage=0.0005) + model.fit(dataset, nb_epoch=200) + transformers = [] + metric = dc.metrics.Metric( + dc.metrics.roc_auc_score, np.mean, mode="classification") + scores = model.evaluate(dataset, [metric], transformers) + assert scores['mean-roc_auc_score'] >= 0.9 diff --git a/deepchem/utils/typing.py b/deepchem/utils/typing.py index ebd0f8e79d..ab1423789a 100644 --- a/deepchem/utils/typing.py +++ b/deepchem/utils/typing.py @@ -19,6 +19,7 @@ # type of RDKit object RDKitMol = Any RDKitAtom = Any +RDKitBond = Any # type of Pymatgen object PymatgenStructure = Any