In [1]:
from jetstream_hugo.jet_finding import *
from jetstream_hugo.definitions import *
from jetstream_hugo.definitions import _compute
from jetstream_hugo.plots import *
from jetstream_hugo.clustering import *
from jetstream_hugo.data import *
import colormaps
from IPython.display import clear_output
import polars as pl
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
data_handlers = {}
for varname in ["u", "v", "s"]:
    dh = DataHandler.from_specs("ERA5", "plev", varname, "6H", "all", None, -80, 40, 15, 80, [175, 200, 225, 250, 300, 350], reduce_da=False)
    data_handlers[varname] = dh
data_handler = DataHandler.from_several_dhs(data_handlers)
exp = JetFindingExperiment(data_handler)
ds = exp.ds
all_jets_one_df, all_jets_over_time, flags = exp.track_jets()
props_as_df_uncat = exp.props_as_df(False)
props_as_df = exp.props_as_df(True)
all_props_over_time = exp.props_over_time(all_jets_over_time, props_as_df_uncat)
ds = exp.ds
da = exp.ds["s"]

jet_pos_da = jet_position_as_da(all_jets_one_df, exp.path)
props_as_df = get_nao(props_as_df)
props_as_df = get_double_jet_index(props_as_df, jet_pos_da)

In [3]:
metric = "euclidean"
nx = 6
ny = 4
dh = DataHandler.from_specs("ERA5", "plev", varname, "6H", "all", None, -80, 40, 15, 80, [175, 200, 225, 250, 300, 350], reduce_da=False)
exp_s = Experiment()
ds_center_path = exp_s.path.joinpath(f"som_{nx}_{ny}_pbc_{metric}_center.nc")
if not ds_center_path.is_file():
    ds = exp.ds
    ds = ds.sel(time=ds.time.dt.season=="JJA")
kwargs_som = dict(
    nx=nx,
    ny=ny,
    metric=metric,
    return_type=RAW_REALSPACE,
    force=False,
    train_kwargs=dict(train_algo="batch", epochs=50, start_learning_rate=0.05)
)
net, centers, labels = exp_s.som_cluster(**kwargs_som)
populations = net.compute_populations()
coords = net.neighborhoods.coordinates
if not ds_center_path.is_file():
    ds_center = labels_to_centers(labels, ds)
    ds_center.to_netcdf(ds_center_path)
else:
    ds_center = xr.open_dataset(ds_center_path)
mask = labels_to_mask(labels)
mask_da = xr.DataArray(mask, coords={"time": exp_s.da.time, "cluster": np.arange(net.n_nodes)})

In [2]:
ds_cesm = xr.open_dataset("/storage/workspaces/giub_meteo_impacts/ci01/CESM2/flat_wind/ds.zarr", engine="zarr")
ds_cesm = ds_cesm.chunk({"member": 1, "time": 100, "lat": -1, "lon": -1})
dh = DataHandler(ds_cesm, "/storage/workspaces/giub_meteo_impacts/ci01/CESM2/flat_wind/results")
exp_cesm = JetFindingExperiment(dh)
jets_cesm, _, _ = exp_cesm.track_jets()
props_cesm = exp_cesm.props_as_df(True)
# da_cesm = ds_cesm["s"].sel(time=ds_cesm.time.dt.season=="JJA")
# dh = DataHandler(da_cesm, "/storage/workspaces/giub_meteo_impacts/ci01/CESM2/flat_wind/results")
# exp_s_cesm = Experiment(dh)
# net, centers, labels = exp_s_cesm.project_on_other_som(exp_s, **kwargs_som)

In [30]:
from matplotlib.dates import DateFormatter, MonthLocator
from matplotlib.lines import Line2D 
def periodic_rolling(df: pl.DataFrame, winsize: int, data_vars: list):
    halfwinsize = winsize // 2
    other_columns = get_index_columns(df, ("member", "jet"))
    descending = [False, *[col == "jet" for col in other_columns]]
    len_ = [df[col].unique().len() for col in other_columns]
    len_ = np.prod(len_)
    df = df.sort(["time", *other_columns], descending=descending)
    df = pl.concat([df.tail(halfwinsize * len_).with_columns(pl.col("time") - 366), df, df.head(halfwinsize * len_).with_columns(pl.col("time") + 366)])
    df = df.rolling(pl.col("time"), period=f"{winsize}i", offset=f"-{halfwinsize + 1}i", group_by=other_columns).agg(*[pl.col(col).mean() for col in data_vars])
    df = df.sort(["time", *other_columns], descending=descending)
    df = df.slice(halfwinsize * len_, 366 * len_)
    return df

