# Analyze mutations on host switch branches

As another test for host-switching associated mutations, I would like to do this branch test to determine whether there are mutations that are associated specifically with cross-species transmission events. In this notebook, I will do the following: 

1. For each branch on the tree, determine whether it is part of a host switch or not. A branch is part of a host switch if it's host state is different from its parent host state, or if it's child branch is a different host state than it is. 
2. If the branch is part of a host switch, categorize whether it is the parent or child branch. In this mascot model with the migration events mapped onto individual branches, what you end up with is for host switch branches, 2 branches of distinct host states that are continuous. So for example, the parent branch will have one child, which is the portion of the entire branch which is a different host state. 
3. Enumerate whether there are any mutations on the branch. 

In [1]:
import glob
import re,copy, imp
import pandas as pd 
import numpy as np

# for this to work, you will need to download the most recent version of baltic, available here 
bt = imp.load_source('baltic', '../baltic/baltic/baltic-modified-for-muts.py')
from io import StringIO
import time

import rpy2
%load_ext rpy2.ipython

In [2]:
from datetime import date
current_date = str(date.today())

In [3]:
# define colors 
domestic_color="#4E83AE"
wild_color="#CEB540"
human_color="#DE4428"

## Grab all trees and process posterior

In [4]:
def get_taxa_lines(tree_path):    

    lines_to_write = ""
    with open(tree_path, 'rU') as infile:
        for line in infile: ## iterate through each line
            if 'state' not in line.lower(): #going to grab all the interesting stuff in the .trees file prior to the newick tree strings
                lines_to_write = lines_to_write + line

    return(lines_to_write)

In [5]:
def convert_strain_to_number(taxa_lines):
    
    output_dict = {}
    
    translation_block = taxa_lines.split("Translate\n")[1]
    translation_list = translation_block.replace("\t","").split("\n")
    
    for t in translation_list: 
        information = t.lstrip().replace(",","")  # remove leading white spaces and commas
        
        if len(information.split(" ")) == 2:
            numeric_id = information.split(" ")[0]
            strain_name = information.split(" ")[1]
        
            output_dict[numeric_id] = strain_name
            
        else: 
            pass
        
    return(output_dict)

In [6]:
def convert_leaves_to_strains(input_leaves, strains_dict): 
    output_list = []
    
    for l in input_leaves: 
        strain_name = strains_dict[l]
        output_list.append(strain_name)
    
    output_set = set(output_list)
    return(output_set)

In [7]:
def get_burnin_value(tree_path, burnin_percent):
    with open(tree_path, 'rU') as infile:
        numtrees = 0
        for line in infile: ## iterate through each line
            if 'state' in line.lower(): #going to grab all the interesting stuff in the .trees file prior to the newick tree strings
                numtrees += 1
    
    burnin = numtrees * burnin_percent
    return(burnin)

In [8]:
"""this is a function I'm not currently using, but it will enumerate mutations as well as which host type branches
they occur on"""

def enumerate_mutations(tree):
    muts = {}
    host_branches = {}
    
    for k in tree.Objects:
        host = k.traits['typeTrait']
        if host in host_branches: 
            host_branches[host] += 1
        else:
            host_branches[host] = 1
        
        if 'mutations' in k.traits: 
            mutations = k.traits['mutations'].split(",")
            
            for m in mutations: 
                site = m[1:-1]
                if site not in muts: 
                    muts[site] = [host]
                else: 
                    muts[site].append(host)
            
    return(muts, host_branches)

In [9]:
def tabulate_mutations(mutations_dict):
    output_dict = {}
    for mut in mutations_dict:
        output_dict[mut] = {}
        hosts = list(set(mutations_dict[mut]))
        for h in hosts: 
            output_dict[mut][h] = mutations_dict[mut].count(h)
    
    return(output_dict)

In [10]:
def muts_on_branch(branch):
    
    if 'mutations' in branch.traits: 
        mutations = branch.traits['mutations'].split(",")
    
    else: 
        mutations = ""
    
    return(mutations)

In [11]:
"""given a branch, determine whether any of its children are of a different host state"""

def enumerate_child_traits(branch):
    child_traits = []
    
    current_trait = branch.traits['typeTrait']
    
    for c in branch.children: 
        child_trait = c.traits['typeTrait']
        child_traits.append(child_trait)
    
    if set(child_traits) == {current_trait}:
        host_switch_branch = "no"
    
    else:
        host_switch_branch = "yes"
        
    return(host_switch_branch)

In [12]:
"""given a parent node and its child, infer whether either the parent or child has a mutation on it"""

def mut_near_migration_event(parent, child):
    
    if 'mutations' in parent.traits:
        parent_mutations = parent.traits['mutations'].split(",")
    else:
        parent_mutations = []
        
    if 'mutations' in child.traits:
        child_mutations = child.traits['mutations'].split(",")
        
    else:
        child_mutations = []
        
    return(parent_mutations, child_mutations)

