# Figure 8. Compare vaccine strains to estimated and observed closest strains to the future

Observed distance to natural H3N2 populations one year into the future for each vaccine strain (green) and the observed (blue) and estimated (orange) closest strains to the future at the corresponding timepoints. Vaccine strains were assigned to the validation or test timepoint closest to the date they were selected by the WHO. The distance to the future of each vaccine strain was calculated from their amino acid sequences and the frequencies and sequences of the corresponding population one year in the future. The estimated closest strain to the future was identified by either the best model in the validation period (mutational load and LBI) or the best model in the test period
(HI antigenic novelty and mutational load).

In [None]:
validation_tip_attributes_path = snakemake.input.validation_tip_attributes
test_tip_attributes_path = snakemake.input.test_tip_attributes

cTiter_x_ne_star_validation_forecasts_path = snakemake.input.cTiter_x_ne_star_validation_forecasts_path
ne_star_lbi_validation_forecasts_path = snakemake.input.ne_star_lbi_validation_forecasts_path

cTiter_x_ne_star_test_forecasts_path = snakemake.input.cTiter_x_ne_star_test_forecasts_path
ne_star_lbi_test_forecasts_path = snakemake.input.ne_star_lbi_test_forecasts_path

vaccines_json_path = snakemake.input.vaccines_json_path

output_figure = snakemake.output.figure

In [None]:
"""
validation_tip_attributes_path = "../results/builds/natural/natural_sample_1_with_90_vpm_sliding/tip_attributes_with_weighted_distances.tsv"
validation_forecasts_path = "../results/builds/natural/natural_sample_1_with_90_vpm_sliding/forecasts.tsv"

test_tip_attributes_path = "../results/builds/natural/natural_sample_1_with_90_vpm_sliding_test_tree/tip_attributes_with_weighted_distances.tsv"
test_forecasts_path = "../results/builds/natural/natural_sample_1_with_90_vpm_sliding_test_tree/forecasts.tsv"

vaccines_json_path = "../config/vaccines_h3n2.json"

output_figure = "../manuscript/figures/vaccine-comparison.pdf"
"""

In [None]:
import json
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import pandas as pd
from pandas.plotting import register_matplotlib_converters
from scipy.stats import pearsonr, spearmanr, probplot
import seaborn as sns
import statsmodels.api as sm

%matplotlib inline

In [None]:
register_matplotlib_converters()

In [None]:
first_validation_timepoint = "2003-10-01"

In [None]:
sns.set_style("ticks")

In [None]:
# Display figures at a reasonable default size.
mpl.rcParams['figure.figsize'] = (6, 4)

# Disable top and right spines.
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False
    
# Display and save figures at higher resolution for presentations and manuscripts.
mpl.rcParams['savefig.dpi'] = 200
mpl.rcParams['figure.dpi'] = 200

# Display text at sizes large enough for presentations and manuscripts.
mpl.rcParams['font.weight'] = "normal"
mpl.rcParams['axes.labelweight'] = "normal"
mpl.rcParams['font.size'] = 18
mpl.rcParams['axes.labelsize'] = 14
mpl.rcParams['legend.fontsize'] = 12
mpl.rcParams['xtick.labelsize'] = 14
mpl.rcParams['ytick.labelsize'] = 14

mpl.rc('text', usetex=False)

## Define functions

In [None]:
def calculate_weighted_distance_between_pairs(row):
    distance = (
        np.frombuffer(row["aa_sequence"].encode(), dtype="S1") !=
        np.frombuffer(row["aa_sequence_future"].encode(), dtype="S1")
    ).sum()
    return row["frequency"] * distance

def calculate_weighted_distance_by_group(group_df):
    return group_df.apply(calculate_weighted_distance_between_pairs, axis=1).sum()

## Load tip attributes with sequences

For each timepoint, find distinct sequences that will be used to calculate distances to the future.

