In [1]:
import argparse
import json
import re
import ast
import collections
import pandas as pd
import numpy as np
from Bio import SeqIO

In [2]:
#Load tree file
tree_path = '../auspice/flu_seasonal_h3n2_ha_12y_hi_tree.json'
with open(tree_path, 'r') as tree_json:
    tree = json.load(tree_json)

#Load sequences file
seq_path = '../results/aa-seq_who_h3n2_ha_12y_concat_hi_HA1.fasta'
seqs = SeqIO.to_dict(SeqIO.parse(seq_path, "fasta"))

#Load root seq
root_path = '../auspice/flu_seasonal_h3n2_ha_12y_hi_root-sequence.json'
with open(root_path, 'r') as root_json:
    root_seq = json.load(root_json)

positions = [160, 194, 186, 225, 219, 203, 156, 138, 246]

tip_muts = {}
def traverse(branch, seq, root, pos_list):

    #keep track of mutations at internal nodes
    if 'children' in branch.keys():
        for child in branch['children']:
            if 'HA1' in child['aa_muts']:
                traverse_aa.append({str(child['clade']):child['aa_muts']['HA1']})
                # traverse_aa.append(child['aa_muts']['HA1'])
                # aa_mut_clade.append({str(child['clade']):child['aa_muts']['HA1']})
                traverse(child, seq, root, pos_list)
                traverse_aa.remove({str(child['clade']):child['aa_muts']['HA1']})
                # traverse_aa.remove(child['aa_muts']['HA1'])
                # aa_mut_clade.remove({str(child['clade']):child['aa_muts']['HA1']})

            else:
                traverse(child, seq, root, pos_list)

    elif 'children' not in branch.keys():

        # muts_list = [str(mut) for sublist in traverse_aa for mut in sublist]
        # aa_mut_clade_list = [str(mut) for mut in aa_mut_clade]
        # last_node = [str(mut) for sublist in traverse_aa[:-1] for mut in sublist]

        aa_mut_clade_list = [str(mut) for mut in traverse_aa]
        muts_list = [str(mut) for sublist in [list(ast.literal_eval(x).values())[0]
                                              for x in aa_mut_clade_list] for mut in sublist]
        last_node = muts_list

        if 'HA1' in branch['aa_muts']:
            branch_tip_muts = len(branch['aa_muts']['HA1'])
            last_node = last_node[:-(branch_tip_muts)]

        #Find sequence of tip and sequence one branch in
        tip_sequence = seq[branch['strain']].seq
        last_node_sequence = root_seq['HA1']

        for mut in last_node:
            internal_mut_pos = int(re.findall('\d+', mut)[0])
            internal_mut_aa = mut[-1:]
            last_node_sequence = last_node_sequence[:internal_mut_pos-1] + internal_mut_aa + last_node_sequence[internal_mut_pos:]

        tip_muts[branch['strain']]=([(branch['aa_muts']['HA1']
                                    if 'HA1' in branch['aa_muts'] else []),
                                    (branch['aa_muts']['HA2']
                                    if 'HA2' in branch['aa_muts'] else []),
                                    (branch['aa_muts']['SigPep']
                                    if 'SigPep' in branch['aa_muts'] else []),
                                    branch['attr']['num_date'],
                                    (branch['muts'] if 'muts' in branch else []),
                                    (branch['attr']['dTiterSub'] if 'dTiterSub' in branch['attr'] else None),
                                    (branch['attr']['cTiterSub'] if 'cTiterSub' in branch['attr'] else None),
                                    branch['attr']['clade_membership'],
                                    branch['attr']['kk_clade'],
                                    aa_mut_clade_list] +
                                    [tip_sequence[pos-1] for pos in pos_list] +
                                    [last_node_sequence[pos-1] for pos in pos_list])

traverse_aa = []
traverse(tree, seqs, root_seq, positions)

df = pd.DataFrame(tip_muts).T
df.reset_index(inplace=True)
df.columns = ['strain', 'tip_HA1_muts', 'tip_HA2_muts', 'tip_SigPep_muts', 'date', 'tip_nt_muts', 'dTiterSub','cTiterSub', 'clade', 'kk_clade', 'aa_mut_list'] + positions + [str(x)+'_lastnode' for x in positions]
df['dTiterSub'], df['cTiterSub']= df['dTiterSub'].astype(float, inplace=True), df['cTiterSub'].astype(float, inplace=True)
df['passage'] = np.select((df.strain.str.contains('egg'), df.strain.str.contains('cell')), ('egg', 'cell'))
df['passage'] = np.where(df['passage']=='0', 'unpassaged', df['passage'])
    #Identify pairs where strain sequence exists for multiple passage types
