# Transfer Learning / Pretraining
Transfer learning (or pretraining) leverages knowledge from a pre-trained model on a related task to enhance performance on a new task. In Chemprop, we can use pre-trained model checkpoints to initialize a new model and freeze components of the new model during training, as demonstrated in this notebook.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chemprop/chemprop/blob/main/examples/transfer_learning.ipynb)

# Import packages

In [171]:
import os
import sys

current_path=os.getcwd()
print(current_path)

parent_path=os.path.dirname(current_path)
print(parent_path)

if parent_path not in sys.path:
    sys.path.append(parent_path)

/home/labhhc2/Documents/workspace/D20/Tam/backup/chemprop_1/examples
/home/labhhc2/Documents/workspace/D20/Tam/backup/chemprop_1


In [172]:
import pandas as pd
from pathlib import Path

from lightning import pytorch as pl
from sklearn.preprocessing import StandardScaler
import torch

from chemprop import data, featurizers, models, nn

# Change data inputs here

In [173]:
import numpy as np
chemprop_dir = Path.cwd().parent
num_workers = 0 # number of workers for dataloader. 0 means using main process for data loading
# smiles_column = 'smiles' # name of the column containing SMILES strings
# target_columns = ['lipo'] # list of names of the columns containing targets

## Load data

In [174]:
train_path = chemprop_dir / "tests" / "data" / "regression" / "rxn" / "rad6re" / "fold_4" / "aam_train.csv"
train_npz = np.load(f'../chemprop/data/super/rad6re/fold_4/rad6re_aam_train_processed_data.npz', allow_pickle=True)
train_v = train_npz['node_attrs']
train_e = train_npz['edge_attrs']
train_idx_g = train_npz['edge_indices']
train_y = train_npz['ys'] 

val_path = chemprop_dir / "tests" / "data" / "regression" / "rxn" / "rad6re" / "fold_4" / "aam_val.csv"
val_npz = np.load(f'../chemprop/data/super/rad6re/fold_4/rad6re_aam_val_processed_data.npz', allow_pickle=True)
val_v = val_npz['node_attrs']
val_e = val_npz['edge_attrs']
val_idx_g = val_npz['edge_indices']
val_y = val_npz['ys'] 

test_path = chemprop_dir / "tests" / "data" / "regression" / "rxn" / "rad6re" / "fold_4" / "aam_test.csv"
test_npz = np.load(f'../chemprop/data/super/rad6re/fold_4/rad6re_aam_test_processed_data.npz', allow_pickle=True)
test_v = test_npz['node_attrs']
test_e = test_npz['edge_attrs']
test_idx_g = test_npz['edge_indices']
test_y = test_npz['ys'] 

## Get molecule datapoints

In [175]:
train_dset = data.ReactionDataset(train_v, train_e, train_idx_g, train_y)
print(train_dset[0][3])
scaler = train_dset.normalize_targets()
# print(scaler)
print(train_dset[0][3])

val_dset = data.ReactionDataset(val_v, val_e, val_idx_g, val_y)
val_dset.normalize_targets(scaler)
test_dset = data.ReactionDataset(test_v, test_e, test_idx_g, test_y)

[-1.17180312]
[[-1.17180312]
 [-1.14678741]
 [ 0.78115952]
 ...
 [-2.7202096 ]
 [-0.77980316]
 [-3.99633527]]
[[-0.29984787]
 [-0.29344671]
 [ 0.19988769]
 ...
 [-0.69606323]
 [-0.19954062]
 [-1.02260577]]
[-0.29984787]
[[-0.52384804]
 [ 0.45034463]
 [ 0.44598369]
 ...
 [-1.14828404]
 [-1.26933186]
 [-1.22618143]]


## Perform data splitting for training, validation, and testing

# Change checkpoint model inputs here
Both message-passing neural networks (MPNNs) and multi-component MPNNs can have their weights initialized from a checkpoint file.

In [176]:
chemprop_dir = Path.cwd().parent
checkpoint_path = chemprop_dir / "examples" / "checkpoints" / "epoch=99-step=79900-v8.ckpt" # path to the checkpoint file.
# If the checkpoint file is generated using the training notebook, it will be in the `checkpoints` folder with name similar to `checkpoints/epoch=19-step=180.ckpt`.

In [177]:
mpnn_cls = models.MPNN

In [178]:
mpnn = mpnn_cls.load_from_file(checkpoint_path)
mpnn

