# Analyze mutations on host switch branches

As another test for host-switching associated mutations, I would like to do this branch test to determine whether there are mutations that are associated specifically with cross-species transmission events. In this notebook, I will do the following: 

1. For each branch on the tree, determine whether it is part of a host switch or not. A branch is part of a host switch if it's host state is different from its parent host state, or if it's child branch is a different host state than it is. 
2. If the branch is part of a host switch, categorize whether it is the parent or child branch. In this mascot model with the migration events mapped onto individual branches, what you end up with is for host switch branches, 2 branches of distinct host states that are continuous. So for example, the parent branch will have one child, which is the portion of the entire branch which is a different host state. 
3. Enumerate whether there are any mutations on the branch. 

In [32]:
import glob
import re,copy, imp
import pandas as pd 
import numpy as np
from Bio import SeqIO
from Bio.Seq import Seq

# for this to work, you will need to download the most recent version of baltic, available here 
bt = imp.load_source('baltic', '../baltic/baltic/baltic-modified-for-muts.py')
from io import StringIO
import time

In [33]:
from datetime import date
current_date = str(date.today())

In [34]:
# define colors 
domestic_color="#4E83AE"
wild_color="#CEB540"
human_color="#DE4428"

# Translations and ancestral sequence reconstruction

In [35]:
def read_alignment(alignment_file):
    alignment_dict = {}

    for seq in SeqIO.parse(alignment_file, "fasta"):
        seqName = seq.description 
        sequence = str(seq.seq)
        alignment_dict[seqName] = sequence
        
    return(alignment_dict)

In [36]:
def return_cds_coordinates(genbank_ref_file):
    
    from Bio import GenBank
    with open(genbank_ref_file) as handle:
        for record in GenBank.parse(handle):

            # pull out the CDS feature; the gene coordinatees are in the feature.location. Get help with help(feauture)
            for f in record.features:
                if f.key == "CDS":
                    cds_start = int(f.location.split("..")[0])
                    cds_stop = int(f.location.split("..")[1])
                
    return(cds_start, cds_stop)

In [37]:
def return_mutations_on_branch(branch):
    if branch == None: 
        mutations = []
    elif "mutations" in branch.traits: 
        mutations = branch.traits["mutations"].split(",")
    else:
        mutations = []
    
    return(mutations)

In [38]:
def return_mutated_sequence(sequence, muts, cds_start, cds_stop):
    # make into a list because strings are immutable, while list are not
    mutated_sequence = list(sequence)
    
    for m in muts:
        site = int(m[1:-1])-1   # -1 is because of 0 indexing
        ancestral_nt = m[0]
        mutated_nt = m[-1]
        
        # since we are going backwards up the tree, we are reconstructing the ancestral sequence
        mutated_sequence[site] = ancestral_nt
    
    mutated_sequence = "".join(mutated_sequence)
    mutated_aa_sequence, aa_muts = return_mutated_aa_sequence(sequence, mutated_sequence, cds_start, cds_stop)
    # return a string
    return(mutated_sequence, mutated_aa_sequence, aa_muts)

In [39]:
def return_aa_sequence(sequence, cds_start, cds_stop):
    
    ha_cds = str(sequence)[cds_start-1:cds_stop-1]    # slice string based on cds coordinates
    ha_cds_seq = Seq(ha_cds)    # make it a Seq object
    translation = ha_cds_seq.translate()
    
    return(str(translation))

In [40]:
def return_mutated_aa_sequence(sequence, mutated_sequence, cds_start, cds_stop):
    
    ha_cds = str(sequence)[cds_start-1:cds_stop-1]    # slice string based on cds coordinates
    ha_cds_seq = Seq(ha_cds)    # make it a Seq object
    translation = ha_cds_seq.translate()
    mutated_translation = str(Seq(str(mutated_sequence)[cds_start-1:cds_stop-1]).translate())  # same as above but on 1 line
    
    aa_muts = []
    for i in range(len(translation)):
        if mutated_translation[i] != translation[i]:
            aa_mut = mutated_translation[i] + str(i+1) + translation[i]
            aa_muts.append(aa_mut)

    return(mutated_translation, aa_muts)

