# Calculate enrichment scores across tree - BEAST trees

This is an initial attempt to break this code into reasonable chunks that can be coded, tested, and executed separately. In this notebook, we will be reading in a tree, enumerating all the mutations on that tree, and calculating enrichment scores for each of them. These enrichment scores are based on [this code](https://github.com/sheppardlab/pGWAS/blob/master/assomap_given_phylo.py), written for [this paper](https://www.nature.com/articles/s41467-018-07368-7#Sec10) detailed in lines 245-273 and calculated from the following contingency table as: 

|host|presence|absence|
|:------|:-------|:------|
|host 1|A|B| 
|host 2|C|D|

where A, B, C, and D are counts of the mutation's presence and absence in host 1 and host 2. The odds ratio is then calculated as: `OR = (A * D)/(B * C)`

In this notebook, this code is written for parsing a tree json format, output from Nextstrain. In subsequent notebooks, I will alter this for running on beast trees. 


### A NOTE ON BALTIC: 
Currently, in the posterior set of trees, mutations are annotated as traits as `&typeTrait=domestic,mutations="G730A,A846G,C1203A,A1278G"`. However, baltic is not reading in all the mutations properly, instead only reading in the first mutation in the list. This is due to line 1116 of baltic, which attempts to find trait strings. In order to make this work, I added in a comma as an acceptable character in the 4th block of the string search. I saved the version with this small edit as `../baltic/baltic/baltic-modified-for-muts.py`. I will use that baltic version in this notebook. Additionally, for this to work, the typeTrait also needs to be a string. 


Also, a note on trees: as currently written, this does some switching between numNames and strain nameese. If the tree you are testing it on doesen't have those, the code will need to be changed. 

In [63]:
import glob, json
import re,copy, imp
import pandas as pd 
import numpy as np
from io import StringIO

# for this to work, you will need to download the most recent version of baltic, available here 
#bt = imp.load_source('baltic', '../baltic/baltic/baltic.py')
bt = imp.load_source('baltic', '../baltic/baltic/baltic-modified-for-muts.py')

## Infer each mutation that occurs across the tree

In [2]:
def get_taxa_lines(tree_path):    

    lines_to_write = ""
    with open(tree_path, 'rU') as infile:
        for line in infile: ## iterate through each line
            if 'state' not in line.lower(): #going to grab all the interesting stuff in the .trees file prior to the newick tree strings
                lines_to_write = lines_to_write + line

    return(lines_to_write)

In [125]:
def convert_strain_to_number(taxa_lines):
    
    output_dict = {}
    
    translation_block = taxa_lines.split("Translate\n")[1]
    translation_list = translation_block.replace("\t","").split("\n")
    
    for t in translation_list: 
        information = t.lstrip().replace(",","")  # remove leading white spaces and commas
        
        if len(information.split(" ")) == 2:
            numeric_id = information.split(" ")[0]
            strain_name = information.split(" ")[1]
        
            output_dict[numeric_id] = strain_name
            
        else: 
            pass
        
    return(output_dict)

In [141]:
def convert_leaves_to_strains(input_leaves, strains_dict): 
    output_list = []
    
    for l in input_leaves: 
        strain_name = strains_dict[l]
        output_list.append(strain_name)
        
    return(output_list)

In [3]:
def get_burnin_value(tree_path, burnin_percent):
    with open(tree_path, 'rU') as infile:
        numtrees = 0
        for line in infile: ## iterate through each line
            if 'state' in line.lower(): #going to grab all the interesting stuff in the .trees file prior to the newick tree strings
                numtrees += 1
    
    burnin = numtrees * burnin_percent
    return(burnin)

In [4]:
"""count the number of tips on the tree corresponding to each host category"""
def return_all_host_tips(tree):
    host_counts = {'human':0, 'domestic':0, 'wild':0}
    
    for k in tree.Objects: 
        if k.branchType == "leaf":
            host = k.traits['typeTrait']
            host_counts[host] += 1
    return(host_counts)

In [5]:
"""given a branch and gene, return the mutations present on that branch"""

def return_muts_on_branch(branch):
    muts = []
    
    if 'mutations' in branch.traits:
        muts = branch.traits['mutations'].split(",")
                            
    return(muts)

In [6]:
"""this function does 2 things: 1. for each branch, it records the branch name and its branch length in a 
dictionary; 2. it adds up the total branch length on the tree. For the beast trees, we only have branch lengths 
in time. However, we can get a reasonable branch length (and I think this is perfectly reasonabble for this purpose)
by just summing the total mutations on the branch and dividing by the total number of sites. The only purpose 
in this analysis for the total tree branch length is to get an idea of the number of mutations that should 
occur across the tree. So this should work."""

def return_total_tree_branch_length(tree, n_sites_alignment):
    total_branch_length = 0
    branch_lengths = {}
    
    for k in tree.Objects:
        muts_on_branch = return_muts_on_branch(k)
        branch_length_time = k.length
        branch_length_divergence = len(muts_on_branch)/n_sites_alignment
                
        total_branch_length += branch_length_divergence
        branch_lengths[k] = branch_length_divergence
    
    return(total_branch_length, branch_lengths)

In [7]:
"""Traverse the tree from root to tip. On each branch, tips and nodes, gather every nucleotide and amino acid 
mutation"""

def gather_all_mut_on_tree(tree):
    
    all_nt_muts = []
    
    for k in tree.Objects:
        if 'mutations' in k.traits:
            nt_muts = k.traits['mutations'].split(",")
            all_nt_muts.extend(nt_muts)
                
    all_nt_muts = list(set(all_nt_muts))
    
    return(all_nt_muts)

In [8]:
"""return the total number of times that the mutation arises on the phylogeny. This includes instances of mutation 
on internal nodes and on tips and counts each with the same weight"""

def return_number_times_on_tree(tree, mut):
    times_on_tree = 0
    
    for k in tree.Objects:
        if 'mutations' in k.traits:
        
            nt_muts = k.traits['mutations'].split(",")
            if mut in nt_muts: 
                times_on_tree += 1
                                                
    return(times_on_tree)

In [9]:
"""return the total number of times that the mutation arises on the phylogeny. This includes instances of mutation 
on internal nodes and on tips and counts each with the same weight"""

def return_branch_length_mut_on_tree(tree, mut,n_sites_alignment):
    branch_length = 0
            
    for k in tree.Objects:
        if 'mutations' in k.traits:
            nt_muts = k.traits['mutations'].split(",")
                
            # if the mutation arises on this branch
            if mut in nt_muts: 
                branch_length_div = len(nt_muts)/n_sites_alignment
                    
                branch_length += branch_length_div
                    
    return(branch_length)

In [149]:
"""Given a starting internal node, and a tip you would like to end at, traverse the full path from that node to
tip. Along the way, gather mutations that occur along that path. Once you have reached the ending 
tip, return the list of mutations that fell along that path. Input for the ending tip here is a tip name, while 
the starting node is a node object"""

def return_all_muts_on_path_to_tip(starting_node, ending_tip, muts, strains_dict):
    
    # set an empty list of mutations and enumerate the children of the starting node; children can be tips or nodes
    children = starting_node.children
    
    for child in children:
        local_muts = []
        
        """if the child is a leaf: if leaf is the target end tip, add the mutations that occur on that branch to 
        the list and return the list; if leaf is not the target end tip, move on"""
        """if the child is an internal node: first, test whether that child node contains the target tips in its 
        children. child.leaves will output a list of the names of all tips descending from that node. If not, pass. 
        if the node does contain the target end tip in its leaves, keep traversing down that node recursively, 
        collecting mutations as you go"""

        if child.branchType == "leaf":
            if child.name != ending_tip:
                pass
            elif child.name == ending_tip:
                host = child.traits["typeTrait"]
                local_muts = return_muts_on_branch(child)
                muts.extend(local_muts)
                return(host, muts)
        
        elif child.branchType == "node":
            strain_leaves = convert_leaves_to_strains(child.leaves, strains_dict)
            if ending_tip not in strain_leaves:
                pass
            else:
                local_muts = return_muts_on_branch(child)
                muts.extend(local_muts)
                host, muts = return_all_muts_on_path_to_tip(child, ending_tip, muts, strains_dict)
    
    return(host, muts)

In [151]:
"""at times, will need to check whether the revertant mutation occcurs downstream. Return the revertant mutation"""

def return_opposite_mutation(mut):
    
    site = mut[1:-1]
    ref = mut[0]
    alt = mut[-1]
    opposite = alt+site+ref
    
    return(opposite)

In [152]:
"""given a tree, mutation, and gene, return the number of times that mutation is present in each host"""

def return_host_distribution_mutation(tree, mut, strains_dict):
    
    host_counts_dict = {'human':0, 'domestic':0, 'wild':0}
    back_mutation = return_opposite_mutation(mut)
    
    # iterate through tree
    for k in tree.Objects:
        if 'mutations' in k.traits:
            nt_muts = k.traits['mutations'].split(",")
                
            # if we have reached a node or tip in the tree with the target mutation, enumerate descendants
            if mut in nt_muts: 
                    
                # if the mutation occurs on a leaf, record the host and move on 
                if k.branchType == 'leaf':
                    host = k.traits['typeTrait']
                    host_counts_dict[host] += 1
                    
                # else, if the mutation occurs on a node, traverse the children and return host
                elif k.branchType == "node":
                    all_leaves = k.leaves
                    for leaf in all_leaves: 
                        muts = []
                        strain_name = strains_dict[leaf]
                        host, muts = return_all_muts_on_path_to_tip(k, strain_name, muts, strains_dict)
                        
                        if back_mutation in muts: 
                            pass
                        elif back_mutation and mut in muts:  # if both the mutation and backmutation occur, print
                            print("something odd happened",leaf, back_mutation, mut)
                        else:
                            host_counts_dict[host] += 1
                                
    return(host_counts_dict)

## Calculate the enrichment scores

In [53]:
"""calculate an enrichment score for an individual mutation, based on the counts across hosts"""

def calculate_enrichment_score_counts(mut_counts_dict, host1, host2, host_counts):
    total_host1_tree = host_counts[host1]
    total_host2_tree = host_counts[host2]
    
    mut_host1 = mut_counts_dict[host1]
    mut_host2 = mut_counts_dict[host2]
    
    # this is calculating this table as counts
    presence_host1 = mut_host1
    absence_host1 = total_host1_tree - mut_host1
    presence_host2 = mut_host2
    absence_host2 = total_host2_tree - mut_host2
    
    if presence_host2 == 0:
        presence_host2 = 1
    if absence_host1 == 0:
        absence_host1 = 1

    # this score is calculated in terms of its enrichment in host 1
    score = (presence_host1 * absence_host2)/(presence_host2 * absence_host1)
#     score = (presence_host1 + absence_host2) - (presence_host2 + absence_host1)
    return(score)

In [54]:
"""calculate an enrichment score for an individual mutation, based on the counts across hosts"""

def calculate_enrichment_score_proportions(mut_counts_dict, host1, host2, host_counts):
    total_host1_tree = host_counts[host1]
    total_host2_tree = host_counts[host2]
    
    mut_host1 = mut_counts_dict[host1]
    mut_host2 = mut_counts_dict[host2]
    
    total_tips_in_tree = total_host1_tree + total_host2_tree
    
    # this is calculating this table as proportions
    presence_host1 = (mut_host1)/total_tips_in_tree
    absence_host1 = (total_host1_tree - mut_host1)/total_tips_in_tree
    presence_host2 = (mut_host2)/total_tips_in_tree
    absence_host2 = (total_host2_tree - mut_host2)/total_tips_in_tree
    
#     if presence_host2 == 0:
#         presence_host2 = 1
#     if absence_host1 == 0:
#         absence_host1 = 1

    # this score is calculated in terms of its enrichment in host 1
    #score = (presence_host1 * absence_host2)/(presence_host2 * absence_host1)
    score = (presence_host1 + absence_host2) - (presence_host2 + absence_host1)
    return(score)

In [154]:
"""for a tree and all amino acid mutations, calculate the enrichment scores across the tree"""
def calculate_enrichment_scores(tree, nt_muts, host1, host2, min_required_count, method, host_counts, n_sites_alignment, strains_dict):
    scores = []
    scores_dict = {}
    times_detected_dict = {}
    branch_lengths_dict = {}
    host_counts_dict2 = {}
    
    if method == "counts":
        enrichment_calculation_function = calculate_enrichment_score_counts
    elif method == "proportions":
        enrichment_calculation_function = calculate_enrichment_score_proportions
    
    for n in nt_muts:
        times_detected = return_number_times_on_tree(tree, n)
        times_detected_dict[n] = times_detected

        branch_length_mut = return_branch_length_mut_on_tree(tree, n, n_sites_alignment)
        branch_lengths_dict[n] = branch_length_mut

        host_counts_dict = return_host_distribution_mutation(tree, n, strains_dict)
        host_counts_dict2[n] = host_counts_dict
        total_tips_with_mut = host_counts_dict['human'] + host_counts_dict['domestic'] + host_counts_dict['wild']

        if total_tips_with_mut > min_required_count:
            enrichment_score = enrichment_calculation_function(host_counts_dict, host1,host2, host_counts)
            #print(enrichment_score, total_tips_with_mut, host_counts_dict)
            scores.append(enrichment_score)
            scores_dict[n] = enrichment_score
            
            
    return(scores, scores_dict, times_detected_dict, branch_lengths_dict, host_counts_dict2)

In [None]:
def run_on_posterior_trees(all_trees, burnin):
    start_time = time.time()

    with open(all_trees, "r") as infile:

        tree_counter = 0

        for line in infile:
            if 'tree STATE_' in line:
                
                if tree_counter < 2:
                #if tree_counter >= burnin:
                    tree_counter += 1
                    temp_tree = StringIO(taxa_lines + line)
                    tree = bt.loadNexus(temp_tree)

    # print the amount of time this took
    total_time_seconds = time.time() - start_time
    total_time_minutes = total_time_seconds/60
    print("this took", total_time_seconds, "seconds (", total_time_minutes," minutes) to run on", tree_counter, "trees")

In [156]:
all_trees = "../test-data/2021-03-15-mascot-3deme-no-human-mig-TEST.combined.trees"
burnin_percent = 0

taxa_lines = get_taxa_lines(all_trees)
strains_dict = convert_strain_to_number(taxa_lines)

burnin = get_burnin_value(all_trees, burnin_percent)
n_sites_alignment = 1762

with open(all_trees, "r") as infile:
    muts = []
    tree_counter = 0

    for line in infile:
        if 'tree STATE_' in line:
                
            if tree_counter >= burnin:
                tree_counter += 1
                print(tree_counter)
                temp_tree = StringIO(taxa_lines + line)
                tree = bt.loadNexus(temp_tree)      
                        
                host_counts = return_all_host_tips(tree)
                x,y = return_total_tree_branch_length(tree, n_sites_alignment)
                all_nt_muts = gather_all_mut_on_tree(tree)
                times_on_tree = return_number_times_on_tree(tree, mut)
                branch_length_mut = return_branch_length_mut_on_tree(tree, mut,n_sites_alignment)
                xdict = return_host_distribution_mutation(tree, mut, strains_dict)
    
                scores, scores_dict, times_detected_dict, branch_lengths_dict, host_counts_dict2 = calculate_enrichment_scores(tree, all_nt_muts, "human","domestic", 1, "counts",host_counts, n_sites_alignment, strains_dict)

scores, scores_dict

  after removing the cwd from sys.path.
  
