# Finetune on experimental data
This notebook demonstrates how to finetune METL models on experimental data.

In [1]:
# @title Cloning metl
!git clone https://github.com/gitter-lab/metl.git
%cd metl

Cloning into 'metl'...
remote: Enumerating objects: 416, done.[K
remote: Counting objects: 100% (416/416), done.[K
remote: Compressing objects: 100% (280/280), done.[K
remote: Total 416 (delta 166), reused 330 (delta 98), pack-reused 0 (from 0)[K
Receiving objects: 100% (416/416), 18.08 MiB | 14.06 MiB/s, done.
Resolving deltas: 100% (166/166), done.
/content/metl


In [2]:
# @title Setting up conda to download notebook dependencies (this takes a while)
# @markdown This step may take 10-20 minutes.

!wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ./miniconda.sh
!chmod +x miniconda.sh
!bash ./miniconda.sh -b -u -p /usr/local
!conda env update -q -n base -f ./environment.yml

--2024-08-28 22:01:31--  https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
Resolving repo.anaconda.com (repo.anaconda.com)... 104.16.191.158, 104.16.32.241, 2606:4700::6810:20f1, ...
Connecting to repo.anaconda.com (repo.anaconda.com)|104.16.191.158|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 148981743 (142M) [application/octet-stream]
Saving to: ‘./miniconda.sh’


2024-08-28 22:01:32 (87.1 MB/s) - ‘./miniconda.sh’ saved [148981743/148981743]

PREFIX=/usr/local
Unpacking payload ...

Installing base environment...

Preparing transaction: ...working... done
Executing transaction: ...working... done
installation finished.
    You currently have a PYTHONPATH environment variable set. This may cause
    unexpected behavior when running the Python interpreter in Miniconda3.
    For best results, please verify that your PYTHONPATH only points to
    directories of packages that are compatible with the Python interpreter
    in Miniconda3: /us

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
import os
import sys

sys.path.append('/usr/local/lib/python3.9/site-packages')
# define the name of the project root directory
project_root_dir_name = "metl"

# find the project root by checking each parent directory
current_dir = os.getcwd()
while os.path.basename(current_dir) != project_root_dir_name and current_dir != os.path.dirname(current_dir):
    current_dir = os.path.dirname(current_dir)

# change the current working directory to the project root directory
if os.path.basename(current_dir) == project_root_dir_name:
    os.chdir(current_dir)
else:
    print("project root directory not found")

# add the project code folder to the system path so imports work
module_path = os.path.abspath("code")
if module_path not in sys.path:
    sys.path.append(module_path)

# Acquire an experimental dataset

For demonstration purposes, this repository contains the [avGFP dataset](https://github.com/gitter-lab/metl/tree/main/data/dms_data/avgfp) from [Sarkisyan et al. (2016)](https://doi.org/10.1038/nature17995).
See the [metl-pub](https://github.com/gitter-lab/metl-pub) repository to access the other experimental datasets we used in our preprint.
See the README in the [dms_data](https://github.com/gitter-lab/metl/tree/main/data/dms_data/) directory for information about how to use your own experimental dataset.

# Acquire a pretrained model
Pretrained METL models are available in the [metl-pretrained](https://github.com/gitter-lab/metl-pretrained) repository. You can use one of those, or you can pretrain your own METL model (see [pretraining.ipynb](https://github.com/gitter-lab/metl/blob/main/notebooks/pretraining.ipynb)).

For demonstration purposes, we include a pretrained avGFP METL-Local model from the [metl-pretrained](https://github.com/gitter-lab/metl-pretrained) repository in the [pretrained_models](https://github.com/gitter-lab/metl/tree/main/pretrained_models) directory. This model is `METL-L-2M-3D-GFP` (UUID: `Hr4GNHws`).
It is the avGFP METL-Local source model we used for the analysis in our preprint.

We will show how to finetune this model using the [experimental avGFP dataset](https://github.com/gitter-lab/metl/tree/main/data/dms_data/avgfp).

# Training arguments

The script for finetuning on experimental data is [train_target_model.py](https://github.com/gitter-lab/metl/blob/main/code/train_target_model.py). This script has a number of arguments you can view by uncommenting and running the below cell. There are additional arguments related to architecture that won't show up if you run the command, but you can view them in [models.py](https://github.com/gitter-lab/metl/tree/main/code/models.py) in the `TransferModel` class.

In [5]:
# !python code/train_target_model.py -h

We set up finetuning arguments for this example in [finetune_avgfp_local.txt](https://github.com/gitter-lab/metl/tree/main/args/pretrain_avgfp_local.txt) in the [args](https://github.com/gitter-lab/metl/tree/main/args) directory. This argument file can be used directly with [train_target_model.py](https://github.com/gitter-lab/metl/blob/main/code/train_target_model.py) by calling the command `!python code/train_target_model.py @args/finetune_avgfp_local.txt` (we do this in the next section).

Uncomment and run the cell below to view the contents of the argument file. The sections below will walk through and explain the key arguments.

In [6]:
# with open("args/finetune_avgfp_local.txt", "r") as file:
#     contents = file.read()
#     print(contents)

## Dataset arguments


Specify the dataset name and the train/val/test split. The dataset must be defined in [datasets.yml](https://github.com/gitter-lab/metl/tree/main/data/dms_data/datasets.yml). For demonstration purposes, we are using one of the reduced dataset size splits with a dataset size of 160 (train size of 128).
```
--ds_name
avgfp
--split_dir
data/dms_data/avgfp/splits/resampled/resampled_ds160_val0.2_te0.1_w1abc2f4e9a64_s1_r8099/resampled_ds160_val0.2_te0.1_w1abc2f4e9a64_s1_r8099_rep_0
```

Specify the names of the train, validation, and test set files in the split directory. Using "auto" for the test_name will select the super test set ("stest.txt") if it exists in the split directory, otherwise it will use the standard test set ("test.txt").

```
--train_name
train
--val_name
val
--test_name
test
```

The name of the target column in the dataset dataframe. The model will be finetuned to predict the score in this column.

```
--target_names
score
```

The METL-Local model we are finetuning uses 3D structure-based relative position embeddings, so we need to specify the PDB filename. This PDB file is in the [data/pdb_files](https://github.com/gitter-lab/metl/tree/main/data/pdb_files) directory, which the script checks by default, so there is no need to specify the full path. You can also just specify "auto" to use the PDB file defined for this dataset in [datasets.yml](https://github.com/gitter-lab/metl/tree/main/data/dms_data/datasets.yml).

```
--pdb_fn
1gfl_cm.pdb
```

## Network architecture arguments

For finetuning, we implemented a special model `transfer_model` that handles pretrained checkpoints with top nets.
```
--model_name
transfer_model
```

The pretrained checkpoint can be a PyTorch checkpoint (.pt file) downloaded from the [metl-pretrained](https://github.com/gitter-lab/metl-pretrained) repository or a PyTorch Lightning checkpoint (.ckpt file) obtained from pretraining a model with this repository.
```
--pretrained_ckpt_path
pretrained_models/Hr4GNHws.pt
```

The backbone cutoff determines where to cutoff the pretrained model and place the new prediction head. For METL-Local models, we recommend backbone cutoff -1, and for METL-Global models we recommend backbone cutoff -2.

```
--backbone_cutoff
-1
```

The remaining arguments determine the encoding, which should be set to `int_seqs`, whether to use dropout after the backbone cutoff, and the architecture of the new top net. You can leave these values as-is to match what we did for the preprint.

```
--encoding
int_seqs
--dropout_after_backbone
--dropout_after_backbone_rate
0.5
--top_net_type
linear
```

## Finetuning strategy arguments

We implemented a dual-phase finetuning strategy. During the first phase, the backbone weights are frozen and only the top net is trained. During the second phase, all the network weights are unfrozen and trained at a reduced learning rate.

The unfreeze_backbone_at_epoch argument determines the training epoch at which to unfreeze the backbone. We train the models for 500 epochs, so the backbone is unfrozen halfway through at epoch 250.

```
--finetuning
--finetuning_strategy
backbone
--unfreeze_backbone_at_epoch
250
--backbone_always_align_lr
```

## Optimization arguments

Basic optimizer arguments include the batch size, learning rate, and maximum number of epochs to train for. Unless early stopping is enabled, the model will train for the given number of epochs.

```
--optimizer
adamw
--weight_decay
0.1
--batch_size
128
--learning_rate
0.001
--max_epochs
500
--gradient_clip_val
0.5
```

The learning rate scheduler we used for finetuning is a dual phase learning rate schedule that matches the dual phase finetuning strategy. Each phase has a linear learning rate warmup for 1% of the total steps in that phase. There is also a cosine decay for the learning rate for each phase. The phase 2 learning rate is 10% of the phase 1 learning rate.

```
--lr_scheduler
dual_phase_warmup_constant_cosine_decay
--warmup_steps
.01
--phase2_lr_ratio
0.1

```

## Logging arguments

We have built in functionality for tracking model training with Weights & Biases. If you have a Weights and Biases account, set the environment variable `WANDB_API_KEY` to your API key and set the flag `--use_wandb` instead of `--no_use_wandb` below.

```
--no_use_wandb
--wandb_project
metl-target
--wandb_online
--experiment
default
```

The below argument determines where to place the log directory locally.
```
--log_dir_base
output/training_logs
```

# Running training

All the arguments described above are contained in [finetune_avgfp_local.txt](https://github.com/gitter-lab/metl/tree/main/args/pretrain_avgfp_local.txt), which can be fed directly into [train_target_model.py](https://github.com/gitter-lab/metl/blob/main/code/train_target_model.py).

PyTorch Lightning has a built-in progress bar that is convenient for seeing training progress, but it does not display correctly in Jupyter when calling the script with `!python`. We are going to disable the progress bar for by setting the flag `--enable_progress_bar false`. Instead, we implemented a simple print statement to track training progress, which we will enable with the flag `--enable_simple_progress_messages`.

In [7]:
!python code/train_target_model.py @args/finetune_avgfp_local.txt --enable_progress_bar false --enable_simple_progress_messages --max_epochs 50 --unfreeze_backbone_at_epoch 25

Random seed not specified, using: 522644021
Global seed set to 522644021
Created model UUID: fmngE6sB
Created log directory: output/training_logs/fmngE6sB
Final UUID: fmngE6sB
Final log directory: output/training_logs/fmngE6sB
Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.
  rank_zero_warn(
Number of training steps is 50
Number of warmup steps is 0.5
Second warmup phase starts at step 25
total_steps 50
phase1_total_steps 25
phase2_total_steps 25
┏━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃[1;35m [0m[1;35m [0m[1;35m [0m┃[1;35m [0m[1;35mName  

# Additional recommendations

## Model selection

Selecting the model from the epoch with the lowest validation set loss can help prevent overfitting. It requires having a big enough validation set that provides an accurate estimate of performance.

We enabled model selection if the validation set size was ≥ 32 for METL-Local and ≥ 128 for METL-Global. We found the optimization was more stable for METL-Local than METL-Global, thus smaller validation sets were still reliable.

Enable model selection by setting argument `--ckpt_monitor val`.


## Backbone cutoff for METL-Global
Finetuning METL-Global is largely the same as METL-Local, except we recommend using a different threshold for model selection (see above), as well as a different backbone cutoff.

For METL-Local, we set `--backbone_cutoff -1`, which attaches the new prediction head immediately after the final fully connected layer.

For METL-Global, we recommend setting `--backbone_cutoff -2`, which attaches the new prediction head immediately after the global pooling layer. We found this resulted in better finetuning performance for METL-Global.

# Running inference using finetuned model

The PyTorch Lightning framework supports inference, but while we put together a working example, we recommend converting the PyTorch Lightning checkpoint to pure PyTorch and using the [metl-pretrained](https://github.com/gitter-lab/metl-pretrained) package to run inference in pure PyTorch.

## Convert to PyTorch
Lightning checkpoints are compatible with pure pytorch, but they may contain additional items that are not needed for inference. This script loads the checkpoint and saves a smaller checkpoint with just the model weights and hyperparameters.

In [8]:
# the Lightning checkpoint from the finetuning we performed above
fine_tuning_dir_name = os.listdir('output/training_logs')[0]

ckpt_fn = f"output/training_logs/{fine_tuning_dir_name}/checkpoints/epoch=49-step=50.ckpt"

# run the conversion script
!python code/convert_ckpt.py --ckpt_path $ckpt_fn

Processing checkpoint: output/training_logs/fmngE6sB/checkpoints/epoch=49-step=50.ckpt
Saving converted checkpoint to: output/training_logs/fmngE6sB/checkpoints/fmngE6sB.pt


## Load checkpoint with metl-pretrained package
Using the Hugging Face wrapper, we can load the metl library and use it to load our newly trained model checkpoint and run inference with it.

In [9]:
from transformers import AutoModel

metl = AutoModel.from_pretrained('gitter-lab/METL', trust_remote_code=True)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/269 [00:00<?, ?B/s]

huggingface_wrapper.py:   0%|          | 0.00/95.9k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/gitter-lab/METL:
- huggingface_wrapper.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


model.safetensors:   0%|          | 0.00/176 [00:00<?, ?B/s]

Some weights of the model checkpoint at gitter-lab/METL were not used when initializing METLModel: ['model.bias', 'model.weight']
- This IS expected if you are initializing METLModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing METLModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [10]:
checkpoint_path = f"output/training_logs/{fine_tuning_dir_name}/checkpoints/{fine_tuning_dir_name}.pt"
metl.get_from_checkpoint(checkpoint_path)
model = metl.model
data_encoder = metl.encoder

Initialized PDB bucket matrices in: 0.000
Initialized PDB bucket matrices in: 0.000


  ckpt = torch.load(ckpt_fn, map_location="cpu")


## Run inference with pure PyTorch

In [11]:
import yaml
import torch

def load_dataset_metadata(metadata_fn: str = "data/dms_data/datasets.yml"):
    with open(metadata_fn, "r") as stream:
        try:
            datasets = yaml.safe_load(stream)
        except yaml.YAMLError as exc:
            print(exc)
    return datasets

In [12]:
# load the GFP wild-type sequence and the PDB file (needed for 3D RPE)
datasets = load_dataset_metadata()
wt = datasets["avgfp"]["wt_aa"]
pdb_fn = datasets["avgfp"]["pdb_fn"]

# some example GFP variants to compute the scores for
variants = ["E3K,G102S",
            "T36P,S203T,K207R",
            "V10A,D19G,F25S,E113V"]

encoded_variants = data_encoder.encode_variants(wt, variants)

# set model to eval mode
model.eval()

# no need to compute gradients for inference
with torch.no_grad():
    # note we are specifying the pdb_fn because this model uses 3D RPE
    predictions = model(torch.tensor(encoded_variants), pdb_fn=pdb_fn)

print(predictions)

tensor([[-0.3088],
        [-0.4627],
        [-0.5127]])
