Skip to content

Commit

Permalink
Merge pull request #351 from shihchengli/return_invalid_smiles
Browse files Browse the repository at this point in the history
Molecule fingerprinting with invalid SMILES in list
  • Loading branch information
kevingreenman committed Feb 10, 2023
2 parents 959176d + 516feb1 commit befa630
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions chemprop/train/molecule_fingerprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
from chemprop.models import MoleculeModel

@timeit()
def molecule_fingerprint(args: FingerprintArgs, smiles: List[List[str]] = None) -> List[List[Optional[float]]]:
def molecule_fingerprint(args: FingerprintArgs,
smiles: List[List[str]] = None,
return_invalid_smiles: bool = True) -> List[List[Optional[float]]]:
"""
Loads data and a trained model and uses the model to encode fingerprint vectors for the data.
:param args: A :class:`~chemprop.args.PredictArgs` object containing arguments for
loading data and a model and making predictions.
:param smiles: List of list of SMILES to make predictions on.
:param return_invalid_smiles: Whether to return predictions of "Invalid SMILES" for invalid SMILES, otherwise will skip them in returned predictions.
:return: A list of fingerprint vectors (list of floats)
"""

Expand Down Expand Up @@ -173,7 +176,15 @@ def molecule_fingerprint(args: FingerprintArgs, smiles: List[List[str]] = None)
for datapoint in full_data:
writer.writerow(datapoint.row)

return all_fingerprints
if return_invalid_smiles:
full_fingerprints = np.zeros((len(full_data), total_fp_size, len(args.checkpoint_paths)), dtype='object')
for full_index in range(len(full_data)):
valid_index = full_to_valid_indices.get(full_index, None)
preds = all_fingerprints[valid_index] if valid_index is not None else np.full((total_fp_size, len(args.checkpoint_paths)), 'Invalid SMILES')
full_fingerprints[full_index] = preds
return full_fingerprints
else:
return all_fingerprints

def model_fingerprint(model: MoleculeModel,
data_loader: MoleculeDataLoader,
Expand Down

0 comments on commit befa630

Please sign in to comment.