# Static visualizations of influenza serology measurements

The following notebook demonstrates different static visualizations of influenza serology measurements.
These visualizations were designed to aid decision-making during influenza vaccine strain selection.
The primary questions we want to answer with each visualization are:

 - Which reference serum best covers the other circulating clades?
 - Which clades are not covered by any sera?

## Setup

In [None]:
import altair as alt
from altair_saver import save
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

%matplotlib inline

In [None]:
sns.set_style("ticks")

In [None]:
mpl.rcParams['figure.dpi'] = 150
mpl.rcParams['font.size'] = 14
mpl.rcParams['axes.labelsize'] = 14
mpl.rcParams['legend.fontsize'] = 12
mpl.rcParams['xtick.labelsize'] = 14
mpl.rcParams['ytick.labelsize'] = 14

In [None]:
color_by_clade = {}

In [None]:
color_by_clade["158N/189K"] = "#9ebe5a"
color_by_clade["144K"] = "#5aa4a8"
color_by_clade["173Q"] = "#e0a23a"

In [None]:
color_by_clade

In [None]:
clade_order = ["158N/189K", "144K", "173Q"]

In [None]:
references = [
    "A/Wisconsin/67/2005",
    "A/Brisbane/10/2007",
    "A/Perth/16/2009",
]

In [None]:
reference_order = [
    'A/Perth/16/2009\n(144K)',
    'A/Brisbane/10/2007\n(140I)',
    'A/Wisconsin/67/2005\n(193F/225N)',
]

## Prepare titer data

Select titer measurements for recent strains (those from the last year), identify the reference strains with the most measurements, and visualize the antigenic distances between clades using these measurements.

In [None]:
df = pd.read_csv(
    "../../results/h3n2/antigenic_distances_between_strains.tsv",
    sep="\t"
)

In [None]:
df.head()

In [None]:
df["reference_name"] = df.apply(lambda row: f"{row['reference_strain']}\n({row['clade_reference']})", axis=1)

In [None]:
is_selected_reference = df["reference_strain"].isin(references)
is_selected_test_clade = df["clade_test"].isin(clade_order)

In [None]:
filtered_df = df[(is_selected_reference) & (is_selected_test_clade)].copy()

In [None]:
filtered_df.shape

In [None]:
filtered_df.head()

## Plot mean titer distances by clade per reference strain

In [None]:
grouped_df = filtered_df.loc[filtered_df["reference_name"].isin(reference_order)].groupby(["reference_name", "clade_test"]).aggregate({
    "log2_titer": "mean",
    "test_strain": "count"
}).reset_index()

In [None]:
grouped_df["log2_titer"] = np.round(grouped_df["log2_titer"], 2)

In [None]:
grouped_df

In [None]:
pivot_df = grouped_df.pivot_table(
    values="log2_titer",
    index="reference_name",
    columns="clade_test",
)

In [None]:
font_size = 14

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(4, 4), dpi=100)
ax = sns.heatmap(
    data=pivot_df,
    annot=True,
    cmap="vlag",
    center=0.0,
    vmin=-2.0,
    vmax=7.0,
    ax=ax,
)

ax.set_xlabel("Test virus clade")
ax.set_ylabel("Reference strain")

In [None]:
base = alt.Chart(grouped_df).encode(
    x=alt.X(
        "clade_test:N",
        sort=clade_order,
        title="Test virus clade",
        axis=alt.Axis(labelAngle=0),
    ),
    y=alt.Y(
        "reference_name:N",
        sort=reference_order,
        title="Reference strain",
    )
).properties(
    width=400,
    height=400,
)

heatmap = base.mark_rect().encode(
    color=alt.Color(
        "log2_titer:Q",
        scale=alt.Scale(
            scheme="blueorange",
            domain=[-2.0, 7.0],
            domainMid=0.0,
        ),
        legend=alt.Legend(
            direction="vertical",
            title="log2 titer",
        )
    )
)

text = base.mark_text(baseline="middle").encode(
    text="log2_titer:Q",
    color=alt.value("black"),
)

chart = heatmap + text
chart.configure_text(
    fontSize=font_size,
).configure_axis(
    titleFontSize=font_size,
    labelFontSize=font_size - 2,
).configure_legend(
    titleFontSize=font_size,
    labelFontSize=font_size,
)

save(chart, "../../manuscript/figures/figure-1b-titer-heatmap.pdf")

In [None]:
chart

## Plot data by reference strain

In [None]:
# Initialize the figure
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
sns.despine()

# Plot conditional means.
sns.pointplot(
    x="log2_titer",
    y="reference_name",
    hue="clade_test",
    order=reference_order,
    hue_order=clade_order,
    data=filtered_df,
    dodge=0.55,
    join=False,
    palette=color_by_clade,
    markers="d",
    scale=0.75,
    errorbar=("ci", 89),
)

# Draw a line at the traditional threshold used to denote antigenic drift.
ax.axvline(
    x=2.0,
    color="#000000",
    alpha=0.25,
    zorder=-10
)

# Draw a line at the y-axis as a guide.
ax.axvline(
    x=0.0,
    color="#000000",
    alpha=0.25,
    zorder=-10,
    linestyle="dashed",
)

ax.set_xlabel("$\log_{2}$ normalized titer")
ax.set_ylabel("Reference strain")

# Improve the legend
number_of_clades = filtered_df["clade_test"].drop_duplicates().shape[0]
handles, labels = ax.get_legend_handles_labels()
ax.legend(
    handles[:number_of_clades],
    labels[:number_of_clades],
    title="clade (test strains)",
    handletextpad=0,
    columnspacing=1,
    loc="upper right",
    ncol=1,
    frameon=False
)

plt.tight_layout()
plt.savefig("../../manuscript/figures/figure-1c-titer-distributions.pdf", dpi=200)

In [None]:
# Initialize the figure
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
sns.despine()

# Show each observation with a scatterplot
sns.stripplot(
    x="log2_titer",
    y="reference_name",
    hue="clade_test",
    order=reference_order,
    hue_order=clade_order,
    data=filtered_df,
    palette=color_by_clade,
    dodge=True,
    alpha=0.5,
    jitter=0.2,
    zorder=1
)

# Plot conditional means.
sns.pointplot(
    x="log2_titer",
    y="reference_name",
    hue="clade_test",
    order=reference_order,
    hue_order=clade_order,
    data=filtered_df,
    dodge=0.55,
    join=False,
    palette=color_by_clade,
    markers="d",
    scale=0.75,
    errorbar=("ci", 89),
)

# Draw a line at the traditional threshold used to denote antigenic drift.
ax.axvline(
    x=2.0,
    color="#000000",
    alpha=0.25,
    zorder=-10
)

# Draw a line at the y-axis as a guide.
ax.axvline(
    x=0.0,
    color="#000000",
    alpha=0.25,
    zorder=-10,
    linestyle="dashed",
)

ax.set_xlabel("$\log_{2}$ normalized titer")
ax.set_ylabel("Reference strain")

# Improve the legend
number_of_clades = filtered_df["clade_test"].drop_duplicates().shape[0]
handles, labels = ax.get_legend_handles_labels()
ax.legend(
    handles[:number_of_clades],
    labels[:number_of_clades],
    title="clade (test strains)",
    handletextpad=0,
    columnspacing=1,
    loc="upper right",
    ncol=1,
    frameon=False
)

plt.tight_layout()
plt.savefig("../../manuscript/figures/figure-1d-titer-distributions-and-raw-data.pdf", dpi=200)