# Natural sequence and DMS analysis

This notebook is preliminary analysis of antigenic selection across Lassa sequences and is not included in the final HTML documents.

In [None]:
# Imports
import os
import json
import warnings
import dmslogo
import neutcurve
import numpy as np
import scipy as sp
import pandas as pd
import altair as alt
import seaborn as sns
import statsmodels.api
import matplotlib.colors
from Bio import SeqIO, AlignIO 
from matplotlib import pyplot as plt
from matplotlib import ticker as mticker

# Rearranged to make the tree look nicer
# re-ordered
tol_muted_adjusted = [
    "#000000",
    "#CC6677", 
    "#1f78b4", 
    "#88CCEE",
    "#DDDDDD",
    "#882255",
    "#117733", 
    "#DDCC77",
    "#44AA99", 
    "#EE7733",
    "#AA4499",
    "#999933", 
    "#CC3311",
]

# Seaborn style settings
sns.set(rc={
    "figure.dpi":300, 
    "savefig.dpi":300,
    "svg.fonttype":"none",
})
sns.set_style("ticks")
sns.set_palette(tol_muted_adjusted)

# Suppress warnings
warnings.simplefilter("ignore")

# Allow more rows for Altair
_ = alt.data_transformers.disable_max_rows()

In [None]:
# this cell is tagged as `parameters` for papermill parameterization
filtered_escape_377H = None
filtered_escape_89F = None
filtered_escape_2510C = None
filtered_escape_121F = None
filtered_escape_256A = None
filtered_escape_372D = None

func_scores = None
min_times_seen = None
n_selections = None

GPC_tree_mutations = None
GPC_FEL_results = None
GPC_FUBAR_results = None

natural_seq_metadata = None
natural_seq_alignment = None

out_dir_natural = None
ols_regression = None

In [None]:
# # Uncomment for running interactive
# filtered_escape_377H = "../results/filtered_antibody_escape_CSVs/377H_filtered_mut_effect.csv"
# filtered_escape_89F = "../results/filtered_antibody_escape_CSVs/89F_filtered_mut_effect.csv"
# filtered_escape_2510C = "../results/filtered_antibody_escape_CSVs/2510C_filtered_mut_effect.csv"
# filtered_escape_121F = "../results/filtered_antibody_escape_CSVs/121F_filtered_mut_effect.csv"
# filtered_escape_256A = "../results/filtered_antibody_escape_CSVs/256A_filtered_mut_effect.csv"
# filtered_escape_372D = "../results/filtered_antibody_escape_CSVs/372D_filtered_mut_effect.csv"

# func_scores = "../results/func_effects/averages/293T_entry_func_effects.csv"
# min_times_seen = 2
# n_selections = 8

# GPC_tree_mutations = "../non-pipeline_analyses/LASV_phylogeny_analysis/Results/GPC_tree_mutations.csv"
# GPC_FEL_results = "../non-pipeline_analyses/LASV_phylogeny_analysis/Results/GPC_FEL_results.json"
# GPC_FUBAR_results = "../non-pipeline_analyses/LASV_phylogeny_analysis/Results/GPC_FUBAR_results.json"

# natural_seq_metadata = "../non-pipeline_analyses/LASV_phylogeny_analysis/Results/LASV_S_segment_metadata.tsv"
# natural_seq_alignment = "../non-pipeline_analyses/LASV_phylogeny_analysis/Results/LASV_GPC_protein_alignment.fasta"

# out_dir_natural = "../results/natural_isolate_escape/"
# ols_regression = "../results/natural_isolate_escape/ols_regression.html"

Process functional scores and antibody escape to create one dataframe with site level information.

In [None]:
# Read functional score data
functional_scores = pd.read_csv(func_scores)

# Filter for minimum selections, times seen and no stop codons
functional_scores = (
    functional_scores.query(
        "n_selections >= @n_selections and times_seen >= @min_times_seen and mutant != '*'"
    )
    .drop(columns=["mutant", "times_seen"])
    .groupby(["site", "wildtype"])
    .aggregate({
        "effect" : "mean"
    })
    .reset_index()
)

# Initialize list of escape files and final df
antibody_files = [
    filtered_escape_377H,
    filtered_escape_89F,
    filtered_escape_2510C,
    filtered_escape_121F,
    filtered_escape_256A,
    filtered_escape_372D,
]
antibody_escape = pd.DataFrame(columns=["site", "wildtype"])

# Add escape to dataframe for each antibody
for index,antibody_file in enumerate(antibody_files):

    antibody_name = antibody_file.split("/")[-1].split("_")[0]

    # Load data as dataframe
    escape_df = pd.read_csv(antibody_file)

    # Clip lower scores to 0
    escape_df["escape_median"] = escape_df["escape_median"].clip(lower=0)

    # # Normalize scores
    # escape_df["escape_median"] = (
    #     (escape_df["escape_median"] - escape_df["escape_median"].min()) / (escape_df["escape_median"].max() - escape_df["escape_median"].min())  
    # )

    # Rename escape column to include antibody name
    escape_df = (
        escape_df
        .rename(columns={"escape_median" : "escape_" + antibody_name})
        .groupby(["site", "wildtype"])
        .aggregate({"escape_" + antibody_name : "sum"})
        .reset_index()
    )

    if index == 0:
        antibody_escape = (
            pd.concat([
                antibody_escape, 
                escape_df[["site", "wildtype", "escape_" + antibody_name]],
            ])
        )
    else:
        # Merge dataframes
        antibody_escape = (
            antibody_escape.merge(
                escape_df[["site", "wildtype", "escape_" + antibody_name]],
                how="left",
                on=["site", "wildtype",],
                validate="one_to_one",
            )
        )

# Create new column with antibody summed escape
antibody_escape["total_escape"] = (
    antibody_escape[[
        "escape_377H",
        "escape_89F",
        "escape_2510C",
        "escape_121F",
        "escape_256A",
        "escape_372D",
    ]].sum(axis=1)
)

merged_df = (
    functional_scores.merge(
        antibody_escape,
        how="left",
        on=["site", "wildtype"],
        validate="one_to_one",
    )
)

# Load tree mutation counts
tree_mutations = pd.read_csv(GPC_tree_mutations)

# Calculate counts of synonymous and nonsynonous mutations and ratio
tree_mutations["synonymous"] = (
    tree_mutations["mut_type"].apply(lambda x: 1 if x == "synonymous" else 0)
)
tree_mutations["nonsynonymous"] = (
    tree_mutations["mut_type"].apply(lambda x: 1 if x == "nonsynonymous" else 0)
)
tree_mutations = (
    tree_mutations
    .groupby(["site"])
    .aggregate({
        "mut_type" : "count",
        "synonymous" : "sum",
        "nonsynonymous" : "sum",
    })
    .reset_index()
)
tree_mutations["non/syn"] = tree_mutations["nonsynonymous"] / tree_mutations["synonymous"]

# Merge tree mutation counts
merged_df = (
    merged_df.merge(
        tree_mutations,
        how="left",
        on=["site"],
        validate="one_to_one",
    )
)

