[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chemprop/chemprop/blob/main/examples/mol_atom_bond/regression.ipynb)

In [1]:
# Install chemprop from GitHub if running in Google Colab
import os

if os.getenv("COLAB_RELEASE_TAG"):
    try:
        import chemprop
    except ImportError:
        !git clone https://github.com/chemprop/chemprop.git
        %cd chemprop
        !pip install .
        %cd examples/mol_atom_bond

In [2]:
import ast
from pathlib import Path

from lightning import pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
import numpy as np
import pandas as pd
import torch

from chemprop import data, featurizers, models, nn

columns = ["smiles", "mol_y1", "mol_y2", "atom_y1", "atom_y2", "bond_y1", "bond_y2", "weight"]
chemprop_dir = Path.cwd().parent.parent
data_dir = chemprop_dir / "tests" / "data" / "mol_atom_bond"

In [3]:
x_ds = np.load(data_dir / "descriptors.npz")["arr_0"]
V_fs = np.load(data_dir / "atom_features_descriptors.npz")
V_fs = [V_fs[f"arr_{i}"] for i in range(len(V_fs))]
V_ds = V_fs
E_fs = np.load(data_dir / "bond_features_descriptors.npz")
E_fs = [E_fs[f"arr_{i}"] for i in range(len(E_fs))]
E_ds = [np.repeat(E_f, repeats=2, axis=0) for E_f in E_fs]

In [4]:
df_input = pd.read_csv(data_dir / "regression.csv")
smis = df_input.loc[:, columns[0]].values
mol_ys = df_input.loc[:, columns[1:3]].values
atoms_ys = df_input.loc[:, columns[3:5]].values
bonds_ys = df_input.loc[:, columns[5:7]].values
weights = df_input.loc[:, columns[7]].values

atoms_ys = [
    np.array([ast.literal_eval(atom_y) for atom_y in atom_ys], dtype=float).T
    for atom_ys in atoms_ys
]
bonds_ys = [
    np.array([ast.literal_eval(bond_y) for bond_y in bond_ys], dtype=float).T
    for bond_ys in bonds_ys
]

datapoints = [
    data.MolAtomBondDatapoint.from_smi(
        smi,
        keep_h=True,
        add_h=False,
        reorder_atoms=True,
        y=mol_ys[i],
        atom_y=atoms_ys[i],
        bond_y=bonds_ys[i],
        weight=weights[i],
        x_d=x_ds[i],
        V_f=V_fs[i],
        V_d=V_ds[i],
        E_f=E_fs[i],
        E_d=E_ds[i],
    )
    for i, smi in enumerate(smis)
]

featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer(
    extra_atom_fdim=atoms_ys[0].shape[1], extra_bond_fdim=bonds_ys[0].shape[1]
)

train_dataset = data.MolAtomBondDataset(datapoints, featurizer=featurizer)
val_dataset = data.MolAtomBondDataset(datapoints, featurizer=featurizer)
test_dataset = data.MolAtomBondDataset(datapoints, featurizer=featurizer)
predict_dataset = data.MolAtomBondDataset(datapoints, featurizer=featurizer)

V_f_scaler = train_dataset.normalize_inputs("V_f")
E_f_scaler = train_dataset.normalize_inputs("E_f")
V_d_scaler = train_dataset.normalize_inputs("V_d")
E_d_scaler = train_dataset.normalize_inputs("E_d")
val_dataset.normalize_inputs("V_f", V_f_scaler)
val_dataset.normalize_inputs("E_f", E_f_scaler)
val_dataset.normalize_inputs("V_d", V_d_scaler)
val_dataset.normalize_inputs("E_d", E_d_scaler)

V_f_transform = nn.ScaleTransform.from_standard_scaler(
    V_f_scaler, pad=(featurizer.atom_fdim - featurizer.extra_atom_fdim)
)
E_f_transform = nn.ScaleTransform.from_standard_scaler(
    E_f_scaler, pad=(featurizer.bond_fdim - featurizer.extra_bond_fdim)
)
graph_transform = nn.GraphTransform(V_f_transform, E_f_transform)
V_d_transform = nn.ScaleTransform.from_standard_scaler(V_d_scaler)
E_d_transform = nn.ScaleTransform.from_standard_scaler(E_d_scaler)

