# Parse divergence tree and output tips with evidence of onward transmission using nextstrain jsons

October 14, 2019 

Using the same method as used for mumps, I would like to try parsing the H5N1 nextstrain tree (json format) and quantifying tips with evidence of onward transmission vs. tips without evidence of onward transmission. My thought here is that this could be used to replace my current aim 2. This way, I could actually associated metadata from tips with a metric for transmission. 

In [2]:
import sys, subprocess, glob, os, shutil, re, importlib
from subprocess import call
import imp
bt = imp.load_source('baltic', '/Users/lmoncla/src/baltic/baltic-iqtree.py')

%matplotlib inline
import matplotlib as mpl
from matplotlib import pyplot as plt
import matplotlib.patheffects as path_effects
import matplotlib.lines as mlines
from matplotlib.font_manager import FontProperties
import matplotlib.colors as clr
import textwrap as textwrap
from textwrap import wrap

import pandas as pd

import numpy as np
from scipy.special import binom

import rpy2
import seaborn as sns
%load_ext rpy2.ipython

import json
import collections
from collections import Counter
from Bio import SeqIO
from Bio import Seq
import Bio.Phylo

The rpy2.ipython extension is already loaded. To reload it, use:
  %reload_ext rpy2.ipython


# Functions from augur to read in nextstrain jsons and convert to trees

In [3]:
# function to use the json module to read in a json file and store it as "data"                
def read_json(file_name):
    try:
        handle = open(file_name, 'r')
    except IOError:
        pass
    else:
        data = json.load(handle)
        handle.close()
    return data

In [4]:
# original code that Trevor gave me for parsing through tree jsons and returning descendents
def all_descendants(node):
    """Take node, ie. dict, and return a flattened list of all nodes descending from this node"""
    yield node
    
    # this will recursively return all internal nodes (nodes with children)
    if 'children' in node:
        for child in node['children']:
            for desc in all_descendants(child):
                yield desc

In [5]:
# Biopython's trees don't store links to node parents, so we need to build
# a map of each node to its parent.
# Code from the Bio.Phylo cookbook: http://biopython.org/wiki/Phylo_cookbook
def all_parents(tree):
    parents = {}
    for clade in tree.find_clades(order='level'):
        for child in clade:
            parents[child] = clade
    return parents

In [6]:
def annotate_parents(tree):
    # Get all parent nodes by node.
    parents_by_node = all_parents(tree)

    # Next, annotate each node with its parent.
    for node in tree.find_clades():
        if node == tree.root:
            node.up = None
        else:
            node.up = parents_by_node[node]

    # Return the tree.
    return tree

In [7]:
def json_to_tree(json_dict, root=True):
    """Returns a Bio.Phylo tree corresponding to the given JSON dictionary exported
    by `tree_to_json`.

    Assigns links back to parent nodes for the root of the tree.

    >>> import json
    >>> json_fh = open("tests/data/json_tree_to_nexus/flu_h3n2_ha_3y_tree.json", "r")
    >>> json_dict = json.load(json_fh)
    >>> tree = json_to_tree(json_dict)
    >>> tree.name
    u'NODE_0002020'
    >>> len(tree.clades)
    2
    >>> tree.clades[0].name
    u'NODE_0001489'
    >>> hasattr(tree, "attr")
    True
    >>> "dTiter" in tree.attr
    True
    """
    node = Bio.Phylo.Newick.Clade()
    node.name = json_dict["strain"]

    if "children" in json_dict:
        # Recursively add children to the current node.
        node.clades = [json_to_tree(child, root=False) for child in json_dict["children"]]

    # Assign all non-children attributes.
    for attr, value in json_dict.items():
        if attr != "children":
            setattr(node, attr, value)

    node.numdate = node.attr.get("num_date")
    node.divergence = node.attr.get("div")

    if "translations" in node.attr:
        node.translations = node.attr["translations"]

    if root:
        node = annotate_parents(node)

    return node

# Read in metadata, find proper parent node, and add branch lengths

