In [1]:
from augur.utils import json_to_tree
import pandas as pd
from Bio import SeqIO
from Bio.Seq import MutableSeq
import json

### Count mutations in each gene at each node in the tree

Strategy: 
1. put json tree in BioPhylo format to traverse it and count mutations at each node (relative to root)
2. save the mutations at each node in a dictionary with the node name as key
3. traverse the tree in JSON format and add the mutation information to the JSON object by looking up the node from the above dictionary
4. replace the json['tree'] with this edited tree and save the whole json again
5. also give the ability to add a color_by option for auspice to color the tree by the mutation tally

First, get the locations of all genes, and find the codons within the gene

In [47]:
# store gene location information
location_by_gene = {}
# store the wuhan-1 sequence for each gene
wuhan_seq_by_gene = {}

with open("reference_all_genes.gb") as reference_handle:
    for record in SeqIO.parse(reference_handle, "genbank"):
        wuhan_seq = record.seq
        for feature in record.features:
            if feature.type == 'CDS':
                gene = feature.qualifiers['gene'][0]
                location = feature.location
                location_by_gene[gene] = location
                wuhan_seq_by_gene[gene] = feature.location.extract(wuhan_seq)

In [3]:
all_genes = list(location_by_gene.keys())

In [4]:
# map nucleotide position to gene
# allow a position to map to multiple genes (ex: this will happen for S1 and S, orf9b and N)
# if position in not in any of the annotated genes, call it 'noncoding'
nt_to_gene_mapper = {i+1:[] for i in range(len(wuhan_seq))}

for gene, loc in location_by_gene.items():
    for pos in range(loc.start+1, loc.end+1):
        nt_to_gene_mapper[pos].append(gene)

for k, v in nt_to_gene_mapper.items():
    if len(v)==0:
        nt_to_gene_mapper[k] = ['noncoding']

In [52]:
def readin_tree(date, tree_type):
    """
    Read in the 2m tree json for the specified date
    """

    # path to tree json
    if tree_type == 'sars2':
        tree_file = f'../ncov_builds/auspice_{tree_type}_2m/sars2_{date}_2m.json'
    elif tree_type == '21L':
        tree_file = f'../ncov_builds/auspice_{tree_type}_2m/sars2_21L_{date}_2m.json'    
    elif tree_type == 'alltime':
        tree_file = f'../ncov_builds/auspice_sars2_{tree_type}/sars2_global_all-time.json'

    with open(tree_file, 'r') as f:
        tree_json = json.load(f)

    # put tree in Bio.phylo format
    tree = json_to_tree(tree_json)
    
    return tree

In [41]:
def count_mutations_by_type(gene, muts):
    """
    Given a list of mutations in this gene, apply them to the reference sequence 
    for this gene and count the number of mutations
    """
    # get the gene sequence of wuhan-1 
    wuhan_gene_seq = wuhan_seq_by_gene[gene]
    translated_wuhan_seq = wuhan_gene_seq.translate()
        
    
    # get the genome coordinates for this gene
    gene_start, gene_end = location_by_gene[gene].start, location_by_gene[gene].end
    
    # apply each mutation in this gene to the wuhan-1 reference sequence
    mutated_gene_seq = MutableSeq(wuhan_gene_seq)

    
    # muts are listed oldest to newest, so if multiple mutations occur at same nt, 
    # the newer one should correctly overwrite the older one
    for mut in muts:
        # convert genome coordinates to gene-specific coordinates 
        mut_pos = int(mut[1:-1])-gene_start-1
        # deal with ribosomal frameshift in Nsp12
        if gene=='Nsp12' and int(mut[1:-1])>13468:
            mut_pos = mut_pos+1

        #apply the mutation
        mutated_gene_seq[mut_pos] = mut[-1]
        # if the mutation happens at the nt where ribosomal frameshift occurs, need to apply the mutation to 2 nt positions
        if int(mut[1:-1])==13468:
            mutated_gene_seq[mut_pos+1] = mut[-1]

        
    translated_mut_seq = mutated_gene_seq.translate()

    
    # now tally the number of each kind of mutation
    # find stops (other than those that occur at the very end)
    stops_in_gene = str(translated_mut_seq[:-1]).count('*')
    
    # find number of nonsyn mutations
    nonsyn_in_gene = sum(1 for a, b in zip(translated_wuhan_seq, translated_mut_seq) if a != b) - stops_in_gene
    

    
    # find number of syn mutations (as the total number of muts that are not stop or nonsyn)
    syn_in_gene = len(muts) - nonsyn_in_gene - stops_in_gene
    
    return nonsyn_in_gene, syn_in_gene, stops_in_gene

