## Chemprop wMPNN models

In [None]:
from chemprop.models import wMPNN

### Composition

Like a Chemprop [MPNN](./basic_mpnn_model.ipynb) model, a Chemprop `wMPNN` model is made up of several submodules including a [message passing](./message_passing.ipynb) layer, an [aggregation](./aggregation.ipynb) layer, an optional batch normalization layer, and a [predictor](./predictor.ipynb) feed forward network layer. `wMPNN` defines the training and predicting logic used by `lightning` when using a Chemprop model in their framework.

In addition to an `MPNN` model, a `wMPNN` model applies additional atom and bond weightings upon message passing. This is useful in cases where the probability of an atom or bond being present is not always 1. E.g. Polymers, where a 50:50 mixture of two monomers have an atom probability of 0.5:0.5 with the bond probability between monomers being determined by the type of polymer (block, alternating), the type of linking functionalities and the proportion of the monomers.

Unlike the `MPNN` model `wMPNN` uses `WeightedBondMessagePassing` as the default message passing scheme.

In [None]:
from chemprop.nn import WeightedBondMessagePassing, NormAggregation, RegressionFFN

mp = WeightedBondMessagePassing()
agg = NormAggregation()
ffn = RegressionFFN()

basic_model = wMPNN(mp, agg, ffn)
basic_model

### Batch normalization

Batch normalization can improve training by keeping the inputs to the FFN small and centered around zero. It is off by default, but can be turned on.

In [None]:
wMPNN(mp, agg, ffn, batch_norm=True)

### Optimizer

`wMPNN` also configures the optimizer used by lightning during training. The `torch.optim.Adam` optimizer is used with a Noam learning rate scheduler (defined in `chemprop.scheduler.NoamLR`). The following parameters are customizable:

 - number of warmup epochs, defaults to 2
 - the initial learning rate, defaults to $10^{-4}$
 - the max learning rate, defaults to $10^{-3}$
 - the final learning rate, defaults to $10^{-4}$

In [None]:
model = wMPNN(mp, agg, ffn, warmup_epochs=5, init_lr=1e-3, max_lr=1e-2, final_lr=1e-5)

### Metrics

During the validation and testing loops, lightning will use the metrics stored in `wMPNN` to evaluate the current model's performance. The `wMPNN` has a default metric defined by the type of predictor used. Other [metrics](../metrics.ipynb) can be given to `wMPNN` to use instead.

In [None]:
from chemprop.nn import metrics

metrics_list = [metrics.RMSE(), metrics.MAE()]
model = wMPNN(mp, agg, ffn, metrics=metrics_list)

### Fingerprinting and encoding

`wMPNN` has two helper functions to get the hidden representations at different parts of the model. The fingerprint is the learned representation of the message passing layer after aggregation and batch normalization. The encoding is the hidden representation after a number of layers of the predictor. Note that the 0th encoding is equivalent to the fingerprint.

In [None]:
import numpy as np
from chemprop.data import PolymerDatapoint, PolymerDataset
from chemprop.data import build_dataloader

smis = [
    "[*:1]c1cc(F)c([*:2])cc1F.[*:3]c1c(O)cc(O)c([*:4])c1O|0.5|0.5|<1-3:0.5:0.5<1-4:0.5:0.5<2-3:0.5:0.5<2-4:0.5:0.5~10",
    "[*:1]c1cc(F)c([*:2])cc1F.[*:3]c1c(O)cc(O)c([*:4])c1O|0.5|0.5|<1-2:0.375:0.375<1-1:0.375:0.375<2-2:0.375:0.375<3-4:0.375:0.375<3-3:0.375:0.375<4-4:0.125:0.125<1-3:0.125:0.125<1-4:0.125:0.125<2-3:0.125:0.125<2-4:0.125:0.125",
    ]
ys = np.random.rand(len(smis), 1)
dataset = PolymerDataset([PolymerDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)])
dataloader = build_dataloader(dataset)
batch = next(iter(dataloader))
bmg, V_d, X_d, *_ = batch

In [None]:
basic_model(bmg, V_d, X_d)

In [None]:
basic_model.fingerprint(bmg, V_d, X_d).shape

In [None]:
basic_model.encoding(bmg, V_d, X_d, i=1).shape

In [None]:
(basic_model.fingerprint(bmg, V_d, X_d) == basic_model.encoding(bmg, V_d, X_d, i=0)).all()