In [3]:
import sys
sys.path.append('../../')
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import esm
import torch
from Functions import *

In [None]:
model, alphabet = esm.pretrained.load_model_and_alphabet('esm2_t36_3B_UR50D')
model.eval()
batch_converter = alphabet.get_batch_converter()
device = torch.device("cuda")
model_layers = 36
if torch.cuda.is_available():
    model =  model.to(device)
    print("Transferred model to GPU")

In [None]:
sequence_file = ''
metadata_file = ''
ref_spike_seq = 'MFVFLVLLPLVSSQCVNLTTRTQLPPAYTNSFTRGVYYPDKVFRSSVLHSTQDLFLPFFSNVTWFHAIHVSGTNGTKRFDNPVLPFNDGVYFASTEKSNIIRGWIFGTTLDSKTQSLLIVNNATNVVIKVCEFQFCNDPFLGVYYHKNNKSWMESEFRVYSSANNCTFEYVSQPFLMDLEGKQGNFKNLREFVFKNIDGYFKIYSKHTPINLVRDLPQGFSALEPLVDLPIGINITRFQTLLALHRSYLTPGDSSSGWTAGAAAYYVGYLQPRTFLLKYNENGTITDAVDCALDPLSETKCTLKSFTVEKGIYQTSNFRVQPTESIVRFPNITNLCPFGEVFNATRFASVYAWNRKRISNCVADYSVLYNSASFSTFKCYGVSPTKLNDLCFTNVYADSFVIRGDEVRQIAPGQTGKIADYNYKLPDDFTGCVIAWNSNNLDSKVGGNYNYLYRLFRKSNLKPFERDISTEIYQAGSTPCNGVEGFNCYFPLQSYGFQPTNGVGYQPYRVVVLSFELLHAPATVCGPKKSTNLVKNKCVNFNFNGLTGTGVLTESNKKFLPFQQFGRDIADTTDAVRDPQTLEILDITPCSFGGVSVITPGTNTSNQVAVLYQDVNCTEVPVAIHADQLTPTWRVYSTGSNVFQTRAGCLIGAEHVNNSYECDIPIGAGICASYQTQTNSPRRARSVASQSIIAYTMSLGAENSVAYSNNSIAIPTNFTISVTTEILPVSMTKTSVDCTMYICGDSTECSNLLLQYGSFCTQLNRALTGIAVEQDKNTQEVFAQVKQIYKTPPIKDFGGFNFSQILPDPSKPSKRSFIEDLLFNKVTLADAGFIKQYGDCLGDIAARDLICAQKFNGLTVLPPLLTDEMIAQYTSALLAGTITSGWTFGAGAALQIPFAMQMAYRFNGIGVTQNVLYENQKLIANQFNSAIGKIQDSLSSTASALGKLQDVVNQNAQALNTLVKQLSSNFGAISSVLNDILSRLDKVEAEVQIDRLITGRLQSLQTYVTQQLIRAAEIRASANLAATKMSECVLGQSKRVDFCGKGYHLMSFPQSAPHGVVFLHVTYVPAQEKNFTTAPAICHDGKAHFPREGVFVSNGTHWFVTQRNFYEPQIITTDNTFVSGNCDVVIGIVNNTVYDPLQPELDSFKEELDKYFKNHTSPDVDLGDISGINASVVNIQKEIDRLNEVAKNLNESLIDLQELGKYEQYIKWPWYIWLGFIAGLIAIVMVTIMLCCMTSCCSCLKGCCSCGSCCKFDEDDSEPVLKGVKLHYT'

In [None]:
initial_sequence_embeddings = process_fasta(sequence_file,'S:0',ref_spike_seq,model,model_layers,batch_converter)

In [None]:
compressed_pickle('initial_sequences_aligned_spike',initial_sequence_embeddings)

In [None]:
mutations_list = list(initial_sequence_embeddings['S:0'].keys())
columns = ['label', 'semantic_score', 'grammaticality', 'relative_grammaticality', 'sequence_grammaticality', 'relative_sequence_grammaticality', 'probability']
initial_table = []
for key in mutations_list:
    if key != 'Reference':
        row = pd.DataFrame([initial_sequence_embeddings['S:0'][key].get(c) for c in columns]).T
        row.columns = columns
        initial_table.append(row)
initial_table = pd.concat(initial_table)

In [None]:
mutations_table = []
for fasta in SeqIO.parse(sequence_file, "fasta"):
    name, sequence = fasta.id, str(fasta.seq)
    mutations = get_mutations(ref_spike_seq,str(fasta.seq))
    row = pd.DataFrame({'label':name,'mutations':str(mutations)[1:-1].replace("'","")},index=[0])
    mutations_table.append(row)
mutations_table = pd.concat(mutations_table)

In [None]:
initial_table = pd.merge(initial_table,mutations_table,how='left',left_on='label',right_on='label')
initial_table.label = initial_table.label.str.split('|',expand=True)[1]

In [None]:
initial_table = initial_table.sort_values('semantic_score')
initial_table['semantic_rank'] = initial_table.reset_index().index.astype(int) + 1
initial_table = initial_table.sort_values('grammaticality')
initial_table['grammatical_rank'] =initial_table.reset_index().index.astype(int) + 1
initial_table['acquisition_priority'] = initial_table['semantic_rank'] + initial_table['grammatical_rank']

initial_table = initial_table.sort_values('sequence_grammaticality')
initial_table['sequence_grammatical_rank'] =initial_table.reset_index().index.astype(int) + 1
initial_table['sequence_acquisition_priority'] = initial_table['semantic_rank'] + initial_table['sequence_grammatical_rank']

In [None]:
initial_table = pd.merge(initial_table,pd.read_csv(metadata_file,sep='\t'),how='left',left_on='label',right_on='Accession.ID')
initial_table = initial_table.rename({'Pango.lineage': 'lineage', 'VOC': 'Voc','n':'lineage_count','Collection.date':'sample_date'},axis='columns')

In [None]:
initial_table.to_csv('initial_lineages.csv')