# Atom and Bond Prediction

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chemprop/chemprop/blob/main/examples/mol_atom_bond.ipynb)

In [1]:
# Install chemprop from GitHub if running in Google Colab
import os

if os.getenv("COLAB_RELEASE_TAG"):
    try:
        import chemprop
    except ImportError:
        !git clone https://github.com/chemprop/chemprop.git
        %cd chemprop
        !pip install .
        %cd examples

In [2]:
import ast
from pathlib import Path

from lightning import pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
import numpy as np
import pandas as pd
import torch

from chemprop import data, featurizers, models, nn

chemprop_dir = Path.cwd().parent
data_dir = chemprop_dir / "tests" / "data" / "mol_atom_bond"

This notebook shows how to use Chemprop to fit models on atom and bond property data. One model can predict molecule-, atom-, and bond-level properties at the same time. 

## Make datapoints

The atom and bond targets are saved as strings that look like lists. This example uses regression targets, but classification (including multiclass) is also supported.

In [3]:
df_input = pd.read_csv(data_dir / "regression.csv")
df_input

Unnamed: 0,smiles,mol_y1,mol_y2,atom_y1,atom_y2,bond_y1,bond_y2,weight
0,[H][H],2.016,2.0,"[1, 1]","[1.008, 1.008]",[2],[-2],0.090909
1,C,16.043,1.0,[6],[12.011],[],[],0.181818
2,CN,31.058,2.0,"[6, 7]","[12.011, 14.007]",[13],[-13],0.272727
3,CN,31.058,,"[6, 7]","[None, 14.007]",[13],[None],0.363636
4,CC,30.07,2.0,"[6, 6]","[12.011, 12.011]",[12],[-12],0.454545
5,[CH2:3]=[N+:1]([H:4])[H:2],30.05,4.0,"[7, 1, 6, 1]","[14.007, 1.008, 12.011, 1.008]","[13, 8, 8]","[-13, -8, -8]",0.545455
6,CCCC,58.124,4.0,"[6, 6, 6, 6]","[12.011, 12.011, 12.011, 12.011]","[12, 12, 12]","[-12, -12, -12]",0.636364
7,CO,32.042,2.0,"[6, 8]","[12.011, 15.999]",[14],[-14],0.727273
8,CC#N,41.053,3.0,"[6, 6, 7]","[12.011, 12.011, 14.007]","[12, 13]","[-12, -13]",0.818182
9,C1NN1,44.057,3.0,"[6, 7, 7]","[12.011, 14.007, 14.007]","[13, 14, 13]","[-13, -14, -13]",0.909091


### Load optional extra features and descriptors
Extra bond descriptors can be used when making bond property predictions, analogous to extra atom descriptors.

In [4]:
x_ds = np.load(data_dir / "descriptors.npz")["arr_0"]
V_fs = np.load(data_dir / "atom_features_descriptors.npz")
V_fs = [V_fs[f"arr_{i}"] for i in range(len(V_fs))]
V_ds = V_fs
E_fs = np.load(data_dir / "bond_features_descriptors.npz")
E_fs = [E_fs[f"arr_{i}"] for i in range(len(E_fs))]
E_ds = [np.repeat(E_f, repeats=2, axis=0) for E_f in E_fs]

In [5]:
columns = ["smiles", "mol_y1", "mol_y2", "atom_y1", "atom_y2", "bond_y1", "bond_y2", "weight"]
smis = df_input.loc[:, columns[0]].values
mol_ys = df_input.loc[:, columns[1:3]].values
atoms_ys = df_input.loc[:, columns[3:5]].values
bonds_ys = df_input.loc[:, columns[5:7]].values
weights = df_input.loc[:, columns[7]].values

# String lists are converted to lists using ast.literal_eval
atoms_ys = [
    np.array([ast.literal_eval(atom_y) for atom_y in atom_ys], dtype=float).T
    for atom_ys in atoms_ys
]
bonds_ys = [
    np.array([ast.literal_eval(bond_y) for bond_y in bond_ys], dtype=float).T
    for bond_ys in bonds_ys
]

