## Message passing

In [1]:
from chemprop.nn.message_passing.base import BondMessagePassing, AtomMessagePassing

This is an example [dataloader](../data/dataloaders.ipynb) to make inputs for the message passing layer.

In [2]:
import numpy as np
from chemprop.data import MoleculeDatapoint, MoleculeDataset, build_dataloader

smis = ["C" * i for i in range(1, 4)]
ys = np.random.rand(len(smis), 1)
dataset = MoleculeDataset([MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)])
dataloader = build_dataloader(dataset)

### Message passing schemes

There are two message passing schemes. Chemprop prefers a D-MPNN scheme (`BondMessagePassing`) where messages are passed between directed edges (bonds) rather than between nodes (atoms) as would be done in a traditional MPNN (`AtomMessagePassing`).

In [3]:
mp = AtomMessagePassing()
mp = BondMessagePassing()

### Input dimensions

By default, the bond message passing layer's input dimension is the sum of atom and bond features from the default [atom](../featurizers/atom_featurizers.ipynb) and [bond](../featurizers/bond_featurizers.ipynb) featurizers. If you use a custom featurizer, the message passing layer needs to be told when it is created.

Also note that an atom message passing's default input dimension is the length of the atom features from the default atom featurizer.

In [4]:
from chemprop.featurizers import SimpleMoleculeMolGraphFeaturizer

n_atom_features, n_bond_features = SimpleMoleculeMolGraphFeaturizer().shape
(n_atom_features + n_bond_features) == mp.W_i.in_features

True

In [5]:
from chemprop.featurizers import MultiHotAtomFeaturizer

n_extra_bond_features = 12
featurizer = SimpleMoleculeMolGraphFeaturizer(
    atom_featurizer=MultiHotAtomFeaturizer.organic(), extra_bond_fdim=n_extra_bond_features
)

mp = BondMessagePassing(d_v=featurizer.atom_fdim, d_e=featurizer.bond_fdim)

If extra atom descriptors are used, the message passing layer also needs to be told. A separate weight matrix is created and optimized to transform the concatenated hidden representation and extra descriptors back to the hidden dimension after message passing is complete. 

In [6]:
n_extra_atom_descriptors = 28
mp = BondMessagePassing(d_vd=n_extra_atom_descriptors)

### Customization

The following hyperparameters of the message passing layer are customizable:

 - the hidden dimension during message passing, default: 300
 - whether a bias term used, default: False
 - the number of message passing iterations, default: 3
 - whether to pass messages on undirected edges, default: False
 - the dropout probability, default: 0.0 (i.e. no dropout)
 - which activation function, default: ReLU

In [7]:
mp = BondMessagePassing(
    d_h=600, bias=True, depth=5, undirected=True, dropout=0.5, activation="tanh"
)

The output of message passing is a torch tensor of shape # of atoms in batch x length of hidden representation.

In [8]:
batch_molgraph, extra_atom_descriptors, *_ = next(iter(dataloader))
hidden_atom_representations = mp(batch_molgraph, extra_atom_descriptors)
hidden_atom_representations.shape

torch.Size([6, 600])