In [18]:
def read_metadata(metadata_path):
    metadata = {}
    
    with open(metadata_path, "r") as infile:
        metadata_categories = {}
        
        for line in infile: 
            if "isolate_id" in line: # skip first line
                for i in range(len(line.split("\t"))):
                    metadata_categories[i] = line.split("\t")[i]
                    
            elif "isolate_id" not in line:
                strain = line.split("\t")[0]
                metadata[strain] = {}
                
                for i in range(len(line.split("\t"))):
                    ID = metadata_categories[i]
                    metadata[strain][ID] = line.split("\t")[i]

    return(metadata)

In [24]:
def return_country(strain_name, metadata):
        
    country = metadata[strain_name]["country"]
    
    return(country)

In [72]:
def return_region(strain_name, metadata):
        
    region = metadata[strain_name]["region"]
    
    return(region)

In [73]:
def return_host(strain_name, metadata):
        
    host = metadata[strain_name]["host"]
    
    return(host)

In [97]:
def return_proper_parent_node(node):
    """given an internal node, traverse back up the tree to find a parental node with a 
    real branch length (basically, collapse the polytomy)"""
    
    #print(node, node.length)
    if abs(node.divergence - node.up.divergence) < (1/length_alignment): 
        
        #print("going up 1 node")
        if node.up !=None:
            parent_node = return_proper_parent_node(node.up)
        
        else:
            #print("root is proper parent")
            parent_node = node
    
    else: 
        #print("current node has proper length")
        parent_node = node
    
    return(parent_node)

In [113]:
def add_nodes(node):
    """Take node and add up branch lengths for total subtending tree from that node"""
    total_lengths = 0
    
    if node.is_terminal() == True: 
        total_lengths += node.branch_length

    
    else:
        total_lengths += node.branch_length
        for child in node.clades:
            
            total_lengths += add_nodes(child)
                            
    return(total_lengths)

In [192]:
def return_dates_of_children(proper_parent,current_tip,long_spillover_threshold):
    current_tip_date = float(current_tip.attr['num_date'])
    number_long = 0
    number_short = 0
        
    for i in proper_parent.get_terminals():
        date = float(i.attr['num_date'])

        if date - current_tip_date >= long_spillover_threshold:
            number_long += 1
        else:
            number_short += 1
            
#         if current_tip.strain == 'A/chicken/Vietnam/DT-171/2004':
#             print(current_tip_date, date, number_long, number_short)

            
    return(number_short, number_long)

In [193]:
def return_descendants_dict(tree, metadata, length_alignment):
    
    output_dict = {}
    not_polytomies = []
    polytomies = []
    
    for i in tree.find_clades(): ## iterate over objects in tree            
        if i.is_terminal() == True:

            country = return_country(i.strain, metadata)
            region = return_region(i.strain,metadata)
            host = return_host(i.strain, metadata)
            date = i.attr['num_date']

            if abs(i.divergence - i.up.divergence < (1/length_alignment)): 
                polytomies.append(i.strain)

                proper_parent = return_proper_parent_node(i.up)
                branch_length = add_nodes(proper_parent) - i.branch_length
                number_children = len(proper_parent.get_terminals()) - 1
                short_spillovers,long_spillovers = return_dates_of_children(proper_parent,i,long_spillover_threshold)

                output_dict[i.strain] = {}
                output_dict[i.strain]['branch_lengths'] = branch_length
                output_dict[i.strain]['number_children'] = number_children
                output_dict[i.strain]['country'] = country
                output_dict[i.strain]['region'] = region
                output_dict[i.strain]['host'] = host
                output_dict[i.strain]['long_spillovers'] = long_spillovers
                output_dict[i.strain]['short_spillovers'] = short_spillovers

            else:
                not_polytomies.append(i.strain)
                output_dict[i.strain] = {}
                output_dict[i.strain]['branch_lengths'] = 0
                output_dict[i.strain]['number_children'] = 0
                output_dict[i.strain]['country'] = country
                output_dict[i.strain]['region'] = region
                output_dict[i.strain]['host'] = host
                output_dict[i.strain]['long_spillovers'] = 0
                output_dict[i.strain]['short_spillovers'] = 0
    
    return(polytomies,not_polytomies,output_dict)

# Set paths, run it