datapoints = [
    data.MolAtomBondDatapoint.from_smi(
        smi,
        keep_h=True,
        add_h=False,
        # If the atom targets follow the order of an atom mapping in the SMILES string instead of
        # the order of the atoms in the SMILES string (i.e. [F:2][Cl:1]), set reorder_atoms=True.
        reorder_atoms=True,
        y=mol_ys[i],
        atom_y=atoms_ys[i],
        bond_y=bonds_ys[i],
        weight=weights[i],
        x_d=x_ds[i],
        V_f=V_fs[i],
        V_d=V_ds[i],
        E_f=E_fs[i],
        E_d=E_ds[i],
    )
    for i, smi in enumerate(smis)
]

If the regression targets are bounded (i.e. look like "<3" or ">0.1"), parsing the atom and bond targets is a bit more complicated. Note that `BoundedMSE` should be used as the loss function (`RegressionFFN(criterion=BoundedMSE)`) and the less-than and greater-than masks should be given to the datapoints. 

In [6]:
bounded = False
if bounded:    
    mol_ys = mol_ys.astype(str)
    lt_mask = mol_ys.map(lambda x: "<" in x).to_numpy()
    gt_mask = mol_ys.map(lambda x: ">" in x).to_numpy()
    mol_ys = mol_ys.map(lambda x: x.strip("<").strip(">")).to_numpy(np.single)

    atoms_ys = atoms_ys.map(ast.literal_eval)
    atom_lt_masks = atoms_ys.map(lambda L: ["<" in v if v else False for v in L])
    atom_gt_masks = atoms_ys.map(lambda L: [">" in v if v else False for v in L])

    atom_lt_masks = atom_lt_masks.apply(lambda row: np.vstack(row.values).T, axis=1).tolist()
    atom_gt_masks = atom_gt_masks.apply(lambda row: np.vstack(row.values).T, axis=1).tolist()
    atoms_ys = atoms_ys.map(
        lambda L: np.array([v.strip("<").strip(">") if v else "nan" for v in L], dtype=np.single)
    )
    atoms_ys = atoms_ys.apply(lambda row: np.vstack(row.values).T, axis=1).tolist()

    bonds_ys = bonds_ys.map(ast.literal_eval)
    bond_lt_masks = bonds_ys.map(lambda L: ["<" in v if v else False for v in L])
    bond_gt_masks = bonds_ys.map(lambda L: [">" in v if v else False for v in L])

    bond_lt_masks = bond_lt_masks.apply(lambda row: np.vstack(row.values).T, axis=1).tolist()
    bond_gt_masks = bond_gt_masks.apply(lambda row: np.vstack(row.values).T, axis=1).tolist()

    bond_lt_masks = [bond_lt_mask.astype(bool) for bond_lt_mask in bond_lt_masks]
    bond_gt_masks = [bond_gt_mask.astype(bool) for bond_gt_mask in bond_gt_masks]

    bonds_ys = bonds_ys.map(
        lambda L: np.array([v.strip("<").strip(">") if v else "nan" for v in L], dtype=np.single)
    )
    bonds_ys = bonds_ys.apply(lambda row: np.vstack(row.values).T, axis=1).tolist()

    datapoints = [
        data.MolAtomBondDatapoint.from_smi(
            smi,
            keep_h=True,
            add_h=False,
            reorder_atoms=True,
            y=mol_ys[i],
            atom_y=atoms_ys[i],
            bond_y=bonds_ys[i],
            weight=weights[i],
            lt_mask=lt_mask[i],
            gt_mask=gt_mask[i],
            atom_lt_mask=atom_lt_masks[i],
            atom_gt_mask=atom_gt_masks[i],
            bond_lt_mask=bond_lt_masks[i],
            bond_gt_mask=bond_gt_masks[i],
        )
        for i, smi in enumerate(smis)
    ]

## Make datasets 

In [7]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer(
    extra_atom_fdim=V_fs[0].shape[1], extra_bond_fdim=E_fs[0].shape[1]
)

