[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chemprop/chemprop/blob/main/examples/mol_atom_bond/MAB_subsets.ipynb)

In [None]:
# Install chemprop from GitHub if running in Google Colab
import os

if os.getenv("COLAB_RELEASE_TAG"):
    try:
        import chemprop
    except ImportError:
        !git clone https://github.com/chemprop/chemprop.git
        %cd chemprop
        !pip install .
        %cd examples/mol_atom_bond

In [None]:
import ast
from pathlib import Path

from lightning import pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
import numpy as np
import pandas as pd
import torch

from chemprop import data, featurizers, models, nn

columns = ["smiles", "mol_y1", "mol_y2", "atom_y1", "atom_y2", "bond_y1", "bond_y2", "weight"]
chemprop_dir = Path.cwd().parent.parent
data_dir = chemprop_dir / "tests" / "data" / "mol_atom_bond"

In [None]:
df_input = pd.read_csv(data_dir / "regression.csv")
smis = df_input.loc[:, columns[0]].values
mol_ys = df_input.loc[:, columns[1:3]].values
atoms_ys = df_input.loc[:, columns[3:5]].values
bonds_ys = df_input.loc[:, columns[5:7]].values
weights = df_input.loc[:, columns[7]].values

atoms_ys = [
    np.array([ast.literal_eval(atom_y) for atom_y in atom_ys], dtype=float).T
    for atom_ys in atoms_ys
]
bonds_ys = [
    np.array([ast.literal_eval(bond_y) for bond_y in bond_ys], dtype=float).T
    for bond_ys in bonds_ys
]

datapoints = [
    data.MolAtomBondDatapoint.from_smi(
        smi,
        keep_h=True,
        add_h=False,
        reorder_atoms=True,
        y=mol_ys[i],
        atom_y=atoms_ys[i],
        bond_y=bonds_ys[i],
        weight=weights[i],
    )
    for i, smi in enumerate(smis)
]

featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

train_dataset = data.MolAtomBondDataset(datapoints, featurizer=featurizer)
val_dataset = data.MolAtomBondDataset(datapoints, featurizer=featurizer)
test_dataset = data.MolAtomBondDataset(datapoints, featurizer=featurizer)
predict_dataset = data.MolAtomBondDataset(datapoints, featurizer=featurizer)

train_dataloader = data.build_dataloader(train_dataset, shuffle=True, batch_size=4)
val_dataloader = data.build_dataloader(val_dataset, shuffle=False, batch_size=4)
test_dataloader = data.build_dataloader(test_dataset, shuffle=False, batch_size=4)
predict_dataloader = data.build_dataloader(predict_dataset, shuffle=False, batch_size=4)

In [3]:
# Mol, Atom, Bond

mp = nn.MABBondMessagePassing()
agg = nn.NormAggregation()
mol_predictor = nn.RegressionFFN(n_tasks=mol_ys.shape[1])
atom_predictor = nn.RegressionFFN(n_tasks=atoms_ys[0].shape[1])
bond_predictor = nn.RegressionFFN(input_dim=(mp.output_dims[1] * 2), n_tasks=bonds_ys[0].shape[1])
model = models.MolAtomBondMPNN(
    message_passing=mp,
    agg=agg,
    mol_predictor=mol_predictor,
    atom_predictor=atom_predictor,
    bond_predictor=bond_predictor,
)
display(model)

checkpointing = ModelCheckpoint(
    dirpath="checkpoints",
    filename="best-{epoch}-{val_loss:.2f}",
    monitor="val_loss",
    mode="min",
    save_last=True,
)
trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True,
    enable_progress_bar=True,
    max_epochs=4,
    callbacks=[checkpointing],
)

trainer.fit(model, train_dataloader, val_dataloader)
results = trainer.test(dataloaders=test_dataloader)
predss = trainer.predict(model, predict_dataloader)
mol_preds, atom_preds, bond_preds = (
    torch.concat(tensors) if tensors[0] is not None else None for tensors in zip(*predss)
)
assert mol_preds.shape == (11, 2)
assert atom_preds.shape == (30, 2)
assert bond_preds.shape == (21, 2)

models.utils.save_model("temp.pt", model)
display(models.MolAtomBondMPNN.load_from_file("temp.pt"))
display(models.MolAtomBondMPNN.load_from_checkpoint("checkpoints/last.ckpt"))

! rm -rf checkpoints/
! rm temp.pt