In [198]:
# set paths
tree_path = "/Users/lmoncla/src/avian-flu/auspice/flu_avian_h5n1_ha_tree.json"
metadata_path = "/Users/lmoncla/src/avian-flu/results/metadata_h5n1_ha.tsv"
length_alignment = 1762
long_spillover_threshold = 1  #0.083 # this is in terms of years

# run
metadata = read_metadata(metadata_path)
tree = read_json(tree_path)
tree = json_to_tree(tree)
polytomies,not_polytomies,output_dict = return_descendants_dict(tree, metadata, length_alignment)

## Convert to dataframe 

In [199]:
df = pd.DataFrame.from_dict(output_dict, orient="index")
df.reset_index(inplace=True)
df.head()

Unnamed: 0,index,branch_lengths,number_children,country,region,host,long_spillovers,short_spillovers
0,A/Alberta/1/2014,0.0,0,canada,north_america,human,0,0
1,A/American_green_winged_teal/Washington/195750...,0.000608,1,usa,north_america,avian,0,2
2,A/American_wigeon/Washington/196336/2015,0.020912,7,usa,north_america,avian,0,8
3,A/American_wigeon/Washington/196340/2015,0.000653,1,usa,north_america,avian,0,2
4,A/Americanblackduck/Alberta/118/2016,0.005564,2,canada,north_america,avian,0,3


In [180]:
slothing = df[df['index'] == 'A/chicken/Vietnam/DT-171/2004']
slothing

Unnamed: 0,index,branch_lengths,number_children,country,region,host,long_spillovers,short_spillovers
915,A/chicken/Vietnam/DT-171/2004,0.306898,109,vietnam,southeast_asia,avian,0,110


## Given a category to compare and 2 members of that category, perform a Fisher's exact test and return the results 

In [206]:
def return_fishers_exact_test(df,variable,group1,group2):
    group1_with_desc = len(df[(df['number_children'] > 0) & (df[variable] == group1)])
    group1_no_desc = len(df[df[variable] == group1]) - group1_with_desc

    group2_with_desc = len(df[(df['number_children'] > 0) & (df[variable] == group2)])
    group2_no_desc = len(df[df[variable] == group2]) - group2_with_desc
    
    table = [[group1_with_desc,group1_no_desc],[group2_with_desc,group2_no_desc]]
    result = scipy.stats.fisher_exact(table)
    return(table, result)

In [207]:
def return_dateframe_of_Fishers_exact_results(variable,group1,group2,table, result):
    
    group1_with_desc = table[0][0]
    group1_no_desc = table[0][1]
    group2_with_desc = table[1][0]
    group2_no_desc = table[1][1]
    p_value = result[1]
    
    x = {'variable':[variable],'group1':[group1],'group2':[group2],
         'group1_with_desc':[group1_with_desc],'group1_no_desc':[group1_no_desc],
        'group2_with_desc':[group2_with_desc],'group2_no_desc':[group2_no_desc],'p-value':[p_value]}
    
    d = pd.DataFrame.from_dict(x)
    return(d)

In [208]:
variable = "country"
group1 = "china"
group2 = "vietnam"
table,result = return_fishers_exact_test(df,variable,group1,group2)
Fishers_df = return_dateframe_of_Fishers_exact_results(variable,group1,group2,table, result)
Fishers_df

Unnamed: 0,variable,group1,group2,group1_with_desc,group1_no_desc,group2_with_desc,group2_no_desc,p-value
0,country,china,vietnam,23,142,44,108,0.001436


## Run on multiple variables 

Now, given a list of variables to test, run them all. Here, I will first subset to include only avian tips, and then compare regions or countries to each other. 

In [209]:
Fishers_df1 = pd.DataFrame()

variables = [{"country":['china','vietnam']},{'region':['china','southeast_asia']},
            {'region':['china','south_asia']},{'region':['africa','south_asia']}]
df_test = df[df['host'] == 'avian']

for i in variables: 
    for key, value in i.items() :
        variable = key
        group1 = value[0]
        group2 = value[1]

        table,result = return_fishers_exact_test(df_test,variable,group1,group2)
        Fishers_df = return_dateframe_of_Fishers_exact_results(variable,group1,group2,table, result)
        Fishers_df1 = Fishers_df1.append(Fishers_df)
    
