Skip to content

Commit

Permalink
Merge pull request #1998 from ncfrey/json_loaders
Browse files Browse the repository at this point in the history
JSON file support
  • Loading branch information
Bharath Ramsundar committed Jul 15, 2020
2 parents d941e08 + b7e2a8a commit 5e025e4
Show file tree
Hide file tree
Showing 8 changed files with 301 additions and 3 deletions.
1 change: 1 addition & 0 deletions deepchem/data/__init__.py
Expand Up @@ -14,6 +14,7 @@
from deepchem.data.data_loader import DataLoader
from deepchem.data.data_loader import CSVLoader
from deepchem.data.data_loader import UserCSVLoader
from deepchem.data.data_loader import JsonLoader
from deepchem.data.data_loader import SDFLoader
from deepchem.data.data_loader import FASTALoader
from deepchem.data.data_loader import ImageLoader
192 changes: 190 additions & 2 deletions deepchem/data/data_loader.py
Expand Up @@ -12,10 +12,12 @@
import sys
import logging
import warnings
from deepchem.utils.save import load_csv_files
from typing import List, Optional, Dict, Tuple

from deepchem.utils.save import load_csv_files, load_json_files
from deepchem.utils.save import load_sdf_files
from deepchem.utils.genomics import encode_fasta_sequence
from deepchem.feat import UserDefinedFeaturizer
from deepchem.feat import UserDefinedFeaturizer, Featurizer
from deepchem.data import DiskDataset, NumpyDataset, ImageDataset
import zipfile

Expand Down Expand Up @@ -437,6 +439,192 @@ def _featurize_shard(self, shard):
return (X, np.ones(len(X), dtype=bool))


class JsonLoader(DataLoader):
"""
Creates `Dataset` objects from input json files.
This class provides conveniences to load data from json files.
It's possible to directly featurize data from json files using
pandas, but this class may prove useful if you're processing
large json files that you don't want to manipulate directly in
memory.
It is meant to load JSON files formatted as "records" in line
delimited format, which allows for sharding.
``list like [{column -> value}, ... , {column -> value}]``.
Examples
--------
>> import pandas as pd
>> df = pd.DataFrame(some_data)
>> df.columns.tolist()
.. ['sample_data', 'sample_name', 'weight', 'task']
>> df.to_json('file.json', orient='records', lines=True)
>> loader = JsonLoader(tasks=['task'], feature_field='sample_data',
label_field='task', weight_field='weight', id_field='sample_name')
>> dataset = loader.create_dataset('file.json')
"""

def __init__(self,
tasks: List[str],
feature_field: str,
label_field: str = None,
weight_field: str = None,
id_field: str = None,
featurizer: Optional[Featurizer] = None,
log_every_n: int = 1000):
"""Initializes JsonLoader.
Parameters
----------
tasks : List[str]
List of task names
feature_field : str
JSON field with data to be featurized.
label_field : str, default None
Field with target variables.
weight_field : str, default None
Field with weights.
id_field : str, default None
Field for identifying samples.
featurizer : dc.feat.Featurizer, optional
Featurizer to use to process data
log_every_n : int, optional
Writes a logging statement this often.
"""

if not isinstance(tasks, list):
raise ValueError("Tasks must be a list.")
self.tasks = tasks
self.feature_field = feature_field
self.label_field = label_field
self.weight_field = weight_field
self.id_field = id_field

self.user_specified_features = None
if isinstance(featurizer, UserDefinedFeaturizer):
self.user_specified_features = featurizer.feature_fields
self.featurizer = featurizer
self.log_every_n = log_every_n

def create_dataset(self,
input_files: List[str],
data_dir: Optional[str] = None,
shard_size: Optional[int] = 8192) -> DiskDataset:
"""Creates a `Dataset` from input JSON files.
Parameters
----------
input_files: List[str]
List of JSON filenames.
data_dir: Optional[str], default None
Name of directory where featurized data is stored.
shard_size: Optional[int], default 8192
Shard size when loading data.
Returns
-------
dataset: dc.data.Dataset
A `Dataset` object containing a featurized representation of data
from `input_files`.
"""

if not isinstance(input_files, list):
input_files = [input_files]

def shard_generator():
"""Yield X, y, w, and ids for shards."""
for shard_num, shard in enumerate(
self._get_shards(input_files, shard_size)):

time1 = time.time()
X, valid_inds = self._featurize_shard(shard)
if self.id_field:
ids = shard[self.id_field].values
else:
ids = np.ones(len(X))
ids = ids[valid_inds]

if len(self.tasks) > 0:
# Featurize task results if they exist.
y, w = _convert_df_to_numpy(shard, self.tasks)

if self.label_field:
y = shard[self.label_field]
if self.weight_field:
w = shard[self.weight_field]