MolAtomBondMPNN(
  (message_passing): MABBondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_vo): Linear(in_features=372, out_features=300, bias=True)
    (W_eo): Linear(in_features=314, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (E_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): NormAggregation()
  (mol_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=2, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0, 1.0]])
    (output_transform): Identity()
  )
  (atom_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Lin

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Loading `train_dataloader` to estimate number of stepping batches.
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.

  | Name            | Type                  | Params | Mode 
------------------------------------------------------------------
0 | message_passing | MABBondMessagePassing | 322 K  | train
1 | agg             | NormAggregation       | 0      | train
2 | mol_predictor   | RegressionFFN         | 90.9 K | train
3 | atom_predictor  | RegressionFFN         | 90.9 K | train
4 | bond_predictor  | RegressionFFN         | 180 K  | train
5 | bns             | ModuleList            | 0      | trai

Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Epoch 3: 100%|██████████| 3/3 [00:00<00:00, 20.06it/s, mol_train_loss_step=1.72e+3, atom_train_loss_step=54.20, bond_train_loss_step=100.0, train_loss_step=1.88e+3, mol_val_loss=652.0, atom_val_loss=49.20, bond_val_loss=92.20, val_loss=738.0, mol_train_loss_epoch=652.0, atom_train_loss_epoch=50.50, bond_train_loss_epoch=94.40, train_loss_epoch=841.0]

`Trainer.fit` stopped: `max_epochs=4` reached.


Epoch 3: 100%|██████████| 3/3 [00:00<00:00, 14.41it/s, mol_train_loss_step=1.72e+3, atom_train_loss_step=54.20, bond_train_loss_step=100.0, train_loss_step=1.88e+3, mol_val_loss=652.0, atom_val_loss=49.20, bond_val_loss=92.20, val_loss=738.0, mol_train_loss_epoch=652.0, atom_train_loss_epoch=50.50, bond_train_loss_epoch=94.40, train_loss_epoch=841.0]


Restoring states from the checkpoint path at /home/knathan/chemprop/molatombond_notebooks/checkpoints/best-epoch=3-val_loss=737.93.ckpt
Loaded model weights from the checkpoint at /home/knathan/chemprop/molatombond_notebooks/checkpoints/best-epoch=3-val_loss=737.93.ckpt
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 3/3 [00:00<00:00, 83.31it/s]


/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Predicting DataLoader 0: 100%|██████████| 3/3 [00:00<00:00, 174.30it/s]


MolAtomBondMPNN(
  (message_passing): MABBondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_vo): Linear(in_features=372, out_features=300, bias=True)
    (W_eo): Linear(in_features=314, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (E_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): NormAggregation()
  (mol_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=2, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0, 1.0]])
    (output_transform): Identity()
  )
  (atom_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Lin

MolAtomBondMPNN(
  (message_passing): MABBondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_vo): Linear(in_features=372, out_features=300, bias=True)
    (W_eo): Linear(in_features=314, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (E_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): NormAggregation()
  (mol_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=2, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0, 1.0]])
    (output_transform): Identity()
  )
  (atom_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Lin

In [4]:
# Mol, Atom, None

mp = nn.MABBondMessagePassing(return_edge_embeddings=False)
agg = nn.NormAggregation()
mol_predictor = nn.RegressionFFN(n_tasks=mol_ys.shape[1])
atom_predictor = nn.RegressionFFN(n_tasks=atoms_ys[0].shape[1])
model = models.MolAtomBondMPNN(
    message_passing=mp, agg=agg, mol_predictor=mol_predictor, atom_predictor=atom_predictor
)
display(model)

checkpointing = ModelCheckpoint(
    dirpath="checkpoints",
    filename="best-{epoch}-{val_loss:.2f}",
    monitor="val_loss",
    mode="min",
    save_last=True,
)
trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True,
    enable_progress_bar=True,
    max_epochs=4,
    callbacks=[checkpointing],
)

trainer.fit(model, train_dataloader, val_dataloader)
results = trainer.test(dataloaders=test_dataloader)
predss = trainer.predict(model, predict_dataloader)
mol_preds, atom_preds, _ = (
    torch.concat(tensors) if tensors[0] is not None else None for tensors in zip(*predss)
)
assert mol_preds.shape == (11, 2)
assert atom_preds.shape == (30, 2)

