# Zero-shot mutant prediction with Prime

This tutorial demonstrates how to predict the mutant effect of a protein using a pretrained model from the Prime model.

In this example, we will load the basic prime model (no tuning on homology sequence) and predict the effect of a mutation on the GAL4_YEAST_Kitzman_2015 exmperiment.

We provide:

- The wild sequence, a FASTA file.
- The mutant list, a CSV file.

Goals
Obtain an predicted score for each mutant.


## Config for imports

In [1]:
import sys
sys.path.append('..')

## Import the necessary modules

In [2]:
import torch
import pandas as pd
from prime.utils import read_seq, score_mutant
from prime.model import Config, ZeroShot
from tqdm.notebook import tqdm


## Read wildtype sequence and mutant list

In [3]:
wild_type = f"../protein_gym_data/fasta/GAL4_YEAST_Kitzman_2015.fasta"
mutant = f"../protein_gym_data/mutant/GAL4_YEAST_Kitzman_2015.csv"

In [4]:
sequence = read_seq(wild_type)
df = pd.read_csv(mutant)

In [5]:
sequence

'MKLLSSIEQACDICRLKKLKCSKEKPKCAKCLKNNWECRYSPKTKRSPLTRAHLTEVESRLERLEQLFLLIFPREDLDMILKMDSLQDIKALLTGLFVQDNVNKDAVTDRLASVETDMPLTLRQHRISATSSSEESSNKGQRQLTVSIDSAAHHDNSTIPLDFMPRDALHGFDWSEEDDMSDGLPFLKTDPNNNGFFGDGSLLCILRSIGFKPENYTNSNVNRLPTMITDRYTLASRSTTSRLLQSYLNNFHPYCPIVHSPTLMMLYNNQIEIASKDQWQILFNCILAIGAWCIEGESTDIDVFYYQNAKSHLTSKVFESGSIILVTALHLLSRYTQWRQKTNTSYNFHSFSIRMAISLGLNRDLPSSFSDSSILEQRRRIWWSVYSWEIQLSLLYGRSIQLSQNTISFPSSVDDVQRTTTGPTIYHGIIETARLLQVFTKIYELDKTVTAEKSPICAKKCLMICNEIEEVSRQAPKFLQMDISTTALTNLLKEHPWLSFTRFELKWKQLSLIIYVLRDFFTNFTQKKSQLEQDQNDHQSYEVKRCSIMLSDAAQRTVMSVSSYMDNHNVTPYFAWNCSYYLFNAVLVPIKTLLSNSKSNAENNETAQLLQQINTVLMLLKKLATFKIQTCEKYIQVLEEVCAPFLLSQCAIPLPHISYNNSNGSAIKNIVGSATIAQYPTLPEENVNNISVKYVSPGSVGPSPVPLKSGASFSDLVKLLSNRPPSRNSPVTIPRSTPSHRSVTPFLGQQQQLQSLVPLTPSALFGGANFNQSGNIADSSLSFTFTNSSNGPNLITTQTNSQALSQPIASSNVHDNFMNNEITASKIDDGNNSKPLSPGWTDQTAYNAFGITTGMFNTTTMDDVYNYLFDDEDTPPNPKKE'

In [6]:
df.head()

Unnamed: 0,mutant,score
0,K2R,-0.98
1,K2Y,0.13
2,K2W,-11.64
3,K2T,3.73
4,K2S,2.46


The 'score' column is the score of the mutant. Tt is only used for evaluation purpose. Tn practice, it is not available and you do not need the column, just the 'mutant' column is enough.

## Load Prime model

In [7]:
model_path = "../checkpoints/prime.pt"

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ZeroShot(Config())
model.load_state_dict(torch.load(model_path))
model.eval()
model = model.to(device)

### Compute the logits of the wild type sequence

In [9]:
with torch.no_grad():
    sequence_ids = model.tokenize(sequence).to(device)
    attention_mask = torch.ones_like(sequence_ids).to(device)
    logits = model(input_ids=sequence_ids, attention_mask=attention_mask)[0]
    logits = torch.log_softmax(logits, dim=-1)

### Score mutants

In [10]:
scores = []
for mutant in tqdm(df["mutant"]):
    score = score_mutant(
        mutant, sequence, logits=logits, vocab=model.VOCAB, offset=1
    )
    scores.append(score)
df["predict_score"] = scores

  0%|          | 0/1195 [00:00<?, ?it/s]

### Check the results

In [11]:
df.head()

Unnamed: 0,mutant,score,predict_score
0,K2R,-0.98,0.029868
1,K2Y,0.13,0.147601
2,K2W,-11.64,-0.089627
3,K2T,3.73,0.212623
4,K2S,2.46,0.131806


### Evaluation (optional)

In [12]:
from scipy.stats import spearmanr

In [13]:
spearmanr(df["score"], df["predict_score"])

SignificanceResult(statistic=0.6680778571562757, pvalue=2.445648005445124e-155)