# Model results for simulated and natural populations

Plots model results for all populations by timepoint including those shown in Figures 2, 3, 5, and 6 and Supplemental Figures S3, S6, and S8.

Generates tables of model results for all populations including those in Tables 1 and 2 and Supplemental Table S3.

Plots cross-validation approaches shown in Supplemental Figures S1 and S5.

In [None]:
# Define inputs.
errors_file = snakemake.input.model_distances
coefficients_file = snakemake.input.model_coefficients
bootstrap_p_values_file = snakemake.input.bootstrap_p_values

# Define outputs.
table_for_simulated_model_selection = snakemake.output.table_for_simulated_model_selection
source_data_for_simulated_model_coefficients = snakemake.output.source_data_for_simulated_model_coefficients
source_data_for_simulated_model_distances = snakemake.output.source_data_for_simulated_model_distances

figure_for_simulated_model_controls = snakemake.output.figure_for_simulated_model_controls
figure_for_simulated_individual_models = snakemake.output.figure_for_simulated_individual_models
figure_for_simulated_composite_models = snakemake.output.figure_for_simulated_composite_models

table_for_natural_model_selection = snakemake.output.table_for_natural_model_selection
table_for_natural_model_complete_selection = snakemake.output.table_for_natural_model_complete_selection
source_data_for_natural_model_coefficients = snakemake.output.source_data_for_natural_model_coefficients
source_data_for_natural_model_distances = snakemake.output.source_data_for_natural_model_distances

figure_for_natural_epitope_vs_oracle_models = snakemake.output.figure_for_natural_epitope_vs_oracle_models
figure_for_natural_individual_models = snakemake.output.figure_for_natural_individual_models
figure_for_natural_composite_models = snakemake.output.figure_for_natural_composite_models
figure_for_natural_updated_models = snakemake.output.figure_for_natural_updated_models

figure_for_simulated_cross_validation = snakemake.output.figure_for_simulated_cross_validation
figure_for_natural_cross_validation = snakemake.output.figure_for_natural_cross_validation

# Define parameters.
simulated_sample = snakemake.params.simulated_sample
natural_sample = snakemake.params.natural_sample

