# Using trained models to make predictions for new variants

This notebook shows how to use trained models to make predictions for new variants.

Prerequisites
- A trained model. You can use the pre-trained models we provide in the `pub/trained_models` directory, train your own models similar to ours using the arguments in the `pub/regression_args` directory, or train your own models using your preferred arguments.

Main steps
- Encode variants with a combination of one-hot and AAindex encoding.
- Use the trained models to get predictions for those variants.

In [1]:
# reload modules before executing code in order to make development and debugging easier
%load_ext autoreload
%autoreload 2

In [2]:
# this jupyter notebook is running inside of the "notebooks" directory
# for relative paths to work properly, we need to set the current working directory to the root of the project
# for imports to work properly, we need to add the code folder to the system path
import os
from os.path import abspath, join, isdir
import sys
if not isdir("notebooks"):
    # if there's a "notebooks" directory in the cwd, we've already set the cwd so no need to do it again
    os.chdir("..")
module_path = abspath("code")
if module_path not in sys.path:
    sys.path.append(module_path)

In [3]:
import numpy as np
import constants
import utils
import encode as enc
import inference as inf

# Encode variants
Using a simple example with a few avGFP variants. For more detailed information on how to encode variants, check out the data encoding notebook. 

In [4]:
variants = ["Y64C,E170V", "I126T,N210H", "E15V,D17G,I169F", "A108G"]
# specifying "ds_name" only works if the dataset is defined in constants.py
# alternatively, you can specify the wild-type sequence and offset
# see the encoding notebook for details
encoded_variants = enc.encode(encoding="one_hot,aa_index", variants=variants, ds_name="avgfp")
encoded_variants.shape

(4, 237, 40)

# Use a pre-trained model to get predictions for these variants
The saved models consist of three files (meta, index, data). This is due to how TensorFlow saves checkpoints. A single saved model will have the same prefix for all three files. Using the avGFP linear regression model as an example. 

In [5]:
lr_prefix = "pub/trained_models/avgfp/avgfp_lr"  # just the prefix, no file extension
lr_predictions = inf.run_inference(encoded_data=encoded_variants, ckpt_prefix_fn=lr_prefix)
lr_predictions

W0910 19:33:12.652569 4449762752 deprecation.py:323] From /Users/sg/miniconda3/envs/rad_mutants/lib/python3.6/site-packages/tensorflow/python/training/saver.py:1276: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to check for files with this prefix.


array([2.688005 , 3.7444515, 2.8850722, 3.215506 ], dtype=float32)

# Predictions for lots of variants (batches)
If you want to make predictions for >64 variants a time, the script will automatically break the input into batches of size 64. You can change the batch size by using the `batch_size` argument to `inf.run_inference()`. A progress bar will show how progress through the batches.

In [6]:
lots_of_variants = enc.encode(encoding="one_hot,aa_index", variants=["Y64C,E170V"] * 200, ds_name="avgfp")
preds = inf.run_inference(encoded_data=lots_of_variants, ckpt_prefix_fn=lr_prefix, batch_size=32)

100%|██████████| 7/7 [00:00<00:00, 149.98it/s]


# Repeated predictions (single session)
If you need to run inference many times in a loop, the code above is inefficient since it restores the TensorFlow model on each call to `inf.run_inference()`. 
You can create a single TensorFlow session to use in the loop instead. 


In [7]:
# open the session
lr_sess = inf.restore_sess(lr_prefix)

# run inference many times in a loop
for i in range(3):
    display(inf.run_inference(encoded_data=encoded_variants, sess=lr_sess))

# close the session when you're done
lr_sess.close()

array([2.688005 , 3.7444515, 2.8850722, 3.215506 ], dtype=float32)

array([2.688005 , 3.7444515, 2.8850722, 3.215506 ], dtype=float32)

array([2.688005 , 3.7444515, 2.8850722, 3.215506 ], dtype=float32)

You can have multiple sessions open at the same time.

In [8]:
lr_sess = inf.restore_sess(lr_prefix)
cnn_prefix = "pub/trained_models/avgfp/avgfp_cnn"
cnn_sess = inf.restore_sess(cnn_prefix)
for i in range(3):
    print("LR:", inf.run_inference(encoded_data=encoded_variants, sess=lr_sess))
    print("CNN:", inf.run_inference(encoded_data=encoded_variants, sess=cnn_sess))
lr_sess.close()
cnn_sess.close()

LR: [2.688005  3.7444515 2.8850722 3.215506 ]
CNN: [1.4410194 3.88649   2.8706326 3.687386 ]
LR: [2.688005  3.7444515 2.8850722 3.215506 ]
CNN: [1.4410194 3.88649   2.8706326 3.687386 ]
LR: [2.688005  3.7444515 2.8850722 3.215506 ]
CNN: [1.4410194 3.88649   2.8706326 3.687386 ]
