In [25]:
import json
import math
from collections import Counter
from augur.utils import json_to_tree
import pandas as pd
import numpy as np

## Find adaptive sites 
those that exceed 2.5% prevalence overall
or over 7.5% in time window of 3 years

In [16]:
def get_muts_over_threshold(virus, segment, gene, threshold):
    """
    Return all muts over the threshold
    """
    
    muts_file = f"../egg-mutation-analysis/egg-mut-counts/{virus}_{segment.upper()}_{gene.upper()}_egg-mut-counts.json"
    
    with open(muts_file) as json_handle:
        egg_mut_info = json.load(json_handle)
    
    
    sites_over_threshold = list(egg_mut_info[f'sites_mutated_above_{threshold}percent'].keys())
                
    
    return sites_over_threshold
        
        

In [17]:
def get_muts(virus, segment, gene):
    """
    Return all nonsyn muts in this gene
    """
    
    curated_mut_file = f"../egg-mutation-analysis/egg-muts-by-strain/{virus}_{segment}_curated-egg-muts.json"
    
    with open(curated_mut_file) as json_handle:
        egg_mut_info = json.load(json_handle)
        
    muts_by_strain = {}
        
    for strain, muts in egg_mut_info.items():
        
        muts_by_strain[strain] = muts.get(gene, [])

    return muts_by_strain

In [18]:
def get_strain_year(virus, segment):
    """
    Return the year for all egg-passaged strains, 
    and also for all other strains (so we can determine what percentage are egg-passaaged)
    """

    tree_path= f'../nextstrain_builds/egg-enriched/auspice/{virus}_{segment}_egg.json'
    
    #read in the tree
    with open(tree_path, 'r') as f:
        tree_json = json.load(f)
        
    #put tree in Bio.phylo format
    tree = json_to_tree(tree_json)
    
    eggstrains_to_year = {}
    
    for node in tree.find_clades(terminal=True):
        passage = node.node_attrs['passage_category']['value']
        date = node.node_attrs['num_date']['value']
        year = math.trunc(date)
        if passage == 'egg':
            eggstrains_to_year[node.name] = year
            
    return eggstrains_to_year

In [19]:
def get_muts_in_windows(virus, segment, gene, window_len, sites):
    """
    Partition all strains into sliding time windows and find the percentage of strains in that window
    with each mutation
    Can specify whether to do this for just certain sites, or for entire gene
    """
    
    # get date for each strain
    eggstrains_to_year = get_strain_year(virus, segment)
    
    # get aa muts in this gene for each strain
    muts_by_strain = get_muts(virus, segment, gene)
    
    # make sliding windows
    max_year = max(eggstrains_to_year.values())
    min_year = min(eggstrains_to_year.values())
    
    all_years = years = list(range(min_year, max_year + 1))
    windows = [all_years[i:i + window_len] for i in range(len(all_years) - window_len + 1)]
    
    # keep track of muts at egg adaptive sites in each window
    # as percentage of all egg-strains in that window
    res_muts_all_windows_pct = {}
    aa_muts_all_windows_pct = {}
    
    # and also as a count
    res_muts_all_window_count = {}
    aa_muts_all_window_count = {}
    
    # total number of egg strains in window
    eggstrains_per_window = {}
    
    # find egg_muts in this window
    for window in windows:
        # number of times each residue has recieved mut in the window
        res_muts_in_window = []
        # count of each mut within the window
        aa_muts_in_window = []
        
        # get all strains in window
        eggstrains = [k for k,v in eggstrains_to_year.items() if v in window]
        
        # require at least 5 egg strains per window
        if len(eggstrains)>=5:
        
            # get aa muts in these strains
            muts_in_window = []
            for e in eggstrains:
                # if looking at all muts in gene
                if sites == 'gene':
                    aa_muts_in_window += muts_by_strain[e]
                    res_muts_in_window += [x[:-1] for x in muts_by_strain[e]]
                else:
                    muts_in_site_list = [x for x in muts_by_strain[e] if int(x[:-1]) in sites]
                    aa_muts_in_window += muts_in_site_list
                    res_muts_in_window += [x[:-1] for x in muts_in_site_list]

            # count number of egg strains that got particular muts in this window
            res_muts_in_window_count = Counter(res_muts_in_window)
            # or number of egg strains with muts at particular residue
            aa_muts_in_window_count= Counter(aa_muts_in_window)      

            # convert to percentage of all strains in window
            res_muts_in_window_pct = {m:c/len(eggstrains) for m,c in res_muts_in_window_count.items()}
            # convert to percentage of all strains in window
            aa_muts_in_window_pct = {m:c/len(eggstrains) for m,c in aa_muts_in_window_count.items()}


            # store muts in this window by its midpoint
            window_midpoint = sum(window)/len(window)

            eggstrains_per_window[window_midpoint] = len(eggstrains)


            res_muts_all_windows_pct[window_midpoint] = res_muts_in_window_pct
            aa_muts_all_windows_pct[window_midpoint] = aa_muts_in_window_pct
            res_muts_all_window_count[window_midpoint] = res_muts_in_window_count
            aa_muts_all_window_count[window_midpoint] = aa_muts_in_window_count

            
        

        
    return (res_muts_all_windows_pct, aa_muts_all_windows_pct, 
            res_muts_all_window_count, aa_muts_all_window_count, eggstrains_per_window)
    
    