In [13]:
"""this function will iterate through the tree, and record each inferred migration event along the phylogeny. It
will also capture the inferred time of the migration event as a decimal date, and whether the parent or child branch
surrounding the migration event have any mutations on them. These values are stored in a dictionary where the 
key is a number to keep track of the total number of migration events on the tree, and the value is a dictionary
housing data about the migration event. this includes its type, the parent and child host states, date, and mutations"""

def enumerate_mutations_and_migration_events(tree):
    output_dict = {}
    branch_counter = 0
        
    for k in tree.Objects:
                
        trait = k.traits['typeTrait']
        parent_node = k.parent
        
        if 'typeTrait' not in parent_node.traits:
            parent_trait = "root"
        
        # only write out migration events that are not from root to deme
        else:
            parent_trait = parent_node.traits['typeTrait']
                        
            # if the current trait differs from the parent trait, then this is a host switch event. Enumerate 
            # mutations on the parent and current branch; here, I'm denoting the current branch as the child 
            # branch because I am labelling host switches as being comprised of a parent and child state/branch
            if trait != parent_trait:
                host_switch_event = "yes"
                migration_event = parent_trait + "-to-" + trait

                # does a mutation occur on the parent or child branch? 
                parent_mut, child_mut = mut_near_migration_event(parent_node, k)
                
                # write out the information for the parent and child branches
                branch_counter += 1
                output_dict[branch_counter] = {"type":migration_event, "host_switch": host_switch_event,
                                                 "host":parent_trait, "role":"parent", "mutations":parent_mut, "branch_id":parent_node}
                
                branch_counter += 1
                output_dict[branch_counter] = {"type":migration_event, "host_switch": host_switch_event,
                                                 "host":trait, "role":"child", "mutations":child_mut, "branch_id":k}


            
            # The below if else business is to make sure that we don't count mutations on some branches twice
            # If we have a non-host switch branch that will be the parent to a host switch, we have to make sure that 
            # that branch's mutations aren't counted twice 
            # if the current trait is the same as the parent trait:
            else:
                
                # if the branch is a node and it has no children or if it is a leaf, then you are not a host 
                # switch branch, and we can record your mutations and be done
                # here, we record the role as "on" because the mutations are just on the branch and we don't have
                # parent and child branches
                if (k.branchType == 'node' and len(k.children) == 0) or k.branchType == 'leaf':
                    mutations = muts_on_branch(k)
                    migration_event = "within-" + trait
                    host_switch_event = "no"
                    
                    branch_counter += 1
                    output_dict[branch_counter] = {"type":migration_event, "host_switch": host_switch_event,
                                                 "host":trait, "role":"on", "mutations":mutations, "branch_id":k}
                
                # if you are a node and you do have children, we need to check your children to make sure that 
                # they do not have a trait change. If they do, then you are part of a host switch.
                elif len(k.children) > 0:
                    host_switch_event = enumerate_child_traits(k)
                    
                    # if we are on a host_switch branch, pass; we will detect this when we get down to the child
                    # branch that is a host switch 
                    if host_switch_event == "yes":
                        to_record = "no"
                    else: 
                        mutations = muts_on_branch(k)
                        migration_event = "within-" + trait
                        
                        branch_counter += 1
                        output_dict[branch_counter] = {"type":migration_event, "host_switch": host_switch_event,
                                                 "host":trait, "role":"on", "mutations":mutations, "branch_id":k}
                
    return(output_dict)

In [14]:
def run_on_posterior_trees(all_trees, burnin):
    start_time = time.time()

    with open(all_trees, "r") as infile:

        tree_counter = 0
        mutations_dict = {}

        for line in infile:
            if 'tree STATE_' in line:
                tree_counter += 1
                
                if tree_counter >= burnin:

                    temp_tree = StringIO(taxa_lines + line)
                    tree = bt.loadNexus(temp_tree)
                    #tree.setAbsoluteTime(2019.227)

                    # iterate through the tree and pull out all migration events
                    mutations_dict[tree_counter] = enumerate_mutations_and_migration_events(tree)

    # print the amount of time this took
    total_time_seconds = time.time() - start_time
    total_time_minutes = total_time_seconds/60
    print("this took", total_time_seconds, "seconds (", total_time_minutes," minutes) to run on", tree_counter, "trees")
    return(mutations_dict)

# Run on posterior trees

In [15]:
# trees file paths
skyline_human_off = {"label": "skyline-human-off", "mig_direction":"forwards in time",
                      "trees": "../../h5n1-host-classification/beast/beast-runs/2022-04-19-mascot-3deme-skyline-fixed-muts-logger/it3/2022-04-19-mascot-3deme-skyline-tipdates.muts.trees"}