# Load FEL data
FEL_results = pd.read_json(GPC_FEL_results)

# Process json file and extract site info
FEL_results = (
    FEL_results
    .reset_index()
)
FEL_results = (
    pd.DataFrame(FEL_results.at[0, "MLE"])
    .rename(columns={"0" : "data"})
)
headers = [
    "FEL_alpha",
    "FEL_beta",
    "FEL_alpha=beta",
    "FEL_LRT",
    "FEL_p-value",
    "FEL_Total branch length",
    "FEL_p-asmp",
]
FEL_results[headers] = pd.DataFrame(
    FEL_results["data"].tolist(), 
    index=FEL_results.index,
)
FEL_results = (
    FEL_results
    .reset_index(drop=False)
    .rename(columns={"index" : "site"})
    .drop(columns=["data"])
)
FEL_results["site"] = FEL_results["site"] + 1

# Merge tree mutation counts
merged_df = (
    merged_df.merge(
        FEL_results,
        how="left",
        on=["site"],
        validate="one_to_one",
    )
)

# Load FUBAR data
# requires json load because the output file
# has mixed dict and series
json_data = json.load(open(GPC_FUBAR_results))
FUBAR_results = (
    pd.DataFrame(json_data["MLE"]["content"])
    .rename(columns={"0" : "data"})
)

# Process json file and extract site info
headers = [
    "FUBAR_alpha",
    "FUBAR_beta",
    "FUBAR_beta-alpha",
    "FUBAR_Prob[alpha>beta]",
    "FUBAR_Prob[alpha<beta]",
    "FUBAR_BayesFactor[alpha<beta]",
    "empty_1", # dummy columns
    "empty_2", # dummy columns
]
FUBAR_results[headers] = pd.DataFrame(
    FUBAR_results["data"].tolist(), 
    index=FUBAR_results.index,
)
FUBAR_results = (
    FUBAR_results
    .reset_index(drop=False)
    .rename(columns={"index" : "site"})
    .drop(columns=["data", "empty_1", "empty_2"])
)
FUBAR_results["site"] = FUBAR_results["site"] + 1

# Merge tree mutation counts
merged_df = (
    merged_df.merge(
        FUBAR_results,
        how="left",
        on=["site"],
        validate="one_to_one",
    )
)

## Correlation of phylogenetic tree metrics

In [None]:
comparisons = [
    ("FEL_beta", "nonsynonymous"),
    ("FUBAR_beta", "nonsynonymous"),
    ("FUBAR_beta", "FEL_beta"),
]
subplots = []
for pair in comparisons:
    stat1 = pair[0]
    stat2 = pair[1]

    # Calculate correlation 
    r, p = sp.stats.pearsonr(
        merged_df[[stat1, stat2]].dropna()[stat1],
        merged_df[[stat1, stat2]].dropna()[stat2],
    )
    print(f"r correlation of {stat1} and {stat2}: {r:.2f}")
    
    curr_subplot = alt.Chart(
        merged_df, 
    ).mark_point(
        filled=True, 
        color="black", 
        size=75,
        opacity=0.15,
    ).encode(
        alt.Y(
            stat1,
            axis=alt.Axis(
                title=stat1,
                # values=[-4,-3,-2,-1,0,1],
                # domainWidth=1,
                # domainColor="black",
                # tickColor="black",
            ),
            # scale=alt.Scale(domain=[-4.1,1.1])
        ),
        alt.X(
            stat2,
            axis=alt.Axis(
                title=stat2,
                # values=[-3,-2,-1,0],
                # domainWidth=1,
                # domainColor="black",
                # tickColor="black",
            ),
            # scale=alt.Scale(type="symlog")
        ),
        tooltip=[
            "site",
            "wildtype",
            "effect",
            "total_escape",
            "nonsynonymous",
            "synonymous",
            "non/syn",
            "FEL_beta",
            "FEL_alpha",
            "FEL_alpha=beta",
            "FUBAR_beta",
            "FUBAR_alpha",
            "FUBAR_beta-alpha",
        ],
    ).properties(
        width=300,
        height=300,
    )

    subplots.append(curr_subplot)

nonsyn = alt.hconcat(
    subplots[0],
    subplots[1],
    subplots[2],
    spacing=5,
    title=["Correlations of nonsynonymous mutation metrics"],
)


comparisons = [
    ("FEL_alpha", "synonymous"),
    ("FUBAR_alpha", "synonymous"),
    ("FUBAR_alpha", "FEL_alpha"),
]
subplots = []
for pair in comparisons:
    stat1 = pair[0]
    stat2 = pair[1]

    # Calculate correlation 
    r, p = sp.stats.pearsonr(
        merged_df[[stat1, stat2]].dropna()[stat1],
        merged_df[[stat1, stat2]].dropna()[stat2],
    )
    print(f"r correlation of {stat1} and {stat2}: {r:.2f}")
    
    curr_subplot = alt.Chart(
        merged_df, 
    ).mark_point(
        filled=True, 
        color="black", 
        size=75,
        opacity=0.15,
    ).encode(
        alt.Y(
            stat1,
            axis=alt.Axis(
                title=stat1,
                # values=[-4,-3,-2,-1,0,1],
                # domainWidth=1,
                # domainColor="black",
                # tickColor="black",
            ),
            # scale=alt.Scale(domain=[-4.1,1.1])
        ),
        alt.X(
            stat2,
            axis=alt.Axis(
                title=stat2,
                # values=[-3,-2,-1,0],
                # domainWidth=1,
                # domainColor="black",
                # tickColor="black",
            ),
            # scale=alt.Scale(type="symlog")
        ),
        tooltip=[
            "site",
            "wildtype",
            "effect",
            "total_escape",
            "nonsynonymous",
            "synonymous",
            "non/syn",
            "FEL_beta",
            "FEL_alpha",
            "FEL_alpha=beta",
            "FUBAR_beta",
            "FUBAR_alpha",
            "FUBAR_beta-alpha",
        ],
    ).properties(
        width=300,
        height=300,
    )

    subplots.append(curr_subplot)

syn = alt.hconcat(
    subplots[0],
    subplots[1],
    subplots[2],
    spacing=5,
    title=["Correlations of synonymous mutation metrics"],
)


