In [1]:
import math
import json
import random
import ast
import re
import os
from os import path
import pandas as pd
import numpy as np
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from Bio.SeqFeature import SeqFeature, FeatureLocation, CompoundLocation
from Bio import AlignIO
from Bio.Align import MultipleSeqAlignment
from Bio.Align import AlignInfo
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import seaborn as sns
from scipy import stats
from collections import Counter

In [2]:
def readin_virus_config(virus):
    config_json = f'config/adaptive_evo_config_{virus}.json'
    with open(config_json) as json_handle:
        configs = json.load(json_handle)
        
    return configs

In [140]:
def subset_alignment(virus, subtype, gene, location, window, min_seqs, year_max, year_min):
    """
    Get the alignment file and subset it into time windows (defined by window)
    Each window must have at least min_seqs in it
    Find the outgroup_seq as the consensus sequence at the first time point
    """
    
    configs = readin_virus_config(virus)
    
    
    alignment_file = configs['alignment_file'].format(virus=virus, subtype=subtype, gene=gene)
    meta_file = configs['meta_file'].format(virus=virus, subtype=subtype, gene=gene)
    #some are comma-separated, some are tab-separated
    metafile_sep = configs['metafile_sep']
    
    
    meta = pd.read_csv(meta_file, sep = metafile_sep)
    meta.drop(meta[meta['date']=='?'].index, inplace=True)
    meta.dropna(subset=['date'], inplace=True)
    meta = meta[meta["date"].str.contains("20XX")==False]
    meta['year'] = meta['date'].str[:4].astype('int')
    if year_max:
        meta.drop(meta[meta['year']>year_max].index, inplace=True)
    if year_min:
        meta.drop(meta[meta['year']<year_min].index, inplace=True)
    
    #Remove egg- and cell-passaged strains
    meta.drop(meta[meta['strain'].str[-4:]=='-egg'].index, inplace=True)
    meta.drop(meta[meta['strain'].str[-5:]=='-cell'].index, inplace=True)
    
    #Limit meta data to only strains in alignment file
    aligned_isolates = []
    with open(alignment_file, "r") as aligned_handle:
        for isolate in SeqIO.parse(aligned_handle, "fasta"):
            aligned_isolates.append(isolate.id)
    aligned_isolates_df = pd.DataFrame(aligned_isolates, columns=['strain'])
    meta = meta.merge(aligned_isolates_df, on='strain', how='inner')
    
    #Group viruses by time windows
    virus_time_subset = {}
    date_window_start = meta['year'].min()
    date_window_end = meta['year'].min() + window
    while date_window_end <= meta['year'].max():
        years = str(date_window_start) + '-' + str(date_window_end)
        strains = meta[(meta['year']>=date_window_start) & (meta['year']<date_window_end)]['strain'].tolist()
        virus_time_subset[years] = strains


        #sliding window
        date_window_end += 1
        date_window_start += 1 

    #for location of sub-genic locus, change list-format to SeqFeature
    locus_location = SeqFeature(FeatureLocation(location[0], location[1]))   
    

    #Only use time points with enough data:
    virus_time_subset = {k:v for k,v in virus_time_subset.items() if len(v)>=min_seqs}

    year_windows = []
    seqs_in_window = []
    
    #Find outgroup sequence from strains at first time point(to make consensus from)
    first_window = True
    first_window_strains = []
    first_window_sequences = []
    
    alignment_time_subset = {}

    
    for years, subset_viruses in virus_time_subset.items():

        year_windows.append(years)
        seqs_in_window.append(len(subset_viruses))
        alignment_time_subset[years] = []

        #make consensus sequence at first time point
        if first_window == True:
            first_window_strains+=subset_viruses
            first_window = False
        

        with open(alignment_file, "r") as aligned_handle:
            for isolate in SeqIO.parse(aligned_handle, "fasta"):
                if isolate.id in first_window_strains:
                    gene_record = SeqRecord(seq = locus_location.extract(isolate.seq), 
                                            id = isolate.id, description = gene)

                    first_window_sequences.append(gene_record)

                if isolate.id in subset_viruses:
                    alignment_time_subset[years].append(locus_location.extract(isolate.seq))


    first_window_alignment = MultipleSeqAlignment(first_window_sequences)
    if virus=='rsv':
        outgroup_seq = AlignInfo.SummaryInfo(first_window_alignment).gap_consensus(ambiguous ='N')
    else:
        outgroup_seq = AlignInfo.SummaryInfo(first_window_alignment).dumb_consensus(ambiguous ='N')
        
    has_dup = find_duplication(outgroup_seq)
    
    #if virus has duplication, want to run Bhatt on entire alignment excluding dup, 
    #and then separately on the duplicated sequence to look at evolution occurring on top of it    
    if has_dup:
        outgroup_seq, outgroup_seq_aa, alignment_time_subset = adjust_for_duplications(outgroup_seq, alignment_time_subset)
    else:
        outgroup_seq_aa = outgroup_seq.translate()
        
        
    return virus_time_subset, alignment_time_subset, outgroup_seq, outgroup_seq_aa, year_windows, seqs_in_window 


