In [None]:
import matplotlib as mpl
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
from scipy.stats import linregress
import seaborn as sns

In [None]:
sns.set_style("ticks")

# Disable top and right spines.
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False

# Display and save figures at higher resolution for presentations and manuscripts.
mpl.rcParams['savefig.dpi'] = 300
mpl.rcParams['figure.dpi'] = 120

# Display text at sizes large enough for presentations and manuscripts.
mpl.rcParams['font.weight'] = "normal"
mpl.rcParams['axes.labelweight'] = "normal"
mpl.rcParams['font.size'] = 14
mpl.rcParams['axes.labelsize'] = 14
mpl.rcParams['legend.fontsize'] = 10
mpl.rcParams['xtick.labelsize'] = 14
mpl.rcParams['ytick.labelsize'] = 14
mpl.rcParams['axes.titlesize'] = 14
mpl.rc('text', usetex=False)

In [None]:
df = pd.read_csv(snakemake.input.distances, sep="\t").dropna()

In [None]:
regression_by_season = {}
for season, season_df in df.groupby("season"):
    slope, intercept, r, p, se = linregress(
        season_df["welsh_escape"].values,
        season_df["weighted_distance_to_observed_future"].values,
    )
    intercept_sign = "+" if intercept >= 0 else "-"
    abs_intercept = np.abs(intercept)
    
    regression_by_season[season] = f"y = {slope:.2f}x {intercept_sign} {abs_intercept:.2f}\nPearson's $R^2$={r**2:.2f}"


In [None]:
regression_by_season

In [None]:
g = sns.lmplot(
    df,
    x="welsh_escape",
    y="weighted_distance_to_observed_future",
    hue="season",
    col="season",
    col_wrap=3,
)

g.set_axis_labels(
    "Welsh et al. escape score",
    "Weighted distance to observed future (AAs)",
)

for season, season_ax in g.axes_dict.items():
    season_ax.text(
        0.75,
        0.75,
        regression_by_season[season],
        horizontalalignment='center',
        verticalalignment='center',
        transform=season_ax.transAxes,
    )

plt.tight_layout()
plt.savefig(snakemake.output.distances_figure, dpi=300)