## Import and define functions
[back to top](#Summarize-models)

In [None]:
from collections import OrderedDict
import matplotlib as mpl
from matplotlib.collections import LineCollection
import matplotlib.dates as mdates
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.ticker as ticker
import numpy as np
import pandas as pd
from pandas.plotting import register_matplotlib_converters
import seaborn as sns

%matplotlib inline

In [None]:
register_matplotlib_converters()

In [None]:
sns.set_style("white")

In [None]:
# Display figures at a reasonable default size.
mpl.rcParams['figure.figsize'] = (6, 4)

# Disable top and right spines.
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False
    
# Display and save figures at higher resolution for presentations and manuscripts.
mpl.rcParams['savefig.dpi'] = 200
mpl.rcParams['figure.dpi'] = 200

# Display text at sizes large enough for presentations and manuscripts.
mpl.rcParams['font.weight'] = "normal"
mpl.rcParams['axes.labelweight'] = "normal"
mpl.rcParams['font.size'] = 18
mpl.rcParams['axes.labelsize'] = 14
mpl.rcParams['legend.fontsize'] = 12
mpl.rcParams['xtick.labelsize'] = 14
mpl.rcParams['ytick.labelsize'] = 14

mpl.rc('text', usetex=False)

In [None]:
panel_labels_dict = {
    "weight": "bold",
    "size": 14
}

In [None]:
colors = ['#d73027','#fc8d59','#fee090','#e0f3f8','#91bfdb','#4575b4']

In [None]:
colors.reverse()

In [None]:
colors

In [None]:
ncols = 2
color_by_predictor = {
    'naive': '#000000',
    'offspring': '#000000',
    'normalized_fitness': '#000000',
    'fitness': '#000000',
    'ep': '#4575b4',
    'ep_wolf': '#4575b4',
    'ep_star': '#4575b4',
    'ep_x': '#4575b4',
    'ep_x_koel': '#4575b4',
    'ep_x_wolf': '#4575b4',
    'oracle_x': '#4575b4',
    'rb': '#4575b4',
    'cTiter': '#91bfdb',
    'cTiter_x': '#91bfdb',
    'cTiterSub': '#91bfdb',
    'cTiterSub_star': '#91bfdb',
    'cTiterSub_x': '#91bfdb',
    'fra_cTiter_x': '#91bfdb',
    'ne_star': '#2ca25f',
    'dms_star': '#99d8c9',
    "dms_nonepitope": "#99d8c9",
    "dms_entropy": "#99d8c9",
    'unnormalized_lbi': '#fc8d59',
    'lbi': '#fc8d59',
    'delta_frequency': '#d73027'
}

name_by_predictor = {
    "naive": "naive",
    "offspring": "observed fitness",
    "normalized_fitness": "true fitness",
    "fitness": "estimated fitness",
    "ep": "epitope mutations",
    "ep_wolf": "Wolf epitope mutations",
    "ep_star": "epitope ancestor",
    "ep_x": "epitope antigenic novelty",
    "ep_x_koel": "Koel epitope antigenic novelty",
    "ep_x_wolf": "Wolf epitope antigenic novelty",
    "oracle_x": "oracle antigenic novelty",
    "rb": "Koel epitope mutations",
    "cTiter": "antigenic advance",
    "cTiter_x": "HI antigenic novelty",
    "cTiterSub": "linear HI mut phenotypes",
    "cTiterSub_star": "ancestral HI mut phenotypes",
    "cTiterSub_x": "HI sub cross-immunity",
    "fra_cTiter_x": "FRA antigenic novelty",
    "ne_star": "mutational load",
    "dms_star": "DMS mutational effects",
    "dms_nonepitope": "DMS mutational load",
    "dms_entropy": "DMS entropy",
    "unnormalized_lbi": "unnormalized LBI",
    "lbi": "LBI",
    "delta_frequency": "delta frequency"
}

predictors_to_drop = [
    "ep",
    "cTiter",
    "cTiterSub",
    "cTiterSub_star",
    "cTiterSub_x"
    #"delta_frequency-ne_star",
    #"lbi-ep_x-ne_star"
]

In [None]:
def get_individual_predictors_for_data_frame(df):
    return [
        predictor
        for predictor in df["predictors"].unique()
        if "-" not in predictor and predictor != "naive"
    ]

def get_composite_predictors_for_data_frame(df):
    return [
        predictor
        for predictor in df["predictors"].unique()
        if "-" in predictor
    ]

In [None]:
def plot_model_accuracy_and_coefficients_for_build(errors_by_time_df, coefficients_by_time_df, predictors, rotation=30,
                                             years_fmt_string="%Y", date_fmt_string="%Y-%m-%d", height=12, width=12,
                                             text_vertical_padding=0.12, hspace=0.1, wspace=0.2, max_predictor_name_length=55,
                                             share_y=True, max_coefficient=None, min_normal_error=None, max_normal_error=None,
                                             error_attribute="validation_error",
                                             naive_attribute="null_validation_error",
                                             optimal_attribute="optimal_validation_error",
                                             distance_axis_label="Distance to\nfuture (AAs)",
                                             coefficient_axis_label="Coefficient",
                                             distance_tick_multiple=2):
    # Determine bounds for given data to set axes domains and ranges.
    std_normal_error = errors_by_time_df[error_attribute].std()
    
    if max_normal_error is None:
        max_normal_error = errors_by_time_df[error_attribute].max()
        max_normal_error += 2.0 * std_normal_error

    if min_normal_error is None:
        min_normal_error = errors_by_time_df[optimal_attribute].min()
        
    min_coefficient = coefficients_by_time_df["coefficient"].min()
    
    if max_coefficient is None:
        max_coefficient = coefficients_by_time_df["coefficient"].max() + 2

    min_date = errors_by_time_df["validation_timepoint"].min() - pd.DateOffset(months=6)
    max_date = errors_by_time_df["validation_timepoint"].max() + pd.DateOffset(months=6)
    
    nrows = len(predictors)
    
    naive_error_df = errors_by_time_df[errors_by_time_df["predictors"] == "naive"].copy()
    naive_validation_error_df = naive_error_df[naive_error_df["error_type"] == "validation"].copy()
    naive_test_error_df = naive_error_df[naive_error_df["error_type"] == "test"].copy()
    total_validation_timepoints = naive_validation_error_df.shape[0]
    total_timepoints = naive_error_df.shape[0]
    
    fig, axes = plt.subplots(figsize=(width, height), facecolor='w')
    gs = gridspec.GridSpec(
        nrows=nrows,
        ncols=ncols,
        hspace=hspace,
        wspace=wspace
    )

    years = mdates.YearLocator(3)
    years_fmt = mdates.DateFormatter(years_fmt_string)
    months = mdates.MonthLocator()
    
    # Get the start and end date for test data to enable visual delineation of these later data.
    test_start_date, test_end_date = naive_test_error_df["validation_timepoint"].aggregate(["min", "max"]).values

    for i, predictor in enumerate(predictors):
        error_df = errors_by_time_df[errors_by_time_df["predictors"] == predictor].copy()
        validation_error_df = error_df[error_df["error_type"] == "validation"].copy()
        test_error_df = error_df[error_df["error_type"] == "test"].copy()
        
        coefficient_df = coefficients_by_time_df[coefficients_by_time_df["predictors"] == predictor].copy()
        validation_coefficient_df = coefficient_df[coefficient_df["error_type"] == "validation"].copy()
        test_coefficient_df = coefficient_df[coefficient_df["error_type"] == "test"].copy()
        
        composite_predictors = predictor.split("-")
        composite_predictors_name = " + ".join([name_by_predictor[predictor_name] for predictor_name in composite_predictors])
        if len(composite_predictors_name) > max_predictor_name_length:
            predictor_name_spacing = "\n"
        else:
            predictor_name_spacing = " "

        distance_ax = plt.subplot(gs[i, 1])    
        distance_ax.set_xlabel("Date")
        distance_ax.set_ylabel(distance_axis_label)

        distance_ax.axhline(
            y=0.0,
            color="#cccccc"
        )

        # Plot validation data.        
        distance_ax.plot(
            pd.to_datetime(validation_error_df["validation_timepoint"]).astype(np.datetime64),
            validation_error_df[error_attribute],
            "o-",
            color="#000000",
            label="validation: %.2f +/- %.2f" % (validation_error_df[error_attribute].mean(), validation_error_df[error_attribute].std())
        )
        
        # Plot distance from current timepoint to future.
        distance_ax.plot(
            pd.to_datetime(validation_error_df["validation_timepoint"]).astype(np.datetime64),
            validation_error_df[naive_attribute],
            "-",
            color="#cccccc",
            label="",
            zorder=-10
        )
        
        # Plot optimal distance from current timepoint to future for any model.
        distance_ax.plot(
            pd.to_datetime(validation_error_df["validation_timepoint"]).astype(np.datetime64),
            validation_error_df[optimal_attribute],
            "-",
            color="#999999",
            label="",
            zorder=-10
        )
        
        # Plot test data.
        if test_error_df.shape[0] > 0:
            model_test_distance_handle_output = distance_ax.plot(
                pd.to_datetime(test_error_df["validation_timepoint"]).astype(np.datetime64),
                test_error_df[error_attribute],
                "o-",
                fillstyle="none",
                color="#000000",
                label="test: %.2f +/- %.2f" % (test_error_df[error_attribute].mean(), test_error_df[error_attribute].std())
            )
            
            # Plot distance from current timepoint to future.
            distance_ax.plot(
                pd.to_datetime(test_error_df["validation_timepoint"]).astype(np.datetime64),
                test_error_df[naive_attribute],
                "-",
                fillstyle="none",
                color="#cccccc",
                label="",
                zorder=-10
            )
            
            # Plot optimal distance from current timepoint to future for any model.
            distance_ax.plot(
                pd.to_datetime(test_error_df["validation_timepoint"]).astype(np.datetime64),
                test_error_df[optimal_attribute],
                "-",
                fillstyle="none",
                color="#999999",
                label="",
                zorder=-10
            )
        
        distance_ax.legend(
            loc=(0.01, 0.92),
            frameon=False,
            fontsize=12,
            ncol=2
        )

        distance_ax.set_xlim(min_date, max_date)
        
        distance_ax.set_ylim(min_normal_error, max_normal_error)
        distance_ax.xaxis.set_major_locator(years)
        distance_ax.xaxis.set_major_formatter(years_fmt)
        distance_ax.xaxis.set_minor_locator(months)
        distance_ax.format_xdata = mdates.DateFormatter(date_fmt_string)
        
        distance_ax.yaxis.set_major_locator(ticker.MultipleLocator(distance_tick_multiple))
        distance_ax.tick_params(which='major', width=1.00, length=5)

        coefficient_ax = plt.subplot(gs[i, 0])
        coefficient_ax.set_xlabel("Date")
        coefficient_ax.set_ylabel(coefficient_axis_label)
        
        if share_y:
            coefficient_ax.set_ylim(min_coefficient - 1, max_coefficient)

        coefficient_ax.axhline(
            y=0.0,
            color="#999999"
        )

        # Plot validation coefficients
        for predictor, predictor_coefficient_df in validation_coefficient_df.groupby("predictor"):
            coefficient_ax.plot(
                predictor_coefficient_df["validation_timepoint"],
                predictor_coefficient_df["coefficient"],
                "o-",
                color=color_by_predictor[predictor],
                label="%s: %.2f +/- %.2f" % (
                    name_by_predictor[predictor],
                    predictor_coefficient_df["coefficient"].mean(),
                    predictor_coefficient_df["coefficient"].std()
                )
            )
        
        number_of_coefficients = validation_coefficient_df["predictor"].drop_duplicates().shape[0]
        y_position_by_number_of_coefficients = {
            1: 0.92,
            2: 0.75,
            3: 0.6
        }
        coefficient_legend = coefficient_ax.legend(
            loc=(0.01, y_position_by_number_of_coefficients[number_of_coefficients]),
            frameon=False,
            fontsize=12
        )
        
        for legend_text in coefficient_legend.get_texts():
            legend_text.set_horizontalalignment("left")
            legend_text.set_verticalalignment("baseline")

        # Plot fixed coefficients for testing
        for predictor, predictor_coefficient_df in test_coefficient_df.groupby("predictor"):
            coefficient_ax.plot(
                predictor_coefficient_df["validation_timepoint"],
                predictor_coefficient_df["coefficient"],
                "o-",
                fillstyle="none",
                color=color_by_predictor[predictor],
                label="%s: %.2f +/- %.2f" % (
                    name_by_predictor[predictor],
                    predictor_coefficient_df["coefficient"].mean(),
                    predictor_coefficient_df["coefficient"].std()
                )
            )

        coefficient_ax.set_xlim(min_date, max_date)
        coefficient_ax.xaxis.set_major_locator(years)
        coefficient_ax.xaxis.set_major_formatter(years_fmt)
        coefficient_ax.xaxis.set_minor_locator(months)
        coefficient_ax.format_xdata = mdates.DateFormatter(date_fmt_string)

    fig.autofmt_xdate(rotation=rotation, ha="center")
    gs.tight_layout(fig, h_pad=hspace)
    # show x-axis tick lines
    
    return (fig, axes, gs)

In [None]:
def prepare_table(errors_df, coefficients_df, text_width=1.0, include_coefficients=True):
    error_metric = "validation_error"
    
    coefficient_columns = ["model", "predictor", "coefficient_mean", "coefficient_std"]
    model_selection_coefficients = coefficients_df.groupby(["model", "predictor"], sort=False).aggregate({
        "coefficient": ["mean", "std"]
    }).reset_index()
    model_selection_coefficients.columns = coefficient_columns
    
    model_selection_errors = errors_df.groupby("model").aggregate({
        error_metric: ["mean", "std"],
        "model_better_than_naive": ["sum", "mean"]
    }).sort_values((error_metric, "mean"), ascending=False)
    # .query("model != 'naive'")
    
    model_selection_errors.loc[:, ("model_better_than_naive", "sum")] = model_selection_errors[("model_better_than_naive", "sum")].astype(int)
    
    columns = [
        "model",
        "%s_mean" % error_metric,
        "%s_std" % error_metric,
        "model_better_count",
        "model_better_proportion"
    ]
    model_selection_errors = np.around(model_selection_errors, 2).reset_index()
    model_selection_errors.columns = columns
    model_selection_errors = model_selection_errors.sort_values("%s_mean" % error_metric, ascending=True)
        
    if include_coefficients:
        model_selection = model_selection_errors.merge(
            model_selection_coefficients,
            on=["model"]
        )

        model_selection["coefficients"] = model_selection.apply(
            lambda row: "%.2f +/- %.2f" % (row["coefficient_mean"], row["coefficient_std"]),
            axis=1
        )
        
        simple_model_selection_columns = ["model", "coefficients", error_metric, "model_better"]
    else:
        model_selection = model_selection_errors.copy()
        simple_model_selection_columns = ["model", error_metric, "model_better"]

    model_selection[error_metric] = model_selection.apply(
        lambda row: "%.2f +/- %.2f" % (row["%s_mean" % error_metric], row["%s_std" % error_metric]),
        axis=1
    )
    
    model_selection["model_better"] = model_selection.apply(
        lambda row: "%i (%i\%%)" % (row["model_better_count"], int(row["model_better_proportion"] * 100)),
        axis=1
    )
    
    simple_model_selection = []
    for model, model_df in model_selection.loc[:, simple_model_selection_columns].groupby("model", sort=False):
        new_model_predictors = model.split(" + ")
        
        if include_coefficients:
            new_coefficients = model_df["coefficients"].values
        
        for i in range(len(new_model_predictors)):
            if i == 0:
                if len(new_model_predictors) > 1:
                    new_model_predictor = new_model_predictors[i] + " +"
                else:
                    new_model_predictor = new_model_predictors[i]
                    
                new_validation_error = model_df[error_metric].values[0]
                new_model_better = model_df["model_better"].values[0]
            else:
                new_model_predictor = "\hspace{3mm}" + new_model_predictors[i]
                new_validation_error = ""
                new_model_better = ""
                
            record = {
                "model": new_model_predictor,
                error_metric: new_validation_error,
                "model_better": new_model_better
            }
            
            if include_coefficients:
                record["coefficients"] = new_coefficients[i]
                
            simple_model_selection.append(record)

    latex_columns = [
        "Model",
        "\makecell{Distance to \\\\ future (AAs)}",
        "\makecell[l]{Model $>$ naive \\\\ (N=%i)}" % errors_df["validation_timepoint"].unique().shape[0]
    ]
    column_format = "lrl"
    
    if include_coefficients:
        latex_columns.insert(1, "Coefficients")
        column_format = "lrrl"
        
    simple_model_selection = pd.DataFrame(simple_model_selection, columns=simple_model_selection_columns)
    simple_model_selection.columns = latex_columns
    
    # Update pandas options for maximum column width to display so longer cells do not get truncates in LaTeX.
    with pd.option_context("max_colwidth", 1000):
        simple_model_selection_table = simple_model_selection.to_latex(index=False, escape=False, column_format=column_format).replace(
            "tabular}",
            "tabular*}"
        ).replace(
            "{tabular*}{",
            "{tabular*}{%s\\textwidth}{" % text_width
        )
        
    return simple_model_selection_table

## Load data

In [None]:
p_values = pd.read_csv(bootstrap_p_values_file, sep="\t")

In [None]:
p_values.head()

In [None]:
errors_by_time_df = pd.read_csv(errors_file, sep="\t", parse_dates=["validation_timepoint"])

In [None]:
errors_by_time_df = errors_by_time_df.merge(
    p_values,
    on=["sample", "error_type", "predictors"],
    how="left"
)

In [None]:
errors_by_time_df.head()

In [None]:
distinct_samples_with_errors = errors_by_time_df["sample"].unique()

In [None]:
distinct_samples_with_errors

In [None]:
assert simulated_sample in distinct_samples_with_errors

In [None]:
assert natural_sample in distinct_samples_with_errors

In [None]:
errors_by_time_df = errors_by_time_df[~errors_by_time_df["predictors"].isin(predictors_to_drop)].copy()
errors_by_time_df["model_improvement"] = errors_by_time_df["null_validation_error"] - errors_by_time_df["validation_error"]
errors_by_time_df["log2_model_improvement"] = np.log2(errors_by_time_df["null_validation_error"] / errors_by_time_df["validation_error"])
errors_by_time_df["relative_improvement"] = (
    errors_by_time_df["null_validation_error"] - errors_by_time_df["validation_error"]
) / errors_by_time_df["null_validation_error"]
errors_by_time_df["proportion_by_model"] = errors_by_time_df["validation_error"] / errors_by_time_df["null_validation_error"]
errors_by_time_df["proportion_explained"] = 1 - (errors_by_time_df["validation_error"] / errors_by_time_df["null_validation_error"])
errors_by_time_df["distance_from_future"] = errors_by_time_df["average_distance_to_future"] - errors_by_time_df["average_diversity_in_future"]

simulated_errors_by_time_df = errors_by_time_df[errors_by_time_df["sample"] == simulated_sample].copy()
natural_errors_by_time_df = errors_by_time_df[errors_by_time_df["sample"] == natural_sample].copy()

In [None]:
simulated_errors_by_time_df.shape

In [None]:
natural_errors_by_time_df.shape

In [None]:
coefficients_by_time_df = pd.read_csv(coefficients_file, sep="\t", parse_dates=["validation_timepoint"])
coefficients_by_time_df = coefficients_by_time_df[~coefficients_by_time_df["predictors"].isin(predictors_to_drop)].copy()

simulated_coefficients_by_time_df = coefficients_by_time_df[coefficients_by_time_df["sample"] == simulated_sample].copy()
natural_coefficients_by_time_df = coefficients_by_time_df[coefficients_by_time_df["sample"] == natural_sample].copy()

In [None]:
simulated_coefficients_by_time_df.shape

In [None]:
natural_coefficients_by_time_df.shape

In [None]:
get_individual_predictors_for_data_frame(simulated_errors_by_time_df)

In [None]:
get_composite_predictors_for_data_frame(simulated_errors_by_time_df)

In [None]:
get_individual_predictors_for_data_frame(natural_errors_by_time_df)

In [None]:
get_composite_predictors_for_data_frame(natural_errors_by_time_df)

In [None]:
simulated_errors_by_time_df.tail()

In [None]:
natural_errors_by_time_df.query("predictors == 'cTiter_x'").loc[:, ["validation_timepoint", "validation_error", "null_validation_error"]]

## Summary of models for simulated populations

In [None]:
simulated_errors_by_time_df["model"] = simulated_errors_by_time_df["predictors"].apply(
    lambda predictor: " + ".join([name_by_predictor.get(predictor_name, predictor_name) for predictor_name in predictor.split("-")])
)
simulated_coefficients_by_time_df["model"] = simulated_coefficients_by_time_df["predictors"].apply(
    lambda predictor: " + ".join([name_by_predictor.get(predictor_name, predictor_name) for predictor_name in predictor.split("-")])
)

In [None]:
simulated_errors_by_time_df["model_better_than_naive"] = (simulated_errors_by_time_df["model_improvement"] > 0)
simulated_errors_by_time_df["relative_improvement"] = (
    simulated_errors_by_time_df["validation_error"] / simulated_errors_by_time_df["null_validation_error"]
) - 1.0

In [None]:
simulated_errors_by_time_df.head()

In [None]:
simulated_errors_by_time_df.tail()

In [None]:
simulated_validation_errors_by_time_df = simulated_errors_by_time_df.query("error_type == 'validation'").copy()
simulated_validation_coefficients_by_time_df = simulated_coefficients_by_time_df.query("error_type == 'validation'").copy()

simulated_test_errors_by_time_df = simulated_errors_by_time_df.query("error_type == 'test'").copy()
simulated_test_coefficients_by_time_df = simulated_coefficients_by_time_df.query("error_type == 'test'").copy()

### Model validation table

In [None]:
table_template_header = r"""
\begin{tabular*}{1.1\textwidth}{lrllrr}
\toprule
        &                 & \multicolumn{2}{c}{Distance to future (AAs)} & \multicolumn{2}{c}{Model $>$ naive} \\
  Model &    \makecell{Coefficients} & \makecell{Validation} & \makecell{Test} & \makecell{Validation} & \makecell{Test} \\
\midrule
"""

table_template_first_row = r"{predictor} & {coefficient_mean:.2f} +/- {coefficient_std:.2f} & {mean_error_validation:.2f} +/- {std_error_validation:.2f}{significance_mark_validation} & {mean_error_test:.2f} +/- {std_error_test:.2f}{significance_mark_test} & {model_better_count_validation} ({model_better_percentage_validation}\%) & {model_better_count_test} ({model_better_percentage_test}\%) \\"

table_template_next_row = r"\hspace{{5mm}} + {predictor} & {coefficient_mean:.2f} +/- {coefficient_std:.2f} & & & & \\"

table_template_footer = r"""
\bottomrule
\end{tabular*}
"""

In [None]:
def make_significance_mark(model_error_record):
    p_value = model_error_record["p_value"]
    predictor = model_error_record["model"]
    
    if predictor == "naive":
        return ""
    elif np.isnan(p_value):
        return "\^"
    elif p_value < 0.05:
        return "*"
    else:
        return ""

In [None]:
def group_error_data_frame(errors_df):
    model_errors = errors_df.groupby("model").aggregate({
        "validation_error": ["mean", "std"],
        "model_better_than_naive": ["sum", "mean"],
        "p_value": ["first"]
    }).sort_values(("validation_error", "mean"), ascending=True).reset_index()
    
    columns = [
        "model",
        "mean_error",
        "std_error",
        "model_better_count",
        "model_better_proportion",
        "p_value"
    ]
    model_errors.columns = columns
    model_errors["model_better_count"] = model_errors["model_better_count"].astype(int)
    model_errors["model_better_percentage"] = np.around(model_errors["model_better_proportion"] * 100, 0).astype(int)
    model_errors["significance_mark"] = model_errors.apply(make_significance_mark, axis=1)
    
    return model_errors

def group_coefficients_data_frame(coefficients_df):
    coefficient_columns = ["model", "predictor", "coefficient_mean", "coefficient_std"]
    model_coefficients = coefficients_df.groupby(["model", "predictor"], sort=False).aggregate({
        "coefficient": ["mean", "std"]
    }).reset_index()
    model_coefficients.columns = coefficient_columns
    model_coefficients["predictor"] = model_coefficients["predictor"].map(name_by_predictor)

    return model_coefficients

def prepare_complete_table(coefficients_df, errors_df, test_errors_df):
    model_coefficients = group_coefficients_data_frame(coefficients_df)
    model_selection_errors = group_error_data_frame(errors_df)
    model_test_errors = group_error_data_frame(test_errors_df)
    
    model_errors = model_selection_errors.merge(
        model_test_errors,
        on="model",
        suffixes=["_validation", "_test"]
    )
    
    model_summary = np.around(model_errors.merge(
        model_coefficients,
        on="model"
    ), 2)
    
    rows = [table_template_header]
    for model, model_df in model_summary.groupby("model", sort=False):
        for i, record in enumerate(model_df.to_dict(orient="records")):
            if i == 0:
                rows.append(table_template_first_row.format(**record))
            else:
                rows.append(table_template_next_row.format(**record))
            
    rows.append(table_template_footer)

    return "\n".join(rows)

### Table 1. Simulated model performance table

In [None]:
simple_simulated_model_selection_table = prepare_complete_table(
    simulated_validation_coefficients_by_time_df,
    simulated_validation_errors_by_time_df,
    simulated_test_errors_by_time_df
)

In [None]:
print(simple_simulated_model_selection_table)

In [None]:
with open(table_for_simulated_model_selection, "w") as oh:
    oh.write(simple_simulated_model_selection_table)

Build a clean data frame of the table's source data for export.

In [None]:
coefficients_source_file_columns = OrderedDict([
    ("model", "model"),
    ("predictor", "predictor"),
    ("validation_timepoint", "timepoint"),
    ("coefficient", "coefficient"),
])

distances_source_file_columns = OrderedDict([
    ("error_type", "error_type"),
    ("model", "model"),
    ("validation_timepoint", "timepoint"),
    ("validation_error", "model_distance_to_future"),
    ("null_validation_error", "naive_distance_to_future"),
    ("optimal_validation_error", "optimal_distance_to_future")
])

In [None]:
simulated_coefficients_source_data = simulated_validation_coefficients_by_time_df.copy()

In [None]:
simulated_coefficients_source_data = np.around(
    simulated_coefficients_source_data.loc[
        simulated_coefficients_source_data["model"] != "naive",
        tuple(coefficients_source_file_columns.keys())
    ].copy().rename(columns=coefficients_source_file_columns),
    3
)

In [None]:
simulated_coefficients_source_data

In [None]:
simulated_coefficients_source_data.to_csv(
    source_data_for_simulated_model_coefficients,
    header=True,
    index=False
)

In [None]:
simulated_distances_source_data = pd.concat([
    simulated_validation_errors_by_time_df,
    simulated_test_errors_by_time_df
])

In [None]:
simulated_distances_source_data = np.around(
    simulated_distances_source_data.loc[
        simulated_distances_source_data["model"] != "naive",
        tuple(distances_source_file_columns.keys())
    ].copy().rename(columns=distances_source_file_columns),
    3
)

In [None]:
simulated_distances_source_data

In [None]:
simulated_distances_source_data.to_csv(
    source_data_for_simulated_model_distances,
    header=True,
    index=False
)

### Figure 2. Simulated model results for controls

In [None]:
simulated_errors_by_time_df["centered_validation_error"] = (
    simulated_errors_by_time_df["validation_error"] - simulated_errors_by_time_df["optimal_validation_error"]
)

simulated_errors_by_time_df["centered_null_validation_error"] = (
    simulated_errors_by_time_df["null_validation_error"] - simulated_errors_by_time_df["optimal_validation_error"]
)

simulated_errors_by_time_df["centered_optimal_validation_error"] = 0.0

In [None]:
fig, axes, gs = plot_model_accuracy_and_coefficients_for_build(
    simulated_errors_by_time_df,
    simulated_coefficients_by_time_df,
    ["normalized_fitness"],
    rotation=0,
    years_fmt_string="%y",
    date_fmt_string="%y-%m",
    height=3,
    hspace=0.1,
    share_y=True,
    max_coefficient=13,
    max_normal_error=13
)

plt.figtext(0.0, 0.9, "A", **panel_labels_dict)
plt.figtext(0.49, 0.9, "B", **panel_labels_dict)

plt.savefig(figure_for_simulated_model_controls)

Summarize optimal distance to the future possible from the current population. These values represent the lower bound possible for any given model based on the number of amino acid mutations that accumulate during one year of evolution.

In [None]:
simulated_errors_by_time_df.query("predictors == 'naive'").groupby("error_type")["optimal_validation_error"].aggregate([
    "mean",
    "std"
])


### Figure 3. Simulated model results for individual predictors and best composite

In [None]:
fig, axes, gs = plot_model_accuracy_and_coefficients_for_build(
    simulated_errors_by_time_df,
    simulated_coefficients_by_time_df,
    ["ep_x", "ne_star", "lbi", "delta_frequency", "lbi-ne_star"],
    rotation=0,
    years_fmt_string="%y",
    date_fmt_string="%y-%m",
    height=10,
    hspace=0.1,
    share_y=True,
    max_coefficient=7,
    max_normal_error=16
)

plt.figtext(0.0, 0.98, "A", **panel_labels_dict)
plt.figtext(0.49, 0.98, "B", **panel_labels_dict)

plt.savefig(figure_for_simulated_individual_models)

In [None]:
fig, axes, gs = plot_model_accuracy_and_coefficients_for_build(
    simulated_errors_by_time_df,
    simulated_coefficients_by_time_df,
    get_composite_predictors_for_data_frame(simulated_errors_by_time_df),
    rotation=0,
    years_fmt_string="%y",
    date_fmt_string="%y-%m",
    height=6,
    max_coefficient=5.0,
    max_normal_error=16
)

plt.figtext(0.0, 0.97, "A", **panel_labels_dict)
plt.figtext(0.49, 0.97, "B", **panel_labels_dict)

plt.savefig(figure_for_simulated_composite_models)

## Summary of models for natural populations

In [None]:
subset_of_individual_predictors = [
    "ep_x",
    "cTiter_x",
    "ne_star",
    "dms_star",
    "lbi",
    "delta_frequency"
]

In [None]:
composite_models = [
    "cTiter_x-ne_star",
    "ne_star-lbi",
    "cTiter_x-ne_star-lbi"
]

In [None]:
natural_errors_by_time_df["validation_timepoint"].unique().shape

In [None]:
natural_errors_by_time_df["model"] = natural_errors_by_time_df["predictors"].apply(
    lambda predictor: " + ".join([name_by_predictor.get(predictor_name, predictor_name) for predictor_name in predictor.split("-")])
)
natural_coefficients_by_time_df["model"] = natural_coefficients_by_time_df["predictors"].apply(
    lambda predictor: " + ".join([name_by_predictor.get(predictor_name, predictor_name) for predictor_name in predictor.split("-")])
)

In [None]:
natural_errors_by_time_df["model_better_than_naive"] = (natural_errors_by_time_df["model_improvement"] > 0)
natural_errors_by_time_df["relative_improvement"] = (
    natural_errors_by_time_df["validation_error"] / natural_errors_by_time_df["null_validation_error"]
) - 1.0

In [None]:
natural_validation_errors_by_time_df = natural_errors_by_time_df.query("error_type == 'validation'").copy()
natural_validation_coefficients_by_time_df = natural_coefficients_by_time_df.query("error_type == 'validation'").copy()

natural_test_errors_by_time_df = natural_errors_by_time_df.query("error_type == 'test'").copy()
natural_test_coefficients_by_time_df = natural_coefficients_by_time_df.query("error_type == 'test'").copy()

In [None]:
subset_of_natural_validation_errors_by_time_df = natural_validation_errors_by_time_df[
    natural_validation_errors_by_time_df["predictors"].isin(["naive"] + subset_of_individual_predictors + composite_models)
].copy()

subset_of_natural_validation_coefficients_by_time_df = natural_validation_coefficients_by_time_df[
    natural_validation_coefficients_by_time_df["predictors"].isin(["naive"] + subset_of_individual_predictors + composite_models)
].copy()

subset_of_natural_test_errors_by_time_df = natural_test_errors_by_time_df[
    natural_test_errors_by_time_df["predictors"].isin(["naive"] + subset_of_individual_predictors + composite_models)
].copy()

In [None]:
natural_errors_by_time_df.query("predictors == 'naive'").groupby("error_type")["optimal_validation_error"].aggregate([
    "mean",
    "std"
])

### Table 2. Natural model performance table

In [None]:
subset_natural_model_selection_table = prepare_complete_table(
    subset_of_natural_validation_coefficients_by_time_df,
    subset_of_natural_validation_errors_by_time_df,
    subset_of_natural_test_errors_by_time_df
)

In [None]:
print(subset_natural_model_selection_table)

In [None]:
with open(table_for_natural_model_selection, "w") as oh:
    oh.write(subset_natural_model_selection_table)

### Table S3. Complete natural model performance table

Make a separate table with all models including those we do not discuss in the manuscript.

In [None]:
table_template_header = r"""
\begin{tabular*}{1.1\textwidth}{lrllrr}
\toprule
        &                 & \multicolumn{2}{c}{Distance to future (AAs)} & \multicolumn{2}{c}{Model $>$ naive} \\
  Model &    \makecell{Coefficients} & \makecell{Validation} & \makecell{Test} & \makecell{Validation} & \makecell{Test} \\
\midrule
"""

complete_natural_model_selection_table = prepare_complete_table(
    natural_validation_coefficients_by_time_df,
    natural_validation_errors_by_time_df,
    natural_test_errors_by_time_df
)

with open(table_for_natural_model_complete_selection, "w") as oh:
    oh.write(complete_natural_model_selection_table)

Build a clean data frame of the table's source data for export.

In [None]:
natural_coefficients_source_data = natural_validation_coefficients_by_time_df.copy()

In [None]:
natural_coefficients_source_data = np.around(
    natural_coefficients_source_data.loc[
        natural_coefficients_source_data["model"] != "naive",
        tuple(coefficients_source_file_columns.keys())
    ].copy().rename(columns=coefficients_source_file_columns),
    3
)

In [None]:
natural_coefficients_source_data

In [None]:
natural_coefficients_source_data.to_csv(
    source_data_for_natural_model_coefficients,
    header=True,
    index=False
)

In [None]:
natural_distances_source_data = pd.concat([
    natural_validation_errors_by_time_df,
    natural_test_errors_by_time_df
])

In [None]:
natural_distances_source_data = np.around(
    natural_distances_source_data.loc[
        natural_distances_source_data["model"] != "naive",
        tuple(distances_source_file_columns.keys())
    ].copy().rename(columns=distances_source_file_columns),
    3
)

In [None]:
natural_distances_source_data

In [None]:
natural_distances_source_data.to_csv(
    source_data_for_natural_model_distances,
    header=True,
    index=False
)

### Inspection of epitope cross-immunity performance

Epitope cross-immunity has strong predictive support in training data, based on its consistently high coefficient prior to October 2009.

In [None]:
subset_of_natural_validation_coefficients_by_time_df.query("predictors == 'ep_x' & validation_timepoint < '2009-10-01'")["coefficient"].aggregate(["mean", "std"])

By the validation timepoint for October 2009, the training data for the model no longer contains more pre-2006 information than 2006 and after information. At this timepoint and after the mean coefficient drops to effectively zero.

In [None]:
subset_of_natural_validation_coefficients_by_time_df.query("predictors == 'ep_x' & validation_timepoint >= '2009-10-01'")["coefficient"].aggregate(["mean", "std"])

Epitope cross-immunity does not overfit for the first few validation timepoints.

In [None]:
subset_of_natural_validation_errors_by_time_df.query("predictors == 'ep_x'").loc[
    :, ["validation_timepoint", "validation_error", "null_validation_error"]
].head()

### Individual models

In [None]:
natural_errors_by_time_df.query("predictors == 'dms_star' & model_improvement > 1")

### Figure S6. Comparison of models based on epitope sites

Original epitope sites from Luksza and Lassig 2014 (`ep_x` or "epitope antigenic novelty") compared to comparable sites from a reanalysis of mutational sweeps up through 2015 (`oracle_x` or "oracle antigenic novelty").

In [None]:
fig, axes, gs = plot_model_accuracy_and_coefficients_for_build(
    natural_errors_by_time_df,
    natural_coefficients_by_time_df,
    ["ep_x", "oracle_x"],
    height=5,
    rotation=0,
    max_normal_error=12
)

plt.figtext(0.0, 0.96, "A", **panel_labels_dict)
plt.figtext(0.49, 0.96, "B", **panel_labels_dict)

plt.savefig(figure_for_natural_epitope_vs_oracle_models)

### Figure 5. Natural model results for individual predictors

In [None]:
fig, axes, gs = plot_model_accuracy_and_coefficients_for_build(
    natural_errors_by_time_df,
    natural_coefficients_by_time_df,
    subset_of_individual_predictors,
    height=10,
    rotation=0,
    max_normal_error=17,
    distance_tick_multiple=3
)

plt.figtext(0.0, 0.98, "A", **panel_labels_dict)
plt.figtext(0.49, 0.98, "B", **panel_labels_dict)

plt.savefig(figure_for_natural_individual_models)

In [None]:
natural_errors_by_time_df[
    (natural_errors_by_time_df["validation_timepoint"] == "2014-10-01") & (natural_errors_by_time_df["predictors"] == "lbi")
]

### Figure 6. Natural model results for composite predictors

In [None]:
fig, axes, gs = plot_model_accuracy_and_coefficients_for_build(
    natural_errors_by_time_df,
    natural_coefficients_by_time_df,
    composite_models,
    height=7,
    text_vertical_padding=0.12,
    rotation=0,
    max_normal_error=17,
    distance_tick_multiple=3
)

plt.figtext(0.0, 0.96, "A", **panel_labels_dict)
plt.figtext(0.49, 0.96, "B", **panel_labels_dict)

plt.savefig(figure_for_natural_composite_models)

Calculate sum of differences between the estimated distances from the naive model and each biological model. The higher sum per model, the more the biological model outperforms the naive model.

In [None]:
naive_error_df = errors_by_time_df[errors_by_time_df["predictors"] == "naive"].copy()

In [None]:
natural_errors_by_time_with_naive_df = natural_errors_by_time_df.merge(
    naive_error_df,
    on=["validation_timepoint", "validation_n", "type", "sample"],
    suffixes=["_model", "_naive"]
)

In [None]:
natural_errors_by_time_with_naive_df["model_gain"] = (
    natural_errors_by_time_with_naive_df["validation_error_naive"] - natural_errors_by_time_with_naive_df["validation_error_model"]
)

In [None]:
natural_errors_by_time_with_naive_df.head()

In [None]:
natural_errors_by_time_with_naive_df.groupby("predictors_model")["model_gain"].sum().sort_values(ascending=False)

## Cross-validation figures

In [None]:
def plot_cross_validation_times(data, ax, years_fmt_string):
    """
    """
    timepoints = data.loc[:, ["validation_timepoint", "error_type"]].drop_duplicates()
    y_positions = list(range(len(timepoints.values)))
    
    validation_timepoints = timepoints.query("error_type == 'validation'").loc[:, "validation_timepoint"].apply(
        lambda time: time.toordinal()
    ).values
    validation_y_positions = y_positions[:len(validation_timepoints)]

    test_timepoints = timepoints.query("error_type == 'test'").loc[:, "validation_timepoint"].apply(
        lambda time: time.toordinal()
    ).values
    test_y_positions = y_positions[len(validation_timepoints):]
    
    one_year = pd.DateOffset(years=1)
    training_window = pd.DateOffset(years=6)
    
    training_line_segments = [
        [((timepoint - one_year - training_window).toordinal(), y), ((timepoint - one_year).toordinal(), y)]
        for timepoint, y in zip(timepoints.query("error_type == 'validation'").loc[:, "validation_timepoint"], validation_y_positions)
    ]

    markersize = 4
    years = mdates.YearLocator(5)
    years_fmt = mdates.DateFormatter(years_fmt_string)
    months = mdates.MonthLocator()

    training_lc = LineCollection(training_line_segments, zorder=9)
    training_lc.set_color("#999999")
    training_lc.set_linewidth(1)
    training_lc.set_label("Training")
    training_artist = ax.add_collection(training_lc)

    validation_artist, = ax.plot(
        validation_timepoints,
        validation_y_positions,
        "o",
        label="Validation",
        markersize=markersize,
        color="#000000"
    )
    test_artist, = ax.plot(
        test_timepoints,
        test_y_positions,
        "o",
        label="Test",
        markersize=markersize,
        color="#000000",
        fillstyle="none"
    )

    ax.xaxis.set_major_locator(years)
    ax.xaxis.set_major_formatter(years_fmt)
    ax.xaxis.set_minor_locator(months)
    ax.format_xdata = mdates.DateFormatter("%y-%m")

    ax.spines['left'].set_visible(False)
    ax.tick_params(axis='y',size=0)
    ax.set_yticklabels([])

    handles = [training_artist, validation_artist]
    labels = ["Training", "Validation"]
    
    if len(test_timepoints) > 0:
        handles.append(test_artist)
        labels.append("Test")
    
    ax.legend(
        handles,
        labels,
        frameon=False
    )
    
    ax.set_xlabel("Date")
    
    return ax

### Figure S1. Cross-validation of simulated populations

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 4))
ax = plot_cross_validation_times(simulated_errors_by_time_df, ax, years_fmt_string="%y")
fig.autofmt_xdate(rotation=0, ha="center")

