Skip to content

Commit

Permalink
Copy smiles column names when dropping extra columns during prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
swansonk14 committed Jan 3, 2021
1 parent 0065637 commit 0613395
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions chemprop/train/make_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from .predict import predict
from chemprop.args import PredictArgs, TrainArgs
from chemprop.data import get_data, get_data_from_smiles, MoleculeDataLoader, MoleculeDataset
from chemprop.data import get_data, get_data_from_smiles, get_header, MoleculeDataLoader, MoleculeDataset
from chemprop.utils import load_args, load_checkpoint, load_scalers, makedirs, timeit


Expand Down Expand Up @@ -125,11 +125,13 @@ def make_predictions(args: PredictArgs, smiles: List[List[str]] = None) -> List[
if args.drop_extra_columns:
datapoint.row = OrderedDict()

if len(datapoint.smiles) == 1:
datapoint.row['smiles'] = datapoint.smiles[0]
else:
for i, smiles in enumerate(datapoint.smiles):
datapoint.row[f'smiles_{i}'] = smiles
smiles_columns = args.smiles_columns

if None in smiles_columns:
smiles_columns = get_header(args.test_path)[:len(smiles_columns)]

for column, smiles in zip(smiles_columns, datapoint.smiles):
datapoint.row[column] = smiles

# Add predictions columns
for pred_name, pred in zip(task_names, preds):
Expand Down

0 comments on commit 0613395

Please sign in to comment.