Skip to content

Commit

Permalink
Multimolecule features (#242)
Browse files Browse the repository at this point in the history
* Allows fingerprinting of single-molecule lists when using a multi-molecule model within limitations. Added fingerprinting and multi-molecule classification tests. Fingerprinting output now differentiates between molecules for multi-molecule input, rather than treating the fingerprints of all molecules for a row as a single set of fingerprints from a single molecule.

* Added multimolecule classification test file

* Fixed column ordering for fingerprint outout

* Allows number-of-molecules mismatch between fingerprinting data and multi-molecule models only when the fingerprinting data has one molecule per entry.

* Fixed accidentally hardcoded value for the number of molecules in a fingerprint output file.

* Update tests/test_integration.py

Co-authored-by: Charles McGill <44245643+cjmcgill@users.noreply.github.com>

* Updated single task fingerprinting test to aggregrate fingerprint values and compare sum rather than comparing against a reference true_fingerprint file

Co-authored-by: Charles McGill <44245643+cjmcgill@users.noreply.github.com>
  • Loading branch information
colinrsmall and cjmcgill committed Mar 15, 2022
1 parent ac6bc9f commit 0f04bfa
Show file tree
Hide file tree
Showing 4 changed files with 445 additions and 12 deletions.
26 changes: 19 additions & 7 deletions chemprop/train/molecule_fingerprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,25 @@ def molecule_fingerprint(args: FingerprintArgs, smiles: List[List[str]] = None)

# Set column names
fingerprint_columns = []
if len(args.checkpoint_paths) == 1:
for j in range(total_fp_size):
fingerprint_columns.append(f'fp_{j}')
else:
for j in range(total_fp_size):
for i in range(len(args.checkpoint_paths)):
fingerprint_columns.append(f'fp_{j}_model_{i}')
if args.fingerprint_type == 'MPN':
if len(args.checkpoint_paths) == 1:
for j in range(total_fp_size//args.number_of_molecules):
for k in range(args.number_of_molecules):
fingerprint_columns.append(f'fp_{j}_mol_{k}')
else:
for j in range(total_fp_size//args.number_of_molecules):
for i in range(len(args.checkpoint_paths)):
for k in range(args.number_of_molecules):
fingerprint_columns.append(f'fp_{j}_mol_{k}_model_{i}')

else: # args == 'last_FNN'
if len(args.checkpoint_paths) == 1:
for j in range(total_fp_size):
fingerprint_columns.append(f'fp_{j}')
else:
for j in range(total_fp_size):
for i in range(len(args.checkpoint_paths)):
fingerprint_columns.append(f'fp_{j}_model_{i}')

# Copy predictions over to full_data
for full_index, datapoint in enumerate(full_data):
Expand Down
9 changes: 5 additions & 4 deletions chemprop/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch.optim.lr_scheduler import _LRScheduler
from tqdm import tqdm

from chemprop.args import PredictArgs, TrainArgs
from chemprop.args import PredictArgs, TrainArgs, FingerprintArgs
from chemprop.data import StandardScaler, MoleculeDataset, preprocess_smiles_columns, get_task_names
from chemprop.models import MoleculeModel
from chemprop.nn_utils import NoamLR
Expand Down Expand Up @@ -577,10 +577,11 @@ def update_prediction_args(predict_args: PredictArgs,
setattr(predict_args,key,override_defaults.get(key,value))

# Same number of molecules must be used in training as in making predictions
if train_args.number_of_molecules != predict_args.number_of_molecules:
if train_args.number_of_molecules != predict_args.number_of_molecules \
and not (isinstance(predict_args, FingerprintArgs) and predict_args.fingerprint_type == "MPN" and predict_args.mpn_shared and predict_args.number_of_molecules == 1):
raise ValueError('A different number of molecules was used in training '
f'model than is specified for prediction, {train_args.number_of_molecules} '
'smiles fields must be provided')
'model than is specified for prediction. This is only supported for models with shared MPN networks'
f'and a fingerprint type of MPN. {train_args.number_of_molecules} smiles fields must be provided.')

# If atom-descriptors were used during training, they must be used when predicting and vice-versa
if train_args.atom_descriptors != predict_args.atom_descriptors:
Expand Down

0 comments on commit 0f04bfa

Please sign in to comment.