Fishers_df1

Unnamed: 0,variable,group1,group2,group1_with_desc,group1_no_desc,group2_with_desc,group2_no_desc,p-value
0,country,china,vietnam,15,110,42,99,0.00051
0,region,china,southeast_asia,48,174,122,257,0.006412
0,region,china,south_asia,48,174,56,159,0.31229
0,region,africa,south_asia,98,141,56,159,0.001022


This suggests that Vietnamese birds are more likely to transmit than Chinese birds, which is a little surprising. This also suggests that African birds are more likely to transmit than South Asian (Middle Eastern) birds, which makes sense. 

I think that there are a few things I will need to really think about/test. 
1. How does uneven sampling impact this? I suppose that I can just test in subsampled trees to see whether this is robust to sampling or not. The current nextstrain jsons I'm using here are subsampled for H5N1 by country to 10 sequences per country per year. 
2. How do I do power testing on this? 
3. How will I validate this as a method? It is a little strange to me that China is showing up as less likely to be an ancestor than Vietnam, especially when more internal nodes on the tree are in China than Vietnam. This makes me a little concerned that this isn't going to work super well. 
4. How do I do multiple comparisons? 

In [202]:
def return_fishers_exact_test_longevity(df,variable,group1,group2):
    group1_with_desc = len(df[(df['long_spillovers'] > 0) & (df[variable] == group1)])
    group1_no_desc = len(df[df[variable] == group1]) - group1_with_desc

    group2_with_desc = len(df[(df['long_spillovers'] > 0) & (df[variable] == group2)])
    group2_no_desc = len(df[df[variable] == group2]) - group2_with_desc
    
    table = [[group1_with_desc,group1_no_desc],[group2_with_desc,group2_no_desc]]
    result = scipy.stats.fisher_exact(table)
    return(table, result)

In [203]:
def return_dateframe_of_Fishers_exact_results_longevity(variable,group1,group2,table, result):
    
    group1_with_long = table[0][0]
    group1_no_long = table[0][1]
    group2_with_long = table[1][0]
    group2_no_long = table[1][1]
    p_value = result[1]
    
    x = {'variable':[variable],'group1':[group1],'group2':[group2],
         'group1_with_long':[group1_with_long],'group1_no_long':[group1_no_long],
        'group2_with_long':[group2_with_long],'group2_no_long':[group2_no_long],'p-value':[p_value]}
    
    d = pd.DataFrame.from_dict(x)
    return(d)

In [204]:
Fishers_df2 = pd.DataFrame()

variables = [{"country":['china','vietnam']},{'region':['china','southeast_asia']},
            {'region':['china','south_asia']},{'region':['africa','south_asia']}]
df_test = df[df['host'] == 'avian']

for i in variables: 
    for key, value in i.items() :
        variable = key
        group1 = value[0]
        group2 = value[1]

        table,result = return_fishers_exact_test_longevity(df_test,variable,group1,group2)
        Fishers_df = return_dateframe_of_Fishers_exact_results_longevity(variable,group1,group2,table, result)
        Fishers_df2 = Fishers_df2.append(Fishers_df)
    
Fishers_df2

Unnamed: 0,variable,group1,group2,group1_with_long,group1_no_long,group2_with_long,group2_no_long,p-value
0,country,china,vietnam,4,121,5,136,1.0
0,region,china,southeast_asia,8,214,12,367,0.81565
0,region,china,south_asia,8,214,9,206,0.808307
0,region,africa,south_asia,13,226,9,206,0.662846


In [200]:
sloth = df[df['number_children'] > 0]
len(sloth)

653

In [201]:
sloth = df[df['long_spillovers'] > 0]
len(sloth)

98

If I define a long spillover as 6 months or greater, there are 142 tips. If I define it as at least 1 year, then I have 98 tips. 

In [None]:
"""Using Fisher.test in R, I tried this, but it did not work: 

sloth <- matrix(c(8,214,12,367,13,226),nrow=3)
fisher.test(sloth, simulate.p.value = TRUE, B=10000)"""