[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chemprop/chemprop/blob/main/examples/mol_atom_bond/constrained.ipynb)

In [1]:
# Install chemprop from GitHub if running in Google Colab
import os

if os.getenv("COLAB_RELEASE_TAG"):
    try:
        import chemprop
    except ImportError:
        !git clone https://github.com/chemprop/chemprop.git
        %cd chemprop
        !pip install .
        %cd examples/mol_atom_bond

In [2]:
import ast
from pathlib import Path

from lightning import pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
import numpy as np
import pandas as pd
import torch

from chemprop import data, featurizers, models, nn

columns = ["smiles", "mol_y", "atom_y1", "atom_y2", "bond_y1", "bond_y2"]
chemprop_dir = Path.cwd().parent.parent
data_dir = chemprop_dir / "tests" / "data" / "mol_atom_bond"

In [3]:
df_input = pd.read_csv(data_dir / "constrained_regression.csv")
smis = df_input.loc[:, columns[0]].values
mol_ys = df_input.loc[:, columns[1:2]].values
atoms_ys = df_input.loc[:, columns[2:4]].values
bonds_ys = df_input.loc[:, columns[4:6]].values

atoms_ys = [
    np.array([ast.literal_eval(atom_y) for atom_y in atom_ys], dtype=float).T
    for atom_ys in atoms_ys
]
bonds_ys = [
    np.array([ast.literal_eval(bond_y) for bond_y in bond_ys], dtype=float).T
    for bond_ys in bonds_ys
]

df_constraints = pd.read_csv(data_dir / "constrained_regression_constraints.csv")
n_mols = len(df_constraints)
constraints_cols_to_target_cols = {
    "atom_target_col_0": 0,
    "atom_target_col_1": 1,
    "bond_target_col_1": 2,
}

atom_constraint_cols = [
    constraints_cols_to_target_cols.get(f"atom_target_col_{col}", None)
    for col in range(atoms_ys[0].shape[1])
]
atom_constraints = np.hstack(
    [
        df_constraints.iloc[:, col].values.reshape(-1, 1)
        if col is not None
        else np.full((n_mols, 1), np.nan)
        for col in atom_constraint_cols
    ]
)

bond_constraint_cols = [
    constraints_cols_to_target_cols.get(f"bond_target_col_{col}", None)
    for col in range(bonds_ys[0].shape[1])
]
bond_constraints = np.hstack(
    [
        df_constraints.iloc[:, col].values.reshape(-1, 1)
        if col is not None
        else np.full((n_mols, 1), np.nan)
        for col in bond_constraint_cols
    ]
)

datapoints = [
    data.MolAtomBondDatapoint.from_smi(
        smi,
        keep_h=True,
        add_h=False,
        reorder_atoms=True,
        y=mol_ys[i],
        atom_y=atoms_ys[i],
        bond_y=bonds_ys[i],
        atom_constraint=atom_constraints[i],
        bond_constraint=bond_constraints[i],
    )
    for i, smi in enumerate(smis)
]

featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

train_dataset = data.MolAtomBondDataset(datapoints, featurizer=featurizer)
val_dataset = data.MolAtomBondDataset(datapoints, featurizer=featurizer)
test_dataset = data.MolAtomBondDataset(datapoints, featurizer=featurizer)
predict_dataset = data.MolAtomBondDataset(datapoints, featurizer=featurizer)

train_dataloader = data.build_dataloader(train_dataset, shuffle=True)
val_dataloader = data.build_dataloader(val_dataset, shuffle=False)
test_dataloader = data.build_dataloader(test_dataset, shuffle=False)
predict_dataloader = data.build_dataloader(predict_dataset, shuffle=False)

In [4]:
mp = nn.MABBondMessagePassing()

In [5]:
agg = nn.NormAggregation()
mol_predictor = nn.RegressionFFN(n_tasks=mol_ys.shape[1])
atom_predictor = nn.RegressionFFN(n_tasks=atoms_ys[0].shape[1])
bond_predictor = nn.RegressionFFN(input_dim=(mp.output_dims[1] * 2), n_tasks=bonds_ys[0].shape[1])

In [6]:
atom_constrainer = nn.Constrainer(n_constraints=(~np.isnan(atom_constraints[0])).sum())
bond_constrainer = nn.Constrainer(n_constraints=(~np.isnan(bond_constraints[0])).sum(), fp_dim=600)

In [7]:
model = models.MolAtomBondMPNN(
    message_passing=mp,
    agg=agg,
    mol_predictor=mol_predictor,
    atom_predictor=atom_predictor,
    bond_predictor=bond_predictor,
    atom_constrainer=atom_constrainer,
    bond_constrainer=bond_constrainer,
)

In [8]:
model