comparisons = [
    ("FEL_alpha=beta", "non/syn"),
    ("FUBAR_beta-alpha", "non/syn"),
    ("FUBAR_beta-alpha", "FEL_alpha=beta"),
]
subplots = []
for pair in comparisons:
    stat1 = pair[0]
    stat2 = pair[1]

    # Calculate correlation 
    r, p = sp.stats.pearsonr(
        merged_df[[stat1, stat2]].replace([np.inf, -np.inf], np.nan).dropna()[stat1],
        merged_df[[stat1, stat2]].replace([np.inf, -np.inf], np.nan).dropna()[stat2],
    )
    print(f"r correlation of {stat1} and {stat2}: {r:.2f}")
    
    curr_subplot = alt.Chart(
        merged_df, 
    ).mark_point(
        filled=True, 
        color="black", 
        size=75,
        opacity=0.15,
    ).encode(
        alt.Y(
            stat1,
            axis=alt.Axis(
                title=stat1,
                # values=[-4,-3,-2,-1,0,1],
                # domainWidth=1,
                # domainColor="black",
                # tickColor="black",
            ),
            # scale=alt.Scale(domain=[-4.1,1.1])
        ),
        alt.X(
            stat2,
            axis=alt.Axis(
                title=stat2,
                # values=[-3,-2,-1,0],
                # domainWidth=1,
                # domainColor="black",
                # tickColor="black",
            ),
            # scale=alt.Scale(type="symlog")
        ),
        tooltip=[
            "site",
            "wildtype",
            "effect",
            "total_escape",
            "nonsynonymous",
            "synonymous",
            "non/syn",
            "FEL_beta",
            "FEL_alpha",
            "FEL_alpha=beta",
            "FUBAR_beta",
            "FUBAR_alpha",
            "FUBAR_beta-alpha",
        ],
    ).properties(
        width=300,
        height=300,
    )

    subplots.append(curr_subplot)

nonsyn_and_syn = alt.hconcat(
    subplots[0],
    subplots[1],
    subplots[2],
    spacing=5,
    title=["Correlations of combined nonsynonymous and synonymous mutation metrics"],
)

combined_plot = alt.vconcat(
    nonsyn,
    syn,
    nonsyn_and_syn,
    spacing=5,
).configure_axis(
    grid=False,
    labelFontSize=16,
    titleFontSize=16,
    labelFontWeight="normal",
    titleFontWeight="normal",
).configure_axis(
    grid=False,
    labelFontSize=16,
    titleFontSize=16,
    labelFontWeight="normal",
    titleFontWeight="normal",
).configure_title(
    fontSize=24,
)

combined_plot

The lower end of nonsynonymous mutations (i.e., 0 to 10) are less correlated even though the correlation looks really good when zoomed out. 

## Tree mutations compared to functional effects and antibody escape

In [None]:
comparisons = [
    ("effect", "nonsynonymous"),
    ("effect", "synonymous"),
    ("effect", "non/syn"),
]
subplots = []
for pair in comparisons:
    stat1 = pair[0]
    stat2 = pair[1]

    # Calculate correlation 
    r, p = sp.stats.pearsonr(
        merged_df[[stat1, stat2]].replace([np.inf, -np.inf], np.nan).dropna()[stat1],
        merged_df[[stat1, stat2]].replace([np.inf, -np.inf], np.nan).dropna()[stat2],
    )
    print(f"r correlation of {stat1} and {stat2}: {r:.2f}")
    
    curr_subplot = alt.Chart(
        merged_df, 
    ).mark_point(
        filled=True, 
        color="black", 
        size=75,
        opacity=0.15,
    ).encode(
        alt.Y(
            stat1,
            axis=alt.Axis(
                title=stat1,
                # values=[-4,-3,-2,-1,0,1],
                # domainWidth=1,
                # domainColor="black",
                # tickColor="black",
            ),
            # scale=alt.Scale(domain=[-4.1,1.1])
        ),
        alt.X(
            stat2,
            axis=alt.Axis(
                title=stat2,
                # values=[-3,-2,-1,0],
                # domainWidth=1,
                # domainColor="black",
                # tickColor="black",
            ),
            # scale=alt.Scale(type="symlog")
        ),
        tooltip=[
            "site",
            "wildtype",
            "effect",
            "total_escape",
            "nonsynonymous",
            "synonymous",
            "non/syn",
            "FEL_beta",
            "FEL_alpha",
            "FEL_alpha=beta",
            "FUBAR_beta",
            "FUBAR_alpha",
            "FUBAR_beta-alpha",
        ],
    ).properties(
        width=300,
        height=300,
    )

    subplots.append(curr_subplot)

effects = alt.hconcat(
    subplots[0],
    subplots[1],
    subplots[2],
    spacing=5,
    title=["Correlations of functional effects and tree mutation counts"],
)


comparisons = [
    ("total_escape", "nonsynonymous"),
    ("total_escape", "synonymous"),
    ("total_escape", "non/syn"),
]
subplots = []
for pair in comparisons:
    stat1 = pair[0]
    stat2 = pair[1]

    # Calculate correlation 
    r, p = sp.stats.pearsonr(
        merged_df[[stat1, stat2]].replace([np.inf, -np.inf], np.nan).dropna()[stat1],
        merged_df[[stat1, stat2]].replace([np.inf, -np.inf], np.nan).dropna()[stat2],
    )
    print(f"r correlation of {stat1} and {stat2}: {r:.2f}")
    
    curr_subplot = alt.Chart(
        merged_df, 
    ).mark_point(
        filled=True, 
        color="black", 
        size=75,
        opacity=0.15,
    ).encode(
        alt.Y(
            stat1,
            axis=alt.Axis(
                title=stat1,
                # values=[-4,-3,-2,-1,0,1],
                # domainWidth=1,
                # domainColor="black",
                # tickColor="black",
            ),
            # scale=alt.Scale(domain=[-4.1,1.1])
        ),
        alt.X(
            stat2,
            axis=alt.Axis(
                title=stat2,
                # values=[-3,-2,-1,0],
                # domainWidth=1,
                # domainColor="black",
                # tickColor="black",
            ),
            # scale=alt.Scale(type="symlog")
        ),
        tooltip=[
            "site",
            "wildtype",
            "effect",
            "total_escape",
            "nonsynonymous",
            "synonymous",
            "non/syn",
            "FEL_beta",
            "FEL_alpha",
            "FEL_alpha=beta",
            "FUBAR_beta",
            "FUBAR_alpha",
            "FUBAR_beta-alpha",
        ],
    ).properties(
        width=300,
        height=300,
    )

    subplots.append(curr_subplot)

escape = alt.hconcat(
    subplots[0],
    subplots[1],
    subplots[2],
    spacing=5,
    title=["Correlations of antibody escape and tree mutation counts"],
)

combined_plot = alt.vconcat(
    effects,
    escape,
    spacing=5,
    title="Tree mutations",
).configure_axis(
    grid=False,
    labelFontSize=16,
    titleFontSize=16,
    labelFontWeight="normal",
    titleFontWeight="normal",
).configure_axis(
    grid=False,
    labelFontSize=16,
    titleFontSize=16,
    labelFontWeight="normal",
    titleFontWeight="normal",
).configure_title(
    fontSize=24,
)

combined_plot

## FEL analysis compared to functional effects and antibody escape

