In [16]:
from augur.utils import json_to_tree
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
from Bio import SeqIO
from Bio.Seq import MutableSeq
from scipy import stats
import matplotlib.gridspec as gridspec
from collections import Counter
import requests
import math

Find the average time til next S1 mutation (on all descending paths from a given mutation). Ignoring S1 mutations that occur on the same branch as the given mutation

In [2]:
#Download tree json
tree_url = "https://data.nextstrain.org/ncov_global.json"

tree_json = requests.get(tree_url).json()

#Put tree in Bio.Phylo format
tree = json_to_tree(tree_json)

In [14]:
# make dictionary with gene name as key and reference sequence of that gene as value
reference_sequence_aa = {}
reference_sequence_nt = {}


# make dictionary giving gene by genomic location 
reference_gene_locations = {}

# make dictionary saying what codon within the gene a certain genomic location falls within
# and whether the mutation is at pos 0, 1 or 2 within codon
reference_gene_codon = {}

for record in SeqIO.parse(open("reference_seq_edited.gb","r"), "genbank"):
    genome_seq = record.seq
    for feature in record.features:
        if feature.type == 'CDS':
            # allow RdRp to overwrite Orf1a and Orf1b, 
            # to take care of changed reading frame due to  ribosome slippage
            # S1 and S2 will also overwrite spike
            for pos in range(int(feature.location.start), int(feature.location.end)):
                reference_gene_locations[pos] = feature.qualifiers['gene'][0]
                codon_num = math.floor((pos-feature.location.start)/3)
                pos_in_codon = ((pos-feature.location.start)-codon_num*3)
                reference_gene_codon[pos] = (codon_num, pos_in_codon)
                
            gene_seq = feature.location.extract(record.seq)
            reference_sequence_nt[feature.qualifiers['gene'][0]] = gene_seq
            gene_seq_aa = gene_seq.translate()
            reference_sequence_aa[feature.qualifiers['gene'][0]] = gene_seq_aa



In [3]:
def get_parent(tree, child_clade):
    node_path = tree.get_path(child_clade)
    return node_path

In [65]:
def get_branches_with_mutation(gene, mutation):
    
    # dictionary with branch that has given mutation as key and date it occurred as value
    branches_with_mutation = {}
    
    for node in tree.find_clades(terminal=False):

        if len(node.get_terminals()) >=10:
            
            if hasattr(node, "branch_attrs") and "mutations" in node.branch_attrs:
                if gene in node.branch_attrs["mutations"]:
                    if mutation in node.branch_attrs["mutations"][gene]:
                        branches_with_mutation[node.name] = node.node_attrs['num_date']['value']
                        
    return branches_with_mutation

In [88]:
def find_next_s1_muts(branches_with_mutation):
    
    # dictionary with descendent that has S1 mutation as key and 
    # parent with given mutation as value
    descendents_with_s1_muts_parents = {}
    descendents_with_s1_muts_dates = {}
    
    # find all descendents of given mutation that have S1 mutations
    for node in tree.find_clades(terminal=False):
        
        node_path = get_parent(tree, node)[:-1]
        if node.s1_nonsyn_at_node > 0:
            parents = [p.name for p in node_path]
            
            for p in parents:
                if p in branches_with_mutation.keys():
                    descendents_with_s1_muts_parents[node.name] = p
                    descendents_with_s1_muts_dates[node.name] = node.node_attrs['num_date']['value']

                
    # limit this list of descendents to the FIRST descendent (along all paths) 
    # after given mutation that has an S1 mutation. This means excluding 
    # elements of `descendents_with_s1_muts` that are descendents of other elements of this list
    
    first_descendents_with_s1_muts = {}
    
    for node in tree.find_clades(terminal=False):
        if node.name in descendents_with_s1_muts_parents.keys():
            node_path = get_parent(tree, node)[:-1]
            parents = [p.name for p in node_path]
            if not any(n in descendents_with_s1_muts_parents.keys() for n in parents):
                first_descendents_with_s1_muts[node.name] = node.node_attrs['num_date']['value']

    # find the wait time between given mutation occuring and the first S1 mutation on each descending path   
    wait_times = []
    
    for k,v in first_descendents_with_s1_muts.items():
        descendent_with_s1_mut = k
        descendent_date = v
        parent_with_specified_mut = descendents_with_s1_muts_parents[k]
        parent_date = branches_with_mutation[parent_with_specified_mut]
        wait_time = float(descendent_date) - float(parent_date)
        wait_times.append(wait_time)
        
    mean_wait_time = sum(wait_times)/len(wait_times)
                
    return wait_times, mean_wait_time
        

