Skip to content

Commit

Permalink
Avoid saving CSV row where not necessary
Browse files Browse the repository at this point in the history
  • Loading branch information
swansonk14 committed Jul 19, 2020
1 parent e979564 commit bcb402b
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 8 deletions.
2 changes: 1 addition & 1 deletion chemprop/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self,

self.smiles = smiles
self.targets = targets
self.row = row or OrderedDict()
self.row = row
self.features = features
self.features_generator = features_generator
self._mol = 'None' # Initialize with 'None' to distinguish between None returned by invalid molecule
Expand Down
12 changes: 8 additions & 4 deletions chemprop/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def get_data(path: str,
features_path: List[str] = None,
features_generator: List[str] = None,
max_data_size: int = None,
store_row: bool = False,
logger: Logger = None) -> MoleculeDataset:
"""
Gets SMILES and target values from a CSV file.
Expand All @@ -127,6 +128,7 @@ def get_data(path: str,
in place of :code:`args.features_generator`.
:param max_data_size: The maximum number of data points to load.
:param logger: A logger for recording output.
:param store_row: Whether to store the raw CSV row in each :class:`~chemprop.data.data.MoleculeDatapoint`.
:return: A :class:`~chemprop.data.MoleculeDataset` containing SMILES and target values along
with other info such as additional features when desired.
"""
Expand Down Expand Up @@ -179,7 +181,9 @@ def get_data(path: str,

all_smiles.append(smiles)
all_targets.append(targets)
all_rows.append(row)

if store_row:
all_rows.append(row)

if len(all_smiles) >= max_data_size:
break
Expand All @@ -188,11 +192,11 @@ def get_data(path: str,
MoleculeDatapoint(
smiles=smiles,
targets=targets,
row=row,
row=all_rows[i] if store_row else None,
features_generator=features_generator,
features=features_data[i] if features_data is not None else None
) for i, (smiles, targets, row) in tqdm(enumerate(zip(all_smiles, all_targets, all_rows)),
total=len(all_smiles))
) for i, (smiles, targets) in tqdm(enumerate(zip(all_smiles, all_targets)),
total=len(all_smiles))
])

# Filter out invalid SMILES
Expand Down
7 changes: 6 additions & 1 deletion chemprop/sklearn_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@ def predict_sklearn(args: SklearnPredictArgs) -> None:
loading data, loading a trained scikit-learn model, and making predictions with the model.
"""
print('Loading data')
data = get_data(path=args.test_path, smiles_column=args.smiles_column, target_columns=[])
data = get_data(
path=args.test_path,
smiles_column=args.smiles_column,
target_columns=[],
store_row=True
)

print('Loading training arguments')
with open(args.checkpoint_paths[0], 'rb') as f:
Expand Down
14 changes: 12 additions & 2 deletions chemprop/train/make_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,19 @@ def make_predictions(args: PredictArgs, smiles: List[str] = None) -> List[List[O

print('Loading data')
if smiles is not None:
full_data = get_data_from_smiles(smiles=smiles, skip_invalid_smiles=False, features_generator=args.features_generator)
full_data = get_data_from_smiles(
smiles=smiles,
skip_invalid_smiles=False,
features_generator=args.features_generator
)
else:
full_data = get_data(path=args.test_path, args=args, target_columns=[], skip_invalid_smiles=False)
full_data = get_data(
path=args.test_path,
args=args,
target_columns=[],
skip_invalid_smiles=False,
store_row=True
)

print('Validating SMILES')
full_to_valid_indices = {}
Expand Down

0 comments on commit bcb402b

Please sign in to comment.