# skyline_human_off = {"label": "skyline-human-off", "mig_direction":"forwards in time",
#                      "trees": "/Volumes/data-backups-post-doc/stored_files_too_big_for_laptop/h5n1-host-classification/beast/beast-runs/2021-03-15-mascot-3deme-skyline-with-mig-history/no-human-mig-2021-03-31/2021-03-15-mascot-3deme-no-human-mig.combined.trees"}

In [16]:
to_run = skyline_human_off

trees_file_path = to_run['trees']
label = to_run['label']

In [17]:
all_trees = trees_file_path
burnin_percent = 0.1

taxa_lines = get_taxa_lines(all_trees)
burnin = get_burnin_value(all_trees, burnin_percent)
print(burnin)

  after removing the cwd from sys.path.


27.700000000000003


  


In [18]:
mutations_dict = run_on_posterior_trees(all_trees, burnin)

this took 31.523279905319214 seconds ( 0.5253879984219869  minutes) to run on 277 trees


In [19]:
"""this will generate a multi-index dataframe from the migrations dictionary"""

mutations_df = pd.DataFrame.from_dict({(i,j): mutations_dict[i][j] 
                           for i in mutations_dict.keys() 
                           for j in mutations_dict[i].keys()},
                       orient='index')
 
mutations_df.reset_index(inplace=True)
mutations_df.rename(columns={'level_0': 'tree_number', 'level_1': 'branch_number'}, inplace=True)
mutations_df

Unnamed: 0,tree_number,branch_number,type,host_switch,host,role,mutations,branch_id
0,28,1,within-domestic,no,domestic,on,"[C72T, A189G, A349G, T482A, A856G, T1035A, A12...",<baltic.node object at 0x7f999b614908>
1,28,2,domestic-to-wild,yes,domestic,parent,[],<baltic.node object at 0x7f999b614780>
2,28,3,domestic-to-wild,yes,wild,child,"[A9C, T10C, A202G, C219A, T447A, G488A, C660T,...",<baltic.node object at 0x7f999b6146d8>
3,28,4,within-wild,no,wild,on,"[T264A, T483G, C699A, C783T, G1287A, T1464C]",<baltic.node object at 0x7f999b6147b8>
4,28,5,within-wild,no,wild,on,"[G486A, A672G, T789C, A1017G, C1028G, A1517G, ...",<baltic.node object at 0x7f999b6146a0>
5,28,6,within-wild,no,wild,on,"[C9T, G320A, A744G, G771A, T783A, A1158G, C1566A]",<baltic.node object at 0x7f999b6147f0>
6,28,7,within-wild,no,wild,on,"[G225A, G349A, A439G, G555A, A635G, C1561T, A1...",<baltic.node object at 0x7f999b614748>
7,28,8,wild-to-domestic,yes,wild,parent,"[G533A, G750T, T874G, G998A, C1029T, A1052G, A...",<baltic.node object at 0x7f999b614940>
8,28,9,wild-to-domestic,yes,domestic,child,"[T57C, G1056A, T1389C]",<baltic.node object at 0x7f999b614898>
9,28,10,within-domestic,no,domestic,on,"[C43T, T69C, G75T, G202A, A219G, G227A, T243A,...",<baltic.node object at 0x7f999b614828>


In [21]:
long_form_mutations_df = mutations_df.explode('mutations')
long_form_mutations_df

Unnamed: 0,tree_number,branch_number,type,host_switch,host,role,mutations,branch_id
0,28,1,within-domestic,no,domestic,on,C72T,<baltic.node object at 0x7f999b614908>
0,28,1,within-domestic,no,domestic,on,A189G,<baltic.node object at 0x7f999b614908>
0,28,1,within-domestic,no,domestic,on,A349G,<baltic.node object at 0x7f999b614908>
0,28,1,within-domestic,no,domestic,on,T482A,<baltic.node object at 0x7f999b614908>
0,28,1,within-domestic,no,domestic,on,A856G,<baltic.node object at 0x7f999b614908>
0,28,1,within-domestic,no,domestic,on,T1035A,<baltic.node object at 0x7f999b614908>
0,28,1,within-domestic,no,domestic,on,A1228G,<baltic.node object at 0x7f999b614908>
0,28,1,within-domestic,no,domestic,on,A1476G,<baltic.node object at 0x7f999b614908>
1,28,2,domestic-to-wild,yes,domestic,parent,,<baltic.node object at 0x7f999b614780>
2,28,3,domestic-to-wild,yes,wild,child,A9C,<baltic.node object at 0x7f999b6146d8>


In [22]:
# write out to tsv
long_form_mutations_df.to_csv("/Users/lmoncla/src/h5n1-gwas/data/mutations-on-branches-"+current_date+".tsv", sep="\t")