# Static visualizations of influenza serology measurements

The following notebook demonstrates different static visualizations of influenza serology measurements.
These visualizations were designed to aid decision-making during influenza vaccine strain selection.
The primary questions we want to answer with each visualization are:

 - Which reference serum best covers the other circulating clades?
 - Which clades are not covered by any sera?

## Setup

In [None]:
import matplotlib as mpl
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

%matplotlib inline

In [None]:
sns.set_style("ticks")

In [None]:
mpl.rcParams['figure.dpi'] = 150
mpl.rcParams['font.size'] = 14
mpl.rcParams['axes.labelsize'] = 14
mpl.rcParams['legend.fontsize'] = 12
mpl.rcParams['xtick.labelsize'] = 14
mpl.rcParams['ytick.labelsize'] = 14

In [None]:
color_by_clade = {}

In [None]:
color_by_clade["158N/189K"] = "#D2B340"
color_by_clade["144K"] = "#5AA5A8"
color_by_clade["173Q"] = "#E68033"

In [None]:
color_by_clade

In [None]:
clade_order = ["158N/189K", "144K", "173Q"]

In [None]:
references = [
    "A/Wisconsin/67/2005",
    "A/Brisbane/10/2007",
    "A/Perth/16/2009",
]

In [None]:
reference_order = [
    'A/Perth/16/2009\n(144K)',
    'A/Brisbane/10/2007\n(140I)',
    'A/Wisconsin/67/2005\n(193F/225N)',
]

## Prepare titer data

Select titer measurements for recent strains (those from the last year), identify the reference strains with the most measurements, and visualize the antigenic distances between clades using these measurements.

In [None]:
df = pd.read_csv(
    "../../results/h3n2/antigenic_distances_between_strains.tsv",
    sep="\t"
)

In [None]:
df.head()

In [None]:
df["reference_name"] = df.apply(lambda row: f"{row['reference_strain']}\n({row['clade_reference']})", axis=1)

In [None]:
is_selected_reference = df["reference_strain"].isin(references)
is_selected_test_clade = df["clade_test"].isin(clade_order)

In [None]:
filtered_df = df[(is_selected_reference) & (is_selected_test_clade)].copy()

In [None]:
filtered_df.shape

In [None]:
filtered_df.head()

## Plot mean titer distances by clade per reference strain

In [None]:
grouped_df = filtered_df.loc[filtered_df["reference_name"].isin(reference_order)].groupby(["reference_name", "clade_test"]).aggregate({
    "log2_titer": "mean",
    "test_strain": "count"
}).reset_index()

In [None]:
grouped_df["log2_titer"] = np.round(grouped_df["log2_titer"], 2)

In [None]:
grouped_df

In [None]:
pivot_df = grouped_df.pivot_table(
    values="log2_titer",
    index="reference_name",
    columns="clade_test",
).reindex(index=reference_order, columns=clade_order)

In [None]:
pivot_df

In [None]:
font_size = 14

## Setup main figure

In [None]:
fig, all_axes = plt.subplots(2, 2, figsize=(16, 10), dpi=200, sharex=False, sharey=False)
axes = all_axes.flatten()

# Panel A: Titer distances between single reference strain and tips on a phylogeny
figure_1a_img = mpimg.imread('../../manuscript/figures/figure-1a-titer-distance-in-phylogeny.png')
ax_a = axes[0]
ax_a.imshow(
    figure_1a_img,
    aspect="equal",
    interpolation="nearest",
)
ax_a.axis("off")

# Panel B: Heatmap of mean titer distances between multiple reference strains and test strains in specific clades.
ax_b = axes[1]
ax_b = sns.heatmap(
    data=pivot_df,
    annot=True,
    cmap="coolwarm",
    center=0.0,
    vmin=-3.0,
    vmax=7.0,
    ax=ax_b,
)

ax_b.set_xlabel("Test virus clade")
ax_b.set_ylabel("Reference strain")

