# Develop the PNA graph model

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

import datamol as dm
import dgl
import torch
from ogb.utils.features import (
    atom_to_feature_vector,
    bond_to_feature_vector,
    get_atom_feature_dims,
    get_bond_feature_dims,
)
from rdkit import Chem
from rdkit.Chem.rdmolops import GetAdjacencyMatrix

from src.modules.molecules.pna import PNA

In [3]:
def get_graph_from_mol(smiles):
    mol = dm.to_mol(smiles)
    n_atoms = len(mol.GetAtoms())

    atom_features_list = []
    for atom in mol.GetAtoms():
        atom_features_list.append(atom_to_feature_vector(atom))

    atom_features = torch.tensor(atom_features_list, dtype=torch.long)

    edges_list = []
    edge_features_list = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edge_feature = bond_to_feature_vector(bond)

        # add edges in both directions
        edges_list.append((i, j))
        edge_features_list.append(edge_feature)
        edges_list.append((j, i))
        edge_features_list.append(edge_feature)

    # Graph connectivity in COO format with shape [2, num_edges]
    edge_index = torch.tensor(edges_list, dtype=torch.long).T
    edge_features = torch.tensor(edge_features_list, dtype=torch.long)

    graph = dgl.graph(
        data=(edge_index[0], edge_index[1]),
        num_nodes=n_atoms,
    )

    graph.ndata["feat"] = atom_features
    graph.edata["feat"] = edge_features

    return graph

In [6]:
smiles = [
    "CCO",
    "CCCC",
    "CC(=O)O",
    "CC(=O)OC",
]

In [88]:
batch = dgl.batch([get_graph_from_mol(s) for s in smiles])

In [24]:
batch

Graph(num_nodes=16, num_edges=24,
      ndata_schemes={'feat': Scheme(shape=(9,), dtype=torch.int64)}
      edata_schemes={'feat': Scheme(shape=(3,), dtype=torch.int64)})

In [45]:
get_atom_feature_dims()

[119, 5, 12, 12, 10, 6, 6, 2, 2]

In [9]:
pna = PNA(
    hidden_dim=200,
    target_dim=256,
    ckpt_path="models/best_checkpoint_35epochs.pt",
    aggregators=["mean", "max", "min", "std"],
    scalers=["identity", "amplification", "attenuation"],
    readout_aggregators=["min", "max", "mean"],
    readout_batchnorm=True,
    readout_hidden_dim=200,
    readout_layers=2,
    residual=True,
    pairwise_distances=False,
    activation="relu",
    last_activation="none",
    mid_batch_norm=True,
    last_batch_norm=True,
    propagation_depth=7,
    dropout=0.0,
    posttrans_layers=1,
    pretrans_layers=2,
    batch_norm_momentum=0.93,
)