# Filter out examples where featurization failed.
y, w = (y[valid_inds], w[valid_inds])
assert len(X) == len(ids) == len(y) == len(w)
else:
# For prospective data where results are unknown, it
# makes no sense to have y values or weights.
y, w = (None, None)
assert len(X) == len(ids)

time2 = time.time()
logger.info("TIMING: featurizing shard %d took %0.3f s" %
(shard_num, time2 - time1))
yield X, y, w, ids

return DiskDataset.create_dataset(shard_generator(), data_dir)

def _get_shards(self, input_files, shard_size):
"""Defines a generator which returns data for each shard"""
return load_json_files(input_files, shard_size)

def _featurize_shard(self, shard):
"""Featurizes a shard of an input dataframe."""
return self._featurize_df(
shard, self.featurizer, log_every_n=self.log_every_n)

def _featurize_df(self,
shard,
featurizer: Featurizer,
log_every_n: int = 1000) -> Tuple[np.ndarray, np.ndarray]:
"""Featurize individual samples in dataframe.
Helper that given a featurizer that operates on individual
samples, computes & adds features for that sample to the
features dataframe.
Parameters
----------
shard: pd.DataFrame
DataFrame that holds data to be featurized.
featurizer: Featurizer
An instance of `dc.feat.Featurizer`.
log_every_n: int, optional (default 1000)
Emit a logging statement every `log_every_n` rows.
Returns
-------
features : np.ndarray
Array of feature vectors.
valid_inds : np.ndarray
Boolean values indicating successfull featurization.
"""

features = []
valid_inds = []
field = self.feature_field
data = shard[field].tolist()

for idx, datapoint in enumerate(data):
feat = featurizer.featurize([datapoint])
is_valid = True if feat.size > 0 else False
valid_inds.append(is_valid)
if is_valid:
features.append(feat)

return np.squeeze(np.array(features), axis=1), valid_inds