plt.savefig(figure_for_simulated_cross_validation)

### Figure S5. Cross-validation of simulated populations

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 4))
ax = plot_cross_validation_times(natural_errors_by_time_df, ax, years_fmt_string="%Y")
fig.autofmt_xdate(rotation=0, ha="center")

plt.savefig(figure_for_natural_cross_validation)

### Figure S8. Natural model coefficients and distances refit across test timepoints

In [None]:
latest_natural_sample = "natural_sample_20191001"
latest_natural_errors_by_time_df = errors_by_time_df[errors_by_time_df["sample"] == latest_natural_sample].copy()
latest_natural_coefficients_by_time_df = coefficients_by_time_df[coefficients_by_time_df["sample"] == latest_natural_sample].copy()

In [None]:
fig, axes, gs = plot_model_accuracy_and_coefficients_for_build(
    latest_natural_errors_by_time_df,
    latest_natural_coefficients_by_time_df,
    ["cTiter_x-ne_star", "fra_cTiter_x-ne_star", "ne_star-lbi", "cTiter_x-ne_star-lbi"],
    rotation=0,
    years_fmt_string="%Y",
    date_fmt_string="%Y-%m",
    height=8,
    hspace=0.1,
    share_y=True,
    max_coefficient=7,
    max_normal_error=19,
    distance_tick_multiple=3
)

