In [41]:
from augur.utils import json_to_tree
import matplotlib.pyplot as plt
import seaborn as sns
from Bio import SeqIO
import numpy as np
import pandas as pd
import itertools
import requests
from collections import Counter
from itertools import dropwhile
import math

In [2]:
tree_url = 'https://nextstrain-blab.s3.amazonaws.com/ncov_adaptive-evolution.json'

tree_json = requests.get(tree_url).json()

#Put tree in Bio.Phylo format
tree = json_to_tree(tree_json)

In [9]:
# make dictionary with gene name as key and reference sequence of that gene as value

reference_sequence = {}

for record in SeqIO.parse(open("../reference_seq_edited.gb","r"), "genbank"):
    for feature in record.features:
        if feature.type == 'CDS':
            gene_seq = feature.location.extract(record.seq).translate()
            reference_sequence[feature.qualifiers['gene'][0]] = gene_seq




In [3]:
def get_parent(tree, child_clade):
    node_path = tree.get_path(child_clade)
    return node_path

In [12]:
def track_independent_occurrences(mutation_type='aa'):
    
    # dictionary with mutation as key and list of nodes it appears at as value
    all_mutation_occurrences = {}


    #only look at mutations on internal branches
    for node in tree.find_clades(terminal=False):
        
        # only consider mutations on branches that give rise to a clade of at least 15 tips
        if len(node.get_terminals())>=15:

            if hasattr(node, 'branch_attrs'):
                for gene, mut_list in node.branch_attrs["mutations"].items():
                    # not considering synonymous mutations for now
                    if gene!= 'nuc':
                        for mut in mut_list:
                            
                            if mutation_type == 'aa':
                                # look at specific nonsyn muts
                                gene_mutation = f'{gene}:{mut}'

                                
                            elif mutation_type == 'site':
                                # look at mutation to a codon (without regard to identity of mutation)
                                gene_mutation = f'{gene}:{mut[1:-1]}'
                                
                            
                            # append mutation to a list of all observed mutations
                            # exclude if mutation is in stop codon position of gene
                            if int(mut[1:-1])!=len(reference_sequence[gene]):
                                if gene_mutation in all_mutation_occurrences.keys():
                                    all_mutation_occurrences[gene_mutation].append(node.name)
                                else:
                                    all_mutation_occurrences[gene_mutation] = [node.name]
                
    
    # filter to only mutations that have occurred at least 3 times 
    # (with fewer occurrences, potential epistatic interactions are more likely to be coincidence) 
    
    all_mutation_occurrences = {k:v for k,v in all_mutation_occurrences.items() if len(v)>=3}
    
    return all_mutation_occurrences


In [28]:
def find_sequential_muts(all_mutation_occurrences, mutation_type='aa'):
    
    # dictionary with recurrent mut as key and dictionary as value 
    # dictionary has node where recurrent mut occurs as key and list of muts observed after this node as value
    muts_after_recurrentmuts = {}
    
    for recurrent_mut, mut_nodes in all_mutation_occurrences.items():

        #only look at mutations on internal branches
        for node in tree.find_clades(terminal=False):

            # only consider mutations on branches that give rise to a clade of at least 3 tips
            if len(node.get_terminals())>=3:

                parents = [p.name for p in get_parent(tree, node)[:-1]]
                
                # check whether one of the recurrent mutations occurred in a parent
                for n in mut_nodes:
                    if n in parents:
                        if hasattr(node, 'branch_attrs'):
                            for gene, mut_list in node.branch_attrs["mutations"].items():
                                if gene!= 'nuc':
                                    for mut in mut_list:
                                        
                                        if mutation_type == 'aa':
                                            # look at specific nonsyn muts
                                            gene_mutation = f'{gene}:{mut}'
                                        elif mutation_type == 'site':
                                            # look at mutation to a codon (without regard to identity of mutation)
                                            gene_mutation = f'{gene}:{mut[1:-1]}'
                                                
                                        if int(mut[1:-1])!=len(reference_sequence[gene]):
                                            if recurrent_mut in muts_after_recurrentmuts.keys():
                                                if n in muts_after_recurrentmuts[recurrent_mut].keys():
                                                    muts_after_recurrentmuts[recurrent_mut][n].append(gene_mutation)
                                                else:
                                                    muts_after_recurrentmuts[recurrent_mut][n] = [gene_mutation]
                                            else:
                                                muts_after_recurrentmuts[recurrent_mut] = {n:[gene_mutation]}

    return muts_after_recurrentmuts
    

In [27]:
all_mutation_occurrences = track_independent_occurrences()

In [29]:
muts_after_recurrentmuts = find_sequential_muts(all_mutation_occurrences)

In [42]:
tally_muts_after_occurrence = {}

for recurrent_mut, node_mut_list in muts_after_recurrentmuts.items():
    following_muts = []
    for node_name, mut_list in node_mut_list.items():
        following_muts+=list(set(mut_list))
    
    following_muts_dict = Counter(following_muts)
    for key, count in dropwhile(lambda key_count: key_count[1] >= 2, following_muts_dict.most_common()):
        del following_muts_dict[key]
    tally_muts_after_occurrence[recurrent_mut] = {'occurrences': len(node_mut_list), 
                                                  'mut_counts':following_muts_dict}

tally_muts_after_occurrence

{'ORF1a:L3606F': {'occurrences': 5,
  'mut_counts': Counter({'N:D402Y': 2,
           'ORF1b:D1183Y': 2,
           'ORF1a:M3752I': 2,
           'ORF1a:P2376L': 2,
           'ORF1a:F3606L': 2})},
 'N:P13L': {'occurrences': 3, 'mut_counts': Counter()},
 'ORF9b:P10S': {'occurrences': 3, 'mut_counts': Counter()},
 'ORF1a:F3606L': {'occurrences': 3, 'mut_counts': Counter()},
 'N:M234I': {'occurrences': 6, 'mut_counts': Counter()},
 'ORF1a:T3255I': {'occurrences': 6, 'mut_counts': Counter({'S:T95I': 2})},
 'S:L452R': {'occurrences': 5,
  'mut_counts': Counter({'ORF1a:H3580Q': 2,
           'S:W258L': 2,
           'M:I82T': 2,
           'ORF1a:D2980N': 2})},
 'S:D614G': {'occurrences': 3,
  'mut_counts': Counter({'S:-144Y': 2,
           'ORF1a:T4065I': 2,
           'N:P365S': 2,
           'ORF1b:A1643V': 2,
           'S:D215Y': 2,
           'ORF1a:H3580Q': 2,
           'ORF1a:A486V': 2,
           'S:Y144-': 2,
           'S:G1219C': 2,
           'ORF9b:R32L': 2,
           'N:D40

In [None]:
# 

for node in tree.find_clades():
    node_path = get_parent(tree, node)
    node_names_path = [x.name for x in node_path]
    all_paths.append(node_names_path)
    
    for parent in node_path:
        if len(parent.get_terminals()) >=10:
            if hasattr(parent, "branch_attrs") and "mutations" in parent.branch_attrs: