Skip to content

Commit

Permalink
chatted with kyle and made change so that this patch does not break p…
Browse files Browse the repository at this point in the history
…rediction
  • Loading branch information
apappu97 authored and swansonk14 committed Aug 16, 2020
1 parent 62e10dd commit 7ec25da
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
9 changes: 6 additions & 3 deletions chemprop/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def get_data(path: str,
features_generator: List[str] = None,
max_data_size: int = None,
store_row: bool = False,
logger: Logger = None) -> MoleculeDataset:
logger: Logger = None,
skip_none_targets: bool = False) -> MoleculeDataset:
"""
Gets SMILES and target values from a CSV file.
Expand All @@ -129,6 +130,8 @@ def get_data(path: str,
:param max_data_size: The maximum number of data points to load.
:param logger: A logger for recording output.
:param store_row: Whether to store the raw CSV row in each :class:`~chemprop.data.data.MoleculeDatapoint`.
:param skip_none_targets: Whether to skip targets that are all 'None'. This is mostly relevant when --target_columns
are passed in, so only a subset of tasks are examined.
:return: A :class:`~chemprop.data.MoleculeDataset` containing SMILES and target values along
with other info such as additional features when desired.
"""
Expand Down Expand Up @@ -179,8 +182,8 @@ def get_data(path: str,

targets = [float(row[column]) if row[column] != '' else None for column in target_columns]

# Check whether all targets are None -- this is relevant when specifying target_columns
if all(x is None for x in targets):
if skip_none_targets and all(x is None for x in targets):
# Check whether all targets are None and skip if so
continue

all_smiles.append(smiles)
Expand Down
2 changes: 1 addition & 1 deletion chemprop/train/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def run_training(args: TrainArgs, logger: Logger = None) -> List[float]:

# Get data
debug('Loading data')
data = get_data(path=args.data_path, args=args, logger=logger)
data = get_data(path=args.data_path, args=args, logger=logger, skip_none_targets=True)
validate_dataset_type(data, dataset_type=args.dataset_type)
args.features_size = data.features_size()
debug(f'Number of tasks = {args.num_tasks}')
Expand Down

0 comments on commit 7ec25da

Please sign in to comment.