-
Notifications
You must be signed in to change notification settings - Fork 129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Zinc 250k dataset support #276
Merged
Merged
Changes from 8 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
d6fb90d
copy script from QM9 dataset
de434b9
zinc.py for automatically download zinc250k dataset
e09fd38
fix docstring for qm9, tox21
5a19ca7
update __init__
b5cc634
test, wip: copy test from qm9
37b4061
test: added zinc250k tests
35ca37e
added reference to zinc dataset
79f3af1
update support version
a07d403
add slow
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
from logging import getLogger | ||
import os | ||
|
||
from chainer.dataset import download | ||
import numpy | ||
import pandas | ||
|
||
from chainer_chemistry.dataset.parsers.csv_file_parser import CSVFileParser | ||
from chainer_chemistry.dataset.preprocessors.atomic_number_preprocessor import AtomicNumberPreprocessor # NOQA | ||
|
||
download_url = 'https://raw.githubusercontent.com/aspuru-guzik-group/chemical_vae/master/models/zinc_properties/250k_rndm_zinc_drugs_clean_3.csv' # NOQA | ||
file_name_250k = 'zinc250k.csv' | ||
|
||
_root = 'pfnet/chainer/zinc' | ||
|
||
_label_names = ['logP', 'qed', 'SAS'] | ||
_smiles_column_names = ['smiles'] | ||
|
||
|
||
def get_zinc250k_label_names(): | ||
"""Returns label names of ZINC250k datasets.""" | ||
return _label_names | ||
|
||
|
||
def get_zinc250k(preprocessor=None, labels=None, return_smiles=False, | ||
target_index=None): | ||
"""Downloads, caches and preprocesses Zinc 250K dataset. | ||
|
||
Args: | ||
preprocessor (BasePreprocessor): Preprocessor. | ||
This should be chosen based on the network to be trained. | ||
If it is None, default `AtomicNumberPreprocessor` is used. | ||
labels (str or list): List of target labels. | ||
return_smiles (bool): If set to ``True``, | ||
smiles array is also returned. | ||
target_index (list or None): target index list to partially extract | ||
dataset. If None (default), all examples are parsed. | ||
|
||
Returns: | ||
dataset, which is composed of `features`, which depends on | ||
`preprocess_method`. | ||
|
||
""" | ||
labels = labels or get_zinc250k_label_names() | ||
if isinstance(labels, str): | ||
labels = [labels, ] | ||
|
||
def postprocess_label(label_list): | ||
# This is regression task, cast to float value. | ||
return numpy.asarray(label_list, dtype=numpy.float32) | ||
|
||
if preprocessor is None: | ||
preprocessor = AtomicNumberPreprocessor() | ||
parser = CSVFileParser(preprocessor, postprocess_label=postprocess_label, | ||
labels=labels, smiles_col='smiles') | ||
result = parser.parse(get_zinc250k_filepath(), return_smiles=return_smiles, | ||
target_index=target_index) | ||
|
||
if return_smiles: | ||
return result['dataset'], result['smiles'] | ||
else: | ||
return result['dataset'] | ||
|
||
|
||
def get_zinc250k_filepath(download_if_not_exist=True): | ||
"""Construct a filepath which stores ZINC250k dataset for config_name | ||
|
||
This method check whether the file exist or not, and downloaded it if | ||
necessary. | ||
|
||
Args: | ||
download_if_not_exist (bool): If `True` download dataset | ||
if it is not downloaded yet. | ||
|
||
Returns (str): file path for ZINC250k dataset (csv format) | ||
|
||
""" | ||
cache_path = _get_zinc250k_filepath() | ||
if not os.path.exists(cache_path): | ||
if download_if_not_exist: | ||
is_successful = download_and_extract_zinc250k( | ||
save_filepath=cache_path) | ||
if not is_successful: | ||
logger = getLogger(__name__) | ||
logger.warning('Download failed.') | ||
return cache_path | ||
|
||
|
||
def _get_zinc250k_filepath(): | ||
"""Construct a filepath which stores ZINC250k dataset in csv | ||
|
||
This method does not check if the file is already downloaded or not. | ||
|
||
Returns (str): filepath for ZINC250k dataset | ||
|
||
""" | ||
cache_root = download.get_dataset_directory(_root) | ||
cache_path = os.path.join(cache_root, file_name_250k) | ||
return cache_path | ||
|
||
|
||
def _remove_new_line(s): | ||
return s.replace('\n', '') | ||
|
||
|
||
def download_and_extract_zinc250k(save_filepath): | ||
logger = getLogger(__name__) | ||
logger.info('Extracting ZINC250k dataset...') | ||
download_file_path = download.cached_download(download_url) | ||
df = pandas.read_csv(download_file_path) | ||
# 'smiles' column contains '\n', need to remove it. | ||
df['smiles'] = df['smiles'].apply(_remove_new_line) | ||
df.to_csv(save_filepath, columns=_smiles_column_names + _label_names) | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
import os | ||
|
||
import numpy | ||
import pytest | ||
|
||
from chainer_chemistry.dataset.preprocessors.atomic_number_preprocessor import AtomicNumberPreprocessor # NOQA | ||
from chainer_chemistry.datasets import zinc | ||
|
||
|
||
ZINC250K_NUM_LABEL = 3 | ||
ZINC250K_NUM_DATASET = 249455 | ||
|
||
|
||
def test_get_zinc_filepath_without_download(): | ||
filepath = zinc.get_zinc250k_filepath(download_if_not_exist=False) | ||
if os.path.exists(filepath): | ||
os.remove(filepath) # ensure a cache file does not exist. | ||
|
||
filepath = zinc.get_zinc250k_filepath(download_if_not_exist=False) | ||
assert isinstance(filepath, str) | ||
assert not os.path.exists(filepath) | ||
|
||
|
||
def test_get_zinc_filepath_with_download(): | ||
filepath = zinc.get_zinc250k_filepath(download_if_not_exist=False) | ||
if os.path.exists(filepath): | ||
os.remove(filepath) # ensure a cache file does not exist. | ||
|
||
# This method downloads the file if not exist | ||
filepath = zinc.get_zinc250k_filepath(download_if_not_exist=True) | ||
assert isinstance(filepath, str) | ||
assert os.path.exists(filepath) | ||
|
||
|
||
@pytest.mark.slow | ||
def test_get_zinc(): | ||
# test default behavior | ||
pp = AtomicNumberPreprocessor() | ||
dataset = zinc.get_zinc250k(preprocessor=pp) | ||
|
||
# --- Test dataset is correctly obtained --- | ||
index = numpy.random.choice(len(dataset), None) | ||
atoms, label = dataset[index] | ||
|
||
assert atoms.ndim == 1 # (atom, ) | ||
assert atoms.dtype == numpy.int32 | ||
assert label.ndim == 1 | ||
assert label.shape[0] == ZINC250K_NUM_LABEL | ||
assert label.dtype == numpy.float32 | ||
|
||
# --- Test number of dataset --- | ||
assert len(dataset) == ZINC250K_NUM_DATASET | ||
|
||
|
||
def test_get_zinc_smiles(): | ||
# test smiles extraction and dataset order | ||
pp = AtomicNumberPreprocessor() | ||
target_index = [0, 7777, 249454] # set target_index for fast testing... | ||
dataset, smiles = zinc.get_zinc250k(preprocessor=pp, return_smiles=True, | ||
target_index=target_index) | ||
|
||
# --- Test dataset is correctly obtained --- | ||
index = numpy.random.choice(len(dataset), None) | ||
atoms, label = dataset[index] | ||
|
||
assert atoms.ndim == 1 # (atom, ) | ||
assert atoms.dtype == numpy.int32 | ||
# (atom from, atom to) or (edge_type, atom from, atom to) | ||
assert label.ndim == 1 | ||
assert label.shape[0] == ZINC250K_NUM_LABEL | ||
assert label.dtype == numpy.float32 | ||
|
||
# --- Test number of dataset --- | ||
assert len(dataset) == len(target_index) | ||
assert len(smiles) == len(target_index) | ||
|
||
# --- Test order of dataset --- | ||
assert smiles[0] == 'CC(C)(C)c1ccc2occ(CC(=O)Nc3ccccc3F)c2c1' | ||
atoms0, labels0 = dataset[0] | ||
assert numpy.alltrue(atoms0 == numpy.array( | ||
[6, 6, 6, 6, 6, 6, 6, 6, 8, 6, 6, 6, 6, 8, 7, 6, 6, 6, 6, 6, 6, 9, 6, | ||
6], dtype=numpy.int32)) | ||
assert numpy.alltrue(labels0 == numpy.array( | ||
[5.0506, 0.70201224, 2.0840945], dtype=numpy.float32)) | ||
|
||
assert smiles[1] == 'CCCc1cc(NC(=O)Nc2ccc3c(c2)OCCO3)n(C)n1' | ||
atoms7777, labels7777 = dataset[1] | ||
assert numpy.alltrue(atoms7777 == numpy.array( | ||
[6, 6, 6, 6, 6, 6, 7, 6, 8, 7, 6, 6, 6, 6, 6, 6, 8, 6, 6, 8, 7, 6, 7], | ||
dtype=numpy.int32)) | ||
assert numpy.alltrue(labels7777 == numpy.array( | ||
[2.7878, 0.9035222, 2.3195992], dtype=numpy.float32)) | ||
|
||
assert smiles[2] == 'O=C(CC(c1ccccc1)c1ccccc1)N1CCN(S(=O)(=O)c2ccccc2[N+](=O)[O-])CC1' # NOQA | ||
atoms249454, labels249454 = dataset[2] | ||
assert numpy.alltrue(atoms249454 == numpy.array( | ||
[8, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, | ||
6, 6, 7, 16, 8, 8, 6, 6, 6, 6, 6, 6, 7, 8, 8, 6, 6], | ||
dtype=numpy.int32)) | ||
assert numpy.alltrue(labels249454 == numpy.array( | ||
[3.6499, 0.37028658, 2.2142494], dtype=numpy.float32)) | ||
|
||
|
||
def test_get_zinc_label_names(): | ||
label_names = zinc.get_zinc250k_label_names() | ||
assert isinstance(label_names, list) | ||
for label in label_names: | ||
assert isinstance(label, str) | ||
|
||
|
||
if __name__ == '__main__': | ||
args = [__file__, '-v', '-s'] | ||
pytest.main(args=args) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add
@pytest.mark.slow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is intentional. I used
target_index
so this test is not slow.