# 02456 Molecular Property Prediction

Basic example of how to train the PaiNN model to predict the QM9 property
"internal energy at 0K". This property (and the majority of the other QM9
properties) is computed as a sum of atomic contributions.

In [1]:
import torch
import argparse
from tqdm import trange
import torch.nn.functional as F
from pytorch_lightning import seed_everything

## QM9 Datamodule

In [3]:
import numpy as np
import pytorch_lightning as pl
from torch_geometric.data import Data
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from typing import Optional, List, Union, Tuple
from torch_geometric.transforms import BaseTransform


class GetTarget(BaseTransform):
    def __init__(self, target: Optional[int] = None) -> None:
        self.target = [target]


    def forward(self, data: Data) -> Data:
        if self.target is not None:
            data.y = data.y[:, self.target]
        return data


class QM9DataModule(pl.LightningDataModule):

    target_types = ['atomwise' for _ in range(19)]
    target_types[0] = 'dipole_moment'
    target_types[5] = 'electronic_spatial_extent'

    # Specify unit conversions (eV to meV).
    unit_conversion = {
        i: (lambda t: 1000*t) if i not in [0, 1, 5, 11, 16, 17, 18]
        else (lambda t: t)
        for i in range(19)
    }

    def __init__(
        self,
        target: int = 7,
        data_dir: str = 'data/',
        batch_size_train: int = 100,
        batch_size_inference: int = 1000,
        num_workers: int = 0,
        splits: Union[List[int], List[float]] = [110000, 10000, 10831],
        seed: int = 0,
        subset_size: Optional[int] = None,
    ) -> None:
        super().__init__()
        self.target = target
        self.data_dir = data_dir
        self.batch_size_train = batch_size_train
        self.batch_size_inference = batch_size_inference
        self.num_workers = num_workers
        self.splits = splits
        self.seed = seed
        self.subset_size = subset_size

        self.data_train = None
        self.data_val = None
        self.data_test = None


    def prepare_data(self) -> None:
        # Download data
        QM9(root=self.data_dir)


    def setup(self, stage: Optional[str] = None) -> None:
        dataset = QM9(root=self.data_dir, transform=GetTarget(self.target))

        # Shuffle dataset
        rng = np.random.default_rng(seed=self.seed)
        dataset = dataset[torch.tensor(rng.permutation(len(dataset))).long()]

        # Subset dataset
        if self.subset_size is not None:
            dataset = dataset[:self.subset_size]

        # Split dataset
        if all([type(split) == int for split in self.splits]):
            split_sizes = self.splits
        elif all([type(split) == float for split in self.splits]):
            split_sizes = [int(len(dataset) * prop) for prop in self.splits]

        split_idx = np.cumsum(split_sizes)
        self.data_train = dataset[:split_idx[0]]
        self.data_val = dataset[split_idx[0]:split_idx[1]]
        self.data_test = dataset[split_idx[1]:]


    def get_target_stats(
        self,
        remove_atom_refs: bool = True,
        divide_by_atoms: bool = True
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        atom_refs = self.data_train.atomref(self.target)

        ys = list()
        for batch in self.train_dataloader(shuffle=False):
            y = batch.y.clone()
            if remove_atom_refs and atom_refs is not None:
                y.index_add_(
                    dim=0, index=batch.batch, source=-atom_refs[batch.z]
                )
            if divide_by_atoms:
                _, num_atoms  = torch.unique(batch.batch, return_counts=True)
                y = y / num_atoms.unsqueeze(-1)
            ys.append(y)

        y = torch.cat(ys, dim=0)
        return y.mean(), y.std(), atom_refs


    def train_dataloader(self, shuffle: bool = True) -> DataLoader:
        return DataLoader(
            self.data_train,
            batch_size=self.batch_size_train,
            num_workers=self.num_workers,
            shuffle=shuffle,
            pin_memory=True,
        )


    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.data_val,
            batch_size=self.batch_size_inference,
            num_workers=self.num_workers,
            shuffle=False,
            pin_memory=True,
        )


    def test_dataloader(self) -> DataLoader:
        return DataLoader(
            self.data_test,
            batch_size=self.batch_size_inference,
            num_workers=self.num_workers,
            shuffle=False,
            pin_memory=True,
        )

## Post-processing module

In [4]:
import torch.nn as nn

