# Transfer Learning / Pretraining
Transfer learning (or pretraining) leverages knowledge from a pre-trained model on a related task to enhance performance on a new task. In Chemprop, we can use pre-trained model checkpoints to initialize a new model and freeze components of the new model during training, as demonstrated in this notebook.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chemprop/chemprop/blob/main/examples/transfer_learning.ipynb)

In [None]:
import os
import sys

current_path=os.getcwd()
print(current_path)

parent_path=os.path.dirname(current_path)
print(parent_path)

if parent_path not in sys.path:
    sys.path.append(parent_path)

In [None]:
import pandas as pd
from lightning import pytorch as pl
from sklearn.preprocessing import StandardScaler

from pathlib import Path

from chemprop import data, featurizers, models, nn

# Change data inputs here

In [None]:
chemprop_dir = Path.cwd().parent
input_path = chemprop_dir / "tests" / "data" / "regression" / "mol" / "mol.csv" # path to your data .csv file
num_workers = 0 # number of workers for dataloader. 0 means using main process for data loading
smiles_column = 'smiles' # name of the column containing SMILES strings
target_columns = ['lipo'] # list of names of the columns containing targets

## Load data

In [None]:
import numpy as np
chemprop_dir = Path.cwd().parent
num_workers = 0  # number of workers for dataloader. 0 means using main process for data loading
# smiles_column = 'AAM'
# target_columns = ['lograte']

In [None]:
train_path = chemprop_dir / "tests" / "data" / "regression" / "rxn" / "barriers_rdb7" / "train.csv"
train_npz = np.load(f'../chemprop/data/normal/barriers_rdb7/barriers_rdb7_aam_train_processed_data.npz', allow_pickle=True)
train_v = train_npz['node_attrs']
train_e = train_npz['edge_attrs']
train_idx_g = train_npz['edge_indices']
train_y = train_npz['ys'] 

val_path = chemprop_dir / "tests" / "data" / "regression" / "rxn" / "barriers_rdb7" / "val.csv"
val_npz = np.load(f'../chemprop/data/normal/barriers_rdb7/barriers_rdb7_aam_val_processed_data.npz', allow_pickle=True)
val_v = val_npz['node_attrs']
val_e = val_npz['edge_attrs']
val_idx_g = val_npz['edge_indices']
val_y = val_npz['ys'] 

test_path = chemprop_dir / "tests" / "data" / "regression" / "rxn" / "barriers_rdb7" / "test.csv"
test_npz = np.load(f'../chemprop/data/normal/barriers_rdb7/barriers_rdb7_aam_test_processed_data.npz', allow_pickle=True)
test_v = test_npz['node_attrs']
test_e = test_npz['edge_attrs']
test_idx_g = test_npz['edge_indices']
test_y = test_npz['ys'] 

In [None]:
print(train_v.shape)

In [None]:
print(train_idx_g.shape, val_y.shape, test_y.shape)

## Get molecule datapoints

In [None]:
train_dset = data.ReactionDataset(train_v, train_e, train_idx_g, train_y)
print(train_dset[0][3])
scaler = train_dset.normalize_targets()
# print(scaler)
print(train_dset[0][3])

val_dset = data.ReactionDataset(val_v, val_e, val_idx_g, val_y)
val_dset.normalize_targets(scaler)
test_dset = data.ReactionDataset(test_v, test_e, test_idx_g, test_y)

# Change checkpoint model inputs here
Both message-passing neural networks (MPNNs) and multi-component MPNNs can have their weights initialized from a checkpoint file.

In [None]:
chemprop_dir = Path.cwd().parent
checkpoint_path = chemprop_dir / "tests" / "data" / "example_model_v2_regression_mol.ckpt" # path to the checkpoint file.
# If the checkpoint file is generated using the training notebook, it will be in the `checkpoints` folder with name similar to `checkpoints/epoch=19-step=180.ckpt`.

In [None]:
mpnn_cls = models.MPNN

In [None]:
mpnn = mpnn_cls.load_from_file(checkpoint_path)
mpnn

# Scale fine-tuning data with the model's target scaler

If the pre-trained model was a regression model, it probably was trained on a scaled dataset. The scaler is saved as part of the model and used during prediction. For furthur training, we need to scale the fine-tuning data with the same target scaler.

In [None]:
pretraining_scaler = StandardScaler()
pretraining_scaler.mean_ = mpnn.predictor.output_transform.mean.numpy()
pretraining_scaler.scale_ = mpnn.predictor.output_transform.scale.numpy()

## Get MoleculeDataset

In [None]:
train_dset = data.ReactionDataset(train_v, train_e, train_idx_g, train_y)
train_dset.normalize_targets(pretraining_scaler)


val_dset = data.ReactionDataset(val_v, val_e, val_idx_g, val_y)
val_dset.normalize_targets(pretraining_scaler)
test_dset = data.ReactionDataset(test_v, test_e, test_idx_g, test_y)

In [None]:
train_dset[0][3]

edge_index=train_dset[1][0][-2]
print(f'edge_index: {edge_index}')
reverse_index=train_dset[1][0][-1]
print(f'reverse_index: {reverse_index}')

import numpy as np

np.arange(6).reshape(-1,2)[:, ::-1].ravel()

## Get DataLoader

In [None]:
train_loader = data.build_dataloader(train_dset, num_workers=num_workers)
val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)
test_loader = data.build_dataloader(test_dset, num_workers=num_workers, shuffle=False)

# Freezing MPNN and FFN layers
Certain layers of a pre-trained model can be kept unchanged during further training on a new task.

## Freezing the MPNN

In [None]:
mpnn.message_passing.apply(lambda module: module.requires_grad_(False))
mpnn.message_passing.eval()
mpnn.bn.apply(lambda module: module.requires_grad_(False))
mpnn.bn.eval()  # Set batch norm layers to eval mode to freeze running mean and running var.

## Freezing FFN layers

In [None]:
frzn_ffn_layers = 1  # the number of consecutive FFN layers to freeze.

In [None]:
for idx in range(frzn_ffn_layers):
    mpnn.predictor.ffn[idx].requires_grad_(False)
    mpnn.predictor.ffn[idx + 1].eval()

# Set up trainer

In [None]:
trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True, # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=20, # number of epochs to train for
)

# Start training

In [None]:
trainer.fit(mpnn, train_loader, val_loader)

# Test results

In [None]:
results = trainer.test(mpnn, test_loader)


# Transfer learning with multicomponenent models
Multi-component MPNN models have individual MPNN blocks for each molecule it parses in one input. These MPNN modules can be independently frozen for transfer learning.

## Change data inputs here

In [None]:
chemprop_dir = Path.cwd().parent
checkpoint_path = chemprop_dir / "tests" / "data" / "example_model_v2_regression_mol+mol.ckpt"  # path to the checkpoint file. 

## Change checkpoint model inputs here

In [None]:
mpnn_cls = models.MulticomponentMPNN
mcmpnn = mpnn_cls.load_from_checkpoint(checkpoint_path)
mcmpnn

In [None]:
blocks_to_freeze = [0, 1]  # a list of indices of the individual MPNN blocks to freeze before training.

In [None]:
mcmpnn = mpnn_cls.load_from_checkpoint(checkpoint_path)
for i in blocks_to_freeze:
    mp_block = mcmpnn.message_passing.blocks[i]
    mp_block.apply(lambda module: module.requires_grad_(False))
    mp_block.eval()
mcmpnn.bn.apply(lambda module: module.requires_grad_(False))
mcmpnn.bn.eval()