In [111]:
def find_duplication(outgroup_seq):
    """
    Duplication events (or any insertions) will be signified in the outgroup sequence 
    by a series of consecutive --- placeholders. Find if there is a duplication in this 
    evolution of this virus.
    """
    has_dup = False
    outgroup_seq_str = str(outgroup_seq)
    #if there are ---s in the outgroup_seq, find where they are
    #say that insertion/duplication has to be at least 3 codons long
    if re.search("-{9,}", outgroup_seq_str):
        has_dup=True
        
        
    return has_dup

In [115]:
def adjust_for_duplications(outgroup_seq, alignment_time_subset):
    """
    Find the position and length of the duplication.
    Remove the duplicated region from the outgroup sequence and the every sequence in the alignment.
    Evolution on the duplicated region will be considered separately because the outgroup consensus 
    for this region needs to done from the first timepoint where there are sequences with the duplication
    """

    outgroup_seq_str = str(outgroup_seq)
    #find where the duplication is by locating ---s in the outgroup_seq
    if re.search("-{9,}", outgroup_seq_str):
        dup_start, dup_end = [(x.start(),x.end()) for x in re.finditer(r'-{9,}', outgroup_seq_str)][0]


    outgroup_wo_dup = Seq(outgroup_seq_str[:dup_start]+outgroup_seq_str[dup_end:])
    outgroup_wo_dup_aa = outgroup_wo_dup.translate()

    # remove the duplicated portion from the main alignment
    alignment_time_subset_wo_dup = {}
    for dates, strain_seqs in alignment_time_subset.items():
        strain_seqs_wo_dup = [Seq(str(x)[:dup_start]+str(x)[dup_end:]) for x in strain_seqs]
        alignment_time_subset_wo_dup[dates] = strain_seqs_wo_dup
    
        
    
    return outgroup_wo_dup, outgroup_wo_dup_aa, alignment_time_subset_wo_dup



In [28]:
def count_codon_fixations(alignment_sequences, outgroup_seq, midfreq_high):
    """
    For each time window, find if any mutations have fixed 
    (or nearly fixed, meaning reach a frequency exceeding midfreq_high).
    Find whether those mutations are nonsynonymous or synonymous.
    Keep a count of fixations/near-fixations at each codon
    """
    
    #divide outgroup_seq into codons
    outgroup_codons = get_codons(outgroup_seq)
    

    #initiate arrays to record fixations (or near-fixations) at all codons in alignment
    #keep track of this for both nonsynonymous and synonymous mutations
    nonsynonymous_fixations = np.zeros(len(outgroup_codons))
    synonymous_fixations = np.zeros(len(outgroup_codons))
        

    for years, alignment_seqs in alignment_sequences.items():
  
        #look for fixations in each time window
        nonsyn_fixations_in_window, syn_fixations_in_window, fixed_codons = walk_through_codons(outgroup_codons, 
                                                                                                alignment_seqs, midfreq_high)
        nonsynonymous_fixations += nonsyn_fixations_in_window
        synonymous_fixations += syn_fixations_in_window
        
        #update outgroup_seq at codons that have fixed mutations 
        for pos, seq in fixed_codons.items():
            outgroup_codons[pos] = seq
    

    return nonsynonymous_fixations, synonymous_fixations


In [7]:
def get_codons(seq):
    """
    Split the sequence up into a list of codons
    Return an error if not divisble by 3
    """
    if len(seq) %3 != 0:
        print('Sequence not divisible by 3. Check specific location')
    
    
    codon_list = []
    for i in range(0, len(seq), 3):
        codon_list.append(seq[i:i+3])
    return codon_list

