In [None]:
# https://climcal4.giub.unibe.ch:8000/user/hugo/?token=a71c03fa9afc41be95bbad3a0ba9d739
import os
import sys
import numpy as np
import xarray as xr
import cartopy.crs as ccrs
import cartopy.feature as feat
import xrft
import pickle as pkl
from scipy import constants as co
from scipy.stats import gaussian_kde, norm
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.cluster import KMeans
# import kmedoids
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib as mpl
from matplotlib.patches import Rectangle
from matplotlib.widgets import CheckButtons, Slider, Button, RadioButtons
from matplotlib.legend_handler import HandlerBase
from matplotlib.colors import ListedColormap, LinearSegmentedColormap, Normalize, CenteredNorm
import matplotlib.path as mpath
import IPython.display as disp
from hmmlearn.hmm import GaussianHMM
from metpy import calc as mcalc
from metpy import interpolate as minterpolate
from metpy.units import units
import pandas as pd
import time
from ipywidgets import IntProgress, HTML
import shutil
from cdo import Cdo
from nco import Nco
import hvplot.xarray # noqa
import panel.widgets as pnw
import panel as pn
from bokeh.resources import INLINE
# pn.extension(comms="vscode")
# from definitions import *

plt.rcParams['animation.ffmpeg_path'] = '/home/hugo/mambaforge-pypy3/envs/env/bin/ffmpeg'
# cdo = Cdo()
# nco = Nco()
pn.extension(comms="default")
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

# Definitions

### Utilities, platform specifics and constants

In [None]:
import platform
pf = platform.platform()
if pf.find("cray") >= 0:
    NODE = "daint"
elif platform.node()[:4] == "clim":
    NODE = "CLIM"
else: # find better later
    NODE = "UBELIX"
DATADIR = "/scratch/snx3000/hbanderi/data/persistent" if NODE == "daint" else "/scratch2/hugo"
CLIMSTOR = "/mnt/climstor/ecmwf/era5/raw"

def filenamescm(y, m, d):  # Naming conventions of the files on climstor (why are they so different?)
    return [f"{CLIMSTOR}/ML/data/{str(y)}/P{str(y)}{str(m).zfill(2)}{str(d).zfill(2)}_{str(h).zfill(2)}" for h in range(0, 24, 6)]
def filenamecp(y, m, d):
    return [f"{CLIMSTOR}/PL/data/an_pl_ERA5_{str(y)}-{str(m).zfill(2)}-{str(d).zfill(2)}.nc"]  # returns iterable to have same call signature as filenamescl(y, m, d)
def filenamegeneric(y, m, folder):
    return [f"{DATADIR}/{folder}/{y}{str(m).zfill(2)}.nc"]

def _fn(date, which):
    if which == "ML":
        return filenamescm(date.year, date.month, date.day)
    elif which == "PL":
        return filenamecp(date.year, date.month, date.day)
    else:
        return filenamegeneric(date.year, date.month, which)
    
def fn(date, which):  # instead takes pandas.timestamp (or iterable of _) as input
    if isinstance(date, (list, np.ndarray, pd.DatedayIndex)):
        filenames = []
        for d in date:
            filenames.extend(_fn(d, which))
        return filenames
    elif isinstance(date, pd.daystamp):
        return _fn(date, which)
    else:
        raise RundayError(f"Invalid type : {type(date)}")

RADIUS = 6.371e6  # m
OMEGA = 7.2921e-5  # rad.s-1
KAPPA = 0.2854
R_SPECIFIC_AIR = 287.0500676

def degcos(x):
    return np.cos(x / 180 * np.pi)
def degsin(x):
    return np.sin(x / 180 * np.pi)

DATERANGEPL = pd.date_range("19590101", "20211231")
YEARSPL = np.unique(DATERANGEPL.year)
DATERANGEML = pd.date_range("19770101", "20211231")
WINDBINS = np.arange(0, 25, 2)
LATBINS = np.arange(15, 75, 2.5)
LONBINS = np.arange(-90, 30, 3)
DEPBINS = np.arange(-25, 26, 1.5)

COLORS5 = [     # https://coolors.co/palette/ef476f-ffd166-06d6a0-118ab2-073b4c
    "#ef476f", # pinky red
    "#ffd166", # yellow
    "#06d6a0", # cyany green
    "#118ab2", # light blue
    "#073b4c", # dark blue
]

COLORS10 = [     # https://coolors.co/palette/ef476f-ffd166-06d6a0-118ab2-073b4c
    "#F94144", # Vermilion
    "#F3722C", # Orange
    "#F8961E", # Atoll
    "#F9844A", # Cadmium orange
    "#F9C74F", # Caramel
    "#90BE6D", # Lettuce green
    "#43AA8B", # Bright Parrot Green
    "#4D908E", # Abyss Green
    "#577590", # Night Blue
    "#277DA1", # Night Blue
]

ZOO = ['Lat', 'Int', 'Shar', 'Lats', 'Latn', 'Tilt', 'Lon', 'Lonw', 'Lone', 'Dep', "Mea"]

COASTLINE = feat.NaturalEarthFeature(
    "physical", "coastline", "10m", edgecolor="black", facecolor="none"
)
BORDERS = feat.NaturalEarthFeature(
    "cultural",
    "admin_0_boundary_lines_land",
    "10m",
    edgecolor="grey",
    facecolor="none",
)

### Zoo / JLI computation and plotting functions:

In [None]:
def cdf(timeseries):
    idxs = np.argsort(timeseries).values
    y = np.cumsum(idxs) / np.sum(idxs)
    x = timeseries.values[idxs]
    return x, y

### Create histogram
def compute_hist(timeseries, season, bins):
    if season is not None and season != "Annual":
        timeseries = timeseries.isel(time=timeseries.time.dt.season==season)
    return np.histogram(timeseries, bins=bins)

    
def histogram(timeseries, ax, season=None, bins=LATBINS, **kwargs):
    hist = compute_hist(timeseries, season, bins)
    midpoints = (hist[1][1:] + hist[1][:-1]) / 2
    bars = ax.bar(midpoints, hist[0], width=hist[1][1] - hist[1][0], **kwargs)
    return hist


def kde(timeseries, season=None, bins=LATBINS, scaled=False, return_x=False, **kwargs):
    hist = compute_hist(timeseries, season, bins)
    midpoints = (hist[1][1:] + hist[1][:-1]) / 2
    norm = (hist[1][1] - hist[1][0]) * np.sum(hist[0])
    kde = gaussian_kde(midpoints, weights=hist[0], **kwargs).evaluate(midpoints)
    if scaled:
        kde *= norm
    if return_x:
        return midpoints, kde
    return kde


def compute_anomaly(ds, return_clim=False, smooth_kmax=None):
    # needed to workaround xarray's check with zero dimensions
    # https://github.com/pydata/xarray/issues/3575
    if len(ds['time']) == 0:
        return ds
    gb = ds.groupby("time.dayofyear")
    clim = gb.mean(dim='time')
    if smooth_kmax:
        ft = xrft.fft(clim, dim="dayofyear")
        ft[:int(len(ft) / 2) - smooth_kmax] = 0
        ft[int(len(ft) / 2) + smooth_kmax:] = 0
        clim = xrft.ifft(
            ft, dim="freq_dayofyear", true_phase=True, true_amplitude=True
        ).real.assign_coords(dayofyear=clim.dayofyear)
    anom = (gb - clim).reset_coords("dayofyear", drop=True)
    if return_clim:
        return anom, clim
    return anom


### Lat and Int
def compute_JLI(da_Lat):
    LatI = da_Lat.argmax(dim="lat", skipna=True)
    Lat = xr.DataArray(da_Lat.lat[LatI.values.flatten()].values, coords={"time": da_Lat.time}).rename("Lat")
    Lat.attrs["units"] = "degree_north"
    Int = da_Lat.isel(lat=LatI).reset_coords("lat", drop=True).rename("Int")
    Int.attrs["units"] = "m/s"
    return Lat, Int
    
### Shar, Latn, Lats, 
def compute_shar(da_Lat, Int, Lat):
    Shar = (Int - da_Lat.mean(dim="lat")).rename("Shar")
    Shar.attrs["units"] = Int.attrs["units"]
    this = da_Lat - Shar / 2
    ouais = np.where(this.values[:, 1:] * this.values[:, :-1] < 0)
    hist = np.histogram(ouais[0], bins=np.arange(len(da_Lat.time) + 1))[0]
    cumsumhist = np.append([0], np.cumsum(hist)[:-1])
    Lats = xr.DataArray(da_Lat.lat.values[ouais[1][cumsumhist]], coords={"time": da_Lat.time}, name="Lats")
    Latn = xr.DataArray(da_Lat.lat.values[ouais[1][cumsumhist + hist - 1]], coords={"time": da_Lat.time}, name="Latn")
    Latn[Latn < Lat] = da_Lat.lat[-1]
    Lats[Lats > Lat] = da_Lat.lat[0]
    Latn.attrs["units"] = "degree_north"
    Lats.attrs["units"] = "degree_north"
    return Shar, Lats, Latn