train_dataset = data.MolAtomBondDataset(datapoints, featurizer=featurizer)
val_dataset = data.MolAtomBondDataset(datapoints, featurizer=featurizer)
test_dataset = data.MolAtomBondDataset(datapoints, featurizer=featurizer)
predict_dataset = data.MolAtomBondDataset(datapoints, featurizer=featurizer)

## Scale the extra features and descriptors
If extra features and descriptors are used, they can be scaled to make training easier. The scalers are turned into "transforms" which are given to the model to use at inference time. 

In [8]:
V_f_scaler = train_dataset.normalize_inputs("V_f")
E_f_scaler = train_dataset.normalize_inputs("E_f")
V_d_scaler = train_dataset.normalize_inputs("V_d")
E_d_scaler = train_dataset.normalize_inputs("E_d")
val_dataset.normalize_inputs("V_f", V_f_scaler)
val_dataset.normalize_inputs("E_f", E_f_scaler)
val_dataset.normalize_inputs("V_d", V_d_scaler)
val_dataset.normalize_inputs("E_d", E_d_scaler)

V_f_transform = nn.ScaleTransform.from_standard_scaler(
    V_f_scaler, pad=(featurizer.atom_fdim - featurizer.extra_atom_fdim)
)
E_f_transform = nn.ScaleTransform.from_standard_scaler(
    E_f_scaler, pad=(featurizer.bond_fdim - featurizer.extra_bond_fdim)
)
graph_transform = nn.GraphTransform(V_f_transform, E_f_transform)

V_d_transform = nn.ScaleTransform.from_standard_scaler(V_d_scaler)
E_d_transform = nn.ScaleTransform.from_standard_scaler(E_d_scaler)

X_d_scaler = train_dataset.normalize_inputs("X_d")
val_dataset.normalize_inputs("X_d", X_d_scaler)
X_d_transform = nn.ScaleTransform.from_standard_scaler(X_d_scaler)

## Scale the regression targets

In [9]:
mol_target_scaler = train_dataset.normalize_targets("mol")
atom_target_scaler = train_dataset.normalize_targets("atom")
bond_target_scaler = train_dataset.normalize_targets("bond")
val_dataset.normalize_targets("mol", mol_target_scaler)
val_dataset.normalize_targets("atom", atom_target_scaler)
val_dataset.normalize_targets("bond", bond_target_scaler)
mol_target_transform = nn.UnscaleTransform.from_standard_scaler(mol_target_scaler)
atom_target_transform = nn.UnscaleTransform.from_standard_scaler(atom_target_scaler)
bond_target_transform = nn.UnscaleTransform.from_standard_scaler(bond_target_scaler)

## Make dataloaders

In [10]:
train_dataloader = data.build_dataloader(train_dataset, shuffle=True, batch_size=4)
val_dataloader = data.build_dataloader(val_dataset, shuffle=False, batch_size=4)
test_dataloader = data.build_dataloader(test_dataset, shuffle=False, batch_size=4)
predict_dataloader = data.build_dataloader(predict_dataset, shuffle=False, batch_size=4)

## The MAB (mol atom bond) message passing returns both learned node embeddings and learned edge embeddings
`MABBondMessagePassing` takes the same customization arguments as the usual `BondMessagePassing` class

In [11]:
mp = nn.MABBondMessagePassing(
    d_v=featurizer.atom_fdim,
    d_e=featurizer.bond_fdim,
    d_h=100,
    d_vd=V_ds[0].shape[1],
    d_ed=E_ds[0].shape[1],
    dropout=0.1,
    activation="tanh",
    depth=4,
    graph_transform=graph_transform,
    V_d_transform=V_d_transform,
    E_d_transform=E_d_transform,
)

## A separate predictor is used for each of the molecule, atom, and bond predictions

In [12]:
agg = nn.MeanAggregation()

# Note that each predictor may have a different input dimension
mol_predictor = nn.RegressionFFN(
    input_dim=mp.output_dims[0] + x_ds.shape[1],
    n_tasks=mol_ys.shape[1],
    output_transform=mol_target_transform,
)
atom_predictor = nn.RegressionFFN(
    input_dim=mp.output_dims[0],
    n_tasks=atoms_ys[0].shape[1],
    output_transform=atom_target_transform,
)
bond_predictor = nn.RegressionFFN(
    input_dim=(mp.output_dims[1] * 2),
    n_tasks=bonds_ys[0].shape[1],
    output_transform=bond_target_transform,
)

