# Setup/Config

In [15]:
import baltic as bt
import matplotlib.pyplot as plt
from matplotlib import cm
import matplotlib.gridspec as gridspec
%matplotlib inline
import seaborn as sns
import pandas as pd
import numpy as np
from itertools import combinations
from collections import defaultdict
import json

In [16]:
tree_path = '../../../data/titer-model/all-branch-effects/jsons/dengue_all_tree.json'
frequencies_path = '../../../data/titer-model/frequencies/dengue_all_frequencies.json'

c1 = '#4477AA'
c2 = '#aa4455'
cmap = cm.YlOrRd

sns.set(style='white', font_scale=1.5)

out_path = '../../../data/titer-model/frequencies/'

# Define plotting functions

In [17]:
def plot_tree(tree, orientation = 'v', colorby=None, color_callable=None, vmax=None,):
    sns.set_style('white')
    plt.yticks(size=0)
    branchWidth=3 ## default branch width
    ll = tree
    
    if orientation != 'v':
        
        def flip(tree):
            for k in tree.traverse_tree(include_all=True):
                if k.branchType == 'node':
                    k.children.reverse()
            tree.drawTree()

        flip(ll)
        
    for k in ll.Objects: ## iterate over objects in tree      
        if orientation == 'v':
            x=k.x 
            y=k.y 

            xp=k.parent.x ## get x position of current object's parent
            if x==None: ## matplotlib won't plot Nones, like root
                x=0.0
            if xp==None:
                xp=x
        else:         ## If not a vertical plot, just flip x and y
            x = k.y
            y = -1*k.x

            xp = k.parent.y
            yp = k.parent.x
            if x == None:
                x = 0.
            if xp == None:
                xp = x
            if yp == None:
                yp = 0.
            yp = -1.*yp

        if colorby and colorby in k.traits:
            if vmax:
                c = color_callable(k.traits[colorby]/vmax)
            else:
                try:
                    c = color_callable(k.traits[colorby])
                except:
                    c = 'gray'
        else:
            c = 'gray'
            
            
            
        if isinstance(k,bt.leaf):# or k.branchType=='leaf': ## if leaf...
            s=10 ## tip size can be fixed
            
            plt.scatter(x,y,s=s,facecolor=c,edgecolor='none',zorder=11) ## plot circle for every tip
            plt.scatter(x,y,s=s+0.8*s,facecolor='k',edgecolor='none',zorder=10) ## plot black circle underneath
              
        if isinstance(k, bt.clade):
            s = 50
            pad = tree.treeHeight
            
            if orientation == 'v':
                plt.plot([x, pad], [y,y], lw=branchWidth, color=c, ls='--', zorder=9) # extend branches so tips line up
                plt.scatter(pad,y,s=s,facecolor=c,edgecolor='none',zorder=11) ## plot circle for every tip
                plt.scatter(pad,y,s=s+0.8*s,facecolor='k',edgecolor='none',zorder=10) ## plot black circle underneath

            else:
                pad = -1.*pad
                plt.plot([x,x],[y,pad], lw=branchWidth, color=c, ls='--', zorder=9)
                plt.scatter(x,pad,s=s,facecolor=c,edgecolor='none',zorder=11) ## plot circle for every tip
                plt.scatter(x,pad,s=s+0.8*s,facecolor='k',edgecolor='none',zorder=10) ## plot black circle underneath
                
        elif isinstance(k,bt.node) or k.branchType=='node': ## if node...
            if orientation == 'v':
                plt.plot([x,x],[k.children[-1].y,k.children[0].y],lw=branchWidth,color=c,ls='-',zorder=9)
            else: ## Flip x and y for node.children coordinates, draw from top -> bottom (*-1)
                plt.plot([k.children[-1].y,k.children[0].y],[y,y],lw=branchWidth,color=c,ls='-',zorder=9)
        
        if orientation == 'v':
            plt.plot([xp,x],[y,y],lw=branchWidth,color=c,ls='-',zorder=9)
        else:
            plt.plot([x,x], [y,yp], lw=branchWidth, color=c, ls='-', zorder=9)
            
            
    if orientation != 'v':
        flip(ll)

# Load trees, calculate pairwise dTiter values, collapse antigenically uniform clades

In [18]:
def trace_between(k1, k2):
    '''Naive path tracing; finds mrca of the two nodes, 
    traces each node backwards until it finds the mrca, 
    returns concatenated list of all nodes on that path'''
    
    k1_trace = []
    k2_trace = []
    
    k = k1 ## Trace from the node 1 to the root (in order)
    while k.index != 'Root':
        k1_trace.append(k)
        k = k.parent
        
    k = k2 ## Start tracing from node2 to the root; break when we find the first node that's shared along node1's path
    while k not in k1_trace:
        k2_trace.append(k)
        k = k.parent
    else:
#         node2_trace.append(k) # do not include mrca in trace
        mrca = k
        
    shared_k1_trace = k1_trace[:k1_trace.index(mrca)]
    k2_trace.reverse() # walk back down from the mrca, so the path is unidirectional
    total_trace = []
    total_trace += shared_k1_trace
    total_trace += k2_trace
    
    return total_trace
    
def sum_attr(trace, attr='dTiter'):
    '''Collects all values of the passed attribute found 
    along the passed path through the tree'''
    if len(trace) == 0:
        return 0.
    return sum([i.traits[attr] for i in trace])

