In [None]:
import matplotlib as mpl
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
from scipy.stats import linregress
import seaborn as sns

In [None]:
sns.set_style("ticks")

# Disable top and right spines.
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False

# Display and save figures at higher resolution for presentations and manuscripts.
mpl.rcParams['savefig.dpi'] = 300
mpl.rcParams['figure.dpi'] = 120

# Display text at sizes large enough for presentations and manuscripts.
mpl.rcParams['font.weight'] = "normal"
mpl.rcParams['axes.labelweight'] = "normal"
mpl.rcParams['font.size'] = 14
mpl.rcParams['axes.labelsize'] = 14
mpl.rcParams['legend.fontsize'] = 10
mpl.rcParams['xtick.labelsize'] = 14
mpl.rcParams['ytick.labelsize'] = 14
mpl.rcParams['axes.titlesize'] = 14
mpl.rc('text', usetex=False)

In [None]:
df = pd.read_csv(snakemake.input.distances, sep="\t").dropna()

In [None]:
df.head()

In [None]:
df.shape

In [None]:
df = df.query("frequency > 0").copy()

In [None]:
df.shape

In [None]:
subclades = df["subclade"].drop_duplicates().sort_values().values

In [None]:
n_subclades = len(subclades)

In [None]:
n_subclades

In [None]:
i = 1
with open(snakemake.input.color_schemes, "r", encoding="utf-8") as fh:
    for line in fh:
        if i == n_subclades:
            colors = line.strip().split("\t")
            break
            
        i = i + 1

In [None]:
color_by_clade = dict(zip(subclades, colors))

In [None]:
historical_clades = df["historicalclade"].drop_duplicates().sort_values().values

In [None]:
n_historical_clades = len(historical_clades)

In [None]:
i = 1
with open(snakemake.input.color_schemes, "r", encoding="utf-8") as fh:
    for line in fh:
        if i == n_historical_clades:
            historical_colors = line.strip().split("\t")
            break
            
        i = i + 1

In [None]:
color_by_historical_clade = dict(zip(historical_clades, historical_colors))

In [None]:
g = sns.lmplot(
    df,
    x="welsh_escape_per_ha1",
    y="weighted_distance_to_observed_future",
    hue="subclade",
    hue_order=subclades,
    palette=color_by_clade,
    col="season",
    col_wrap=3,
    fit_reg=False,
    height=6,
    scatter_kws={"alpha": 0.5},
)

g.set_axis_labels(
    "Welsh et al. escape score per HA1 substitutions",
    "Weighted distance to observed future (AAs)",
)

x_values = np.arange(0, df["welsh_escape_per_ha1"].max(), 0.001)

for season, season_ax in g.axes_dict.items():
    season_df = df[df["season"] == season]
    slope, intercept, r, p, se = linregress(
        season_df["welsh_escape_per_ha1"].values,
        season_df["weighted_distance_to_observed_future"].values,
    )
    intercept_sign = "+" if intercept >= 0 else "-"
    abs_intercept = np.abs(intercept)
    
    regression = f"y = {slope:.2f}x {intercept_sign} {abs_intercept:.2f}\nPearson's $R^2$={r**2:.2f}"
    
    season_ax.text(
        0.25,
        0.15,
        regression,
        horizontalalignment='center',
        verticalalignment='center',
        transform=season_ax.transAxes,
    )
    
    y_values = (slope * x_values) + intercept
    season_ax.plot(
        x_values,
        y_values,
        "-",
        color="#999999",
        zorder=-10,
    )
    
    season_ax.set_ylim(bottom=0.0)

g.tight_layout()
plt.savefig(snakemake.output.distances_by_subclade_and_escape_score, dpi=300)

In [None]:
g = sns.lmplot(
    df,
    x="welsh_escape_upper_80th_quantile_per_ha1",
    y="weighted_distance_to_observed_future",
    hue="subclade",
    hue_order=subclades,
    palette=color_by_clade,
    col="season",
    col_wrap=3,
    fit_reg=False,
    height=6,
    scatter_kws={"alpha": 0.5},
)

g.set_axis_labels(
    "Upper 80th quantile Welsh et al. escape score\nper HA1 substitutions",
    "Weighted distance to observed future (AAs)",
)

x_values = np.arange(0, df["welsh_escape_upper_80th_quantile_per_ha1"].max(), 0.001)

for season, season_ax in g.axes_dict.items():
    season_df = df[df["season"] == season]
    slope, intercept, r, p, se = linregress(
        season_df["welsh_escape_upper_80th_quantile_per_ha1"].values,
        season_df["weighted_distance_to_observed_future"].values,
    )
    intercept_sign = "+" if intercept >= 0 else "-"
    abs_intercept = np.abs(intercept)
    
    regression = f"y = {slope:.2f}x {intercept_sign} {abs_intercept:.2f}\nPearson's $R^2$={r**2:.2f}"
    
    season_ax.text(
        0.25,
        0.15,
        regression,
        horizontalalignment='center',
        verticalalignment='center',
        transform=season_ax.transAxes,
    )
    
    y_values = (slope * x_values) + intercept
    season_ax.plot(
        x_values,
        y_values,
        "-",
        color="#999999",
        zorder=-10,
    )
    
    season_ax.set_ylim(bottom=0.0)