Different predictors can be used for different types of tasks including but not limited to `MveFFN`, `BinaryClassificationFFN`, `MulticlassClassificationFFN`.

## Combine the layers into a single model

In [13]:
metrics = [nn.MAE(), nn.RMSE()]
model = models.MolAtomBondMPNN(
    message_passing=mp,
    agg=agg,
    mol_predictor=mol_predictor,
    atom_predictor=atom_predictor,
    bond_predictor=bond_predictor,
    batch_norm=True,
    metrics=metrics,
    X_d_transform=X_d_transform,
)

In [14]:
model

MolAtomBondMPNN(
  (message_passing): MABBondMessagePassing(
    (W_i): Linear(in_features=90, out_features=100, bias=False)
    (W_h): Linear(in_features=100, out_features=100, bias=False)
    (W_vo): Linear(in_features=174, out_features=100, bias=True)
    (W_vd): Linear(in_features=102, out_features=102, bias=True)
    (W_eo): Linear(in_features=116, out_features=100, bias=True)
    (W_ed): Linear(in_features=102, out_features=102, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (tau): Tanh()
    (V_d_transform): ScaleTransform()
    (E_d_transform): ScaleTransform()
    (graph_transform): GraphTransform(
      (V_transform): ScaleTransform()
      (E_transform): ScaleTransform()
    )
  )
  (agg): MeanAggregation()
  (mol_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=104, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_fea

If any of molecule, atom, or bond targets are not used, the corresponding predictor isn't added to the model. If bond targets are not used, the message passing layer should be told to not return the bond embeddings to avoid initializing weight matrices that won't be used. If molecule targets are not used, the aggregation layer isn't added to the model. If both molecule and atom targets are not used, the message passing layer should be told not to return the node embeddings. 

In [15]:
no_bond = False
no_mol = False
no_mol_atom = False

if no_bond:
    mp = nn.MABBondMessagePassing(return_edge_embeddings=False)
    agg = nn.NormAggregation()
    mol_predictor = nn.RegressionFFN()
    atom_predictor = nn.RegressionFFN()
    model = models.MolAtomBondMPNN(
        message_passing=mp, agg=agg, mol_predictor=mol_predictor, atom_predictor=atom_predictor
    )

if no_mol:
    mp = nn.MABBondMessagePassing()
    atom_predictor = nn.RegressionFFN()
    bond_predictor = nn.RegressionFFN(input_dim=(mp.output_dims[1] * 2))
    model = models.MolAtomBondMPNN(
        message_passing=mp, atom_predictor=atom_predictor, bond_predictor=bond_predictor
    )

if no_mol_atom:
    mp = nn.MABBondMessagePassing(return_vertex_embeddings=False)
    bond_predictor = nn.RegressionFFN(input_dim=(mp.output_dims[1] * 2))
    model = models.MolAtomBondMPNN(message_passing=mp, bond_predictor=bond_predictor)

## Set up trainer with checkpointing

In [None]:
checkpointing = ModelCheckpoint(
    dirpath="MABcheckpoints",
    filename="best-{epoch}-{val_loss:.2f}",
    monitor="val_loss",
    mode="min",
    save_last=True,
)

trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True,
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=20,
    callbacks=[checkpointing],
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [17]:
trainer.fit(model, train_dataloader, val_dataloader)

Loading `train_dataloader` to estimate number of stepping batches.

  | Name            | Type                  | Params | Mode 
------------------------------------------------------------------
0 | message_passing | MABBondMessagePassing | 69.2 K | train
1 | agg             | MeanAggregation       | 0      | train
2 | mol_predictor   | RegressionFFN         | 32.1 K | train
3 | atom_predictor  | RegressionFFN         | 31.5 K | train
4 | bond_predictor  | RegressionFFN         | 62.1 K | train
5 | bns             | ModuleList            | 612    | train
6 | X_d_transform   | ScaleTransform        | 0      | train
7 | metricss        | ModuleList            | 0      | train
------------------------------------------------------------------
195 K     Trainable params
0         Non-trainable params
195 K     Total params
0.782     Total estimated model params size (MB)
63        Modules in train mode
0         Modules in eval mode


Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00,  7.59it/s]

Epoch 19: 100%|██████████| 3/3 [00:01<00:00,  2.24it/s, mol_train_loss_step=0.0765, atom_train_loss_step=0.113, bond_train_loss_step=0.148, train_loss_step=0.337, mol_val_loss=0.0379, atom_val_loss=0.0219, bond_val_loss=0.033, val_loss=0.136, mol_train_loss_epoch=0.084, atom_train_loss_epoch=0.0578, bond_train_loss_epoch=0.0589, train_loss_epoch=0.209]    

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 3/3 [00:01<00:00,  2.14it/s, mol_train_loss_step=0.0765, atom_train_loss_step=0.113, bond_train_loss_step=0.148, train_loss_step=0.337, mol_val_loss=0.0379, atom_val_loss=0.0219, bond_val_loss=0.033, val_loss=0.136, mol_train_loss_epoch=0.084, atom_train_loss_epoch=0.0578, bond_train_loss_epoch=0.0589, train_loss_epoch=0.209]


In [18]:
results = trainer.test(dataloaders=test_dataloader)

Testing DataLoader 0: 100%|██████████| 3/3 [00:00<00:00, 20.82it/s]


In [19]:
predss = trainer.predict(model, predict_dataloader)
mol_preds, atom_preds, bond_preds = (torch.concat(tensors) for tensors in zip(*predss))

Predicting DataLoader 0: 100%|██████████| 3/3 [00:00<00:00, 49.64it/s]


## Split the atom and bond predictions into a list of tensors, one for each molecule

In [20]:
atoms_per_mol = [mol.GetNumAtoms() for mol in predict_dataset.mols]
bonds_per_mol = [mol.GetNumBonds() for mol in predict_dataset.mols]

atom_preds = torch.split(atom_preds, atoms_per_mol)
bond_preds = torch.split(bond_preds, bonds_per_mol)

## Save and load the model

In [21]:
models.utils.save_model("temp.pt", model)
models.MolAtomBondMPNN.load_from_file("temp.pt")

MolAtomBondMPNN(
  (message_passing): MABBondMessagePassing(
    (W_i): Linear(in_features=90, out_features=100, bias=False)
    (W_h): Linear(in_features=100, out_features=100, bias=False)
    (W_vo): Linear(in_features=174, out_features=100, bias=True)
    (W_vd): Linear(in_features=102, out_features=102, bias=True)
    (W_eo): Linear(in_features=116, out_features=100, bias=True)
    (W_ed): Linear(in_features=102, out_features=102, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (tau): Tanh()
    (V_d_transform): ScaleTransform()
    (E_d_transform): ScaleTransform()
    (graph_transform): GraphTransform(
      (V_transform): ScaleTransform()
      (E_transform): ScaleTransform()
    )
  )
  (agg): MeanAggregation()
  (mol_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=104, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_fea

In [None]:
models.MolAtomBondMPNN.load_from_checkpoint("MABcheckpoints/last.ckpt")

MolAtomBondMPNN(
  (message_passing): MABBondMessagePassing(
    (W_i): Linear(in_features=90, out_features=100, bias=False)
    (W_h): Linear(in_features=100, out_features=100, bias=False)
    (W_vo): Linear(in_features=174, out_features=100, bias=True)
    (W_vd): Linear(in_features=102, out_features=102, bias=True)
    (W_eo): Linear(in_features=116, out_features=100, bias=True)
    (W_ed): Linear(in_features=102, out_features=102, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (tau): Tanh()
    (V_d_transform): ScaleTransform()
    (E_d_transform): ScaleTransform()
    (graph_transform): GraphTransform(
      (V_transform): ScaleTransform()
      (E_transform): ScaleTransform()
    )
  )
  (agg): MeanAggregation()
  (mol_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=104, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_fea