In [None]:
validation_tips_with_sequence = pd.read_csv(
    validation_tip_attributes_path,
    sep="\t",
    parse_dates=["timepoint"],
    usecols=["strain", "timepoint", "frequency", "aa_sequence"]
)

In [None]:
test_tips_with_sequence = pd.read_csv(
    test_tip_attributes_path,
    sep="\t",
    parse_dates=["timepoint"],
    usecols=["strain", "timepoint", "frequency", "aa_sequence"]
)

In [None]:
# Retain only validation tips whose timepoints occur prior to the first
# test timepoint. This prevents us from using too many tips from overlapping
# timepoints between validation and test periods.
validation_tips_with_sequence = validation_tips_with_sequence[
    (validation_tips_with_sequence["timepoint"] < test_tips_with_sequence["timepoint"].min()) &
    (validation_tips_with_sequence["timepoint"] >= first_validation_timepoint)
].copy()

In [None]:
last_validation_timepoint = validation_tips_with_sequence["timepoint"].max()

In [None]:
last_validation_timepoint

In [None]:
distinct_validation_tips_with_sequence = validation_tips_with_sequence.groupby(
    ["timepoint", "aa_sequence"]
).first().reset_index()

In [None]:
distinct_test_tips_with_sequence = test_tips_with_sequence.groupby(
    ["timepoint", "aa_sequence"]
).first().reset_index()

In [None]:
tips_with_sequence = pd.concat([validation_tips_with_sequence, test_tips_with_sequence])

In [None]:
distinct_tips_with_sequence = pd.concat([distinct_validation_tips_with_sequence, distinct_test_tips_with_sequence])

## Load vaccine strain data

Load information about vaccine strains including their names, amino acid sequences, and the timepoint in our analysis when they were selected for the vaccine. These latter timepoints constrain the timepoints we consider in the analyses that follow.

In [None]:
with open(vaccines_json_path, "r") as fh:
    vaccines_json = json.load(fh)

In [None]:
vaccine_df = pd.DataFrame([
    {
        "strain_type": "vaccine",
        "strain": vaccine,
        "timepoint": vaccine_data["vaccine"]["timepoint"],
        "aa_sequence": vaccine_data["aa_sequence"]
    }
    for vaccine, vaccine_data in vaccines_json["nodes"].items()
])
vaccine_df["timepoint"] = pd.to_datetime(vaccine_df["timepoint"])
vaccine_df["future_timepoint"] = vaccine_df["timepoint"] + pd.DateOffset(months=12)

In [None]:
# Find all tips with sequences at the future timepoint for each vaccine strain.
tips_for_vaccines_df = vaccine_df.merge(
    tips_with_sequence,
    left_on=["future_timepoint"],
    right_on=["timepoint"],
    suffixes=["", "_future"]
)

In [None]:
vaccine_distance_to_future = tips_for_vaccines_df.groupby("timepoint").apply(
    calculate_weighted_distance_by_group
).reset_index(name="distance_to_future")

In [None]:
vaccine_forecasts = vaccine_df.merge(
    vaccine_distance_to_future,
    on="timepoint"
).loc[:, ["strain_type", "strain", "timepoint", "future_timepoint", "distance_to_future"]]

In [None]:
vaccine_forecasts

## Load forecasts from models

### Load forecasts from mutational load and LBI model

In [None]:
ne_star_lbi_validation_forecasts = pd.read_csv(
    ne_star_lbi_validation_forecasts_path,
    sep="\t",
    parse_dates=["timepoint", "future_timepoint"]
)

In [None]:
ne_star_lbi_validation_forecasts = ne_star_lbi_validation_forecasts.query(
    "timepoint >= '%s'" % first_validation_timepoint
).dropna().copy()

In [None]:
ne_star_lbi_test_forecasts = pd.read_csv(
    ne_star_lbi_test_forecasts_path,
    sep="\t",
    parse_dates=["timepoint", "future_timepoint"]
).dropna()

