Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zinc 250k dataset support #276

Merged
merged 9 commits into from
Nov 28, 2018
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ currently supported:
| ------------------: | --------------: | -------------: |
| v0.1.0 ~ v0.3.0 | v2.0 ~ v3.0 | 2017.09.3.0 |
| v0.4.0 | v3.0 ~ v4.0 *1 | 2017.09.3.0 |
| master branch | v3.0 ~ v4.0 | 2017.09.3.0 |
| master branch | v3.0 ~ v5.0 | 2017.09.3.0 |

## Installation

Expand Down Expand Up @@ -90,6 +90,7 @@ The following datasets are currently supported:
- QM9 [7, 8]
- Tox21 [9]
- MoleculeNet [11]
- ZINC (only 250k dataset) [12, 13]
- User (own) dataset

## Research Projects
Expand Down Expand Up @@ -149,3 +150,7 @@ papers. Use the library at your own risk.
[10] Kipf, Thomas N. and Welling, Max. Semi-Supervised Classification with Graph Convolutional Networks. *International Conference on Learning Representations (ICLR)*, 2017.

[11] Zhenqin Wu, Bharath Ramsundar, Evan N. Feinberg, Joseph Gomes, Caleb Geniesse, Aneesh S. Pappu, Karl Leswing, Vijay Pande, MoleculeNet: A Benchmark for Molecular Machine Learning, arXiv preprint, arXiv: 1703.00564, 2017.

[12] J. J. Irwin, T. Sterling, M. M. Mysinger, E. S. Bolstad, and R. G. Coleman. Zinc: a free tool to discover chemistry for biology. *Journal of chemical information and modeling*, 52(7):1757–1768, 2012.

[13] Preprocessed csv file downloaded from https://raw.githubusercontent.com/aspuru-guzik-group/chemical_vae/master/models/zinc_properties/250k_rndm_zinc_drugs_clean_3.csv
4 changes: 4 additions & 0 deletions chainer_chemistry/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from chainer_chemistry.datasets import molnet # NOQA
from chainer_chemistry.datasets import qm9 # NOQA
from chainer_chemistry.datasets import tox21 # NOQA
from chainer_chemistry.datasets import zinc # NOQA


# import class and function
Expand All @@ -11,3 +12,6 @@
from chainer_chemistry.datasets.tox21 import get_tox21 # NOQA
from chainer_chemistry.datasets.tox21 import get_tox21_filepath # NOQA
from chainer_chemistry.datasets.tox21 import get_tox21_label_names # NOQA
from chainer_chemistry.datasets.zinc import get_zinc250k # NOQA
from chainer_chemistry.datasets.zinc import get_zinc250k_filepath # NOQA
from chainer_chemistry.datasets.zinc import get_zinc250k_label_names # NOQA
5 changes: 3 additions & 2 deletions chainer_chemistry/datasets/qm9.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,10 @@ def get_qm9_filepath(download_if_not_exist=True):
necessary.

Args:
config_name: either 'train', 'val', or 'test'
download_if_not_exist (bool): If `True` download dataset
if it is not downloaded yet.

Returns (str): filepath for qm9 dataset
Returns (str): file path for qm9 dataset (formatted to csv)

"""
cache_path = _get_qm9_filepath()
Expand Down
2 changes: 2 additions & 0 deletions chainer_chemistry/datasets/tox21.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ def get_tox21_filepath(dataset_type, download_if_not_exist=True):
Args:
dataset_type: Name of the target dataset type.
Either 'train', 'val', or 'test'
download_if_not_exist (bool): If `True` download dataset
if it is not downloaded yet.

Returns (str): file path for tox21 dataset

Expand Down
114 changes: 114 additions & 0 deletions chainer_chemistry/datasets/zinc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from logging import getLogger
import os

from chainer.dataset import download
import numpy
import pandas

from chainer_chemistry.dataset.parsers.csv_file_parser import CSVFileParser
from chainer_chemistry.dataset.preprocessors.atomic_number_preprocessor import AtomicNumberPreprocessor # NOQA

download_url = 'https://raw.githubusercontent.com/aspuru-guzik-group/chemical_vae/master/models/zinc_properties/250k_rndm_zinc_drugs_clean_3.csv' # NOQA
file_name_250k = 'zinc250k.csv'

_root = 'pfnet/chainer/zinc'

_label_names = ['logP', 'qed', 'SAS']
_smiles_column_names = ['smiles']


def get_zinc250k_label_names():
"""Returns label names of ZINC250k datasets."""
return _label_names


def get_zinc250k(preprocessor=None, labels=None, return_smiles=False,
target_index=None):
"""Downloads, caches and preprocesses Zinc 250K dataset.