### Tilt
def compute_Tilt(da, Lat):
    trackedLats = da.isel(lat=0).copy(data=np.zeros(da.shape[::2])).reset_coords("lat", drop=True).rename("Tracked Latitudes")
    trackedLats.attrs["units"] = "degree_north"
    lats = da.lat.values
    twodelta = lats[2] - lats[0]
    midpoint = int(len(da.lon) / 2)
    trackedLats[:, midpoint] = Lat
    iterator = zip(reversed(range(midpoint)), range(midpoint + 1, len(da.lon)))
    for lonw, lone in iterator:
        for k, thislon in enumerate((lonw, lone)):
            otherlon = thislon - (2 * k - 1) # previous step in the iterator for either east (k=1, otherlon=thislon-1) or west (k=0, otherlon=thislon+1)
            mask = np.abs(trackedLats[:, otherlon].values[:, None] - lats[None, :]) > twodelta 
            # mask = where not to look for a maximum. The next step (forward for east or backward for west) needs to be within twodelta of the previous (otherlon)
            here = np.ma.argmax(np.ma.array(da.isel(lon=thislon).values, mask=mask), axis=1)
            trackedLats[:, thislon] = lats[here]
    Tilt = trackedLats.polyfit(dim="lon", deg=1).sel(degree=1)["polyfit_coefficients"].reset_coords("degree", drop=True).rename("Tilt")
    Tilt.attrs["units"] = "degree_north/degree_east"
    return trackedLats, Tilt

### Lon
def compute_Lon(da, trackedLats):
    Intlambda = da.sel(lat=trackedLats).reset_coords("lat", drop=True)
    Intlambdasq = Intlambda * Intlambda
    lons = xr.DataArray(da.lon.values[None, :] * np.ones(len(da.time))[:, None], coords={"time": da.time, "lon": da.lon})
    Lon = (lons * Intlambdasq).sum(dim="lon") / Intlambdasq.sum(dim="lon")
    Lon.attrs["units"] = "degree_east"
    return Intlambda, Lon.rename("Lon")

### Lonw, Lone
def compute_Lonew(da, Intlambda, Lon):
    Intlambda = Intlambda.values
    Mean = np.mean(Intlambda, axis=1)
    lon = da.lon.values
    iLon = np.argmax(lon[None, :] - Lon.values[:, None] > 0, axis=1)
    basearray = Intlambda - Mean[:, None] < 0
    iLonw = np.ma.argmin(np.ma.array(basearray, mask=lon[None, :] > Lon.values[:, None]), axis=1) - 1
    iLone = np.ma.argmax(np.ma.array(basearray, mask=lon[None, :] <= Lon.values[:, None]), axis=1) - 1
    Lonw = xr.DataArray(lon[iLonw], coords={"time": da.time}, name="Lonw")
    Lone = xr.DataArray(lon[iLone], coords={"day": da.day}, name="Lone")
    Lonw.attrs["units"] = "degree_east"
    Lone.attrs["units"] = "degree_east"
    return Lonw, Lone
    
### Dep
def compute_Dep(da, trackedLats):
    phistarl = xr.DataArray(da.lat.values[da.argmax(dim="lat").values], coords={"time": da.time.values, "lon": da.lon.values})
    Dep = np.sqrt((phistarl - trackedLats) ** 2).sum(dim="lon").rename("Dep")
    Dep.attrs["units"] = "degree_north"
    return Dep


def make_boundary_path(minlon,maxlon,minlat,maxlat,n=50):
    '''
    return a matplotlib Path whose points are a lon-lat box given by
    the input parameters
    '''

    boundary_path = []
    #North (E->W)
    edge = [np.linspace(minlon,maxlon,n), np.full(n,maxlat)]
    boundary_path += [[i,j] for i,j in zip(*edge)]

    #West (N->S)
    edge = [np.full(n,maxlon),np.linspace(maxlat,minlat,n)]
    boundary_path += [[i,j] for i,j in zip(*edge)]

    #South (W->E)
    edge = [np.linspace(maxlon,minlon,n), np.full(n,minlat)]
    boundary_path += [[i,j] for i,j in zip(*edge)]

    #East (S->N)
    edge = [np.full(n,minlon),np.linspace(minlat,maxlat,n)]
    boundary_path += [[i,j] for i,j in zip(*edge)]

    boundary_path = mpath.Path(boundary_path)

    return boundary_path

# Barriopedro

### Package

In [None]:
dataset = "ERA5" # dataset = "NCEP"
datadir = f"{DATADIR}/{dataset}/packaged/"

In [None]:
da = xr.open_dataset(f"{datadir}/BarriopedroRaw.nc")["u"]
if dataset == "ERA5":
    da = da.rename({"longitude": "lon", "latitude": "lat"})
da2 = da.rolling(lon=60 if dataset == "ERA5" else 24, center=True).mean().sel(lon=np.arange(-60, 0.1, 0.5))
da_fft = xrft.fft(da2, dim="time")
da_fft[np.abs(da_fft.freq_time) > 1 / 10 / 24 / 3600] = 0
da3 = xrft.ifft(da_fft, dim="freq_time", true_phase=True, true_amplitude=True).real.assign_coords(time=da.time).rename("u")
da2.attrs["unit"] = "m/s"
da3.attrs["unit"] = "m/s"
da3.to_netcdf(f"{datadir}/Barriopedro.nc")

### Compute

In [None]:
da = xr.open_dataset(f"{datadir}/Barriopedro.nc")["u"]
da_Lat = da.sel(lon=-30.).reset_coords("lon", drop=True)
Lat, Int, Latmean, Latsmooth = compute_JLI(da_Lat)
Shar, Lats, Latn = compute_shar(da_Lat, Int, Lat)
trackedLats, Tilt = compute_Tilt(da, Lat)
Intlambda, Lon = compute_Lon(da, trackedLats)
Lonw, Lone = compute_Lonew(da, Intlambda, Lon)
Dep = compute_Dep(da, trackedLats)
Zoo = xr.Dataset({
    "Lat": Lat, 
    "Int" : Int, 
    "Shar": Shar, 
    "Lats": Lats, 
    "Latn" : Latn, 
    "Tilt" : Tilt, 
    "Lon" : Lon, 
    "Lonw" : Lonw, 
    "Lone" : Lone, 
    "Dep" : Dep,
})
Zoo.to_netcdf(f"{DATADIR}/{dataset}/processed/BarriopedroZoo.nc")

### Plot

In [None]:
%matplotlib inline
# timeseries, ax, season=None, bins=LATBINS, **kwargs
dataset = "ERA5"
Zoo = xr.open_dataset(f"{DATADIR}/{dataset}/processed/BarriopedroZooDetrended.nc")
fig, axes = plt.subplots(2, 2, figsize=(14, 14))
axes = axes.flatten()
mapping = [["Int", "Shar"], ["Lat", "Lats", "Latn"], ["Lon", "Lone", "Lonw"], ["Tilt", "Dep"]]
bins = [WINDBINS, LATBINS, LONBINS, DEPBINS]
for i, group in enumerate(mapping):
    ax = axes[i]
    for j, key in enumerate(group):
        midpoints, gkde = kde(Zoo[key], "DJF", bins[i], scaled=True, return_x=True)
        ax.plot(
            midpoints, gkde, color=COLORS5[j], label=key
        )
    ax.legend()
plt.show()

## Meandering Index

In [None]:
dataset = "ERA5"
da = xr.open_mfdataset(
    f"{DATADIR}/{dataset}/Geopotential/dailymean/*.nc"
).rename(
    {"longitude": "lon", "latitude": "lat"}
).isel(lat=np.arange(180, 361))["z"].load() / co.g

In [None]:
import contourpy
from joblib import delayed, Parallel, dump, load

def meandering(lines):
    m = 0
    for line in lines:
        m += np.sum(np.sqrt(np.sum(np.diff(line, axis=0)**2, axis=1))) / 360
    return m

def one_ts(lon, lat, da):
    m = []
    gen = contourpy.contour_generator(x=lon, y=lat, z=da)
    for lev in range(4900, 6205, 5):
        m.append(meandering(gen.lines(lev)))
    return np.amax(m)

lon = da.lon.values
lat = da.lat.values
M = Parallel(
    n_jobs=32, backend="loky", max_nbytes=1e5
)(
    delayed(one_ts)(lon, lat, da.sel(time=t).values) for t in da.time[:]
)
daM = xr.DataArray(M, coords={"time":da.time})
daM.to_netcdf(f"{DATADIR}/{dataset}/processed/Meandering.nc")

### Plot like Dicapua et al. 2016