In [None]:
ne_star_lbi_forecasts = pd.concat([
    ne_star_lbi_validation_forecasts,
    ne_star_lbi_test_forecasts
])
ne_star_lbi_forecasts = ne_star_lbi_forecasts[
    ne_star_lbi_forecasts["timepoint"].isin(vaccine_forecasts["timepoint"])
].copy()

### Load forecasts from HI antigenic novelty and mutational load model

In [None]:
cTiter_x_ne_star_validation_forecasts = pd.read_csv(
    cTiter_x_ne_star_validation_forecasts_path,
    sep="\t",
    parse_dates=["timepoint", "future_timepoint"]
)

In [None]:
cTiter_x_ne_star_validation_forecasts = cTiter_x_ne_star_validation_forecasts.query(
    "timepoint >= '%s'" % first_validation_timepoint
).dropna().copy()

In [None]:
cTiter_x_ne_star_test_forecasts = pd.read_csv(
    cTiter_x_ne_star_test_forecasts_path,
    sep="\t",
    parse_dates=["timepoint", "future_timepoint"]
).dropna()

In [None]:
cTiter_x_ne_star_forecasts = pd.concat([
    cTiter_x_ne_star_validation_forecasts,
    cTiter_x_ne_star_test_forecasts
])
cTiter_x_ne_star_forecasts = cTiter_x_ne_star_forecasts[
    cTiter_x_ne_star_forecasts["timepoint"].isin(vaccine_forecasts["timepoint"])
].copy()

## Identify observed closest strains to the future

Use strains from forecast models with annotated weighted distances to the future to find observed closest strains to the future per timepoint with a vaccine.

In [None]:
combined_forecasts = pd.concat([
    ne_star_lbi_forecasts,
    cTiter_x_ne_star_forecasts
])

In [None]:
observed_closest_strains = combined_forecasts.sort_values(["timepoint", "weighted_distance_to_future"]).groupby(
    "timepoint"
).first().reset_index().loc[:, ["timepoint", "future_timepoint", "strain", "weighted_distance_to_future"]]
observed_closest_strains["strain_type"] = "observed closest"
observed_closest_strains = observed_closest_strains.rename(
    columns={"weighted_distance_to_future": "distance_to_future"}
)

In [None]:
observed_closest_strains

## Identify estimated closest strains to the future by model

Use strains from forecast models with annotated weighted distances to the future to find estimated closest strains to the future per timepoint with a vaccine.

### Identify estimated closest strains to the future by mutational load and LBI

In [None]:
estimated_closest_strains_by_ne_star_lbi = ne_star_lbi_forecasts.sort_values(
    ["timepoint", "y"]
).groupby(
    "timepoint"
).first().reset_index().loc[:, ["timepoint", "future_timepoint", "strain", "weighted_distance_to_future"]]
estimated_closest_strains_by_ne_star_lbi["strain_type"] = "estimated closest by ne_star-lbi"
estimated_closest_strains_by_ne_star_lbi = estimated_closest_strains_by_ne_star_lbi.rename(
    columns={"weighted_distance_to_future": "distance_to_future"}
)

In [None]:
estimated_closest_strains_by_ne_star_lbi

### Identify estimated closest strains to the future by HI antigenic novelty and mutational load

In [None]:
estimated_closest_strains_by_cTiter_x_ne_star = cTiter_x_ne_star_forecasts.sort_values(
    ["timepoint", "y"]
).groupby(
    "timepoint"
).first().reset_index().loc[:, ["timepoint", "future_timepoint", "strain", "weighted_distance_to_future"]]
estimated_closest_strains_by_cTiter_x_ne_star["strain_type"] = "estimated closest by cTiter_x-ne_star"
estimated_closest_strains_by_cTiter_x_ne_star = estimated_closest_strains_by_cTiter_x_ne_star.rename(
    columns={"weighted_distance_to_future": "distance_to_future"}
)

In [None]:
estimated_closest_strains_by_cTiter_x_ne_star

## Plot distance to the future by timepoint and strain type

Compare distances to the future for selected vaccine strains and the observed and estimated closest strains to the future.