In [41]:
def return_all_parents(k, parents_dict, sequence, cds_start, cds_stop):
    mutations = return_mutations_on_branch(k)
    
    # if at root
    if k.parent == None:
        return(parents_dict)
    
    # if not at root yet
    elif k.branchType == "leaf":
        
        # do something else here....we've already recorded the mutations and stuff so we should just go up one
        parents_dict = return_all_parents(k.parent, parents_dict, sequence, cds_start, cds_stop)
    
    else:
        sequence, aa_sequence, aa_muts = return_mutated_sequence(sequence, mutations, cds_start, cds_stop)
        parents_dict[k] = {"nt_muts": mutations, "sequence":sequence, "aa_sequence":aa_sequence, 
                                  "aa_muts":aa_muts}
        parents_dict = return_all_parents(k.parent, parents_dict, sequence, cds_start, cds_stop)
        
    return(parents_dict)

In [42]:
def return_sequence_map(tree, alignment_dict, cds_start, cds_stop):
    
    all_nodes = {}

    for k in tree.Objects: 
        if k.branchType == "leaf":
            sequence = alignment_dict[k.name]
            aa_sequence = return_aa_sequence(sequence, cds_start, cds_stop)
            mutations = return_mutations_on_branch(k)
            mutated_sequence, mutated_aa_sequence, aa_muts = return_mutated_sequence(sequence, mutations, cds_start, cds_stop)
            all_nodes[k] = {"muts":mutations, "aa_muts":aa_muts, "leaves":"NA"}
            
            # parents dict will include all parental nodes from the tip back to the root with their mutations, 
            # nucleotide sequences, and names as 'branchName':{'nt_muts':[list of nt muts], 'sequence': str(nt sequence)}
            parents_dict = {}
            parents_dict = return_all_parents(k, parents_dict, sequence, cds_start, cds_stop)
            
            # make a master list of internal nodes we've already inferred to not repeat work
            for p in parents_dict:
                leaves = p.leaves
                sequence = parents_dict[p]['sequence']
                aa_sequence = parents_dict[p]['aa_sequence']
                aa_muts = parents_dict[p]['aa_muts']
                muts = parents_dict[p]['nt_muts']
                
                # check to see if the name matches and if the leaves match; sometimes baltic assigns the same numeric
                # name to 2 different nodes!
                if p in all_nodes:
                    if all_nodes[p]['leaves'] == leaves:
                        pass
                    else:
                        all_nodes[p] = {"muts":muts, "aa_muts":aa_muts, "leaves":leaves}
                else: 
                    all_nodes[p] = {"muts":muts, "aa_muts":aa_muts, "leaves":leaves}

    return(all_nodes)

## Grab all trees and process posterior

In [43]:
def get_taxa_lines(tree_path):    

    lines_to_write = ""
    with open(tree_path, 'rU') as infile:
        for line in infile: ## iterate through each line
            if 'state' not in line.lower(): #going to grab all the interesting stuff in the .trees file prior to the newick tree strings
                lines_to_write = lines_to_write + line

    return(lines_to_write)

In [44]:
def convert_strain_to_number(taxa_lines):
    
    output_dict = {}
    
    translation_block = taxa_lines.split("Translate\n")[1]
    translation_list = translation_block.replace("\t","").split("\n")
    
    for t in translation_list: 
        information = t.lstrip().replace(",","")  # remove leading white spaces and commas
        
        if len(information.split(" ")) == 2:
            numeric_id = information.split(" ")[0]
            strain_name = information.split(" ")[1]
        
            output_dict[numeric_id] = strain_name
            
        else: 
            pass
        
    return(output_dict)

In [45]:
def convert_leaves_to_strains(input_leaves, strains_dict): 
    output_list = []
    
    for l in input_leaves: 
        strain_name = strains_dict[l]
        output_list.append(strain_name)
    
    output_set = set(output_list)
    return(output_set)