In [None]:
comparisons = [
    ("effect", "FEL_beta"),
    ("effect", "FEL_alpha"),
    ("effect", "FEL_alpha=beta"),
]
subplots = []
for pair in comparisons:
    stat1 = pair[0]
    stat2 = pair[1]

    # Calculate correlation 
    r, p = sp.stats.pearsonr(
        merged_df[[stat1, stat2]].replace([np.inf, -np.inf], np.nan).dropna()[stat1],
        merged_df[[stat1, stat2]].replace([np.inf, -np.inf], np.nan).dropna()[stat2],
    )
    print(f"r correlation of {stat1} and {stat2}: {r:.2f}")
    
    curr_subplot = alt.Chart(
        merged_df, 
    ).mark_point(
        filled=True, 
        color="black", 
        size=75,
        opacity=0.15,
    ).encode(
        alt.Y(
            stat1,
            axis=alt.Axis(
                title=stat1,
                # values=[-4,-3,-2,-1,0,1],
                # domainWidth=1,
                # domainColor="black",
                # tickColor="black",
            ),
            # scale=alt.Scale(domain=[-4.1,1.1])
        ),
        alt.X(
            stat2,
            axis=alt.Axis(
                title=stat2,
                # values=[-3,-2,-1,0],
                # domainWidth=1,
                # domainColor="black",
                # tickColor="black",
            ),
            # scale=alt.Scale(type="symlog")
        ),
        tooltip=[
            "site",
            "wildtype",
            "effect",
            "total_escape",
            "nonsynonymous",
            "synonymous",
            "non/syn",
            "FEL_beta",
            "FEL_alpha",
            "FEL_alpha=beta",
            "FUBAR_beta",
            "FUBAR_alpha",
            "FUBAR_beta-alpha",
        ],
    ).properties(
        width=300,
        height=300,
    )

    subplots.append(curr_subplot)

effects = alt.hconcat(
    subplots[0],
    subplots[1],
    subplots[2],
    spacing=5,
    title=["Correlations of functional effects and FEL metrics"],
)


comparisons = [
    ("total_escape", "FEL_beta"),
    ("total_escape", "FEL_alpha"),
    ("total_escape", "FEL_alpha=beta"),
]
subplots = []
for pair in comparisons:
    stat1 = pair[0]
    stat2 = pair[1]

    # Calculate correlation 
    r, p = sp.stats.pearsonr(
        merged_df[[stat1, stat2]].replace([np.inf, -np.inf], np.nan).dropna()[stat1],
        merged_df[[stat1, stat2]].replace([np.inf, -np.inf], np.nan).dropna()[stat2],
    )
    print(f"r correlation of {stat1} and {stat2}: {r:.2f}")
    
    curr_subplot = alt.Chart(
        merged_df, 
    ).mark_point(
        filled=True, 
        color="black", 
        size=75,
        opacity=0.15,
    ).encode(
        alt.Y(
            stat1,
            axis=alt.Axis(
                title=stat1,
                # values=[-4,-3,-2,-1,0,1],
                # domainWidth=1,
                # domainColor="black",
                # tickColor="black",
            ),
            # scale=alt.Scale(domain=[-4.1,1.1])
        ),
        alt.X(
            stat2,
            axis=alt.Axis(
                title=stat2,
                # values=[-3,-2,-1,0],
                # domainWidth=1,
                # domainColor="black",
                # tickColor="black",
            ),
            # scale=alt.Scale(type="symlog")
        ),
        tooltip=[
            "site",
            "wildtype",
            "effect",
            "total_escape",
            "nonsynonymous",
            "synonymous",
            "non/syn",
            "FEL_beta",
            "FEL_alpha",
            "FEL_alpha=beta",
            "FUBAR_beta",
            "FUBAR_alpha",
            "FUBAR_beta-alpha",
        ],
    ).properties(
        width=300,
        height=300,
    )

    subplots.append(curr_subplot)

escape = alt.hconcat(
    subplots[0],
    subplots[1],
    subplots[2],
    spacing=5,
    title=["Correlations of antibody escape and FEL metrics"],
)

combined_plot = alt.vconcat(
    effects,
    escape,
    spacing=5,
    title="FEL",
).configure_axis(
    grid=False,
    labelFontSize=16,
    titleFontSize=16,
    labelFontWeight="normal",
    titleFontWeight="normal",
).configure_axis(
    grid=False,
    labelFontSize=16,
    titleFontSize=16,
    labelFontWeight="normal",
    titleFontWeight="normal",
).configure_title(
    fontSize=24,
)

combined_plot

## FUBAR analysis compared to functional effects and antibody escape

In [None]:
comparisons = [
    ("effect", "FUBAR_beta"),
    ("effect", "FUBAR_alpha"),
    ("effect", "FUBAR_beta-alpha"),
]
subplots = []
for pair in comparisons:
    stat1 = pair[0]
    stat2 = pair[1]

    # Calculate correlation 
    r, p = sp.stats.pearsonr(
        merged_df[[stat1, stat2]].replace([np.inf, -np.inf], np.nan).dropna()[stat1],
        merged_df[[stat1, stat2]].replace([np.inf, -np.inf], np.nan).dropna()[stat2],
    )
    print(f"r correlation of {stat1} and {stat2}: {r:.2f}")
    
    curr_subplot = alt.Chart(
        merged_df, 
    ).mark_point(
        filled=True, 
        color="black", 
        size=75,
        opacity=0.15,
    ).encode(
        alt.Y(
            stat1,
            axis=alt.Axis(
                title=stat1,
                # values=[-4,-3,-2,-1,0,1],
                # domainWidth=1,
                # domainColor="black",
                # tickColor="black",
            ),
            # scale=alt.Scale(domain=[-4.1,1.1])
        ),
        alt.X(
            stat2,
            axis=alt.Axis(
                title=stat2,
                # values=[-3,-2,-1,0],
                # domainWidth=1,
                # domainColor="black",
                # tickColor="black",
            ),
            # scale=alt.Scale(type="symlog")
        ),
        tooltip=[
            "site",
            "wildtype",
            "effect",
            "total_escape",
            "nonsynonymous",
            "synonymous",
            "non/syn",
            "FEL_beta",
            "FEL_alpha",
            "FEL_alpha=beta",
            "FUBAR_beta",
            "FUBAR_alpha",
            "FUBAR_beta-alpha",
        ],
    ).properties(
        width=300,
        height=300,
    )

    subplots.append(curr_subplot)

effects = alt.hconcat(
    subplots[0],
    subplots[1],
    subplots[2],
    spacing=5,
    title=["Correlations of functional effects and FUBAR metrics"],
)


comparisons = [
    ("total_escape", "FUBAR_beta"),
    ("total_escape", "FUBAR_alpha"),
    ("total_escape", "FUBAR_beta-alpha"),
]
subplots = []
for pair in comparisons:
    stat1 = pair[0]
    stat2 = pair[1]

    # Calculate correlation 
    r, p = sp.stats.pearsonr(
        merged_df[[stat1, stat2]].replace([np.inf, -np.inf], np.nan).dropna()[stat1],
        merged_df[[stat1, stat2]].replace([np.inf, -np.inf], np.nan).dropna()[stat2],
    )
    print(f"r correlation of {stat1} and {stat2}: {r:.2f}")
    
    curr_subplot = alt.Chart(
        merged_df, 
    ).mark_point(
        filled=True, 
        color="black", 
        size=75,
        opacity=0.15,
    ).encode(
        alt.Y(
            stat1,
            axis=alt.Axis(
                title=stat1,
                # values=[-4,-3,-2,-1,0,1],
                # domainWidth=1,
                # domainColor="black",
                # tickColor="black",
            ),
            # scale=alt.Scale(domain=[-4.1,1.1])
        ),
        alt.X(
            stat2,
            axis=alt.Axis(
                title=stat2,
                # values=[-3,-2,-1,0],
                # domainWidth=1,
                # domainColor="black",
                # tickColor="black",
            ),
            # scale=alt.Scale(type="symlog")
        ),
        tooltip=[
            "site",
            "wildtype",
            "effect",
            "total_escape",
            "nonsynonymous",
            "synonymous",
            "non/syn",
            "FEL_beta",
            "FEL_alpha",
            "FEL_alpha=beta",
            "FUBAR_beta",
            "FUBAR_alpha",
            "FUBAR_beta-alpha",
        ],
    ).properties(
        width=300,
        height=300,
    )

    subplots.append(curr_subplot)