In [91]:
def get_overall_mut_count(virus, segment, gene, min_occurrences = 5):
    """
    Get the overall occurrences of this mut on the tree
    Mut needs to occur at least 5 times to even be considered for potential adaptive
    """
    muts_file = f"../egg-mutation-analysis/egg-mut-counts/{virus}_{segment.upper()}_{gene.upper()}_egg-mut-counts.json"
    
    with open(muts_file) as json_handle:
        egg_mut_info = json.load(json_handle)
        
    egg_mut_counts = {m:f*egg_mut_info['total_num_egg_strains'] for m,f in egg_mut_info['egg_mut_freqs'].items()}
    
    egg_muts_over_min = {k:v for k,v in egg_mut_counts.items() if v>=min_occurrences}
    
    return egg_muts_over_min

In [153]:
def get_transient_adaptive_sites(virus, segment, gene, window_len):
    """
    Look for sites that get mutated in egg strains during just a period of time (and don't show up in the overall list) 
    Let's say they have to be present in over 3% of strains within a 5-year period
    """
    
    (res_muts_all_windows_pct, aa_muts_all_windows_pct, res_muts_all_window_count, 
     aa_muts_all_window_count, eggstrains_per_window) = get_muts_in_windows(virus, segment, gene, window_len, 'gene')
    
    mut_res_by_window = {y:[] for y in res_muts_all_windows_pct.keys()}
    
    # get the count of all egg muts that happen 5 or more times
    egg_muts_over_min = get_overall_mut_count(virus, segment, gene)
        
    for y, d in aa_muts_all_windows_pct.items():
        for m, p in d.items():
            if m in egg_muts_over_min.keys():
                if p>=0.075:
                    mut_res_by_window[y].append(m)
                
    egg_adaptive_sites_overall = [int(x) for x in get_muts_over_threshold(virus, segment, gene, '2_5')]
                
    muts_not_in_overall = {y:[x for x in r if int(x[:-1]) not in egg_adaptive_sites_overall] for y,r in mut_res_by_window.items()}
    # find those that appear in at least 5 windows
    window_counts = Counter([x for l in list(muts_not_in_overall.values()) for x in l])
    
    transient_muts = [m for m,c in window_counts.items() if c>=3]
    
    return transient_muts
    

In [184]:
# set aas to manually exclude (not convinced they are adaptive based on tree)
manual_exclude = {'h3n2': {186:['D', 'R']}}

In [191]:
def get_predominant_aas(virus, segment, gene, sites, cutoff=8):
    """
    Return count of all aas exceeding cutoff that are seen at the sites, ordered by number of occurrences
    """
    
    # muts in each strain
    muts_by_strain = get_muts(virus, segment, gene)
    
    # keep track of aa identity of all observed muts
    aas_by_res = {str(x):[] for x in sites}
    
    for s, ms in muts_by_strain.items():
        for m in ms:
            res = m[:-1]
            aa = m[-1]
            if int(res) in sites:
                aas_by_res[res].append(aa)
                
    # now count aas
    # order by most prevalent
    aa_counts_by_res = {x:dict(Counter(a).most_common()) for x,a in aas_by_res.items()}
    aa_counts_by_res_cutoff = {k: {m: c for m, c in v.items() if c >= cutoff} for k, v in aa_counts_by_res.items()}
    aas_by_res_cutoff = {int(k):list(v.keys()) for k,v in aa_counts_by_res_cutoff.items()}
    # manual exclude
    if virus in manual_exclude.keys():
        for s, aas in aas_by_res_cutoff.items():
            if s in manual_exclude[virus].keys():
                aas_by_res_cutoff[s] = [a for a in aas_by_res_cutoff[s] if a not in manual_exclude[virus][s]]
    
    return aas_by_res_cutoff

