In [351]:
from augur.utils import json_to_tree
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from Bio import SeqIO
from collections import Counter
from sklearn import linear_model
from sklearn.model_selection import ShuffleSplit
from sklearn.metrics import mean_squared_error
import requests
import re


Get global ncov tree and convert to Bio Phylo format

In [352]:
tree_url = "https://data.nextstrain.org/ncov_global.json"
tree_json = requests.get(tree_url).json()
tree = json_to_tree(tree_json)

Download entropy manually from nextstrain.org/ncov/global. Find the 0.5% of sites with highest entropy throughout genome

In [312]:
entropy_file = 'nextstrain_ncov_global_diversity.tsv'
entropy_df = pd.read_csv(entropy_file, sep='\t')
# genome is 29902 nt long, 
# which would be roughly 9967 codons (but this isn't quite right since there are non-coding regions)
# estimate about 9950 codons -> top 1% would 99.5 codons... so take 100 highest entropy values
top_entropy_df = entropy_df.nlargest(100,'entropy')

Read in sequence file, that will be used to find the genotype of all members of a clade

aws s3 cp s3://nextstrain-ncov-private/global_subsampled_sequences.fasta.xz .

In [353]:
sequences_file = 'global_subsampled_sequences.fasta'
record_dict = SeqIO.to_dict(SeqIO.parse(sequences_file, "fasta"))

Make dataframe where each row is a clade and columns contain information about logistic growth rate, how many S1 mutations from root to clade, and percentage of tips in clade that are genotype X

First, need to find genotype of isolates at desired positions (of highest entropy). Do this for all tips and store in a dictionary

In [354]:
# read in reference file and find genome position for each codon in the top entropy sites

# find the start position of the codon encoding each entropic site
genome_location_of_entropic_sites = {}

for k,v in top_entropy_df.iterrows():

    for record in SeqIO.parse(open("reference_seq_edited.gb","r"), "genbank"):
        for feature in record.features:
            if feature.type == 'CDS':
                if feature.qualifiers['gene'][0] == v['gene']:
                    mut_location_start = feature.location.start + int(v['position'] -1)*3
#                     print((record.seq[mut_location_start:mut_location_start+3]).translate())
                    genome_location_of_entropic_sites[mut_location_start] = f"{v['gene']}_{v['position']}"

                    

In [355]:
# initialize dictionary to store genotype info
tip_genotypes = {}

for k,v in record_dict.items():
    this_tip = {}
    #store genotype at each entropic site
    for nt_start, gene_codon in genome_location_of_entropic_sites.items():
        nt_codon = v.seq[nt_start:nt_start+3]
        if all([characters in ['A', 'C', 'G', 'T'] for characters in nt_codon]):
            this_tip[gene_codon] = str(nt_codon.translate())
        elif all([characters in ['-'] for characters in nt_codon]):
            this_tip[gene_codon] = '-'
        else:
            this_tip[gene_codon] = None

    tip_genotypes[k] = this_tip

In [356]:
#initiate list to store all clade info
clade_stats = []

#Only want to look at clades, don't care about tips
for node in tree.find_clades(terminal=False):
    
    # only care about clade if logistic growth value is associated
    # find stored logistic growth value
    if "logistic_growth" in node.node_attrs:
        logistic_growth = node.node_attrs["logistic_growth"]["value"]
        
        #all S1 muts from root to clade already stored as json value
        if "S1_mutations" in node.node_attrs:
            s1_mutations = node.node_attrs["S1_mutations"]["value"]
        else:
            s1_mutations = None

        # find all tips in this clade
        tips_in_clade = node.get_terminals()
        tip_names_in_clade = [tip.name for tip in tips_in_clade]
        
        # tally the genotype of each tip within the clade
        clade_genotype_tally = {}
        for tip in tip_names_in_clade:
            for site, aa in tip_genotypes[tip].items():
                # ignore the isolates that were not sequenced
                if aa != None:
                    if site in clade_genotype_tally.keys():
                        clade_genotype_tally[site]+=[aa]
                    else:
                        clade_genotype_tally[site] = [aa]
        
        # find proportion of tips that have given genotype
        clade_genotype_freqs = {}
        for k, v in clade_genotype_tally.items():
            counts = Counter(v)
            for item, count in counts.items():
                site_genotype = k+item
                clade_genotype_freqs[site_genotype] = count/len(v)

        
        clade_stats.append({'clade': node.name, 'logistic_growth': logistic_growth, 
                            'num_s1_muts': s1_mutations, **clade_genotype_freqs})


clade_stats_df = pd.DataFrame(clade_stats).fillna(0)
                

    

    

    

