# Figure 8. Compare vaccine strains to estimated and observed closest strains to the future

Observed distance to natural H3N2 populations one year into the future for each vaccine strain (green) and the observed (blue) and estimated (orange) closest strains to the future at the corresponding timepoints. Vaccine strains were assigned to the validation or test timepoint closest to the date they were selected by the WHO. The distance to the future of each vaccine strain was calculated from their amino acid sequences and the frequencies and sequences of the corresponding population one year in the future. The estimated closest strain to the future was identified by either the best model in the validation period (mutational load and LBI) or the best model in the test period
(HI antigenic novelty and mutational load).

In [None]:
validation_tip_attributes_path = snakemake.input.validation_tip_attributes
validation_forecasts_path = snakemake.input.validation_forecasts_path

test_tip_attributes_path = snakemake.input.test_tip_attributes
test_forecasts_path = snakemake.input.test_forecasts_path

vaccines_json_path = snakemake.input.vaccines_json_path

output_figure = snakemake.output.figure

In [None]:
"""
validation_tip_attributes_path = "../results/builds/natural/natural_sample_1_with_90_vpm_sliding/tip_attributes_with_weighted_distances.tsv"
validation_forecasts_path = "../results/builds/natural/natural_sample_1_with_90_vpm_sliding/forecasts.tsv"

test_tip_attributes_path = "../results/builds/natural/natural_sample_1_with_90_vpm_sliding_test_tree/tip_attributes_with_weighted_distances.tsv"
test_forecasts_path = "../results/builds/natural/natural_sample_1_with_90_vpm_sliding_test_tree/forecasts.tsv"

vaccines_json_path = "../config/vaccines_h3n2.json"

output_figure = "../manuscript/figures/vaccine-comparison.pdf"
"""

In [None]:
import json
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
import pandas as pd
from pandas.plotting import register_matplotlib_converters
from scipy.stats import pearsonr, spearmanr, probplot
import seaborn as sns
import statsmodels.api as sm

%matplotlib inline

In [None]:
register_matplotlib_converters()

In [None]:
first_validation_timepoint = "2003-10-01"

min_clade_frequency = 0.15
precision = 4
pseudofrequency = 0.001
number_of_bootstrap_samples = 200

In [None]:
sns.set_style("ticks")

In [None]:
# Display figures at a reasonable default size.
mpl.rcParams['figure.figsize'] = (6, 4)

# Disable top and right spines.
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False
    
# Display and save figures at higher resolution for presentations and manuscripts.
mpl.rcParams['savefig.dpi'] = 200
mpl.rcParams['figure.dpi'] = 200

# Display text at sizes large enough for presentations and manuscripts.
mpl.rcParams['font.weight'] = "normal"
mpl.rcParams['axes.labelweight'] = "normal"
mpl.rcParams['font.size'] = 18
mpl.rcParams['axes.labelsize'] = 14
mpl.rcParams['legend.fontsize'] = 12
mpl.rcParams['xtick.labelsize'] = 14
mpl.rcParams['ytick.labelsize'] = 14

mpl.rc('text', usetex=False)

## Load data

In [None]:
validation_forecasts = pd.read_csv(
    validation_forecasts_path,
    sep="\t",
    parse_dates=["timepoint", "future_timepoint"]
)

In [None]:
validation_forecasts = validation_forecasts.query("timepoint >= '%s'" % first_validation_timepoint).dropna().copy()

In [None]:
test_forecasts = pd.read_csv(
    test_forecasts_path,
    sep="\t",
    parse_dates=["timepoint", "future_timepoint"]
).dropna()

In [None]:
test_forecasts.head()

In [None]:
validation_forecasts.tail()

In [None]:
forecasts = pd.concat([validation_forecasts, test_forecasts])

In [None]:
forecasts.head()

In [None]:
forecasts.shape

In [None]:
validation_forecasts.shape

In [None]:
test_forecasts.shape

In [None]:
validation_forecasts.shape[0] + test_forecasts.shape[0]

In [None]:
forecasts.dropna().shape

## Estimated and observed closest strains per timepoint