# Panel C: Distribution plot of mean +/- std of titer distances between multiple reference strains
# and test strains in specific clades.
ax_c = axes[2]

# Plot conditional means.
ax_c = sns.pointplot(
    x="log2_titer",
    y="reference_name",
    hue="clade_test",
    order=reference_order,
    hue_order=clade_order,
    data=filtered_df,
    dodge=0.55,
    join=False,
    palette=color_by_clade,
    markers="d",
    scale=0.75,
    errorbar=("ci", 89),
    ax=ax_c
)

# Draw a line at the traditional threshold used to denote antigenic drift.
ax_c.axvline(
    x=2.0,
    color="#000000",
    alpha=0.25,
    zorder=-10
)

# Draw a line at the y-axis as a guide.
ax_c.axvline(
    x=0.0,
    color="#000000",
    alpha=0.25,
    zorder=-10,
    linestyle="dashed",
)

ax_c.set_xlabel("$\log_{2}$ normalized titer")
ax_c.set_ylabel("Reference strain")

ax_c.set_xlim(np.floor(filtered_df["log2_titer"].min()), np.ceil(filtered_df["log2_titer"].max()))

# Improve the legend
number_of_clades = filtered_df["clade_test"].drop_duplicates().shape[0]
handles, labels = ax_c.get_legend_handles_labels()
ax_c.legend(
    handles[:number_of_clades],
    labels[:number_of_clades],
    title="clade (test strains)",
    handletextpad=0,
    columnspacing=1,
    loc="upper right",
    ncol=1,
    frameon=False
)

# Panel D: Distribution plot of mean +/- std and raw values of titer distances between
# multiple reference strains and test strains in specific clades.
ax_d = axes[3]

# Show each observation with a scatterplot
ax_d = sns.stripplot(
    x="log2_titer",
    y="reference_name",
    hue="clade_test",
    order=reference_order,
    hue_order=clade_order,
    data=filtered_df,
    palette=color_by_clade,
    dodge=True,
    alpha=0.5,
    jitter=0.2,
    zorder=1,
    ax=ax_d,
)

# Plot conditional means.
ax_d = sns.pointplot(
    x="log2_titer",
    y="reference_name",
    hue="clade_test",
    order=reference_order,
    hue_order=clade_order,
    data=filtered_df,
    dodge=0.55,
    join=False,
    palette=color_by_clade,
    markers="d",
    scale=0.75,
    errorbar=("ci", 89),
    ax=ax_d,
)

# Draw a line at the traditional threshold used to denote antigenic drift.
ax_d.axvline(
    x=2.0,
    color="#000000",
    alpha=0.25,
    zorder=-10
)

# Draw a line at the y-axis as a guide.
ax_d.axvline(
    x=0.0,
    color="#000000",
    alpha=0.25,
    zorder=-10,
    linestyle="dashed",
)

ax_d.set_xlabel("$\log_{2}$ normalized titer")
ax_d.set_ylabel("Reference strain")

# Improve the legend
number_of_clades = filtered_df["clade_test"].drop_duplicates().shape[0]
handles, labels = ax_d.get_legend_handles_labels()
ax_d.legend(
    handles[:number_of_clades],
    labels[:number_of_clades],
    title="clade (test strains)",
    handletextpad=0,
    columnspacing=1,
    loc="upper right",
    ncol=1,
    frameon=False
)

# Annotate panel labels.
panel_labels_dict = {
    "weight": "bold",
    "size": 18
}
plt.figtext(0.1, 0.95, "A", **panel_labels_dict)
plt.figtext(0.5, 0.95, "B", **panel_labels_dict)
plt.figtext(0.1, 0.47, "C", **panel_labels_dict)
plt.figtext(0.5, 0.47, "D", **panel_labels_dict)

sns.despine()
plt.tight_layout()
plt.savefig("../../manuscript/figures/figure-1-static-titer-visualizations.pdf", dpi=200)