In [200]:
import pandas as pd
import numpy as np
import json
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

In [14]:
with open ('nextstrain/auspice/data/flu_seasonal_h3n2_ha_6y_tree.json', 'r') as jsonfile:
    json_data = json.load(jsonfile)


In [243]:
#Pull out mutations at tree tips
tip_strain_muts = {}

def traverse(branch):
    if 'children' not in branch.keys():
        tip_strain_muts[branch['strain']]=[branch['aa_muts']['HA1'], branch['aa_muts']['HA2'], 
                                           branch['aa_muts']['SigPep'],branch['attr']['num_date'], 
                                           branch['attr']['clade_membership']]
    else:
        for child in branch['children']:
            traverse(child)


traverse(json_data)

In [274]:
#Make dictionary into a pandas df
df = pd.DataFrame(tip_strain_muts).T
df.reset_index(inplace=True)
df.rename(columns={'index': 'strain',0: "HA1_muts", 1: "HA2_muts", 2: "SigPep_muts", 3: "date", 4: "clade"}, inplace=True)
df['passage'] = np.select((df.strain.str.contains('egg'), df.strain.str.contains('cell')), ('egg', 'cell'))

df.head(20)

Unnamed: 0,strain,HA1_muts,HA2_muts,SigPep_muts,date,clade,passage
0,A/AbuDhabi/221/2017-cell,[],[A201T],[],2017.97,A1b/135K,cell
1,A/AbuDhabi/258/2018-cell,[],[],[],2018.0,A1b/135K,cell
2,A/Acores/SU43/2012-egg,[],[],[A16T],2012.85,unassigned,egg
3,A/Adana/A15/2017-cell,[],[],[],2017.09,A3,cell
4,A/Afghanistan/243/2016-cell,[],[],[],2016.44,3c3.A,cell
5,A/Afghanistan/437/2017,[],[],[],2017.85,A1b/135K,0
6,A/Afghanistan/624/2017,[],[],[],2017.9,A1b/135K,0
7,A/Aichi/116/2013-cell,[],[],[],2013.43,3c3,cell
8,A/Aichi/118/2013-cell,[],[],[],2013.51,3c3,cell
9,A/Aichi/119/2013-cell,[],[],[],2013.83,3c3,cell


In [245]:
#make tidy version of df where each mutation gets a row
mut_df = pd.DataFrame(columns=['mutation']+ list(df.columns))

count=0
for i, r in df.iterrows():

    for ha1_mut in r['HA1_muts']:
        mut_df.loc[count]= ['HA1'+str(ha1_mut)] + list(df.loc[i])
        count+=1
        
    for ha2_mut in r['HA2_muts']:
        mut_df.loc[count]= ['HA2'+str(ha2_mut)] + list(df.loc[i])
        count+=1
        
    for sp_mut in r['SigPep_muts']:
        mut_df.loc[count]= ['SP'+str(sp_mut)] + list(df.loc[i])
        count+=1


In [310]:
top_muts = {}
for pas_type in mut_df['passage'].unique():
    top = mut_df[mut_df.passage==pas_type].groupby('mutation')['mutation'].count().sort_values(ascending=False)[:10]
    top_muts[pas_type] = list((g_name, g) for g_name, g in top.iteritems())
pd.DataFrame(top_muts)

Unnamed: 0,0,cell,egg
0,"(HA1V323I, 5)","(HA1T160K, 13)","(HA1G186V, 52)"
1,"(HA1D53N, 5)","(HA2A201V, 11)","(HA1L194P, 51)"
2,"(HA2N116S, 5)","(HA1R261Q, 11)","(HA1T160K, 23)"
3,"(HA1I140M, 4)","(HA2V18M, 10)","(HA1D225G, 18)"
4,"(HA2A43S, 4)","(HA1D53N, 9)","(HA1S219F, 17)"
5,"(HA1S219F, 3)","(HA1N158K, 9)","(HA1T203I, 14)"
6,"(HA2I186V, 3)","(SPA16T, 9)","(HA1S219Y, 11)"
7,"(HA1I25V, 3)","(HA1N225D, 8)","(HA1H156R, 9)"
8,"(HA1R33Q, 3)","(HA1P221L, 8)","(HA1A138S, 9)"
9,"(HA1F193S, 3)","(HA1V112I, 8)","(HA1H156Q, 8)"
