# APPLYING PAINN: Molecular Property Prediction Using GNN

By s214983, s214659, s204618, s183624

Code inspiration from the following sources:
- Cosine cutoff: https://schnetpack.readthedocs.io/en/latest/_modules/schnetpack/nn/cutoff.html#CosineCutoff
- RBF: https://schnetpack.readthedocs.io/en/latest/_modules/schnetpack/nn/radial.html#GaussianRBF
- Reduced LR on Plateau: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.ReduceLROnPlateau.html
- PaiNN: https://github.com/nityasagarjena/PaiNN-model/tree/main
- Wandb: https://wandb.ai
- Edge: https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.pool.radius_graph.html?highlight=radius_graph


In [1]:
import torch
import argparse
from tqdm import trange
import torch.nn.functional as F
from pytorch_lightning import seed_everything

## QM9 Datamodule

Functions used to load the QM9 dataset

In [2]:
import numpy as np
import pytorch_lightning as pl
from torch_geometric.data import Data
from torch_geometric.datasets import QM9
from torch_geometric.loader import DataLoader
from typing import Optional, List, Union, Tuple
from torch_geometric.transforms import BaseTransform


class GetTarget(BaseTransform):
    def __init__(self, target: Optional[int] = None) -> None:
        self.target = [target]


    def forward(self, data: Data) -> Data:
        if self.target is not None:
            data.y = data.y[:, self.target]
        return data


class QM9DataModule(pl.LightningDataModule):

    target_types = ['atomwise' for _ in range(19)]
    target_types[0] = 'dipole_moment'
    target_types[5] = 'electronic_spatial_extent'

    # Specify unit conversions (eV to meV).
    unit_conversion = {
        i: (lambda t: 1000*t) if i not in [0, 1, 5, 11, 16, 17, 18]
        else (lambda t: t)
        for i in range(19)
    }

    def __init__(
        self,
        target: int = 7,
        data_dir: str = 'data/',
        batch_size_train: int = 100,
        batch_size_inference: int = 1000,
        num_workers: int = 0,
        splits: Union[List[int], List[float]] = [110000, 10000, 10831],
        seed: int = 0,
        subset_size: Optional[int] = None,
    ) -> None:
        super().__init__()
        self.target = target
        self.data_dir = data_dir
        self.batch_size_train = batch_size_train
        self.batch_size_inference = batch_size_inference
        self.num_workers = num_workers
        self.splits = splits
        self.seed = seed
        self.subset_size = subset_size

        self.data_train = None
        self.data_val = None
        self.data_test = None


    def prepare_data(self) -> None:
        # Download data
        QM9(root=self.data_dir)


    def setup(self, stage: Optional[str] = None) -> None:
        dataset = QM9(root=self.data_dir, transform=GetTarget(self.target))

        # Shuffle dataset
        rng = np.random.default_rng(seed=self.seed)
        dataset = dataset[rng.permutation(len(dataset))]

        # Subset dataset
        if self.subset_size is not None:
            dataset = dataset[:self.subset_size]
        
        # Split dataset
        if all([type(split) == int for split in self.splits]):
            split_sizes = self.splits
        elif all([type(split) == float for split in self.splits]):
            split_sizes = [int(len(dataset) * prop) for prop in self.splits]

        split_idx = np.cumsum(split_sizes)
        self.data_train = dataset[:split_idx[0]]
        self.data_val = dataset[split_idx[0]:split_idx[1]]
        self.data_test = dataset[split_idx[1]:]


    def get_target_stats(
        self,
        remove_atom_refs: bool = True,
        divide_by_atoms: bool = True
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        atom_refs = self.data_train.atomref(self.target)

        ys = list()
        for batch in self.train_dataloader(shuffle=False):
            y = batch.y.clone()
            if remove_atom_refs and atom_refs is not None:
                y.index_add_(
                    dim=0, index=batch.batch, source=-atom_refs[batch.z]
                )
            if divide_by_atoms:
                _, num_atoms  = torch.unique(batch.batch, return_counts=True)
                y = y / num_atoms.unsqueeze(-1)
            ys.append(y)

        y = torch.cat(ys, dim=0)
        return y.mean(), y.std(), atom_refs


    def train_dataloader(self, shuffle: bool = True) -> DataLoader:
        return DataLoader(
            self.data_train,
            batch_size=self.batch_size_train,
            num_workers=self.num_workers,
            shuffle=shuffle,
            pin_memory=True,
        )


    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.data_val,
            batch_size=self.batch_size_inference,
            num_workers=self.num_workers,
            shuffle=False,
            pin_memory=True,
        )


    def test_dataloader(self) -> DataLoader:
        return DataLoader(
            self.data_test,
            batch_size=self.batch_size_inference,
            num_workers=self.num_workers,
            shuffle=False,
            pin_memory=True,
        )

  Referenced from: <E87A820F-D734-3F45-AFBE-9D80043A97C0> /Users/clarasofiechristiansen/anaconda3/envs/painn_venv4/lib/python3.12/site-packages/libpyg.so
  Reason: tried: '/Library/Frameworks/Python.framework/Versions/3.12/Python' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Library/Frameworks/Python.framework/Versions/3.12/Python' (no such file), '/Library/Frameworks/Python.framework/Versions/3.12/Python' (no such file)
  Referenced from: <E87A820F-D734-3F45-AFBE-9D80043A97C0> /Users/clarasofiechristiansen/anaconda3/envs/painn_venv4/lib/python3.12/site-packages/libpyg.so
  Reason: tried: '/Library/Frameworks/Python.framework/Versions/3.12/Python' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Library/Frameworks/Python.framework/Versions/3.12/Python' (no such file), '/Library/Frameworks/Python.framework/Versions/3.12/Python' (no such file)