def plot_seasonal(data_vars: list, props_as_df: pl.DataFrame, nrows: int = 3, ncols: int = 4, clear: bool = True, suffix: str = ""):
    if clear:
        plt.ioff()
    else:
        plt.ion()
        plt.show()
        clear_output()
    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols*3.5, nrows*2.4), tight_layout=True, sharex="all")
    axes = axes.flatten()
    jets = props_as_df["jet"].unique().to_numpy()
    member = [pl.col("member")] if "member" in props_as_df.columns else []
    past_props = props_as_df.filter(pl.col("time").dt.year() < 2025)
    future_props = props_as_df.filter(pl.col("time").dt.year() >= 2025)
    for k, df in enumerate([past_props, future_props]):
        ls = "solid" if k == 0 else "dashed"
        gb = df.group_by([*member, pl.col("time").dt.ordinal_day(), pl.col("jet")], maintain_order=True)
        means = gb.mean().cast({"time": pl.Int32})
        means = periodic_rolling(means, 15, data_vars)
        x = means["time"].unique()
        if "member" in props_as_df.columns:
            means = means.group_by(["time", "jet"], maintain_order=True).mean().drop("member")
        for varname, ax in zip(data_vars, axes.ravel()):
            dji = (varname == "double_jet_index")
            ys = means[varname].to_numpy().reshape(366, 2)
            if varname == "width":
                pre = "k"
                ys = ys / 1000
            else:
                pre = ""
            for i in range(2):
                color = "black" if dji else COLORS[2 - i]
                ax.plot(x, ys[:, i], lw=3, color=color, zorder=10, ls=ls)
                if dji:
                    break
            ax.set_title(f"{PRETTIER_VARNAME.get(varname, varname)} [{pre}{UNITS.get(varname, '')}]")
            ax.xaxis.set_major_locator(MonthLocator(range(0, 13, 3)))
            ax.xaxis.set_major_formatter(DateFormatter("%b"))
            ax.set_xlim(min(x), max(x))
            if varname == "mean_lev" and (k == 0):
                ax.invert_yaxis()
            # ylim = ax.get_ylim()
            # wherex = np.isin(x.month, [6, 7, 8])
            # ax.fill_between(x, *ylim, where=wherex, alpha=0.1, color="black", zorder=-10)
            # ax.set_ylim(ylim)
        if k == 0:
            continue
        handles = [
            Line2D([0], [0], color=COLORS[2], lw=2),
            Line2D([0], [0], color=COLORS[1], lw=2),
            Line2D([0], [0], color="black", lw=2, ls="solid"),
            Line2D([0], [0], color="black", lw=2, ls="dashed"),
        ]
        labels = ["STJ", "EDJ", "1980-2009", "2070-2099"]
        axes.ravel()[0].legend(handles[2:], labels[2:], ncol=1, framealpha=1, loc="center left").set_zorder(102)
        axes.ravel()[2].legend(handles[:2], labels[:2], ncol=2, framealpha=1, loc="upper right").set_zorder(102)
    # plt.savefig(f"{FIGURES}/jet_props_misc/jet_props_seasonal{suffix}.png")
    if clear:
        del fig
        plt.close()
        clear_output()

In [None]:
data_vars = ["mean_lat", "mean_lev", "mean_s", "width"]
plot_seasonal(data_vars, props_cesm, nrows=2, ncols=2, clear=False, suffix="_subset")
plt.savefig(f"{FIGURES}/jet_props_cesm/seasonal_subset.pdf")

In [None]:
plot_seasonal(data_vars, past_props, nrows=2, ncols=3, clear=False, suffix="_subset")

In [24]:
unique_labels = np.arange(24)[None, :]
years_per_group = 10
labels_gb = labels.resample(time=f"{years_per_group}YS")
nyears = len(labels_gb)
fig, axes = plt.subplots(2, 3, subplot_kw={"aspect": "equal"}, figsize=(6, 4))
axes = axes.flatten()
pops = []
for (year, theselabels), ax in zip(labels_gb, axes):
    year_int = year.astype('datetime64[Y]').astype(int) + 1970
    pop_thisyear = (theselabels.values.flatten()[:, None] == unique_labels).sum(axis=0)
    pops.append(pop_thisyear)
    net.plot_on_map(
        pop_thisyear,
        fig=fig, 
        ax=ax,
    )
    ax.set_title(str(year_int))

In [None]:
timestepwise = []
group = [] 
for i, (_, group_) in enumerate(labels.groupby(labels.time.dt.dayofyear).groups.items()):
    group.append(group_)
    if i % 7 != 6:
        continue
    group = np.concatenate(group)
    coords = net.neighborhoods.coordinates[labels[group]]
    unique, count = np.unique(labels[group], return_counts=True)
    timestepwise.append((unique, count))
    group = [] 

In [None]:
fig, axes = plt.subplots(4, 3, figsize=(TEXTWIDTH_IN, 4.4), tight_layout=False, subplot_kw={"aspect": "equal"})
cmap = colormaps.bubblegum_r
norm = BoundaryNorm(np.arange(0, 2200, 200), cmap.N)
im = ScalarMappable(norm, cmap)
coords = net.neighborhoods.coordinates
fig.colorbar(im, ax=axes)
for i, ax in enumerate(axes.ravel()):
    step = i + 1
    unique, counts = timestepwise[i]
    to_plot = np.zeros(net.n_nodes)
    to_plot[unique] = counts
    fig, ax = net.plot_on_map(
        to_plot,
        smooth_sigma=0,
        show=False,
        cmap=cmap,
        norm=norm,
        fig=fig,
        ax=ax,
        draw_cbar=False,
    )
    # ax.errorbar(*com[step], *com_std[step][[1, 0]])
    ax.set_title(f"Week {step}", pad=2)
    
    for i, c in enumerate(coords):
        x, y = c
        ax.text(x, y, f'${to_prettier_order(i)}$', va='center', ha='center', color="white", fontsize=8)
fig.set_tight_layout(False)
plt.savefig(f"{FIGURES}/som_props/weekly_pathway.png")

In [None]:
kwargs_som_2 = dict(
    nx=8,
    ny=6,
    metric=metric,
    return_type=RAW_REALSPACE,
    force=False,
)
exp_s_cesm.som_cluster(**kwargs_som_2)

In [None]:
pop_thisyear