MPNN(
  (message_passing): BondMessagePassing(
    (W_i): Linear(in_features=27, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_o): Linear(in_features=321, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (graph_transform): Identity()
  )
  (agg): SumAggregation()
  (bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=1, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0]])
    (output_transform): UnscaleTransform()
  )
  (X_d_transform): Identity()
  (metrics): ModuleList(
    (0): RMSE(task_weights=[[1.0]])
    (1): MAE(task_weight

# Scale fine-tuning data with the model's target scaler

If the pre-trained model was a regression model, it probably was trained on a scaled dataset. The scaler is saved as part of the model and used during prediction. For furthur training, we need to scale the fine-tuning data with the same target scaler.

In [179]:
# pretraining_scaler = StandardScaler()
# pretraining_scaler.mean_ = mpnn.predictor.output_transform.mean.numpy()
# pretraining_scaler.scale_ = mpnn.predictor.output_transform.scale.numpy()

## Get MoleculeDataset

In [180]:

# train_dset.normalize_targets(pretraining_scaler)

# val_dset.normalize_targets(pretraining_scaler)

# # test_dset = data.MoleculeDataset(test_dset[0], featurizer)


## Get DataLoader

In [181]:
train_loader = data.build_dataloader(train_dset, num_workers=num_workers)
val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)
test_loader = data.build_dataloader(test_dset, num_workers=num_workers, shuffle=False)

# Freezing MPNN and FFN layers
Certain layers of a pre-trained model can be kept unchanged during further training on a new task.

## Freezing the MPNN

In [182]:
mpnn.message_passing.apply(lambda module: module.requires_grad_(False))
mpnn.message_passing.eval()
mpnn.bn.apply(lambda module: module.requires_grad_(False))
mpnn.bn.eval()  # Set batch norm layers to eval mode to freeze running mean and running var.

BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

## Freezing FFN layers

In [183]:
frzn_ffn_layers = 1  # the number of consecutive FFN layers to freeze.

In [184]:
for idx in range(frzn_ffn_layers):
    mpnn.predictor.ffn[idx].requires_grad_(False)
    mpnn.predictor.ffn[idx + 1].eval()

# Set up trainer

In [185]:
trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True, # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=100, # number of epochs to train for
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


# Start training

In [186]:
trainer.fit(mpnn, train_loader, val_loader)

/home/labhhc2/anaconda3/envs/chemprop/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /home/labhhc2/Documents/workspace/D20/Tam/backup/chemprop_1/examples/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.
/home/labhhc2/anaconda3/envs/chemprop/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.

  | Name            | Type               | Params | Mode 
---------------------------------------------------------------
0 | message_passing | BondMessagePassing | 194 K  | eval 
1 | agg             | SumAggregation     | 0      | train
2 | bn              | BatchNorm1d        | 600    | eval 
3 | predictor     

                                                                            

/home/labhhc2/anaconda3/envs/chemprop/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


Epoch 99: 100%|██████████| 799/799 [00:34<00:00, 23.09it/s, train_loss_step=0.000937, val_loss=0.00498, train_loss_epoch=0.000457] 

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 99: 100%|██████████| 799/799 [00:34<00:00, 23.07it/s, train_loss_step=0.000937, val_loss=0.00498, train_loss_epoch=0.000457]


# Test results

In [187]:
results = trainer.test(mpnn, test_loader)


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/labhhc2/anaconda3/envs/chemprop/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 100/100 [00:00<00:00, 439.30it/s]


# Transfer learning with multicomponenent models
Multi-component MPNN models have individual MPNN blocks for each molecule it parses in one input. These MPNN modules can be independently frozen for transfer learning.

## Change data inputs here

In [188]:
chemprop_dir = Path.cwd().parent
checkpoint_path = chemprop_dir / "tests" / "data" / "example_model_v2_regression_mol+mol.ckpt"  # path to the checkpoint file. 

## Change checkpoint model inputs here

In [189]:
mpnn_cls = models.MulticomponentMPNN
mcmpnn = mpnn_cls.load_from_checkpoint(checkpoint_path)
mcmpnn

MulticomponentMPNN(
  (message_passing): MulticomponentMessagePassing(
    (blocks): ModuleList(
      (0-1): 2 x BondMessagePassing(
        (W_i): Linear(in_features=86, out_features=300, bias=False)
        (W_h): Linear(in_features=300, out_features=300, bias=False)
        (W_o): Linear(in_features=372, out_features=300, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (tau): ReLU()
        (V_d_transform): Identity()
        (graph_transform): GraphTransform(
          (V_transform): Identity()
          (E_transform): Identity()
        )
      )
    )
  )
  (agg): MeanAggregation()
  (bn): BatchNorm1d(600, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=600, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=1, bias=True)
  

In [190]:
blocks_to_freeze = [0, 1]  # a list of indices of the individual MPNN blocks to freeze before training.

In [191]:
mcmpnn = mpnn_cls.load_from_checkpoint(checkpoint_path)
for i in blocks_to_freeze:
    mp_block = mcmpnn.message_passing.blocks[i]
    mp_block.apply(lambda module: module.requires_grad_(False))
    mp_block.eval()
mcmpnn.bn.apply(lambda module: module.requires_grad_(False))
mcmpnn.bn.eval()

BatchNorm1d(600, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)