In [1]:
%load_ext autoreload
%autoreload 2
import torch
import esm
from Bio import SeqIO
from Bio.Seq import Seq
import pandas as pd
import numpy as np
from numpy import dot
from numpy.linalg import norm
from Shared_Functions import *
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
import matplotlib.pyplot as plt
from scipy.special import softmax

# Wuhan-Hu-1 Spike Epistasis

In [2]:
model, alphabet = esm.pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
model.eval()
batch_converter = alphabet.get_batch_converter()
device = torch.device("cuda:0")
if torch.cuda.is_available():
    model =  model.to(device)
    print("Transferred model to GPU")

In [3]:
from Bio import Entrez
from Bio import SeqIO

Entrez.email = "sample@example.org"

handle = Entrez.efetch(db="nucleotide",
                       id="NC_045512",
                       rettype="gb",
                       retmode="gb")
whole_sequence = SeqIO.read(handle, "genbank")
model_layers = 36

In [5]:
reference_embeddings = process_sequence_genbank(whole_sequence.seq,whole_sequence,model,model_layers,device,batch_converter,alphabet)

In [6]:
amino_acids = ["A","R","N","D","C","Q","E","G","H","I","L","K","M","F","P","S","T","W","Y","V"]

In [7]:
#Code for every mutatation in spike
# mutated_spikes = DMS_Table(reference_embeddings['S:0']['Sequence'])

#Few mutations example
mutated_spikes = [
    ['S13I',mutate_sequence(reference_embeddings['S:0']['Sequence'],['S13I'])],
    ['W152C',mutate_sequence(reference_embeddings['S:0']['Sequence'],['W152C'])],
    ['S13I,W152C',mutate_sequence(reference_embeddings['S:0']['Sequence'],['S13I','W152C'])]]
mutated_spikes = pd.DataFrame(mutated_spikes,columns = ['Mutations','Sequence'])
mutated_spikes

Unnamed: 0,Mutations,Sequence
0,S13I,MFVFLVLLPLVSIQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...
1,W152C,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...
2,"S13I,W152C",MFVFLVLLPLVSIQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...


In [9]:
all_results = {}
for i,row in mutated_spikes.iterrows():
    all_results[row.Mutations] = process_protein_sequence(row.Sequence,model,model_layers,batch_converter,alphabet,device)
    all_results[row.Mutations]['Logits'] = pd.DataFrame(all_results[row.Mutations]['Logits'][1:-1])
    all_results[row.Mutations]['Logits'].columns = alphabet.all_toks
    all_results[row.Mutations]['Logits'] = all_results[row.Mutations]['Logits'].T.loc[amino_acids].T
    all_results[row.Mutations]['Logits'].index = all_results[row.Mutations]['Logits'].index+1

In [10]:
reference_logits = reference_embeddings['S:0']['Logits'][1:-1]
reference_logits = pd.DataFrame(reference_logits)
reference_logits.columns = alphabet.all_toks
reference_logits = reference_logits.T.loc[amino_acids].T
reference_logits.index = reference_logits.index+1
reference_embedding =reference_embeddings['S:0']['Mean_Embedding']

In [11]:
all_dfs = []
for mutant in all_results.keys():
    print(mutant)
    reference_probabilities = pd.DataFrame(softmax(reference_logits[amino_acids].values,1),columns = amino_acids)
    mutant_probabilities = pd.DataFrame(softmax(all_results[mutant]['Logits'][amino_acids].values,1),columns = amino_acids)

    #We minus by the reference to determine if the mutant has a larger or smaller probability than the reference.
    #Positions that are positive have larger probabilities than the reference while positions that are negative have probabilities smaller than the reference.
    df = mutant_probabilities-reference_probabilities 
    mutant_list = mutant.split(',')
    reference_rows = []
    mutant_list = mutant.split(',')
    for j,row in df.iterrows():
        ref_amino = reference_embeddings['S:0']['Sequence'][j]
        reference_values =  pd.DataFrame([j+1, ref_amino, row[ref_amino],mutant]).T
        reference_values.columns = ['pos','reference','change','mutation']
        for m in mutant_list :
            if int(m[1:-1]) == reference_values.pos[0]: 
                reference_values.change = 0
        reference_rows.append(reference_values)
    df = pd.concat(reference_rows)  
    all_dfs.append(df)

S13I
W152C
S13I,W152C


In [13]:
rounded_all_dfs = pd.concat(all_dfs)
rounded_all_dfs.change = np.around(rounded_all_dfs.change.astype(float),5)

In [15]:
import plotly.express as px
df = px.data.tips()
fig = px.line(rounded_all_dfs, x="pos", y="change",facet_col='mutation',color="mutation",facet_col_wrap=3,height=600, width=1600, hover_data=['pos','reference','change','mutation'])
fig.update_traces(marker={'size': 3})
fig.show()

In [16]:
rounded_all_dfs.to_csv('DMS/Results/Epistasis/Spike_Epistasis_Softmax.csv',index=False)