# Pretrain METL on Rosetta data

This notebook demonstrates how to pretrain a METL source model on Rosetta data.  

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

# define the name of the project root directory
project_root_dir_name = "metl"

# find the project root by checking each parent directory
current_dir = os.getcwd()
while os.path.basename(current_dir) != project_root_dir_name and current_dir != os.path.dirname(current_dir):
    current_dir = os.path.dirname(current_dir)

# change the current working directory to the project root directory
if os.path.basename(current_dir) == project_root_dir_name:
    os.chdir(current_dir)
else:
    print("project root directory not found")
    
# add the project code folder to the system path so imports work
module_path = os.path.abspath("code")
if module_path not in sys.path:
    sys.path.append(module_path)

# Training arguments

The script for pretraining on Rosetta data is [train_source_model.py](../code/train_source_model.py).
This script has a number of arguments to specify various aspects of the model, training, and logging.
You can view the arguments by uncommenting and running the following line.

In [3]:
# %run code/train_source_model.py -h

Note this won't show model-specific arguments. For information on what arguments to use to set up the model, see the function in [code/models.py](../code/models.py) that corresponds to the `--model_name` you want to use.

We will set up arguments to pretrain a toy METL-Local model on the sample avGFP Rosetta dataset located in [data/rosetta_data](data/rosetta_data). See the README in that directory for more information about this sample dataset. 

The arguments are contained in the file [pretrain_avgfp_local.txt](../args/pretrain_avgfp_local.txt) in the [args](../args) directory. Uncomment and run the cell below to view the contents of the argument file. The sections below will walk through and explain the key arguments.

In [21]:
# with open("args/pretrain_avgfp_local.txt", "r") as file:
#     contents = file.read()
#     print(contents)

## Dataset arguments

Define the path to the Rosetta dataset (the .db file).

```
--ds_fn
data/rosetta_data/avgfp/avgfp.db
```

Define the path to the train/val/test split to use for this run.

```
--split_dir
data/rosetta_data/avgfp/splits/standard_tr0.8_tu0.1_te0.1_w1aea30517f4f_r4991
```


## Optimizer arguments
Source models are trained with the **AdamW** optimizer by default. See `RosettaTask.configure_optimizers()` in [tasks.py](../code/tasks.py) if you want to change the optimizer.

Basic optimizer arguments include the batch size, learning rate, and maximum number of epochs to train for. Unless early stopping is enabled, the model will train for the given number of epochs. 

```
--batch_size
128
--learning_rate
0.001
--max_epochs
30
```

Learning rate schedule determines how the learning rate changes over the course of training. We implemented a few different schedules and settled on a constant learning rate with a linear warmup period for our models. The number of warmup steps can be specified as an integer or a fraction of total training steps (determined by number of epochs and batch size).

```
--lr_scheduler
warmup_constant
--warmup_steps
.02
```


Gradient clipping can help prevent exploding gradients and make training smoother. We used gradient norm clipping with a threshold of 0.5.
```
--gradient_clip_val
0.5
```

## Model architecture arguments

METL models are transformer-based models. Our implementation supports a number of options including structure-based (3D) or sequence-based (1D) relative position embeddings, different numbers of layers, etc. 

The arguments below are what we used for our 2M parameter METL-Local models.

```
--model_name
transformer_encoder
```

The 3D structure-based relative position embeddings need a specified contact threshold and clipping threshold. The contact threshold determines the cutoff distance for contact map / structure graph, which the 3D RPE uses to compute neighboring residues. The clipping threshold determines the maximum relative distance cutoff. With a clipping threshold of 3, a relative distance of 0 represents a node with itself, 1 signifies direct neighbors, 2 signifies second degree neighbors, and 3 encapsulates any other node not covered by the previous categories.

```
--pos_encoding
relative_3D
--contact_threshold
8
--clipping_threshold
3
```

