# Constrained Atom and Bond Prediction

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chemprop/chemprop/blob/main/examples/constrained_mol_atom_bond.ipynb)

In [1]:
# Install chemprop from GitHub if running in Google Colab
import os

if os.getenv("COLAB_RELEASE_TAG"):
    try:
        import chemprop
    except ImportError:
        !git clone https://github.com/chemprop/chemprop.git
        %cd chemprop
        !pip install .
        %cd examples

In [2]:
import ast
from pathlib import Path

from lightning import pytorch as pl
import numpy as np
import pandas as pd
import torch

from chemprop import data, featurizers, models, nn


chemprop_dir = Path.cwd().parent
data_dir = chemprop_dir / "tests" / "data" / "mol_atom_bond"

If any of the atom or bond properties should sum to a known molecule level value, we can constrain the atom and bond predictions to sum to that value. For example, atom partial charges should sum to the total charge of the molecule.

## Make datapoints

In [3]:
df_input = pd.read_csv(data_dir / "constrained_regression.csv")
df_input

Unnamed: 0,smiles,mol_y,atom_y1,atom_y2,bond_y1,bond_y2
0,[H][H],0,"[0, 0]","[1.008, 1.008]",[2],[2]
1,C,0,[0],[12.011],[],[]
2,CN,0,"[0, 0]","[12.011, 14.007]",[13],[2]
3,CN,0,"[0, 0]","[12.011, 14.007]",[13],[2]
4,CC,0,"[0, 0]","[12.011, 12.011]",[12],[2]
5,[CH2:3]=[N+:1]([H:4])[H:2],1,"[1, 0, 0, 0]","[14.007, 1.008, 12.011, 1.008]","[13, 8, 8]","[4, 2, 2]"
6,CCCC,0,"[0, 0, 0, 0]","[12.011, 12.011, 12.011, 12.011]","[12, 12, 12]","[2, 2, 2]"
7,CO,0,"[0, 0]","[12.011, 15.999]",[14],[2]
8,CC#N,0,"[0, 0, 0]","[12.011, 12.011, 14.007]","[12, 13]","[2, 6]"
9,C1NN1,0,"[0, 0, 0]","[12.011, 14.007, 14.007]","[13, 14, 13]","[2, 2, 2]"


In [4]:
columns = ["smiles", "mol_y", "atom_y1", "atom_y2", "bond_y1", "bond_y2"]
smis = df_input.loc[:, columns[0]].values
mol_ys = df_input.loc[:, columns[1:2]].values
atoms_ys = df_input.loc[:, columns[2:4]].values
bonds_ys = df_input.loc[:, columns[4:6]].values

atoms_ys = [
    np.array([ast.literal_eval(atom_y) for atom_y in atom_ys], dtype=float).T
    for atom_ys in atoms_ys
]
bonds_ys = [
    np.array([ast.literal_eval(bond_y) for bond_y in bond_ys], dtype=float).T
    for bond_ys in bonds_ys
]

### Load constraints

Not all atom and bond predictions need to be constrained. Here both atom predictions are constrained and only one of the bond predictions is constrained.

In [5]:
df_constraints = pd.read_csv(data_dir / "constrained_regression_constraints.csv")
df_constraints

Unnamed: 0,atom_y1_constraint,atom_y2_constraint,bond_y2_constraint
0,0,2.016,2
1,0,12.011,0
2,0,26.018,2
3,0,26.018,2
4,0,24.022,2
5,1,28.034,8
6,0,48.044,6
7,0,28.01,2
8,0,38.029,8
9,0,40.025,6


In [6]:
n_mols = len(df_constraints)

# A dictionary to map the atom and bond target columns to the corresponding constraint column
constraints_cols_to_target_cols = {
    "atom_y1": 0,
    "atom_y2": 1,
    "bond_y2": 2,
}

