In [1]:
import torch
from chemprop.nn.agg import MeanAggregation, SumAggregation, NormAggregation, AttentiveAggregation

Example output from message passing for input to aggregation, see the message passing notebook for more details.

In [2]:
n_atoms_in_batch = 7
hidden_dim = 3
example_message_passing_output = torch.randn(n_atoms_in_batch, hidden_dim)
which_atoms_in_which_molecule = torch.tensor([0, 0, 1, 1, 1, 1, 2]).long()

### Aggregation

The aggregation layer combines the node level represenations into a graph level representaiton (usually atoms -> molecule).

### Mean and sum aggregation 

Mean aggregation is recommended when the property to predict does not depend on the number of atoms in the molecules (intensive). Sum aggregation is recommended when the property is extensive, though usually norm aggregation is better.

In [3]:
mean_agg = MeanAggregation()
sum_agg = SumAggregation()

In [4]:
mean_agg(H=example_message_passing_output, batch=which_atoms_in_which_molecule)

tensor([[-0.1990,  0.3492,  0.6883],
        [-0.0078, -0.2862, -0.2961],
        [ 2.0586,  0.0895, -1.9913]])

In [5]:
sum_agg(H=example_message_passing_output, batch=which_atoms_in_which_molecule)

tensor([[-0.3979,  0.6984,  1.3765],
        [-0.0314, -1.1449, -1.1844],
        [ 2.0586,  0.0895, -1.9913]])

### Norm aggregation

Norm aggregation can be better than sum aggregation when the molecules are large as it is best to keep the hidden representation values on the order of 1 (though this is less important when batch normalization is used). The normalization constant can be customized (defaults to 100.0).

In [6]:
norm_agg = NormAggregation()
big_norm = NormAggregation(norm=1000.0)

In [7]:
norm_agg(H=example_message_passing_output, batch=which_atoms_in_which_molecule)

tensor([[-0.0040,  0.0070,  0.0138],
        [-0.0003, -0.0114, -0.0118],
        [ 0.0206,  0.0009, -0.0199]])

In [8]:
big_norm(H=example_message_passing_output, batch=which_atoms_in_which_molecule)

tensor([[-3.9791e-04,  6.9843e-04,  1.3765e-03],
        [-3.1359e-05, -1.1449e-03, -1.1844e-03],
        [ 2.0586e-03,  8.9455e-05, -1.9913e-03]])

### Attentive aggregation 

This uses a learned weighted average to combine atom representations within a molecule graph. It needs to be told the size of the hidden dimension as it uses the hidden representation of each atom to calculate the weight of that atom. 

In [9]:
att_agg = AttentiveAggregation(output_size=hidden_dim)

In [10]:
att_agg(H=example_message_passing_output, batch=which_atoms_in_which_molecule)

tensor([[-0.3872,  0.4476,  0.8692],
        [-0.4625, -0.4167, -0.2669],
        [ 2.0586,  0.0895, -1.9913]], grad_fn=<ScatterReduceBackward0>)