plt.figtext(0.0, 0.97, "A", **panel_labels_dict)
plt.figtext(0.49, 0.97, "B", **panel_labels_dict)

plt.savefig(figure_for_natural_updated_models)

In [None]:
latest_natural_errors_by_time_df.query("predictors == 'naive'")["validation_error"].mean()

In [None]:
latest_natural_errors_by_time_df.query("predictors == 'naive'")["validation_error"].std()

In [None]:
latest_natural_errors_by_time_df.query("predictors == 'naive'")["optimal_validation_error"].aggregate([
    "mean",
    "std"
])

## Investigate distances to the future by Hemisphere

Plot the distributions of distances to the future by Hemisphere for the best model.
We make Northern Hemisphere predictions in October and Southern Hemisphere predictions in April.

In [None]:
natural_sample

In [None]:
best_natural_model = "cTiter_x-ne_star"

In [None]:
best_natural_model_df = errors_by_time_df.query(
    f"(sample == '{natural_sample}') & (predictors == '{best_natural_model}')"
).copy()

Annotate forecast hemispheres by the month when the forecast was made.

In [None]:
best_natural_model_df["hemisphere"] = best_natural_model_df["validation_timepoint"].dt.month.apply(
    lambda month: "Northern" if month == 10 else "Southern"
)