# Target columns without constraints have their constraints set to np.nan
atom_constraint_cols = [
    constraints_cols_to_target_cols.get(col)
    for col in columns[2:4]
]
atom_constraints = np.hstack(
    [
        df_constraints.iloc[:, col].values.reshape(-1, 1)
        if col is not None
        else np.full((n_mols, 1), np.nan)
        for col in atom_constraint_cols
    ]
)

bond_constraint_cols = [
    constraints_cols_to_target_cols.get(col)
    for col in columns[4:6]
]
bond_constraints = np.hstack(
    [
        df_constraints.iloc[:, col].values.reshape(-1, 1)
        if col is not None
        else np.full((n_mols, 1), np.nan)
        for col in bond_constraint_cols
    ]
)

In [7]:
datapoints = [
    data.MolAtomBondDatapoint.from_smi(
        smi,
        keep_h=True,
        add_h=False,
        reorder_atoms=True,
        y=mol_ys[i],
        atom_y=atoms_ys[i],
        bond_y=bonds_ys[i],
        atom_constraint=atom_constraints[i],
        bond_constraint=bond_constraints[i],
    )
    for i, smi in enumerate(smis)
]

train_dataset = data.MolAtomBondDataset(datapoints)
val_dataset = data.MolAtomBondDataset(datapoints)
test_dataset = data.MolAtomBondDataset(datapoints)
predict_dataset = data.MolAtomBondDataset(datapoints)

# If the atom/bond targets are scaled, the corresponding constraints are also scaled automatically.
atom_target_scaler = train_dataset.normalize_targets("atom")
val_dataset.normalize_targets("atom", atom_target_scaler)
atom_target_transform = nn.UnscaleTransform.from_standard_scaler(atom_target_scaler)

train_dataloader = data.build_dataloader(train_dataset, shuffle=True)
val_dataloader = data.build_dataloader(val_dataset, shuffle=False)
test_dataloader = data.build_dataloader(test_dataset, shuffle=False)
predict_dataloader = data.build_dataloader(predict_dataset, shuffle=False)

## Set up model

In [8]:
mp = nn.MABBondMessagePassing()
agg = nn.NormAggregation()
mol_predictor = nn.RegressionFFN(n_tasks=mol_ys.shape[1])
atom_predictor = nn.RegressionFFN(n_tasks=atoms_ys[0].shape[1], output_transform=atom_target_transform)
bond_predictor = nn.RegressionFFN(input_dim=(mp.output_dims[1] * 2), n_tasks=bonds_ys[0].shape[1])

Each atom/bond prediction for a constrained target is adjusted so they sum to the constraint. The amount each individual prediction is adjusted is determined from the node/edge fingerprints using a separate feed forward network. 

In [9]:
atom_constrainer = nn.ConstrainerFFN(n_constraints=(~np.isnan(atom_constraints[0])).sum())
bond_constrainer = nn.ConstrainerFFN(n_constraints=(~np.isnan(bond_constraints[0])).sum(), fp_dim=600)

In [10]:
model = models.MolAtomBondMPNN(
    message_passing=mp,
    agg=agg,
    mol_predictor=mol_predictor,
    atom_predictor=atom_predictor,
    bond_predictor=bond_predictor,
    atom_constrainer=atom_constrainer,
    bond_constrainer=bond_constrainer,
)

In [11]:
model

