In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from Bio import SeqIO
from Bio.Seq import Seq
import pandas as pd
import numpy as np
import random
import torch
import pickle
import math
from evo import Evo, positional_entropies
from evo.scoring import prepare_batch, score_sequences
from tqdm import tqdm
from generating_utils import perposition_scores, remove_gaps, complement_5_strand, parse_genbank_to_dataframe
from analysis_utils import read_fasta, violin

In [None]:
'''
This notebook is used to generate  and analyze additional data for the poster

'''

# Set up

In [None]:
device = 'cuda:0'
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)


evo_model = Evo('evo-1-131k-base')

model, tokenizer = evo_model.model, evo_model.tokenizer
model.to(device) 
model.eval()
dnp=True

In [None]:
aligned_records_clp=read_fasta("picked_clpA.txt")
aligned_aa_clp=read_fasta("picked_clpA_prot.txt")
aligned_records_pro=read_fasta("picked_proRS.txt")
aligned_aa_pro=read_fasta("picked_prot_proRS.txt")

In [None]:
color1='#68e7a9'
color2='#fe6100'
color3='#dc267f'
color4='#788ef0'

# Functions

In [None]:
def generate_ortholog_FW(aligned,alignedAA, 
                         model,tokenizer,device,
                         context_front="full",context_back="full"):
    cuttingAA=len(alignedAA[0]["sequence"])//2
    cutting=3*(cuttingAA)
    scores=[]
    meta={}
    counter=0
    for i in tqdm(range(len(aligned)), desc="Processing tasks", unit="task"):
        for j in range(len(aligned)):
            head=remove_gaps(aligned[i]["sequence"][:cutting])
            tail=remove_gaps(aligned[j]["sequence"][cutting:])
            if context_front!="full":
                head=head[-context_front:]    
            if context_back!="full":
                tail=tail[:context_back]
            seq=head+tail
            seq_RC=complement_5_strand(seq)
            sc_FW=perposition_scores([seq], model, tokenizer, device)[0]
            scores.append(sc_FW)

    return scores

In [None]:
def generate_ortholog_RC(aligned,alignedAA,
                         model,tokenizer,device,
                         context_front="full",context_back="full"):
    cuttingAA=len(alignedAA[0]["sequence"])//2
    cutting=3*(cuttingAA)
    scores=[]
    meta={}
    counter=0
    for i in tqdm(range(len(aligned)), desc="Processing tasks", unit="task"):
        for j in range(len(aligned)):
            head=remove_gaps(aligned[i]["sequence"][:cutting])
            tail=remove_gaps(aligned[j]["sequence"][cutting:])
            if context_front!="full":
                head=head[-context_front:]
            if context_back!="full":
                tail=tail[:context_back]   
            seq=head+tail
            seq_RC=complement_5_strand(seq)
            sc_RC=perposition_scores([seq_RC], model, tokenizer, device)[0][::-1]
            scores.append(sc_RC)

    return scores

In [None]:
def context_dep_scores(aligned,alignedAA,
                       model,tokenizer,device,
                       cont_list,rem=400):
    full_scores_FW=[]
    full_scores_RC=[]
    for i in cont_list:
        print("context:",i)
        full_scores_FW.append(
            generate_ortholog_FW(aligned,alignedAA,model,tokenizer,device,i,rem)
        )
        full_scores_RC.append(
            generate_ortholog_RC(aligned,alignedAA,model,tokenizer,device,rem,i) 
        )
    return full_scores_FW, full_scores_RC

In [None]:
sequence_meta

In [None]:
seq1=str(heads[("TGTTCAGCGC",120683)]+tails[("TGTTCAGCGC",120683)])
seq2=str(heads[("TGTTCAGCGC",120683)]+tails[("TGTTCAGCGC",611086)])

In [None]:
seq_RC_1=complement_5_strand(seq1)
sc_FW_1=perposition_scores([seq1], model, tokenizer, device)[0]
sc_RC_1=perposition_scores([seq_RC_1], model, tokenizer, device)[0][::-1]

In [None]:
seq_RC_2=complement_5_strand(seq2)
sc_FW_2=perposition_scores([seq2], model, tokenizer, device)[0]
sc_RC_2=perposition_scores([seq_RC_2], model, tokenizer, device)[0][::-1]

In [None]:
math.exp(np.mean(np.log(sc_RC_2[790:990])))

# Genome distribution graphs

In [None]:
filename='path to EColiK12.gbff'
feat_coli=parse_genbank_to_dataframe(filename)

In [None]:
# collect intragenic regions
intra_list=[]
for i in range(2,len(feat_coli)-1):
    if feat_coli.iloc[i+1]["Start"]-feat_coli.iloc[i]["End"]>50:
        intra_list.append((feat_coli.iloc[i]["End"],feat_coli.iloc[i+1]["Start"],feat_coli.iloc[i+1]["Start"]-feat_coli.iloc[i]["End"]))