## Post-processing module

Functions to process the output of the PaiNN model. The functions sums the atomic contributions predicted by the model across the molecules. 

In [3]:
import torch.nn as nn

class AtomwisePostProcessing(nn.Module):
    """
    Post-processing for (QM9) properties that are predicted as sums of atomic
    contributions.
    """
    def __init__(
        self,
        num_outputs: int,
        mean: torch.FloatTensor,
        std: torch.FloatTensor,
        atom_refs: torch.FloatTensor,
    ) -> None:
        """
        Args:
            num_outputs: Integer with the number of model outputs. In most
                cases 1.
            mean: torch.FloatTensor with mean value to shift atomwise
                contributions by.
            std: torch.FloatTensor with standard deviation to scale atomwise
                contributions by.
            atom_refs: torch.FloatTensor of size [num_atom_types, 1] with
                atomic reference values.
        """
        super().__init__()
        self.num_outputs = num_outputs
        self.register_buffer('scale', std)
        self.register_buffer('shift', mean)
        self.atom_refs = nn.Embedding.from_pretrained(atom_refs, freeze=True)


    def forward(
        self,
        atomic_contributions: torch.FloatTensor,
        atoms: torch.LongTensor,
        graph_indexes: torch.LongTensor,
    ) -> torch.FloatTensor:
        """
        Atomwise post-processing operations and atomic sum.

        Args:
            atomic_contributions: torch.FloatTensor of size [num_nodes,
                num_outputs] with each node's contribution to the overall graph
                prediction, i.e., each atom's contribution to the overall
                molecular property prediction.
            atoms: torch.LongTensor of size [num_nodes] with atom type of each
                node in the graph.
            graph_indexes: torch.LongTensor of size [num_nodes] with the graph 
                index each node belongs to.

        Returns:
            A torch.FLoatTensor of size [num_graphs, num_outputs] with
            predictions for each graph (molecule).
        """
        num_graphs = torch.unique(graph_indexes).shape[0]

        atomic_contributions = atomic_contributions*self.scale + self.shift
        atomic_contributions = atomic_contributions + self.atom_refs(atoms)

        # Sum contributions for each graph
        output_per_graph = torch.zeros(
            (num_graphs, self.num_outputs),
            device=atomic_contributions.device,
        )
        output_per_graph.index_add_(
            dim=0,
            index=graph_indexes,
            source=atomic_contributions,
        )

        return output_per_graph

## Hyperparameters

In [4]:
def cli(args: list = []):
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', default=0)

    # Data
    parser.add_argument('--target', default=7, type=int) # 7 => Internal energy at 0K
    parser.add_argument('--data_dir', default='data/', type=str)
    parser.add_argument('--batch_size_train', default=100, type=int)
    parser.add_argument('--batch_size_inference', default=1000, type=int)
    parser.add_argument('--num_workers', default=0, type=int)
    parser.add_argument('--splits', nargs=3, default=[110000, 10000, 10831], type=int) # [num_train, num_val, num_test]
    parser.add_argument('--subset_size', default=None, type=int)

    # Model
    parser.add_argument('--num_message_passing_layers', default=3, type=int)
    parser.add_argument('--num_features', default=128, type=int)
    parser.add_argument('--num_outputs', default=1, type=int)
    parser.add_argument('--num_rbf_features', default=20, type=int)
    parser.add_argument('--num_unique_atoms', default=100, type=int)
    parser.add_argument('--cutoff_dist', default=5.0, type=float)

    # Training
    parser.add_argument('--lr', default=5e-4, type=float)
    parser.add_argument('--weight_decay', default=0.01, type=float)
    parser.add_argument('--num_epochs', default=1000, type=int)

    args = parser.parse_args(args=args)
    return args