MolAtomBondMPNN(
  (message_passing): MABBondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_vo): Linear(in_features=372, out_features=300, bias=True)
    (W_eo): Linear(in_features=314, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (E_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): NormAggregation()
  (mol_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=1, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0]])
    (output_transform): Identity()
  )
  (atom_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(i

## The atom and bond predictions obey the constraints

In [12]:
batch = next(iter(predict_dataloader))
bmg, V_d, E_d, X_d, *_, constraints = batch
with torch.no_grad():
    mol_preds, atom_preds_tensor, bond_preds_tensor = model(bmg, V_d, E_d, X_d, constraints)

In [13]:
atoms_per_mol = [mol.GetNumAtoms() for mol in predict_dataset.mols]
atom_preds = torch.split(atom_preds_tensor, atoms_per_mol)
errors = predict_dataset.atom_constraints - torch.vstack([p.sum(dim=0) for p in atom_preds]).numpy()
print(errors)
assert np.all(np.isclose(errors[~np.isnan(errors)], 0.0, atol=1e-5))

[[ 0.00000000e+00 -3.24249267e-08]
 [ 0.00000000e+00  3.20434570e-07]
 [ 3.72529030e-09 -1.55639648e-06]
 [ 3.72529030e-09 -1.55639648e-06]
 [ 0.00000000e+00  6.40869139e-07]
 [ 0.00000000e+00 -3.96728517e-07]
 [-1.49011612e-08  1.28173828e-06]
 [ 3.72529030e-09  1.67846680e-06]
 [ 3.72529030e-09  6.71386722e-07]
 [ 3.72529030e-09  2.28881836e-06]
 [ 0.00000000e+00  1.31225586e-06]]


In [14]:
bonds_per_mol = [mol.GetNumBonds() for mol in predict_dataset.mols]
bond_preds = torch.split(bond_preds_tensor, bonds_per_mol)
errors = predict_dataset.bond_constraints - torch.vstack([p.sum(dim=0) for p in bond_preds]).numpy()
print(errors)
assert np.all(np.isclose(errors[~np.isnan(errors)], 0.0, atol=1e-5))

[[           nan 0.00000000e+00]
 [           nan 0.00000000e+00]
 [           nan 1.19209290e-07]
 [           nan 1.19209290e-07]
 [           nan 0.00000000e+00]
 [           nan 4.76837158e-07]
 [           nan 0.00000000e+00]
 [           nan 0.00000000e+00]
 [           nan 0.00000000e+00]
 [           nan 0.00000000e+00]
 [           nan 9.53674316e-07]]


## Fit the model

In [15]:
trainer = pl.Trainer(
    logger=False,
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=20,
)

Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [16]:
trainer.fit(model, train_dataloader, val_dataloader)

/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /home/knathan/chemprop/examples/checkpoints exists and is not empty.
Loading `train_dataloader` to estimate number of stepping batches.
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.

  | Name             | Type                  | Params | Mode 
-------------------------------------------------------------------
0 | message_passing  | MABBondMessagePassing | 322 K  | train
1 | agg              | NormAggregation       | 0      | train
2 | mol_predictor    | RegressionFFN         | 90.6 K | train
3 | atom_predictor   | RegressionFFN         | 90.9 K | train
4 | atom_constr

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Sanity Checking DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]

/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Epoch 19: 100%|██████████| 1/1 [00:00<00:00,  3.89it/s, train_loss_step=82.10, val_loss=81.20, train_loss_epoch=82.10]

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 1/1 [00:00<00:00,  3.26it/s, train_loss_step=82.10, val_loss=81.20, train_loss_epoch=82.10]


In [17]:
results = trainer.test(dataloaders=test_dataloader)

Restoring states from the checkpoint path at /home/knathan/chemprop/examples/checkpoints/epoch=19-step=20.ckpt
Loaded model weights from the checkpoint at /home/knathan/chemprop/examples/checkpoints/epoch=19-step=20.ckpt
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 20.36it/s]


In [18]:
predss = trainer.predict(model, predict_dataloader)
mol_preds, atom_preds, bond_preds = (torch.concat(tensors) for tensors in zip(*predss))

/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Predicting DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 25.98it/s]


In [19]:
atoms_per_mol = [mol.GetNumAtoms() for mol in predict_dataset.mols]
bonds_per_mol = [mol.GetNumBonds() for mol in predict_dataset.mols]

atom_preds = torch.split(atom_preds, atoms_per_mol)
bond_preds = torch.split(bond_preds, bonds_per_mol)