escape = alt.hconcat(
    subplots[0],
    subplots[1],
    subplots[2],
    spacing=5,
    title=["Correlations of antibody escape and FUBAR metrics"],
)

combined_plot = alt.vconcat(
    effects,
    escape,
    spacing=5,
    title="FUBAR",
).configure_axis(
    grid=False,
    labelFontSize=16,
    titleFontSize=16,
    labelFontWeight="normal",
    titleFontWeight="normal",
).configure_axis(
    grid=False,
    labelFontSize=16,
    titleFontSize=16,
    labelFontWeight="normal",
    titleFontWeight="normal",
).configure_title(
    fontSize=24,
)

combined_plot

## OLS regression for DMS measurements and nonsynonymous substitution metrics

Perform multiple linear regression between the DMS measurements (i.e., effect on entry and antibody escape) and the nonsynonymous mutation metrics. We perform OLS regression for both the metric as is and for the log 2 of the metric. Note that when taking the log2 of the metric, only sites with nonzero values remain (i.e., all sites that did not have any mutations observed are not included). We finally plot the predicted metric from the OLS model versus the actual value.

In [None]:
ols_df = (
    merged_df[[
        "site",
        "wildtype",
        "FEL_beta",
        "nonsynonymous",
        "FUBAR_beta",
        "effect",
        "escape_377H",
        "escape_2510C",
        "escape_89F",
        "escape_256A",
        "escape_372D",
        "escape_121F",
        "total_escape",
    ]]
    # .replace([np.inf, -np.inf], np.nan)
    .rename(columns={
        "effect" : "effect on cell entry",
        "total_escape" : "antibody escape",
    })
    .dropna() # some sites do not have measurements for antibody escape
    .reset_index(drop=True)
)


X_vars = [
    # "site",
    "effect on cell entry",
    # "escape_377H",
    # "escape_2510C",
    # "escape_89F",
    # "escape_256A",
    # "escape_372D",
    # "escape_121F",
    "antibody escape",
]

Y_vars = [
    "nonsynonymous",
    "FEL_beta",
    "FUBAR_beta",
]

subplots = []
for Y_var in Y_vars:
    
    Y_values = ols_df[[Y_var]]
    X_values = ols_df[X_vars]

    # https://www.statsmodels.org/dev/examples/notebooks/generated/ols.html
    ols_model = statsmodels.api.OLS(
        endog=Y_values,
        exog=statsmodels.api.add_constant(X_values.astype(float)),
    )
    res_ols = ols_model.fit()
    full_r2 = res_ols.rsquared
    # print(res_ols.summary())
    ols_df[Y_var+"_predicted"] = res_ols.predict()

    # Calculate correlation 
    actual_r, p = sp.stats.pearsonr(
        ols_df[[Y_var, Y_var+"_predicted"]].replace([np.inf, -np.inf], np.nan).dropna()[Y_var+"_predicted"],
        ols_df[[Y_var, Y_var+"_predicted"]].replace([np.inf, -np.inf], np.nan).dropna()[Y_var],
    )
    print(f"r correlation of predicted {Y_var} vs actual {Y_var}: {actual_r:.2f}")
    
    unique_var = {}
    # https://blog.minitab.com/en/adventures-in-statistics-2/how-to-identify-the-most-important-predictor-variables-in-regression-models
    for vremove in X_vars:
        vremove_ols_model = statsmodels.api.OLS(
            endog=Y_values,
            exog=statsmodels.api.add_constant(ols_df[[v for v in X_vars if v != vremove]].astype(float)),
        )
        vremove_res_ols = vremove_ols_model.fit()
        unique_var[vremove] = full_r2 - vremove_res_ols.rsquared
    
    # https://stackoverflow.com/a/53966201
    subtitle = [
        f"{var}: {unique_var[var] * 100:.1f}% of variance (coef {res_ols.params[var]:.3f} \u00B1 {res_ols.bse[var]:.3f})"
        for var in X_vars
    ]

    chart_title = alt.TitleParams(
        Y_var + " (n = " + str(Y_values.shape[0]) + ")",
        subtitle=subtitle,
        fontSize=16,
    )

    curr_subplot = alt.Chart(
        ols_df,
        title=chart_title,
    ).mark_point(
        filled=True, 
        color="black", 
        size=75,
        opacity=0.15,
    ).encode(
        alt.X(
            Y_var+"_predicted",
            axis=alt.Axis(
                title=["predicted " + Y_var],
                # values=[-4,-3,-2,-1,0,1],
                # domainWidth=1,
                # domainColor="black",
                # tickColor="black",
            ),
            # scale=alt.Scale(type="symlog")
        ),
        alt.Y(
            Y_var,
            axis=alt.Axis(
                title=["actual " + Y_var],
                # values=[-3,-2,-1,0],
                # domainWidth=1,
                # domainColor="black",
                # tickColor="black",
            ),
            # scale=alt.Scale(type="symlog")
        ),
        tooltip=[
            "site",
            "wildtype",
            Y_var,
            Y_var+"_predicted",
            "effect on cell entry",
            "escape_377H",
            "escape_2510C",
            "escape_89F",
            "escape_256A",
            "escape_372D",
            "escape_121F",
            "antibody escape",
            "nonsynonymous",
            "FEL_beta",
            "FUBAR_beta",
            # "synonymous",
            # "non/syn",
            # "FEL_beta",
            # "FEL_alpha",
            # "FEL_alpha=beta",
            # "FUBAR_beta",
            # "FUBAR_alpha",
            # "FUBAR_beta-alpha",
        ],
    ).properties(
        width=400,
        height=400,
    )

    subplots.append(curr_subplot)

no_log_plots = alt.hconcat(
    subplots[0],
    subplots[1],
    subplots[2],
    spacing=5,
    # title=["Correlations of predicted vs actual substitution metrics"],
)


ols_df["log2 nonsynonymous"] = np.log2(ols_df["nonsynonymous"])
ols_df["log2 FEL_beta"] = np.log2(ols_df["FEL_beta"])
ols_df["log2 FUBAR_beta"] = np.log2(ols_df["FUBAR_beta"])
ols_df = (
    ols_df
    .replace([np.inf, -np.inf], np.nan) 
    .dropna()
    .reset_index(drop=True)
)


Y_vars = [
    "log2 nonsynonymous",
    "log2 FEL_beta",
    "log2 FUBAR_beta",
]