Args:
preprocessor (BasePreprocessor): Preprocessor.
This should be chosen based on the network to be trained.
If it is None, default `AtomicNumberPreprocessor` is used.
labels (str or list): List of target labels.
return_smiles (bool): If set to ``True``,
smiles array is also returned.
target_index (list or None): target index list to partially extract
dataset. If None (default), all examples are parsed.

Returns:
dataset, which is composed of `features`, which depends on
`preprocess_method`.

"""
labels = labels or get_zinc250k_label_names()
if isinstance(labels, str):
labels = [labels, ]

def postprocess_label(label_list):
# This is regression task, cast to float value.
return numpy.asarray(label_list, dtype=numpy.float32)

if preprocessor is None:
preprocessor = AtomicNumberPreprocessor()
parser = CSVFileParser(preprocessor, postprocess_label=postprocess_label,
labels=labels, smiles_col='smiles')
result = parser.parse(get_zinc250k_filepath(), return_smiles=return_smiles,
target_index=target_index)

if return_smiles:
return result['dataset'], result['smiles']
else:
return result['dataset']


def get_zinc250k_filepath(download_if_not_exist=True):
"""Construct a filepath which stores ZINC250k dataset for config_name

This method check whether the file exist or not, and downloaded it if
necessary.

Args:
download_if_not_exist (bool): If `True` download dataset
if it is not downloaded yet.

Returns (str): file path for ZINC250k dataset (csv format)

"""
cache_path = _get_zinc250k_filepath()
if not os.path.exists(cache_path):
if download_if_not_exist:
is_successful = download_and_extract_zinc250k(
save_filepath=cache_path)
if not is_successful:
logger = getLogger(__name__)
logger.warning('Download failed.')
return cache_path


def _get_zinc250k_filepath():
"""Construct a filepath which stores ZINC250k dataset in csv

This method does not check if the file is already downloaded or not.

Returns (str): filepath for ZINC250k dataset