In [8]:
def find_mutation_fixations(pos, outgroup_codon, alignment_codons, midfreq_high):
    """
    At a given codon position, find whether any mutations occurred, 
    and if so what frequency they are at in the population
    
    Classify mutations that reach between midfreq_high and 100% as fixations
    """ 

    #only consider unabiguous sequencing
    alignment_codons = [x for x in alignment_codons if set('AGCT').intersection(set(str(x))) == set(str(x))]

    #get a count of the different codon sequences observed at this position
    codon_seq_counts = Counter(alignment_codons)
    codon_seqs = list(codon_seq_counts.keys())
    codon_freqs = {c: (codon_seq_counts[c] / len(alignment_codons)) for c in codon_seq_counts}

    
    #all codons are the same in the alignment
    if len(codon_seqs)==1:
        #check if they are the same as the outgroup
        #or whether they are a fixed mutation
        if outgroup_codon==codon_seqs[0]:
            site_type = 'no_fixation'
        #or whether they are a fixed mutation
        elif outgroup_codon!=codon_seqs[0]:
            site_type = 'fixation'
            fixed_codon = codon_seqs[0]
            
    #if there are multiple codon sequences observed at this position, 
    #see if there are any mutations present at a frequency of midfreq_high or higher
    elif len(codon_seqs)!=1:
        #default is no fixation, can be overwritten if one codon is present at high freq and is a mutation
        site_type = 'no_fixation'
        for cod in codon_seqs:
            #check whether any of the codon sequences are present at high enough frequency
            if codon_freqs[cod] >= midfreq_high:
                #check if it is a mutation
                if cod != outgroup_codon:
                    site_type = 'fixation'
                    fixed_codon = cod


    # check if fixation was synonymous or nonsynonymous                
    if site_type == 'fixation':
        outgroup_aa = Seq(outgroup_codon).translate()
        alignment_codon_aa = Seq(fixed_codon).translate()
        if outgroup_aa == alignment_codon_aa:
            fixation_type = 'synonymous'
        elif outgroup_aa != alignment_codon_aa:
            fixation_type = 'nonsynonymous'
    elif site_type == 'no_fixation':
        fixation_type = None
        fixed_codon = None
    
    return fixation_type, fixed_codon

In [74]:
def walk_through_codons(outgroup_codons, alignment_seqs, midfreq_high):
    """
    Walk through each codon in each sequence of the alignment and determine whether 
    there was a synonymous mut, nonsynonymous mut or neither
    """
    
    #list of codons in alignment organized as [['ATG','ATG','ATG'],['TAG','TAG','TAG']]
    alignment_codons = [[] for x in outgroup_codons]
    for seq in alignment_seqs:
        isolate_codons = get_codons(seq)
        for i in range(len(isolate_codons)):
            alignment_codons[i].append(str(isolate_codons[i]))  
            
    
    #initialize arrays to count fixations at each site 
    nonsyn_fixations_in_window = np.zeros(len(outgroup_codons))
    syn_fixations_in_window = np.zeros(len(outgroup_codons))
    
    #initialize dictionary to keep track of codons {pos:codon seq} that have fixed during this time window, 
    #so that outgroup_seq can be updated
    fixed_codons = {}
    
    #walk through sequence codon by codon
    for i in range(len(outgroup_codons)):
        #only consider unabiguous sequencing
        if set('AGCT').intersection(set(str(outgroup_codons[i]))) == set(str(outgroup_codons[i])):
            #find fixations or near-fixations
            fixation_type, fixed_codon_seq = find_mutation_fixations(i, outgroup_codons[i], 
                                                                     alignment_codons[i], midfreq_high)
            if fixed_codon_seq!=None:
                fixed_codons[i] = fixed_codon_seq
            if fixation_type== 'nonsynonymous':
                nonsyn_fixations_in_window[i]+=1
            elif fixation_type== 'synonymous':
                syn_fixations_in_window[i]+=1
                    
    return nonsyn_fixations_in_window, syn_fixations_in_window, fixed_codons