subplots = []
for Y_var in Y_vars:
    
    Y_values = ols_df[[Y_var]]
    X_values = ols_df[X_vars]

    # https://www.statsmodels.org/dev/examples/notebooks/generated/ols.html
    ols_model = statsmodels.api.OLS(
        endog=Y_values,
        exog=statsmodels.api.add_constant(X_values.astype(float)),
    )
    res_ols = ols_model.fit()
    full_r2 = res_ols.rsquared
    # print(res_ols.summary())
    ols_df[Y_var+"_predicted"] = res_ols.predict()

    # Calculate correlation 
    actual_r, p = sp.stats.pearsonr(
        ols_df[[Y_var, Y_var+"_predicted"]].replace([np.inf, -np.inf], np.nan).dropna()[Y_var+"_predicted"],
        ols_df[[Y_var, Y_var+"_predicted"]].replace([np.inf, -np.inf], np.nan).dropna()[Y_var],
    )
    print(f"r correlation of predicted {Y_var} vs actual {Y_var}: {actual_r:.2f}")
    
    unique_var = {}
    # https://blog.minitab.com/en/adventures-in-statistics-2/how-to-identify-the-most-important-predictor-variables-in-regression-models
    for vremove in X_vars:
        vremove_ols_model = statsmodels.api.OLS(
            endog=Y_values,
            exog=statsmodels.api.add_constant(ols_df[[v for v in X_vars if v != vremove]].astype(float)),
        )
        vremove_res_ols = vremove_ols_model.fit()
        unique_var[vremove] = full_r2 - vremove_res_ols.rsquared
    
    # https://stackoverflow.com/a/53966201
    subtitle = [
        f"{var}: {unique_var[var] * 100:.1f}% of variance (coef {res_ols.params[var]:.3f} \u00B1 {res_ols.bse[var]:.3f})"
        for var in X_vars
    ]

    chart_title = alt.TitleParams(
        Y_var + " (n = " + str(Y_values.shape[0]) + ")",
        subtitle=subtitle,
        fontSize=16,
    )

    curr_subplot = alt.Chart(
        ols_df,
        title=chart_title,
    ).mark_point(
        filled=True, 
        color="black", 
        size=75,
        opacity=0.15,
    ).encode(
        alt.X(
            Y_var+"_predicted",
            axis=alt.Axis(
                title=["predicted " + Y_var],
                # values=[-4,-3,-2,-1,0,1],
                # domainWidth=1,
                # domainColor="black",
                # tickColor="black",
            ),
            # scale=alt.Scale(type="symlog")
        ),
        alt.Y(
            Y_var,
            axis=alt.Axis(
                title=["actual " + Y_var],
                # values=[-3,-2,-1,0],
                # domainWidth=1,
                # domainColor="black",
                # tickColor="black",
            ),
            # scale=alt.Scale(type="symlog")
        ),
        tooltip=[
            "site",
            "wildtype",
            Y_var,
            Y_var+"_predicted",
            "effect on cell entry",
            "escape_377H",
            "escape_2510C",
            "escape_89F",
            "escape_256A",
            "escape_372D",
            "escape_121F",
            "antibody escape",
            "nonsynonymous",
            "FEL_beta",
            "FUBAR_beta",
            # "synonymous",
            # "non/syn",
            # "FEL_beta",
            # "FEL_alpha",
            # "FEL_alpha=beta",
            # "FUBAR_beta",
            # "FUBAR_alpha",
            # "FUBAR_beta-alpha",
        ],
    ).properties(
        width=400,
        height=400,
    )

    subplots.append(curr_subplot)

log_plots = alt.hconcat(
    subplots[0],
    subplots[1],
    subplots[2],
    spacing=5,
)

combined_plot = alt.vconcat(
    no_log_plots,
    log_plots,
    title=["Correlations of predicted vs actual substitution metrics"],
).configure_axis(
    grid=False,
    labelFontSize=16,
    titleFontSize=16,
    labelFontWeight="normal",
    titleFontWeight="normal",
).configure_axis(
    grid=False,
    labelFontSize=16,
    titleFontSize=16,
    labelFontWeight="normal",
    titleFontWeight="normal",
).configure_title(
    fontSize=24,
)

# Make output dir if doesn't exist
if not os.path.exists(out_dir_natural):
    os.mkdir(out_dir_natural)

combined_plot.save(ols_regression)

combined_plot

## Antibody escape compared to natural sequence time of isolation

In [None]:
# Load metadata as dataframe
all_metadata = pd.read_csv(natural_seq_metadata, sep="\t")

# Load alignment and metadata info
natural_seqs_df = pd.DataFrame(columns=["strain", "sequence"])

# Load list of sequence names from alignment
strains = []
for curr_fasta in SeqIO.parse(natural_seq_alignment, "fasta"):
    strains.append(str(curr_fasta.id))
    natural_seqs_df.loc[len(natural_seqs_df.index)] = [
        str(curr_fasta.id),
        str(curr_fasta.seq),
    ]

# Filter metadata
GPC_metadata = (
    all_metadata.loc[all_metadata["strain"].isin(strains)].copy()
)
# Add Sierra Leone for Josiah strain and fix date for Josiah
# using paper Wulff and Johnson, 1979
# Note that some dates might be wrong because they were incorrectly
# dated on NCBI virus
GPC_metadata.at[0, "country"] = "Sierra Leone"
GPC_metadata.at[0, "date"] = "1976-XX-XX"
# Also add host for Pinneo strain
GPC_metadata.at[1, "host"] = "Homo sapiens"

# Add sequence data and metadata
natural_seqs_df = (
    natural_seqs_df.merge(
        GPC_metadata,
        how="left",
        on="strain",
        validate="one_to_one",
    )
)

# Create column of collection year
natural_seqs_df["date_year"] = (
    natural_seqs_df["date"].apply(lambda x: x[0:4])
)

# Re-group host into simpler categories
natural_seqs_df["grouped_host"] = (
    natural_seqs_df["host"].apply(lambda x: "human" if x == "Homo sapiens" else ("unknown" if x == "MISSING" else "rodent"))
)

# Extract josiah sequence for comparisons
josiah_seq = (
    natural_seqs_df.loc[
        (natural_seqs_df["strain"] == "Josiah_NC_004296_reverse_complement_2018-08-13")
    ]["sequence"][0]
)

# Read functional score data
functional_scores = pd.read_csv(func_scores)

# Filter for minimum selections, times seen and no stop codons
functional_scores = (
    functional_scores.query(
        "n_selections >= @n_selections and times_seen >= @min_times_seen and mutant != '*'"
    )
    .reset_index(drop=True)
)

# Initialize list of escape files and final df
antibody_files = [
    filtered_escape_377H,
    filtered_escape_89F,
    filtered_escape_2510C,
    filtered_escape_121F,
    filtered_escape_256A,
    filtered_escape_372D,
]
antibody_escape = pd.DataFrame(columns=["site", "wildtype"])