"""
cache_root = download.get_dataset_directory(_root)
cache_path = os.path.join(cache_root, file_name_250k)
return cache_path


def _remove_new_line(s):
return s.replace('\n', '')


def download_and_extract_zinc250k(save_filepath):
logger = getLogger(__name__)
logger.info('Extracting ZINC250k dataset...')
download_file_path = download.cached_download(download_url)
df = pandas.read_csv(download_file_path)
# 'smiles' column contains '\n', need to remove it.
df['smiles'] = df['smiles'].apply(_remove_new_line)
df.to_csv(save_filepath, columns=_smiles_column_names + _label_names)
return True
113 changes: 113 additions & 0 deletions tests/datasets_tests/test_zinc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import os

import numpy
import pytest

from chainer_chemistry.dataset.preprocessors.atomic_number_preprocessor import AtomicNumberPreprocessor # NOQA
from chainer_chemistry.datasets import zinc


ZINC250K_NUM_LABEL = 3
ZINC250K_NUM_DATASET = 249455


def test_get_zinc_filepath_without_download():
filepath = zinc.get_zinc250k_filepath(download_if_not_exist=False)
if os.path.exists(filepath):
os.remove(filepath) # ensure a cache file does not exist.

filepath = zinc.get_zinc250k_filepath(download_if_not_exist=False)
assert isinstance(filepath, str)
assert not os.path.exists(filepath)


def test_get_zinc_filepath_with_download():
filepath = zinc.get_zinc250k_filepath(download_if_not_exist=False)
if os.path.exists(filepath):
os.remove(filepath) # ensure a cache file does not exist.

# This method downloads the file if not exist
filepath = zinc.get_zinc250k_filepath(download_if_not_exist=True)
assert isinstance(filepath, str)
assert os.path.exists(filepath)


@pytest.mark.slow
def test_get_zinc():
# test default behavior
pp = AtomicNumberPreprocessor()
dataset = zinc.get_zinc250k(preprocessor=pp)

# --- Test dataset is correctly obtained ---
index = numpy.random.choice(len(dataset), None)
atoms, label = dataset[index]

assert atoms.ndim == 1 # (atom, )
assert atoms.dtype == numpy.int32
assert label.ndim == 1
assert label.shape[0] == ZINC250K_NUM_LABEL
assert label.dtype == numpy.float32

# --- Test number of dataset ---
assert len(dataset) == ZINC250K_NUM_DATASET


def test_get_zinc_smiles():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add @pytest.mark.slow

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is intentional. I used target_index so this test is not slow.

# test smiles extraction and dataset order
pp = AtomicNumberPreprocessor()
target_index = [0, 7777, 249454] # set target_index for fast testing...
dataset, smiles = zinc.get_zinc250k(preprocessor=pp, return_smiles=True,
target_index=target_index)

# --- Test dataset is correctly obtained ---
index = numpy.random.choice(len(dataset), None)
atoms, label = dataset[index]

assert atoms.ndim == 1 # (atom, )
assert atoms.dtype == numpy.int32
# (atom from, atom to) or (edge_type, atom from, atom to)
assert label.ndim == 1
assert label.shape[0] == ZINC250K_NUM_LABEL
assert label.dtype == numpy.float32

# --- Test number of dataset ---
assert len(dataset) == len(target_index)
assert len(smiles) == len(target_index)

# --- Test order of dataset ---
assert smiles[0] == 'CC(C)(C)c1ccc2occ(CC(=O)Nc3ccccc3F)c2c1'
atoms0, labels0 = dataset[0]
assert numpy.alltrue(atoms0 == numpy.array(
[6, 6, 6, 6, 6, 6, 6, 6, 8, 6, 6, 6, 6, 8, 7, 6, 6, 6, 6, 6, 6, 9, 6,
6], dtype=numpy.int32))
assert numpy.alltrue(labels0 == numpy.array(
[5.0506, 0.70201224, 2.0840945], dtype=numpy.float32))

assert smiles[1] == 'CCCc1cc(NC(=O)Nc2ccc3c(c2)OCCO3)n(C)n1'
atoms7777, labels7777 = dataset[1]
assert numpy.alltrue(atoms7777 == numpy.array(
[6, 6, 6, 6, 6, 6, 7, 6, 8, 7, 6, 6, 6, 6, 6, 6, 8, 6, 6, 8, 7, 6, 7],
dtype=numpy.int32))
assert numpy.alltrue(labels7777 == numpy.array(
[2.7878, 0.9035222, 2.3195992], dtype=numpy.float32))

assert smiles[2] == 'O=C(CC(c1ccccc1)c1ccccc1)N1CCN(S(=O)(=O)c2ccccc2[N+](=O)[O-])CC1' # NOQA
atoms249454, labels249454 = dataset[2]
assert numpy.alltrue(atoms249454 == numpy.array(
[8, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7,
6, 6, 7, 16, 8, 8, 6, 6, 6, 6, 6, 6, 7, 8, 8, 6, 6],
dtype=numpy.int32))
assert numpy.alltrue(labels249454 == numpy.array(
[3.6499, 0.37028658, 2.2142494], dtype=numpy.float32))


def test_get_zinc_label_names():
label_names = zinc.get_zinc250k_label_names()
assert isinstance(label_names, list)
for label in label_names:
assert isinstance(label, str)


if __name__ == '__main__':
args = [__file__, '-v', '-s']
pytest.main(args=args)