# Inference with METL models
This notebook shows how to run inference with METL models trained in this repository.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

# define the name of the project root directory
project_root_dir_name = "metl"

# find the project root by checking each parent directory
current_dir = os.getcwd()
while os.path.basename(current_dir) != project_root_dir_name and current_dir != os.path.dirname(current_dir):
    current_dir = os.path.dirname(current_dir)

# change the current working directory to the project root directory
if os.path.basename(current_dir) == project_root_dir_name:
    os.chdir(current_dir)
else:
    print("project root directory not found")
    
# add the project code folder to the system path so imports work
module_path = os.path.abspath("code")
if module_path not in sys.path:
    sys.path.append(module_path)

# Using our inference framework

We provide the script [inference.py](../code/inference.py) for running inference with models trained in this repository. It supports similar arguments and datamodule capabilities as used for training the models. 

The arguements `--write_interval` and `--batch_write_mode` control how often predictions are saved and in what format. 

The `write_interval` can be set to "batch", "epoch", or "batch_and_epoch". When set to "batch", predictions will be saved to disk after each batch. When set to "epoch", predictions will first be stored in RAM until all data has been processed, and then they will be written to disk. If you have a lot of data which might not fit in RAM, then you will want to set `--write_interval` to "batch" (default).

The `--batch_write_mode` can be set to "combined_csv", "separate_csv", or "separate_npy". When set to "combined_csv", there will be a single output csv file, and it will be appended to after each batch is processed. When set to either "separate_csv" or "separate_npy", there will be a separate output file for each batch in either .csv or .npy format. 

## Source model example
This repository contains a sample GFP Rosetta dataset and a pretrained METL-Local GFP source model, which we can use as examples. 

We specify the following arguments:

| Argument               | Description                                                | Value                                      |
|:------------------------|:------------------------------------------------------------|:--------------------------------------------|
| `pretrained_ckpt_path` | Path to the pretrained model checkpoint                    | `pretrained_models/Hr4GNHws.pt`            |
| `dataset_type`         | Type of dataset being used (rosetta or dms)                                | `rosetta`                                  |
| `ds_fn`                | Path to the database file for the dataset                  | `data/rosetta_data/avgfp/avgfp.db`         |
| `batch_size`           | Batch size used during inference               | `512`                                     |

The inference script will automatically save output in the `output/inference` directory. There will be an output csv file for each processed batch.

In [3]:
!python code/inference.py --pretrained_ckpt_path=pretrained_models/Hr4GNHws.pt --dataset_type=rosetta --ds_fn=data/rosetta_data/avgfp/avgfp.db --batch_size=512

Using example_input_array with pdb_fn='1gfl_cm.pdb' and aa_seq_len=237
Output directory: output/inference/Hr4GNHws/rosetta_avgfp/full_dataset
Writing predictions to output/inference/Hr4GNHws/rosetta_avgfp/full_dataset/predictions.npy
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Predicting DataLoader 0: 100%|██████████████████| 20/20 [00:08<00:00,  2.41it/s]


By default, the script will compute predictions for the full dataset. If you only need to save predictions for a particular train, validation, or test set, you can do so by setting the `--split_dir` and `--predict_mode` arguments. The function call below will compute predictions just for the test set.

In [4]:
!python code/inference.py --pretrained_ckpt_path=pretrained_models/Hr4GNHws.pt --dataset_type=rosetta --ds_fn=data/rosetta_data/avgfp/avgfp.db --batch_size=512 --split_dir=data/rosetta_data/avgfp/splits/standard_tr0.8_tu0.1_te0.1_w1aea30517f4f_r4991 --predict_mode=test

Using example_input_array with pdb_fn='1gfl_cm.pdb' and aa_seq_len=237
Output directory: output/inference/Hr4GNHws/rosetta_avgfp/standard_tr0.8_tu0.1_te0.1_w1aea30517f4f_r4991/test
Writing predictions to output/inference/Hr4GNHws/rosetta_avgfp/standard_tr0.8_tu0.1_te0.1_w1aea30517f4f_r4991/test/predictions.npy
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Predicting DataLoader 0: 100%|████████████████████| 2/2 [00:01<00:00,  1.10it/s]


