Skip to content

Commit

Permalink
Merge branch 'master' of github.com:chemprop/chemprop
Browse files Browse the repository at this point in the history
  • Loading branch information
swansonk14 committed Aug 16, 2020
2 parents acc3d43 + 7ec25da commit c5d3545
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
9 changes: 8 additions & 1 deletion chemprop/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def get_data(path: str,
features_generator: List[str] = None,
max_data_size: int = None,
store_row: bool = False,
logger: Logger = None) -> MoleculeDataset:
logger: Logger = None,
skip_none_targets: bool = False) -> MoleculeDataset:
"""
Gets SMILES and target values from a CSV file.
Expand All @@ -129,6 +130,8 @@ def get_data(path: str,
:param max_data_size: The maximum number of data points to load.
:param logger: A logger for recording output.
:param store_row: Whether to store the raw CSV row in each :class:`~chemprop.data.data.MoleculeDatapoint`.
:param skip_none_targets: Whether to skip targets that are all 'None'. This is mostly relevant when --target_columns
are passed in, so only a subset of tasks are examined.
:return: A :class:`~chemprop.data.MoleculeDataset` containing SMILES and target values along
with other info such as additional features when desired.
"""
Expand Down Expand Up @@ -179,6 +182,10 @@ def get_data(path: str,

targets = [float(row[column]) if row[column] != '' else None for column in target_columns]

if skip_none_targets and all(x is None for x in targets):
# Check whether all targets are None and skip if so
continue

all_smiles.append(smiles)
all_targets.append(targets)

Expand Down
2 changes: 1 addition & 1 deletion chemprop/train/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def run_training(args: TrainArgs, logger: Logger = None) -> List[float]:

# Get data
debug('Loading data')
data = get_data(path=args.data_path, args=args, logger=logger)
data = get_data(path=args.data_path, args=args, logger=logger, skip_none_targets=True)
validate_dataset_type(data, dataset_type=args.dataset_type)
args.features_size = data.features_size()
debug(f'Number of tasks = {args.num_tasks}')
Expand Down

0 comments on commit c5d3545

Please sign in to comment.