X_d_scaler = train_dataset.normalize_inputs("X_d")
val_dataset.normalize_inputs("X_d", X_d_scaler)
X_d_transform = nn.ScaleTransform.from_standard_scaler(X_d_scaler)

mol_target_scaler = train_dataset.normalize_targets("mol")
atom_target_scaler = train_dataset.normalize_targets("atom")
bond_target_scaler = train_dataset.normalize_targets("bond")
val_dataset.normalize_targets("mol", mol_target_scaler)
val_dataset.normalize_targets("atom", atom_target_scaler)
val_dataset.normalize_targets("bond", bond_target_scaler)
mol_target_transform = nn.UnscaleTransform.from_standard_scaler(mol_target_scaler)
atom_target_transform = nn.UnscaleTransform.from_standard_scaler(atom_target_scaler)
bond_target_transform = nn.UnscaleTransform.from_standard_scaler(bond_target_scaler)

train_dataloader = data.build_dataloader(train_dataset, shuffle=True, batch_size=4)
val_dataloader = data.build_dataloader(val_dataset, shuffle=False, batch_size=4)
test_dataloader = data.build_dataloader(test_dataset, shuffle=False, batch_size=4)
predict_dataloader = data.build_dataloader(predict_dataset, shuffle=False, batch_size=4)

In [5]:
mp = nn.MABBondMessagePassing(
    d_v=featurizer.atom_fdim,
    d_e=featurizer.bond_fdim,
    d_h=100,
    d_vd=V_ds[0].shape[1],
    d_ed=E_ds[0].shape[1],
    dropout=0.1,
    activation="tanh",
    depth=4,
    graph_transform=graph_transform,
    V_d_transform=V_d_transform,
    E_d_transform=E_d_transform,
)

In [6]:
metrics = [nn.MAE(), nn.RMSE()]

In [7]:
agg = nn.NormAggregation(norm=10)
mol_predictor = nn.RegressionFFN(
    input_dim=mp.output_dims[0] + x_ds.shape[1],
    n_tasks=mol_ys.shape[1],
    output_transform=mol_target_transform,
    criterion=nn.MSE(task_weights=[0.5, 0.1]),
)
atom_predictor = nn.RegressionFFN(
    input_dim=mp.output_dims[0],
    n_tasks=atoms_ys[0].shape[1],
    output_transform=atom_target_transform,
)
bond_predictor = nn.RegressionFFN(
    input_dim=(mp.output_dims[1] * 2),
    n_tasks=bonds_ys[0].shape[1],
    output_transform=bond_target_transform,
)

In [8]:
model = models.MolAtomBondMPNN(
    message_passing=mp,
    agg=agg,
    mol_predictor=mol_predictor,
    atom_predictor=atom_predictor,
    bond_predictor=bond_predictor,
    batch_norm=True,
    metrics=metrics,
    X_d_transform=X_d_transform,
)

In [9]:
model