class AtomwisePostProcessing(nn.Module):
    """
    Post-processing for (QM9) properties that are predicted as sums of atomic
    contributions.
    """
    def __init__(
        self,
        num_outputs: int,
        mean: torch.FloatTensor,
        std: torch.FloatTensor,
        atom_refs: torch.FloatTensor,
    ) -> None:
        """
        Args:
            num_outputs: Integer with the number of model outputs. In most
                cases 1.
            mean: torch.FloatTensor with mean value to shift atomwise
                contributions by.
            std: torch.FloatTensor with standard deviation to scale atomwise
                contributions by.
            atom_refs: torch.FloatTensor of size [num_atom_types, 1] with
                atomic reference values.
        """
        super().__init__()
        self.num_outputs = num_outputs
        self.register_buffer('scale', std)
        self.register_buffer('shift', mean)
        self.atom_refs = nn.Embedding.from_pretrained(atom_refs, freeze=True)


    def forward(
        self,
        atomic_contributions: torch.FloatTensor,
        atoms: torch.LongTensor,
        graph_indexes: torch.LongTensor,
    ) -> torch.FloatTensor:
        """
        Atomwise post-processing operations and atomic sum.

        Args:
            atomic_contributions: torch.FloatTensor of size [num_nodes,
                num_outputs] with each node's contribution to the overall graph
                prediction, i.e., each atom's contribution to the overall
                molecular property prediction.
            atoms: torch.LongTensor of size [num_nodes] with atom type of each
                node in the graph.
            graph_indexes: torch.LongTensor of size [num_nodes] with the graph
                index each node belongs to.

        Returns:
            A torch.FLoatTensor of size [num_graphs, num_outputs] with
            predictions for each graph (molecule).
        """
        num_graphs = torch.unique(graph_indexes).shape[0]

        atomic_contributions = atomic_contributions*self.scale + self.shift
        atomic_contributions = atomic_contributions + self.atom_refs(atoms)

        # Sum contributions for each graph
        output_per_graph = torch.zeros(
            (num_graphs, self.num_outputs),
            device=atomic_contributions.device,
        )
        output_per_graph.index_add_(
            dim=0,
            index=graph_indexes,
            source=atomic_contributions,
        )

        return output_per_graph

## PaiNN

In [5]:
import torch
import torch.nn as nn
import src.data.AtomNeighbours as AN

class PaiNN(nn.Module):
    """
    Polarizable Atom Interaction Neural Network with PyTorch.
    """
    def __init__(
        self,
        num_message_passing_layers: int = 3,
        num_features: int = 128,
        num_outputs: int = 1,
        num_rbf_features: int = 20,
        num_unique_atoms: int = 100,
        cutoff_dist: float = 5.0,
    ) -> None:
        """
        Args:
            num_message_passing_layers: Number of message passing layers in
                the PaiNN model.
            num_features: Size of the node embeddings (scalar features) and
                vector features.
            num_outputs: Number of model outputs. In most cases 1.
            num_rbf_features: Number of radial basis functions to represent
                distances.
            num_unique_atoms: Number of unique atoms in the data that we want
                to learn embeddings for.
            cutoff_dist: Euclidean distance threshold for determining whether
                two nodes (atoms) are neighbours.
        """
        super().__init__()

        self.scalar_embedding = nn.Embedding(num_unique_atoms, num_features)

        self.message_layer = nn.ModuleList(
            [MessagePaiNN(num_features, num_rbf_features, cutoff_dist) for _ in range(num_message_passing_layers)]
        )

        self.update_layer = nn.ModuleList(
            [UpdatePaiNN(num_features) for _ in range(num_message_passing_layers)]
        )

        self.last_layer = nn.Sequential(
            nn.Linear(num_features, num_features),
            nn.SiLU(),
            nn.Linear(num_features, num_outputs),
        )

        self.AN = AN.AtomNeighbours(cutoff_dist)

        self.num_features = num_features


    def forward(
        self,
        atoms: torch.LongTensor,
        atom_positions: torch.FloatTensor,
        graph_indexes: torch.LongTensor,
    ) -> torch.FloatTensor:
        """
        Forward pass of PaiNN. Includes the readout network highlighted in blue
        in Figure 2 in (Schütt et al., 2021) with normal linear layers which is
        used for predicting properties as sums of atomic contributions. The
        post-processing and final sum is perfomed with
        src.models.AtomwisePostProcessing.

        Args:
            atoms: torch.LongTensor of size [num_nodes] with atom type of each
                node in the graph.
            atom_positions: torch.FloatTensor of size [num_nodes, 3] with
                euclidean coordinates of each node / atom.
            graph_indexes: torch.LongTensor of size [num_nodes] with the graph
                index each node belongs to.

        Returns:
            A torch.FloatTensor of size [num_nodes, num_outputs] with atomic
            contributions to the overall molecular property prediction.
        """
        # Neighbourhood matrix
        self.adj_matrix = self.AN.neigbourhood_matrix(atom_positions, graph_indexes)

        node_scalar = self.scalar_embedding(atoms)
        node_vector = torch.zeros(atoms.size(0), self.num_features, 3, device=atoms.device)

        for message_layer, update_layer in zip(self.message_layer, self.update_layer):
            node_scalar, node_vector = message_layer(node_scalar, node_vector, self.adj_matrix)
            node_scalar, node_vector = update_layer(node_scalar, node_vector)

        node_scalar = self.last_layer(node_scalar)


        return node_scalar