In [19]:
def antigenically_uniform(node, tree):
        
    node_cTiters = [round(k.traits['cTiter'], 2) for k in tree.traverse_tree(node, include_all=True) if k.traits.has_key('cTiter')]
    node_uniform = len(set(node_cTiters)) == 1
    
    parent_cTiters = [round(k.traits['cTiter'], 2) for k in tree.traverse_tree(node.parent, include_all=True) if k.traits.has_key('cTiter')]
    parent_uniform = len(set(parent_cTiters)) == 1
    
    if node_uniform==True and parent_uniform == False:
        return True
    else:
        return False
                          
def collapse_antigenic_phenotypes(tree):
    tree_copy = deepcopy(tree)
    sorted_branches = sorted(tree_copy.nodes, key = lambda k: k.height, reverse=False)
    
    to_collapse = len(sorted_branches)
    for k in sorted_branches:
        if antigenically_uniform(k, tree_copy):
            tree_copy.collapseSubtree(k, k.traits['clade'], widthFunction=lambda x: 1)
        else:
            continue
    return tree_copy

In [20]:
json_translation = {
    'name': 'strain',
    'height': 'xvalue',
}
tree = bt.loadJSON(tree_path, json_translation) ## baltic Tree object


Tree height: 0.368260
Tree length: 6.105600
annotations present

Numbers of objects in tree: 2996 (1426 nodes and 1570 leaves)



In [21]:
def parse_cluster_frequencies(path):
    '''
    {'south_america': 
            {0: [ 0.1, 0.4, 0.23, ....]},
        }, 
    'pivots': [1900, 1901, ...]
    'genotypes': set('DENV2_AMERICAN', ...)}
    '''    
    raw_frequencies = json.load(open(path, 'r')) ## {'southeast_asia_clade:179':[0.1, 0.23, ...]}
    pivots = raw_frequencies.pop('pivots') # [1917., 1918., ...]
    
    regional_clade_frequencies = defaultdict(dict) ## {'global': {'0': [0.1, 0.23, ...] } }
    
    for label, freqs in raw_frequencies.items():
        if 'clade:' not in label: #south_america_II, africa_denv4_II
            split_label = label.split('_')

            if 'DENV' in label:
                region = '_'.join(split_label[:-2])
                clade = '_'.join(split_label[-2:])
            else:
                region = '_'.join(split_label[:-1])
                clade = '_'.join(split_label[-1:])
        else: # 'global_clade:0'
            region, clade = label.split('_clade:')
            
        try:
            clade = int(clade) # the tree uses integers for clade indices
            regional_clade_frequencies[region][clade] = freqs 
        except:
            continue
        
    return {'frequencies': dict(regional_clade_frequencies), 
            'pivots': pivots }

all_frequencies = parse_cluster_frequencies(frequencies_path)
pivots = all_frequencies['pivots']
all_frequencies = all_frequencies['frequencies']

In [22]:
for region, freqs in all_frequencies.items():
    freqs = pd.DataFrame(all_frequencies[region], index=pivots)
    out_name = out_path+region+'_clade_frequencies.csv'
    freqs.to_csv(out_name)

In [8]:
class antigenic_cluster():
    def __init__(self, node, tree):
        self.clade = node.traits['clade']
        self.mrca = node
        self.descendents = tree.traverse_tree(node, include_all=True)
        self.leaves = [k for k in self.descendents if k.branchType=='leaf']
        
        contains_serotypes = set([k.traits['strain'].split('/')[0] 
                                  for k in self.leaves
                                  if k.traits['strain'].split('/')[0] != 'DENV'])
        if len(contains_serotypes) == 1:
            self.serotype = contains_serotypes.pop()
            self.name = '%s_%d'%(self.serotype, self.clade)
        else:
            self.serotype = None
            self.name = 'DENV_%d'%self.clade
        
        self.genotypes = set([k.traits['genotype'] for k in self.leaves 
                             if 'genotype' in k.traits])
        
        self.N = len(self.leaves)
        self.cTiter = node.traits['cTiter']
        self.dTiters = {}
        self.frequencies = {}

antigenic_clusters = [ antigenic_cluster(k, tree) for k in tree.nodes 
                      if antigenically_uniform(k, tree) ]

dTiters = defaultdict(dict)

for (c0,c1) in combinations(antigenic_clusters, 2):
    dt = sum_attr(trace_between(c0.mrca, c1.mrca))
    c0.dTiters[c1.name] = dt
    c1.dTiters[c0.name] = dt
    dTiters[c0.name][c1.name] = dt
    dTiters[c1.name][c0.name] = dt

dTiters = pd.DataFrame(dTiters)

In [35]:
def pull_cluster_frequencies(antigenic_clusters, region, frequencies):
    
    cluster_frequencies = {}
    for cluster in antigenic_clusters:
        freqs = frequencies[region][cluster.clade]
        cluster_frequencies[cluster.name] = freqs
        cluster.frequencies[region] = freqs
    
    return cluster_frequencies

def make_dataframe(genotype_frequencies, pivots, mindate=1970.):
    ''' dataframe with genotypes as columns, pivots as index, region-specific frequencies as values'''
    df = pd.DataFrame(genotype_frequencies, index=pivots)
    df = df.loc[df.index[df.index >= mindate]]
    return df

cluster_frequencies = {region: pull_cluster_frequencies(antigenic_clusters, region, all_frequencies)
                       for region in ['southeast_asia', 'south_america', 'global']}

cluster_frequencies = {k:make_dataframe(v, pivots) for k,v in cluster_frequencies.items()}

In [40]:
for region, freqs in cluster_frequencies.items():
    out_name = out_path+region+'_cluster_frequencies.csv'
    freqs.to_csv(out_name)
    
dTiters.to_csv(out_path+'cluster_dTiters.csv')