In [2]:
%load_ext autoreload
%autoreload 2
import torch
import esm
from Bio import SeqIO
from Bio.Seq import Seq
import pandas as pd
import numpy as np
from numpy import dot
from numpy.linalg import norm
from Shared_Functions import *
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
import matplotlib.pyplot as plt
from scipy.special import softmax


# Experiment: Embed Probabilities for looing at epi-stasis 

In [3]:
model, alphabet = esm.pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
model.eval()
batch_converter = alphabet.get_batch_converter()
device = torch.device("cuda:0")
if torch.cuda.is_available():
    model =  model.to(device)
    print("Transferred model to GPU")

In [4]:
from Bio import Entrez
from Bio import SeqIO

Entrez.email = "sample@example.org"

handle = Entrez.efetch(db="nucleotide",
                       id="NC_045512",
                       rettype="gb",
                       retmode="gb")
whole_sequence = SeqIO.read(handle, "genbank")
model_layers = 36

In [5]:
reference_embeddings = process_sequence_genbank(whole_sequence.seq,whole_sequence,model,model_layers,device,batch_converter,alphabet)

# Generate Omicron reference

In [6]:
reference_sequences_aligned = SeqIO.to_dict(SeqIO.parse('Sequences/BA1_with_Wuhan_reference.fasta', 'fasta'))
reference_sequences_aligned

{'BA.1': SeqRecord(seq=Seq('MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSSVLHSTQDL...HYT'), id='BA.1', name='BA.1', description='BA.1', dbxrefs=[]),
 'Wuhan': SeqRecord(seq=Seq('MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSSVLHSTQDL...HYT'), id='Wuhan', name='Wuhan', description='Wuhan', dbxrefs=[])}

In [7]:
amino_acids = ["A","R","N","D","C","Q","E","G","H","I","L","K","M","F","P","S","T","W","Y","V"]

In [8]:
reversion_mutations = get_mutations(reference_sequences_aligned['BA.1'].seq,reference_sequences_aligned['Wuhan'].seq)
# reversion_mutations

In [9]:
reverted_positions = {}
for mutation in reversion_mutations:
    mut_seq = mutate_sequence(reference_sequences_aligned['BA.1'].seq,[mutation])
    if mut_seq != None:
        reverted_positions[mutation] =  {'Gapped':str(mut_seq),'Ungapped':str(mut_seq.replace('-',''))}

In [10]:
reverted_positions['Reference'] = {'Gapped':str(reference_sequences_aligned['BA.1'].seq),'Ungapped':str(reference_sequences_aligned['BA.1'].seq).replace('-','')}
reverted_positions['Reference']

