Skip to content

Commit

Permalink
Merge pull request #193 from chemprop/FFNN_fingerprint
Browse files Browse the repository at this point in the history
Latent Representations for Ensembles and from FFN
  • Loading branch information
cjmcgill committed Aug 11, 2021
2 parents b1a6ad4 + 0290622 commit ef3a26c
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 118 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -314,10 +314,11 @@ To load a trained model and encode the fingerprint latent representation of mole
* `--checkpoint_dir <dir>` Directory where the model checkpoint is saved (i.e. `--save_dir` during training).
* `--checkpoint_path <path>` Path to a model checkpoint file (`.pt` file).
* `--preds_path` Path where a CSV file containing the encoded fingerprint vectors will be saved.
* Any other arguments that you would supply for a prediction, such as atom or bond features.

SMILES from the provided file are encoded using the MPNN weights loaded from a trained checkpoint file. Fingerprint encoding uses the same set of arguments as making predictions. Unlike making predictions, fingerprint encoding only supports a single saved checkpoint file.
Latent representations of molecules are taken from intermediate stages of the prediction model. This latent representation can be taken at the output of the MPNN (default) or from the last input layer of the FFNN, specified using `--fingerprint_type <MPN or last_FFN>`. Fingerprint encoding uses the same set of arguments as making predictions. If multiple checkpoint files are supplied through `--checkpoint_dir`, then the fingerprint encodings for each of the models will be provided concatenated together as a longer vector.

For example:
Example input:
```
chemprop_fingerprint --test_path data/tox21.csv --checkpoint_dir tox21_checkpoints --preds_path tox21_fingerprint.csv
```
Expand Down
7 changes: 7 additions & 0 deletions chemprop/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,13 @@ def process_args(self) -> None:
'--checkpoint_dir <dir> containing at least one checkpoint.')


class FingerprintArgs(PredictArgs):
""":class:`FingerprintArgs` includes :class:`PredictArgs` with additional arguments for the generation of latent fingerprint vectors."""

fingerprint_type: Literal['MPN','last_FFN'] = 'MPN'
"""Choice of which type of latent fingerprint vector to use. Default is the output of the MPNN, excluding molecular features"""


class HyperoptArgs(TrainArgs):
""":class:`HyperoptArgs` includes :class:`TrainArgs` along with additional arguments used for optimizing Chemprop hyperparameters."""

Expand Down
54 changes: 17 additions & 37 deletions chemprop/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,14 @@
class MoleculeModel(nn.Module):
"""A :class:`MoleculeModel` is a model which contains a message passing network following by feed-forward layers."""

def __init__(self, args: TrainArgs, featurizer: bool = False):
def __init__(self, args: TrainArgs):
"""
:param args: A :class:`~chemprop.args.TrainArgs` object containing model arguments.
:param featurizer: Whether the model should act as a featurizer, i.e., outputting the
learned features from the last layer prior to prediction rather than
outputting the actual property predictions.
"""
super(MoleculeModel, self).__init__()

self.classification = args.dataset_type == 'classification'
self.multiclass = args.dataset_type == 'multiclass'
self.featurizer = featurizer

self.output_size = args.num_tasks
if self.multiclass:
Expand Down Expand Up @@ -112,46 +108,34 @@ def create_ffn(self, args: TrainArgs) -> None:
for param in list(self.ffn.parameters())[0:2*args.frzn_ffn_layers]: # Freeze weights and bias for given number of layers
param.requires_grad=False