In [5]:
def get_device():
    if torch.cuda.is_available():
        return 'cuda'
    elif torch.backends.mps.is_available():
        return 'mps'
    else:
        return 'cpu'

device = 'cpu'
print(device)

cpu


## PaiNN 1.0

Model architecture found with layer optimization. It has two message-update blocks.

In [6]:
from torch_geometric.nn import radius_graph
from schnetpack.nn.radial import BesselRBF
from schnetpack.nn.cutoff import CosineCutoff

class MessageBlock(nn.Module):
    def __init__(self, num_features, num_rbf_features, cutoff_dist):
        super().__init__()
        self.num_features = num_features
        self.num_rbf_features = num_rbf_features

        self.cutoff_dist = cutoff_dist

        # Message
        self.cutoff_function = CosineCutoff(self.cutoff_dist)

        self.phi_path = nn.Sequential(
            nn.Linear(self.num_features, self.num_features),
            nn.SiLU(),
            nn.Linear(self.num_features, self.num_features * 3))
        self.W_path = nn.Sequential(
            nn.Linear(self.num_rbf_features, self.num_features * 3))
        
        # Update
        self.U = nn.Linear(self.num_features, self.num_features)
        self.V = nn.Linear(self.num_features, self.num_features)

        self.mlp_update = nn.Sequential(
            nn.Linear(self.num_features * 2, self.num_features),
            nn.SiLU(),
            nn.Linear(self.num_features, self.num_features * 3)
        )

    def forward(self, s, v, i_index, j_index, rbf, r_ij_direction):
        # Message 
        phi = self.phi_path(s)
        W = self.cutoff_function(self.W_path(rbf))
        split = phi[j_index] * W # i_index
        Ws, Wvs, Wvv = torch.split(split, self.num_features, dim=-1)
        
        delta_v_all = Wvs.unsqueeze(-1) * r_ij_direction.unsqueeze(1) + Wvv.unsqueeze(-1) * v[j_index] # right and left path
        
        delta_v = torch.zeros_like(v)
        delta_v = delta_v.index_add_(0, i_index, delta_v_all)
        v = v + delta_v

        delta_s = torch.zeros_like(s)
        delta_s = delta_s.index_add_(0, i_index, Ws)
        s = s + delta_s

        # Update 
        v_permuted = torch.permute(v, (0,2,1))
        Uv = torch.permute(self.U(v_permuted), (0,2,1))
        Vv = torch.permute(self.V(v_permuted), (0,2,1))
        Vv_norm = torch.linalg.norm(Vv, dim=2)
        mlp_input = torch.hstack([Vv_norm, s])
        mlp_result = self.mlp_update(mlp_input)

        a_vv, a_sv, a_ss = torch.split(mlp_result, self.num_features, dim=-1)
        
        dv = a_vv.unsqueeze(-1) * Uv
        
        dot_prod = torch.sum(Uv * Vv, dim=2) # dot product
        ds = dot_prod * a_sv + a_ss
        
        s = s + ds
        v = v + delta_v # dv

        return s, v