MolAtomBondMPNN(
  (message_passing): MABBondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_vo): Linear(in_features=372, out_features=300, bias=True)
    (W_eo): Linear(in_features=314, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (E_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): NormAggregation()
  (mol_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=1, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0]])
    (output_transform): Identity()
  )
  (atom_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(i

Show that the atom and bond predictions match the constraints

In [9]:
batch = next(iter(predict_dataloader))
bmg, V_d, E_d, X_d, *_, constraints = batch
with torch.no_grad():
    mol_preds, atom_preds_tensor, bond_preds_tensor = model(bmg, V_d, E_d, X_d, constraints)

In [10]:
atoms_per_mol = [mol.GetNumAtoms() for mol in predict_dataset.mols]
atom_preds = torch.split(atom_preds_tensor, atoms_per_mol)
errors = predict_dataset.atom_constraints - torch.vstack([p.sum(dim=0) for p in atom_preds]).numpy()
print(errors)
assert np.all(np.isclose(errors[~np.isnan(errors)], 0.0, atol=1e-5))

[[ 0.00000000e+00 -3.24249267e-08]
 [ 0.00000000e+00  3.20434570e-07]
 [ 9.31322575e-10  2.25830078e-06]
 [ 9.31322575e-10  2.25830078e-06]
 [ 0.00000000e+00  6.40869139e-07]
 [-1.19209290e-07 -3.96728517e-07]
 [ 1.86264515e-09 -2.53295899e-06]
 [ 9.31322575e-10 -2.13623047e-06]
 [ 2.79396772e-09  6.71386722e-07]
 [-1.86264515e-09 -1.52587891e-06]
 [ 0.00000000e+00  5.12695313e-06]]


In [11]:
bonds_per_mol = [mol.GetNumBonds() for mol in predict_dataset.mols]
bond_preds = torch.split(bond_preds_tensor, bonds_per_mol)
errors = predict_dataset.bond_constraints - torch.vstack([p.sum(dim=0) for p in bond_preds]).numpy()
print(errors)
assert np.all(np.isclose(errors[~np.isnan(errors)], 0.0, atol=1e-5))

[[           nan 0.00000000e+00]
 [           nan 0.00000000e+00]
 [           nan 0.00000000e+00]
 [           nan 0.00000000e+00]
 [           nan 1.19209290e-07]
 [           nan 0.00000000e+00]
 [           nan 0.00000000e+00]
 [           nan 0.00000000e+00]
 [           nan 9.53674316e-07]
 [           nan 0.00000000e+00]
 [           nan 0.00000000e+00]]


In [12]:
checkpointing = ModelCheckpoint(
    dirpath="checkpoints",
    filename="best-{epoch}-{val_loss:.2f}",
    monitor="val_loss",
    mode="min",
    save_last=True,
)

trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True,
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=20,
    callbacks=[checkpointing],
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [13]:
! rm -rf checkpoints/
! rm temp.pt

In [14]:
trainer.fit(model, train_dataloader, val_dataloader)

Loading `train_dataloader` to estimate number of stepping batches.
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.

  | Name             | Type                  | Params | Mode 
-------------------------------------------------------------------
0 | message_passing  | MABBondMessagePassing | 322 K  | train
1 | agg              | NormAggregation       | 0      | train
2 | mol_predictor    | RegressionFFN         | 90.6 K | train
3 | atom_predictor   | RegressionFFN         | 90.9 K | train
4 | atom_constrainer | Constrainer           | 90.9 K | train
5 | bond_predictor   | RegressionFFN         | 180 K  | train
6 | bond_constrainer | Constrainer           | 180 K  | train
7 | bns              | ModuleList      

Sanity Checking DataLoader 0: 100%|██████████| 1/1 [00:00<00:00,  5.83it/s]

/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Epoch 19: 100%|██████████| 1/1 [00:00<00:00,  9.05it/s, mol_train_loss_step=0.166, atom_train_loss_step=0.166, bond_train_loss_step=8.520, train_loss_step=8.850, mol_val_loss=0.166, atom_val_loss=0.146, bond_val_loss=8.270, val_loss=8.590, mol_train_loss_epoch=0.166, atom_train_loss_epoch=0.166, bond_train_loss_epoch=8.520, train_loss_epoch=8.850]

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 1/1 [00:00<00:00,  4.40it/s, mol_train_loss_step=0.166, atom_train_loss_step=0.166, bond_train_loss_step=8.520, train_loss_step=8.850, mol_val_loss=0.166, atom_val_loss=0.146, bond_val_loss=8.270, val_loss=8.590, mol_train_loss_epoch=0.166, atom_train_loss_epoch=0.166, bond_train_loss_epoch=8.520, train_loss_epoch=8.850]


In [15]:
results = trainer.test(dataloaders=test_dataloader)

Restoring states from the checkpoint path at /home/knathan/chemprop/examples/mol_atom_bond/checkpoints/best-epoch=19-val_loss=8.59.ckpt
Loaded model weights from the checkpoint at /home/knathan/chemprop/examples/mol_atom_bond/checkpoints/best-epoch=19-val_loss=8.59.ckpt
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 36.00it/s]


In [16]:
predss = trainer.predict(model, predict_dataloader)
mol_preds, atom_preds, bond_preds = (torch.concat(tensors) for tensors in zip(*predss))

/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Predicting DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 96.45it/s] 


In [17]:
models.utils.save_model("temp.pt", model)
model = models.MolAtomBondMPNN.load_from_file("temp.pt")
with torch.no_grad():
    mol_preds, atom_preds_tensor, bond_preds_tensor = model(bmg, V_d, E_d, X_d, constraints)

In [18]:
model = models.MolAtomBondMPNN.load_from_checkpoint("checkpoints/last.ckpt")
with torch.no_grad():
    mol_preds, atom_preds_tensor, bond_preds_tensor = model(bmg, V_d, E_d, X_d, constraints)