[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chemprop/chemprop/blob/main/examples/mol_atom_bond/bounded.ipynb)

In [1]:
# Install chemprop from GitHub if running in Google Colab
import os

if os.getenv("COLAB_RELEASE_TAG"):
    try:
        import chemprop
    except ImportError:
        !git clone https://github.com/chemprop/chemprop.git
        %cd chemprop
        !pip install .
        %cd examples/mol_atom_bond

In [2]:
import ast
from pathlib import Path

from lightning import pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
import numpy as np
import pandas as pd
import torch

from chemprop import data, featurizers, models, nn

columns = ["smiles", "mol_y1", "mol_y2", "atom_y1", "atom_y2", "bond_y1", "bond_y2", "weight"]
chemprop_dir = Path.cwd().parent.parent
data_dir = chemprop_dir / "tests" / "data" / "mol_atom_bond"

In [3]:
df_input = pd.read_csv(data_dir / "bounded.csv")
smis = df_input.loc[:, columns[0]].values
mol_ys = df_input.loc[:, columns[1:3]]
atoms_ys = df_input.loc[:, columns[3:5]]
bonds_ys = df_input.loc[:, columns[5:7]]
weights = df_input.loc[:, columns[7]].values

mol_ys = mol_ys.astype(str)
lt_mask = mol_ys.map(lambda x: "<" in x).to_numpy()
gt_mask = mol_ys.map(lambda x: ">" in x).to_numpy()
mol_ys = mol_ys.map(lambda x: x.strip("<").strip(">")).to_numpy(np.single)

atoms_ys = atoms_ys.map(ast.literal_eval)
atom_lt_masks = atoms_ys.map(lambda L: ["<" in v if v else False for v in L])
atom_gt_masks = atoms_ys.map(lambda L: [">" in v if v else False for v in L])

atom_lt_masks = atom_lt_masks.apply(lambda row: np.vstack(row.values).T, axis=1).tolist()
atom_gt_masks = atom_gt_masks.apply(lambda row: np.vstack(row.values).T, axis=1).tolist()
atoms_ys = atoms_ys.map(
    lambda L: np.array([v.strip("<").strip(">") if v else "nan" for v in L], dtype=np.single)
)
atoms_ys = atoms_ys.apply(lambda row: np.vstack(row.values).T, axis=1).tolist()

bonds_ys = bonds_ys.map(ast.literal_eval)
bond_lt_masks = bonds_ys.map(lambda L: ["<" in v if v else False for v in L])
bond_gt_masks = bonds_ys.map(lambda L: [">" in v if v else False for v in L])

bond_lt_masks = bond_lt_masks.apply(lambda row: np.vstack(row.values).T, axis=1).tolist()
bond_gt_masks = bond_gt_masks.apply(lambda row: np.vstack(row.values).T, axis=1).tolist()

bond_lt_masks = [bond_lt_mask.astype(bool) for bond_lt_mask in bond_lt_masks]
bond_gt_masks = [bond_gt_mask.astype(bool) for bond_gt_mask in bond_gt_masks]

bonds_ys = bonds_ys.map(
    lambda L: np.array([v.strip("<").strip(">") if v else "nan" for v in L], dtype=np.single)
)
bonds_ys = bonds_ys.apply(lambda row: np.vstack(row.values).T, axis=1).tolist()

datapoints = [
    data.MolAtomBondDatapoint.from_smi(
        smi,
        keep_h=True,
        add_h=False,
        reorder_atoms=True,
        y=mol_ys[i],
        atom_y=atoms_ys[i],
        bond_y=bonds_ys[i],
        weight=weights[i],
        lt_mask=lt_mask[i],
        gt_mask=gt_mask[i],
        atom_lt_mask=atom_lt_masks[i],
        atom_gt_mask=atom_gt_masks[i],
        bond_lt_mask=bond_lt_masks[i],
        bond_gt_mask=bond_gt_masks[i],
    )
    for i, smi in enumerate(smis)
]
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

train_dataset = data.MolAtomBondDataset(datapoints, featurizer=featurizer)
val_dataset = data.MolAtomBondDataset(datapoints, featurizer=featurizer)
test_dataset = data.MolAtomBondDataset(datapoints, featurizer=featurizer)
predict_dataset = data.MolAtomBondDataset(datapoints, featurizer=featurizer)

train_dataloader = data.build_dataloader(train_dataset, shuffle=True)
val_dataloader = data.build_dataloader(val_dataset, shuffle=False)
test_dataloader = data.build_dataloader(test_dataset, shuffle=False)
predict_dataloader = data.build_dataloader(predict_dataset, shuffle=False)

In [4]:
mp = nn.MABAtomMessagePassing()

In [5]:
metrics = [nn.MAE(), nn.RMSE()]

In [6]:
agg = nn.SumAggregation()
mol_predictor = nn.RegressionFFN(n_tasks=mol_ys.shape[1], criterion=nn.BoundedMSE())
atom_predictor = nn.RegressionFFN(n_tasks=atoms_ys[0].shape[1], criterion=nn.BoundedMSE())
bond_predictor = nn.RegressionFFN(
    input_dim=(mp.output_dims[1] * 2), n_tasks=bonds_ys[0].shape[1], criterion=nn.BoundedMSE()
)