In [None]:
daM = xr.open_dataarray(f"{DATADIR}/{dataset}/processed/Meandering.nc")
daM_anom = compute_anomaly(daM)
fig, axes = plt.subplots(1, 5, figsize=(18, 5), sharey=True)
early = daM.isel(time=(daM.time.dt.year < 2000))
recent = daM.isel(time=(daM.time.dt.year >= 2000))
fig.subplots_adjust(wspace=0.04)
for i, season in enumerate(["DJF", "MAM", "JJA", "SON", "Annual"]):
    
    kde_early = kde(early, season, 50, scaled=False, return_x=True)
    kde_recent = kde(recent, season, 50, scaled=False, return_x=True)
    axes[i].plot(*kde_early, label="pre 2000", color=COLORS5[3], lw=2.5)
    axes[i].plot(*kde_recent, label="post 2000", color=COLORS5[0], lw=2.5)
    axes[i].fill_between(kde_early[0], 0, kde_early[1], color=COLORS5[3], alpha=0.5)
    axes[i].fill_between(kde_recent[0], 0, kde_recent[1], color=COLORS5[0], alpha=0.5)
    axes[i].set_title(season)
    axes[i].set_xlabel("Meandering Index")
    axes[i].set_yticks([0, 0.5, 1, 1.5, 2.0])
    axes[i].set_xlim([1, 3.3])
    axes[i].set_ylim([0, 2])
axes[0].set_ylabel("Density")
axes[-1].legend()

## Combine, detrend

In [None]:
# timeseries, ax, season=None, bins=LATBINS, **kwargs
dataset = "ERA5"
Zoo = xr.open_dataset(f"{DATADIR}/{dataset}/processed/BarriopedroZoo.nc")
daM = xr.open_dataarray(f"{DATADIR}/{dataset}/processed/Meandering.nc")
Zoo["Mea"] = daM
for key, value in Zoo.data_vars.items():
    noseason, Zoo[f"{key}_climatology"] = compute_anomaly(value, return_clim=True, smooth_kmax=3)
    Zoo[f"{key}_anomaly"] = xrft.detrend(noseason, dim="time", detrend_type="linear")
Zoo.to_netcdf(f"{DATADIR}/{dataset}/processed/BarriopedroZooDetrended.nc")

In [None]:
Zoo["Lat_anomaly"].plot.hist(bins=np.arange(-30,31,1));

# EDG (see Barriopedro)

### Create EDG filtered datasets

In [None]:
from scipy.stats import gaussian_kde

In [None]:
### EDG
ds_EDG = xr.open_dataset(f"{datadir}/EDG.nc").isel(level=0)

In [None]:
for varname in ["u", "v"]:
    da_fft_bgrnd = xrft.fft(ds_EDG[f"{varname}wnd"], dim="time")
    da_fft_trans = da_fft_bgrnd.copy()
    freq = np.abs(da_fft.freq_time)
    da_fft_bgrnd[freq > 1 / 10 / 24 / 3600] = 0
    da_fft_trans[np.logical_or(freq > 1 / 2 / 24 / 3600, freq < 1 / 6 / 24 / 3600)] = 0
    ds_EDG[f"{varname}bgrnd"] = xrft.ifft(da_fft_bgrnd, dim="freq_time", true_phase=True, true_amplitude=True).real.assign_coords(time=ds_EDG.time).rename(f"{varname}bgrnd")
    ds_EDG[f"{varname}trans"] = xrft.ifft(da_fft_trans, dim="freq_time", true_phase=True, true_amplitude=True).real.assign_coords(time=ds_EDG.time).rename(f"{varname}trans")

In [None]:
ds_EDG.to_netcdf(f"{datadir}/EDG3.nc")

### EDG Computations

In [None]:
ds_EDG = xr.open_dataset(f"{DATADIR}/EDG3.nc")

In [None]:
ds_EDG["E1"] = (ds_EDG["vtrans"] ** 2 - ds_EDG["utrans"] ** 2) / 2
ds_EDG["E2"] = - ds_EDG["utrans"] * ds_EDG["vtrans"]
### D vector in spherical coordinates, see Obsidian page for this
ds_EDG["D1"] = ds_EDG["ubgrnd"].differentiate("lon") / RADIUS \
               - 1 / degsin(ds_EDG["lat"]) / RADIUS * ds_EDG["vbgrnd"].differentiate("lat") \
               - ds_EDG["ubgrnd"] * degcos(ds_EDG["lat"]) / degsin(ds_EDG["lat"]) / RADIUS
ds_EDG["D2"] = 0.5 * (degsin(ds_EDG["lat"]) / RADIUS * (ds_EDG["vbgrnd"] / degsin(ds_EDG["lat"])).differentiate("lon") \
                    + 1 / degsin(ds_EDG["lat"]) / RADIUS * ds_EDG["ubgrnd"].differentiate("lat"))
### Generation rate
ds_EDG["G"] = ds_EDG["E1"] * ds_EDG["D1"] + ds_EDG["E2"] * ds_EDG["D2"]
### while we're at it, let's compute the vorticity
ds_EDG["omega"] = 1 / RADIUS / degcos(ds_EDG["lat"]) * (ds_EDG["vwnd"].differentiate("lon") - \
                    (ds_EDG["uwnd"] * degcos(ds_EDG["lat"])).differentiate("lat"))
ds_EDG["EKE"] = 0.5 * (ds_EDG["utrans"] ** 2 + ds_EDG["vtrans"] ** 2)
for key in ds_EDG.data_vars:
    ds_EDG[key].to_netcdf(f"{DATADIR}/NCEP/processed/{key}.nc")

In [None]:
for key in ds_EDG.data_vars:
    ds_EDG[key].to_netcdf(f"{DATADIR}/NCEP/processed/{key}.nc")

# Stationarity

### Zoo Autocorrelation:

In [None]:
howmany = 50
f1 = IntProgress(value=0, max=len(Zoo.data_vars))
f2 = IntProgress(value=0, max=howmany)
display(f1, f2)
autocorrs = {}

for i, varname in enumerate(Zoo):
    if varname.split('_')[-1] == "climatology":
        continue
    f2.value = 0
    autocorrs[varname] = ("lag", np.empty(howmany))
    for j in range(howmany):
        autocorrs[varname][1][j] = xr.corr(Zoo[varname], Zoo[varname].shift(time=j)).values
        f2.value = j + 1
    f1.value = i + 1
autocorrsda = xr.Dataset(autocorrs, coords={"lag": np.arange(howmany)})
autocorrsda.to_netcdf(f"{DATADIR}/{dataset}/processed/Zoo_autocorrs.nc")

In [None]:
fig, axes = plt.subplots(3, 4, figsize=(20, 16), tight_layout=True)
axes = axes.flatten(order="F")
dataset = "ERA5"
datadir = f"{DATADIR}/{dataset}/processed" # daint
autocorrs = xr.open_dataset(f"{datadir}/Zoo_autocorrs.nc")
howmany = len(autocorrs.coords["lag"])
newlist = []
for key in list(Zoo.data_vars.keys())[:11]:
    for suffix in ["_anomaly"]:
        newlist.append(f"{key}{suffix}")
telatex = r"$T^e_{\rho}$"
tdlatex = r"$T^d_{\rho}$"
tclatex = r"$T^c_{\rho}$"
lw = 2
for i, varname in enumerate(newlist):
    te = np.argmax(autocorrs[varname].values <= 1 / np.exp(1))
    td = 1 + 2 * np.sum(autocorrs[varname])
    tc = 1 + np.sum(autocorrs[varname] * (1 - np.arange(1, howmany + 1) / (howmany + 1)))
    axes[i].plot(np.arange(howmany), autocorrs[varname], color=COLORS5[0], lw=lw)
    axes[i].plot([te, te], [0, 1], label=telatex, color=COLORS5[2], lw=lw)
    axes[i].plot([tc, tc], [0, 1], label=tclatex, color=COLORS5[3], lw=lw)
    axes[i].plot([td, td], [0, 1], label=tdlatex, color=COLORS5[4], lw=lw)
    axes[i].grid()
    axes[i].legend()
    axes[i].set_title(f"{varname}, {telatex}={te}, {tdlatex}={td:.3f}, {tclatex}={tc:.3f}")
    axes[i].set_ylabel("Autocorrelation")
    axes[i].set_xlabel("Lag time [days]")

###  Zoo Hurst exponent

