# Reconstruct internal branch sequences 

Given a tree and an alignment, infer the sequences of the internal nodes. Then, we can translate nucleotide changes into amino acid changes. This notebook contains all the work I did to work out the method, but I just ended up copying and pasting all the important cells into the enumerate-mutations-on-branches notebook (since that is really where I wanted to use it)

In [256]:
import importlib, json
import glob
import pandas as pd 
import numpy as np
from Bio.Seq import Seq

import copy
import datetime as dt
import time
    
# for this to work, you will need to download the most recent version of baltic, available here 
bt = imp.load_source('baltic', '../baltic/baltic/baltic-modified-for-muts.py')

In [231]:
def read_alignment(alignment_file):
    alignment_dict = {}

    for seq in SeqIO.parse(alignment_file, "fasta"):
        seqName = seq.description 
        sequence = str(seq.seq)
        alignment_dict[seqName] = sequence
        
    return(alignment_dict)

In [257]:
def return_cds_coordinates(genbank_ref_file):
    
    from Bio import GenBank
    with open(genbank_ref_file) as handle:
        for record in GenBank.parse(handle):

            # pull out the CDS feature; the gene coordinatees are in the feature.location. Get help with help(feauture)
            for f in record.features:
                if f.key == "CDS":
                    cds_start = int(f.location.split("..")[0])
                    cds_stop = int(f.location.split("..")[1])
                
    return(cds_start, cds_stop)

In [258]:
def return_mutations_on_branch(branch):
    if branch == None: 
        mutations = []
    elif "mutations" in branch.traits: 
        mutations = branch.traits["mutations"].split(",")
    else:
        mutations = []
    
    return(mutations)

In [259]:
def return_mutated_sequence(sequence, muts, cds_start, cds_stop):
    # make into a list because strings are immutable, while list are not
    mutated_sequence = list(sequence)
    
    for m in muts:
        site = int(m[1:-1])-1   # -1 is because of 0 indexing
        ancestral_nt = m[0]
        mutated_nt = m[-1]
        
        # since we are going backwards up the tree, we are reconstructing the ancestral sequence
        mutated_sequence[site] = ancestral_nt
    
    mutated_sequence = "".join(mutated_sequence)
    mutated_aa_sequence, aa_muts = return_mutated_aa_sequence(sequence, mutated_sequence, cds_start, cds_stop)
    # return a string
    return(mutated_sequence, mutated_aa_sequence, aa_muts)

In [266]:
def return_aa_sequence(sequence, cds_start, cds_stop):
    
    ha_cds = str(sequence)[cds_start-1:cds_stop-1]    # slice string based on cds coordinates
    ha_cds_seq = Seq(ha_cds)    # make it a Seq object
    translation = ha_cds_seq.translate()
    
    return(str(translation))

In [267]:
def return_mutated_aa_sequence(sequence, mutated_sequence, cds_start, cds_stop):
    
    ha_cds = str(sequence)[cds_start-1:cds_stop-1]    # slice string based on cds coordinates
    ha_cds_seq = Seq(ha_cds)    # make it a Seq object
    translation = ha_cds_seq.translate()
    mutated_translation = str(Seq(str(mutated_sequence)[cds_start-1:cds_stop-1]).translate())  # same as above but on 1 line
    
    aa_muts = []
    for i in range(len(translation)):
        if mutated_translation[i] != translation[i]:
            aa_mut = mutated_translation[i] + str(i+1) + translation[i]
            aa_muts.append(aa_mut)

    return(mutated_translation, aa_muts)

In [262]:
def return_all_parents(k, parents_dict, sequence, cds_start, cds_stop):
    #print(parents_dict)
    mutations = return_mutations_on_branch(k)
    
    # if at root
    if k.parent == None:
        return(parents_dict)
    
    # if not at root yet
    else:
        sequence, aa_sequence, aa_muts = return_mutated_sequence(sequence, mutations, cds_start, cds_stop)
        
        parents_dict[k.parent] = {"nt_muts": mutations, "sequence":sequence, "aa_sequence":aa_sequence, 
                                  "aa_muts":aa_muts}
        parents_dict = return_all_parents(k.parent, parents_dict, sequence, cds_start, cds_stop)
        
    return(parents_dict)

In [263]:
def return_sequence_map(tree, alignment_dict, cds_start, cds_stop):
    
    all_nodes = {}

    for k in tree.Objects: 
        if k.branchType == "leaf":
            sequence = alignment_dict[k.name]
            aa_sequence = return_aa_sequence(sequence, cds_start, cds_stop)
            mutations = return_mutations_on_branch(k)
            mutated_sequence, mutated_aa_sequence, aa_muts = return_mutated_sequence(sequence, mutations, cds_start, cds_stop)
            all_nodes[k.name] = {"muts":mutations, "sequence":sequence, "aa_sequence":aa_sequence, "aa_muts":aa_muts}
            
            # parents dict will include all parental nodes from the tip back to the root with their mutations, 
            # nucleotide sequences, and names as 'branchName':{'nt_muts':[list of nt muts], 'sequence': str(nt sequence)}
            parents_dict = {}
            parents_dict = return_all_parents(k, parents_dict, sequence, cds_start, cds_stop)
            
            # make a master list of internal nodes we've already inferred to not repeat work
            for p in parents_dict:
                sequence = parents_dict[p]['sequence']
                aa_sequence = parents_dict[p]['aa_sequence']
                aa_muts = parents_dict[p]['aa_muts']
                muts = parents_dict[p]['nt_muts']
                if p in all_nodes: 
                    pass
                else: 
                    all_nodes[p] = {"muts":muts, "sequence":sequence, "aa_sequence":aa_sequence, "aa_muts":aa_muts}

    return(all_nodes)

In [264]:
# example for slicing with the proper indices
# x = "apple"
# y = x[0:3] + "s" + x[4:]
# print(y)

In [268]:
# test_tree = "../test-data/test-1-tree.trees"
# alignment = "../../h5n1-host-classification/beast/alignments/aligned_h5n1_ha-3deme-1per-country-month-host-downsampled-bad-dates-2021-06-09-with-annotations-2021-07-06.fasta"
# genbank_ref_file = "../test-data/reference_h5n1_ha.gb"

# tree = bt.loadNexus(test_tree)

# tree

<baltic.tree at 0x7fa711e668d0>

In [269]:
# start_time = time.time()

# cds_start, cds_stop = return_cds_coordinates(genbank_ref_file)
# alignment_dict = read_alignment(alignment)
# sequence_map = return_sequence_map(tree, alignment_dict, cds_start, cds_stop)

# total_time_seconds = time.time() - start_time
# total_time_minutes = total_time_seconds/60
# print("this took", total_time_seconds, "seconds (", total_time_minutes," minutes) to run on")

this took 12.981129169464111 seconds ( 0.21635215282440184  minutes) to run on


In [222]:
# sequence_map

In [223]:
# for s in sequence_map: 
#     if 'C666A' in sequence_map[s]['muts']:
#         print(s)
#         print(sequence_map[s]['muts'])
#         print(sequence_map[s]['aa_muts'])