# Add escape to dataframe for each antibody
for index,antibody_file in enumerate(antibody_files):

    antibody_name = antibody_file.split("/")[-1].split("_")[0]

    # Load data as dataframe
    escape_df = pd.read_csv(antibody_file)

    # Clip lower scores to 0
    escape_df["escape_median"] = escape_df["escape_median"].clip(lower=0)

    # Rename escape column to include antibody name
    escape_df = (
        escape_df
        .rename(columns={"escape_median" : "escape_" + antibody_name})
    )

    if index == 0:
        antibody_escape = (
            pd.concat([
                antibody_escape, 
                escape_df[["site", "wildtype", "mutant", "escape_" + antibody_name]],
            ])
        )
    else:
        # Merge dataframes
        antibody_escape = (
            antibody_escape.merge(
                escape_df[["site", "wildtype", "mutant", "escape_" + antibody_name]],
                how="left",
                on=["site", "wildtype", "mutant"],
                validate="one_to_one",
            )
        )

merged_df = (
    functional_scores.merge(
        antibody_escape,
        how="left",
        on=["site", "wildtype", "mutant"],
        validate="one_to_one",
    )
)

# Add column of mutation
merged_df["mutation"] = (
    merged_df["wildtype"] + merged_df["site"].astype(str) + merged_df["mutant"]
)

def calculate_natural_sequence_scores(antibody, sequence, jos_sequence):
    """
    Calculate a cumulative escape or functional 
    effects score for a natural sequence.
    """

    list_of_muts = []
    site = 1
    # Get list of mutations
    for s1, s2 in zip(jos_sequence, sequence):
        if s1 != s2 and s1 != "-" and s2 != "-":
            curr_mut = f"{s1}{site}{s2}"
            list_of_muts.append(curr_mut)
        site += 1

    max = merged_df[antibody].max()
    min = merged_df[antibody].min()

    total_escape = 0
    # Iterate through mutations and get escape scores
    for mut in list_of_muts:
        curr_mut = merged_df.query("mutation == @mut")
        if curr_mut.shape[0] == 1:
            curr_mut_escape = (
                curr_mut
                .fillna(0)
                .reset_index(drop=True)[antibody][0]
            )
            # total_escape += (curr_mut_escape - min) / (max-min)
            total_escape += curr_mut_escape

    # Return total escape
    return total_escape 

def identify_sequences_with_escape_muts(antibody, sequence, jos_sequence, threshold, strain, antibody_name):
    """
    Identify sequences with single escape mutations over a threshold.
    """

    list_of_muts = []
    site = 1
    # Get list of mutations
    for s1, s2 in zip(jos_sequence, sequence):
        if s1 != s2 and s1 != "-" and s2 != "-":
            curr_mut = f"{s1}{site}{s2}"
            list_of_muts.append(curr_mut)
        site += 1

    # Iterate through mutations and get escape scores
    for mut in list_of_muts:
        curr_mut = merged_df.query("mutation == @mut")
        if curr_mut.shape[0] == 1:
            curr_mut_escape = (
                curr_mut
                .fillna(0)
                .reset_index(drop=True)[antibody][0]
            )
            if curr_mut_escape >= threshold:
                # print(f"{antibody_name} escaped by {curr_mut.fillna(0).reset_index(drop=True)['mutation'][0]} ({curr_mut_escape}) in {strain}")
                return True

    # Return false if no muts found
    return False 

# Calculate antibody escape for individual antibodies
for antibody in ["377H", "256A", "372D", "121F", "89F", "2510C"]:
    natural_seqs_df["total_"+antibody+"_escape"] = (
        natural_seqs_df["sequence"].apply(lambda x: calculate_natural_sequence_scores(
            "escape_"+antibody,
            x,
            josiah_seq
        ))
    )

# Determine if escape muts exist for individual antibodies
for antibody in ["377H", "256A", "372D", "121F", "89F", "2510C"]:
    natural_seqs_df["strong_escape_mut_"+antibody] = (
        natural_seqs_df.apply(lambda x: identify_sequences_with_escape_muts(
            "escape_"+antibody,
            x["sequence"],
            josiah_seq,
            1,
            x["strain"],
            antibody,
        ), axis=1)
    )

# Calculate total antibody escape
natural_seqs_df["total_escape"] = (
    natural_seqs_df[[
        "total_377H_escape", 
        "total_256A_escape", 
        "total_372D_escape", 
        "total_121F_escape", 
        "total_89F_escape", 
        "total_2510C_escape"
    ]].sum(axis=1)
)

# Calculate total functional scores
natural_seqs_df["total_func_effect"] = (
    natural_seqs_df["sequence"].apply(lambda x: calculate_natural_sequence_scores(
        "effect",
        x,
        josiah_seq
    ))
)

antibody_list = [
    "total_2510C_escape",
    "total_121F_escape", 
    "total_377H_escape", 
    "total_256A_escape", 
    "total_372D_escape", 
    "total_89F_escape", 
    
]

# Isolates chosen to validate for escape 
chosen_isolates = [
    "Josiah_NC_004296_reverse_complement_2018-08-13",
    "GA391_OL774861_reverse_complement_1977-XX-XX",
    "LASV_H-sapiens-tc_NGA_2016_IRR_007_MK107922_2016-01-18",
    "Lassa_virus_H-sapiens-wt_NGA_2018_ISTH_1024_MH157037_2018-02-14",
    "LM395-SLE-2009_KM822115_2009-XX-XX",
]

In [None]:
# Print number of human and rodent sequences that have strong escape mutants
print(f"Total number of sequences isolated from humans: {natural_seqs_df.loc[natural_seqs_df['grouped_host'] == 'human'].shape[0]}")
print(f"Total number of sequences isolated from rodents: {natural_seqs_df.loc[natural_seqs_df['grouped_host'] == 'rodent'].shape[0]}")
print(f"Total number of sequences isolated from missing: {natural_seqs_df.loc[natural_seqs_df['grouped_host'] == 'unknown'].shape[0]}")
print(f"Total number of sequences: {natural_seqs_df.shape[0]}")
print(f"Total number of sequences isolated from humans with strong escape mutations: {natural_seqs_df.loc[natural_seqs_df['grouped_host'] == 'human'].query('strong_escape_mut_2510C == True | strong_escape_mut_121F == True | strong_escape_mut_377H == True | strong_escape_mut_256A == True | strong_escape_mut_372D == True | strong_escape_mut_89F == True').shape[0]}")
print(f"Total number of sequences isolated from rodents with strong escape mutations: {natural_seqs_df.loc[natural_seqs_df['grouped_host'] == 'rodent'].query('strong_escape_mut_2510C == True | strong_escape_mut_121F == True | strong_escape_mut_377H == True | strong_escape_mut_256A == True | strong_escape_mut_372D == True | strong_escape_mut_89F == True').shape[0]}")
print(f"Total number of sequences isolated from missing with strong escape mutations: {natural_seqs_df.loc[natural_seqs_df['grouped_host'] == 'unknown'].query('strong_escape_mut_2510C == True | strong_escape_mut_121F == True | strong_escape_mut_377H == True | strong_escape_mut_256A == True | strong_escape_mut_372D == True | strong_escape_mut_89F == True').shape[0]}")
print(f"Total number of sequences without strong escape mutations: {natural_seqs_df.query('strong_escape_mut_2510C == False & strong_escape_mut_121F == False & strong_escape_mut_377H == False & strong_escape_mut_256A == False & strong_escape_mut_372D == False & strong_escape_mut_89F == False').shape[0]}")