Filter to only mutations that occur at least 4 times on the tree. And don't allow reversion mutations

In [357]:
# find root sequence, in order to exclude reversions

reference_sequence = {}

for record in SeqIO.parse(open("reference_seq_edited.gb","r"), "genbank"):
    for feature in record.features:
        if feature.type == 'CDS':
            gene_seq = feature.location.extract(record.seq).translate()
            reference_sequence[feature.qualifiers['gene'][0]] = gene_seq




In [358]:
all_mutations = []

all_positions = []


#only look at mutations on internal branches
for node in tree.find_clades(terminal=False):
    
    num_descendents = len(node.get_terminals())

    # mutations must be propagated to at least a small clade
    if num_descendents >=10:
        # look just at nonsyn muts
        if hasattr(node, 'branch_attrs'):
            for gene, mut_list in node.branch_attrs["mutations"].items():
                if gene!= 'nuc':
                    for mut in mut_list:
                        gene_mutation = f'{gene}_{mut[1:]}'
                        gene_position = f'{gene}_{mut[1:-1]}'
                        all_mutations.append(gene_mutation)
                        all_positions.append(gene_position)



independent_occurences_all_mutations = Counter(all_mutations)
independent_occurences_all_positions = Counter(all_positions)

In [337]:
# filter the dataframe to only include mutations that are not reversions, and occur at least 4 times over tree
clade_stats_filtered_df = clade_stats_df.copy()

to_drop = []
for m, count in independent_occurences_all_mutations.items():

    
    if m in clade_stats_filtered_df.columns:
        # drop mutations that appear less than 4 independent times
        if count<4:
            to_drop.append(m)
            
for m in clade_stats_filtered_df.columns[3:]:
    # find whether mutation is a reversion, and drop it if so
    gene = m.split('_')[0]
    mut = m.split('_')[1]

    if reference_sequence[gene][int(mut[:-1])-1] == mut[-1]:
        to_drop.append(m)


clade_stats_filtered_df.drop(to_drop, inplace=True, axis=1)


Want to find muts that drive/predict logisitic growth. Intuitively, the prevalence of these muts should be enriched in clades with high logistic growth compared to clades with low logistic growth

In [320]:
def log_growth_enrichment(threshold, df):
    # get rid of first 3 columns (clade, logistic_growth, num_s1_muts)
    all_mut_keys = list(df.keys())[3:]

    avg_prevalence_high_growth = df[df['logistic_growth']>=threshold][all_mut_keys].mean()
    avg_prevalence_low_growth = df[df['logistic_growth']<threshold][all_mut_keys].mean()
    
    avg_prevalence_overall = df[all_mut_keys].mean()
    

    print(avg_prevalence_high_growth.subtract(avg_prevalence_low_growth).nlargest(20))
    
#     print((avg_prevalence_high_growth/avg_prevalence_overall).nlargest(20))


In [247]:
log_growth_enrichment(3.0, clade_stats_df)

ORF8_73C       0.690065
ORF1a_1708D    0.687376
S_1118H        0.687325
S_570D         0.686958
S_982A         0.686948
ORF1a_2230T    0.686314
N_235F         0.685546
ORF8_27*       0.684941
ORF8_52I       0.684897
ORF1a_1001I    0.684401
S_501Y         0.678876
N_3L           0.677292
S_716I         0.669107
S_144-         0.660464
S_69-          0.636046
S_681H         0.615012
ORF1a_3675-    0.611973
ORF1a_3676-    0.606666
ORF1a_3677-    0.604868
N_203K         0.554096
dtype: float64


After 501Y and nsp6 deletion, many of the hits above are just in 501Y.v1 clade. Because that clade is so large it skews the average. Look at the filtered df, which requires that the mutation has occurred at least 4 on the phylogeny, spreading to at least 10 descendents each time

In [338]:
log_growth_enrichment(3.0, clade_stats_filtered_df)

S_501Y         0.314463
ORF1a_3675-    0.290252
ORF1a_3676-    0.290138
ORF1a_3677-    0.289553
S_681H         0.276434
S_484K         0.042783
M_82T          0.041230
S_452R         0.037842
S_18F          0.028800
S_95I          0.027188
ORF8_68-       0.010788
ORF8_24-       0.010748
ORF8_27-       0.010748
ORF8_52-       0.010748
ORF8_84-       0.010710
ORF8_92-       0.010710
ORF8_73-       0.010710
S_138H         0.006039
S_144F         0.004523
S_477I         0.004105
dtype: float64


In [339]:
log_growth_enrichment(1.0, clade_stats_filtered_df)