In [None]:
data={"CDS":[],"tRNA":[],"rRNA":[],"Intergenic regions":[]}
names={"CDS":[],"tRNA":[],"rRNA":[]}
for i in tqdm(range(len(feat_coli)), desc="Processing tasks", unit="task"):
    start=feat_coli.iloc[i]["Start"]
    if start>1000:
        type=feat_coli.iloc[i]["Type"]
        if (type=="CDS") | (type=="rRNA")|(type=="tRNA"):
            name=feat_coli.iloc[i]["Gene"]
            # if not "ins" in name:
            seq=[str(genome_coli[start-1000:feat_coli.iloc[i]["End"]])]
            data[type].append(np.mean(positional_entropies(seq,model,tokenizer)[0][1000:]))
            names[type].append(name) 
            
for i in tqdm(range(len(intra_list)), desc="Processing tasks", unit="task"):
    seq=[str(genome_coli[intra_list[i][0]-1000:intra_list[i][1]])]
    data["Intergenic regions"].append(np.mean(positional_entropies(seq,model,tokenizer)[0][1000:]))

In [None]:
display(violin(data,"Type","Average Entropy"))

In [None]:
real_clp=[]
for i in tqdm(range(len(aligned_records_clp)), desc="Processing tasks", unit="task"):
    seq=[remove_gaps(aligned_records_clp[i]["sequence"])]
    real_clp.append(np.mean(positional_entropies(seq,model,tokenizer)[0]))

In [None]:
real_pro=[]
for i in tqdm(range(len(aligned_records_pro)), desc="Processing tasks", unit="task"):
        seq=[remove_gaps(aligned_records_pro[i]["sequence"])]
        real_pro.append(np.mean(positional_entropies(seq,model,tokenizer)[0]))

In [None]:
data={"ClpA":real_clp,"ProRS":real_pro}

In [None]:
display(violin(data,"Gene","Average Entropy"))

# Context dependencies (ortholog)

In [None]:
# gen meta
meta={}
meta_fl={}
counter=0
for i in range(20):
    for j in range(20):
        meta[counter]=(i,j)
        meta_fl[(i,j)]=counter
        counter+=1
        

In [None]:
# context list
c_l=[10,50,100,200,300,400,500,600,700,800,900,1000,1100, 1200]
clp_fw_sc,clp_rc_sc=context_dep_scores(aligned_records_clp,aligned_aa_clp, 
                         model, tokenizer,device,c_l)

In [None]:
with open("scores_clpA_FW_con.pkl", "wb") as file:
    pickle.dump(clp_fw_sc, file)
with open("scores_clpA_RC_con.pkl", "wb") as file:
    pickle.dump(clp_rc_sc, file)

In [None]:
c_l=[10,50,100,200,300,400,500,600,700,800,900]
pro_fw_sc,pro_rc_sc=context_dep_scores(aligned_records_pro,aligned_aa_pro, 
                         model, tokenizer,device,c_l)

In [None]:
with open("scores_proRS_FW_con.pkl", "wb") as file:
    pickle.dump(pro_fw_sc, file)    
with open("scores_proRS_RC_con.pkl", "wb") as file:
    pickle.dump(pro_rc_sc, file)

# Context dependencies analysis

In [None]:
with open("scores_clpA_FW_con.pkl", "rb") as file:
    clp_FW_con = pickle.load(file)
    
with open("scores_clpA_RC_con.pkl", "rb") as file:
    clp_RC_con = pickle.load(file)


with open("scores_proRS_FW_con.pkl", "rb") as file:
    pro_FW_con = pickle.load(file)
    
with open("scores_proRS_RC_con.pkl", "rb") as file:
    pro_RC_con = pickle.load(file)

In [None]:
meta={}
meta_fl={}
counter=0
for i in range(20):
    for j in range(20):
        meta[counter]=(i,j)
        meta_fl[(i,j)]=counter
        counter+=1

In [None]:
c_l=[10,50,100,200,300,400,500,600,700,800,900,1000,1100,1200]

In [None]:
def create_collections(sc,cl,coord_ind,length=20,rem=400):
    '''
    structure of the output:
    context: {(ind1,ind2):[real,constructed]}
    '''
    col={}
    for c in range(len(cl)):
        col[cl[c]]={}
        for l in range(length):
            for ll in range(length):
                if l!=ll:
                    real=sc[c][coord_ind[(l,l)]]
                    con=sc[c][coord_ind[(l,ll)]]
                    if (len(real)==rem+cl[c]) and (len(con)==rem+cl[c]):
                        col[cl[c]][(l,ll)]=[real,con]

    return col