To use a 1D sequence-based embedding, you would instead specify `relative` for the `pos_encoding` and a `clipping_threshold` determining the maximum sequence-based relative distance cutoff. We used `clipping_threshold` of `8` for our 1D-based models.

The following arguments determine the transformer encoder architecture.
```
--embedding_len
256
--num_hidden
1024
--num_heads
4
--num_enc_layers
3
--enc_layer_dropout
0.1
--use_final_encoder_norm
--global_average_pooling
--use_final_hidden_layer
--final_hidden_size
256
--use_final_hidden_layer_norm
--final_hidden_layer_norm_before_activation
--use_final_hidden_layer_dropout
--final_hidden_layer_dropout_rate
0.1
--activation
relu
```

## Logging arguments
We have built in functionality for tracking model training with Weights & Biases. If you have a Weights and Biases account, set the environment variable `WANDB_API_KEY` to your API key and set the flag `--use_wandb` instead of `--no_use_wandb` below.

```
--no-use_wandb
--wandb_project
metl-source
--wandb_online
--experiment
default
```

Another flag that may be of interest is `--wandb_log_grad` which will additionally log gradients.

The below argument determines where to place the log directory locally.
```
--log_dir_base
output/training_logs
```

# Running training
In addition to the arguments explained above and defined the in the arguments file, we are going to limit the amount of training and testing data using `--limit_(train,val,test)_batches` to speed up training for this example. We also overwrite the number of training epochs from 30 to 5 to speed up this example.

The training script will do the following:
1. Assign a UUID for this model
2. Create a log directory with the UUID name in the base log directory [output/training_logs](../output/training_logs)
3. Print out various information including a summary of the model architecture (the output may not fit in this Jupyter Notebook)
4. Initiate training, which will be tracked via a progress bar printed in this notebook
5. Evaluate the final model on the test set and print metrics

Running the training script will generate output in the cell below and place training logs and checkpoints in the log directory in `output/training_logs`.

In [4]:
!python code/train_source_model.py @args/pretrain_avgfp_local.txt --max_epochs 5 --limit_train_batches 5 --limit_val_batches 5 --limit_test_batches 5


Global seed set to 1
Created model UUID: En8Ux9wH
Created log directory: output/training_logs/En8Ux9wH
Final UUID: En8Ux9wH
Final log directory: output/training_logs/En8Ux9wH
This is version: 0
Version-specific logs will be saved to: output/training_logs/En8Ux9wH/version_0
No checkpoint found, training from scratch
Using example_input_array with pdb_fn='1gfl_cm.pdb' and aa_seq_len=237
Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
Loading `train_dataloader` to estimate number of stepping batches.
  rank_zero_warn(
Number of training steps is 25
Number of warmup steps is 0.5
┏┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳┳┳┳┓
┃┃[1;35m [0m[1;35mName    

# Resuming training
If you need to train your model for more epochs, you can adjust the max_epochs and resume training by providing the existing UUID of model. Here, we provide the previous UUID `En8Ux9wH` and change the max_epochs to 7 to train for an additional 2 epochs.

In [5]:
!python code/train_source_model.py @args/pretrain_avgfp_local.txt --uuid En8Ux9wH --max_epochs 7 --limit_train_batches 5 --limit_val_batches 5 --limit_test_batches 5


Global seed set to 1
User gave model UUID: En8Ux9wH
Found existing log directory corresponding to given UUID: output/training_logs/En8Ux9wH
Final UUID: En8Ux9wH
Final log directory: output/training_logs/En8Ux9wH
This is version: 1
Version-specific logs will be saved to: output/training_logs/En8Ux9wH/version_1
Found checkpoint, resuming training from: output/training_logs/En8Ux9wH/checkpoints/last.ckpt
Using example_input_array with pdb_fn='1gfl_cm.pdb' and aa_seq_len=237
Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
Restoring states from the checkpoint path at output/training_logs/En8Ux9wH/checkpoints/last