S_501Y         0.459344
S_681H         0.437056
ORF1a_3677-    0.414989
ORF1a_3675-    0.414634
ORF1a_3676-    0.414170
M_82T          0.007173
ORF8_68-       0.006133
ORF8_52-       0.006110
ORF8_27-       0.006110
ORF8_24-       0.006110
ORF8_73-       0.006089
ORF8_92-       0.006089
ORF8_84-       0.006089
S_138H         0.005783
ORF1a_3676S    0.005642
ORF1a_3677L    0.004800
ORF1a_3675T    0.004634
S_95I          0.004116
N_3Q           0.003649
ORF9b_32H      0.002812
dtype: float64


Want to fit a multiple linear regression to see which mutations best predict logisitic growth

In [340]:
# get all possible explanatory variables
all_keys = list(clade_stats_filtered_df.keys())[2:]


X = clade_stats_filtered_df[all_keys]
# X = clade_stats_df[['num_s1_muts']]
# X = clade_stats_df[['S_501Y']]
y = clade_stats_filtered_df['logistic_growth']

# make a linear regression object using sklearn
regr = linear_model.LinearRegression()
regr.fit(X, y)

coefficients = list(regr.coef_)
ranked_coefficients_by_index = sorted(range(len(coefficients)), key=lambda x: coefficients[x])
top_ranked_idicies = reversed(ranked_coefficients_by_index[-20:])
for index in top_ranked_idicies:
    variable = all_keys[index]
    coefficient = coefficients[index]
    print(variable, coefficient)

ORF8_92- 2386889.7506320127
ORF8_27- 61597.79208677484
ORF8_24- 34653.172970179105
ORF8_52- 27754.962856982314
N_204- 25547.223332664355
N_203- 25547.223317929474
ORF1a_1640H 6947.862432278775
N_3G 2916.68335020463
N_3N 1423.3744379028765
N_202I 703.3072557435502
ORF1b_218S 560.4135788484829
N_203S 509.31792933764126
S_732S 455.68478078565056
S_716V 227.8339048337713
M_82V 212.26524559558314
S_215N 158.6357238632068
S_142V 148.37199179189588
ORF8_52K 135.73354802351815
S_26L 107.14246619430516
N_3E 91.13888667173268


In [341]:
all_keys = list(clade_stats_filtered_df.keys())[2:]

X = clade_stats_filtered_df[all_keys]
y = clade_stats_filtered_df['logistic_growth']

lasso = linear_model.LassoCV().fit(X, y)
importance = np.abs(lasso.coef_)
feature_names = np.array(all_keys)
# plt.bar(height=importance, x=feature_names)

ranked_importance_by_index = sorted(range(len(importance)), key=lambda x: importance[x])
top_ranked_idicies = reversed(ranked_importance_by_index[-20:])
for index in top_ranked_idicies:
    variable = all_keys[index]
    i = importance[index]
    print(variable, i)

S_681H 2.5001424292532115
num_s1_muts 1.223358381816525
S_18F 0.5535908780562372
ORF9b_32H 0.0
S_20A 0.0
ORF1a_1640H 0.0
N_2F 0.0
S_138V 0.0
S_138N 0.0
S_138H 0.0
S_138- 0.0
ORF8_92- 0.0
ORF9b_10H 0.0
S_477R 0.0
S_142V 0.0
N_377E 0.0
ORF1a_959T 0.0
S_732S 0.0
S_677L 0.0
ORF8_84- 0.0


In [116]:
all_keys = list(clade_stats_df.keys())[2:]

X = clade_stats_df[all_keys]
y = clade_stats_df['logistic_growth']

lin_reg = linear_model.LinearRegression()

# number of features to select
sel_num = 20
selected_feature = []

# Split into 5 training and test sets
rs = ShuffleSplit(n_splits=5, test_size=.25, random_state=0)

for i in all_keys:

    for train_index, test_index in rs.split(X):
        X_train, X_valid = X[i][train_index], X[i][test_index]
        y_train, y_valid = y[train_index], y[test_index]

        
        model_train = lin_reg.fit(np.array(X_train).reshape(-1, 1), np.array(y_train).reshape(-1, 1))
        predictions_train = model_train.predict(np.array(X_valid).reshape(-1, 1))
        print(mean_squared_error(np.array(y_valid).reshape(-1, 1), predictions_train))



ValueError: Found input variables with inconsistent numbers of samples: [1566, 522]

In [26]:
clade_stats_df[clade_stats_df['logistic_growth']>=4.0]


