Skip to content

Commit

Permalink
changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Bharath Ramsundar authored and Bharath Ramsundar committed Jul 22, 2020
1 parent 692d22c commit 534a8e2
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 40 deletions.
10 changes: 4 additions & 6 deletions deepchem/data/data_loader.py
Expand Up @@ -407,11 +407,9 @@ def _featurize_shard(self,
valid_inds: np.ndarray
Indices of rows in source CSV with valid data.
"""
features = [
np.array(elt) for elt in self.featurizer(shard[self.feature_field])
]
features = [elt for elt in self.featurizer(shard[self.feature_field])]
valid_inds = np.array(
[1 if elt.size > 0 else 0 for elt in features], dtype=bool)
[1 if np.array(elt).size > 0 else 0 for elt in features], dtype=bool)
features = [
elt for (is_valid, elt) in zip(valid_inds, features) if is_valid
]
Expand Down Expand Up @@ -734,9 +732,9 @@ def _get_shards(self, input_files, shard_size):

def _featurize_shard(self, shard):
"""Featurizes a shard of an input dataframe."""
features = [np.array(elt) for elt in featurizer(shard[self.mol_field])]
features = [elt for elt in featurizer(shard[self.mol_field])]
valid_inds = np.array(
[1 if elt.size > 0 else 0 for elt in features], dtype=bool)
[1 if np.array(elt).size > 0 else 0 for elt in features], dtype=bool)
features = [
elt for (is_valid, elt) in zip(valid_inds, features) if is_valid
]
Expand Down
71 changes: 37 additions & 34 deletions deepchem/trans/tests/test_transformers.py
Expand Up @@ -69,6 +69,43 @@ def load_unlabelled_data():
return loader.featurize(input_file)


def test_featurization_transformer():
fp_size = 2048
tasks, all_dataset, transformers = load_delaney('Raw')
train = all_dataset[0]
transformer = FeaturizationTransformer(
transform_X=True,
dataset=train,
featurizer=dc.feat.CircularFingerprint(size=fp_size))
new_train = transformer.transform(train)

assert new_train.y.shape == train.y.shape
assert new_train.X.shape[-1] == fp_size


def test_DAG_transformer():
"""Tests the DAG transformer."""
np.random.seed(123)
tf.random.set_seed(123)
n_tasks = 1

# Load mini log-solubility dataset.
current_dir = os.path.dirname(os.path.abspath(__file__))
featurizer = dc.feat.ConvMolFeaturizer()
tasks = ["outcome"]
input_file = os.path.join(current_dir,
"../../models/tests/example_regression.csv")
loader = dc.data.CSVLoader(
tasks=tasks, smiles_field="smiles", featurizer=featurizer)
dataset = loader.create_dataset(input_file)
transformer = dc.trans.DAGTransformer(max_atoms=50)
dataset = transformer.transform(dataset)
# The transformer generates n DAGs for a molecule with n
# atoms. These are denoted the "parents"
for idm, mol in enumerate(dataset.X):
assert dataset.X[idm].get_num_atoms() == len(dataset.X[idm].parents)


class TestTransformers(unittest.TestCase):
"""
Test top-level API for transformer objects.
Expand Down Expand Up @@ -474,19 +511,6 @@ def test_IRV_transformer(self):
assert np.allclose(test_dataset_trans.X[0, 10:20], [0] * 10)
assert not np.isclose(dataset_trans.X[0, 0], 1.)

def test_featurization_transformer(self):
fp_size = 2048
tasks, all_dataset, transformers = load_delaney('Raw')
train = all_dataset[0]
transformer = FeaturizationTransformer(
transform_X=True,
dataset=train,
featurizer=dc.feat.CircularFingerprint(size=fp_size))
new_train = transformer.transform(train)

self.assertEqual(new_train.y.shape, train.y.shape)
self.assertEqual(new_train.X.shape[-1], fp_size)

def test_blurring(self):
# Check Blurring
dt = DataTransforms(self.d)
Expand Down Expand Up @@ -593,27 +617,6 @@ def test_salt_pepper_noise(self):
check_random_noise = dt.salt_pepper_noise(prob, salt=255, pepper=0)
assert np.allclose(random_noise, check_random_noise)

def test_DAG_transformer(self):
"""Tests the DAG transformer."""
np.random.seed(123)
tf.random.set_seed(123)
n_tasks = 1

# Load mini log-solubility dataset.
featurizer = dc.feat.ConvMolFeaturizer()
tasks = ["outcome"]
input_file = os.path.join(self.current_dir,
"../../models/tests/example_regression.csv")
loader = dc.data.CSVLoader(
tasks=tasks, smiles_field="smiles", featurizer=featurizer)
dataset = loader.create_dataset(input_file)
transformer = dc.trans.DAGTransformer(max_atoms=50)
dataset = transformer.transform(dataset)
# The transformer generates n DAGs for a molecule with n
# atoms. These are denoted the "parents"
for idm, mol in enumerate(dataset.X):
assert dataset.X[idm].get_num_atoms() == len(dataset.X[idm].parents)

def test_median_filter(self):
#Check median filter
from PIL import Image, ImageFilter
Expand Down

0 comments on commit 534a8e2

Please sign in to comment.