## Saving and loading models

In [1]:
import torch
from chemprop.models.utils import save_model, load_model
from chemprop.models.model import MPNN
from chemprop.models.multi import MulticomponentMPNN
from chemprop import nn

This is an example buffer to save to and load from, to avoid creating new files when running this notebook. A real use case would probably save to and read from a file like `model.pt`.

In [2]:
import io

saved_model = io.BytesIO()

# from pathlib import Path
# saved_model = Path("model.pt")

### Saving models

A valid model save file is a dictionary containing the hyper parameters and state dict of the model. `torch` is used to pickle the dictionary.

In [3]:
model = MPNN(nn.BondMessagePassing(), nn.MeanAggregation(), nn.RegressionFFN())

save_model(saved_model, model)

# model_dict = {"hyper_parameters": model.hparams, "state_dict": model.state_dict()}
# torch.save(model_dict, saved_model)

`lightning` will also automatically create checkpoint files during training. These `.ckpt` files are like `.pt` model files, but also contain information about training and can be used to restart training. See the `lightning` documentation for more details.

In [4]:
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch import Trainer

checkpointing = ModelCheckpoint(
    dirpath="mycheckpoints",
    filename="best-{epoch}-{val_loss:.2f}",
    monitor="val_loss",
    mode="min",
    save_last=True,
)
trainer = Trainer(callbacks=[checkpointing])

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


### Loading models

`MPNN` and `MulticomponentMPNN` each have a class method to load a model from either a model file `.pt` or a checkpoint file `.ckpt`. The method to load from a file works for either model files or checkpoint files, but won't load the saved training information from a checkpoint file.

In [5]:
# Need to set the buffer stream position to the beginning, not necessary if using a file
saved_model.seek(0)

model = MPNN.load_from_file(saved_model)

# Other options
# model = MPNN.load_from_checkpoint(saved_model)
# model = MulticomponentMPNN.load_from_file(saved_model)
# model = MulticomponentMPNN.load_from_checkpoint(saved_model)