In [None]:
sorted_df = forecasts.sort_values(
    ["timepoint"]
).copy()

In [None]:
sorted_df.head()

In [None]:
validation_tips_with_sequence = pd.read_csv(
    validation_tip_attributes_path,
    sep="\t",
    parse_dates=["timepoint"],
    usecols=["strain", "timepoint", "frequency", "aa_sequence"]
)

In [None]:
test_tips_with_sequence = pd.read_csv(
    test_tip_attributes_path,
    sep="\t",
    parse_dates=["timepoint"],
    usecols=["strain", "timepoint", "frequency", "aa_sequence"]
)

In [None]:
validation_tips_with_sequence.shape

In [None]:
validation_tips_with_sequence = validation_tips_with_sequence[
    validation_tips_with_sequence["timepoint"].isin(validation_forecasts["timepoint"])
].copy()

In [None]:
validation_tips_with_sequence.shape

In [None]:
distinct_validation_tips_with_sequence = validation_tips_with_sequence.groupby(["timepoint", "aa_sequence"]).first().reset_index()

In [None]:
distinct_test_tips_with_sequence = test_tips_with_sequence.groupby(["timepoint", "aa_sequence"]).first().reset_index()

In [None]:
distinct_validation_tips_with_sequence.head()

In [None]:
distinct_validation_tips_with_sequence.tail()

In [None]:
distinct_test_tips_with_sequence.head()

In [None]:
distinct_test_tips_with_sequence.tail()

In [None]:
tips_with_sequence = pd.concat([validation_tips_with_sequence, test_tips_with_sequence])

In [None]:
distinct_tips_with_sequence = pd.concat([distinct_validation_tips_with_sequence, distinct_test_tips_with_sequence])

In [None]:
sorted_df = sorted_df.merge(
    distinct_tips_with_sequence,
    on=["timepoint", "strain", "frequency"]
)

In [None]:
sorted_df.head()

In [None]:
sorted_df.tail()

In [None]:
sorted_df["timepoint_rank"] = sorted_df.groupby("timepoint")["weighted_distance_to_future"].rank(pct=True)
sorted_df["timepoint_estimated_rank"] = sorted_df.groupby("timepoint")["y"].rank(pct=True)

In [None]:
spearmanr(
    sorted_df["timepoint_rank"],
    sorted_df["timepoint_estimated_rank"]
)

In [None]:
observed_closest_strains = sorted_df.sort_values(
    ["timepoint", "weighted_distance_to_future"],
    ascending=True
).groupby("timepoint").first().reset_index().loc[:, ["timepoint", "strain", "weighted_distance_to_future", "aa_sequence"]]

In [None]:
estimated_closest_strains = sorted_df.sort_values(
    ["timepoint", "y"],
    ascending=True
).groupby("timepoint").first().reset_index().loc[:, ["timepoint", "strain", "weighted_distance_to_future", "y", "timepoint_rank", "aa_sequence"]]

In [None]:
estimated_closest_strains.head()

In [None]:
closest_strains = observed_closest_strains.merge(
    estimated_closest_strains,
    on=["timepoint"],
    suffixes=["_observed", "_estimated"]
)

In [None]:
closest_strains.head()

In [None]:
with open(vaccines_json_path, "r") as fh:
    vaccines_json = json.load(fh)

In [None]:
vaccine_df = pd.DataFrame([
    {
        "vaccine_strain": vaccine,
        "vaccine_timepoint": vaccine_data["vaccine"]["timepoint"],
        "aa_sequence": vaccine_data["aa_sequence"]
    }
    for vaccine, vaccine_data in vaccines_json["nodes"].items()
])
vaccine_df["vaccine_timepoint"] = pd.to_datetime(vaccine_df["vaccine_timepoint"])
vaccine_df["future_vaccine_timepoint"] = vaccine_df["vaccine_timepoint"] + pd.DateOffset(months=12)

In [None]:
vaccine_df

In [None]:
forecasts.head()

In [None]:
tips_for_vaccines_df = vaccine_df.merge(
    tips_with_sequence,
    left_on=["future_vaccine_timepoint"],
    right_on=["timepoint"],
    suffixes=["_vaccine", ""]
)