## Target (finetuned) model example
We first need to finetune a model using experimental data. Run the command below, which will finetune the pretrained model above using the GFP experimental dataset. Note we manually specify the UUID `examplemodel` for this model. See the [finetuning.ipynb](finetuning.ipynb) notebook for more details. 

In [5]:
!python code/train_target_model.py @args/finetune_avgfp_local.txt --enable_progress_bar false --enable_simple_progress_messages --max_epochs 50 --unfreeze_backbone_at_epoch 25 --uuid examplemodel  

Random seed not specified, using: 855922268
Global seed set to 855922268
User gave model UUID: examplemodel
Did not find existing log directory corresponding to given UUID: examplemodel
Created log directory: output/training_logs/examplemodel
Final UUID: examplemodel
Final log directory: output/training_logs/examplemodel
Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Loading `train_dataloader` to estimate number of stepping batches.
  rank_zero_warn(
Number of training steps is 50
Number of warmup steps is 0.5
Second warmup phase starts at step 25
total_steps 50
phase1_total_steps 25
phase2_total_steps 25
┏━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃[1;35m [0m[

We can now run inference with this finetuned model using the [inference.py](../code/inference.py) script.

| Argument                   | Description                                 | Value                                                             |
|:---------------------------|:---------------------------------------------|:------------------------------------------------------------------|
| `pretrained_ckpt_path`     | Path to the pretrained model checkpoint     | `output/training_logs/examplemodel/checkpoints/epoch=49-step=50.ckpt` |
| `dataset_type`             | Type of dataset being used (rosetta or dms)                 | `dms`                                                             |
| `ds_name`                  | Name of the predefined dataset to use       | `avgfp`                                                           |
| `encoding`                 | Input encoding method (should be int_seqs for transformer-based METL models)                       | `int_seqs`                                                        |
| `predict_mode`             | Prediction mode for inference               | `full_dataset`                                                    |
| `batch_size`               | Batch size used during inference            | `512`                                                             |

In [6]:
!python code/inference.py --pretrained_ckpt_path=output/training_logs/examplemodel/checkpoints/epoch=49-step=50.ckpt --dataset_type=dms --ds_name=avgfp --encoding=int_seqs --predict_mode full_dataset --batch_size 512 

Output directory: output/inference/examplemodel/dms_avgfp
Writing predictions to output/inference/examplemodel/dms_avgfp/predictions.npy
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Predicting DataLoader 0: 100%|████████████████| 102/102 [00:39<00:00,  2.59it/s]


# Using your own inference loop
If you prefer to have more control and run your own inference loop, we provide easy to use functions to load models and encode data. The [inference.py](../code/inference.py) file contains the functions to load models, and the [encode.py](../code/encode.py) file contains the functions to encode data. 

In [7]:
import torch

import inference
import encode as enc
import utils  # for loading dataset metadata

First, let's load the GFP wild-type sequence and pdb filename from the predefined dataset. This information is necessary to encode variants in the correct format and use models with 3D relative position embeddings.

In [8]:
datasets = utils.load_dataset_metadata()
wt = datasets["avgfp"]["wt_aa"]
wt_offset = datasets["avgfp"]["wt_ofs"]
pdb_fn = datasets["avgfp"]["pdb_fn"]

You can load a model using `inference.load_pytorch_module()`. It supports both source and target models, in either PyTorch Lightning's `.ckpt` format or regular PyTorch `.pt` format. You can use keyword arguments to override any hyperparameters that may be stored in the checkpoint.

In [9]:
source_model = inference.load_pytorch_module("pretrained_models/Hr4GNHws.pt")
# for the target model, use the example model we trained above
target_model = inference.load_pytorch_module("output/training_logs/examplemodel/checkpoints/epoch=49-step=50.ckpt")



You might get a warning about transforming checkpoint keys. This transformation happens automatically in the background. It's necessary when loading a PyTorch Lightning checkpoint because our PyTorch Lightning module wraps the model, and the resulting checkpoint has an additional prefix in the state dictionary keys. If you wanted to, you could fix this by converting the saved checkpoint with [convert_ckpt.py](../code/convert_ckpt.py), but otherwise you can safely ignore this warning. 

Let's define some variants that we want to feed through the models. Note these variants are 0-based indexing. You can use 1-based indexing if you like, but be sure to specify the correct type of indexing in the encode function below. The default is 0-based indexing. 

In [10]:
variants = ["E3K,G102S",
            "T36P,S203T,K207R"]

You can encode variants using the `enc.encode()` function. The correct encoding for transformer-based METL models is "int_seqs". If you trained a custom model with a different encoding, such as one hot encoding, you would specify that instead. 

In [11]:
encoded_variants = enc.encode(
    encoding="int_seqs",
    variants=variants,
    wt_aa=wt,
    wt_offset=wt_offset,
    indexing="0_indexed"
)

You can also encode full sequences instead of variants by specifying `char_seqs`. In that case, there would be no need to specify `variants`, `wt_aa`, `wt_offset`, or `indexing`. The cell below shows an example.  

In [12]:
full_seqs = ["SMART", "MAGIC"]  # sample amino acid sequences
encoding_example = enc.encode(encoding="int_seqs", char_seqs=full_seqs)
print(encoding_example)

[[16 11  1 15 17]
 [11  1  6  8  2]]


Finally, we can run inference by calling the model with our encoded variants.

In [13]:
# set model to eval mode
source_model.eval()

# no need to compute gradients for inference
with torch.no_grad():
    # note we are specifying the pdb_fn because this model uses 3D RPE
    predictions = source_model(torch.tensor(encoded_variants), pdb_fn=pdb_fn)

print(predictions)

tensor([[ 0.2894,  0.1854, -0.5471, -0.0754, -0.3709,  0.1079, -0.4791,  0.2374,
          0.1379,  0.9030,  0.3907,  0.5771,  0.3447,  0.3692,  0.4965, -0.4149,
          0.1715,  0.1173,  0.1156, -0.2475,  0.0904,  0.1284,  1.1474,  0.8472,
          0.3155,  0.5036,  0.5245,  0.4521, -0.8744,  0.2048,  0.5267,  0.5939,
         -0.3658, -0.0320, -0.1717,  0.2009,  1.0826, -0.0399,  0.3710,  0.3503,
          0.3202,  0.5267, -0.0226,  1.1644, -0.1571, -1.3753,  0.4995, -0.9204,
          0.1762,  0.8513,  0.3808, -0.9562, -0.0644, -0.1857, -0.2437],
        [-0.1340, -0.3133, -0.8069, -0.0609, -0.0690, -0.5090,  1.5713, -0.6018,
          1.5414, -0.5417,  0.0078,  0.2171,  1.4525,  0.1137, -0.0883,  0.8100,
          0.1789, -0.1843, -0.2339, -0.0281,  0.0298, -0.2055, -0.5315,  0.9129,
          0.0417, -0.1724,  0.2707, -0.7365, -0.2668,  1.3744,  0.5445, -0.0232,
          0.0581, -0.1810,  0.0594, -0.4147,  0.4928, -0.0676, -0.9483,  0.2872,
          0.3291,  0.5446, -0.8568, 

The source model (above) outputs predictions for each Rosetta energy term. The target model (below), which was finetuned on the experimental GFP dataset, outputs a functional score prediction (brightness).

In [14]:
# set model to eval mode
target_model.eval()

# no need to compute gradients for inference
with torch.no_grad():
    # note we are specifying the pdb_fn because this model uses 3D RPE
    predictions = target_model(torch.tensor(encoded_variants), pdb_fn=pdb_fn)

print(predictions)

tensor([[-0.4834],
        [-0.4877]])