Investigate absolute distance to the future measured by the best model for natural populations. These distances do not account for seasonal variation in observed distance to the future that is measured by the naive model.

In [None]:
ax = sns.violinplot(
    x="hemisphere",
    y="validation_error",
    data=best_natural_model_df,
    inner=None
)
ax = sns.swarmplot(
    x="hemisphere",
    y="validation_error",
    data=best_natural_model_df,
    ax=ax,
    color="black"
)

ax.set_xlabel("Hemisphere")
ax.set_ylabel("Distance to the future (AAs)\nfor HI + mutational load")
ax.set_ylim(bottom=0)

Investigate adjusted distance to the future for the best model where seasonal variation measured by the naive model is accounted for.

In [None]:
ax = sns.violinplot(
    x="hemisphere",
    y="model_improvement",
    data=best_natural_model_df,
    inner=None
)
ax = sns.swarmplot(
    x="hemisphere",
    y="model_improvement",
    data=best_natural_model_df,
    ax=ax,
    color="black"
)

ax.set_xlabel("Hemisphere")
ax.set_ylabel("Naive - model\ndistance to the future (AAs)")
#ax.set_ylim(bottom=0)

Inspect the median adjusted distance to the future by hemisphere.

In [None]:
best_natural_model_df.groupby("hemisphere")["model_improvement"].median()

Count the number of timepoints in each hemisphere group.

In [None]:
best_natural_model_df.groupby("hemisphere")["model_improvement"].count()