Skip to content

Commit

Permalink
Changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Bharath Ramsundar authored and Bharath Ramsundar committed Jul 16, 2020
1 parent 5d64e28 commit 89d8cd3
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 8 deletions.
3 changes: 2 additions & 1 deletion deepchem/data/data_loader.py
Expand Up @@ -976,7 +976,8 @@ def create_dataset(self,
Parameters
----------
inputs: Sequence[Any]
List of inputs to process. Entries can be filenames or arbitrary objects.
List of inputs to process. Entries can be arbitrary objects so long as
they are understood by `self.featurizer`
data_dir: str, optional
Directory to store featurized dataset.
shard_size: int, optional
Expand Down
7 changes: 6 additions & 1 deletion deepchem/data/datasets.py
Expand Up @@ -1943,7 +1943,12 @@ def __init__(self,
self._X_shape = self._find_array_shape(X)
self._y_shape = self._find_array_shape(y)
if w is None:
if len(self._y_shape) == 1:
if len(self._y_shape) == 0:
# Case n_samples should be 1
if n_samples != 1:
raise ValueError("y can only be a scalar if n_samples == 1")
w = np.ones_like(y)
elif len(self._y_shape) == 1:
w = np.ones(self._y_shape[0], np.float32)
else:
w = np.ones((self._y_shape[0], 1), np.float32)
Expand Down
16 changes: 10 additions & 6 deletions deepchem/utils/save.py
Expand Up @@ -11,7 +11,7 @@
import deepchem
import warnings
import logging
from typing import List, Optional, Iterator
from typing import List, Optional, Iterator, Any
from deepchem.utils.genomics import encode_bio_sequence as encode_sequence, encode_fasta_sequence as fasta_sequence, seq_one_hot_encode as seq_one_hotencode

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -45,7 +45,8 @@ def get_input_type(input_file):
raise ValueError("Unrecognized extension %s" % file_extension)


def load_data(input_files, shard_size=None):
def load_data(input_files: List[str],
shard_size: Optional[int] = None) -> Iterator[Any]:
"""Loads data from disk.
For CSV files, supports sharded loading for large files.
Expand Down Expand Up @@ -77,7 +78,9 @@ def load_data(input_files, shard_size=None):
yield load_pickle_from_disk(input_file)


def load_sdf_files(input_files, clean_mols, tasks=[]):
def load_sdf_files(input_files: List[str],
clean_mols: bool = True,
tasks: List[str] = []) -> List[pd.DataFrame]:
"""Load SDF file into dataframe.
Parameters
Expand All @@ -99,7 +102,7 @@ def load_sdf_files(input_files, clean_mols, tasks=[]):
-------
dataframes: list
This function returns a list of pandas dataframes. Each dataframe will
columns `('mol_id', 'smiles', 'mol')`.
contain columns `('mol_id', 'smiles', 'mol')`.
"""
from rdkit import Chem
dataframes = []
Expand Down Expand Up @@ -130,12 +133,13 @@ def load_sdf_files(input_files, clean_mols, tasks=[]):
return dataframes


def load_csv_files(filenames, shard_size=None):
def load_csv_files(filenames: List[str],
shard_size: Optional[int] = None) -> Iterator[pd.DataFrame]:
"""Load data as pandas dataframe.
Parameters
----------
input_files: list[str]
filenames: list[str]
List of filenames
shard_size: int, optional (default None)
The shard size to yield at one time.
Expand Down

0 comments on commit 89d8cd3

Please sign in to comment.