## Multicomponent models

In [1]:
from chemprop.nn.message_passing import MulticomponentMessagePassing
from chemprop.models import MulticomponentMPNN

### Overview

The basic Chemprop model is designed for a single molecule or reaction as input. A multicomponent Chemprop model organizes these basic building blocks to take multiple molecules/reactions as input. This is useful for properties that depend on multiple components like properties in solvents.

### Message passing

`MulticomponentMessagePassing` organizes the single component [message passing](./message_passing.ipynb) modules for each component in the multicomponent dataset. The individual message passing modules can be unique for each component, shared between some components, or shared between all components. If all components share the same message passing module, the shared flag can be set to True. Note that it doesn't make sense for components that use different featurizers (e.g. molecules and reactions) to use the same message passing module.

In [2]:
from chemprop.nn import BondMessagePassing

mp1 = BondMessagePassing(d_h=100)
mp2 = BondMessagePassing(d_h=600)
blocks = [mp1, mp2]
mcmp = MulticomponentMessagePassing(blocks=blocks, n_components=len(blocks))

mp = BondMessagePassing()
mcmp = MulticomponentMessagePassing(blocks=[mp], n_components=2, shared=True)

During the forward pass of the model, the output of each message passing block is concatentated after aggregation as input to the predictor.

### Aggregation

A single [aggregation](./aggregation.ipynb) module is used on all message passing outputs.

In [3]:
from chemprop.nn import MeanAggregation

agg = MeanAggregation()

### Predictor

The [predictor](./predictor.ipynb) needs to be told the output dimension of the message passing layer.

In [4]:
from chemprop.nn import RegressionFFN

ffn = RegressionFFN(input_dim=mcmp.output_dim)

### Multicomponent MPNN

The submodules are composed together in a `MulticomponentMPNN` model.

In [5]:
mc_model = MulticomponentMPNN(mcmp, agg, ffn)
mc_model

MulticomponentMPNN(
  (message_passing): MulticomponentMessagePassing(
    (blocks): ModuleList(
      (0-1): 2 x BondMessagePassing(
        (W_i): Linear(in_features=86, out_features=300, bias=False)
        (W_h): Linear(in_features=300, out_features=300, bias=False)
        (W_o): Linear(in_features=372, out_features=300, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (tau): ReLU()
        (V_d_transform): Identity()
        (graph_transform): Identity()
      )
    )
  )
  (agg): MeanAggregation()
  (bn): BatchNorm1d(600, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=600, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=1, bias=True)
      )
    )
    (criterion): MSELoss(task_weights=[[1.0]])
    (output_transform): Iden