In [None]:
tips_for_vaccines_df.head()

In [None]:
def calculate_weighted_distance_between_pairs(row):
    distance = (
        np.frombuffer(row["aa_sequence_vaccine"].encode(), dtype="S1") !=
        np.frombuffer(row["aa_sequence"].encode(), dtype="S1")
    ).sum()
    return row["frequency"] * distance

def calculate_weighted_distance_by_group(group_df):
    return group_df.apply(calculate_weighted_distance_between_pairs, axis=1).sum()

In [None]:
vaccine_distance_to_future = tips_for_vaccines_df.groupby("vaccine_timepoint").apply(
    calculate_weighted_distance_by_group
).reset_index(name="vaccine_distance_to_future")

In [None]:
vaccine_forecasts = vaccine_df.merge(
    vaccine_distance_to_future,
    on="vaccine_timepoint"
)

In [None]:
vaccine_forecasts

In [None]:
closest_strains_with_vaccine_timepoints = closest_strains[closest_strains["timepoint"].isin(vaccine_forecasts["vaccine_timepoint"])]

In [None]:
vaccine_forecasts = vaccine_forecasts[vaccine_forecasts["vaccine_timepoint"].isin(closest_strains_with_vaccine_timepoints["timepoint"])].copy()

In [None]:
closest_strains_with_vaccine_timepoints.shape

In [None]:
vaccine_forecasts.shape

In [None]:
for index, record in vaccine_forecasts.iterrows():
    print(record["vaccine_strain"])

In [None]:
max_analysis_distance_to_future = closest_strains_with_vaccine_timepoints.loc[
    :, ["weighted_distance_to_future_observed", "weighted_distance_to_future_estimated"]
].max().max()

max_vaccine_distance_to_future = vaccine_forecasts["vaccine_distance_to_future"].max()
max_distance_to_future = max(max_analysis_distance_to_future, max_vaccine_distance_to_future)

In [None]:
distance_ticks = np.arange(int(np.ceil(max_distance_to_future)) + 2, step=2)

In [None]:
distance_ticks

In [None]:
vaccine_forecasts["vaccine_distance_to_future"].max()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 4))
strain_alpha=0.5
ax.plot(
    closest_strains_with_vaccine_timepoints["timepoint"],
    closest_strains_with_vaccine_timepoints["weighted_distance_to_future_observed"],
    "o-", label="Observed best", alpha=strain_alpha
)
ax.plot(
    closest_strains_with_vaccine_timepoints["timepoint"],
    closest_strains_with_vaccine_timepoints["weighted_distance_to_future_estimated"],
    "o-", label="Estimated best", alpha=strain_alpha
)
ax.plot(
    vaccine_forecasts["vaccine_timepoint"],
    vaccine_forecasts["vaccine_distance_to_future"],
    "o-", label="Vaccine strain", alpha=strain_alpha
)

for index, record in vaccine_forecasts.iterrows():
    ax.text(
        record["vaccine_timepoint"],
        record["vaccine_distance_to_future"],
        record["vaccine_strain"],
        fontsize=8
    )

ax.axvline(
    validation_forecasts["timepoint"].max(),
    zorder=-10,
    alpha=0.2,
    color="#000000",
    linestyle="--",
    label="Last validation timepoint"
)

ax.legend(
    loc=(0.01, 0.7),
    frameon=False
)

ax.set_yticks(distance_ticks)

ax.set_xlabel("Date")
ax.set_ylabel("Weighted distance\nto the future (AAs)")
ax.set_ylim(bottom=0)

plt.tight_layout()
plt.savefig(output_figure, bbox_inches="tight")

In [None]:
vaccine_forecasts

In [None]:
merged_model_and_vaccine_df = closest_strains_with_vaccine_timepoints.merge(
    vaccine_forecasts,
    left_on=["timepoint"],
    right_on=["vaccine_timepoint"]
)

In [None]:
merged_model_and_vaccine_df["vaccine_distance_to_future"] - merged_model_and_vaccine_df["weighted_distance_to_future_estimated"]