In [125]:
#for location, look at the start and end numbers from the reference file and use [(start-1), end]
def main(virus, subtype, gene, location, coordinate_system, 
         window=3, min_seqs=3, year_max=False, year_min=False, midfreq_high=0.75):
    """
    Count the number of fixations at each codon and save this as a csv file
    """
    (virus_time_subset, alignment_time_subset, outgroup_seq, 
     outgroup_seq_aa, year_windows, seqs_in_window) = subset_alignment(virus, subtype, gene, 
                                                                       location, window, min_seqs, 
                                                                       year_max, year_min)
    
    nonsynonymous_fixations, synonymous_fixations = count_codon_fixations(alignment_time_subset, outgroup_seq, midfreq_high)
    
    #turn arrays into dataframe with count of fixations at residue
    sites_with_nonsyn_fixation = []
    sites_with_syn_fixation = []
    
    for x in range(len(nonsynonymous_fixations)):
        #rsv has duplication- need to make sure coordinates are consistent with this
        #since the duplication was removed from the alignment, coordinates after the duplication need to be adjusted
        codon = x
        if virus=='rsv':
            if subtype=='a':
                if x>=284:
                    codon = x+24
            elif subtype=='b':
                if x>=260:
                    codon = x+20
        
        #change coordinate system to make sense with coordinates in auspice and the pdb structure files
        if coordinate_system== 'relative_to_atg_generef':
            sites_with_nonsyn_fixation.append({'codon':int(codon+location[0]/3+1), 
                                               'nonsyn_fixations': int(nonsynonymous_fixations[x])})
            sites_with_syn_fixation.append({'codon':int(codon+location[0]/3+1), 
                                            'syn_fixations': int(synonymous_fixations[x])})
        elif coordinate_system== 'relative_to_subunit':
            sites_with_nonsyn_fixation.append({'codon':int(codon+1), 
                                       'nonsyn_fixations': int(nonsynonymous_fixations[x])})
            sites_with_syn_fixation.append({'codon':int(codon+1), 
                                'syn_fixations': int(synonymous_fixations[x])})



                
    nonsyn_fixations_df = pd.DataFrame(sites_with_nonsyn_fixation)
    syn_fixations_df = pd.DataFrame(sites_with_syn_fixation)
    
    if subtype:
        nonsyn_fixations_df.to_csv(f'adaptive_loci_results/fixations_per_site/results/{virus}_{subtype}_{gene}_nonsyn_fixations.csv', index=False)
        syn_fixations_df.to_csv(f'adaptive_loci_results/fixations_per_site/results/{virus}_{subtype}_{gene}_syn_fixations.csv', index=False)
    else: 
        nonsyn_fixations_df.to_csv(f'adaptive_loci_results/fixations_per_site/results/{virus}_{gene}_nonsyn_fixations.csv', index=False)
        syn_fixations_df.to_csv(f'adaptive_loci_results/fixations_per_site/results/{virus}_{gene}_syn_fixations.csv', index=False)
    
            

In [15]:
#how to best deal with ambiguous sequencing? 
#Noticed that some mutations called by auspice (like 229E 89) do not show up here


In [175]:
main('229e', None, 'spike', [45,3522], coordinate_system='relative_to_atg_generef')

In [130]:
main('oc43', 'a', 'spike', [39,4086], coordinate_system='relative_to_atg_generef', window=5)

In [129]:
main('nl63', None, 'spike', [45,4071], coordinate_system='relative_to_atg_generef', window=5)

In [177]:
main('h3n2', None, 'ha', [48,1698], coordinate_system='relative_to_subunit')

In [75]:
main('h1n1pdm', None, 'ha', [71,1718], coordinate_system='relative_to_subunit')

In [69]:
main('vic', None, 'ha', [56,1769], coordinate_system='relative_to_subunit')

In [77]:
main('yam', None, 'ha', [56,1766], coordinate_system='relative_to_subunit', year_min=1990)

In [126]:
main('rsv', 'a', 'g', [4680,5646], coordinate_system='relative_to_subunit')

In [127]:
main('rsv', 'b', 'g', [4689,5649], coordinate_system='relative_to_subunit')

In [131]:
main('measles', None, 'h', [7270,9124], coordinate_system='relative_to_subunit')

In [132]:
main('mumps', None, 'hn', [6550,8299], coordinate_system='relative_to_subunit')

In [142]:
main('dengue', 'denv2_AA', 'e', [936,2421], coordinate_system='relative_to_subunit')