def featurize(self,
def fingerprint(self,
batch: Union[List[List[str]], List[List[Chem.Mol]], List[List[Tuple[Chem.Mol, Chem.Mol]]], List[BatchMolGraph]],
features_batch: List[np.ndarray] = None,
atom_descriptors_batch: List[np.ndarray] = None,
atom_features_batch: List[np.ndarray] = None,
bond_features_batch: List[np.ndarray] = None) -> torch.FloatTensor:
"""
Computes feature vectors of the input by running the model except for the last layer.
:param batch: A list of list of SMILES, a list of list of RDKit molecules, or a
list of :class:`~chemprop.features.featurization.BatchMolGraph`.
The outer list or BatchMolGraph is of length :code:`num_molecules` (number of datapoints in batch),
the inner list is of length :code:`number_of_molecules` (number of molecules per datapoint).
:param features_batch: A list of numpy arrays containing additional features.
:param atom_descriptors_batch: A list of numpy arrays containing additional atom descriptors.
:param atom_features_batch: A list of numpy arrays containing additional atom features.
:param bond_features_batch: A list of numpy arrays containing additional bond features.
:return: The feature vectors computed by the :class:`MoleculeModel`.
bond_features_batch: List[np.ndarray] = None,
fingerprint_type = 'MPN') -> torch.FloatTensor:
"""
return self.ffn[:-1](self.encoder(batch, features_batch, atom_descriptors_batch,
atom_features_batch, bond_features_batch))

def fingerprint(self,
batch: Union[List[List[str]], List[List[Chem.Mol]], List[List[Tuple[Chem.Mol, Chem.Mol]]], List[BatchMolGraph]],
features_batch: List[np.ndarray] = None,
atom_descriptors_batch: List[np.ndarray] = None) -> torch.FloatTensor:
"""
Encodes the fingerprint vectors of the input molecules by passing the inputs through the MPNN and returning
the latent representation before the FFNN.
Encodes the latent representations of the input molecules from intermediate stages of the model.
:param batch: A list of list of SMILES, a list of list of RDKit molecules, or a
list of :class:`~chemprop.features.featurization.BatchMolGraph`.
The outer list or BatchMolGraph is of length :code:`num_molecules` (number of datapoints in batch),
the inner list is of length :code:`number_of_molecules` (number of molecules per datapoint).
:param features_batch: A list of numpy arrays containing additional features.
:param atom_descriptors_batch: A list of numpy arrays containing additional atom descriptors.
:return: The fingerprint vectors calculated through the MPNN.
:param fingerprint_type: The choice of which type of latent representation to return as the molecular fingerprint. Currently
supported MPN for the output of the MPNN portion of the model or last_FFN for the input to the final readout layer.
:return: The latent fingerprint vectors.
"""
return self.encoder(batch, features_batch, atom_descriptors_batch)
if fingerprint_type == 'MPN':
return self.encoder(batch, features_batch, atom_descriptors_batch,
atom_features_batch, bond_features_batch)
elif fingerprint_type == 'last_FFN':
return self.ffn[:-1](self.encoder(batch, features_batch, atom_descriptors_batch,
atom_features_batch, bond_features_batch))
else:
raise ValueError(f'Unsupported fingerprint type {fingerprint_type}.')

def forward(self,
batch: Union[List[List[str]], List[List[Chem.Mol]], List[List[Tuple[Chem.Mol, Chem.Mol]]], List[BatchMolGraph]],
Expand All @@ -170,12 +154,8 @@ def forward(self,
:param atom_descriptors_batch: A list of numpy arrays containing additional atom descriptors.
:param atom_features_batch: A list of numpy arrays containing additional atom features.
:param bond_features_batch: A list of numpy arrays containing additional bond features.
:return: The output of the :class:`MoleculeModel`, which is either property predictions
or molecule features if :code:`self.featurizer=True`.
:return: The output of the :class:`MoleculeModel`, containing a list of property predictions
"""
if self.featurizer:
return self.featurize(batch, features_batch, atom_descriptors_batch,
atom_features_batch, bond_features_batch)

output = self.ffn(self.encoder(batch, features_batch, atom_descriptors_batch,
atom_features_batch, bond_features_batch))
Expand Down
40 changes: 0 additions & 40 deletions chemprop/nn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from tqdm import tqdm

from chemprop.data import MoleculeDataLoader, MoleculeDataset


def compute_pnorm(model: nn.Module) -> float:
Expand Down Expand Up @@ -115,43 +112,6 @@ def initialize_weights(model: nn.Module) -> None:
nn.init.xavier_normal_(param)


def compute_molecule_vectors(model: nn.Module,
data: MoleculeDataset,
batch_size: int,
num_workers: int = 8) -> List[np.ndarray]:
"""
Computes the molecule vectors output from the last layer of a :class:`~chemprop.models.MoleculeModel`.
:param model: A :class:`~chemprop.models.MoleculeModel`.
:param data: A :class:`~chemprop.data.MoleculeDataset`.
:param batch_size: Batch size.
:param num_workers: Number of parallel data loading workers.
:return: A list of 1D numpy arrays of length hidden_size containing
the molecule vectors generated by the model for each molecule provided.
"""
training = model.training
model.eval()
data_loader = MoleculeDataLoader(
dataset=data,
batch_size=batch_size,
num_workers=num_workers
)

vecs = []
for batch in tqdm(data_loader, total=len(data_loader)):
# Apply model to batch
with torch.no_grad():
batch_vecs = model.featurize(batch.batch_graph(), batch.features())

# Collect vectors
vecs.extend(batch_vecs.data.cpu().numpy())

if training:
model.train()

return vecs


class NoamLR(_LRScheduler):
"""
Noam learning rate scheduler with piecewise linear increase and exponential decay.
Expand Down
2 changes: 1 addition & 1 deletion chemprop/train/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .cross_validate import chemprop_train, cross_validate, TRAIN_LOGGER_NAME
from .evaluate import evaluate, evaluate_predictions
from .make_predictions import chemprop_predict, make_predictions
from .molecule_fingerprint import chemprop_fingerprint
from .molecule_fingerprint import chemprop_fingerprint, model_fingerprint
from .predict import predict
from .run_training import run_training
from .train import train
Expand Down
101 changes: 67 additions & 34 deletions chemprop/train/molecule_fingerprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@
from typing import List, Optional, Union

import torch
import numpy as np
from tqdm import tqdm

from chemprop.args import PredictArgs, TrainArgs
from chemprop.args import FingerprintArgs, TrainArgs
from chemprop.data import get_data, get_data_from_smiles, MoleculeDataLoader, MoleculeDataset
from chemprop.utils import load_args, load_checkpoint, makedirs, timeit, load_scalers, update_prediction_args
from chemprop.data import MoleculeDataLoader, MoleculeDataset
from chemprop.features import set_reaction, set_explicit_h
from chemprop.models import MoleculeModel

@timeit()
def molecule_fingerprint(args: PredictArgs, smiles: List[List[str]] = None) -> List[List[Optional[float]]]:
def molecule_fingerprint(args: FingerprintArgs, smiles: List[List[str]] = None) -> List[List[Optional[float]]]:
"""
Loads data and a trained model and uses the model to encode fingerprint vectors for the data.
Expand All @@ -26,8 +27,12 @@ def molecule_fingerprint(args: PredictArgs, smiles: List[List[str]] = None) -> L
train_args = load_args(args.checkpoint_paths[0])

# Update args with training arguments
update_prediction_args(predict_args=args, train_args=train_args, validate_feature_sources=False)
args: Union[PredictArgs, TrainArgs]
if args.fingerprint_type == 'MPN': # only need to supply input features if using FFN latent representation and if model calls for them.
validate_feature_sources = False
else:
validate_feature_sources = True
update_prediction_args(predict_args=args, train_args=train_args, validate_feature_sources=validate_feature_sources)
args: Union[FingerprintArgs, TrainArgs]

#set explicit H option and reaction option
set_explicit_h(train_args.explicit_h)
Expand Down Expand Up @@ -67,42 +72,67 @@ def molecule_fingerprint(args: PredictArgs, smiles: List[List[str]] = None) -> L
num_workers=args.num_workers
)

# Set fingerprint size
if args.fingerprint_type == 'MPN':
total_fp_size = args.hidden_size * args.number_of_molecules
if args.features_only:
raise ValueError('With features_only models, there is no latent MPN representation. Use last_FFN fingerprint type instead.')
elif args.fingerprint_type == 'last_FFN':
if args.ffn_num_layers != 1:
total_fp_size = args.ffn_hidden_size
else:
raise ValueError('With a ffn_num_layers of 1, there is no latent FFN representation. Use MPN fingerprint type instead.')
else:
raise ValueError(f'Fingerprint type {args.fingerprint_type} not supported')
all_fingerprints = np.zeros((len(test_data), total_fp_size, len(args.checkpoint_paths)))

# Load model
print(f'Encoding smiles into a fingerprint vector from a single model')
if len(args.checkpoint_paths) != 1:
raise ValueError("Fingerprint generation only supports one model, cannot use an ensemble")

model = load_checkpoint(args.checkpoint_paths[0], device=args.device)
scaler, features_scaler, atom_descriptor_scaler, bond_feature_scaler = load_scalers(args.checkpoint_paths[0])

# Normalize features
if args.features_scaling or train_args.atom_descriptor_scaling or train_args.bond_feature_scaling:
test_data.reset_features_and_targets()
if args.features_scaling:
test_data.normalize_features(features_scaler)
if train_args.atom_descriptor_scaling and args.atom_descriptors is not None:
test_data.normalize_features(atom_descriptor_scaler, scale_atom_descriptors=True)
if train_args.bond_feature_scaling and args.bond_features_size > 0:
test_data.normalize_features(bond_feature_scaler, scale_bond_features=True)

# Make fingerprints
model_preds = model_fingerprint(
model=model,
data_loader=test_data_loader
)
print(f'Encoding smiles into a fingerprint vector from {len(args.checkpoint_paths)} models.')

for index, checkpoint_path in enumerate(tqdm(args.checkpoint_paths, total=len(args.checkpoint_paths))):
model = load_checkpoint(checkpoint_path, device=args.device)
scaler, features_scaler, atom_descriptor_scaler, bond_feature_scaler = load_scalers(args.checkpoint_paths[index])

# Normalize features
if args.features_scaling or train_args.atom_descriptor_scaling or train_args.bond_feature_scaling:
test_data.reset_features_and_targets()
if args.features_scaling:
test_data.normalize_features(features_scaler)
if train_args.atom_descriptor_scaling and args.atom_descriptors is not None:
test_data.normalize_features(atom_descriptor_scaler, scale_atom_descriptors=True)
if train_args.bond_feature_scaling and args.bond_features_size > 0:
test_data.normalize_features(bond_feature_scaler, scale_bond_features=True)

# Make fingerprints
model_fp = model_fingerprint(
model=model,
data_loader=test_data_loader,
fingerprint_type=args.fingerprint_type
)
if args.fingerprint_type == 'MPN' and (args.features_path is not None or args.features_generator): # truncate any features from MPN fingerprint
model_fp = np.array(model_fp)[:,:total_fp_size]
all_fingerprints[:,:,index] = model_fp

# Save predictions
print(f'Saving predictions to {args.preds_path}')
assert len(test_data) == len(model_preds)
assert len(test_data) == len(all_fingerprints)
makedirs(args.preds_path, isfile=True)

# Set column names
fingerprint_columns = []
if len(args.checkpoint_paths) == 1:
for j in range(total_fp_size):
fingerprint_columns.append(f'fp_{j}')
else:
for j in range(total_fp_size):
for i in range(len(args.checkpoint_paths)):
fingerprint_columns.append(f'fp_{j}_model_{i}')

# Copy predictions over to full_data
total_hidden_size = args.hidden_size * args.number_of_molecules
for full_index, datapoint in enumerate(full_data):
valid_index = full_to_valid_indices.get(full_index, None)
preds = model_preds[valid_index] if valid_index is not None else ['Invalid SMILES'] * total_hidden_size
preds = all_fingerprints[valid_index].reshape((len(args.checkpoint_paths) * total_fp_size)) if valid_index is not None else ['Invalid SMILES'] * len(args.checkpoint_paths) * total_fp_size

fingerprint_columns=[f'fp_{i}' for i in range(total_hidden_size)]
for i in range(len(fingerprint_columns)):
datapoint.row[fingerprint_columns[i]] = preds[i]

Expand All @@ -113,10 +143,11 @@ def molecule_fingerprint(args: PredictArgs, smiles: List[List[str]] = None) -> L
for datapoint in full_data:
writer.writerow(datapoint.row)

return model_preds
return all_fingerprints

def model_fingerprint(model: MoleculeModel,
data_loader: MoleculeDataLoader,
fingerprint_type: str = 'MPN',
disable_progress_bar: bool = False) -> List[List[float]]:
"""
Encodes the provided molecules into the latent fingerprint vectors, according to the provided model.
Expand All @@ -133,11 +164,13 @@ def model_fingerprint(model: MoleculeModel,
for batch in tqdm(data_loader, disable=disable_progress_bar, leave=False):
# Prepare batch
batch: MoleculeDataset
mol_batch, features_batch, atom_descriptors_batch = batch.batch_graph(), batch.features(), batch.atom_descriptors()
mol_batch, features_batch, atom_descriptors_batch, atom_features_batch, bond_features_batch = \
batch.batch_graph(), batch.features(), batch.atom_descriptors(), batch.atom_features(), batch.bond_features()

# Make predictions
with torch.no_grad():
batch_fp = model.fingerprint(mol_batch, features_batch, atom_descriptors_batch)
batch_fp = model.fingerprint(mol_batch, features_batch, atom_descriptors_batch,
atom_features_batch, bond_features_batch, fingerprint_type)

# Collect vectors
batch_fp = batch_fp.data.cpu().tolist()
Expand All @@ -151,4 +184,4 @@ def chemprop_fingerprint() -> None:
Parses Chemprop predicting arguments and returns the latent representation vectors for
provided molecules, according to a previously trained model.
"""
molecule_fingerprint(args=PredictArgs().parse_args())
molecule_fingerprint(args=FingerprintArgs().parse_args())

0 comments on commit ef3a26c

Please sign in to comment.