In [None]:
colors_by_strain_type = {
    "vaccine strain": "#2ca02c",
    "observed best": "#1f77b4",
    "estimated best by ne_star-lbi": "#ff7f0e",
    "estimated best by cTiter_x-ne_star": "#9467bd",
}

Split data frames into validation and test periods.

In [None]:
observed_closest_strains_for_validation = observed_closest_strains[
    observed_closest_strains["timepoint"] <= last_validation_timepoint
]
estimated_closest_strains_by_cTiter_x_ne_star_for_validation = estimated_closest_strains_by_cTiter_x_ne_star[
    estimated_closest_strains_by_cTiter_x_ne_star["timepoint"] <= last_validation_timepoint
]
estimated_closest_strains_by_ne_star_lbi_for_validation = estimated_closest_strains_by_ne_star_lbi[
    estimated_closest_strains_by_ne_star_lbi["timepoint"] <= last_validation_timepoint
]
vaccine_forecasts_for_validation = vaccine_forecasts[
    vaccine_forecasts["timepoint"] <= last_validation_timepoint
]

In [None]:
observed_closest_strains_for_test = observed_closest_strains[
    observed_closest_strains["timepoint"] > last_validation_timepoint
]
estimated_closest_strains_by_cTiter_x_ne_star_for_test = estimated_closest_strains_by_cTiter_x_ne_star[
    estimated_closest_strains_by_cTiter_x_ne_star["timepoint"] > last_validation_timepoint
]
estimated_closest_strains_by_ne_star_lbi_for_test = estimated_closest_strains_by_ne_star_lbi[
    estimated_closest_strains_by_ne_star_lbi["timepoint"] > last_validation_timepoint
]
vaccine_forecasts_for_test = vaccine_forecasts[
    vaccine_forecasts["timepoint"] > last_validation_timepoint
]

In [None]:
max_distance_to_future = max(
    observed_closest_strains["distance_to_future"].max(),
    estimated_closest_strains_by_ne_star_lbi["distance_to_future"].max(),
    estimated_closest_strains_by_cTiter_x_ne_star["distance_to_future"].max(),
    vaccine_forecasts["distance_to_future"].max()
)

In [None]:
max_distance_to_future

In [None]:
distance_ticks = np.arange(int(np.ceil(max_distance_to_future)) + 2, step=2)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 3.5))
#ax = axes[0]

strain_alpha=0.5

# Plot validation results.
ax.plot(
    observed_closest_strains_for_validation["timepoint"],
    observed_closest_strains_for_validation["distance_to_future"],
    "o-",
    color=colors_by_strain_type["observed best"],
    alpha=strain_alpha,
    label="observed best"
)

ax.plot(
    estimated_closest_strains_by_cTiter_x_ne_star_for_validation["timepoint"],
    estimated_closest_strains_by_cTiter_x_ne_star_for_validation["distance_to_future"],
    "o-",
    color=colors_by_strain_type["estimated best by cTiter_x-ne_star"],
    alpha=strain_alpha,
    label="estimated best by HI + mutational load"
)

ax.plot(
    estimated_closest_strains_by_ne_star_lbi_for_validation["timepoint"],
    estimated_closest_strains_by_ne_star_lbi_for_validation["distance_to_future"],
    "o-",
    color=colors_by_strain_type["estimated best by ne_star-lbi"],
    alpha=strain_alpha,
    label="estimated best by mutational load + LBI"
)

ax.plot(
    vaccine_forecasts_for_validation["timepoint"],
    vaccine_forecasts_for_validation["distance_to_future"],
    "o-",
    color=colors_by_strain_type["vaccine strain"],
    alpha=strain_alpha,
    label="vaccine strain"
)

# Plot test results.
ax.plot(
    observed_closest_strains_for_test["timepoint"],
    observed_closest_strains_for_test["distance_to_future"],
    "o-",
    color=colors_by_strain_type["observed best"],
    alpha=strain_alpha
)

