[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chemprop/chemprop/blob/main/examples/mol_atom_bond/mve.ipynb)

In [1]:
# Install chemprop from GitHub if running in Google Colab
import os

if os.getenv("COLAB_RELEASE_TAG"):
    try:
        import chemprop
    except ImportError:
        !git clone https://github.com/chemprop/chemprop.git
        %cd chemprop
        !pip install .
        %cd examples/mol_atom_bond

In [2]:
import ast
from pathlib import Path

from lightning import pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
import numpy as np
import pandas as pd
import torch

from chemprop import data, featurizers, models, nn

columns = ["smiles", "mol_y1", "mol_y2", "atom_y1", "atom_y2", "bond_y1", "bond_y2", "weight"]
chemprop_dir = Path.cwd().parent.parent
data_dir = chemprop_dir / "tests" / "data" / "mol_atom_bond"

In [3]:
df_input = pd.read_csv(data_dir / "regression.csv")
smis = df_input.loc[:, columns[0]].values
mol_ys = df_input.loc[:, columns[1:3]].values
atoms_ys = df_input.loc[:, columns[3:5]].values
bonds_ys = df_input.loc[:, columns[5:7]].values
weights = df_input.loc[:, columns[7]].values

atoms_ys = [
    np.array([ast.literal_eval(atom_y) for atom_y in atom_ys], dtype=float).T
    for atom_ys in atoms_ys
]
bonds_ys = [
    np.array([ast.literal_eval(bond_y) for bond_y in bond_ys], dtype=float).T
    for bond_ys in bonds_ys
]

datapoints = [
    data.MolAtomBondDatapoint.from_smi(
        smi,
        keep_h=True,
        add_h=False,
        reorder_atoms=True,
        y=mol_ys[i],
        atom_y=atoms_ys[i],
        bond_y=bonds_ys[i],
        weight=weights[i],
    )
    for i, smi in enumerate(smis)
]

featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

train_dataset = data.MolAtomBondDataset(datapoints, featurizer=featurizer)
val_dataset = data.MolAtomBondDataset(datapoints, featurizer=featurizer)
test_dataset = data.MolAtomBondDataset(datapoints, featurizer=featurizer)
predict_dataset = data.MolAtomBondDataset(datapoints, featurizer=featurizer)

train_dataloader = data.build_dataloader(train_dataset, shuffle=True, batch_size=4)
val_dataloader = data.build_dataloader(val_dataset, shuffle=False, batch_size=4)
test_dataloader = data.build_dataloader(test_dataset, shuffle=False, batch_size=4)
predict_dataloader = data.build_dataloader(predict_dataset, shuffle=False, batch_size=4)

In [4]:
mp = nn.MABBondMessagePassing()

In [5]:
agg = nn.NormAggregation()
mol_predictor = nn.MveFFN(n_tasks=mol_ys.shape[1])
atom_predictor = nn.MveFFN(n_tasks=atoms_ys[0].shape[1])
bond_predictor = nn.MveFFN(input_dim=(mp.output_dims[1] * 2), n_tasks=bonds_ys[0].shape[1])

In [6]:
model = models.MolAtomBondMPNN(
    message_passing=mp,
    agg=agg,
    mol_predictor=mol_predictor,
    atom_predictor=atom_predictor,
    bond_predictor=bond_predictor,
)

In [7]:
model

MolAtomBondMPNN(
  (message_passing): MABBondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_vo): Linear(in_features=372, out_features=300, bias=True)
    (W_eo): Linear(in_features=314, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (E_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): NormAggregation()
  (mol_predictor): MveFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=4, bias=True)
      )
    )
    (criterion): MVELoss(task_weights=[[1.0, 1.0]])
    (output_transform): Identity()
  )
  (atom_predictor): MveFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_fea

In [8]:
print(model.output_dimss)
print(model.n_taskss)
print(model.n_targetss)
print(model.criterions)