In [90]:
branches_with_mutation = get_branches_with_mutation('ORF1a', 'S3675-')
find_next_s1_muts(branches_with_mutation)

([0.6019116337904507,
  0.497929925717699,
  0.7400282892451742,
  0.6191670429293481,
  0.41901656461300263,
  0.25428153157963607,
  0.08909429248956258,
  0.11337948171239987,
  0.08063851339375105,
  0.18090011035201314,
  0.2192113258640802],
 0.34686897378973797)

In [82]:
branches_with_mutation = get_branches_with_mutation('S', 'N501Y')
find_next_s1_muts(branches_with_mutation)

0.4028468310553863

In [83]:
branches_with_mutation = get_branches_with_mutation('S', 'E484K')
find_next_s1_muts(branches_with_mutation)

0.49186554893942064

In [85]:
branches_with_mutation = get_branches_with_mutation('S', 'L452R')
find_next_s1_muts(branches_with_mutation)

0.25672516972209297

In [84]:
branches_with_mutation = get_branches_with_mutation('ORF1a', 'T3255I')
find_next_s1_muts(branches_with_mutation)

0.31538265432609097

In [86]:
branches_with_mutation = get_branches_with_mutation('ORF1a', 'L3606F')
find_next_s1_muts(branches_with_mutation)

0.6203002526914361

In [89]:
branches_with_mutation = get_branches_with_mutation('N', 'T205I')
find_next_s1_muts(branches_with_mutation)

([0.20375446860657576,
  0.30385106667836226,
  0.24148756189492815,
  0.05771895346811107,
  0.15097368617512075,
  0.11168921981425228,
  0.6698805024457215,
  0.5051454694123549,
  0.08089307470004314],
 0.2583771114661633)

In [91]:
branches_with_mutation = get_branches_with_mutation('M', 'I82T')
find_next_s1_muts(branches_with_mutation)

([0.2681981131554494,
  0.25444178291286335,
  0.2529228037053599,
  0.18750767846927374,
  0.43247443910649963,
  0.4066869186654003,
  0.39012483676992815,
  0.2280170079347954,
  0.2187517573149762,
  0.7400282892451742,
  0.6191670429293481,
  0.11633888354981536,
  0.8343882090384795,
  0.6713679100150785],
 0.4014582623437458)

In [6]:
def consolidate_deletions(mutation_list):
    """
    For deletion mutations, consider adjacent sites as part of the same deletion
    """
    
    without_deletions = [x for x in mutation_list if x[-1]!='-' and x[0]!='-']
    #consolidate deletions and reversions
    deletions_only = [x for x in mutation_list if x[-1]=='-' or x[0]=='-']
    deletions_only.sort(key=lambda x:x[1:-1])
    
    
    #keep track of start of separate deletions
    separate_deletions = []

    # if there are deletions, count a run of consecutive sites as a single deletion/mutation
    if len(deletions_only) != 0:
        separate_deletions.append(deletions_only[0])

        deletion_tracker = int(deletions_only[0][1:-1])
        
        for deletion in deletions_only[1:]:

            deleted_pos = int(deletion[1:-1])
            if deleted_pos == deletion_tracker+1:
                pass
            else:
                separate_deletions.append(deletion)
            deletion_tracker = deleted_pos
    
    consolidated_mutation_list = separate_deletions + without_deletions
            
    return consolidated_mutation_list

In [7]:
def nuc_changes_from_reference(muts_on_path):
    """
    From all the of the nucleotide changes that have occurred on the path from root to branch, 
    find the most recent nuc mutation at each site (giving the genotype at the branch)
    """
    
    final_muts_from_ref = {}

    # overwrites genotypes at pos in historical order
    for x in muts_on_path:
        x_pos = int(x[1:-1])
        final_muts_from_ref[x_pos] = x[-1]
        

    return final_muts_from_ref