g.tight_layout()
plt.savefig(snakemake.output.distances_by_subclade_and_upper_80th_quantile_escape_score, dpi=300)

In [None]:
seasons = df["season"].drop_duplicates().sort_values().values

In [None]:
regression_placement_by_season = {
    "2020-10-01": (0.35, 0.15),
    "2021-02-01": (0.75, 0.6),
    "2021-10-01": (0.75, 0.75),
    "2022-02-01": (0.35, 0.5),
    "2022-10-01": (0.35, 0.15),
    "2023-02-01": (0.75, 0.75),
}

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(12, 8), dpi=300)

x_values = np.arange(0, df["welsh_escape_per_ha1"].max(), 0.001)

for season, ax in zip(seasons, axes.flatten()):
    df_ax = df.query(f"season == '{season}'")
    
    slope, intercept, r, p, se = linregress(
        df_ax["welsh_escape_per_ha1"].values,
        df_ax["weighted_distance_to_observed_future"].values,
    )
    intercept_sign = "+" if intercept >= 0 else "-"
    abs_intercept = np.abs(intercept)
    
    regression = f"y = {slope:.2f}x {intercept_sign} {abs_intercept:.2f}\nPearson's $R^2$={r**2:.2f}"
    
    ax = sns.scatterplot(
        data=df_ax,
        x="welsh_escape_per_ha1",
        y="weighted_distance_to_observed_future",
        hue="historicalclade",
        ax=ax,
        legend="brief",
        alpha=0.5,
    )
    ax.set_xlabel("Welsh et al. escape score\nper HA1 substitutions")
    ax.set_ylabel("Weighted distance to\nobserved future (AAs)")

    ax.text(
        regression_placement_by_season[season][0],
        regression_placement_by_season[season][1],
        regression,
        horizontalalignment='center',
        verticalalignment='center',
        transform=ax.transAxes,
        fontsize=12,
    )
    
    y_values = (slope * x_values) + intercept
    ax.plot(
        x_values,
        y_values,
        "-",
        color="#999999",
        zorder=-10,
    )
    
    ax.set_ylim(bottom=0.0)
    
    ax.legend(
        frameon=False,
        title="clade",
    )
    
    ax.set_title(season)

plt.tight_layout()
plt.savefig(snakemake.output.distances_by_historical_clade, dpi=300)

In [None]:
g = sns.catplot(
    df,
    x="welsh_escape_per_ha1",
    y="historicalclade",
    col="season",
    col_wrap=3,
    sharey=False,
    height=6,
    alpha=0.5,
)

g.set_axis_labels(
    "Welsh et al. escape score per HA1 substitutions",
    "Clade",
)

g.tight_layout()
plt.savefig(snakemake.output.escape_scores_by_historical_clade, dpi=300)

In [None]:
df.groupby(["season", "historicalclade"]).agg({"welsh_escape_per_ha1": ["mean", "median"]})

In [None]:
df.head()

In [None]:
lbi_regression_placement_by_season = {
    "2020-10-01": (0.35, 0.1),
    "2021-02-01": (0.75, 0.5),
    "2021-10-01": (0.75, 0.5),
    "2022-02-01": (0.35, 0.1),
    "2022-10-01": (0.75, 0.1),
    "2023-02-01": (0.75, 0.1),
}

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(12, 8), dpi=300)

x_values = np.arange(0, df["lbi"].max(), 0.001)

for season, ax in zip(seasons, axes.flatten()):
    df_ax = df.query(f"season == '{season}'")
    
    slope, intercept, r, p, se = linregress(
        df_ax["lbi"].values,
        df_ax["weighted_distance_to_observed_future"].values,
    )
    intercept_sign = "+" if intercept >= 0 else "-"
    abs_intercept = np.abs(intercept)
    
    regression = f"y = {slope:.2f}x {intercept_sign} {abs_intercept:.2f}\nPearson's $R^2$={r**2:.2f}"
    
    ax = sns.scatterplot(
        data=df_ax,
        x="lbi",
        y="weighted_distance_to_observed_future",
        hue="historicalclade",
        ax=ax,
        legend="brief",
        alpha=0.5,
    )
    ax.set_xlabel("Local branching index (LBI)")
    ax.set_ylabel("Weighted distance to\nobserved future (AAs)")

    ax.text(
        lbi_regression_placement_by_season[season][0],
        lbi_regression_placement_by_season[season][1],
        regression,
        horizontalalignment='center',
        verticalalignment='center',
        transform=ax.transAxes,
        fontsize=12,
    )
    
    y_values = (slope * x_values) + intercept
    ax.plot(
        x_values,
        y_values,
        "-",
        color="#999999",
        zorder=-10,
    )
    
    ax.set_ylim(bottom=0.0)
    
    ax.legend(
        frameon=False,
        title="clade",
    )
    
    ax.set_title(season)

plt.tight_layout()
plt.savefig(snakemake.output.distances_by_historical_clade_and_lbi, dpi=300)