(4, 4, 4)
(2, 2, 2)
(2, 2, 2)
(MVELoss(task_weights=[[1.0, 1.0]]), MVELoss(task_weights=[[1.0, 1.0]]), MVELoss(task_weights=[[1.0, 1.0]]))


In [9]:
checkpointing = ModelCheckpoint(
    dirpath="checkpoints",
    filename="best-{epoch}-{val_loss:.2f}",
    monitor="val_loss",
    mode="min",
    save_last=True,
)


trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True,
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=20,
    callbacks=[checkpointing],
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [10]:
! rm -rf checkpoints/
! rm temp.pt

In [11]:
trainer.fit(model, train_dataloader, val_dataloader)

Loading `train_dataloader` to estimate number of stepping batches.
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.

  | Name            | Type                  | Params | Mode 
------------------------------------------------------------------
0 | message_passing | MABBondMessagePassing | 322 K  | train
1 | agg             | NormAggregation       | 0      | train
2 | mol_predictor   | MveFFN                | 91.5 K | train
3 | atom_predictor  | MveFFN                | 91.5 K | train
4 | bond_predictor  | MveFFN                | 181 K  | train
5 | bns             | ModuleList            | 0      | train
6 | X_d_transform   | Identity              | 0      | train
7 | metricss        | ModuleList            | 0 

                                                                           

/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Epoch 19: 100%|██████████| 3/3 [00:00<00:00, 11.75it/s, mol_train_loss_step=83.90, atom_train_loss_step=2.270, bond_train_loss_step=2.350, train_loss_step=88.50, mol_val_loss=76.80, atom_val_loss=3.220, bond_val_loss=3.360, val_loss=80.70, mol_train_loss_epoch=77.90, atom_train_loss_epoch=3.290, bond_train_loss_epoch=3.410, train_loss_epoch=86.00]

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 3/3 [00:00<00:00,  9.13it/s, mol_train_loss_step=83.90, atom_train_loss_step=2.270, bond_train_loss_step=2.350, train_loss_step=88.50, mol_val_loss=76.80, atom_val_loss=3.220, bond_val_loss=3.360, val_loss=80.70, mol_train_loss_epoch=77.90, atom_train_loss_epoch=3.290, bond_train_loss_epoch=3.410, train_loss_epoch=86.00]


In [12]:
results = trainer.test(dataloaders=test_dataloader)

Restoring states from the checkpoint path at /home/knathan/chemprop/examples/mol_atom_bond/checkpoints/best-epoch=19-val_loss=80.74.ckpt
Loaded model weights from the checkpoint at /home/knathan/chemprop/examples/mol_atom_bond/checkpoints/best-epoch=19-val_loss=80.74.ckpt
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 3/3 [00:00<00:00, 73.80it/s]


In [13]:
predss = trainer.predict(model, predict_dataloader)
mol_preds, atom_preds, bond_preds = (torch.concat(tensors) for tensors in zip(*predss))

/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Predicting DataLoader 0: 100%|██████████| 3/3 [00:00<00:00, 105.36it/s]


In [14]:
models.utils.save_model("temp.pt", model)
models.MolAtomBondMPNN.load_from_file("temp.pt")

MolAtomBondMPNN(
  (message_passing): MABBondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_vo): Linear(in_features=372, out_features=300, bias=True)
    (W_eo): Linear(in_features=314, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (E_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): NormAggregation()
  (mol_predictor): MveFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=4, bias=True)
      )
    )
    (criterion): MVELoss(task_weights=[[1.0, 1.0]])
    (output_transform): Identity()
  )
  (atom_predictor): MveFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_fea

In [15]:
models.MolAtomBondMPNN.load_from_checkpoint("checkpoints/last.ckpt")

MolAtomBondMPNN(
  (message_passing): MABBondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_vo): Linear(in_features=372, out_features=300, bias=True)
    (W_eo): Linear(in_features=314, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (E_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): NormAggregation()
  (mol_predictor): MveFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=4, bias=True)
      )
    )
    (criterion): MVELoss(task_weights=[[1.0, 1.0]])
    (output_transform): Identity()
  )
  (atom_predictor): MveFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_fea