In [10]:
from augur.utils import json_to_tree
import pandas as pd
import json
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.lines import Line2D  
import numpy as np
from sklearn.linear_model import LinearRegression

### Linear model to predict growth rate from mutations on gene

In [2]:
all_genes = ['Nsp1', 'Nsp2', 'Nsp3', 'Nsp4', 'Nsp5', 'Nsp6', 'Nsp7', 'Nsp8', 
             'Nsp9', 'Nsp10', 'Nsp12', 'Nsp13', 'Nsp14', 'Nsp15', 'Nsp16', 
             'S', 'S1', 'S2', 'ORF3a', 'E', 'M', 'ORF6', 'ORF7a', 
             'ORF7b', 'ORF8', 'N', 'ORF9b', 'ORF10']

In [3]:
def readin_tree(date, tree_type):
    """
    Read in the 2m tree json for the specified date
    """

    # path to tree json
    if tree_type == 'sars2':
        tree_file = f'trees_w_mut_counts/sars2_{date}_2m.json'
    elif tree_type == '21L':
        tree_file = f'trees_w_mut_counts/sars2_21L_{date}_2m.json'    

    with open(tree_file, 'r') as f:
        tree_json = json.load(f)

    # put tree in Bio.phylo format
    tree = json_to_tree(tree_json)
    
    return tree

In [68]:
def get_mut_and_growth_info(date, tree_type):
    """
    Get logistic growth and number of mutations for each node in the tree
    Return a dataframe where each row is has this information for a given node.
    
    Also return this information as a list of Y and X's multiple linear regression, 
    where x should be a list of lists where each sublist is the mutations per gene 
    for nodeA, for each node in the tree.
    Order of X is given by the list `all_genes`.
    y is a list of logistic growth rate at each node in the tree.
    """
    tree = readin_tree(date, tree_type)
    
    mut_and_growth_info = []
    
    # Xs will be nonsynonymous mutations at each gene, for each node
    X_list = []
    # Ys will be logistic growth rate, for each node
    Y_list = []
    
    for node in tree.find_clades():
        # only for nodes that have an assigned logistic growht rate (those that are in the past 6 weeks)
        if 'logistic_growth' in node.node_attrs:
            logistic_growth = node.node_attrs['logistic_growth']['value']
            # look at mutations in each gene
            # list of mutation count per gene at this node. Ordered by `all_genes`
            nonsyn_mut_list = []
            nonsyn_muts = {}
            syn_muts = {}
            stop_muts = {}
            
            for gene in all_genes:
                nonsyn_mut_list.append(node.mut_accumulation['Nonsyn_muts'][gene])
                nonsyn_muts[f'{gene}_nonsyn'] = node.mut_accumulation['Nonsyn_muts'][gene]
                syn_muts[f'{gene}_syn'] = node.mut_accumulation['Syn_muts'][gene]
                stop_muts[f'{gene}_stop'] = node.mut_accumulation['Stop_muts'][gene]
            
            X_list.append(nonsyn_mut_list)
            Y_list.append(logistic_growth)
                
            mut_and_growth_info.append({'date': date, 'gene':gene, 'node': node.name,
                                        'logistic_growth': logistic_growth, 
                                        **nonsyn_muts, **syn_muts, **stop_muts})

    df = pd.DataFrame(mut_and_growth_info)
    
    # make mutations counts "relative"
    # meaning, for each timepoint, give the mutation count as (absolute_count - avg_count_in_gene)
    X_relative_list_by_gene = []
    for i in range(len(all_genes)):
        absolute_mut_counts_for_gene = [x[i] for x in X_list]
        gene_avg = sum(absolute_mut_counts_for_gene)/len(absolute_mut_counts_for_gene)
        relative_mut_counts_for_gene = [x-gene_avg for x in absolute_mut_counts_for_gene]
        X_relative_list_by_gene.append(relative_mut_counts_for_gene)
    
    # transpose the by_gene list to get by node, which is how Y values are listed
    X_relative_list = np.array(X_relative_list_by_gene).T.tolist()
    
    return df, X_list, X_relative_list, Y_list

In [69]:
# get data for all genes, all timepoints
all_dates = ['2020-03', '2020-05', '2020-07', '2020-09', '2020-11', 
             '2021-01', '2021-03', '2021-05', '2021-07', '2021-09', '2021-11', 
             '2022-01', '2022-03', '2022-05', '2022-07', '2022-09', '2022-11', 
             '2023-01', '2023-03', '2023-05']
sars2_df = pd.DataFrame()
#combine Xs and Ys for all timepoints
sars2_X = []
sars2_relative_X = []
sars2_Y = []