class PaiNN1(nn.Module):
    """
    Polarizable Atom Interaction Neural Network with PyTorch.
    """
    def __init__(
        self,
        num_message_passing_layers: int = 3,
        num_features: int = 128,
        num_outputs: int = 1,
        num_rbf_features: int = 20,
        num_unique_atoms: int = 100,
        cutoff_dist: float = 5.0,
        device: str='cpu',
    ) -> None:
        """
        Args:
            num_message_passing_layers: Number of message passing layers in
                the PaiNN model.
            num_features: Size of the node embeddings (scalar features) and
                vector features.
            num_outputs: Number of model outputs. In most cases 1.
            num_rbf_features: Number of radial basis functions to represent
                distances.
            num_unique_atoms: Number of unique atoms in the data that we want
                to learn embeddings for.
            cutoff_dist: Euclidean distance threshold for determining whether 
                two nodes (atoms) are neighbours.
        """
        super().__init__()
        self.num_message_passing_layers = num_message_passing_layers
        self.num_features = num_features
        self.num_outputs = num_outputs
        self.cutoff_dist = cutoff_dist
        self.num_rbf_features = num_rbf_features
        self.num_unique_atoms = num_unique_atoms
        self.device = device
        
        self.to(device)

        # Initial embeddings function
        self.embedding = nn.Embedding(self.num_unique_atoms, self.num_features)    
        # RBF
        self.rbf = BesselRBF(self.num_rbf_features, self.cutoff_dist)

        # Message blocks (both message and update)
        self.blocks = nn.ModuleList([
            MessageBlock(num_features, num_rbf_features, self.cutoff_dist) 
            for _ in range(self.num_message_passing_layers)
        ])

        # Last MLP
        self.last_mlp = nn.Sequential(
            nn.Linear(self.num_features, self.num_features),
            nn.SiLU(),
            nn.Linear(self.num_features, self.num_outputs)
        )

    def forward(
        self,
        atoms: torch.LongTensor,
        atom_positions: torch.FloatTensor,
        graph_indexes: torch.LongTensor,
    ) -> torch.FloatTensor:
        """
        Forward pass of PaiNN. Includes the readout network highlighted in blue
        in Figure 2 in (Schütt et al., 2021) with normal linear layers which is
        used for predicting properties as sums of atomic contributions. The
        post-processing and final sum is perfomed with
        src.models.AtomwisePostProcessing.

        Args:
            atoms: torch.LongTensor of size [num_nodes] with atom type of each
                node in the graph.
            atom_positions: torch.FloatTensor of size [num_nodes, 3] with
                euclidean coordinates of each node / atom.
            graph_indexes: torch.LongTensor of size [num_nodes] with the graph 
                index each node belongs to.

        Returns:
            A torch.FloatTensor of size [num_nodes, num_outputs] with atomic
            contributions to the overall molecular property prediction.
        """
        s = self.embedding(atoms)
        v = torch.zeros((atoms.shape[0], self.num_features, 3), device=self.device) 
        i_index, j_index = build_edge_index(atom_positions, self.cutoff_dist, graph_indexes)
        r_ij = atom_positions[j_index] - atom_positions[i_index] # Check
        distance = torch.linalg.norm(r_ij, axis=1, keepdim=True)
        #distance = torch.clamp(distance, min=1e-8)
        rbf = self.rbf(distance.squeeze())
        r_ij_direction = r_ij / (distance + 1e-8)
        # Message passing
        for block in self.blocks:
            s, v = block(s, v, i_index, j_index, rbf, r_ij_direction)
        
        E = self.last_mlp(s)
        #print(E)
        return E
        

def build_edge_index(atom_positions, cutoff_distance, graph_indexes):
    edge_index =radius_graph(atom_positions, r=cutoff_distance, batch=graph_indexes, flow='target_to_source')
    #print(edge_index)
    return edge_index

In [7]:
# Load best model
# Trained with the settings: Early stopping, reduced learning rate on plateau
painn = PaiNN1(num_message_passing_layers=2)
painn.load_state_dict(torch.load('model1_p5mqh3xg.pth', map_location=torch.device(device)))

  painn.load_state_dict(torch.load('model1_p5mqh3xg.pth', map_location=torch.device(device)))


<All keys matched successfully>

### Testing

It takes less than 5 minutes on CPU (including loading the data).

In [8]:
args = [] # Specify non-default arguments in this list
args = cli(args)
seed_everything(args.seed)

dm = QM9DataModule(
    target=args.target,
    data_dir=args.data_dir,
    batch_size_train=args.batch_size_train,
    batch_size_inference=args.batch_size_inference,
    num_workers=args.num_workers,
    splits=args.splits,
    seed=args.seed,
    subset_size=args.subset_size,
)
dm.prepare_data()
print('Data Loaded')
dm.setup()
y_mean, y_std, atom_refs = dm.get_target_stats(
    remove_atom_refs=True, divide_by_atoms=True
)
print('Target Stats Loaded')
post_processing = AtomwisePostProcessing(
    args.num_outputs, y_mean, y_std, atom_refs
)

painn.to(device)
post_processing.to(device)

mae = 0
painn.eval()
with torch.no_grad():
    for i, batch in enumerate(dm.test_dataloader()):
        print('Batch:', i)
        batch = batch.to(device)

        atomic_contributions = painn(
            atoms=batch.z,
            atom_positions=batch.pos,
            graph_indexes=batch.batch,
        )
        preds = post_processing(
            atoms=batch.z,
            graph_indexes=batch.batch,
            atomic_contributions=atomic_contributions,
        )
        mae += F.l1_loss(preds, batch.y, reduction='sum')