Unnamed: 0,clade,logistic_growth,num_s1_muts,N_204G,N_203R,N_203K,S_681P,S_681H,S_681R,S_501N,...,S_570V,ORF1a_2230T,S_484G,N_205-,S_417T,N_199-,N_199Q,ORF1b_2613S,S_701T,S_452Q
1052,NODE_0000055,4.095786,1,0.007332,0.006180,0.992135,0.565900,0.424565,0.008974,0.598519,...,0.000552,0.361738,0.00057,0.00056,0.023889,0.00056,0.00056,0.001147,0.000562,0.011824
1053,NODE_0000056,4.095786,1,0.006772,0.006183,0.992130,0.565657,0.424804,0.008979,0.598291,...,0.000553,0.361942,0.00057,0.00056,0.023902,0.00056,0.00056,0.001148,0.000562,0.011831
1054,NODE_0001698,4.095786,2,0.076923,0.076923,0.923077,1.000000,0.000000,0.000000,1.000000,...,0.000000,0.000000,0.00000,0.00000,0.000000,0.00000,0.00000,0.000000,0.000000,0.000000
1055,NODE_0002319,4.095786,2,0.090909,0.090909,0.909091,1.000000,0.000000,0.000000,1.000000,...,0.000000,0.000000,0.00000,0.00000,0.000000,0.00000,0.00000,0.000000,0.000000,0.000000
1056,NODE_0002322,4.095786,2,0.000000,0.000000,1.000000,1.000000,0.000000,0.000000,1.000000,...,0.000000,0.000000,0.00000,0.00000,0.000000,0.00000,0.00000,0.000000,0.000000,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2083,NODE_0004144,4.111753,6,0.000000,0.000000,1.000000,0.000000,1.000000,0.000000,0.000000,...,0.000000,1.000000,0.00000,0.00000,0.000000,0.00000,0.00000,0.000000,0.000000,0.000000
2084,NODE_0004145,4.111753,6,0.000000,0.000000,1.000000,0.000000,1.000000,0.000000,0.000000,...,0.000000,1.000000,0.00000,0.00000,0.000000,0.00000,0.00000,0.000000,0.000000,0.000000
2085,NODE_0000645,4.111753,6,0.000000,0.000000,1.000000,0.000000,1.000000,0.000000,0.000000,...,0.000000,1.000000,0.00000,0.00000,0.000000,0.00000,0.00000,0.000000,0.000000,0.000000
2086,NODE_0004147,4.111753,6,0.000000,0.000000,1.000000,0.000000,1.000000,0.000000,0.000000,...,0.000000,1.000000,0.00000,0.00000,0.000000,0.00000,0.00000,0.000000,0.000000,0.000000


In [23]:
#add in nsp6 deletion! this is ORF1a	3675
print(clade_stats_df[clade_stats_df['ORF1a_3675-'].notnull()]['ORF1a_3675-'])



105     0.134182
106     0.846154
107     0.916667
109     1.000000
110     1.000000
          ...   
2083    1.000000
2084    1.000000
2085    1.000000
2086    1.000000
2087    1.000000
Name: ORF1a_3675-, Length: 782, dtype: float64


In [3]:
#node.__dict__

First, make dataframe containing information about number of S1 mutations from root to clade and all mutations that occurred between root and clade. Each row is a clade

In [4]:
#initiate list to store all clade history info
clades_history = []

#Function to find path from root to clade
def get_parent(tree, child_clade):
    node_path = tree.get_path(child_clade)
    return node_path


#Only want to look at clades, don't care about tips
for node in tree.find_clades(terminal=False):
    
    #find all mutations that occurred on path from root to clade
    root_to_clade_mutations = []
    for parent in get_parent(tree, node):
        if hasattr(parent, 'branch_attrs'):
            root_to_clade_mutations.append(parent.branch_attrs['mutations'])
    
    #flatten root_to_clade_mutations, making dict entry for nucleotide muts and amino acid subs in each gene
    possible_mutation_site = ['ORF1a', 'ORF1b', 'S', 'ORF3a', 'E', 'M', 'ORF6', 
                              'ORF7a', 'ORF7b', 'ORF8', 'ORF9b', 'N', 'nuc']
    mutations_on_path = {k:[] for k in possible_mutation_site}
    
    for parent_muts in root_to_clade_mutations:
        for k,v in parent_muts.items():
            mutations_on_path[k]+=v
    
    
    #all S1 muts from root to clade already stored as json value
    if "S1_mutations" in node.node_attrs:
        s1_mutations = node.node_attrs["S1_mutations"]["value"]
    else:
        s1_mutations = None
    
    #find stored logistic growth value
    if "logistic_growth" in node.node_attrs:
        logistic_growth = node.node_attrs["logistic_growth"]["value"]
    else:
        logistic_growth = None
    
    #copy dictionary and add key/values for clade name and number of s1 mutations and logistic growth
    clade_mutation_history = mutations_on_path
    clade_mutation_history['clade'] = node.name
    clade_mutation_history['num_s1_mutations'] = s1_mutations
    clade_mutation_history['logistic_growth'] = logistic_growth
    

    
    clades_history.append(clade_mutation_history)
    


