Skip to content

Commit

Permalink
Option to drop extra columns during prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
swansonk14 committed Jan 2, 2021
1 parent 206950c commit 83ea4c0
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
2 changes: 2 additions & 0 deletions chemprop/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,8 @@ class PredictArgs(CommonArgs):
"""Path to CSV file containing testing data for which predictions will be made."""
preds_path: str
"""Path to CSV file where predictions will be saved."""
drop_extra_columns: bool = False
"""Whether to drop all columns from the test data file besides the SMILES columns and the new prediction columns."""

@property
def ensemble_size(self) -> int:
Expand Down
14 changes: 13 additions & 1 deletion chemprop/train/make_predictions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import OrderedDict
import csv
from typing import List, Optional, Union

Expand Down Expand Up @@ -50,7 +51,7 @@ def make_predictions(args: PredictArgs, smiles: List[List[str]] = None) -> List[
)
else:
full_data = get_data(path=args.test_path, target_columns=[], ignore_columns=[], skip_invalid_smiles=False,
args=args, store_row=True)
args=args, store_row=not args.drop_extra_columns)

print('Validating SMILES')
full_to_valid_indices = {}
Expand Down Expand Up @@ -120,6 +121,17 @@ def make_predictions(args: PredictArgs, smiles: List[List[str]] = None) -> List[
valid_index = full_to_valid_indices.get(full_index, None)
preds = avg_preds[valid_index] if valid_index is not None else ['Invalid SMILES'] * len(task_names)

# If extra columns have been dropped, add back in SMILES columns
if args.drop_extra_columns:
datapoint.row = OrderedDict()

if len(datapoint.smiles) == 1:
datapoint.row['smiles'] = datapoint.smiles[0]
else:
for i, smiles in enumerate(datapoint.smiles):
datapoint.row[f'smiles_{i}'] = smiles

# Add predictions columns
for pred_name, pred in zip(task_names, preds):
datapoint.row[pred_name] = pred

Expand Down

0 comments on commit 83ea4c0

Please sign in to comment.