mae /= len(dm.data_test)
unit_conversion = dm.unit_conversion[args.target]
print(f'Test MAE: {unit_conversion(mae):.3f}')

Seed set to 0


Data Loaded
Target Stats Loaded
Batch: 0
Batch: 1
Batch: 2
Batch: 3
Batch: 4
Batch: 5
Batch: 6
Batch: 7
Batch: 8
Batch: 9
Batch: 10
Test MAE: 17.992


## PaiNN 1.0 with SWA

In [9]:
# Load best model with SWA
# Trained with the best model seen above.
# Trained with the settings: Early stopping, reduced learning rate on plateau, and SWA 
painn_swa = PaiNN1(num_message_passing_layers=2)
painn_swa.load_state_dict(torch.load('model1_swa_iccjgnmp.pth', map_location=torch.device(device)))

  painn_swa.load_state_dict(torch.load('model1_swa_iccjgnmp.pth', map_location=torch.device(device)))


<All keys matched successfully>

In [10]:
args = [] # Specify non-default arguments in this list
args = cli(args)
seed_everything(args.seed)

painn_swa.to(device)
post_processing.to(device)

mae = 0
painn_swa.eval()
with torch.no_grad():
    for i, batch in enumerate(dm.test_dataloader()):
        print('Batch:', i)
        batch = batch.to(device)

        atomic_contributions = painn_swa(
            atoms=batch.z,
            atom_positions=batch.pos,
            graph_indexes=batch.batch,
        )
        preds = post_processing(
            atoms=batch.z,
            graph_indexes=batch.batch,
            atomic_contributions=atomic_contributions,
        )
        mae += F.l1_loss(preds, batch.y, reduction='sum')

mae /= len(dm.data_test)
unit_conversion = dm.unit_conversion[args.target]
print(f'Test MAE: {unit_conversion(mae):.3f}')

Seed set to 0


Batch: 0
Batch: 1
Batch: 2
Batch: 3
Batch: 4
Batch: 5
Batch: 6
Batch: 7
Batch: 8
Batch: 9
Batch: 10
Test MAE: 17.272


## PaiNN 2.0

In [11]:
import torch
from torch import nn
from torch_geometric.nn import radius_graph

def rbf_generator(distance, num_rbf_features, cutoff_distance):
    n = torch.arange(num_rbf_features, device=distance.device) + 1
    return torch.sin(distance.unsqueeze(-1) * n * torch.pi / cutoff_distance) / distance.unsqueeze(-1)

def f_cut(distance, cutoff_distance):
    # https://schnetpack.readthedocs.io/en/latest/_modules/schnetpack/nn/cutoff.html#CosineCutoff
    return torch.where(
        distance < cutoff_distance,
        0.5 * (torch.cos(torch.pi * distance / cutoff_distance) + 1),
        torch.tensor(0.0, device=distance.device, dtype=distance.dtype),
    )

class MessageLayer(nn.Module):
    def __init__(self, num_features, num_rbf_features, cutoff_distance):
        super().__init__()
        
        self.num_features = num_features
        self.num_rbf_features = num_rbf_features
        self.cutoff_distance = cutoff_distance
        
        self.phi_path = nn.Sequential(
            nn.Linear(num_features, num_features),
            nn.SiLU(),
            nn.Linear(num_features, num_features * 3),
        )
        
        self.W_path = nn.Linear(num_rbf_features, num_features * 3)
        
    def forward(self, s, v, edge_indexes, r_ij, distance):
        W = self.W_path(rbf_generator(distance, self.num_rbf_features, self.cutoff_distance))
        W = W * f_cut(distance, self.cutoff_distance).unsqueeze(-1)
        phi = self.phi_path(s)        
        split = W * phi[edge_indexes[:, 1]]
        
        Wvv, Wvs, Wvs = torch.split(split, self.num_features, dim = 1)
        
        v_1 =  v[edge_indexes[:, 1]] * Wvv.unsqueeze(1) 
        v_2 = Wvs.unsqueeze(1) * (r_ij / distance.unsqueeze(-1)).unsqueeze(-1)
        v_sum = v_1 + v_2
        
        delta_s = torch.zeros_like(s)
        delta_s.index_add_(0, edge_indexes[:, 0], Wvs)
        new_s = s + delta_s

        delta_v = torch.zeros_like(v)
        delta_v.index_add_(0, edge_indexes[:, 0], v_sum)
        new_v = v + delta_v
        
        return new_s, new_v

