In [None]:
import matplotlib as mpl
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
from scipy.stats import linregress
import seaborn as sns

In [None]:
sns.set_style("ticks")

# Disable top and right spines.
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False

# Display and save figures at higher resolution for presentations and manuscripts.
mpl.rcParams['savefig.dpi'] = 300
mpl.rcParams['figure.dpi'] = 120

# Display text at sizes large enough for presentations and manuscripts.
mpl.rcParams['font.weight'] = "normal"
mpl.rcParams['axes.labelweight'] = "normal"
mpl.rcParams['font.size'] = 14
mpl.rcParams['axes.labelsize'] = 14
mpl.rcParams['legend.fontsize'] = 10
mpl.rcParams['xtick.labelsize'] = 14
mpl.rcParams['ytick.labelsize'] = 14
mpl.rcParams['axes.titlesize'] = 14
mpl.rc('text', usetex=False)

In [None]:
df = pd.read_csv(snakemake.input.distances, sep="\t").dropna()

In [None]:
subclades = df["subclade"].drop_duplicates().sort_values().values

In [None]:
n_subclades = len(subclades)

In [None]:
n_subclades

In [None]:
i = 1
with open(snakemake.input.color_schemes, "r", encoding="utf-8") as fh:
    for line in fh:
        if i == n_subclades:
            colors = line.strip().split("\t")
            break
            
        i = i + 1

In [None]:
color_by_clade = dict(zip(subclades, colors))

In [None]:
regression_by_season = {}
slope_by_season = {}
intercept_by_season = {}

for season, season_df in df.groupby("season"):
    slope, intercept, r, p, se = linregress(
        season_df["welsh_escape"].values,
        season_df["weighted_distance_to_observed_future"].values,
    )
    intercept_sign = "+" if intercept >= 0 else "-"
    abs_intercept = np.abs(intercept)
    
    regression_by_season[season] = f"y = {slope:.2f}x {intercept_sign} {abs_intercept:.2f}\nPearson's $R^2$={r**2:.2f}"
    slope_by_season[season] = slope
    intercept_by_season[season] = intercept

In [None]:
g = sns.lmplot(
    df,
    x="welsh_escape",
    y="weighted_distance_to_observed_future",
    hue="subclade",
    hue_order=subclades,
    palette=color_by_clade,
    col="season",
    col_wrap=3,
    fit_reg=False,
    height=6,
)

g.set_axis_labels(
    "Welsh et al. escape score",
    "Weighted distance to observed future (AAs)",
)

x_values = np.arange(0, df["welsh_escape"].max())

for season, season_ax in g.axes_dict.items():
    season_ax.text(
        0.25,
        0.15,
        regression_by_season[season],
        horizontalalignment='center',
        verticalalignment='center',
        transform=season_ax.transAxes,
    )
    
    y_values = (slope_by_season[season] * x_values) + intercept_by_season[season]
    season_ax.plot(
        x_values,
        y_values,
        "-",
        color="#999999",
        zorder=-10,
    )

g.tight_layout()
plt.savefig(snakemake.output.distances_figure, dpi=300)