Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding more tests for resharding and legacy-to-modern conversion
- Loading branch information
Bharath Ramsundar
authored and
Bharath Ramsundar
committed
Jul 31, 2020
1 parent
fdcd5ea
commit f888b87
Showing
6 changed files
with
226 additions
and
96 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,61 +1,52 @@ | ||
""" | ||
Testing singletask/multitask dataset merging | ||
""" | ||
__author__ = "Bharath Ramsundar" | ||
__copyright__ = "Copyright 2016, Stanford University" | ||
__license__ = "MIT" | ||
|
||
import os | ||
import shutil | ||
import tempfile | ||
import unittest | ||
import deepchem as dc | ||
import numpy as np | ||
|
||
|
||
class TestMerge(unittest.TestCase): | ||
""" | ||
Test singletask/multitask dataset merging. | ||
""" | ||
def test_merge(): | ||
"""Test that datasets can be merged.""" | ||
current_dir = os.path.dirname(os.path.realpath(__file__)) | ||
|
||
def test_merge(self): | ||
"""Test that datasets can be merged.""" | ||
current_dir = os.path.dirname(os.path.realpath(__file__)) | ||
dataset_file = os.path.join(current_dir, "../../models/tests/example.csv") | ||
|
||
dataset_file = os.path.join(current_dir, "../../models/tests/example.csv") | ||
featurizer = dc.feat.CircularFingerprint(size=1024) | ||
tasks = ["log-solubility"] | ||
loader = dc.data.CSVLoader( | ||
tasks=tasks, smiles_field="smiles", featurizer=featurizer) | ||
first_dataset = loader.create_dataset(dataset_file) | ||
second_dataset = loader.create_dataset(dataset_file) | ||
|
||
featurizer = dc.feat.CircularFingerprint(size=1024) | ||
tasks = ["log-solubility"] | ||
loader = dc.data.CSVLoader( | ||
tasks=tasks, smiles_field="smiles", featurizer=featurizer) | ||
first_dataset = loader.featurize(dataset_file) | ||
second_dataset = loader.featurize(dataset_file) | ||
merged_dataset = dc.data.DiskDataset.merge([first_dataset, second_dataset]) | ||
|
||
merged_dataset = dc.data.DiskDataset.merge([first_dataset, second_dataset]) | ||
assert len(merged_dataset) == len(first_dataset) + len(second_dataset) | ||
|
||
assert len(merged_dataset) == len(first_dataset) + len(second_dataset) | ||
|
||
def test_subset(self): | ||
"""Tests that subsetting of datasets works.""" | ||
current_dir = os.path.dirname(os.path.realpath(__file__)) | ||
def test_subset(): | ||
"""Tests that subsetting of datasets works.""" | ||
current_dir = os.path.dirname(os.path.realpath(__file__)) | ||
|
||
dataset_file = os.path.join(current_dir, "../../models/tests/example.csv") | ||
dataset_file = os.path.join(current_dir, "../../models/tests/example.csv") | ||
|
||
featurizer = dc.feat.CircularFingerprint(size=1024) | ||
tasks = ["log-solubility"] | ||
loader = dc.data.CSVLoader( | ||
tasks=tasks, smiles_field="smiles", featurizer=featurizer) | ||
dataset = loader.featurize(dataset_file, shard_size=2) | ||
featurizer = dc.feat.CircularFingerprint(size=1024) | ||
tasks = ["log-solubility"] | ||
loader = dc.data.CSVLoader( | ||
tasks=tasks, smiles_field="smiles", featurizer=featurizer) | ||
dataset = loader.create_dataset(dataset_file, shard_size=2) | ||
|
||
shard_nums = [1, 2] | ||
shard_nums = [1, 2] | ||
|
||
orig_ids = dataset.ids | ||
_, _, _, ids_1 = dataset.get_shard(1) | ||
_, _, _, ids_2 = dataset.get_shard(2) | ||
orig_ids = dataset.ids | ||
_, _, _, ids_1 = dataset.get_shard(1) | ||
_, _, _, ids_2 = dataset.get_shard(2) | ||
|
||
subset = dataset.subset(shard_nums) | ||
after_ids = dataset.ids | ||
subset = dataset.subset(shard_nums) | ||
after_ids = dataset.ids | ||
|
||
assert len(subset) == 4 | ||
assert sorted(subset.ids) == sorted(np.concatenate([ids_1, ids_2])) | ||
assert list(orig_ids) == list(after_ids) | ||
assert len(subset) == 4 | ||
assert sorted(subset.ids) == sorted(np.concatenate([ids_1, ids_2])) | ||
assert list(orig_ids) == list(after_ids) |
20 changes: 20 additions & 0 deletions
20
deepchem/data/tests/test_non_classification_regression_datasets.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import deepchem as dc | ||
import numpy as np | ||
|
||
|
||
def test_disk_generative_dataset(): | ||
"""Test for a hypothetical generative dataset.""" | ||
X = np.random.rand(100, 10, 10) | ||
y = np.random.rand(100, 10, 10) | ||
dataset = dc.data.DiskDataset.from_numpy(X, y) | ||
assert (dataset.X == X).all() | ||
assert (dataset.y == y).all() | ||
|
||
|
||
def test_numpy_generative_dataset(): | ||
"""Test for a hypothetical generative dataset.""" | ||
X = np.random.rand(100, 10, 10) | ||
y = np.random.rand(100, 10, 10) | ||
dataset = dc.data.NumpyDataset(X, y) | ||
assert (dataset.X == X).all() | ||
assert (dataset.y == y).all() |
Oops, something went wrong.