## Chemprop MPNN models

In [1]:
from chemprop.models.model import MPNN

### Composition

A Chemprop `MPNN` model is made up of several submodules including a [message passing](./message_passing.ipynb) layer, an [aggregation](./aggregation.ipynb) layer, an optional batch normalization layer, and a [predictor](./predictor.ipynb) feed forward network layer. `MPNN` defines the training and predicting logic used by `lightning` when using a Chemprop model in their framework. 

In [2]:
from chemprop.nn import BondMessagePassing, NormAggregation, RegressionFFN

mp = BondMessagePassing()
agg = NormAggregation()
ffn = RegressionFFN()

basic_model = MPNN(mp, agg, ffn)
basic_model

MPNN(
  (message_passing): BondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_o): Linear(in_features=372, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): NormAggregation()
  (bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=1, bias=True)
      )
    )
    (criterion): MSELoss(task_weights=[[1.0]])
    (output_transform): Identity()
  )
  (X_d_transform): Identity()
)

### Batch normalization

Batch normalization can improve training by keeping the inputs to the FFN small and centered around zero. It is used by default, but can be turned off.

In [3]:
MPNN(mp, agg, ffn, batch_norm=False)

MPNN(
  (message_passing): BondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_o): Linear(in_features=372, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): NormAggregation()
  (bn): Identity()
  (predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=1, bias=True)
      )
    )
    (criterion): MSELoss(task_weights=[[1.0]])
    (output_transform): Identity()
  )
  (X_d_transform): Identity()
)

### Optimizer

`MPNN` also configures the optimizer used by lightning during training. The `torch.optim.Adam` optimizer is used with a Noam learning rate scheduler (defined in `chemprop.scheduler.NoamLR`). The following parameters are customizable:

 - number of warmup epochs, defaults to 2
 - the initial learning rate, defaults to $10^{-4}$
 - the max learning rate, defaults to $10^{-3}$
 - the final learning rate, defaults to $10^{-4}$

In [4]:
model = MPNN(mp, agg, ffn, warmup_epochs=5, init_lr=1e-3, max_lr=1e-2, final_lr=1e-5)

### Metrics

During the validation and testing loops, lightning will use the metrics stored in `MPNN` to evaluate the current model's performance. The `MPNN` has a default metric defined by the type of predictor used. Other [metrics](../metrics.ipynb) can be given to `MPNN` to use instead.

In [5]:
from chemprop.nn import metrics

metrics_list = [metrics.RMSEMetric(), metrics.MAEMetric()]
model = MPNN(mp, agg, ffn, metrics=metrics_list)

### Fingerprinting and encoding

`MPNN` has two helper functions to get the hidden representations at different parts of the model. The fingerprint is the learned representation of the message passing layer after aggregation and batch normalization. The encoding is the hidden representation after a number of layers of the predictor. See the predictor notebook for more details. Note that the 0th encoding is equivalent to the fingerprint.

Example batch for the model. See the [data notebooks](../data/dataloaders.ipynb) for more details.

In [6]:
import numpy as np
from chemprop.data import MoleculeDatapoint, MoleculeDataset
from chemprop.data import build_dataloader

smis = ["C" * i for i in range(1, 4)]
ys = np.random.rand(len(smis), 1)
dataset = MoleculeDataset([MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)])
dataloader = build_dataloader(dataset)
batch = next(iter(dataloader))
bmg, V_d, X_d, *_ = batch

In [7]:
basic_model(bmg, V_d, X_d)

tensor([[-0.0777],
        [-0.0422],
        [-0.0684]], grad_fn=<AddmmBackward0>)

In [8]:
basic_model.fingerprint(bmg, V_d, X_d).shape

torch.Size([3, 300])

In [9]:
basic_model.encoding(bmg, V_d, X_d, i=1).shape

torch.Size([3, 300])

In [10]:
(basic_model.fingerprint(bmg, V_d, X_d) == basic_model.encoding(bmg, V_d, X_d, i=0)).all()

tensor(True)