In [46]:
def get_burnin_value(tree_path, burnin_percent):
    with open(tree_path, 'rU') as infile:
        numtrees = 0
        for line in infile: ## iterate through each line
            if 'state' in line.lower(): #going to grab all the interesting stuff in the .trees file prior to the newick tree strings
                numtrees += 1
    
    burnin = numtrees * burnin_percent
    return(burnin)

In [47]:
"""this is a function I'm not currently using, but it will enumerate mutations as well as which host type branches
they occur on"""

def enumerate_mutations(tree):
    muts = {}
    host_branches = {}
    
    for k in tree.Objects:
        host = k.traits['typeTrait']
        if host in host_branches: 
            host_branches[host] += 1
        else:
            host_branches[host] = 1
        
        if 'mutations' in k.traits: 
            mutations = k.traits['mutations'].split(",")
            
            for m in mutations: 
                site = m[1:-1]
                if site not in muts: 
                    muts[site] = [host]
                else: 
                    muts[site].append(host)
            
    return(muts, host_branches)

In [48]:
def tabulate_mutations(mutations_dict):
    output_dict = {}
    for mut in mutations_dict:
        output_dict[mut] = {}
        hosts = list(set(mutations_dict[mut]))
        for h in hosts: 
            output_dict[mut][h] = mutations_dict[mut].count(h)
    
    return(output_dict)

In [49]:
def muts_on_branch(branch):
    
    if 'mutations' in branch.traits: 
        mutations = branch.traits['mutations'].split(",")
    
    else: 
        mutations = ""
    
    return(mutations)

In [50]:
"""given a branch, determine whether any of its children are of a different host state"""

def enumerate_child_traits(branch):
    child_traits = []
    
    current_trait = branch.traits['typeTrait']
    
    for c in branch.children: 
        child_trait = c.traits['typeTrait']
        child_traits.append(child_trait)
    
    if set(child_traits) == {current_trait}:
        host_switch_branch = "no"
    
    else:
        host_switch_branch = "yes"
        
    return(host_switch_branch)

In [51]:
"""given a parent node and its child, infer whether either the parent or child has a mutation on it"""

def mut_near_migration_event(parent, child):
    
    if 'mutations' in parent.traits:
        parent_mutations = parent.traits['mutations'].split(",")
    else:
        parent_mutations = []
        
    if 'mutations' in child.traits:
        child_mutations = child.traits['mutations'].split(",")
        
    else:
        child_mutations = []
        
    return(parent_mutations, child_mutations)

In [52]:
"""this function will iterate through the tree, and record each inferred migration event along the phylogeny. It
will also capture the inferred time of the migration event as a decimal date, and whether the parent or child branch
surrounding the migration event have any mutations on them. These values are stored in a dictionary where the 
key is a number to keep track of the total number of migration events on the tree, and the value is a dictionary
housing data about the migration event. this includes its type, the parent and child host states, date, and mutations"""