In [None]:
# Calculate correlation 
r, p = sp.stats.pearsonr(
    natural_seqs_df[["date_year", "total_escape"]].dropna().astype(float)["date_year"],
    natural_seqs_df[["date_year", "total_escape"]].dropna().astype(float)["total_escape"],
)
print(f"r correlation of time and antibody escape: {r:.2f}")

antibody_escape = alt.Chart(
    natural_seqs_df, 
).mark_point(
    filled=True, 
    color="black", 
    size=75,
    opacity=0.5,
).encode(
    alt.Y(
        "total_escape",
        axis=alt.Axis(
            title="total antibody escape",
            domainWidth=1,
            domainColor="black",
            tickColor="black",
        ),
    ),
    alt.X(
        "date_year:T",
        axis=alt.Axis(
            title="date",
            domainWidth=1,
            domainColor="black",
            tickColor="black",
        ),
    ),
    color=alt.Color(
        "country:N", 
        title="country",
        scale=alt.Scale(
            domain=natural_seqs_df["country"].unique().tolist(),
            range=tol_muted_adjusted[1:]
        ),
    ), 
    shape=alt.Shape(
        "grouped_host:N",
        title="host",
    ),
    tooltip=[
        "strain",
        "country",
        "host",
        "date",
        "total_escape",
        "total_377H_escape", 
        "total_256A_escape", 
        "total_372D_escape", 
        "total_121F_escape", 
        "total_89F_escape", 
        "total_2510C_escape",
        "total_func_effect",
    ],
).properties(
    width=500,
    height=300,
)

# Calculate correlation 
r, p = sp.stats.pearsonr(
    natural_seqs_df[["date_year", "total_func_effect"]].dropna().astype(float)["date_year"],
    natural_seqs_df[["date_year", "total_func_effect"]].dropna().astype(float)["total_func_effect"],
)
print(f"r correlation of time and functional effects: {r:.2f}")

func_effects = alt.Chart(
    natural_seqs_df, 
).mark_point(
    filled=True, 
    color="black", 
    size=75,
    opacity=0.5,
).encode(
    alt.Y(
        "total_func_effect",
        axis=alt.Axis(
            title="total functional effect",
            domainWidth=1,
            domainColor="black",
            tickColor="black",
        ),
    ),
    alt.X(
        "date_year:T",
        axis=alt.Axis(
            title="date",
            domainWidth=1,
            domainColor="black",
            tickColor="black",
        ),
    ),
    color=alt.Color(
        "country:N", 
        title="country",
        scale=alt.Scale(
            domain=natural_seqs_df["country"].unique().tolist(),
            range=tol_muted_adjusted[1:]
        ),
    ), 
    shape=alt.Shape(
        "grouped_host:N",
        title="host",
    ),
    tooltip=[
        "strain",
        "country",
        "host",
        "date",
        "total_escape",
        "total_377H_escape", 
        "total_256A_escape", 
        "total_372D_escape", 
        "total_121F_escape", 
        "total_89F_escape", 
        "total_2510C_escape",
        "total_func_effect",
    ],
).properties(
    width=500,
    height=300,
)

combined_plot = alt.hconcat(
    antibody_escape,
    func_effects,
    spacing=5,
    title="Antibody and functional effects across time",
).configure_axis(
    grid=False,
    labelFontSize=16,
    titleFontSize=16,
    labelFontWeight="normal",
    titleFontWeight="normal",
).configure_axis(
    grid=False,
    labelFontSize=16,
    titleFontSize=16,
    labelFontWeight="normal",
    titleFontWeight="normal",
).configure_title(
    fontSize=24,
)

combined_plot

In [None]:
antibody_escape = alt.Chart(
    natural_seqs_df, 
).mark_point(
    filled=True, 
    color="black", 
    size=75,
    opacity=0.5,
).encode(
    alt.Y(
        "total_escape",
        axis=alt.Axis(
            title="total antibody escape",
            domainWidth=1,
            domainColor="black",
            tickColor="black",
        ),
    ),
    alt.X(
        "grouped_host:N",
        axis=alt.Axis(
            title="host",
            domainWidth=1,
            domainColor="black",
            tickColor="black",
        ),
    ),
    xOffset="jitter:Q",
    color=alt.Color(
        "country:N", 
        title="country",
        scale=alt.Scale(
            domain=natural_seqs_df["country"].unique().tolist(),
            range=tol_muted_adjusted[1:]
        ),
    ), 
    tooltip=[
        "strain",
        "country",
        "host",
        "date",
        "total_escape",
        "total_377H_escape", 
        "total_256A_escape", 
        "total_372D_escape", 
        "total_121F_escape", 
        "total_89F_escape", 
        "total_2510C_escape",
        "total_func_effect",
    ],
).transform_calculate(
    # Generate Gaussian jitter with a Box-Muller transform
    jitter="sqrt(-2*log(random()))*cos(2*PI*random())"
).properties(
    width=500,
    height=300,
)

func_effects = alt.Chart(
    natural_seqs_df, 
).mark_point(
    filled=True, 
    color="black", 
    size=75,
    opacity=0.5,
).encode(
    alt.Y(
        "total_func_effect",
        axis=alt.Axis(
            title="total functional effect",
            domainWidth=1,
            domainColor="black",
            tickColor="black",
        ),
    ),
    alt.X(
        "grouped_host:N",
        axis=alt.Axis(
            title="host",
            domainWidth=1,
            domainColor="black",
            tickColor="black",
        ),
    ),
    xOffset="jitter:Q",
    color=alt.Color(
        "country:N", 
        title="country",
        scale=alt.Scale(
            domain=natural_seqs_df["country"].unique().tolist(),
            range=tol_muted_adjusted[1:]
        ),
    ), 
    tooltip=[
        "strain",
        "country",
        "host",
        "date",
        "total_escape",
        "total_377H_escape", 
        "total_256A_escape", 
        "total_372D_escape", 
        "total_121F_escape", 
        "total_89F_escape", 
        "total_2510C_escape",
        "total_func_effect",
    ],
).transform_calculate(
    # Generate Gaussian jitter with a Box-Muller transform
    jitter="sqrt(-2*log(random()))*cos(2*PI*random())"
).properties(
    width=500,
    height=300,
)

combined_plot = alt.hconcat(
    antibody_escape,
    func_effects,
    spacing=5,
    title="Antibody and functional effects for different hosts",
).configure_axis(
    grid=False,
    labelFontSize=16,
    titleFontSize=16,
    labelFontWeight="normal",
    titleFontWeight="normal",
).configure_axis(
    grid=False,
    labelFontSize=16,
    titleFontSize=16,
    labelFontWeight="normal",
    titleFontWeight="normal",
).configure_title(
    fontSize=24,
)

combined_plot