In [193]:
def save_adaptive_muts(segment, gene):
    """
    store adaptive mutations, from overall and time-window analyses
    """
    
    egg_adaptive_sites_overall = {}
    for v in ['h3n2', 'h1n1pdm', 'vic', 'yam']:
        egg_adaptive_sites_overall[v] = [int(x) for x in get_muts_over_threshold(v, segment, gene, '2_5')]
        
    transient_adaptive_sites = {'h3n2': [int(x[:-1]) for x in get_transient_adaptive_sites('h3n2', segment, gene, 1)],
                            'h1n1pdm': [int(x[:-1]) for x in get_transient_adaptive_sites('h1n1pdm', segment, gene, 1)], 
                            'vic': [int(x[:-1]) for x in get_transient_adaptive_sites('vic', segment, gene, 1)], 
                            'yam': [int(x[:-1]) for x in get_transient_adaptive_sites('yam', segment, gene, 1)]}
    
    adaptive_sites = {}

    for virus in egg_adaptive_sites_overall.keys():
        adaptive_sites[virus] = sorted(egg_adaptive_sites_overall[virus] + transient_adaptive_sites[virus])
    
    aas_at_adaptive_sites = {}
    # get the predominant aas at each adaptive site
    for virus in adaptive_sites.keys():
        aas_at_adaptive_sites[virus] = get_predominant_aas(virus, segment, gene, adaptive_sites[virus])
    
    # only save file if there is at least one adaptive mut
    if any(adaptive_sites.values()):
        filename = f'egg-adaptive-muts/{segment}_{gene}_adaptive-muts.json'
        
        json_to_save = {'all_adaptive': adaptive_sites, 
                        'adaptive_sites_overall': egg_adaptive_sites_overall, 
                        'adaptive_muts_transient': transient_adaptive_sites,
                        'aas_at_adaptive_sites': aas_at_adaptive_sites
                       }

        # save mutation info to json           
        # Serializing json
        json_object_to_save = json.dumps(json_to_save, indent=2)

        # Write all egg muts for each strain
        with open(filename, "w") as outfile:
            outfile.write(json_object_to_save)
    
    

In [194]:
segment_gene_combos = [('pb1', 'PB1'), ('pb2', 'PB2'), ('pa', 'PA'), 
                       ('ha', 'HA1'), ('ha', 'HA2'), ('np', 'NP'), 
                       ('na', 'NA'), ('mp', 'M1'), ('ns', 'NS1')]
for x in segment_gene_combos:
    save_adaptive_muts(x[0], x[1])

In [None]:
# printed out

In [192]:
segment_gene_combos = [('pb1', 'PB1'), ('pb2', 'PB2'), ('pa', 'PA'), 
                       ('ha', 'HA1'), ('ha', 'HA2'), ('np', 'NP'), 
                       ('na', 'NA'), ('mp', 'M1'), ('ns', 'NS1')]
for x in segment_gene_combos:
    save_adaptive_muts(x[0], x[1])

{'h3n2': {138: ['S'], 156: ['Q', 'R'], 160: ['K', 'I'], 183: ['L'], 186: ['V', 'N', 'S'], 190: ['N', 'G', 'V'], 193: ['R'], 194: ['P', 'I'], 195: ['Y'], 196: ['T'], 203: ['I'], 219: ['Y', 'F'], 225: ['G', 'N'], 226: ['I'], 246: ['K', 'H', 'S', 'T']}, 'h1n1pdm': {127: ['E'], 187: ['V', 'N', 'T'], 191: ['I'], 222: ['G', 'N'], 223: ['R']}, 'vic': {141: ['R'], 196: ['S', 'D', 'K', 'T'], 198: ['I', 'A', 'N']}, 'yam': {141: ['R'], 195: ['D', 'S', 'K'], 197: ['I', 'A', 'N', 'P']}}
{'h3n2': {290: ['N'], 384: ['R']}, 'h1n1pdm': {101: ['N'], 102: ['R'], 375: ['N']}, 'vic': {}, 'yam': {}}