class UpdateLayer(nn.Module):
    def __init__(self, num_features: int):
        super().__init__()
        
        self.V = nn.Linear(num_features, num_features)
        self.U = nn.Linear(num_features, num_features)
        
        self.mlp_update = nn.Sequential(
            nn.Linear(num_features * 2, num_features),
            nn.SiLU(),
            nn.Linear(num_features, num_features * 3),
        )
        
    def forward(self, s, v):
        Vv = self.V(v)
        Uv = self.U(v)
        
        Vv_norm = torch.linalg.norm(Vv, dim=1)
        mlp_input = torch.cat((Vv_norm, s), dim=1)
        mlp_result = self.mlp_update(mlp_input)
        
        a_vv, a_sv, a_ss = torch.split(mlp_result, v.shape[-1], dim = 1)
        
        dot_prod = torch.sum(Uv * Vv, dim=1)
        delta_s = a_sv * dot_prod + a_ss
        new_s = s + delta_s

        delta_v = a_vv.unsqueeze(1) * Uv
        new_v = v + delta_v

        return new_s, new_v

class PaiNN2(nn.Module):
    def __init__(
        self, 
        num_message_passing_layers: int=5, 
        num_features: int = 128, 
        num_outputs: int = 1,
        num_rbf_features: int = 20,
        num_unique_atoms: int = 100,
        cutoff_dist: float = 5.0,
        device: str='cpu'
    ):
        super().__init__()
        
        self.num_unique_atoms = num_unique_atoms   # number of all elements
        self.cutoff_distance = cutoff_dist
        self.num_message_passing_layers = num_message_passing_layers
        self.num_features = num_features
        self.num_rbf_features = num_rbf_features
        self.num_outputs = num_outputs
        
        self.embedding = nn.Embedding(self.num_unique_atoms, self.num_features)

        self.message_layers = nn.ModuleList(
            [
                MessageLayer(self.num_features, self.num_rbf_features, self.cutoff_distance)
                for _ in range(self.num_message_passing_layers)
            ]
        )
        self.update_layers = nn.ModuleList(
            [
                UpdateLayer(self.num_features)
                for _ in range(self.num_message_passing_layers)
            ]            
        )
        
        self.last_mlp = nn.Sequential(
            nn.Linear(self.num_features, self.num_features),
            nn.SiLU(),
            nn.Linear(self.num_features, self.num_outputs),
        )
        
    def forward(self, atoms, atom_positions, graph_indexes):
        edge = build_edge_index(atom_positions, self.cutoff_distance, graph_indexes).T
       
        r_ij = atom_positions[edge[:,1]] - atom_positions[edge[:,0]]

        distance = torch.linalg.norm(r_ij, dim=1)
        
        s = self.embedding(atoms)
        v = torch.zeros((atom_positions.shape[0], 3, self.num_features), device=r_ij.device, dtype=r_ij.dtype)
        
        for message_layer, update_layer in zip(self.message_layers, self.update_layers):
            s, v = message_layer(s, v, edge, r_ij, distance)
            s, v = update_layer(s, v)
        
        s = self.last_mlp(s)
        
        return s

def build_edge_index(atom_positions, cutoff_distanceance, graph_indexes):
    edge_index =radius_graph(atom_positions, r=cutoff_distanceance, batch=graph_indexes, flow='target_to_source')
    return edge_index

In [12]:
# Load best model
# Trained with the settings: Early stopping, reduced learning rate on plateau
painn = PaiNN2(num_message_passing_layers=5)
painn.load_state_dict(torch.load('model2_dvyocn32.pth', map_location=torch.device(device)))

  painn.load_state_dict(torch.load('model2_dvyocn32.pth', map_location=torch.device(device)))


<All keys matched successfully>

In [13]:
args = [] # Specify non-default arguments in this list
args = cli(args)
seed_everything(args.seed)

painn.to(device)
post_processing.to(device)

mae = 0
painn.eval()
with torch.no_grad():
    for i, batch in enumerate(dm.test_dataloader()):
        print('Batch:', i)
        batch = batch.to(device)

        atomic_contributions = painn(
            atoms=batch.z,
            atom_positions=batch.pos,
            graph_indexes=batch.batch,
        )
        preds = post_processing(
            atoms=batch.z,
            graph_indexes=batch.batch,
            atomic_contributions=atomic_contributions,
        )
        mae += F.l1_loss(preds, batch.y, reduction='sum')