In [7]:
model = models.MolAtomBondMPNN(
    message_passing=mp,
    agg=agg,
    mol_predictor=mol_predictor,
    atom_predictor=atom_predictor,
    bond_predictor=bond_predictor,
    metrics=metrics,
)

In [8]:
model

MolAtomBondMPNN(
  (message_passing): MABAtomMessagePassing(
    (W_i): Linear(in_features=72, out_features=300, bias=False)
    (W_h): Linear(in_features=314, out_features=300, bias=False)
    (W_vo): Linear(in_features=372, out_features=300, bias=True)
    (W_eo): Linear(in_features=314, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (E_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): SumAggregation()
  (mol_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=2, bias=True)
      )
    )
    (criterion): BoundedMSE(task_weights=[[1.0]])
    (output_transform): Identity()
  )
  (atom_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Li

In [9]:
print(model.output_dimss)
print(model.n_taskss)
print(model.n_targetss)
print(model.criterions)

(2, 2, 2)
(2, 2, 2)
(1, 1, 1)
(BoundedMSE(task_weights=[[1.0]]), BoundedMSE(task_weights=[[1.0]]), BoundedMSE(task_weights=[[1.0]]))


In [10]:
checkpointing = ModelCheckpoint(
    dirpath="checkpoints",
    filename="best-{epoch}-{val_loss:.2f}",
    monitor="val_loss",
    mode="min",
    save_last=True,
)

trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True,
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=20,
    callbacks=[checkpointing],
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [11]:
! rm -rf checkpoints/
! rm temp.pt

In [12]:
trainer.fit(model, train_dataloader, val_dataloader)

Loading `train_dataloader` to estimate number of stepping batches.
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.

  | Name            | Type                  | Params | Mode 
------------------------------------------------------------------
0 | message_passing | MABAtomMessagePassing | 322 K  | train
1 | agg             | SumAggregation        | 0      | train
2 | mol_predictor   | RegressionFFN         | 90.9 K | train
3 | atom_predictor  | RegressionFFN         | 90.9 K | train
4 | bond_predictor  | RegressionFFN         | 180 K  | train
5 | bns             | ModuleList            | 0      | train
6 | X_d_transform   | Identity              | 0      | train
7 | metricss        | ModuleList            | 0 

                                                                           

/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Epoch 19: 100%|██████████| 1/1 [00:00<00:00,  7.89it/s, mol_train_loss_step=93.70, atom_train_loss_step=10.70, bond_train_loss_step=7.490, train_loss_step=112.0, mol_val_loss=86.10, atom_val_loss=10.50, bond_val_loss=7.050, val_loss=104.0, mol_train_loss_epoch=93.70, atom_train_loss_epoch=10.70, bond_train_loss_epoch=7.490, train_loss_epoch=112.0]

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 1/1 [00:00<00:00,  5.08it/s, mol_train_loss_step=93.70, atom_train_loss_step=10.70, bond_train_loss_step=7.490, train_loss_step=112.0, mol_val_loss=86.10, atom_val_loss=10.50, bond_val_loss=7.050, val_loss=104.0, mol_train_loss_epoch=93.70, atom_train_loss_epoch=10.70, bond_train_loss_epoch=7.490, train_loss_epoch=112.0]


In [13]:
results = trainer.test(dataloaders=test_dataloader)

Restoring states from the checkpoint path at /home/knathan/chemprop/examples/mol_atom_bond/checkpoints/best-epoch=19-val_loss=103.71.ckpt
Loaded model weights from the checkpoint at /home/knathan/chemprop/examples/mol_atom_bond/checkpoints/best-epoch=19-val_loss=103.71.ckpt
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 52.18it/s]


In [14]:
predss = trainer.predict(model, predict_dataloader)
mol_preds, atom_preds, bond_preds = (torch.concat(tensors) for tensors in zip(*predss))

/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Predicting DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 81.18it/s]


In [15]:
models.utils.save_model("temp.pt", model)
models.MolAtomBondMPNN.load_from_file("temp.pt")

MolAtomBondMPNN(
  (message_passing): MABAtomMessagePassing(
    (W_i): Linear(in_features=72, out_features=300, bias=False)
    (W_h): Linear(in_features=314, out_features=300, bias=False)
    (W_vo): Linear(in_features=372, out_features=300, bias=True)
    (W_eo): Linear(in_features=314, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (E_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): SumAggregation()
  (mol_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=2, bias=True)
      )
    )
    (criterion): BoundedMSE(task_weights=[[1.0]])
    (output_transform): Identity()
  )
  (atom_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Li

In [16]:
models.MolAtomBondMPNN.load_from_checkpoint("checkpoints/last.ckpt")

MolAtomBondMPNN(
  (message_passing): MABAtomMessagePassing(
    (W_i): Linear(in_features=72, out_features=300, bias=False)
    (W_h): Linear(in_features=314, out_features=300, bias=False)
    (W_vo): Linear(in_features=372, out_features=300, bias=True)
    (W_eo): Linear(in_features=314, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (E_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): SumAggregation()
  (mol_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=2, bias=True)
      )
    )
    (criterion): BoundedMSE(task_weights=[[1.0]])
    (output_transform): Identity()
  )
  (atom_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Li