In [None]:
fig, ax = plt.subplots()
subdivs = [2**n for n in range(11)]
lengths = [len(Zoo.time) // n for n in subdivs]
all_lengths = np.repeat(lengths, subdivs)
N_chunks = np.sum(subdivs)
Hurst = {}
for i, varname in enumerate(Zoo.data_vars):
    adjusted_ranges = []
    for n_chunks, n in zip(subdivs, lengths):
        start = 0
        aranges = []
        for k in range(n_chunks):
            end = start + n
            series = Zoo[varname].isel(time=np.arange(start, end)).values
            mean = np.mean(series)
            std = np.std(series)
            series -= mean
            series = np.cumsum(series)
            raw_range = series.max() - series.min()
            aranges.append(raw_range / std)
        adjusted_ranges.append(np.mean(aranges))
    print(varname)
    ax.loglog(lengths, adjusted_ranges, color=COLORS10[i % 10])
    coeffs = np.polyfit(np.log(lengths), np.log(adjusted_ranges), deg=1)
    Hurst[varname] = [coeffs[0], np.exp(coeffs[1])]
    ax.loglog(lengths, np.exp(coeffs[1]) * lengths ** coeffs[0], color=COLORS10[i % 10])
with open(f"{datadir}/Hurst_ERA5.pkl", "wb") as handle:
    pkl.dump(Hurst, handle)

In [None]:
with open(f"{DATADIR}/ERA5/processed/Hurst.pkl", "rb") as handle:
    Hurst = pkl.load(handle)

In [None]:
Hurst

# Clustering

### Z500

In [None]:
ds = xr.open_dataset(
    f"{DATADIR}/ERA5/Geopotential/north_atlantic/full.nc"
).rename(
    {"longitude": "lon", "latitude": "lat"}
).isel(lon=np.arange(241), lat=np.arange(30, 131))
ds["z"] /= co.g
da = ds["z"].chunk({"time": -1, "lon": 121})
anomaly = xr.map_blocks(compute_anomaly, da, template=da)
detrended = xr.map_blocks(xrft.detrend, anomaly, args=("time", "linear"), template=da)
anomaly.to_netcdf(f"{DATADIR}/ERA5/Geopotential/north_atlantic/anomaly.nc")
detrended.to_netcdf(f"{DATADIR}/ERA5/Geopotential/north_atlantic/detrended.nc")
# ds = xr.open_dataset(
#     f"{DATADIR}/NCEP/packaged/Z500NA.nc"
# ).rename({"hgt": "z"}).reset_index("level", drop=True).squeeze().isel(lat=np.arange(4, 25))

In [None]:
n_clu = 4
thisda = xr.open_dataarray(f"{DATADIR}/ERA5/Geopotential/north_atlantic/detrended.nc")
tbt = (thisda * degcos(thisda.lat)).values.reshape(len(thisda.time), -1)
results = KMeans(n_clu, n_init="auto").fit(tbt)
with open("kmeans_ERA5_detrended.pkl", "wb") as handle:
    pkl.dump(results, handle)
# distmatrix = euclidean_distances(tbt)
# results = kmedoids.fasterpam(distmatrix, n_clu)
# centers = thisda.isel(time=results.medoids).rename({"time": "cluster"}).assign_coords({"cluster": np.arange(n_clu)}).compute()

In [None]:
with open("kmeans_ERA5_detrended.pkl", "rb") as handle:
    results = pkl.load(handle)
thisda = xr.open_dataarray(f"{DATADIR}/ERA5/Geopotential/north_atlantic/detrended.nc")
centers = xr.DataArray(
    results.cluster_centers_.reshape(n_clu, *thisda.shape[1:]), 
    coords={"cluster": np.arange(n_clu), "lat": ds.lat.values, "lon": ds.lon.values},
) / degcos(thisda.lat)
projection = ccrs.LambertConformal(
    central_longitude=np.mean(da.lon.values),
)
lon = ds["lon"].values
lat = ds["lat"].values
fig, axes = plt.subplots(2, 2, figsize=(10, 7.5), subplot_kw={"projection": projection}, constrained_layout=True)
extent = [np.amin(lon), np.amax(lon), np.amin(lat), np.amax(lat)]
boundary = make_boundary_path(*extent)
levels = 11
cmap = "seismic"
lower, upper = -150, 150
levels = np.delete(np.append(np.linspace(lower, 0, levels), np.linspace(0, upper, levels)), [levels - 1, levels])
cmap = cm.get_cmap(cmap)
norm = mpl.colors.BoundaryNorm(levels, cmap.N, extend='both')
im = cm.ScalarMappable(norm=norm, cmap=cmap)
axes = axes.flatten()
for i in range(n_clu):
    axes[i].contourf(
        lon,
        lat,
        centers.isel(cluster=i), 
        transform=ccrs.PlateCarree(),
        levels=levels, cmap=cmap, norm=norm,
        extend="both",
    )
    axes[i].contour(
        lon,
        lat,
        centers.isel(cluster=i), 
        transform=ccrs.PlateCarree(),
        levels=levels, colors="k",
    )
    axes[i].set_boundary(boundary, transform=ccrs.PlateCarree())
    axes[i].add_feature(COASTLINE)
    axes[i].add_feature(BORDERS)
    axes[i].set_title(f"Regime {i + 1}, {np.sum(results.labels_ == i) / len(results.labels_) * 100:.2f}%")
cbar = fig.colorbar(im, ax=axes.ravel().tolist(), spacing="proportional")
cbar.ax.set_ylabel("Z500 [m]")
_ = cbar.ax.set_yticks(np.concatenate([np.arange(-150, 20, 30), np.arange(30, 151, 30)]))
plt.show()

#### Predict hot spells

In [None]:
with open("kmeans_ERA5_detrended.pkl", "rb") as handle:
    results = pkl.load(handle)

In [None]:
list_of_dates = np.loadtxt("hotspells.csv", delimiter=",", dtype=np.datetime64)
hotspells_clusters = {}
keys = ["South", "West", "Balkans", "Scandinavia", "Russia", "Arctic"]
detrended = detrended.compute()
minus = 21
plus = 5
for j, key in enumerate(keys):
    dates = np.sort(list_of_dates[:, j])
    dates = dates[~(np.isnat(dates) | (np.datetime_as_string(dates, unit="Y") == "2022"))]
    hotspells_clusters[key] = []
    for i, date in enumerate(dates):
        tsta = date - np.timedelta64(minus, "D") + np.timedelta64(9, "h")
        tend = date + np.timedelta64(plus, "D") + np.timedelta64(9, "h")
        to_predict = detrended.sel(time=pd.date_range(tsta, tend, freq="1D"))
        to_predict = to_predict.values.reshape(len(to_predict.time), -1)
        hotspells_clusters[key].append(results.predict(to_predict))
    hotspells_clusters[key] = xr.DataArray(
        np.stack(hotspells_clusters[key]).transpose(), 
        coords={"time": np.arange(-minus, plus + 1), "hotspell": dates}
    )
with open("hotspells_clusters.pkl", "wb") as handle:
    pkl.dump(hotspells_clusters, handle)

In [None]:
%matplotlib inline
with open("hotspells_clusters.pkl", "rb") as handle:
    hotspells_clusters = pkl.load(handle)
hot_time = hotspells_clusters["South"].time.values
to_plot = np.zeros((n_clu // 2, len(keys), 2 * len(hot_time)))
cmaps = ["Blues", "Greens", "Reds", "Purples"]
for i in range(4):
    for j, (key, value) in enumerate(hotspells_clusters.items()):
        to_plot[int(i <= 1), j, (i % 2)::2] = i + (value == i).mean(dim="hotspell").values
class HandlerColormap(HandlerBase): # https://stackoverflow.com/questions/55501860/how-to-put-multiple-colormap-patches-in-a-matplotlib-legend
    def __init__(self, cmap, num_stripes=8, **kw):
        HandlerBase.__init__(self, **kw)
        self.cmap = cmap
        self.num_stripes = num_stripes
    def create_artists(self, legend, orig_handle, xdescent, ydescent, width, height, fontsize, trans):
        stripes = []
        for i in range(self.num_stripes):
            s = Rectangle(
                [xdescent + i * width / self.num_stripes, ydescent], 
                width / self.num_stripes, 
                height, 
                fc=self.cmap((2 * i + 1) / (2 * self.num_stripes)), 
                transform=trans,
            )
            stripes.append(s)
        return stripes

cmaps = [mpl.colormaps[cmap].resampled(256) for cmap in ["Blues", "Greens", "Reds", "Purples"]]
cmap = ListedColormap(np.concatenate([cmap(np.linspace(0, 1, 256)) for cmap in cmaps]))

fig, axes = plt.subplots(len(keys), 1, figsize=(15, 5))
norm = Normalize(0, len(cmaps))
for j, ax in enumerate(axes):
    ax.set_yticks([1])
    ax.set_yticklabels([list(hotspells_clusters.keys())[j]])
    if j==len(axes) - 1:
        ax.set_xticks(np.arange(hot_time[0] + 0.5, hot_time[-1] + 1.5))
        ax.set_xticklabels(np.arange(hot_time[0], hot_time[-1] + 1))
        ax.set_xlabel("Days around center")
    else:
        ax.set_xticks([])
    ax.set_frame_on(False)
    for i in range(2):
        ax.pcolormesh(
            np.arange(hot_time[0], hot_time[-1] + 1.1, .5), 
            np.arange(3), to_plot[:, j].reshape(2, len(hot_time) * 2), 
            cmap=cmap, norm=norm
        )
    ax.vlines(np.arange(hot_time[0] + 1, hot_time[-1] + 1), 0, 2, color="white", lw=4)
cmap_handles = [Rectangle((0, 0), 3, 1) for _ in cmaps]
handler_map = dict(zip(cmap_handles, [HandlerColormap(cm, num_stripes=20) for cm in cmaps]))
cmap_labels = [f"regime {k + 1}" for k in range(4)]
axes[-1].legend(
    handles=cmap_handles, 
    labels=cmap_labels, 
    handler_map=handler_map, 
    fontsize=12,
    bbox_to_anchor=(1.13, 4.2),
    loc="upper right",
)
plt.subplots_adjust(hspace=0.1)

In [None]:
%matplotlib inline
with open("hotspells_clusters.pkl", "rb") as handle:
    hotspells_clusters = pkl.load(handle)
with open("kmeans_ERA5_detrended.pkl", "rb") as handle:
    results = pkl.load(handle)
n_clu = 4
abs_freq = [np.mean(results.labels_ == i) for i in range(n_clu)]
hot_time = hotspells_clusters["South"].time.values
to_plot = np.zeros((n_clu, len(hotspells_clusters), len(hot_time)))
cmaps = ["Blues", "Greens", "Reds", "Purples"]
for i in range(n_clu):
    for j, (key, value) in enumerate(hotspells_clusters.items()):
        to_plot[i, j] = abs_freq[i] - (value == i).mean(dim="hotspell").values

fig, axes = plt.subplots(len(hotspells_clusters), 1, figsize=(15, 8), sharex=True)
fig.subplots_adjust(hspace=0, wspace=0, left=0.06)
for j, key in enumerate(hotspells_clusters):
    ax = axes[j]
#     ax.spines[["left", "right"]].set_visible(False)
    for i in range(n_clu):
        ax.plot(hot_time, to_plot[i, j], lw=2, color=COLORS5[i], label=f"Regime {i + 1}")
    ax.set_ylim([-0.5, 0.5])
    ax.set_yticks([-0.25, 0, 0.25])
    ax.set_yticklabels([-25, 0, 25])
    ax.text(-20.8, 0.21, key, fontweight="bold")
    ax.grid()
fig.supylabel('Regime relative occurence [%]')
ax.set_xlabel("Time around center [Days]")
ax.set_xlim([-21, 5])
ax.legend(bbox_to_anchor=(1.12, 3.53),)

# PCA

In [None]:
from sklearn.decomposition import PCA as pca
n_components = 20
thisda = xr.open_dataarray(f"{DATADIR}/ERA5/Geopotential/north_atlantic/detrended.nc")
tbt = (thisda * np.sqrt(degcos(thisda.lat))).values.reshape(len(thisda.time), -1)
pca_results = pca(n_components=n_components, whiten=True).fit(tbt)

In [None]:
centers = xr.DataArray(
    pca_results.components_.reshape(n_components, *thisda.shape[1:]), 
    coords={"component": np.arange(n_components), "lat": ds.lat.values, "lon": ds.lon.values},
) / np.sqrt(degcos(thisda.lat))
projection = ccrs.LambertConformal(
    central_longitude=np.mean(da.lon.values),
)
lon = ds["lon"].values
lat = ds["lat"].values
fig, axes = plt.subplots(5, 4, figsize=(20, 25), subplot_kw={"projection": projection}, constrained_layout=True)
extent = [np.amin(lon), np.amax(lon), np.amin(lat), np.amax(lat)]
boundary = make_boundary_path(*extent)
levels = 10
cmap = "seismic"
lower, upper = -0.045, 0.045
levels = np.delete(np.append(np.linspace(lower, 0, levels), np.linspace(0, upper, levels)), [levels - 1, levels])
cmap = cm.get_cmap(cmap)
norm = mpl.colors.BoundaryNorm(levels, cmap.N, extend='both')
im = cm.ScalarMappable(norm=norm, cmap=cmap)
axes = axes.flatten()
for i in range(n_components):
    axes[i].contourf(
        lon,
        lat,
        centers.isel(component=i), 
        transform=ccrs.PlateCarree(),
        levels=levels, 
        cmap=cmap, 
        norm=norm,
        extend="both",
    )
    axes[i].contour(
        lon,
        lat,
        centers.isel(component=i), 
        transform=ccrs.PlateCarree(),
        levels=levels, 
        colors="k",
    )
    axes[i].set_boundary(boundary, transform=ccrs.PlateCarree())
    axes[i].add_feature(COASTLINE)
    axes[i].add_feature(BORDERS)
    axes[i].set_title(f"{pca_results.explained_variance_ratio_[i] * 100:.2f} %")
cbar = fig.colorbar(im, ax=axes.ravel().tolist(), spacing="proportional")
cbar.ax.set_ylabel("Z500 [m]")
# _ = cbar.ax.set_yticks(np.concatenate([np.arange(-0.045, -0.004, 0.005), np.arange(0.005, 0.046, 0.005)]))
plt.show()

In [None]:
with open("pca_ERA5_detrended.pkl", "wb") as handle:
    pkl.dump(pca_results, handle)

In [None]:
reduced = pca_results.transform(tbt)

# K means on EOFs

In [None]:
pca_kmeans_results = KMeans(n_clu, n_init="auto").fit(reduced)

In [None]:
# with open("kmeans_ERA5_detrended.pkl", "rb") as handle:
#     results = pkl.load(handle)
thisda = xr.open_dataarray(f"{DATADIR}/ERA5/Geopotential/north_atlantic/detrended.nc")
centers = xr.DataArray(
    pca_results.inverse_transform(pca_kmeans_results.cluster_centers_).reshape(n_clu, *thisda.shape[1:]), 
    coords={"cluster": np.arange(n_clu), "lat": ds.lat.values, "lon": ds.lon.values},
) / np.sqrt(degcos(thisda.lat))
projection = ccrs.LambertConformal(
    central_longitude=np.mean(da.lon.values),
)
lon = ds["lon"].values
lat = ds["lat"].values
fig, axes = plt.subplots(2, 2, figsize=(10, 7.5), subplot_kw={"projection": projection}, constrained_layout=True)
extent = [np.amin(lon), np.amax(lon), np.amin(lat), np.amax(lat)]
boundary = make_boundary_path(*extent)
levels = 11
cmap = "seismic"
lower, upper = -150, 150
levels = np.delete(np.append(np.linspace(lower, 0, levels), np.linspace(0, upper, levels)), [levels - 1, levels])
cmap = cm.get_cmap(cmap)
norm = mpl.colors.BoundaryNorm(levels, cmap.N, extend='both')
im = cm.ScalarMappable(norm=norm, cmap=cmap)
axes = axes.flatten()
for i in range(n_clu):
    axes[i].contourf(
        lon,
        lat,
        centers.isel(cluster=i), 
        transform=ccrs.PlateCarree(),
        levels=levels, cmap=cmap, norm=norm,
        extend="both",
    )
    axes[i].contour(
        lon,
        lat,
        centers.isel(cluster=i), 
        transform=ccrs.PlateCarree(),
        levels=levels, colors="k",
    )
    axes[i].set_boundary(boundary, transform=ccrs.PlateCarree())
    axes[i].add_feature(COASTLINE)
    axes[i].add_feature(BORDERS)
    axes[i].set_title(f"Regime {i + 1}, {np.sum(results.labels_ == i) / len(results.labels_) * 100:.2f}%")
cbar = fig.colorbar(im, ax=axes.ravel().tolist(), spacing="proportional")
cbar.ax.set_ylabel("Z500 [m]")
_ = cbar.ax.set_yticks(np.concatenate([np.arange(-150, 20, 30), np.arange(30, 151, 30)]))
plt.show()

# OPP from EOFs (optimal $T_1$, TODO $T_2$) 

In [None]:
da_reduced = xr.DataArray(reduced, coords={"time": thisda.time.values, "projection": np.arange(n_components)})

lag_max = 15 # days
autocorrs = []
for j in range(lag_max):
    autocorrs.append(np.cov(da_reduced.values[j:], da_reduced.shift(time=j).values[j:], rowvar=False)[n_components:, :n_components])

autocorrs = np.asarray(autocorrs)
M = autocorrs[0] + np.sum([autocorrs[i] + autocorrs[i].transpose() for i in range(1, lag_max)], axis=0)

eigenvals, eigenvecs = np.linalg.eig(1 / autocorrs[0] @ M)

OPPs_realspace = np.tensordot(eigenvecs, pca_results.components_, axes=1)

In [None]:
centers = xr.DataArray(
    OPPs_realspace.reshape(n_components, *thisda.shape[1:]), 
    coords={"component": np.arange(n_components), "lat": ds.lat.values, "lon": ds.lon.values},
)
projection = ccrs.LambertConformal(
    central_longitude=np.mean(da.lon.values),
)
lon = ds["lon"].values
lat = ds["lat"].values
fig, axes = plt.subplots(5, 4, figsize=(20, 25), subplot_kw={"projection": projection}, constrained_layout=True)
extent = [np.amin(lon), np.amax(lon), np.amin(lat), np.amax(lat)]
boundary = make_boundary_path(*extent)
levels = 10
cmap = "seismic"
lower, upper = -0.045, 0.045
# levels = np.delete(np.append(np.linspace(lower, 0, levels), np.linspace(0, upper, levels)), [levels - 1, levels])
# cmap = cm.get_cmap(cmap)
# norm = mpl.colors.BoundaryNorm(levels, cmap.N, extend='both')
# im = cm.ScalarMappable(norm=norm, cmap=cmap)
axes = axes.flatten()
for i in range(n_components):
    im = axes[i].contourf(
        lon,
        lat,
        centers.isel(component=i), 
        transform=ccrs.PlateCarree(),
#         levels=levels, 
        cmap=cmap, 
#         norm=norm,
        extend="both",
    )
    axes[i].contour(
        lon,
        lat,
        centers.isel(component=i), 
        transform=ccrs.PlateCarree(),
#         levels=levels, 
        colors="k",
    )
    axes[i].set_boundary(boundary, transform=ccrs.PlateCarree())
    axes[i].add_feature(COASTLINE)
    axes[i].add_feature(BORDERS)
    axes[i].set_title(eigenvals[i])
cbar = fig.colorbar(im, ax=axes.ravel().tolist(), spacing="proportional")
cbar.ax.set_ylabel("Z500 [m]")
# _ = cbar.ax.set_yticks(np.concatenate([np.arange(-0.045, -0.004, 0.005), np.arange(0.005, 0.046, 0.005)]))
plt.show()

## Hidden Markov Model

In [None]:
Zoo = xr.open_dataset(f"{DATADIR}/ERA5/processed/BarriopedroZooDetrended.nc")
Y = Zoo["Lat_anomaly"].isel(time=Zoo.time.dt.season=="DJF").values[:, None]

In [None]:
n_components = 3
ghmm = GaussianHMM(n_components=n_components).fit(Y)
# im = plt.imshow(ghmm.transmat_)
# plt.colorbar(im)

In [None]:
for i in range(n_components):
    thisnorm = norm(loc=ghmm.means_[i][0], scale=np.sqrt(ghmm.covars_[i][0][0]))
    X = np.linspace(thisnorm.ppf(0.005), thisnorm.ppf(0.995), 100)
    plt.plot(X, thisnorm.pdf(X) / n_components)
plt.plot(*kde(Y, season=None, bins=np.arange(-30, 30.1, 0.25), scaled=False, return_x=True, bw_method=0.2))

In [None]:
thisda = xr.open_dataarray(f"{DATADIR}/ERA5/Geopotential/north_atlantic/detrended.nc").isel(time=Zoo.time.dt.season=="DJF")
states = ghmm.predict(Y)
projection = ccrs.LambertConformal(
    central_longitude=np.mean(thisda.lon.values),
)
fig, axes = plt.subplots(1, n_components, figsize=(20, 6), subplot_kw={"projection": projection}, constrained_layout=True)
to_plot = [thisda.isel(time=states==i).mean(dim="time") for i in range(n_components)]
lon = thisda["lon"].values
lat = thisda["lat"].values
extent = [np.amin(lon), np.amax(lon), np.amin(lat), np.amax(lat)]
boundary = make_boundary_path(*extent)
levels = 11
cmap = "seismic"
lower, upper = -150, 150
levels = np.delete(np.append(np.linspace(lower, 0, levels), np.linspace(0, upper, levels)), [levels - 1, levels])
cmap = cm.get_cmap(cmap)
norm = mpl.colors.BoundaryNorm(levels, cmap.N, extend='both')
im = cm.ScalarMappable(norm=norm, cmap=cmap)
axes = axes.flatten()
for i in range(n_components):
    axes[i].contourf(
        lon,
        lat,
        to_plot[i], 
        transform=ccrs.PlateCarree(),
        levels=levels, cmap=cmap, norm=norm,
        extend="both",
    )
    axes[i].contour(
        lon,
        lat,
        to_plot[i], 
        transform=ccrs.PlateCarree(),
        levels=levels, colors="k",
    )
    axes[i].set_boundary(boundary, transform=ccrs.PlateCarree())
    axes[i].add_feature(COASTLINE)
    axes[i].add_feature(BORDERS)
cbar = fig.colorbar(im, ax=axes.ravel().tolist(), spacing="proportional")
cbar.ax.set_ylabel("Z500 [m]")
_ = cbar.ax.set_yticks(np.concatenate([np.arange(-150, 20, 30), np.arange(30, 151, 30)]))
plt.show()

### Length of events

In [None]:
def runs_of_ones_array(bits): # https://stackoverflow.com/questions/1066758/find-length-of-sequences-of-identical-values-in-a-numpy-array-run-length-encodi
    # make sure all runs of ones are well-bounded
    bounded = np.hstack(([0], bits, [0]))
    # get 1 at run starts and -1 at run ends
    difs = np.diff(bounded)
    run_starts, = np.where(difs > 0)
    run_ends, = np.where(difs < 0)
    return run_starts, run_ends - run_starts

fig, ax = plt.subplots()
for i in range(ghmm.n_components):
    thisseq = states==i
    _, durations = runs_of_ones_array(thisseq)
    dur, occu = np.unique(durations, return_counts=True)
    durind = np.argsort(dur)
    dur = dur[durind]
    occu = occu[durind] / np.amax(occu)
    ax.plot(dur, occu)

## Quasi stationary states

## Dynamical systems theory

In [None]:
import CDSK as ck

In [None]:
local_indices = []
for year in YEARSPL:
    print(year)
    da = xr.open_dataset(f"{DATADIR}/ERA5/Wind/300/north_atlantic/{year}.nc")["u"]
    local_indices.append(ck.dynamical_local_indexes(da.values.reshape(len(da.time), -1, 1)))
# todo : classify wind using Lachmy & Harnik 2016

In [None]:
ld = xr.DataArray(np.concatenate([li[0].flatten() for li in local_indices]), coords={"time": DATERANGEPL}, name="ld")
theta = xr.DataArray(np.concatenate([li[1].flatten() for li in local_indices]), coords={"time": DATERANGEPL}, name="theta")

In [None]:
ld.to_netcdf("ld_u300_NA.nc")
theta.to_netcdf("theta_u300_NA.nc")

In [None]:
fig, axes = plt.subplots(2, 1, figsize=(10, 12), sharex=True)
fig.subplots_adjust(hspace=0)
ld.plot(ax=axes[0])
persistence = 1 / theta
q95 = np.quantile(persistence, 0.95)
theta.plot(ax=axes[1])
theta.isel(time=persistence>=q95).plot(color="red", ls="", marker="x")

## HMM on theta, because it looks ladder-y

In [None]:
Y_theta = theta.values.reshape((-1, 1))
n_components = 4
ghmm_theta = GaussianHMM(n_components=n_components).fit(Y_theta)
belongs = ghmm_theta.predict(Y_theta)
theta.plot()
for i in range(n_components):
    theta.isel(time=belongs==i).plot(marker="x", color=COLORS10[i%10], ls="")

In [None]:
fig, ax = plt.subplots()
for i, s in enumerate(["DJF", "MAM", "JJA", "SON"]):
    ax.scatter(ld.isel(time=ld.time.dt.season==s), theta.isel(time=ld.time.dt.season==s), s=2, label=s, c=COLORS5[i])
# plt.scatter(ld[20000:], theta[5:])
ax.set_xlim([0, 25])
ax.legend()

In [None]:
EKE = xr.open_dataarray(f"{DATADIR}/NCEP/processed/EDG/EKE.nc").sel(time=DATERANGEPL)

In [None]:
EKE.lat

In [None]:
mean_EKE_NA = EKE.isel(lat=(EKE.lat >= 20)&(EKE.lat <= 80), lon=(EKE.lon >= -30) & (EKE.lon <= 90)).mean(dim=["lon", "lat"])
mean_EKE_NA.plot()

In [None]:
fig, ax = plt.subplots()
idx = np.argsort(mean_EKE_NA).values
hi = ax.scatter(ld[idx], theta[idx], c=mean_EKE_NA[idx], s=mean_EKE_NA[idx], cmap="cool", vmin=0, vmax=140)
fig.colorbar(hi)
ax.set_xlim([-1, 200])

In [None]:
da.differentiate(coord="time").mean(dim=["longitude", "latitude"]).plot()

In [None]:
da

### Residence times

# Recurrence

## Window counts

## Dispersion metric

## Ripley K

## Recurrence plots

# Duncan's hotspells

### Create hotspells file

In [None]:
list_of_dates = np.loadtxt("hotspells.csv", delimiter=",", dtype=np.datetime64)
hotspells = {}
dataset = "ERA5"
keys = ["South", "West", "Balkans", "Scandinavia", "Russia", "Arctic"]
minus = 21
plus = 5
for j, key in enumerate(keys):
    hotspells[key] = []
    dates = np.sort(list_of_dates[:, j])
    dates = dates[~(np.isnat(dates) | (np.datetime_as_string(dates, unit="Y") == "2022"))]
    print(key, len(dates))
    for i, date in enumerate(dates):
        tsta = date - np.timedelta64(minus, "D") + np.timedelta64(9, "h")
        tend = date + np.timedelta64(plus, "D") + np.timedelta64(9, "h")
        thisds = xr.open_dataset(f"{DATADIR}/{dataset}/Wind/300/dailymean/{np.datetime_as_string(date, unit='Y')}.nc")
        thisds = thisds.sel(time=pd.date_range(tsta, tend, freq="1D"))
        thisds = thisds.assign_coords({"time": np.arange(-minus, plus + 1)})
        hotspells[key].append(thisds)
    hotspells[key] = xr.concat(hotspells[key], dim="hotspell").assign_coords({"hotspell": dates})
with open(f"{DATADIR}/{dataset}/processed/hotspells_uv300.pkl", "wb") as handle:
    pkl.dump(hotspells, handle)

- Replace time coords with range(21), keep track of the center date (or all time as index but not coord) : done
- create aggregates : mean, std, correlation with T, look at individuals for all clusters
- compare with duncan's plot, also get boundaries of the clusters
- cluster wind myself and compare backwards

### Study hotspells

In [None]:
dataset = "ERA5"
with open(f"{DATADIR}/{dataset}/processed/hotspells_uv300.pkl", "rb") as handle:
    hotspells = pkl.load(handle)

meanwind_hotspell = xr.concat([hotspell.mean(dim="hotspell") for hotspell in hotspells.values()], dim="region").assign_coords({"region": list(hotspells.keys())})

In [None]:
meanwind_hotspell_NA = meanwind_hotspell.sel(latitude=np.arange(30, 90, 0.5)).load()
lon_NA = meanwind_hotspell_NA.longitude.values
lat_NA = meanwind_hotspell_NA.latitude.values

In [None]:
hvplot.extension('bokeh')

ticker_region = pnw.Select(name="Region", options=meanwind_hotspell.region.values.tolist())
ticker_variable = pnw.Select(name="Component", options=list(meanwind_hotspell.data_vars.keys()))
ticker_kind= pnw.Select(name="Kind", options=["contour", "contourf", "quadmesh"])

tsta, tend = int(np.amin(meanwind_hotspell.time.values)), int(np.amax(meanwind_hotspell.time.values))
slider = pnw.IntSlider(name="Day around center", start=tsta, end=tend)

extent = [np.amin(lon_NA), np.amax(lon_NA), np.amin(lat_NA), np.amax(lat_NA)]

meanwind_hotspell_NA.interactive.sel(region=ticker_region, time=slider).hvplot(
    kind=ticker_kind, x="longitude", y="latitude", z=ticker_variable, 
    title="Wind at 300hPa", cmap="seismic", symmetric=True, line_width=1.5, projection="NorthPolarStereo",
    geo=True, coastline='110m', levels=11, xlim=extent[:2], ylim=extent[2:4],
)

## Zoo during hotspells

### Create hotspells_zoo file

In [None]:
list_of_dates = np.loadtxt("hotspells.csv", delimiter=",", dtype=np.datetime64)
hotspells_Zoo = {}
dataset = "ERA5"
keys = ["South", "West", "Balkans", "Scandinavia", "Russia", "Arctic"]
dataset = "ERA5"
datadir = f"{DATADIR}/{dataset}/processed"
Zoo = xr.open_dataset(f"{datadir}/BarriopedroZooDetrended.nc")
Zookeys = list(Zoo.data_vars.keys()) # copy and not view !
minus = 21
plus = 5
for varname in Zookeys:
    if varname[-11:].split('_')[-1] == "climatology":
        del Zoo[varname]
for j, key in enumerate(keys):
    hotspells_Zoo[key] = []
    dates = np.sort(list_of_dates[:, j])
    dates = dates[~(np.isnat(dates) | (np.datetime_as_string(dates, unit="Y") == "2022"))]
    for i, date in enumerate(dates):
        tsta = date - np.timedelta64(minus, "D") + np.timedelta64(9, "h")
        tend = date + np.timedelta64(plus, "D") + np.timedelta64(9, "h")
        thisds = Zoo.sel(time=pd.date_range(tsta, tend, freq="1D"))
        thisds.attrs["center_date"] = date
        thisds = thisds.assign_coords({"time": np.arange(-minus, plus + 1)}).reset_index("dayofyear", drop=True)
        hotspells_Zoo[key].append(thisds)
    hotspells_Zoo[key] = xr.concat(hotspells_Zoo[key], dim="hotspell").assign_coords({"hotspell": dates})
with open(f"{DATADIR}/{dataset}/processed/hotspells_Zoo.pkl", "wb") as handle:
    pkl.dump(hotspells_Zoo, handle)

### Plot

In [None]:
dataset = "ERA5"
datadir = f"{DATADIR}/{dataset}/processed"
Zoo = xr.open_dataset(f"{datadir}/BarriopedroZooDetrended.nc")
with open(f"{DATADIR}/{dataset}/processed/hotspells_Zoo.pkl", "rb") as handle:
    hotspells_Zoo = pkl.load(handle)

In [None]:
fig, axes = plt.subplots(4, 3, figsize=(20, 25), tight_layout=True)
axes = axes.flatten()
for k, key in enumerate([f"{key}_anomaly" for key in ZOO]):
    ax = axes[k]
    for i, regionkey in enumerate(hotspells_Zoo):
        to_plot = hotspells_Zoo[regionkey][key].mean(dim="hotspell")
        (to_plot / np.amax(np.abs(to_plot))).plot(ax=ax, label=regionkey, color=COLORS10[(2 * i) % 9], lw=2)
        # ax.fill_between(
        #     to_plot.time, 
        #     *(np.quantile(hotspells_Zoo[regionkey][key], [0.05, 0.95], axis=0) / np.amax(np.abs(to_plot)).values), 
        #     color=COLORS10[(2 * i) % 9],
        #     alpha=0.1,
        # )
    ax.set_title(key)
    ax.set_xlabel("Time around center")
    ax.set_ylabel("Normalized anomaly")
    if k==9:
        ax.legend(ncol=2)

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(15, 9), tight_layout=True)
axes = axes.flatten()
tsta, tend = hotspells_Zoo["South"].time.values[[0, -1]]
for k, regionkey in enumerate(hotspells_Zoo):
    ax = axes[k]
    for i, key in enumerate([f"{key}_anomaly" for key in ["Lat", "Int", "Tilt", "Lon", "Mea"]]):
        to_plot = hotspells_Zoo[regionkey][key].mean(dim="hotspell")
        (to_plot / np.amax(np.abs(Zoo[key]))).plot(ax=ax, label=key.split("_")[0], color=COLORS5[i % 5], lw=2)
    ax.set_title(regionkey)
    ax.set_xlabel("Days around center")
    ax.set_ylabel("Normalized anomaly")
    ax.set_xticks(np.arange(tsta, tend + 1, 2))
    if k==5:
        ax.legend(ncol=2)

## Dynamical indices during hotspells

In [None]:
list_of_dates = np.loadtxt("hotspells.csv", delimiter=",", dtype=np.datetime64)
hotspells_dynind = {}
dataset = "ERA5"
keys = ["South", "West", "Balkans", "Scandinavia", "Russia", "Arctic"]
dataset = "ERA5"
datadir = f"{DATADIR}/{dataset}/processed"
ld = xr.open_dataarray(f"{datadir}/ld.nc")
theta = xr.open_dataarray(f"{datadir}/theta.nc")
indices = xr.Dataset({"ld": ld, "theta": theta})
indikeys = list(indices.data_vars.keys()) # copy and not view !
minus = 21
plus = 5
for j, key in enumerate(keys):
    hotspells_dynind[key] = []
    dates = np.sort(list_of_dates[:, j])
    dates = dates[~(np.isnat(dates) | (np.datetime_as_string(dates, unit="Y") == "2022"))]
    for i, date in enumerate(dates):
        tsta = date - np.timedelta64(minus, "D")
        tend = date + np.timedelta64(plus, "D")
        thisds = indices.sel(time=pd.date_range(tsta, tend, freq="1D"))
        thisds.attrs["center_date"] = date
        thisds = thisds.assign_coords({"time": np.arange(-minus, plus + 1)})
        hotspells_dynind[key].append(thisds)
    hotspells_dynind[key] = xr.concat(hotspells_dynind[key], dim="hotspell").assign_coords({"hotspell": dates})
with open(f"{DATADIR}/{dataset}/processed/hotspells_dynind.pkl", "wb") as handle:
    pkl.dump(hotspells_dynind, handle)

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(15, 9), tight_layout=True)
axes = axes.flatten()
tsta, tend = hotspells_dynind["South"].time.values[[0, -1]]
for k, regionkey in enumerate(hotspells_dynind):
    ax = axes[k]
    for i, key in enumerate(hotspells_dynind["South"].data_vars):
        to_plot = hotspells_dynind[regionkey][key].mean(dim="hotspell")
        (to_plot / np.amax(to_plot)).plot(ax=ax, label=key.split("_")[0], color=COLORS5[i % 5], lw=2)
    ax.set_title(regionkey)
    ax.set_xlabel("Day around center")
    ax.set_ylabel("Normalized index")
    ax.set_xticks(np.arange(tsta, tend + 1, 2))
    if k==5:
        ax.legend(ncol=2)

# Misc

### Create_plot

In [None]:
def create_plot(to_plot, titles, levels, twolevel=False, startindex=-1):
    # Figure
    transform = ccrs.PlateCarree()
    projection = transform
    if twolevel:
        fig, axes = plt.subplots(
            2,
            int(len(to_plot) / 2),
            subplot_kw={"projection": projection}, constrained_layout=True #, figsize=(6 * len(to_plot) // 2, 13)
        )
    else:
        fig, axes = plt.subplots(
            1, len(to_plot), subplot_kw={"projection": projection}, constrained_layout=True, figsize=(3.5 * len(to_plot), 6)
        )
    axes = np.atleast_1d(axes)
    axes = axes.flatten()

    # Add coastline and boarders
    coastline = feat.NaturalEarthFeature(
        "physical", "coastline", "10m", edgecolor="black", facecolor="none"
    )
    borders = feat.NaturalEarthFeature(
        "cultural",
        "admin_0_boundary_lines_land",
        "10m",
        edgecolor="grey",
        facecolor="none",
    )
    plt_rej = []
    cbar = [None] * len(to_plot)
    for j in range(len(to_plot)):
        ax = axes[j]
        plt_rej.append(
            ax.contourf(
                to_plot[j]["lon"].values[:, None] * np.ones(len(to_plot[j]["lat"])),
                to_plot[j]["lat"].values[None, :] * np.ones(len(to_plot[j]["lon"]))[:, None],
                to_plot[j].isel(time=startindex).transpose(),
                levels=levels[j],
                transform=transform,
                transform_first=True,
                # cmap=cmap,
                zorder=0,
            )
        )

        ax.add_feature(coastline)
        ax.add_feature(borders)
        ax.set_xmargin(0)
        ax.set_ymargin(0)
        ax.set_title(f"Day {startindex}, {titles[j]}, g.a : {np.mean(to_plot[j][startindex]):.2f}")

        cbar[j] = fig.colorbar(plt_rej[j], ax=ax,fraction=0.046, pad=0.04)

    def animate_all(i):
        global plt_rej
        for j in range(len(to_plot)):
            ax = axes[j]
            for c in plt_rej[j].collections:
                c.remove()
            plt_rej[j] = ax.contourf(
                to_plot[j]["lon"].values[:, None] * np.ones(len(to_plot[j]["lat"])),
                to_plot[j]["lat"].values[None, :] * np.ones(len(to_plot[j]["lon"]))[:, None],
                to_plot[j].isel(time=i).transpose(),
                levels=levels[j],
                transform=transform,
                transform_first=True,
                # cmap=cmap,
                zorder=0,
            )
            ax.set_title(f"Day {i + 1}, {titles[j]}, g.a : {np.mean(to_plot[j][i]):.2f}")
            cbar[j] = fig.colorbar(plt_rej[j], cax=fig.axes[len(axes) + j])
        return plt_rej

    return fig, axes, plt_rej, animate_all

### Fetch

In [None]:
longname = {
    "u": "U-component of wind",
    "v": "V-component of wind",
    "w": "W-component of wind",
    "z": "Geopotential",
    "t": "Temperature",
    "vo": "Relative vorticity",
    "q": "Specific humidity",
    "r": "relative humidity"
    
}
variablemap = {
    f"{var}{lev}": [var, "PL", lev, f"{longname[var]} at {lev} hPa"] 
    for var in ["u", "v", "vo"] 
    for lev in range(700, 901, 50)
}
var = "z"
vm2 = {
    f"{var}{lev}": [var, "PL", lev, f"{longname[var]} at {lev} hPa"] 
    for lev in [300, 500]
}

variablemap.update(vm2)
variablemap["t850"] = ["t", "PL", 850, f"{longname['t']} at 850 hPa"]

### PV calculations

In [None]:
ds = xr.open_mfdataset(fn(DATERANGEML[0], which="ML")[:1], combine="nested", concat_dim="time")
ds["P"] = (ds["hybm"] * ds["PS"] + ds["hyam"]).isel(lev_2=0).drop("lev_2").rename({"nhym": "lev"})
ds["P"].attrs["units"] = "Pa"
ds["T"].attrs["units"] = "celsius"
ds = ds.isel(lat=range(1, len(ds.lat) -1)).metpy.quantify()
ds["THETA"] = mcalc.potential_temperature(ds["P"], ds["T"])
ds["PV"] = mcalc.potential_vorticity_baroclinic(ds["THETA"], ds["P"], ds["U"], ds["V"], x_dim=3, y_dim=2, vertical_dim=1)

In [None]:
fig, axes = plt.subplots(5, 4, figsize=[20, 20])
axes = axes.flatten()
for l, k in enumerate(range(0, 137, int(137/20) + 1)):
    ds["PV"].isel(lev=k, time=0).plot(ax=axes[l])

In [None]:
dims = ["time", "lat", "lon"]
ds["U_2PVU"] = xr.DataArray(np.empty([len(ds["T"].coords[dim]) for dim in dims]), dims=dims, coords={dim: ds["T"].coords[dim] for dim in dims})
ds["V_2PVU"] = ds["U_2PVU"].copy()
for ti, t in enumerate(ds.time):
    this_ds = ds.isel(time=ti)
    for w in ["U", "V"]:
        ds[f"{w}_2PVU"][ti, :, :] = minterpolate.interpolate_to_isosurface(this_ds["PV"].values, this_ds[w].values, 2e-6)

In [None]:
ds["U_2PVU"].plot()

In [None]:
ds["V_2PVU"].plot()

### Matplotlib widgets example (use panel + hvplot it's much faster)

In [None]:
transform = ccrs.PlateCarree()
projection = ccrs.LambertConformal(central_longitude=np.mean(lon_NA))
fig, ax = plt.subplots(figsize=(15, 7), subplot_kw={"projection": projection})
REGION = "South"
DAY = 0
VARIABLE = "u"
mesh = ax.pcolormesh(
    lon_NA, lat_NA, meanwind_hotspell_NA[VARIABLE].sel(region=REGION, time=DAY), 
    shading="nearest", cmap="bwr", transform=transform, norm=CenteredNorm())
fig.subplots_adjust(left=0.3)
cbar = fig.colorbar(mesh)
ax.add_feature(COASTLINE)
ax.add_feature(BORDERS)
extent = [np.amin(lon_NA), np.amax(lon_NA), np.amin(lat_NA), np.amax(lat_NA)]
boundary = make_boundary_path(*extent)
ax.set_boundary(boundary, transform=ccrs.PlateCarree())

def update_mesh():
    global mesh, cbar
    mesh.set_array(meanwind_hotspell_NA[VARIABLE].sel(region=REGION, time=DAY).values.flatten())
    mesh.autoscale()
    cbar.update_normal(mesh)
    plt.draw()

def change_region(region):
    global REGION
    REGION=region
    update_mesh()

def change_day(day):
    global DAY
    DAY=day
    update_mesh()

def change_variable(variable):
    global VARIABLE
    VARIABLE=variable
    update_mesh()

rax = fig.add_axes([0.05, 0.4, 0.2, 0.4])
radio_region = RadioButtons(rax, meanwind_hotspell[VARIABLE].region.values)
radio_region.on_clicked(change_region)

rax = fig.add_axes([0.05, 0.25, 0.2, 0.17])
radio_variable = RadioButtons(rax, list(meanwind_hotspell.data_vars.keys()))
radio_variable.on_clicked(change_variable)

sax = fig.add_axes([0.32, 0.05, 0.4, 0.03])
slider_day = Slider(
    ax=sax,
    label='Day',
    valmin=-10,
    valstep=1,
    valmax=10,
    valinit=0,
)
slider_day.on_changed(change_day)
plt.show()