class SDFLoader(DataLoader):
"""
Creates `Dataset` from SDF input files.
Expand Down
5 changes: 5 additions & 0 deletions deepchem/data/tests/inorganic_crystal_sample_data.json
@@ -0,0 +1,5 @@
{"structure":{"@module":"pymatgen.core.structure","@class":"Structure","charge":null,"lattice":{"matrix":[[3.9545311068,0.0,0.0],[0.0,3.9545311068,0.0],[0.0,0.0,3.9545311068]],"a":3.9545311068,"b":3.9545311068,"c":3.9545311068,"alpha":90.0,"beta":90.0,"gamma":90.0,"volume":61.8422081649},"sites":[{"species":[{"element":"Rh","occu":1}],"abc":[0.0,0.0,0.0],"xyz":[0.0,0.0,0.0],"label":"Rh","properties":{}},{"species":[{"element":"Te","occu":1}],"abc":[0.5,0.5,0.5],"xyz":[1.9772655534,1.9772655534,1.9772655534],"label":"Te","properties":{}},{"species":[{"element":"N","occu":1}],"abc":[0.5,0.0,0.5],"xyz":[1.9772655534,0.0,1.9772655534],"label":"N","properties":{}},{"species":[{"element":"N","occu":1}],"abc":[0.5,0.5,0.0],"xyz":[1.9772655534,1.9772655534,0.0],"label":"N","properties":{}},{"species":[{"element":"N","occu":1}],"abc":[0.0,0.5,0.5],"xyz":[0.0,1.9772655534,1.9772655534],"label":"N","properties":{}}]},"e_form":2.16,"formula":"TeRhN3"}
{"structure":{"@module":"pymatgen.core.structure","@class":"Structure","charge":null,"lattice":{"matrix":[[4.2894318978,0.0,0.0],[0.0,4.2894318978,0.0],[0.0,0.0,4.2894318978]],"a":4.2894318978,"b":4.2894318978,"c":4.2894318978,"alpha":90.0,"beta":90.0,"gamma":90.0,"volume":78.9222269246},"sites":[{"species":[{"element":"Hf","occu":1}],"abc":[0.5922504528,0.0,0.0],"xyz":[2.5404179838,0.0,0.0],"label":"Hf","properties":{}},{"species":[{"element":"Te","occu":1}],"abc":[0.2378848852,0.5,0.5],"xyz":[1.0203910146,2.1447159489,2.1447159489],"label":"Te","properties":{}},{"species":[{"element":"O","occu":1}],"abc":[0.5012320713,0.0,0.5],"xyz":[2.1500008347,0.0,2.1447159489],"label":"O","properties":{}},{"species":[{"element":"O","occu":1}],"abc":[0.5012320713,0.5,0.0],"xyz":[2.1500008347,2.1447159489,0.0],"label":"O","properties":{}},{"species":[{"element":"O","occu":1}],"abc":[0.7980811547,0.5,0.5],"xyz":[3.4233147622,2.1447159489,2.1447159489],"label":"O","properties":{}}]},"e_form":1.52,"formula":"HfTeO3"}
{"structure":{"@module":"pymatgen.core.structure","@class":"Structure","charge":null,"lattice":{"matrix":[[4.2926387638,0.0,0.0],[0.0,4.2926387638,0.0],[0.0,0.0,4.2926387638]],"a":4.2926387638,"b":4.2926387638,"c":4.2926387638,"alpha":90.0,"beta":90.0,"gamma":90.0,"volume":79.0993708544},"sites":[{"species":[{"element":"Re","occu":1}],"abc":[0.1416166515,0.0,0.0],"xyz":[0.6079091278,0.0,0.0],"label":"Re","properties":{}},{"species":[{"element":"As","occu":1}],"abc":[0.5093856748,0.5,0.5],"xyz":[2.1866086932,2.1463193819,2.1463193819],"label":"As","properties":{}},{"species":[{"element":"F","occu":1}],"abc":[0.5316865005,0.0,0.5],"xyz":[2.2823380822,0.0,2.1463193819],"label":"F","properties":{}},{"species":[{"element":"O","occu":1}],"abc":[0.3074869463,0.5,0.0],"xyz":[1.319930385,2.1463193819,0.0],"label":"O","properties":{}},{"species":[{"element":"O","occu":1}],"abc":[0.927582418,0.5,0.5],"xyz":[3.9817762444,2.1463193819,2.1463193819],"label":"O","properties":{}}]},"e_form":1.48,"formula":"ReAsO2F"}
{"structure":{"@module":"pymatgen.core.structure","@class":"Structure","charge":null,"lattice":{"matrix":[[4.1837305646,0.0,0.0],[0.0,4.1837305646,0.0],[0.0,0.0,4.1837305646]],"a":4.1837305646,"b":4.1837305646,"c":4.1837305646,"alpha":90.0,"beta":90.0,"gamma":90.0,"volume":73.2303523231},"sites":[{"species":[{"element":"W","occu":1}],"abc":[0.676648156,0.0,0.0],"xyz":[2.8309135716,0.0,0.0],"label":"W","properties":{}},{"species":[{"element":"Re","occu":1}],"abc":[0.6351628832,0.5,0.5],"xyz":[2.6573503678,2.0918652823,2.0918652823],"label":"Re","properties":{}},{"species":[{"element":"S","occu":1}],"abc":[0.3728524724,0.0,0.5],"xyz":[1.5599142849,0.0,2.0918652823],"label":"S","properties":{}},{"species":[{"element":"O","occu":1}],"abc":[0.7238489421,0.5,0.0],"xyz":[3.0283889434,2.0918652823,0.0],"label":"O","properties":{}},{"species":[{"element":"O","occu":1}],"abc":[0.0978520248,0.5,0.5],"xyz":[0.4093865068,2.0918652823,2.0918652823],"label":"O","properties":{}}]},"e_form":1.24,"formula":"ReWSO2"}
{"structure":{"@module":"pymatgen.core.structure","@class":"Structure","charge":null,"lattice":{"matrix":[[4.2811442539,0.0,0.0],[0.0,4.2811442539,0.0],[0.0,0.0,4.2811442539]],"a":4.2811442539,"b":4.2811442539,"c":4.2811442539,"alpha":90.0,"beta":90.0,"gamma":90.0,"volume":78.4656515166},"sites":[{"species":[{"element":"Bi","occu":1}],"abc":[0.0012121467,0.0,0.0],"xyz":[0.0051893747,0.0,0.0],"label":"Bi","properties":{}},{"species":[{"element":"Hf","occu":1}],"abc":[0.5074940801,0.5,0.5],"xyz":[2.1726553651,2.140572127,2.140572127],"label":"Hf","properties":{}},{"species":[{"element":"F","occu":1}],"abc":[0.4990106707,0.0,0.5],"xyz":[2.1363366656,0.0,2.140572127],"label":"F","properties":{}},{"species":[{"element":"O","occu":1}],"abc":[0.499996373,0.5,0.0],"xyz":[2.1405565992,2.140572127,0.0],"label":"O","properties":{}},{"species":[{"element":"O","occu":1}],"abc":[0.002611863,0.5,0.5],"xyz":[0.0111817624,2.140572127,2.140572127],"label":"O","properties":{}}]},"e_form":0.62,"formula":"HfBiO2F"}
45 changes: 45 additions & 0 deletions deepchem/data/tests/test_json_loader.py
@@ -0,0 +1,45 @@
"""
Tests for JsonLoader class.
"""

import os
import unittest
import tempfile
import shutil
import numpy as np
import deepchem as dc
from deepchem.data.data_loader import JsonLoader
from deepchem.feat.materials_featurizers import SineCoulombMatrix


class TestJsonLoader(unittest.TestCase):
"""
Test JsonLoader
"""

def setUp(self):
super(TestJsonLoader, self).setUp()
self.current_dir = os.path.dirname(os.path.abspath(__file__))

def test_json_loader(self):
input_file = os.path.join(self.current_dir,
'inorganic_crystal_sample_data.json')
featurizer = SineCoulombMatrix(max_atoms=5)
loader = JsonLoader(
tasks=['e_form'],
feature_field='structure',
id_field='formula',
label_field='e_form',
featurizer=featurizer)
dataset = loader.create_dataset(input_file, shard_size=1)

a = [4625.32086965, 6585.20209678, 61.00680193, 48.72230922, 48.72230922]

assert dataset.X.shape == (5, 1, 5)
assert np.allclose(dataset.X[0][0], a, atol=.5)

dataset = loader.create_dataset(input_file, shard_size=None)
assert dataset.X.shape == (5, 1, 5)

dataset = loader.create_dataset([input_file, input_file], shard_size=5)
assert dataset.X.shape == (10, 1, 5)
2 changes: 1 addition & 1 deletion deepchem/feat/materials_featurizers.py
Expand Up @@ -153,7 +153,7 @@ def _featurize(self, struct: "pymatgen.Structure"):

if self.flatten:
eigs, _ = np.linalg.eig(sine_mat)
zeros = np.zeros((self.max_atoms,))
zeros = np.zeros((1, self.max_atoms))
zeros[:len(eigs)] = eigs
features = zeros
else:
Expand Down
43 changes: 43 additions & 0 deletions deepchem/utils/save.py
Expand Up @@ -10,8 +10,12 @@
import os
import deepchem
import warnings
import logging
from typing import List, Optional, Iterator
from deepchem.utils.genomics import encode_bio_sequence as encode_sequence, encode_fasta_sequence as fasta_sequence, seq_one_hot_encode as seq_one_hotencode

logger = logging.getLogger(__name__)


def log(string, verbose=True):
"""Print string if verbose."""
Expand Down Expand Up @@ -116,6 +120,45 @@ def load_csv_files(filenames, shard_size=None, verbose=True):
yield df


def load_json_files(filenames: List[str],
shard_size: Optional[int] = None) -> Iterator[pd.DataFrame]:
"""Load data as pandas dataframe.
Parameters
----------
filenames : List[str]
List of json filenames.
shard_size : int, optional
Chunksize for reading json files.
Yields
------
df : pandas.DataFrame
Shard of dataframe.
Notes
-----
To load shards from a json file into a Pandas dataframe, the file
must be originally saved with
``df.to_json('filename.json', orient='records', lines=True)``
"""

shard_num = 1
for filename in filenames:
if shard_size is None:
yield pd.read_json(filename, orient='records', lines=True)
else:
logger.info("About to start loading json from %s." % filename)
for df in pd.read_json(
filename, orient='records', chunksize=shard_size, lines=True):
logger.info(
"Loading shard %d of size %s." % (shard_num, str(shard_size)))
df = df.replace(np.nan, str(""), regex=True)
shard_num += 1
yield df


def seq_one_hot_encode(sequences, letters='ATCGN'):
"""One hot encodes list of genomic sequences.
Expand Down
14 changes: 14 additions & 0 deletions docs/dataloaders.rst
Expand Up @@ -22,6 +22,20 @@ UserCSVLoader
.. autoclass:: deepchem.data.UserCSVLoader
:members:

JsonLoader
^^^^^^^^^^
JSON is a flexible file format that is human-readable, lightweight,
and more compact than other open standard formats like XML. JSON files
are similar to python dictionaries of key-value pairs. All keys must
be strings, but values can be any of (string, number, object, array,
boolean, or null), so the format is more flexible than CSV. JSON is
used for describing structured data and to serialize objects. It is
conveniently used to read/write Pandas dataframes with the
`pandas.read_json` and `pandas.write_json` methods.

.. autoclass:: deepchem.data.JsonLoader
:members:

FASTALoader
^^^^^^^^^^^

Expand Down
2 changes: 2 additions & 0 deletions docs/utils.rst
Expand Up @@ -54,6 +54,8 @@ File Handling

.. autofunction:: deepchem.utils.save.load_csv_files

.. autofunction:: deepchem.utils.save.load_json_files

.. autofunction:: deepchem.utils.save.save_metadata

.. autofunction:: deepchem.utils.save.load_from_disk
Expand Down

0 comments on commit 5e025e4

Please sign in to comment.