def enumerate_mutations_and_migration_events(tree, sequence_map):
    output_dict = {}
    branch_counter = 0
        
    for k in tree.Objects:
                
        trait = k.traits['typeTrait']
        parent_node = k.parent
        
        if 'typeTrait' not in parent_node.traits:
            parent_trait = "root"
        
        # only write out migration events that are not from root to deme
        else:
            parent_trait = parent_node.traits['typeTrait']
                        
            # if the current trait differs from the parent trait, then this is a host switch event. Enumerate 
            # mutations on the parent and current branch; here, I'm denoting the current branch as the child 
            # branch because I am labelling host switches as being comprised of a parent and child state/branch
            if trait != parent_trait:
                host_switch_event = "yes"
                migration_event = parent_trait + "-to-" + trait

                # does a mutation occur on the parent or child branch? 
                parent_mut, child_mut = mut_near_migration_event(parent_node, k)
                
                # write out the information for the parent and child branches
                branch_counter += 1
                if parent_node in sequence_map: 
                    parent_aa_muts = sequence_map[parent_node]['aa_muts']
                else:
                    print(str(parent_node) + "not in sequence map")
                    
                output_dict[branch_counter] = {"type":migration_event, "host_switch": host_switch_event,
                                                 "host":parent_trait, "role":"parent", "mutations":parent_mut, 
                                               "aa_muts":parent_aa_muts,"branch_id":parent_node}

                if k in sequence_map: 
                    aa_muts = sequence_map[k]['aa_muts']
                else:
                    print(str(k) + "not in sequence map")
                branch_counter += 1
                output_dict[branch_counter] = {"type":migration_event, "host_switch": host_switch_event,
                                                 "host":trait, "role":"child", "mutations":child_mut, 
                                               "aa_muts":aa_muts,"branch_id":k}


            
            # The below if else business is to make sure that we don't count mutations on some branches twice
            # If we have a non-host switch branch that will be the parent to a host switch, we have to make sure that 
            # that branch's mutations aren't counted twice 
            # if the current trait is the same as the parent trait:
            else:
                
                # if the branch is a node and it has no children or if it is a leaf, then you are not a host 
                # switch branch, and we can record your mutations and be done
                # here, we record the role as "on" because the mutations are just on the branch and we don't have
                # parent and child branches
                if (k.branchType == 'node' and len(k.children) == 0) or k.branchType == 'leaf':
                    mutations = muts_on_branch(k)
                    aa_muts = sequence_map[k]["aa_muts"]
                    migration_event = "within-" + trait
                    host_switch_event = "no"
                    
                    branch_counter += 1
                    output_dict[branch_counter] = {"type":migration_event, "host_switch": host_switch_event,
                                                 "host":trait, "role":"on", "mutations":mutations,
                                                   "aa_muts":aa_muts,"branch_id":k}
                
                # if you are a node and you do have children, we need to check your children to make sure that 
                # they do not have a trait change. If they do, then you are part of a host switch.
                elif len(k.children) > 0:
                    host_switch_event = enumerate_child_traits(k)
                    
                    # if we are on a host_switch branch, pass; we will detect this when we get down to the child
                    # branch that is a host switch 
                    if host_switch_event == "yes":
                        to_record = "no"
                    else: 
                        mutations = muts_on_branch(k)
                        aa_muts = sequence_map[k]['aa_muts']
                        migration_event = "within-" + trait
                        
                        branch_counter += 1
                        output_dict[branch_counter] = {"type":migration_event, "host_switch": host_switch_event,
                                                 "host":trait, "role":"on", "mutations":mutations, 
                                                       "aa_muts":aa_muts,"branch_id":k}
                
    return(output_dict)

In [53]:
def run_on_posterior_trees(all_trees, burnin, genbank_ref_file, alignment):
    start_time = time.time()

    cds_start, cds_stop = return_cds_coordinates(genbank_ref_file)
    alignment_dict = read_alignment(alignment)
    
    with open(all_trees, "r") as infile:

        tree_counter = 0
        mutations_dict = {}

        for line in infile:
            if 'tree STATE_' in line:
                tree_counter += 1
                if tree_counter >= burnin:

                    temp_tree = StringIO(taxa_lines + line)
                    tree = bt.loadNexus(temp_tree)
                    #tree.setAbsoluteTime(2019.227)
                    
                    # generate the sequence map, which maps for each branch the mutations, aa muts, and sequences
                    sequence_map = return_sequence_map(tree, alignment_dict, cds_start, cds_stop)

                    # iterate through the tree and pull out all migration events
                    mutations_dict[tree_counter] = enumerate_mutations_and_migration_events(tree, sequence_map)

    # print the amount of time this took
    total_time_seconds = time.time() - start_time
    total_time_minutes = total_time_seconds/60
    print("this took", total_time_seconds, "seconds (", total_time_minutes," minutes) to run on", tree_counter, "trees")
    return(mutations_dict, sequence_map)

# Run on posterior trees

In [54]:
# trees file paths
skyline_human_off = {"label": "skyline-human-off", "mig_direction":"forwards in time",
                      "trees": "../../h5n1-host-classification/beast/beast-runs/2022-04-19-mascot-3deme-skyline-fixed-muts-logger/combined-2022-06-01.muts.trees"}