In [8]:
def determine_synonymous(nuc_muts_on_branch, parent_diffs_from_ref):
    
    parent_diffs_pos = [int(k) for k,v in parent_diffs_from_ref.items()]

    
    # make dictionary of synonymous (and noncoding) mutations to add to tree
    syn_muts = {}
    
    # don't care about deletions because they are obviously not synonymous
    for mut in nuc_muts_on_branch:
        if mut[-1]!= '-' and mut[0]!='-':
            mut_pos = int(mut[1:-1])
            # find what gene this mut happens in
            if (mut_pos-1) in reference_gene_locations.keys():
                mut_gene = reference_gene_locations[mut_pos-1]
                mut_codon_num = reference_gene_codon[mut_pos-1][0]
                mut_codon_pos = reference_gene_codon[mut_pos-1][1]
                
                # find the reference sequence of the codon this mutation occurs in
                codon_ref_aa = reference_sequence_aa[mut_gene][mut_codon_num]
                
                codon_ref_nt = reference_sequence_nt[mut_gene][(mut_codon_num*3):(mut_codon_num*3+3)]
                
                # check if a mutation occurred within the same codon in a parent
                # and if so, change the reference codon sequence accordingly, 
                # to tell whether the mutation at this branch is synonymous or not
                codon_genome_pos = list(range((mut_pos-1-mut_codon_pos),(mut_pos-1-mut_codon_pos+3)))
                
                parent_codon = codon_ref_nt
                for parent_diff in parent_diffs_pos:
                    parent_diff_zero_based = parent_diff-1
                    if parent_diff_zero_based in codon_genome_pos:
                        parent_diff_pos = codon_genome_pos.index(parent_diff_zero_based)
                        parent_codon = MutableSeq(str(codon_ref_nt))
                        parent_codon[parent_diff_pos] = parent_diffs_from_ref[parent_diff]
                        parent_codon = parent_codon.toseq()
                
                
                codon_mutated = MutableSeq(str(parent_codon))
                codon_mutated[mut_codon_pos] = mut[-1]
                codon_mutated = codon_mutated.toseq()
                codon_mutated_translation = codon_mutated.translate()
                
                if str(codon_ref_aa) == str(codon_mutated_translation):
                    if mut_gene in syn_muts.keys():
                        syn_muts[mut_gene] += [mut]
                    else:
                        syn_muts[mut_gene] = [mut]
                        
                

            else:
                if 'noncoding' in syn_muts.keys():
                    syn_muts['noncoding'] += [mut]
                else:
                    syn_muts['noncoding'] = [mut]
                    
    return syn_muts

In [9]:
def add_syn_mut_attribute(tree):
    
    for node in tree.find_clades():

        node.node_attrs['syn_muts'] = {}

        # only care if this branch has some nucleotide mutations
        if hasattr(node, 'branch_attrs'):
            if 'nuc' in node.branch_attrs['mutations']:

                nuc_muts_on_branch = node.branch_attrs['mutations']['nuc']

                node_path = get_parent(tree, node)

                nucleotide_mut_path = []

                # find all nucleotide mutations that happened in parents, 
                # in case they affect codons mutated on this branch
                for parent in node_path[-1]:
                    if hasattr(parent, 'branch_attrs'):
                        if 'nuc' in parent.branch_attrs['mutations']:
                            nucleotide_mut_path+=parent.branch_attrs['mutations']['nuc']

                parent_diffs_from_ref = nuc_changes_from_reference(nucleotide_mut_path)

                syn_muts_dict = determine_synonymous(nuc_muts_on_branch, parent_diffs_from_ref)

                node.node_attrs['syn_muts'] = syn_muts_dict


In [17]:
add_syn_mut_attribute(tree)

In [18]:



for node in tree.find_clades(terminal=False):
    
    if node.name!= 'USA/OR-OSPHL00881/2021':

        node.s1_nonsyn_at_node = 0
        node.s1_syn_at_node = 0
        node.rdrp_nonsyn_at_node = 0
        
        if hasattr(node, "node_attrs") and 'S1' in node.node_attrs['syn_muts']:
            node.s1_syn_at_node = len(node.node_attrs['syn_muts']['S1'])
            
        
        if hasattr(node, 'branch_attrs'):
            
            s1_nonsyn_at_this_node = []
            if "S" in node.branch_attrs["mutations"]:
                for mut in node.branch_attrs["mutations"]["S"]:
                    if int(mut[1:-1]) in range(14,686):
                        s1_nonsyn_at_this_node.append(mut)
            
            s1_consolidated = consolidate_deletions(s1_nonsyn_at_this_node)
            node.s1_nonsyn_at_node = len(s1_consolidated)



            rdrp_nonsyn_at_this_node = []
            if "ORF1a" in node.branch_attrs["mutations"]:
                for mut in node.branch_attrs["mutations"]["ORF1a"]:
                    if int(mut[1:-1]) in range(4492,4401):
                        rdrp_nonsyn_at_this_node.append(mut)


            if "ORF1b" in node.branch_attrs["mutations"]:
                for mut in node.branch_attrs["mutations"]["ORF1b"]:
                    if int(mut[1:-1]) in range(1,923):
                        rdrp_nonsyn_at_this_node.append(mut)
            
            rdrp_consolidated = consolidate_deletions(rdrp_nonsyn_at_this_node)
            node.rdrp_nonsyn_at_node = len(rdrp_consolidated)