def sinc_expansion(r_ij: torch.Tensor, n: int, cutoff: float):

    n_vals = torch.arange(n, device=r_ij.device) + 1

    return torch.sin(r_ij.unsqueeze(-1) * n_vals * torch.pi / cutoff) / r_ij.unsqueeze(-1)


def cosine_cutoff(r_ij: torch.Tensor, cutoff: float):
    return torch.where(
        r_ij < cutoff,
        0.5 * (torch.cos(torch.pi * r_ij / cutoff) + 1),
        torch.tensor(0.0),
    )

class MessagePaiNN(nn.Module):
    """
    Message passing.
    """
    def __init__(
        self,
        # num_message_passing_layers: int = 3,
        num_features: int = 128,
        # num_outputs: int = 1,
        num_rbf_features: int = 20,
        # num_unique_atoms: int = 100,
        cutoff_dist: float = 5.0,
    ) -> None:
        # """
        # Args:
        #     num_message_passing_layers: Number of message passing layers in
        #         the PaiNN model.
        #     num_features: Size of the node embeddings (scalar features) and
        #         vector features.
        #     num_outputs: Number of model outputs. In most cases 1.
        #     num_rbf_features: Number of radial basis functions to represent
        #         distances.
        #     num_unique_atoms: Number of unique atoms in the data that we want
        #         to learn embeddings for.
        #     cutoff_dist: Euclidean distance threshold for determining whether
        #         two nodes (atoms) are neighbours.
        # """
        super().__init__()

        self.scalar_message = nn.Sequential(
            nn.Linear(num_features, num_features),
            nn.SiLU(),
            nn.Linear(num_features, 3 * num_features),
        )

        self.layer_rbf = nn.Linear(num_rbf_features, 3* num_features)

        self.num_features = num_features
        self.num_rbf_features = num_rbf_features
        self.cutoff_dist = cutoff_dist


    def forward(
        self,
        node_scalar,
        node_vector,
        adj_matrix
    ) -> torch.FloatTensor:
        """
        xxxx

        Args:
            djfvkd:ajcnac

        Returns:
            XXXXX
        """
        atom_scalar = self.scalar_message(node_scalar)
        # print("atom scalar shape", atom_scalar.shape)

        # RBF

        r_ij_dist = adj_matrix[:, 5]

        rbf = self.layer_rbf(sinc_expansion(r_ij_dist, self.num_rbf_features, self.cutoff_dist))

        rbf_cos_cutoff = rbf * cosine_cutoff(r_ij_dist, self.cutoff_dist).unsqueeze(-1)
        # print("rbf_cos_cutoff shape", rbf_cos_cutoff.shape)

        # print("rbf_cos_cutoff type", rbf_cos_cutoff.shape)

        # print("rbf type", rbf.shape)

        # print("atom scalar", atom_scalar[adj_matrix[:, 1].long()].shape)


        pre_split = atom_scalar[adj_matrix[:, 1].long()] * rbf_cos_cutoff

        # Split
        split1, split2, split3 = torch.split(pre_split, self.num_features, dim = -1)

        r_ij = adj_matrix[:, 2:5]

        r_ij_standardized = r_ij /r_ij_dist.unsqueeze(-1)

        # print("r_ij_standardized shape", r_ij_standardized.unsqueeze(1).shape)

        message_edge = split3.unsqueeze(-1) * r_ij_standardized.unsqueeze(1)

        message_vector = node_vector[adj_matrix[:, 1].long()] * split1.unsqueeze(-1) + message_edge

        delta_v = torch.zeros_like(node_vector)
        delta_s = torch.zeros_like(node_scalar)

        # list_neighbours: index of the neighbours of atom i
        delta_s.index_add_(0, adj_matrix[:, 0].long(), split2)
        delta_v.index_add_(0, adj_matrix[:, 0].long(), message_vector)

        return node_scalar + delta_s, node_vector + delta_v