# skyline_human_off = {"label": "skyline-human-off", "mig_direction":"forwards in time",
#                      "trees": "/Volumes/data-backups-post-doc/stored_files_too_big_for_laptop/h5n1-host-classification/beast/beast-runs/2021-03-15-mascot-3deme-skyline-with-mig-history/no-human-mig-2021-03-31/2021-03-15-mascot-3deme-no-human-mig.combined.trees"}

alignment = "../../h5n1-host-classification/beast/alignments/aligned_h5n1_ha-3deme-1per-country-month-host-downsampled-bad-dates-2021-06-09-with-annotations-2021-07-06.fasta"
genbank_ref_file = "../test-data/reference_h5n1_ha.gb"

In [55]:
to_run = skyline_human_off

trees_file_path = to_run['trees']
label = to_run['label']

In [56]:
all_trees = trees_file_path
burnin_percent = 0

taxa_lines = get_taxa_lines(all_trees)
burnin = get_burnin_value(all_trees, burnin_percent)
print(burnin)

  after removing the cwd from sys.path.
  


0


In [57]:
mutations_dict, sequence_map = run_on_posterior_trees(all_trees, burnin, genbank_ref_file, alignment)

this took 14966.451487064362 seconds ( 249.44085811773937  minutes) to run on 1133 trees


In [58]:
#mutations_dict

In [59]:
"""this will generate a multi-index dataframe from the migrations dictionary"""

mutations_df = pd.DataFrame.from_dict({(i,j): mutations_dict[i][j] 
                           for i in mutations_dict.keys() 
                           for j in mutations_dict[i].keys()},
                       orient='index')
 
mutations_df.reset_index(inplace=True)
mutations_df.rename(columns={'level_0': 'tree_number', 'level_1': 'branch_number'}, inplace=True)
mutations_df.head()

Unnamed: 0,tree_number,branch_number,type,host_switch,host,role,mutations,aa_muts,branch_id
0,1,1,within-wild,no,wild,on,"[C72T, G173A, T392C, A856G, C1171A, C1341T]","[R51K, I124T, T279A, Q384K]",<baltic.node object at 0x7f86187a5278>
1,1,2,within-wild,no,wild,on,"[A9C, C957T]",[],<baltic.node object at 0x7f86187a52b0>
2,1,3,within-wild,no,wild,on,"[T10C, A202G, C219A, C392T, T447A, G488A, C660...","[N61D, T124I, D142E, R156K, E228K, P233S]",<baltic.node object at 0x7f86187a52e8>
3,1,4,within-wild,no,wild,on,"[T264A, T483G, C699A, C783T, G1287A, T1464C]",[H154Q],<baltic.node object at 0x7f86187a5320>
4,1,5,within-wild,no,wild,on,"[A672G, T789C, A1017G, C1028G, A1293G, A1517G,...","[T336S, K499R, M527I]",<baltic.node object at 0x7f86187a5358>


In [60]:
long_form_mutations_df = mutations_df.explode('aa_muts')
long_form_mutations_df.head()

Unnamed: 0,tree_number,branch_number,type,host_switch,host,role,mutations,aa_muts,branch_id
0,1,1,within-wild,no,wild,on,"[C72T, G173A, T392C, A856G, C1171A, C1341T]",R51K,<baltic.node object at 0x7f86187a5278>
0,1,1,within-wild,no,wild,on,"[C72T, G173A, T392C, A856G, C1171A, C1341T]",I124T,<baltic.node object at 0x7f86187a5278>
0,1,1,within-wild,no,wild,on,"[C72T, G173A, T392C, A856G, C1171A, C1341T]",T279A,<baltic.node object at 0x7f86187a5278>
0,1,1,within-wild,no,wild,on,"[C72T, G173A, T392C, A856G, C1171A, C1341T]",Q384K,<baltic.node object at 0x7f86187a5278>
1,1,2,within-wild,no,wild,on,"[A9C, C957T]",,<baltic.node object at 0x7f86187a52b0>


In [61]:
# long_form_mutations_df = mutations_df.explode('mutations')
# long_form_mutations_df

In [62]:
# write out to tsv
long_form_mutations_df.to_csv("/Users/lmoncla/src/h5n1-gwas/data/aa-mutations-on-branches-"+current_date+".tsv", sep="\t")