for d in all_dates:
    df_date, X_list_date, X_relative_list_date, Y_list_date = get_mut_and_growth_info(d, 'sars2')
    sars2_df = pd.concat([sars2_df, df_date])
    sars2_X+=X_list_date
    sars2_relative_X+=X_relative_list_date
    sars2_Y+=Y_list_date

In [94]:
# save json for Trevor
data_to_save = {'gene_order': all_genes, 
                'relative_mut_counts': sars2_relative_X, 
                'mut_counts': sars2_X, 
                'logistic_growth_rates': sars2_Y}

# Serializing json
json_object = json.dumps(data_to_save, indent=2)

#write json
with open("allSARS2_muts_v_growth.json", "w") as outfile:
    outfile.write(json_object)

In [70]:
# get data for all genes, all timepoints for 21L-only builds
all_dates_21L = ['2022-03', '2022-04', '2022-05', '2022-06', '2022-07', 
             '2022-08', '2022-09', '2022-10', '2022-11', 
             '2023-01', '2023-02', '2023-03', '2023-04', '2023-05', '2023-06']

sars221L_df = pd.DataFrame()
#combine Xs and Ys for all timepoints
sars221L_X = []
sars221L_relative_X = []
sars221L_Y = []

for d in all_dates_21L:
    df_date, X_list_date, X_relative_list_date, Y_list_date = get_mut_and_growth_info(d, '21L')
    sars221L_df = pd.concat([sars221L_df, df_date])
    sars221L_X+=X_list_date
    sars221L_relative_X+=X_relative_list_date
    sars221L_Y+=Y_list_date

In [80]:
def linear_model(X_list, Y_list):
    """
    Predict Y (logisitc clade growth) from a set of X's (mutations in gene)
    Combine all dates
    """
    
    # make lists into numpy arrays
    x, y = np.array(X_list), np.array(Y_list)
    
    # create instance of LinearRegression and fit it to data 
    model = LinearRegression().fit(x, y)
    r_sq = model.score(x, y)
    print(f"coefficient of determination: {r_sq}\n")
    
    coefficients = model.coef_
#     print(sorted(coefficients))
    for i in range(len(all_genes)):
        g = all_genes[i]
        c = coefficients[i]
        print(f'{g} coefficient: {c}')
    
    

In [81]:
# run model with all dates, and relative mutation counts
linear_model(sars2_relative_X, sars2_Y)

coefficient of determination: 0.3386739245559056

Nsp1 coefficient: 2.285234139028699
Nsp2 coefficient: -0.030412781247482368
Nsp3 coefficient: -0.33586862432714115
Nsp4 coefficient: 0.8212857158345233
Nsp5 coefficient: -0.6106388827318456
Nsp6 coefficient: -0.7870579832935483
Nsp7 coefficient: 1.1788284202224413
Nsp8 coefficient: -0.9321742530151544
Nsp9 coefficient: 0.0742433925190154
Nsp10 coefficient: -0.6934564422508988
Nsp12 coefficient: 0.75263748560281
Nsp13 coefficient: 0.5255435897590162
Nsp14 coefficient: -0.05975816548452148
Nsp15 coefficient: 0.9209301965395342
Nsp16 coefficient: 0.16905972478643289
S coefficient: -0.3826027410028662
S1 coefficient: 1.545801551692052
S2 coefficient: 0.08522783388095792
ORF3a coefficient: -0.3374398679776279
E coefficient: 1.3677258206554226
M coefficient: 1.933970533793066
ORF6 coefficient: -0.9913685501397023
ORF7a coefficient: 0.5196287806972777
ORF7b coefficient: 1.2066538006483005
ORF8 coefficient: 0.18989630874766172
N coefficient: 0.

In [82]:
# run model with all dates, and absolute mutation counts
linear_model(sars2_X, sars2_Y)

coefficient of determination: 0.03518091300566972

Nsp1 coefficient: 0.6612223926192435
Nsp2 coefficient: -0.1506792311623942
Nsp3 coefficient: -0.1896739602047245
Nsp4 coefficient: -0.19950757233343025
Nsp5 coefficient: -0.5370080412712569
Nsp6 coefficient: -0.4062194476064762
Nsp7 coefficient: 1.0073163331905304
Nsp8 coefficient: -0.9306738653318737
Nsp9 coefficient: -0.29953552871903755
Nsp10 coefficient: -0.37655783774183943
Nsp12 coefficient: 0.0615228628508013
Nsp13 coefficient: 0.1228746569592328
Nsp14 coefficient: -0.1366762087285849
Nsp15 coefficient: -0.20128718209757643
Nsp16 coefficient: 0.138678985117319
S coefficient: -0.5096602651870573
S1 coefficient: 0.5352329046241211
S2 coefficient: 0.6580600427259035
ORF3a coefficient: -0.4749512632193728
E coefficient: 1.619349699275942
M coefficient: -0.24880385242946232
ORF6 coefficient: -0.13745313954998062
ORF7a coefficient: 0.4821984178879251
ORF7b coefficient: 0.48808097081436413
ORF8 coefficient: -0.06639313899234497
N coeff