MolAtomBondMPNN(
  (message_passing): MABBondMessagePassing(
    (W_i): Linear(in_features=90, out_features=100, bias=False)
    (W_h): Linear(in_features=100, out_features=100, bias=False)
    (W_vo): Linear(in_features=174, out_features=100, bias=True)
    (W_vd): Linear(in_features=102, out_features=102, bias=True)
    (W_eo): Linear(in_features=116, out_features=100, bias=True)
    (W_ed): Linear(in_features=102, out_features=102, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (tau): Tanh()
    (V_d_transform): ScaleTransform()
    (E_d_transform): ScaleTransform()
    (graph_transform): GraphTransform(
      (V_transform): ScaleTransform()
      (E_transform): ScaleTransform()
    )
  )
  (agg): NormAggregation()
  (mol_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=104, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_fea

In [10]:
print(model.output_dimss)
print(model.n_taskss)
print(model.n_targetss)
print(model.criterions)

(2, 2, 2)
(2, 2, 2)
(1, 1, 1)
(MSE(task_weights=[[0.5, 0.10000000149011612]]), MSE(task_weights=[[1.0, 1.0]]), MSE(task_weights=[[1.0, 1.0]]))


In [11]:
checkpointing = ModelCheckpoint(
    dirpath="checkpoints",
    filename="best-{epoch}-{val_loss:.2f}",
    monitor="val_loss",
    mode="min",
    save_last=True,
)

trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True,
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=20,
    callbacks=[checkpointing],
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [12]:
! rm -rf checkpoints/
! rm temp.pt

In [13]:
trainer.fit(model, train_dataloader, val_dataloader)

Loading `train_dataloader` to estimate number of stepping batches.
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.

  | Name            | Type                  | Params | Mode 
------------------------------------------------------------------
0 | message_passing | MABBondMessagePassing | 69.2 K | train
1 | agg             | NormAggregation       | 0      | train
2 | mol_predictor   | RegressionFFN         | 32.1 K | train
3 | atom_predictor  | RegressionFFN         | 31.5 K | train
4 | bond_predictor  | RegressionFFN         | 62.1 K | train
5 | bns             | ModuleList            | 612    | train
6 | X_d_transform   | ScaleTransform        | 0      | train
7 | metricss        | ModuleList            | 0 

                                                                           

/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Epoch 19: 100%|██████████| 3/3 [00:00<00:00, 13.09it/s, mol_train_loss_step=0.0121, atom_train_loss_step=0.0669, bond_train_loss_step=0.219, train_loss_step=0.298, mol_val_loss=0.0259, atom_val_loss=0.0127, bond_val_loss=0.0394, val_loss=0.115, mol_train_loss_epoch=0.0161, atom_train_loss_epoch=0.0351, bond_train_loss_epoch=0.0896, train_loss_epoch=0.151]    

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 3/3 [00:00<00:00, 11.57it/s, mol_train_loss_step=0.0121, atom_train_loss_step=0.0669, bond_train_loss_step=0.219, train_loss_step=0.298, mol_val_loss=0.0259, atom_val_loss=0.0127, bond_val_loss=0.0394, val_loss=0.115, mol_train_loss_epoch=0.0161, atom_train_loss_epoch=0.0351, bond_train_loss_epoch=0.0896, train_loss_epoch=0.151]


In [14]:
results = trainer.test(dataloaders=test_dataloader)

Restoring states from the checkpoint path at /home/knathan/chemprop/examples/mol_atom_bond/checkpoints/best-epoch=18-val_loss=0.11.ckpt
Loaded model weights from the checkpoint at /home/knathan/chemprop/examples/mol_atom_bond/checkpoints/best-epoch=18-val_loss=0.11.ckpt
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 3/3 [00:00<00:00, 41.44it/s]


In [15]:
predss = trainer.predict(model, predict_dataloader)
mol_preds, atom_preds, bond_preds = (torch.concat(tensors) for tensors in zip(*predss))

/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Predicting DataLoader 0: 100%|██████████| 3/3 [00:00<00:00, 109.37it/s]


In [16]:
models.utils.save_model("temp.pt", model)
models.MolAtomBondMPNN.load_from_file("temp.pt")

MolAtomBondMPNN(
  (message_passing): MABBondMessagePassing(
    (W_i): Linear(in_features=90, out_features=100, bias=False)
    (W_h): Linear(in_features=100, out_features=100, bias=False)
    (W_vo): Linear(in_features=174, out_features=100, bias=True)
    (W_vd): Linear(in_features=102, out_features=102, bias=True)
    (W_eo): Linear(in_features=116, out_features=100, bias=True)
    (W_ed): Linear(in_features=102, out_features=102, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (tau): Tanh()
    (V_d_transform): ScaleTransform()
    (E_d_transform): ScaleTransform()
    (graph_transform): GraphTransform(
      (V_transform): ScaleTransform()
      (E_transform): ScaleTransform()
    )
  )
  (agg): NormAggregation()
  (mol_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=104, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_fea

In [17]:
models.MolAtomBondMPNN.load_from_checkpoint("checkpoints/last.ckpt")

MolAtomBondMPNN(
  (message_passing): MABBondMessagePassing(
    (W_i): Linear(in_features=90, out_features=100, bias=False)
    (W_h): Linear(in_features=100, out_features=100, bias=False)
    (W_vo): Linear(in_features=174, out_features=100, bias=True)
    (W_vd): Linear(in_features=102, out_features=102, bias=True)
    (W_eo): Linear(in_features=116, out_features=100, bias=True)
    (W_ed): Linear(in_features=102, out_features=102, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (tau): Tanh()
    (V_d_transform): ScaleTransform()
    (E_d_transform): ScaleTransform()
    (graph_transform): GraphTransform(
      (V_transform): ScaleTransform()
      (E_transform): ScaleTransform()
    )
  )
  (agg): NormAggregation()
  (mol_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=104, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_fea