In [5]:
import math
import json
import string
import requests
from augur.utils import json_to_tree
from os import path
import pandas as pd
import numpy as np
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from Bio import AlignIO
from Bio.Align import MultipleSeqAlignment
from Bio.Align import AlignInfo
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats

In [6]:
#Download tree json
# tree_url = "https://data.nextstrain.org/ncov_gisaid_global.json"
tree_url = "https://data.nextstrain.org/ncov_global_2021-05-19.json"
# tree_url = "https://data.nextstrain.org/ncov_global_2021-05-28.json"

tree_json = requests.get(tree_url).json()

#Put tree in Bio.Phylo format
tree = json_to_tree(tree_json)

In [None]:
def divergence_weighted(cov, gene, window, clade, min_seqs, year_max=None, year_min=None):
    #Find fraction of sites that differ from root and average among all viruses at each time point
    input_file_root = '../'+str(cov)+'/auspice/seasonal_corona_'+str(cov)+'_'+str(gene)+'_root-sequence.json'
    input_file_alignment = '../'+str(cov)+'/results/aligned_'+str(cov)+'_'+str(gene)+'.fasta'
    metafile = '../'+str(cov)+'/results/metadata_'+str(cov)+'_'+str(gene)+'.tsv'
    
    #Subset data based on time windows
    meta = pd.read_csv(metafile, sep = '\t')
    meta.drop(meta[meta['date']=='?'].index, inplace=True)
    meta.dropna(subset=['date'], inplace=True)
    meta['year'] = meta['date'].str[:4].astype('int')
    
    if year_max:
        meta.drop(meta[meta['year']>year_max].index, inplace=True)
    if year_min:
        meta.drop(meta[meta['year']<year_min].index, inplace=True)

    
    date_range = meta['year'].max() - meta['year'].min()
    
    
    if clade!= None:
        clade_df = separate_clades(cov, gene)
        meta = meta.merge(clade_df, on='strain')
        meta.drop(meta[meta['clade']!=clade].index, inplace=True)

    #Group viruses by time windows
    virus_time_subset = {}
    if window == 'all':
        years = str(meta['year'].min()) + '-' + str(meta['year'].max())
        virus_time_subset[years] = meta['strain'].tolist()
    else:
        date_window_start = meta['year'].min()
        date_window_end = meta['year'].min() + window
        while date_window_end <= meta['year'].max():
            years = str(date_window_start) + '-' + str(date_window_end)
            strains = meta[(meta['year']>=date_window_start) & (meta['year']<date_window_end)]['strain'].tolist()
            virus_time_subset[years] = strains
            #sliding window
            date_window_end += 1
            date_window_start += 1     
    
    root_seq, root_aa_seq, first_window_years = find_founder_consensus(virus_time_subset,input_file_alignment, min_seqs)
    
    nonsyn_denominator, syn_denominator = find_nonsyn_syn_denominators(root_seq, root_aa_seq, cov, clade)

    
    #initiate lists to record all time windows
    year_windows = []
    seqs_in_window = []
    nonsyn_divergences = []
    syn_divergences = []
    nonsyn_divergences_window_average = []
    syn_divergences_window_average = []

    
    for years, subset_viruses in virus_time_subset.items():
    #don't use windows with fewer than min_seqs 
        if len(subset_viruses) >= min_seqs:
            year_windows.append(years)
            seqs_in_window.append(len(subset_viruses))
            

            syn_div_allviruses_in_window = []
            nonsyn_div_allviruses_in_window  = []
            with open(input_file_alignment, "r") as aligned_handle:
                for virus in SeqIO.parse(aligned_handle, "fasta"):         
                    #Only viruses in time window
                    if virus.id in subset_viruses:    
                        #check
                        if len(virus.seq) != len(root_seq):
                            print(virus)
                        elif len(virus.seq) == len(root_seq):
                            count_total_unambiguous = 0
                            count_subs = 0
                            count_syn_subs = 0
                            count_nonsyn_subs = 0
                            for pos in range(len(root_seq)):
                                root_nt = str(root_seq[pos])
                                virus_nt = str(virus.seq[pos])
                                #skip ambiguous sites
                                if virus_nt != 'N':
                                    if root_nt != 'N':
                                        count_total_unambiguous+=1
                                        if virus_nt != root_nt:
                                            count_subs+=1
                                            #determine syn or nonsyn
                                            codon = math.floor(pos/3)
                                            codon_pos = pos-(codon*3)
                                            if codon_pos == 0:
                                                codon_nt = virus.seq[pos:(pos+3)]
                                            elif codon_pos == 1:
                                                codon_nt = virus.seq[(pos-1):(pos+2)]
                                            elif codon_pos == 2:
                                                codon_nt = virus.seq[(pos-2):(pos+1)]
                                            codon_aa = codon_nt.translate()
                                            root_aa = root_aa_seq[codon]
                                            #skip ambiguous
                                            if root_aa != 'X':
                                                if codon_aa != root_aa:
                                                    count_nonsyn_subs+=1
                                                elif codon_aa == root_aa:
                                                    count_syn_subs+=1


                            #Multiply div by fraction of sites that were unambiguously sequenced
                            unambiguous_ratio = count_total_unambiguous/len(root_seq)
                            syn_div_for_virus = (count_syn_subs*unambiguous_ratio)/(syn_denominator*unambiguous_ratio)
#                             syn_div_for_virus = (count_syn_subs)/(syn_denominator)
                            syn_div_allviruses_in_window.append(syn_div_for_virus)
                            nonsyn_div_for_virus = (count_nonsyn_subs*unambiguous_ratio)/(nonsyn_denominator*unambiguous_ratio)
#                             nonsyn_div_for_virus = (count_nonsyn_subs)/(nonsyn_denominator)
                            nonsyn_div_allviruses_in_window.append(nonsyn_div_for_virus)



                mean_syn_div_in_window = sum(syn_div_allviruses_in_window)/len(syn_div_allviruses_in_window)
                mean_nonsyn_div_in_window = sum(nonsyn_div_allviruses_in_window)/len(nonsyn_div_allviruses_in_window)

                syn_divergences.append(syn_div_allviruses_in_window)
                nonsyn_divergences.append(nonsyn_div_allviruses_in_window)

                syn_divergences_window_average.append(mean_syn_div_in_window)
                nonsyn_divergences_window_average.append(mean_nonsyn_div_in_window)

            
    return year_windows, seqs_in_window, syn_divergences, nonsyn_divergences, syn_divergences_window_average, nonsyn_divergences_window_average