class UpdatePaiNN(nn.Module):
    """
    Update passing.
    """
    def __init__(
        self,
        # num_message_passing_layers: int = 3,
        num_features: int = 128,
        # num_outputs: int = 1,
        # num_rbf_features: int = 20,
        # num_unique_atoms: int = 100,
        # cutoff_dist: float = 5.0,
    ) -> None:
        super().__init__()

        self.update_U = nn.Linear(num_features, num_features, bias=False)
        self.update_V = nn.Linear(num_features, num_features, bias=False)

        self.num_features = num_features
        self.scalar_update = nn.Sequential(
            nn.Linear(num_features * 2, num_features),
            nn.SiLU(),
            nn.Linear(num_features, 3 * num_features),
        )


    def forward(
        self,
        node_scalar,
        node_vector
    ) -> torch.FloatTensor:
        """
        xxxx

        Args:
            djfvkd:ajcnac

        Returns:
            XXXXX
        """
        U = self.update_U(node_vector.permute(0, 2, 1))
        V = self.update_V(node_vector.permute(0, 2, 1))
        U = U.permute(0,2,1)
        V = V.permute(0,2,1)

        V_norm = torch.norm(V, dim = -1)

        pre_split_s = self.scalar_update(torch.cat((V_norm, node_scalar), dim = 1))

        a_vv, a_sv, a_ss = torch.split(pre_split_s, self.num_features, dim = 1)

        delta_v = a_vv.unsqueeze(2) * U

        inner_prod = torch.sum(U * V, dim=2)

        delta_s = inner_prod * a_sv + a_ss

        return node_scalar + delta_s, node_vector + delta_v

## Hyperparameters

In [6]:
def cli(args: list = []):
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', default=0)

    # Data
    parser.add_argument('--target', default=7, type=int) # 7 => Internal energy at 0K
    parser.add_argument('--data_dir', default='data/', type=str)
    parser.add_argument('--batch_size_train', default=100, type=int)
    parser.add_argument('--batch_size_inference', default=1000, type=int)
    parser.add_argument('--num_workers', default=0, type=int)
    parser.add_argument('--splits', nargs=3, default=[110000, 10000, 10831], type=int) # [num_train, num_val, num_test]
    parser.add_argument('--subset_size', default=None, type=int)

    # Model
    parser.add_argument('--num_message_passing_layers', default=3, type=int)
    parser.add_argument('--num_features', default=128, type=int)
    parser.add_argument('--num_outputs', default=1, type=int)
    parser.add_argument('--num_rbf_features', default=20, type=int)
    parser.add_argument('--num_unique_atoms', default=100, type=int)
    parser.add_argument('--cutoff_dist', default=5.0, type=float)

    # Training
    parser.add_argument('--lr', default=5e-4, type=float)
    parser.add_argument('--weight_decay', default=0.01, type=float)
    parser.add_argument('--num_epochs', default=3, type=int)
    parser.add_argument('--patience', default=30, type=int)
    parser.add_argument('--swa_lr', default=0.0001, type=float)

    args = parser.parse_args(args=args)
    return args

## Training and testing

In [7]:
args = [] # Specify non-default arguments in this list
args = cli(args)
seed_everything(args.seed)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

dm = QM9DataModule(
    target=args.target,
    data_dir=args.data_dir,
    batch_size_train=args.batch_size_train,
    batch_size_inference=args.batch_size_inference,
    num_workers=args.num_workers,
    splits=args.splits,
    seed=args.seed,
    subset_size=args.subset_size,
)
dm.prepare_data()
dm.setup()
y_mean, y_std, atom_refs = dm.get_target_stats(
    remove_atom_refs=True, divide_by_atoms=True
)


Seed set to 0


In [None]:
import torch.optim as optim
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn

painn = PaiNN(
    num_message_passing_layers=args.num_message_passing_layers,
    num_features=args.num_features,
    num_outputs=args.num_outputs,
    num_rbf_features=args.num_rbf_features,
    num_unique_atoms=args.num_unique_atoms,
    cutoff_dist=args.cutoff_dist,
).to(device)

# Load the pre-trained model weights
painn.load_state_dict(torch.load("best_model_3_layer.pth", map_location = device))

post_processing = AtomwisePostProcessing(
    args.num_outputs, y_mean, y_std, atom_refs
).to(device)

# Define optimizer
optimizer = optim.SGD(painn.parameters(), lr=args.swa_lr, momentum=0.9)

# Wrap PaiNN with SWA
swa_model = AveragedModel(painn).to(device)

train_losses = []

# Training Loop
painn.train()
pbar = trange(50)
for epoch in pbar:
    loss_epoch = 0.
    for batch in dm.train_dataloader():
        batch = batch.to(device)

        atomic_contributions = painn(
            atoms=batch.z,
            atom_positions=batch.pos,
            graph_indexes=batch.batch
        )
        preds = post_processing(
            atoms=batch.z,
            graph_indexes=batch.batch,
            atomic_contributions=atomic_contributions,
        )

        loss_step = F.mse_loss(preds, batch.y, reduction='sum')
        loss = loss_step / len(batch.y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_epoch += loss_step.detach().item()
        break 
    loss_epoch /= len(dm.data_train)
    train_losses.append(loss_epoch)

    pbar.set_postfix_str(f'Train loss: {loss_epoch:.3e}')

    swa_model.update_parameters(painn)

# Evaluate SWA model
swa_model.eval()
mae = 0
with torch.no_grad():
    for batch in dm.test_dataloader():
        batch = batch.to(device)

        atomic_contributions = swa_model(
            atoms=batch.z,
            atom_positions=batch.pos,
            graph_indexes=batch.batch,
        )
        preds = post_processing(
            atoms=batch.z,
            graph_indexes=batch.batch,
            atomic_contributions=atomic_contributions,
        )
        mae += F.l1_loss(preds, batch.y, reduction='sum')

mae /= len(dm.data_test)
unit_conversion = dm.unit_conversion[args.target]
print(f'Test MAE (SWA): {unit_conversion(mae):.3f}')

100%|██████████| 50/50 [01:32<00:00,  1.84s/it, Train loss: 3.484e-08]


Test MAE (SWA): 7.835


In [8]:
from swag.posteriors.swag import SWAG
import torch.optim as optim

painn = PaiNN(
    num_message_passing_layers=args.num_message_passing_layers,
    num_features=args.num_features,
    num_outputs=args.num_outputs,
    num_rbf_features=args.num_rbf_features,
    num_unique_atoms=args.num_unique_atoms,
    cutoff_dist=args.cutoff_dist,
).to(device)

# dict of args
dict_args = {
    'atoms': torch.zeros((1,), dtype=torch.long).to(device),
    'atom_positions': torch.zeros((1, 3), dtype=torch.float32).to(device),
    'graph_indexes': torch.zeros((1,), dtype=torch.long).to(device),
}

# Load the pre-trained model weights
painn.load_state_dict(torch.load("best_model_3_layer.pth", map_location = device))

post_processing = AtomwisePostProcessing(
    args.num_outputs, y_mean, y_std, atom_refs
).to(device)

# Define optimizer
optimizer = optim.SGD(painn.parameters(), lr=0.0001, momentum=0.9)

# Wrap PaiNN with SWA
swag_model = SWAG(
    PaiNN).to(device)

train_losses = []

# Training Loop
painn.train()
pbar = trange(3)
for epoch in pbar:
    loss_epoch = 0.
    batch_counter = 0
    for batch in dm.train_dataloader():
        batch = batch.to(device)
        if batch_counter >= 3:  # Break after processing 3 batches
            break

        atomic_contributions = painn(
            atoms=batch.z,
            atom_positions=batch.pos,
            graph_indexes=batch.batch
        )
        preds = post_processing(
            atoms=batch.z,
            graph_indexes=batch.batch,
            atomic_contributions=atomic_contributions,
        )

        loss_step = F.mse_loss(preds, batch.y, reduction='sum')
        loss = loss_step / len(batch.y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_epoch += loss_step.detach().item()
        batch_counter += 1
    loss_epoch /= len(dm.data_train)
    train_losses.append(loss_epoch)
    
    pbar.set_postfix_str(f'Train loss: {loss_epoch:.3e}')

    swag_model.collect_model(painn)


# Evaluate SWAG model
# Sample weights from SWAG posterior
num_samples = 10  # Number of posterior samples
swag_model.eval()
mae = 0
for i in range(num_samples):
    swag_model.sample()  # Sample weights

    # Evaluate sampled model
    with torch.no_grad():
        for batch in dm.test_dataloader():
            batch = batch.to(device)

            atomic_contributions = swag_model(
                atoms=batch.z,
                atom_positions=batch.pos,
                graph_indexes=batch.batch,
            )
            preds = post_processing(
                atoms=batch.z,
                graph_indexes=batch.batch,
                atomic_contributions=atomic_contributions,
            )
            mae += F.l1_loss(preds, batch.y, reduction='sum')
        break


mae /= (len(dm.data_test) * num_samples)
unit_conversion = dm.unit_conversion[args.target]
print(f'Test MAE (SWAG): {unit_conversion(mae):.3f}')


100%|██████████| 3/3 [00:14<00:00,  4.97s/it, Train loss: 1.037e-07]


Test MAE (SWAG): 0.790