models.utils.save_model("temp.pt", model)
display(models.MolAtomBondMPNN.load_from_file("temp.pt"))
display(models.MolAtomBondMPNN.load_from_checkpoint("checkpoints/last.ckpt"))

! rm -rf checkpoints/
! rm temp.pt

MolAtomBondMPNN(
  (message_passing): MABBondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_vo): Linear(in_features=372, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (E_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): NormAggregation()
  (mol_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=2, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0, 1.0]])
    (output_transform): Identity()
  )
  (atom_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Loading `train_dataloader` to estimate number of stepping batches.

  | Name            | Type                  | Params | Mode 
------------------------------------------------------------------
0 | message_passing | MABBondMessagePassing | 227 K  | train
1 | agg             | NormAggregation       | 0      | train
2 | mol_predictor   | RegressionFFN         | 90.9 K | train
3 | atom_predictor  | RegressionFFN         | 90.9 K | train
4 | bns             | ModuleList            | 0      | train
5 | X_d_transform   | Identity              | 0      | train
6 | metricss        | ModuleList            | 0      | train
------------------------------------------------------------------
409 K     Trainable params
0         Non-trainable params
409 K     Total params
1.638     Total estimated model params size (MB)
42        Modules in train mode
0         Modules in eval mode


Epoch 3: 100%|██████████| 3/3 [00:00<00:00, 21.21it/s, mol_train_loss_step=1.35e+3, atom_train_loss_step=48.10, train_loss_step=1.39e+3, mol_val_loss=653.0, atom_val_loss=46.50, val_loss=667.0, mol_train_loss_epoch=653.0, atom_train_loss_epoch=48.40, train_loss_epoch=733.0]

`Trainer.fit` stopped: `max_epochs=4` reached.


Epoch 3: 100%|██████████| 3/3 [00:00<00:00, 16.63it/s, mol_train_loss_step=1.35e+3, atom_train_loss_step=48.10, train_loss_step=1.39e+3, mol_val_loss=653.0, atom_val_loss=46.50, val_loss=667.0, mol_train_loss_epoch=653.0, atom_train_loss_epoch=48.40, train_loss_epoch=733.0]


Restoring states from the checkpoint path at /home/knathan/chemprop/molatombond_notebooks/checkpoints/best-epoch=3-val_loss=667.20.ckpt
Loaded model weights from the checkpoint at /home/knathan/chemprop/molatombond_notebooks/checkpoints/best-epoch=3-val_loss=667.20.ckpt


Testing DataLoader 0: 100%|██████████| 3/3 [00:00<00:00, 74.50it/s] 


Predicting DataLoader 0: 100%|██████████| 3/3 [00:00<00:00, 170.50it/s]


MolAtomBondMPNN(
  (message_passing): MABBondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_vo): Linear(in_features=372, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (E_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): NormAggregation()
  (mol_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=2, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0, 1.0]])
    (output_transform): Identity()
  )
  (atom_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (

MolAtomBondMPNN(
  (message_passing): MABBondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_vo): Linear(in_features=372, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (E_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): NormAggregation()
  (mol_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=2, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0, 1.0]])
    (output_transform): Identity()
  )
  (atom_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (

In [5]:
# Mol, None, Bond

mp = nn.MABBondMessagePassing()
agg = nn.NormAggregation()
mol_predictor = nn.RegressionFFN(n_tasks=mol_ys.shape[1])
bond_predictor = nn.RegressionFFN(input_dim=(mp.output_dims[1] * 2), n_tasks=bonds_ys[0].shape[1])
model = models.MolAtomBondMPNN(
    message_passing=mp, agg=agg, mol_predictor=mol_predictor, bond_predictor=bond_predictor
)
display(model)

checkpointing = ModelCheckpoint(
    dirpath="checkpoints",
    filename="best-{epoch}-{val_loss:.2f}",
    monitor="val_loss",
    mode="min",
    save_last=True,
)
trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True,
    enable_progress_bar=True,
    max_epochs=4,
    callbacks=[checkpointing],
)

trainer.fit(model, train_dataloader, val_dataloader)
results = trainer.test(dataloaders=test_dataloader)
predss = trainer.predict(model, predict_dataloader)
mol_preds, _, bond_preds = (
    torch.concat(tensors) if tensors[0] is not None else None for tensors in zip(*predss)
)
assert mol_preds.shape == (11, 2)
assert bond_preds.shape == (21, 2)

models.utils.save_model("temp.pt", model)
display(models.MolAtomBondMPNN.load_from_file("temp.pt"))
display(models.MolAtomBondMPNN.load_from_checkpoint("checkpoints/last.ckpt"))

! rm -rf checkpoints/
! rm temp.pt

MolAtomBondMPNN(
  (message_passing): MABBondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_vo): Linear(in_features=372, out_features=300, bias=True)
    (W_eo): Linear(in_features=314, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (E_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): NormAggregation()
  (mol_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=2, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0, 1.0]])
    (output_transform): Identity()
  )
  (bond_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Lin

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Loading `train_dataloader` to estimate number of stepping batches.

  | Name            | Type                  | Params | Mode 
------------------------------------------------------------------
0 | message_passing | MABBondMessagePassing | 322 K  | train
1 | agg             | NormAggregation       | 0      | train
2 | mol_predictor   | RegressionFFN         | 90.9 K | train
3 | bond_predictor  | RegressionFFN         | 180 K  | train
4 | bns             | ModuleList            | 0      | train
5 | X_d_transform   | Identity              | 0      | train
6 | metricss        | ModuleList            | 0      | train
------------------------------------------------------------------
594 K     Trainable params
0         Non-trainable params
594 K     Total params
2.376     Total estimated model params size (MB)
43        Modules in train mode
0         Modules in eval mode


Epoch 3: 100%|██████████| 3/3 [00:00<00:00, 25.14it/s, mol_train_loss_step=1.39e+3, bond_train_loss_step=102.0, train_loss_step=1.49e+3, mol_val_loss=650.0, bond_val_loss=86.70, val_loss=691.0, mol_train_loss_epoch=651.0, bond_train_loss_epoch=89.00, train_loss_epoch=717.0]

`Trainer.fit` stopped: `max_epochs=4` reached.


Epoch 3: 100%|██████████| 3/3 [00:00<00:00, 20.11it/s, mol_train_loss_step=1.39e+3, bond_train_loss_step=102.0, train_loss_step=1.49e+3, mol_val_loss=650.0, bond_val_loss=86.70, val_loss=691.0, mol_train_loss_epoch=651.0, bond_train_loss_epoch=89.00, train_loss_epoch=717.0]


Restoring states from the checkpoint path at /home/knathan/chemprop/molatombond_notebooks/checkpoints/best-epoch=3-val_loss=690.71.ckpt
Loaded model weights from the checkpoint at /home/knathan/chemprop/molatombond_notebooks/checkpoints/best-epoch=3-val_loss=690.71.ckpt


Testing DataLoader 0: 100%|██████████| 3/3 [00:00<00:00, 91.64it/s] 


Predicting DataLoader 0: 100%|██████████| 3/3 [00:00<00:00, 165.91it/s]


MolAtomBondMPNN(
  (message_passing): MABBondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_vo): Linear(in_features=372, out_features=300, bias=True)
    (W_eo): Linear(in_features=314, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (E_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): NormAggregation()
  (mol_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=2, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0, 1.0]])
    (output_transform): Identity()
  )
  (bond_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Lin

MolAtomBondMPNN(
  (message_passing): MABBondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_vo): Linear(in_features=372, out_features=300, bias=True)
    (W_eo): Linear(in_features=314, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (E_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): NormAggregation()
  (mol_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=2, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0, 1.0]])
    (output_transform): Identity()
  )
  (bond_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Lin

In [6]:
# None, Atom, Bond

mp = nn.MABBondMessagePassing()
atom_predictor = nn.RegressionFFN(n_tasks=atoms_ys[0].shape[1])
bond_predictor = nn.RegressionFFN(input_dim=(mp.output_dims[1] * 2), n_tasks=bonds_ys[0].shape[1])
model = models.MolAtomBondMPNN(
    message_passing=mp, atom_predictor=atom_predictor, bond_predictor=bond_predictor
)
display(model)

checkpointing = ModelCheckpoint(
    dirpath="checkpoints",
    filename="best-{epoch}-{val_loss:.2f}",
    monitor="val_loss",
    mode="min",
    save_last=True,
)
trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True,
    enable_progress_bar=True,
    max_epochs=4,
    callbacks=[checkpointing],
)