In [49]:
def count_mutations(date, tree_type):
    """
    For each node in the tree, find the number of mutations (syn and nonsyn) at each node 
    and add these tallies as an attribute of the node
    Return a dictionary where key is node name and value is the mutation accumulation info
    """
    tree = readin_tree(date, tree_type)
    
    # initialize dictionary to store mut info
    muts_by_node = {}
    
    for node in tree.find_clades():
        
        # keep track of nonsyn, syn, stop muts on this path
        # by gene
        nonsyn_total_by_gene = {x:0 for x in all_genes}
        syn_total_by_gene = {x:0 for x in all_genes}
        stop_total_by_gene = {x:0 for x in all_genes}
        
        # get path back to the root
        path = tree.get_path(node)
        
        # get all nucleotide mutations relative to root
        nt_muts = [branch.branch_attrs['mutations'].get('nuc', []) for branch in path]
        # flatten the list of nucleotide mutations
        nt_muts = [item for sublist in nt_muts for item in sublist]
        # ignore deletions and "reversions" from deletions
        nt_muts = [x for x in nt_muts if '-' not in x] 

        # find which gene each mut is in
        mut_and_genelocation = {mut: nt_to_gene_mapper[int(mut[1:-1])] for mut in nt_muts}
        # group the mutations by gene
        muts_per_gene = {}
        for m, gs in mut_and_genelocation.items():
            # some mutations will map to multiple genes
            for g in gs:
                # don't need to keep track of the noncoding muts
                if g!='noncoding':
                    if g in muts_per_gene.keys():
                        muts_per_gene[g].append(m)
                    else:
                        muts_per_gene[g] = [m]

        # for each gene, find how many syn, and nonsyn and stop muts
        for gene, gene_nt_muts in muts_per_gene.items():
            nonsyn_in_gene, syn_in_gene, stops_in_gene = count_mutations_by_type(gene, gene_nt_muts)
            nonsyn_total_by_gene[gene] = nonsyn_in_gene
            syn_total_by_gene[gene] = syn_in_gene
            stop_total_by_gene[gene] = stops_in_gene
            
        node.node_attrs['Nonsyn_muts'] = nonsyn_total_by_gene
        node.node_attrs['Syn_muts'] = syn_total_by_gene
        node.node_attrs['Stop_muts'] = stop_total_by_gene
        
        muts_by_node[node.name] = {'Nonsyn_muts':nonsyn_total_by_gene, 'Syn_muts': syn_total_by_gene, 
                                   'Stop_muts': stop_total_by_gene}
    
    
    return muts_by_node 

In [8]:
def traverse_tree(branch, muts_by_node, color_by_genes):
    """
    Traverse the tree and add the mutation accumulation at each node
    """

    branch['mut_accumulation'] = muts_by_node[branch['name']]
    
    if color_by_genes:
        for gene in color_by_genes:
            branch['node_attrs'][f'{gene}_mutations'] = {"value": muts_by_node[branch['name']]['Nonsyn_muts'][gene]}
    
    if 'children' in branch.keys():
        for child in branch['children']:
            traverse_tree(child, muts_by_node, color_by_genes)
            
    return branch

In [53]:
def edit_tree(date, tree_type, color_by_genes=False):
    """
    Edit the tree as a JSON object to have the mutation tally information at each node
    Can also supply a list of genes so that the auspice tree can be colored by mutation tally in these genes
    Save the edited tree
    tree_type options= ['21L', 'sars2']
    """
    # get dicionary of the mutation tally at each node
    muts_by_node = count_mutations(date, tree_type)
    
    # path to tree json
    if tree_type == 'sars2':
        tree_file = f'../ncov_builds/auspice_{tree_type}_2m/sars2_{date}_2m.json'
    elif tree_type == '21L':
        tree_file = f'../ncov_builds/auspice_{tree_type}_2m/sars2_21L_{date}_2m.json'
    if tree_type == 'alltime':
        tree_file = f'../ncov_builds/auspice_sars2_{tree_type}/sars2_global_all-time.json'

    with open(tree_file, 'r') as f:
        tree_json = json.load(f)
    
    # add mutation tallies to each node
    tree_edited = traverse_tree(tree_json['tree'], muts_by_node, color_by_genes)
    # replace the 'tree' in the auspice json with the edited tree
    tree_json['tree'] = tree_edited
    
    # add color by option for the specified genes
    if color_by_genes:
        for gene in color_by_genes:
            tree_json['meta']['colorings'].append({
                "key": f"{gene}_mutations",
                "title": f"{gene} mutations",
                "type": "continuous"})

    # serialize json
    json_object = json.dumps(tree_json, indent=2)

    # Write out edited tree file
    if tree_type == 'sars2':
        with open(f'trees_w_mut_counts/{tree_type}_{date}_2m.json', "w") as outfile:
            outfile.write(json_object)
    elif tree_type == '21L':
        with open(f'trees_w_mut_counts/sars2_{tree_type}_{date}_2m.json', "w") as outfile:
            outfile.write(json_object)
    elif tree_type == 'alltime':
        with open(f'trees_w_mut_counts/sars2_{tree_type}_2023-06.json', "w") as outfile:
            outfile.write(json_object)

In [54]:
edit_tree('2023-06', 'alltime', color_by_genes=all_genes)

In [44]:
# edit all 2m trees
all_dates = ['2020-03', '2020-05', '2020-07', '2020-09', '2020-11', 
             '2021-01', '2021-03', '2021-05', '2021-07', '2021-09', '2021-11', 
             '2022-01', '2022-03', '2022-05', '2022-07', '2022-09', '2022-11', 
             '2023-01', '2023-03', '2023-05']

for d in all_dates:
    edit_tree(d, 'sars2', color_by_genes=all_genes)

In [45]:
# edit all 2m 21L trees

all_dates = ['2022-03', '2022-04', '2022-05', '2022-06', '2022-07', 
             '2022-08', '2022-09', '2022-10', '2022-11', 
             '2023-01', '2023-02', '2023-03', '2023-04', '2023-05', '2023-06']

for d in all_dates:
    edit_tree(d, '21L', color_by_genes=all_genes)