{'Gapped': 'MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSSVLHSTQDLFLPFFSNVTWFHVI--SGTNGTKRFDNPVLPFNDGVYFASIEKSNIIRGWIFGTTLDSKTQSLLIVNNATNVVIKVCEFQFCNDPFLD---HKNNKSWMESEFRVYSSANNCTFEYVSQPFLMDLEGKQGNFKNLREFVFKNIDGYFKIYSKHTPI-IVREPADLPQGFSALEPLVDLPIGINITRFQTLLALHRSYLTPGDSSSGWTAGAAAYYVGYLQPRTFLLKYNENGTITDAVDCALDPLSETKCTLKSFTVEKGIYQTSNFRVQPTESIVRFPNITNLCPFDEVFNATRFASVYAWNRKRISNCVADYSVLYNLAPFFTFKCYGVSPTKLNDLCFTNVYADSFVIRGDEVRQIAPGQTGNIADYNYKLPDDFTGCVIAWNSNKLDSKVSGNYNYLYRLFRKSNLKPFERDISTEIYQAGNKPCNGVAGFNCYFPLRSYSFRPTYGVGHQPYRVVVLSFELLHAPATVCGPKKSTNLVKNKCVNFNFNGLKGTGVLTESNKKFLPFQQFGRDIADTTDAVRDPQTLEILDITPCSFGGVSVITPGTNTSNQVAVLYQGVNCTEVPVAIHADQLTPTWRVYSTGSNVFQTRAGCLIGAEYVNNSYECDIPIGAGICASYQTQTKSHRRARSVASQSIIAYTMSLGAENSVAYSNNSIAIPTNFTISVTTEILPVSMTKTSVDCTMYICGDSTECSNLLLQYGSFCTQLKRALTGIAVEQDKNTQEVFAQVKQIYKTPPIKYFGGFNFSQILPDPSKPSKRSFIEDLLFNKVTLADAGFIKQYGDCLGDIAARDLICAQKFKGLTVLPPLLTDEMIAQYTSALLAGTITSGWTFGAGAALQIPFAMQMAYRFNGIGVTQNVLYENQKLIANQFNSAIGKIQDSLSSTASALGKLQDVVNHNAQALNTLVKQLSSKFGAISSVLNDIFSRLD

In [11]:
reverted_positions_dataframe = pd.DataFrame(reverted_positions).T
reverted_positions_dataframe

Unnamed: 0,Gapped,Ungapped
V67A,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...
-69H,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...
-70V,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...
I95T,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...
D142G,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...
-143V,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...
-144Y,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...
-145Y,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...
-211N,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...
I212L,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...,MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSS...


# REMOVE CLS AND END OF SEQUENCE TOKENS!!!!!!!

In [12]:
all_results = {}
for i,row in reverted_positions_dataframe.iterrows():
    print(i)
    all_results[i] = process_protein_sequence(row.Ungapped,model,model_layers,batch_converter,alphabet,device)
    all_results[i]['Logits'] = pd.DataFrame(all_results[i]['Logits'][1:-1])
    all_results[i]['Logits'].columns = alphabet.all_toks
    all_results[i]['Logits'] = all_results[i]['Logits'].T.loc[amino_acids].T
    all_results[i]['Logits'].index = all_results[i]['Logits'].index+1

V67A
-69H
-70V
I95T
D142G
-143V
-144Y
-145Y
-211N
I212L
E215del
P216del
A217del
D342G
L374S
P376S
F378S
N420K
K443N
S449G
N480S
K481T
A487E
R496Q
S499G
R501Q
Y504N
H508Y
K550T
G617D
Y658H
K682N
H684P
K767N
Y799D
K859N
H957Q
K972N
F984L
Reference


In [13]:
reference_logits = all_results['Reference']['Logits']
reference_embedding =all_results['Reference']['Mean_Embedding']

amino_acids_from_logits = all_results['Reference']['Logits'].columns
amino_acids_from_logits

Index(['A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F',
       'P', 'S', 'T', 'W', 'Y', 'V'],
      dtype='object')

# Find Insertion Unaligned position

In [14]:
def map_omicron_to_wuhan(refseq_wuhan,refseq_omicron_gapped,refseq_omicron_ungapped,logits):
    #Make Blank rows for gaps
    blank_row_for_gap = pd.DataFrame(np.full(20,np.nan)).T
    blank_row_for_gap.columns = amino_acids_from_logits
    #Re-index on gaps
    gap_seqs = []
    refseq_wuhan = reference_sequences_aligned['Wuhan'].seq
    subtract_insertion_number = 0
    for i,amino in enumerate(refseq_omicron_gapped):
        if refseq_wuhan[i] == '-':
            subtract_insertion_number+=1
        if amino == '-':
            insert_blank_row_for_gap = blank_row_for_gap.copy()
            new_index = i-subtract_insertion_number
            insert_blank_row_for_gap.index = [new_index]
            gap_seqs.append(insert_blank_row_for_gap)
    gap_seqs_df = pd.concat(gap_seqs)

    indexes = [0]
    ngaps = ""
    for i,v in enumerate(gap_seqs_df.index.values[1:]):
        if v-1 == gap_seqs_df.index.values[i]:
            ngaps+="-"
            indexes.append(indexes[i]+1)

        else:
            ngaps+="|"
            indexes.append(ngaps.count('-')+ngaps.count('|'))
    gap_seqs_df.index = (gap_seqs_df.index.values - np.array(indexes))+0.5
    gap_seqs_df['Omicron_Position'] = gap_seqs_df.index 
    
    # Map Gapped Positions to sequence characters (mostly for test purposes)
    logits['Omicron_Reference'] = list(refseq_omicron_ungapped)
    logits['Omicron_Position']= logits.index
    logits = pd.concat([logits, gap_seqs_df]).sort_values('Omicron_Position')  
    #Map index to sequence position
    logits = logits.reset_index()
    logits.index = logits.index +1
    #Remove insertion at known insertion site (VERY NOT GENERALISABLE, DO NOT USE FOR OTHER OMICRONS OR SEQS)
    #Drop_insertion
    logits = logits.drop([215,216,217],axis=0)
    # temp = temp.drop(['index','pos'],axis=1)
    #Remap index to sequence position now insertion is gone
    logits = logits.reset_index()
    logits.index = logits.index +1
    logits = logits.drop(['index','level_0'],axis=1)
    logits['Wuhan_Reference'] = list(refseq_wuhan.replace('-',''))
    logits['Wuhan_Position'] = logits.index
    return logits
    

In [15]:
Annotation_Columns = ['Omicron_Reference','Omicron_Position','Wuhan_Reference','Wuhan_Position']

# Heatmap

In [16]:

mutation_map = pd.DataFrame([['67V','95I','142D','212I','339D','371L','373P','375F','417N','440K','446S','477N','478K','484A','493R','496S','498R','501Y','505H','547K','614G','655Y','679K','681H','764K','796Y','856K','954H','969K','981F'],[m for m in list(reverted_positions_dataframe.index) if ('-' not in m) and ('del' not in m) ]]).T
mutation_map

Unnamed: 0,0,1
0,67V,V67A
1,95I,I95T
2,142D,D142G
3,212I,I212L
4,339D,D342G
5,371L,L374S
6,373P,P376S
7,375F,F378S
8,417N,N420K
9,440K,K443N


In [17]:
Mapped_reference_logits = map_omicron_to_wuhan(reference_sequences_aligned['Wuhan'].seq,
                     reference_sequences_aligned['BA.1'].seq,
                     reverted_positions_dataframe.loc['Reference'].Ungapped,
                     all_results['Reference']['Logits'])

In [18]:
Mapped_reference_logits[amino_acids]

Unnamed: 0,A,R,N,D,C,Q,E,G,H,I,L,K,M,F,P,S,T,W,Y,V
1,-10.400224,-11.848193,-11.863545,-13.121102,-11.839265,-12.598226,-13.108079,-10.771518,-12.459828,-10.639020,-9.266152,-11.791100,-0.000451,-9.118547,-10.963687,-10.380225,-11.107462,-11.155540,-11.364346,-10.107835
2,-7.125550,-6.670306,-7.305034,-9.736857,-8.892907,-8.503089,-8.509578,-8.362180,-8.157375,-3.596326,-2.449400,-7.124961,-6.851434,-0.147262,-9.030330,-6.704473,-7.480440,-4.464797,-6.272464,-6.153117
3,-4.074319,-6.699069,-6.802918,-9.358687,-6.095057,-7.096568,-8.075085,-5.243113,-6.457380,-3.682251,-2.521280,-7.587955,-6.406395,-3.738767,-5.724722,-4.063729,-4.868388,-4.468842,-5.228050,-0.230804
4,-5.909041,-10.482908,-9.814790,-12.383938,-7.636454,-9.309292,-10.823387,-7.532139,-9.117239,-4.391000,-2.810306,-10.291072,-7.994823,-0.094790,-7.274489,-6.354386,-6.671182,-5.995200,-6.786162,-5.094680
5,-6.228170,-11.550839,-10.759787,-13.017173,-7.491961,-10.593906,-12.442479,-8.269938,-10.846786,-5.676775,-0.023728,-12.407825,-7.886197,-4.484540,-9.360921,-6.876320,-7.336955,-7.299205,-7.994340,-5.910733
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1269,-4.905125,-4.075939,-4.823470,-5.496413,-8.183576,-5.077892,-5.254810,-5.180232,-5.415587,-6.157300,-6.226329,-0.109359,-7.021934,-7.196352,-4.385451,-4.568635,-4.693594,-9.027787,-6.102598,-5.264412
1270,-3.916694,-3.137348,-4.066595,-5.399303,-6.057154,-4.399049,-4.994042,-3.886073,-4.146449,-3.492923,-0.522582,-2.985913,-5.155276,-4.992611,-3.235482,-3.804076,-3.356941,-7.290349,-4.289998,-2.813949
1271,-3.777554,-2.458016,-3.193762,-4.052520,-6.310438,-3.434392,-3.836904,-3.916681,-0.851010,-3.981353,-4.356826,-1.982015,-5.140470,-5.374985,-3.874377,-3.069950,-3.199960,-7.222559,-4.805934,-3.383032
1272,-3.728924,-2.806745,-3.531278,-4.390131,-6.063446,-3.792853,-4.267377,-4.310506,-3.930003,-4.091583,-3.677981,-2.308872,-4.960357,-4.719212,-3.190723,-3.124561,-3.285889,-6.653604,-0.700367,-3.711202


In [82]:
all_dfs = []
for i,mut in enumerate(mutation_map[1][:-1]):
    mutant_probabilities = map_omicron_to_wuhan(reference_sequences_aligned['Wuhan'].seq,
                         reference_sequences_aligned['BA.1'].seq,
                         reverted_positions_dataframe.loc[mut].Ungapped,
                         all_results[mut]['Logits'])[amino_acids]
    mutant_probabilities = pd.DataFrame(softmax(mutant_probabilities,-1).T)
    
    reference_probabilities = pd.DataFrame(softmax(Mapped_reference_logits[amino_acids],-1).T)
#     print(reference_probabilities.sum())
    df = pd.DataFrame(reference_probabilities - mutant_probabilities).T
    df.columns = amino_acids
    reference_rows = []
    other_mut = mutation_map[0][:-1][i]
    for j,row in df.iterrows():
        pos = j+1
        amino = str(Mapped_reference_logits.Omicron_Reference.iloc[j])
        #Mask out reference change since this is usually massive
        if int(other_mut[:-1]) != j+1:  
            if amino == 'nan':
                val = 0
            else:
                val = row[Mapped_reference_logits.Omicron_Reference.iloc[j]]
            new_row = pd.DataFrame([pos,amino,val])
            reference_rows.append(new_row.T)
        else:
            print(i,j,amino,mut, str(Mapped_reference_logits.Omicron_Reference.iloc[j]),row[Mapped_reference_logits.Omicron_Reference.iloc[j]])  
            new_row = pd.DataFrame([pos,amino,0])
            reference_rows.append(new_row.T)
    df = pd.concat(reference_rows)
    df['most_likely_reference_amino'] = reference_probabilities.idxmax().values
    df['most_likely_mutant_amino'] = mutant_probabilities.idxmax().values
    df['is_changed'] =df['most_likely_reference_amino'] != df['most_likely_mutant_amino']
    df['mutation'] = mutation_map[0][:-1][i]
    all_dfs.append(df)
all_dfs = pd.concat(all_dfs)  
all_dfs.columns = ['pos','reference','change','most_likely_reference_amino','most_likely_mutant_amino','is_changed','mutation']

all_dfs['reference'] =  all_dfs.reference.str.replace('nan','-')
all_dfs['most_likely_reference_amino'] = all_dfs.most_likely_reference_amino.fillna('-')
all_dfs['most_likely_mutant_amino'] = all_dfs.most_likely_mutant_amino.fillna('-')
all_dfs['is_changed'] = [False if row.reference == '-' else row.is_changed for i,row in all_dfs.iterrows()]

0 66 V V67A V 0.6632507552758656
1 94 I I95T I 0.6399677831630597
2 141 D D142G D 0.46809364366052864
3 211 I I212L I 0.6811693277686027
4 338 D D342G D 0.6702668745428441
5 370 L L374S L 0.5981164860643255
6 372 P P376S P 0.6404487951216322
7 374 F F378S F 0.553687403602542
8 416 N N420K N 0.5681506807824485
9 439 K K443N K 0.5550667101874298
10 445 S S449G S 0.663536394758246
11 476 N N480S N 0.5853512491429638
12 477 K K481T K 0.47173229923744353
13 483 A A487E A 0.5963193246117537
14 492 R R496Q R 0.4631296637419956
15 495 S S499G S 0.7405882922151783
16 497 R R501Q R 0.5003630553340124
17 500 Y Y504N Y 0.6567590368631814
18 504 H H508Y H 0.19330287770217616
19 546 K K550T K 0.7478693844842157
20 613 G G617D G 0.6019175570124584
21 654 Y Y658H Y 0.5288087432876305
22 678 K K682N K 0.47565467558663
23 680 H H684P H 0.18268104725982678
24 763 K K767N K 0.17622819631929593
25 795 Y Y799D Y 0.004290211270500779
26 855 K K859N K 0.1766849015302701
27 953 H H957Q H 0.07970364307478332
28

In [83]:
all_dfs[all_dfs.mutation == '67V']

Unnamed: 0,pos,reference,change,most_likely_reference_amino,most_likely_mutant_amino,is_changed,mutation
0,1,M,-0.000007,12.0,12.0,False,67V
0,2,F,0.000784,13.0,13.0,False,67V
0,3,V,-0.000424,19.0,19.0,False,67V
0,4,F,0.00288,13.0,13.0,False,67V
0,5,L,0.000126,10.0,10.0,False,67V
...,...,...,...,...,...,...,...
0,1269,K,-0.000135,11.0,11.0,False,67V
0,1270,L,0.000255,10.0,10.0,False,67V
0,1271,H,-0.000614,8.0,8.0,False,67V
0,1272,Y,-0.001023,18.0,18.0,False,67V


In [84]:
rounded_all_dfs = all_dfs
# rounded_all_dfs.change = np.around(rounded_all_dfs.change.astype(float),5)

In [85]:
import plotly.io as pio
pio.renderers.default = "vscode"

In [86]:
import plotly.express as px
df = px.data.tips()
fig = px.line(rounded_all_dfs, x="pos", y="change",color="is_changed", facet_col='mutation',facet_col_wrap=6,height=1000, width=1500, hover_data=['pos','reference','change','most_likely_reference_amino','most_likely_mutant_amino','is_changed','mutation'])
fig.update_traces(marker={'size': 3})
fig.show()

In [87]:
all_subsets = []
for mut in rounded_all_dfs.mutation.unique():
    subset = rounded_all_dfs[rounded_all_dfs.mutation == mut]
    subset = subset[subset.pos.isin(pd.Series(rounded_all_dfs.mutation.unique()).str[:-1].astype(int))]
    all_subsets.append(subset)
all_subsets = pd.concat(all_subsets)

In [88]:
import plotly.express as px
df = px.data.tips()
fig = px.line(all_subsets, x="pos", y="change", color="mutation", facet_col='mutation',facet_col_wrap=6,height=1000, width=1500,hover_data=['pos','reference','change','most_likely_reference_amino','most_likely_mutant_amino','is_changed','mutation'])
fig.show()

# Epistasis experiments

In [89]:
rounded_all_dfs.to_csv('DMS/Epistasis/Omicron_Epistasis_Softmax.csv')