trainer.fit(model, train_dataloader, val_dataloader)
results = trainer.test(dataloaders=test_dataloader)
predss = trainer.predict(model, predict_dataloader)
_, atom_preds, bond_preds = (
    torch.concat(tensors) if tensors[0] is not None else None for tensors in zip(*predss)
)
assert atom_preds.shape == (30, 2)
assert bond_preds.shape == (21, 2)

models.utils.save_model("temp.pt", model)
display(models.MolAtomBondMPNN.load_from_file("temp.pt"))
display(models.MolAtomBondMPNN.load_from_checkpoint("checkpoints/last.ckpt"))

! rm -rf checkpoints/
! rm temp.pt

MolAtomBondMPNN(
  (message_passing): MABBondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_vo): Linear(in_features=372, out_features=300, bias=True)
    (W_eo): Linear(in_features=314, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (E_d_transform): Identity()
    (graph_transform): Identity()
  )
  (atom_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=2, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0, 1.0]])
    (output_transform): Identity()
  )
  (bond_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=600, out_f

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Loading `train_dataloader` to estimate number of stepping batches.

  | Name            | Type                  | Params | Mode 
------------------------------------------------------------------
0 | message_passing | MABBondMessagePassing | 322 K  | train
1 | atom_predictor  | RegressionFFN         | 90.9 K | train
2 | bond_predictor  | RegressionFFN         | 180 K  | train
3 | bns             | ModuleList            | 0      | train
4 | X_d_transform   | Identity              | 0      | train
5 | metricss        | ModuleList            | 0      | train
------------------------------------------------------------------
594 K     Trainable params
0         Non-trainable params
594 K     Total params
2.376     Total estimated model params size (MB)
42        Modules in train mode
0         Modules in eval mode


Epoch 3: 100%|██████████| 3/3 [00:00<00:00, 13.91it/s, atom_train_loss_step=62.60, bond_train_loss_step=110.0, train_loss_step=172.0, atom_val_loss=49.70, bond_val_loss=91.10, val_loss=111.0, atom_train_loss_epoch=51.30, bond_train_loss_epoch=94.00, train_loss_epoch=137.0]

`Trainer.fit` stopped: `max_epochs=4` reached.


Epoch 3: 100%|██████████| 3/3 [00:00<00:00,  9.84it/s, atom_train_loss_step=62.60, bond_train_loss_step=110.0, train_loss_step=172.0, atom_val_loss=49.70, bond_val_loss=91.10, val_loss=111.0, atom_train_loss_epoch=51.30, bond_train_loss_epoch=94.00, train_loss_epoch=137.0]


Restoring states from the checkpoint path at /home/knathan/chemprop/molatombond_notebooks/checkpoints/best-epoch=3-val_loss=110.92.ckpt
Loaded model weights from the checkpoint at /home/knathan/chemprop/molatombond_notebooks/checkpoints/best-epoch=3-val_loss=110.92.ckpt


Testing DataLoader 0: 100%|██████████| 3/3 [00:00<00:00, 60.47it/s]


Predicting DataLoader 0: 100%|██████████| 3/3 [00:00<00:00, 102.33it/s]


MolAtomBondMPNN(
  (message_passing): MABBondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_vo): Linear(in_features=372, out_features=300, bias=True)
    (W_eo): Linear(in_features=314, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (E_d_transform): Identity()
    (graph_transform): Identity()
  )
  (atom_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=2, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0, 1.0]])
    (output_transform): Identity()
  )
  (bond_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=600, out_f

MolAtomBondMPNN(
  (message_passing): MABBondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_vo): Linear(in_features=372, out_features=300, bias=True)
    (W_eo): Linear(in_features=314, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (E_d_transform): Identity()
    (graph_transform): Identity()
  )
  (atom_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=2, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0, 1.0]])
    (output_transform): Identity()
  )
  (bond_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=600, out_f

In [7]:
# Mol, None, None

mp = nn.MABBondMessagePassing(return_edge_embeddings=False)
agg = nn.NormAggregation()
mol_predictor = nn.RegressionFFN(n_tasks=mol_ys.shape[1])
model = models.MolAtomBondMPNN(message_passing=mp, agg=agg, mol_predictor=mol_predictor)
display(model)

checkpointing = ModelCheckpoint(
    dirpath="checkpoints",
    filename="best-{epoch}-{val_loss:.2f}",
    monitor="val_loss",
    mode="min",
    save_last=True,
)
trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True,
    enable_progress_bar=True,
    max_epochs=4,
    callbacks=[checkpointing],
)