In [92]:
# run model with all dates, and relative mutation counts, for 21L-only
linear_model(sars221L_relative_X, sars221L_Y)

coefficient of determination: 0.3637581916427015

Nsp1 coefficient: 0.43798124029717245
Nsp2 coefficient: 0.025123015595105658
Nsp3 coefficient: -0.0332638347881984
Nsp4 coefficient: -1.550012626484849
Nsp5 coefficient: -0.5980677802710976
Nsp6 coefficient: 0.1453015192057636
Nsp7 coefficient: 0.15929061356584437
Nsp8 coefficient: -0.6410039339482044
Nsp9 coefficient: 1.713413100640249
Nsp10 coefficient: -0.5807024728244552
Nsp12 coefficient: -0.5447157805197039
Nsp13 coefficient: 0.459336210087338
Nsp14 coefficient: -0.10240906673558307
Nsp15 coefficient: -0.7403706068037291
Nsp16 coefficient: -0.18903233518455945
S coefficient: -0.7353926375902275
S1 coefficient: 2.3013664341644806
S2 coefficient: 0.619304455614992
ORF3a coefficient: -0.3012767814397545
E coefficient: 1.242454854627707
M coefficient: 4.060183058584254
ORF6 coefficient: -0.41981286305995213
ORF7a coefficient: -0.14536172864676317
ORF7b coefficient: 1.2818127655942837
ORF8 coefficient: -1.1237433950620588
N coefficient

Try each date individually, and without S (in case the overlapping info of S/S1/S2 is messing things up)

In [74]:
def linear_model_by_date(date, tree_type):
    """
    Predict Y (logisitc clade growth) from a set of X's (mutations in gene)
    Combine all dates
    """
    
    df, X_list, X_relative_list, Y_list = get_mut_and_growth_info(date, tree_type)
    
    # S is 15th element in each list
    S_removed = [x.pop(15) for x in X_list]
    all_genes_wo_S = all_genes[0:15]+all_genes[16:]
    
    # make lists into numpy arrays
    x, y = np.array(X_list), np.array(Y_list)
    
    # create instance of LinearRegression and fit it to data 
    model = LinearRegression().fit(x, y)
    r_sq = model.score(x, y)
    print(f"coefficient of determination: {r_sq}\n")
    
    coefficients = model.coef_
    for i in range(len(all_genes_wo_S)):
        g = all_genes_wo_S[i]
        c = coefficients[i]
        print(f'{g} coefficient: {c}')

In [None]:
# all_dates = ['2020-03', '2020-05', '2020-07', '2020-09', '2020-11', 
#              '2021-01', '2021-03', '2021-05', '2021-07', '2021-09', '2021-11', 
#              '2022-01', '2022-03', '2022-05', '2022-07', '2022-09', '2022-11', 
#              '2023-01', '2023-03', '2023-05']

In [93]:
# run model on individual date only

linear_model_by_date('2023-05', 'sars2')

coefficient of determination: 0.7470534969438616

Nsp1 coefficient: 1.4057200775668062
Nsp2 coefficient: -0.7550708612741635
Nsp3 coefficient: -0.16879180175335107
Nsp4 coefficient: 0.2720460918394415
Nsp5 coefficient: 0.7629031558794598
Nsp6 coefficient: 0.1635630410071834
Nsp7 coefficient: 0.8792391566051786
Nsp8 coefficient: -2.0055092100616445
Nsp9 coefficient: -0.8612183827183594
Nsp10 coefficient: -0.5191439508057052
Nsp12 coefficient: -0.5737575837289781
Nsp13 coefficient: -0.666317792690625
Nsp14 coefficient: -0.1455834102456039
Nsp15 coefficient: -0.5966779317244327
Nsp16 coefficient: -1.5707790822100776
S1 coefficient: 0.5239127648906412
S2 coefficient: -0.2967926624901597
ORF3a coefficient: 0.040951047100089115
E coefficient: 1.9287812661212846
M coefficient: -0.9856540339491371
ORF6 coefficient: 1.336004865076616
ORF7a coefficient: 1.802155845107335
ORF7b coefficient: -0.6531741244862826
ORF8 coefficient: -0.1852010560377825
N coefficient: -1.5075904915720044
ORF9b coeffici