mae /= len(dm.data_test)
unit_conversion = dm.unit_conversion[args.target]
print(f'Test MAE: {unit_conversion(mae):.3f}')

Seed set to 0


Batch: 0
Batch: 1
Batch: 2
Batch: 3
Batch: 4
Batch: 5
Batch: 6
Batch: 7
Batch: 8
Batch: 9
Batch: 10
Test MAE: 6.996


## Training loop
```python
"""
Basic example of how to train the PaiNN model to predict the QM9 property
"internal energy at 0K". This property (and the majority of the other QM9
properties) is computed as a sum of atomic contributions.
"""
import torch
import argparse
import os
import pandas as pd
import time
from tqdm import trange
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
from src.data import QM9DataModule
from pytorch_lightning import seed_everything
from src.models import AtomwisePostProcessing
from src.models.painn import PaiNN# this is the working one!
import wandb

print('currrent working directory: ', os.getcwd())
load = True
def cli():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', default=0)

    # Data
    parser.add_argument('--target', default=7, type=int) # 7 => Internal energy at 0K
    parser.add_argument('--data_dir', default='data/', type=str)
    parser.add_argument('--batch_size_train', default=100, type=int)
    parser.add_argument('--batch_size_inference', default=1000, type=int)
    parser.add_argument('--num_workers', default=0, type=int)
    parser.add_argument('--splits', nargs=3, default=[110000, 10000, 10831], type=int) # [num_train, num_val, num_test]
    parser.add_argument('--subset_size', default=None, type=int)

    # Model
    parser.add_argument('--num_message_passing_layers', default=6, type=int)
    parser.add_argument('--num_features', default=128, type=int)
    parser.add_argument('--num_outputs', default=1, type=int)
    parser.add_argument('--num_rbf_features', default=20, type=int)
    parser.add_argument('--num_unique_atoms', default=100, type=int)
    parser.add_argument('--cutoff_dist', default=5.0, type=float)

    # Training
    parser.add_argument('--lr', default=5e-4, type=float)
    parser.add_argument('--weight_decay', default=0.01, type=float)
    parser.add_argument('--num_epochs', default=1000, type=int)

    #parser.add_argument('--clip_value', default=1000, type=int)

    args = parser.parse_args()
    return args

def get_device():
    if torch.cuda.is_available():
        return 'cuda'
    elif torch.backends.mps.is_available():
        return 'mps'
    else:
        return 'cpu'

device = get_device()
print(device)

wandb.login(key='eff5a31d6dfda82af022ae7c5286724a57c42f8c')
# Initialize wandb sweep config
sweep_config = {
    'method': 'grid',  # Options: 'random', 'grid', 'bayes'
    'metric': {
        'name': 'val_loss',
        'goal': 'minimize'
    },
    'parameters': {
        #'num_message_passing_layers': {'values': [1]}, # [2, 3, 4, 5]
        #'clip_value': {'values': [1, 10, 100]}
        'lr': {'values': [5e-4, 4]}, # [1e-3, 5e-4, 1e-4]
        #'batch_size_train': {'values': [32, 64, 100]},
        #'weight_decay': {'values': [0.01]} # [0.01, 0.001, 0.0001]
    }
}

sweep_id = wandb.sweep(sweep_config, project="layeropt_6")


def main():
    # Parse static arguments
    args = cli()
    seed_everything(args.seed)

    # wandb configuration (dynamic hyperparameters)
    with wandb.init(config=args.__dict__) as run:
        config = wandb.config  # Access wandb-specified parameters
        print(wandb.run.id)

        # Use a mix of CLI and wandb settings
        dm = QM9DataModule(
            target=args.target,
            data_dir=args.data_dir,
            batch_size_train=config.get("batch_size_train", args.batch_size_train),
            batch_size_inference=args.batch_size_inference,
            num_workers=args.num_workers,
            splits=args.splits,
            seed=args.seed,
            subset_size=args.subset_size,
        )
        dm.prepare_data()
        dm.setup()
        y_mean, y_std, atom_refs = dm.get_target_stats(
        remove_atom_refs=True, divide_by_atoms=True
        )
        painn = PaiNN(
            num_message_passing_layers=config.get("num_message_passing_layers", args.num_message_passing_layers),
            num_features=args.num_features,
            num_outputs=args.num_outputs,
            num_rbf_features=args.num_rbf_features,
            num_unique_atoms=args.num_unique_atoms,
            cutoff_dist=args.cutoff_dist,
            device=device,
        )
        post_processing = AtomwisePostProcessing(
        args.num_outputs, y_mean, y_std, atom_refs
        )
        painn.to(device)
        post_processing.to(device)

        # Optimizer hyperparameters come from wandb
        optimizer = torch.optim.AdamW(
            painn.parameters(),
            lr=config.get("lr", args.lr),
            weight_decay=config.get("weight_decay", args.weight_decay),
        )
        scheduler = ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5)
        # Early stopping setup
        smoothed_loss = 0.0
        best_val_loss_smooth = float('inf')
        patience = 30  # Stop training if no improvement after 20 epochs
        patience_counter = 0
        alpha = 0.9
        
        pbar = trange(args.num_epochs)
        for epoch in pbar:
            painn.train()
            train_loss = 0.0
            train_mae = 0.0  # Initialize train MAE accumulator

            for batch in dm.train_dataloader():
                batch = batch.to(device)
                atomic_contributions = painn(
                    atoms=batch.z,
                    atom_positions=batch.pos,
                    graph_indexes=batch.batch
                )
                preds = post_processing(
                    atoms=batch.z,
                    graph_indexes=batch.batch,
                    atomic_contributions=atomic_contributions,
                )
                loss = F.mse_loss(preds, batch.y, reduction='sum') / len(batch.y)
                optimizer.zero_grad()
                loss.backward()
                #if epoch < 10: torch.nn.utils.clip_grad_norm_(painn.parameters(), max_norm=args.clip_value) 
                #else: 
                torch.nn.utils.clip_grad_value_(painn.parameters(), clip_value=config.get("clip_value", 100))
                optimizer.step()

                train_loss += loss.item()

                # Compute MAE
                mae_step = F.l1_loss(preds, batch.y, reduction='sum').item()
                train_mae += mae_step

            train_loss /= len(dm.data_train)
            train_mae /= len(dm.data_train)  # Normalize MAE by dataset size

            # Validation Loop
            painn.eval()
            val_loss = 0.0
            val_mae = 0.0  # Initialize val MAE accumulator
            with torch.no_grad():
                for batch in dm.val_dataloader():
                    batch = batch.to(device)
                    atomic_contributions = painn(
                        atoms=batch.z,
                        atom_positions=batch.pos,
                        graph_indexes=batch.batch,
                    )
                    preds = post_processing(
                        atoms=batch.z,
                        graph_indexes=batch.batch,
                        atomic_contributions=atomic_contributions,
                    )
                    val_loss += F.mse_loss(preds, batch.y, reduction='sum').item()
                    
                    # Compute MAE
                    mae_step = F.l1_loss(preds, batch.y, reduction='sum').item()
                    val_mae += mae_step

            val_loss /= len(dm.data_val)
            val_mae /= len(dm.data_val)  # Normalize MAE by dataset size
            
            smoothed_loss = alpha * val_loss + (1 - alpha) * smoothed_loss #  Smooth loss
            scheduler.step(smoothed_loss) # update rlr
            pbar.set_postfix_str(f'Train loss: {train_loss:.3e}, Val loss: {val_loss:.3e}')

            # Log metrics to wandb
            wandb.log({
                'epoch': epoch,
                'train_loss': train_loss,
                'val_loss': val_loss,
                'train_mae': train_mae,
                'val_mae': val_mae,
            })

            # Early stopping
            if (best_val_loss_smooth - smoothed_loss) > 0.0000001:
                best_val_loss_smooth = smoothed_loss
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    break

# Test evaluation
        mae = 0
        painn.eval()
        with torch.no_grad():
            for batch in dm.test_dataloader():
                batch = batch.to(device)
                atomic_contributions = painn(
                    atoms=batch.z,
                    atom_positions=batch.pos,
                    graph_indexes=batch.batch,
                )
                preds = post_processing(
                    atoms=batch.z,
                    graph_indexes=batch.batch,
                    atomic_contributions=atomic_contributions,
                )
                mae += F.l1_loss(preds, batch.y, reduction='sum')
        mae /= len(dm.data_test)
        unit_conversion = dm.unit_conversion[args.target]
        wandb.log({'Test MAE': unit_conversion(mae.item())})
        #print(os.getcwd() + '/src/results/model_{wandb.run.id}.pth')
        torch.save(painn.state_dict(), f'./src/results/model_{wandb.run.id}.pth')

# Run the sweep agent
wandb.agent(sweep_id, function=main)