ax.plot(
    estimated_closest_strains_by_cTiter_x_ne_star_for_test["timepoint"],
    estimated_closest_strains_by_cTiter_x_ne_star_for_test["distance_to_future"],
    "o-",
    color=colors_by_strain_type["estimated best by cTiter_x-ne_star"],
    alpha=strain_alpha
)

ax.plot(
    estimated_closest_strains_by_ne_star_lbi_for_test["timepoint"],
    estimated_closest_strains_by_ne_star_lbi_for_test["distance_to_future"],
    "o-",
    color=colors_by_strain_type["estimated best by ne_star-lbi"],
    alpha=strain_alpha
)

ax.plot(
    vaccine_forecasts_for_test["timepoint"],
    vaccine_forecasts_for_test["distance_to_future"],
    "o-",
    color=colors_by_strain_type["vaccine strain"],
    alpha=strain_alpha
)

for index, record in vaccine_forecasts.iterrows():
    ax.text(
        record["timepoint"],
        record["distance_to_future"],
        record["strain"],
        fontsize=8
    )

ax.axvline(
    last_validation_timepoint,
    zorder=-10,
    alpha=0.2,
    color="#000000",
    linestyle="-",
    label="Last validation timepoint"
)

ax.legend(
    loc=(0.01, 0.65),
    frameon=False,
    fontsize=10
)

ax.set_yticks(distance_ticks)

ax.set_xlabel("Date")
ax.set_ylabel("Weighted distance\nto the future (AAs)")
ax.set_ylim(bottom=0)

plt.tight_layout()
plt.savefig(output_figure, bbox_inches="tight")

In [None]:
merged_model_and_vaccine_ne_star_lbi = vaccine_forecasts.merge(
    estimated_closest_strains_by_ne_star_lbi,
    on=["timepoint"],
    suffixes=["_vaccine", "_model"]
)

In [None]:
merged_model_and_vaccine_ne_star_lbi["vaccine_minus_model"] = (
    merged_model_and_vaccine_ne_star_lbi["distance_to_future_vaccine"] - 
    merged_model_and_vaccine_ne_star_lbi["distance_to_future_model"]
)

In [None]:
merged_model_and_vaccine_ne_star_lbi["vaccine_minus_model"]

In [None]:
(merged_model_and_vaccine_ne_star_lbi["vaccine_minus_model"] < 0).sum()

In [None]:
merged_model_and_vaccine_ne_star_lbi[
    merged_model_and_vaccine_ne_star_lbi["vaccine_minus_model"] < 0
]["vaccine_minus_model"].mean()

In [None]:
merged_model_and_vaccine_ne_star_lbi[
    merged_model_and_vaccine_ne_star_lbi["vaccine_minus_model"] >= 0
]["vaccine_minus_model"].mean()

In [None]:
merged_model_and_vaccine_cTiter_x_ne_star = vaccine_forecasts.merge(
    estimated_closest_strains_by_cTiter_x_ne_star,
    on=["timepoint"],
    suffixes=["_vaccine", "_model"]
)

In [None]:
merged_model_and_vaccine_cTiter_x_ne_star["vaccine_minus_model"] = (
    merged_model_and_vaccine_cTiter_x_ne_star["distance_to_future_vaccine"] -
    merged_model_and_vaccine_cTiter_x_ne_star["distance_to_future_model"]
)

In [None]:
merged_model_and_vaccine_cTiter_x_ne_star["vaccine_minus_model"]

In [None]:
(merged_model_and_vaccine_cTiter_x_ne_star["vaccine_minus_model"] < 0).sum()

In [None]:
merged_model_and_vaccine_cTiter_x_ne_star[
    merged_model_and_vaccine_cTiter_x_ne_star["vaccine_minus_model"] < 0
]["vaccine_minus_model"]

In [None]:
merged_model_and_vaccine_cTiter_x_ne_star[
    merged_model_and_vaccine_cTiter_x_ne_star["vaccine_minus_model"] >= 0
]["vaccine_minus_model"].mean()