#turn list of clade history info into a dataframe
clades_df = pd.DataFrame(clades_history)

In [5]:
clades_df[pd.DataFrame(clades_df.nuc.tolist()).isin(['T11288-']).any(1).values]

Unnamed: 0,ORF1a,ORF1b,S,ORF3a,E,M,ORF6,ORF7a,ORF7b,ORF8,ORF9b,N,nuc,clade,num_s1_mutations,logistic_growth
292,"[S3675-, G3676-, F3677-]",[P314L],"[D614G, E484K]",[],[],[I82T],[],[],[],[],[],[T205I],"[C3037T, C14408T, C241T, A23403G, A21993-, T21...",NODE_0001779,2.0,6.511755
293,"[S3675-, G3676-, F3677-, A2123V, E2607K, M3752I]",[P314L],"[D614G, E484K, I210T, D936N, S939F, T1027I]",[],[],[I82T],[],[E22D],[],[],[P10S],"[T205I, P13L, S201I]","[C3037T, C14408T, C241T, A23403G, A21993-, T21...",NODE_0001780,3.0,6.511755
294,"[S3675-, G3676-, F3677-, A2123V, E2607K, M3752I]",[P314L],"[D614G, E484K, I210T, D936N, S939F, T1027I, N4...",[],[],[I82T],[],[E22D],[],[],[P10S],"[T205I, P13L, S201I]","[C3037T, C14408T, C241T, A23403G, A21993-, T21...",NODE_0001781,4.0,6.511755
295,"[S3675-, G3676-, F3677-, T2007I]",[P314L],"[D614G, E484K, A67V, H69-, V70-, Y144-, Q677H]",[],[L21F],[I82T],[F2-],[],[],[],[H9D],"[T205I, S2M, D3Y, A12G]","[C3037T, C14408T, C241T, A23403G, A21993-, T21...",NODE_0001782,6.0,6.511755
296,"[S3675-, G3676-, F3677-, T2007I]","[P314L, L314F]","[D614G, E484K, A67V, H69-, V70-, Y144-, Q677H,...",[],[L21F],[I82T],[F2-],[],[],[],[H9D],"[T205I, S2M, D3Y, A12G]","[C3037T, C14408T, C241T, A23403G, A21993-, T21...",NODE_0001783,7.0,6.511755
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3275,"[I2230T, S3675-, G3676-, F3677-, T1001I, A1708D]","[P314L, K1383R, P1001S]","[D614G, N501Y, H69-, V70-, Y144-, S982A, A570D...",[],[],[],[],[],[],"[Q27*, R52I, Y73C, K68*, *68K]",[],"[R203K, G204R, S235F, D3L]","[C3037T, C14408T, C241T, A23403G, G28881A, G28...",NODE_0004177,6.0,5.519968
3276,"[I2230T, S3675-, G3676-, F3677-, T1001I, A1708D]","[P314L, K1383R, P1001S]","[D614G, N501Y, H69-, V70-, Y144-, S982A, A570D...",[],[],[],[],[],[],"[Q27*, R52I, Y73C, K68*, *68K]",[],"[R203K, G204R, S235F, D3L]","[C3037T, C14408T, C241T, A23403G, G28881A, G28...",NODE_0004178,6.0,5.519968
3277,"[I2230T, S3675-, G3676-, F3677-, T1001I, A1708D]","[P314L, K1383R, P1001S]","[D614G, N501Y, H69-, V70-, Y144-, S982A, A570D...",[],[],[],[],[],[],"[Q27*, R52I, Y73C, K68*, *68K]",[],"[R203K, G204R, S235F, D3L]","[C3037T, C14408T, C241T, A23403G, G28881A, G28...",NODE_0004179,6.0,
3278,"[I2230T, S3675-, G3676-, F3677-, T1001I, A1708D]","[P314L, K1383R, P1001S, I1181S]","[D614G, N501Y, H69-, V70-, Y144-, S982A, A570D...",[],[],[],[],[],[],"[Q27*, R52I, Y73C, K68*, *68K]",[],"[R203K, G204R, S235F, D3L]","[C3037T, C14408T, C241T, A23403G, G28881A, G28...",NODE_0000816,6.0,5.519968