In [None]:
def get_results_FW(cc,coll):
    res={
        "id":[],
        "full":[], "full_n":[], "tail":[], "tail_n":[], "tail_50":[],"tail_50_n":[],
        "tail_100":[],"tail_100_n":[], "tail_150":[], "tail_150_n":[],
        "tail_200":[],"tail_200_n":[],"tail_250":[],"tail_250_n":[]
    }
    sc=[50,100,150,200,250]
    
    for i in coll.keys():
        res["id"].append(i)
        r=math.exp(np.mean(np.log(coll[i][0])))
        c=math.exp(np.mean(np.log(coll[i][1])))
        res["full"].append(r>c)
        res["full_n"].append(abs(r-c))

        r=math.exp(np.mean(np.log(coll[i][0][cc:])))
        c=math.exp(np.mean(np.log(coll[i][1][cc:])))
        res["tail"].append(r>c)
        res["tail_n"].append(abs(r-c))
        
        for s in sc:
            r=math.exp(np.mean(np.log(coll[i][0][cc:cc+s])))
            c=math.exp(np.mean(np.log(coll[i][1][cc:cc+s])))
            res["tail_"+str(s)].append(r>c)
            res["tail_"+str(s)+"_n"].append(abs(r-c))
            
    res_overview={}
    res_overview["full"]=(res["full"].count(True)/len(res["id"]),res["full"].count(True),len(res["id"]))
    res_overview["tail"]=(res["tail"].count(True)/len(res["id"]),res["tail"].count(True),len(res["id"]))
    for s in sc:
        res_overview["tail_"+str(s)]=(res["tail_"+str(s)].count(True)/len(res["id"]),res["tail_"+str(s)].count(True),len(res["id"]))

    return {"overview":res_overview,"full":res}

In [None]:
def get_results_RC(cc,coll):
    res={
        "id":[],
        "head":[], "head_n":[], "head_50":[],"head_50_n":[],
        "head_100":[],"head_100_n":[], "head_150":[], "head_150_n":[],
        "head_200":[],"head_200_n":[],"head_250":[],"head_250_n":[]
    }
    sc=[50,100,150,200,250]
    
    for i in coll.keys():
        res["id"].append(i)

        r=math.exp(np.mean(np.log(coll[i][0][:-cc])))
        c=math.exp(np.mean(np.log(coll[i][1][:-cc])))
        res["head"].append(r>c)
        res["head_n"].append(abs(r-c))
        
        for s in sc:
            r=math.exp(np.mean(np.log(coll[i][0][-cc-s:-cc])))
            c=math.exp(np.mean(np.log(coll[i][1][-cc-s:-cc])))
            res["head_"+str(s)].append(r>c)
            res["head_"+str(s)+"_n"].append(abs(r-c))
            
    res_overview={}
    res_overview["head"]=(res["head"].count(True)/len(res["id"]),res["head"].count(True),len(res["id"]))
    for s in sc:
        res_overview["head_"+str(s)]=(res["head_"+str(s)].count(True)/len(res["id"]),res["head_"+str(s)].count(True),len(res["id"]))

    return {"overview":res_overview,"full":res}

In [None]:
def get_results_tag(scores,cl,tag,coord_ind):
    '''
    results structure
    context:{overview:(%,#,total),"full":dict df style}
    '''
    collection=create_collections(scores,cl,coord_ind)
    results={}
    for c in cl:
        if tag=="FW":
            results[c]=get_results_FW(c,collection[c])
        else:
            results[c]=get_results_RC(c,collection[c])
    return results

In [None]:
results_clp_FW=get_results_tag(clp_FW_con,c_l,"FW",meta_fl)
results_clp_RC=get_results_tag(clp_RC_con,c_l,"RC",meta_fl)
results_pro_FW=get_results_tag(pro_FW_con,c_l[:-3],"FW",meta_fl)
results_pro_RC=get_results_tag(pro_RC_con,c_l[:-4],"RC",meta_fl)

In [None]:
def extract_and_combine(res1,res2):
    '''
    len(res1.keys())>=len(res2.keys())
    output structure
    method:[accuracy at c1, accuracy at c2, ...]
    '''
    methods=list(res1[list(res1.keys())[0]]["overview"].keys())
    output={m: [] for m in methods}
    for i in res1.keys():
        for m in methods:
            cor=res1[i]["overview"][m][1]
            tot=res1[i]["overview"][m][2]
            if i in res2.keys():
                cor+=res2[i]["overview"][m][1]
                tot+=res2[i]["overview"][m][2]
            output[m].append(cor/tot)
    return output      

In [None]:
FW=extract_and_combine(results_clp_FW,results_pro_FW)
RC=extract_and_combine(results_clp_RC,results_pro_RC)

In [None]:
plt.figure(figsize=(8, 5))
plt.plot(c_l, FW["full"], label='Default Method', color=color1,linewidth=2)
plt.plot(c_l, FW["tail_200"], label='Forward Comparison (200 bp)', color=color2,linewidth=2)
plt.xlabel('Base pairs of Genomic Context Upstream of the Cutsite')  # Replace with your x-axis label
plt.ylabel('Accuracy')  # Replace with your y-axis label
plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(8, 5))
plt.plot(c_l, RC["head_200"], label='Reverse Comparison (200 bp)', color=color3,linewidth=2)
plt.plot(c_l, RC["head"], label='Reverse Comparison (400 bp)', color=color4,linewidth=2)
plt.xlabel('Base pairs of Genomic Context Downstream from the Cutsite')  # Replace with your x-axis label
plt.ylabel('Accuracy')  # Replace with your y-axis label
plt.legend()
plt.show()