df['source'] = np.select((df.passage=='egg', df.passage=='cell', df.passage=='unpassaged'),
                         (df.strain.str.replace('-egg',''), df.strain.str.replace('-cell',''), df.strain))
e_u_df = df[(df['passage']=='egg') | (df['passage']=='unpassaged')]
pairs_u = e_u_df[e_u_df.duplicated(subset='source', keep=False)]
e_c_df = df[(df['passage']=='egg') | (df['passage']=='cell')]
pairs_c = e_c_df[e_c_df.duplicated(subset='source', keep=False)]
pairs = pd.concat([pairs_u, pairs_c])
pair_ids = dict(zip(list(pairs['source'].unique()),[[n+1] for n in range(len(pairs['source'].unique()))]))
pair_ids = pd.DataFrame(pair_ids).T.reset_index().rename(columns={'index':'source', 0:'pair_id'})
df = df.merge(pair_ids, on='source', how='left')
df['pair_id']= df['pair_id'].fillna(0)
df['pair_id'] = df['pair_id'].astype(int, inplace=True)

#Determine whether there sequence has mutated relative to ancestor 1 branch in, at each position
for p in positions:
    df['mut'+str(p)] = np.select(
    (df[p]==df[str(p)+'_lastnode'], df[p]!=df[str(p)+'_lastnode']),
    (False, True))
for p in positions:
    df['aa_mut'+str(p)] = np.where(df['mut'+str(p)]==1, df[str(p)+'_lastnode']+str(p)+df[p], None)


In [21]:
df['egg_muts'] = np.empty((len(df), 0)).tolist()
max_internal_length=df['aa_mut_list'].map(len).max()

for internal_branch in range(0,max_internal_length):
    sub_df = df[df['aa_mut_list'].map(len) > internal_branch]

    group= sub_df.groupby((sub_df.aa_mut_list.apply(lambda col: col[0:(internal_branch+1)])).map(tuple))
    for k, v in group:
        if len(v[v['passage']=='egg']) != 0:
            #For egg-only clusters
            if len(v.groupby('passage')) == 1:
                recent_muts = list(ast.literal_eval(k[-1]).values())[0]
                for k_strain, v_strain in v.iterrows():
                    df.at[k_strain, 'egg_muts']+=recent_muts
                    
            #If all non-egg seqs have phylogenetically inferred 'reversion' at cluster-defining mutation
            else:
                non_egg_in_cluster = len(v[v['passage']!='egg'])
                recent_muts = list(ast.literal_eval(k[-1]).values())[0]
                
                for recent_mut in recent_muts:
                    non_egg_in_cluster_reversion = 0
                    
                    start_aa = recent_mut[0]
                    site = int(re.findall('\d+', recent_mut)[0])
                    end_aa = recent_mut[-1]
                    reversion_mut = end_aa + str(site)
                    
                    for k_strain, v_strain in v.iterrows():
                        if v_strain['passage']!='egg':
                            strain_aa = seqs[v_strain['strain']][int(site)-1]
                            if strain_aa != end_aa:
                                non_egg_in_cluster_reversion+=1
                                
                    if non_egg_in_cluster == non_egg_in_cluster_reversion:
                
                        for k_strain, v_strain in v.iterrows():
                            if v_strain['passage']=='egg':
                                df.at[k_strain, 'egg_muts'] += [recent_mut]
                        

# for k,v in df.iterrows():
#     #Find mutations for all egg seqs
#     if len(v['egg_muts'])>=1:
#         revised_egg_muts = list(v['egg_muts'])
#         mutated_sites = [int(re.findall('\d+', egg_mut)[0]) for egg_mut in v['egg_muts']]
#         multiple_mut_sites = [item for item, count in collections.Counter(mutated_sites).items() if count > 1]

#         for dup_site in multiple_mut_sites:
#             mult_mutations = []
#             for egg_mut in v['egg_muts']:
#                 if str(dup_site) in egg_mut:
#                     mult_mutations.append(egg_mut)
#                     revised_egg_muts.remove(egg_mut)
#             start_aa = mult_mutations[0][0]
#             end_aa = mult_mutations[len(mult_mutations)-1][-1]
#             if start_aa != end_aa:
#                 revised_egg_muts.append(start_aa+str(dup_site)+end_aa)
#         df.at[k, 'egg_muts'] = revised_egg_muts
#         #Find mutation(s) at specified positions
#         for recent_mut in revised_egg_muts:
#             site = int(re.findall('\d+', recent_mut)[0])
#             if site in positions:
#                 df.at[k, 'mut' + str(site)] = 1
#                 df.at[k, 'aa_mut' + str(site)] = recent_mut
#                 df.at[k, str(site) + '_lastnode'] = recent_mut[0]