# `CheMeleon` Foundation Finetuning

This notebook demonstrates how to use the `CheMeleon` foundation model with Chemprop to achieve accurate prediction on small datasets.
One can also use this functionality from the Command Line Interface by using `--from-foundation chemeleon`.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chemprop/chemprop/blob/main/examples/chemeleon_foundation_finetuning.ipynb)

In [1]:
# Install chemprop from GitHub if running in Google Colab
import os

if os.getenv("COLAB_RELEASE_TAG"):
    try:
        import chemprop
    except ImportError:
        !git clone https://github.com/chemprop/chemprop.git
        %cd chemprop
        !pip install .
        %cd examples

## Retrieving the `CheMeleon` Model

The `CheMeleon` model file is stored on Zenodo at [this link](https://zenodo.org/records/15426601).
Please cite the Zenodo if you use this model in published work.
You can manually download for your own use, or simply execute the below cell to programatically download it using Python:

In [2]:
from urllib.request import urlretrieve

urlretrieve(
    r"https://zenodo.org/records/15460715/files/chemeleon_mp.pt",
    "chemeleon_mp.pt",
)

('chemeleon_mp.pt', <http.client.HTTPMessage at 0x751223631d60>)

## Initializing `CheMeleon`

`CheMeleon` uses the following classes for featurization, message passing, and aggregation:

In [3]:
import torch

from chemprop import featurizers, nn

featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()
agg = nn.MeanAggregation()
chemeleon_mp = torch.load("chemeleon_mp.pt", weights_only=True)
mp = nn.BondMessagePassing(**chemeleon_mp['hyper_parameters'])
mp.load_state_dict(chemeleon_mp['state_dict'])

<All keys matched successfully>

If you have existing Chemprop training code you can simply replace your `agg`, `featurizer`, and `mp` with these classes and you can immediately take advantage of `CheMeleon`!

In general, we suggest continuing to train the `CheMeleon` weights during finetuning.
You may find that in some cases, freezing the weights (`mp.eval()`, `mp.apply(lambda module: module.requires_grad_(False))`) may improve performance.

## Standard Chemprop Preparation

The below code handles importing needed modules, setting up the data, and initializing the Chemprop model.
It's **mostly** the same as the `training` example provided in the Chemprop repository - for a more detailed breakdown, [see the docs here](https://chemprop.readthedocs.io/en/latest/training.html).

The one important change is that we must set `input_dim=mp.output_dim` when we initialize our FFN.
This ensures that the dimension of the learned representation from `CheMeleon` matches the input size for the regressor.
Also important to note here is that to make the `CheMeleon` model useful you set up your own FFN to regress the target you care about - in this case lipophilicity.

You can also use `CheMeleon` for classification tasks.
See the [regular classification demo notebook](https://chemprop.readthedocs.io/en/latest/training_classification.html), which can be modified as shown above to load `CheMeleon`.

In [4]:
from pathlib import Path

from lightning import pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
import pandas as pd

from chemprop import data, models

chemprop_dir = Path.cwd().parent
input_path = chemprop_dir / "tests" / "data" / "regression" / "mol" / "mol.csv" # path to your data .csv file
num_workers = 0 # number of workers for dataloader. 0 means using main process for data loading
smiles_column = 'smiles' # name of the column containing SMILES strings
target_columns = ['lipo'] # list of names of the columns containing targets
df_input = pd.read_csv(input_path)
smis = df_input.loc[:, smiles_column].values
ys = df_input.loc[:, target_columns].values
all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]
mols = [d.mol for d in all_data]  # RDkit Mol objects are use for structure based splits
train_indices, val_indices, test_indices = data.make_split_indices(mols, "random", (0.8, 0.1, 0.1))  # unpack the tuple into three separate lists
train_data, val_data, test_data = data.split_data_by_indices(
    all_data, train_indices, val_indices, test_indices
)
train_dset = data.MoleculeDataset(train_data[0], featurizer)
scaler = train_dset.normalize_targets()
val_dset = data.MoleculeDataset(val_data[0], featurizer)
val_dset.normalize_targets(scaler)
test_dset = data.MoleculeDataset(test_data[0], featurizer)
train_loader = data.build_dataloader(train_dset, num_workers=num_workers)
val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)
test_loader = data.build_dataloader(test_dset, num_workers=num_workers, shuffle=False)
output_transform = nn.UnscaleTransform.from_standard_scaler(scaler)
ffn = nn.RegressionFFN(output_transform=output_transform, input_dim=mp.output_dim)
metric_list = [nn.metrics.RMSE(), nn.metrics.MAE()]
mpnn = models.MPNN(mp, agg, ffn, batch_norm=False, metrics=metric_list)

The return type of make_split_indices has changed in v2.1 - see help(make_split_indices)


Now we can take a look at the model, which we can see has the huge message passing setup from `CheMeleon`:

In [5]:
mpnn

MPNN(
  (message_passing): BondMessagePassing(
    (W_i): Linear(in_features=86, out_features=2048, bias=False)
    (W_h): Linear(in_features=2048, out_features=2048, bias=False)
    (W_o): Linear(in_features=2120, out_features=2048, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): MeanAggregation()
  (bn): Identity()
  (predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=2048, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=1, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0]])
    (output_transform): UnscaleTransform()
  )
  (X_d_transform): Identity()
  (metrics): ModuleList(
    (0): RMSE(task_weights=[[1.0]])
    (1): MAE(task_weights=[[1.0]])
    (2): MSE(task_weights=[[1.0]])
  )
)

## Training

The remainder of this notebook again follows the typical training routine.
With the addition of `CheMeleon` your model may take longer to complete a single epoch due to the increased number of parameters but will (hopefully!) have better performance, particularly if the dataset you have is small, and require fewer epochs to converge!

In [6]:
# Configure model checkpointing
checkpointing = ModelCheckpoint(
    "checkpoints",  # Directory where model checkpoints will be saved
    "best-{epoch}-{val_loss:.2f}",  # Filename format for checkpoints, including epoch and validation loss
    "val_loss",  # Metric used to select the best checkpoint (based on validation loss)
    mode="min",  # Save the checkpoint with the lowest validation loss (minimization objective)
    save_last=True,  # Always save the most recent checkpoint, even if it's not the best
)
trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True, # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=20, # number of epochs to train for
    callbacks=[checkpointing], # Use the configured checkpoint callback
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [7]:
trainer.fit(mpnn, train_loader, val_loader)

/home/jackson/miniconda3/envs/chemprop_dev/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /home/jackson/chemprop/examples/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.
/home/jackson/miniconda3/envs/chemprop_dev/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.

  | Name            | Type               | Params | Mode 
---------------------------------------------------------------
0 | message_passing | BondMessagePassing | 8.7 M  | train
1 | agg             | MeanAggregation    | 0      | train
2 | bn              | Identity           | 0      | train
3 | predictor       | RegressionFFN      | 61

Sanity Checking DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]

/home/jackson/miniconda3/envs/chemprop_dev/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Epoch 19: 100%|██████████| 2/2 [00:00<00:00,  9.11it/s, train_loss_step=0.0298, val_loss=0.548, train_loss_epoch=0.0245]

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 2/2 [00:00<00:00,  4.84it/s, train_loss_step=0.0298, val_loss=0.548, train_loss_epoch=0.0245]


In [8]:
results = trainer.test(dataloaders=test_loader)

Restoring states from the checkpoint path at /home/jackson/chemprop/examples/checkpoints/best-epoch=16-val_loss=0.49.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/jackson/chemprop/examples/checkpoints/best-epoch=16-val_loss=0.49.ckpt
/home/jackson/miniconda3/envs/chemprop_dev/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 87.82it/s] 