trainer.fit(model, train_dataloader, val_dataloader)
results = trainer.test(dataloaders=test_dataloader)
predss = trainer.predict(model, predict_dataloader)
mol_preds, _, _ = (
    torch.concat(tensors) if tensors[0] is not None else None for tensors in zip(*predss)
)
assert mol_preds.shape == (11, 2)

models.utils.save_model("temp.pt", model)
display(models.MolAtomBondMPNN.load_from_file("temp.pt"))
display(models.MolAtomBondMPNN.load_from_checkpoint("checkpoints/last.ckpt"))

! rm -rf checkpoints/
! rm temp.pt

MolAtomBondMPNN(
  (message_passing): MABBondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_vo): Linear(in_features=372, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (E_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): NormAggregation()
  (mol_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=2, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0, 1.0]])
    (output_transform): Identity()
  )
  (bns): ModuleList(
    (0-2): 3 x Identity()
  )
  (X_d_transform): Identity()
  (metricss): ModuleList(
    (0): ModuleList(
      (0): MSE(task_weights

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Loading `train_dataloader` to estimate number of stepping batches.

  | Name            | Type                  | Params | Mode 
------------------------------------------------------------------
0 | message_passing | MABBondMessagePassing | 227 K  | train
1 | agg             | NormAggregation       | 0      | train
2 | mol_predictor   | RegressionFFN         | 90.9 K | train
3 | bns             | ModuleList            | 0      | train
4 | X_d_transform   | Identity              | 0      | train
5 | metricss        | ModuleList            | 0      | train
------------------------------------------------------------------
318 K     Trainable params
0         Non-trainable params
318 K     Total params
1.274     Total estimated model params size (MB)
29        Modules in train mode
0         Modules in eval mode


Epoch 3: 100%|██████████| 3/3 [00:00<00:00, 13.88it/s, mol_train_loss_step=434.0, train_loss_step=434.0, mol_val_loss=651.0, val_loss=625.0, mol_train_loss_epoch=651.0, train_loss_epoch=677.0]    

`Trainer.fit` stopped: `max_epochs=4` reached.


Epoch 3: 100%|██████████| 3/3 [00:00<00:00, 11.45it/s, mol_train_loss_step=434.0, train_loss_step=434.0, mol_val_loss=651.0, val_loss=625.0, mol_train_loss_epoch=651.0, train_loss_epoch=677.0]


Restoring states from the checkpoint path at /home/knathan/chemprop/molatombond_notebooks/checkpoints/best-epoch=3-val_loss=625.36.ckpt
Loaded model weights from the checkpoint at /home/knathan/chemprop/molatombond_notebooks/checkpoints/best-epoch=3-val_loss=625.36.ckpt


Testing DataLoader 0: 100%|██████████| 3/3 [00:00<00:00, 87.99it/s] 


Predicting DataLoader 0: 100%|██████████| 3/3 [00:00<00:00, 44.02it/s]


MolAtomBondMPNN(
  (message_passing): MABBondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_vo): Linear(in_features=372, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (E_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): NormAggregation()
  (mol_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=2, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0, 1.0]])
    (output_transform): Identity()
  )
  (bns): ModuleList(
    (0-2): 3 x Identity()
  )
  (X_d_transform): Identity()
  (metricss): ModuleList(
    (0): ModuleList(
      (0): MSE(task_weights

MolAtomBondMPNN(
  (message_passing): MABBondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_vo): Linear(in_features=372, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (E_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): NormAggregation()
  (mol_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=2, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0, 1.0]])
    (output_transform): Identity()
  )
  (bns): ModuleList(
    (0-2): 3 x Identity()
  )
  (X_d_transform): Identity()
  (metricss): ModuleList(
    (0): ModuleList(
      (0): MSE(task_weights

In [8]:
# None, Atom, None

mp = nn.MABBondMessagePassing(return_edge_embeddings=False)
atom_predictor = nn.RegressionFFN(n_tasks=atoms_ys[0].shape[1])
model = models.MolAtomBondMPNN(message_passing=mp, atom_predictor=atom_predictor)
display(model)

checkpointing = ModelCheckpoint(
    dirpath="checkpoints",
    filename="best-{epoch}-{val_loss:.2f}",
    monitor="val_loss",
    mode="min",
    save_last=True,
)
trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True,
    enable_progress_bar=True,
    max_epochs=4,
    callbacks=[checkpointing],
)

trainer.fit(model, train_dataloader, val_dataloader)
results = trainer.test(dataloaders=test_dataloader)
predss = trainer.predict(model, predict_dataloader)
_, atom_preds, _ = (
    torch.concat(tensors) if tensors[0] is not None else None for tensors in zip(*predss)
)
assert atom_preds.shape == (30, 2)

models.utils.save_model("temp.pt", model)
display(models.MolAtomBondMPNN.load_from_file("temp.pt"))
display(models.MolAtomBondMPNN.load_from_checkpoint("checkpoints/last.ckpt"))

! rm -rf checkpoints/
! rm temp.pt

MolAtomBondMPNN(
  (message_passing): MABBondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_vo): Linear(in_features=372, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (E_d_transform): Identity()
    (graph_transform): Identity()
  )
  (atom_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=2, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0, 1.0]])
    (output_transform): Identity()
  )
  (bns): ModuleList(
    (0-2): 3 x Identity()
  )
  (X_d_transform): Identity()
  (metricss): ModuleList(
    (0): None
    (1): ModuleList(
      (0): MSE(task_weights=[[1.0]])
  

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Loading `train_dataloader` to estimate number of stepping batches.

  | Name            | Type                  | Params | Mode 
------------------------------------------------------------------
0 | message_passing | MABBondMessagePassing | 227 K  | train
1 | atom_predictor  | RegressionFFN         | 90.9 K | train
2 | bns             | ModuleList            | 0      | train
3 | X_d_transform   | Identity              | 0      | train
4 | metricss        | ModuleList            | 0      | train
------------------------------------------------------------------
318 K     Trainable params
0         Non-trainable params
318 K     Total params
1.274     Total estimated model params size (MB)
28        Modules in train mode
0         Modules in eval mode


Epoch 3: 100%|██████████| 3/3 [00:00<00:00, 11.75it/s, atom_train_loss_step=50.50, train_loss_step=50.50, atom_val_loss=49.00, val_loss=42.10, atom_train_loss_epoch=50.70, train_loss_epoch=50.90]

`Trainer.fit` stopped: `max_epochs=4` reached.


Epoch 3: 100%|██████████| 3/3 [00:00<00:00,  9.78it/s, atom_train_loss_step=50.50, train_loss_step=50.50, atom_val_loss=49.00, val_loss=42.10, atom_train_loss_epoch=50.70, train_loss_epoch=50.90]


Restoring states from the checkpoint path at /home/knathan/chemprop/molatombond_notebooks/checkpoints/best-epoch=3-val_loss=42.06.ckpt
Loaded model weights from the checkpoint at /home/knathan/chemprop/molatombond_notebooks/checkpoints/best-epoch=3-val_loss=42.06.ckpt


Testing DataLoader 0: 100%|██████████| 3/3 [00:00<00:00, 68.98it/s]


Predicting DataLoader 0: 100%|██████████| 3/3 [00:00<00:00, 87.09it/s] 


MolAtomBondMPNN(
  (message_passing): MABBondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_vo): Linear(in_features=372, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (E_d_transform): Identity()
    (graph_transform): Identity()
  )
  (atom_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=2, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0, 1.0]])
    (output_transform): Identity()
  )
  (bns): ModuleList(
    (0-2): 3 x Identity()
  )
  (X_d_transform): Identity()
  (metricss): ModuleList(
    (0): None
    (1): ModuleList(
      (0): MSE(task_weights=[[1.0]])
  

MolAtomBondMPNN(
  (message_passing): MABBondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_vo): Linear(in_features=372, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (E_d_transform): Identity()
    (graph_transform): Identity()
  )
  (atom_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=2, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0, 1.0]])
    (output_transform): Identity()
  )
  (bns): ModuleList(
    (0-2): 3 x Identity()
  )
  (X_d_transform): Identity()
  (metricss): ModuleList(
    (0): None
    (1): ModuleList(
      (0): MSE(task_weights=[[1.0]])
  

In [9]:
# None, None, Bond

mp = nn.MABBondMessagePassing(return_vertex_embeddings=False)
bond_predictor = nn.RegressionFFN(input_dim=(mp.output_dims[1] * 2), n_tasks=bonds_ys[0].shape[1])
model = models.MolAtomBondMPNN(message_passing=mp, bond_predictor=bond_predictor)
display(model)

checkpointing = ModelCheckpoint(
    dirpath="checkpoints",
    filename="best-{epoch}-{val_loss:.2f}",
    monitor="val_loss",
    mode="min",
    save_last=True,
)
trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True,
    enable_progress_bar=True,
    max_epochs=4,
    callbacks=[checkpointing],
)

trainer.fit(model, train_dataloader, val_dataloader)
results = trainer.test(dataloaders=test_dataloader)
predss = trainer.predict(model, predict_dataloader)
_, _, bond_preds = (
    torch.concat(tensors) if tensors[0] is not None else None for tensors in zip(*predss)
)
assert bond_preds.shape == (21, 2)

models.utils.save_model("temp.pt", model)
display(models.MolAtomBondMPNN.load_from_file("temp.pt"))
display(models.MolAtomBondMPNN.load_from_checkpoint("checkpoints/last.ckpt"))

! rm -rf checkpoints/
! rm temp.pt

MolAtomBondMPNN(
  (message_passing): MABBondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_eo): Linear(in_features=314, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (E_d_transform): Identity()
    (graph_transform): Identity()
  )
  (bond_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=600, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=2, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0, 1.0]])
    (output_transform): Identity()
  )
  (bns): ModuleList(
    (0-2): 3 x Identity()
  )
  (X_d_transform): Identity()
  (metricss): ModuleList(
    (0-1): 2 x None
    (2): ModuleList(
      (0): MSE(task_weights=[[1.0

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Loading `train_dataloader` to estimate number of stepping batches.

  | Name            | Type                  | Params | Mode 
------------------------------------------------------------------
0 | message_passing | MABBondMessagePassing | 210 K  | train
1 | bond_predictor  | RegressionFFN         | 180 K  | train
2 | bns             | ModuleList            | 0      | train
3 | X_d_transform   | Identity              | 0      | train
4 | metricss        | ModuleList            | 0      | train
------------------------------------------------------------------
391 K     Trainable params
0         Non-trainable params
391 K     Total params
1.565     Total estimated model params size (MB)
28        Modules in train mode
0         Modules in eval mode


Epoch 3: 100%|██████████| 3/3 [00:02<00:00,  1.41it/s, bond_train_loss_step=87.70, train_loss_step=87.70, bond_val_loss=84.30, val_loss=63.60, bond_train_loss_epoch=89.30, train_loss_epoch=90.20]

`Trainer.fit` stopped: `max_epochs=4` reached.


Epoch 3: 100%|██████████| 3/3 [00:02<00:00,  1.37it/s, bond_train_loss_step=87.70, train_loss_step=87.70, bond_val_loss=84.30, val_loss=63.60, bond_train_loss_epoch=89.30, train_loss_epoch=90.20]


Restoring states from the checkpoint path at /home/knathan/chemprop/molatombond_notebooks/checkpoints/best-epoch=3-val_loss=63.59.ckpt
Loaded model weights from the checkpoint at /home/knathan/chemprop/molatombond_notebooks/checkpoints/best-epoch=3-val_loss=63.59.ckpt


Testing DataLoader 0: 100%|██████████| 3/3 [00:00<00:00, 49.45it/s]


Predicting DataLoader 0: 100%|██████████| 3/3 [00:00<00:00, 32.55it/s]


MolAtomBondMPNN(
  (message_passing): MABBondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_eo): Linear(in_features=314, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (E_d_transform): Identity()
    (graph_transform): Identity()
  )
  (bond_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=600, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=2, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0, 1.0]])
    (output_transform): Identity()
  )
  (bns): ModuleList(
    (0-2): 3 x Identity()
  )
  (X_d_transform): Identity()
  (metricss): ModuleList(
    (0-1): 2 x None
    (2): ModuleList(
      (0): MSE(task_weights=[[1.0

MolAtomBondMPNN(
  (message_passing): MABBondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_eo): Linear(in_features=314, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (E_d_transform): Identity()
    (graph_transform): Identity()
  )
  (bond_predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=600, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=2, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0, 1.0]])
    (output_transform): Identity()
  )
  (bns): ModuleList(
    (0-2): 3 x Identity()
  )
  (X_d_transform): Identity()
  (metricss): ModuleList(
    (0-1): 2 x None
    (2): ModuleList(
      (0): MSE(task_weights=[[1.0

In [10]:
# None, None, None -> Error

mp = nn.MABBondMessagePassing()
try:
    model = models.MolAtomBondMPNN(message_passing=mp)
except ValueError:
    print("Caught expected ValueError: No predictors provided.")

Caught expected ValueError: No predictors provided.
