In [1]:
import __future__
import sys
import json
sys.path.append('/Users/bpotter/nextstrain/augur')
from base import io_util

In [2]:
# Set default data directories
data_dir = '/Users/bpotter/nextstrain/auspice/data/'
ha_file = data_dir+'flu_seasonal_h3n2_ha_2y_tree.json'
na_file = data_dir+'flu_seasonal_h3n2_na_2y_tree.json'

# Mutations that identify A and B identities
A_mut = u'P21S'
B_mut = u'P386S'

In [3]:
# Read in HA and NA tree JSONs
with open(ha_file, 'r') as f:
    dump = json.loads(f.read())
    ha_tree = io_util.json_to_tree(dump)
with open(na_file, 'r') as f:
    dump = json.loads(f.read())
    na_tree = io_util.json_to_tree(dump)

In [4]:
def label_na_clade(clade, na_label, B_mutation, na_strains_lookup):
    '''Assign labels b or B to nodes in the NA tree recursively.
    
    Adds leaves to dictionary that allows for easy state-lookup based
    on strain name, which gets used by label_na_na_co_clades.
    '''
    label = na_label
    if na_label == 'b' and hasattr(clade, 'aa_muts'):
        if B_mutation in clade.aa_muts[u'NA']:
            label = 'B'
    setattr(clade, 'na_mut_label', label)
    if clade.clades == []:
        na_strains_lookup[clade.strain] = na_label
    else:
        for child in clade.clades:
            label_na_clade(child, label, B_mutation, na_strains_lookup)

In [5]:
def label_ha_na_co_clades(clade, ha_label, A_mutation, 
                          na_strains, 
                          counts={'ab': 0, 'Ab': 0, 'aB': 0, 'AB': 0, 'unmatched': 0}):
    '''Assign ab, aB, Ab, and AB labels recursively to HA tree.
    
    Relies on lookup dictionary to b/B status from NA tree and keeps track of
    counts of each state. Returns thouse counts.
    
    Count return is mostly for debugging, can be removed later.
    '''
    label = ha_label
    if ha_label == 'a' and hasattr(clade, 'aa_muts'):
        try:
            if A_mutation in clade.aa_muts[u'HA1']:
                label = 'A'
        except:
            print(clade.aa_muts)
    if clade.clades == []:
        try:
            matched_na_node_label = na_strains[clade.strain]
            setattr(clade, 'co_clade_label', label+matched_na_node_label)
        except:
            setattr(clade, 'co_clade_label', 'unmatched')
        counts[clade.co_clade_label] += 1
    for child in clade.clades:
        label_ha_na_co_clades(child, label, A_mutation, na_strains, counts)
        
    return counts

In [6]:
NA_strains = {}
label_na_clade(na_tree, 'b', B_mut, NA_strains)

In [7]:
C = label_ha_na_co_clades(ha_tree, 'a', A_mut, NA_strains)

In [8]:
C

{'AB': 16, 'Ab': 1, 'aB': 286, 'ab': 1514, 'unmatched': 134}