In [1]:
from jetstream_hugo.definitions import *
from jetstream_hugo.plots import *
from jetstream_hugo.data import *
from jetstream_hugo.anyspell import *
from jetstream_hugo.jet_finding import *
from jetstream_hugo.clustering import *
import intake

import colormaps

%load_ext autoreload
%autoreload 2
%matplotlib inline

# from dask.distributed import Client, progress
# client = Client(**COMPUTE_KWARGS)

# arco-era5 tests

In [3]:
from dask.diagnostics import ProgressBar
def _compute(obj, progress: bool = False, **kwargs):
    kwargs = COMPUTE_KWARGS | kwargs
    try:
        if progress:
            with ProgressBar():
                return obj.compute(**kwargs)
        else:
            return obj.compute(**kwargs)
    except AttributeError:
        return obj

In [2]:
import xarray as xr

ds = xr.open_zarr(
    'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3',
    chunks=None,
    storage_options=dict(token='anon'),
)
ar_full_37_1h = ds.sel(time=slice(ds.attrs['valid_time_start'], ds.attrs['valid_time_stop']))


In [16]:
base_ds = ar_full_37_1h[["u_component_of_wind", "v_component_of_wind"]].sel(time=ar_full_37_1h.time.dt.hour % 6 == 0, latitude=ar_full_37_1h.latitude >= 0, level=[175,  200,  225,  250,  300,  350]).isel(longitude=slice(None, None, 2), latitude=slice(None, None, 2))

In [None]:
base_path = Path(f"/storage/workspaces/giub_meteo_impacts/ci01/ERA5/plev/flat_wind/dailymean")
from concurrent.futures import ThreadPoolExecutor, as_completed

def downloader(base_ds, base_path, month, year):
    month_str = str(month).zfill(2)
    opath = base_path.joinpath(f"{year}{month_str}.nc")
    if opath.is_file():
        return f"Already had {year}{month}"
    ds = _compute(base_ds.sel(time=(base_ds.time.dt.year==year) & (base_ds.time.dt.month==month)), progress=True)
    ds = standardize(ds)
    ds["s"] = np.sqrt(ds["u"] ** 2 + ds["v"] ** 2)
    ds = flatten_by(ds, "s")
    ds.to_netcdf(opath)
    return f"Completed {year}{month}"

with ThreadPoolExecutor(max_workers=6) as executor:
    futures = [
        executor.submit(downloader, base_ds.copy(), base_path, month, year) for year in YEARS for month in range(1, 13)
    ]
    for f in as_completed(futures):
        try:
            print(f.result())
        except:
            print("could not retrieve")

In [4]:
varnames = {
    "t2m": "2m_temperature",
    "mslp": "mean_sea_level_pressure",
    "sst": "sea_surface_temperature",
    "tp": "total_precipitation",
}
for key, val in varnames.items():
    base_da = ar_full_37_1h[val].sel(time=ar_full_37_1h.time.dt.hour % 6 == 0, latitude=ar_full_37_1h.latitude >= 0)[:, ::2, ::2]
    base_path = Path(f"/storage/workspaces/giub_meteo_impacts/ci01/ERA5/surf/{key}/dailymean")
    print(base_path)
    for year in tqdm(YEARS):
        opath = base_path.joinpath(f"{year}.nc")
        if opath.is_file():
            da = xr.open_dataarray(opath)
            da = standardize(da)
        else:
            da = _compute(base_da.sel(time=base_da.time.dt.year==year), progress=True).resample(time="1D").mean()
        da.to_netcdf(opath)

/storage/workspaces/giub_meteo_impacts/ci01/ERA5/surf/t2m/dailymean


100%|██████████| 64/64 [00:29<00:00,  2.19it/s]


/storage/workspaces/giub_meteo_impacts/ci01/ERA5/surf/mslp/dailymean


100%|██████████| 64/64 [00:29<00:00,  2.20it/s]


/storage/workspaces/giub_meteo_impacts/ci01/ERA5/surf/sst/dailymean


100%|██████████| 64/64 [00:34<00:00,  1.87it/s]


/storage/workspaces/giub_meteo_impacts/ci01/ERA5/surf/tp/dailymean


100%|██████████| 64/64 [00:59<00:00,  1.08it/s]


# Climatologies, datahandlers of new data

In [8]:
compute_all_smoothed_anomalies("ERA5", "surf", "t2m", "dailymean", 'dayofyear', {'dayofyear': ('win', 15)}, None)

[########################################] | 100.00% Completed | 30.55 s


100%|██████████| 64/64 [00:25<00:00,  2.52it/s]


In [9]:
compute_all_smoothed_anomalies("ERA5", "surf", "tp", "dailymean", 'dayofyear', {'dayofyear': ('win', 15)}, None)

[########################################] | 100.00% Completed | 38.06 s


100%|██████████| 64/64 [00:57<00:00,  1.11it/s]


In [5]:
compute_all_smoothed_anomalies("ERA5", "surf", "sst", "dailymean", 'dayofyear', {'dayofyear': ('win', 15)}, None)

[########################################] | 100.00% Completed | 30.18 s


100%|██████████| 64/64 [00:22<00:00,  2.82it/s]


In [6]:
compute_all_smoothed_anomalies("ERA5", "surf", "mslp", "dailymean", 'dayofyear', {'dayofyear': ('win', 15)}, None)

[########################################] | 100.00% Completed | 31.71 s


100%|██████████| 64/64 [00:22<00:00,  2.81it/s]


# new pvs das

In [2]:
from jetstream_hugo.definitions import TIMERANGE
import numpy as np
import xarray as xr
import geopandas as gpd
from tqdm import trange
from wavebreaking import to_xarray
da = xr.open_dataarray(f"/storage/workspaces/giub_meteo_impacts/ci01/ERA5/surf/t2m/dailymean/1999.nc")
coords = {
    "time": TIMERANGE,
    "lat": da.lat.values,
    "lon": da.lon.values,
}
shape = [len(co) for co in coords.values()]
dummy_da = xr.DataArray(np.zeros(shape), coords=coords)
das_ones = []
das_int = []
events = gpd.read_file(f"/storage/workspaces/giub_meteo_impacts/ci01/ERA5/RWB_index/era5_pv_streamers_310K_1959-2022.parquet")

ERROR 1: PROJ: proj_create_from_database: Open of /storage/homefs/hb22g102/miniforge3/envs/env11_2/share/proj failed


In [3]:
def one_level_events(level: int):
    
    
    return da_ones, da_int

: 

In [45]:
from jetstream_hugo.definitions import TIMERANGE
import numpy as np
import xarray as xr
import geopandas as gpd
from tqdm import trange
from wavebreaking import to_xarray
da = xr.open_dataarray(f"/storage/workspaces/giub_meteo_impacts/ci01/ERA5/surf/t2m/dailymean/1999.nc")
coords = {
    "time": TIMERANGE,
    "lat": da.lat.values,
    "lon": da.lon.values,
}
shape = [len(co) for co in coords.values()]
dummy_da = xr.DataArray(np.zeros(shape), coords=coords)
das_ones = []
das_int = []

da = xr.open_dataarray(f"/storage/workspaces/giub_meteo_impacts/ci01/ERA5/surf/t2m/dailymean/1999.nc")
coords = {
    "time": TIMERANGE,
    "lat": da.lat.values,
    "lon": da.lon.values,
}
shape = [len(co) for co in coords.values()]
dummy_da = xr.DataArray(np.zeros(shape), coords=coords)

for level in trange(310, 365, 5):
    events = gpd.read_file(f"/storage/workspaces/giub_meteo_impacts/ci01/ERA5/RWB_index/era5_pv_streamers_{level}K_1959-2022.parquet")
    
    tropospheric = events[events.mean_var < events.level]

    anticyclonic = tropospheric[tropospheric.intensity >= 0].reset_index(drop=True)
    cyclonic = tropospheric[tropospheric.intensity < 0].reset_index(drop=True)
    
    # da_anti = to_xarray(dummy_da, anticyclonic)
    # da_cycl = to_xarray(dummy_da, cyclonic)
    
    # da_ones = xr.concat([da_anti, da_cycl], dim="type").assign_coords({"type": ["anti", "cycl"]})
    
    # da_anti = to_xarray(dummy_da, anticyclonic, "intensity", "intensity")
    # da_cycl = to_xarray(dummy_da, cyclonic, "intensity", "intensity")
    
    # da_int = xr.concat([da_anti, da_cycl], dim="type").assign_coords({"type": ["anti", "cycl"]})
    break

# res = map_maybe_parallel(levels, one_level_events, len(levels), processes=1)
# da_ones, da_int = list(zip(*res))
# da_ones = xr.concat(da_ones, dim="level").assign_coords(level=levels)
# da_int = xr.concat(da_int, dim="level").assign_coords(level=levels)

  0%|          | 0/11 [00:10<?, ?it/s]


In [60]:
cyclonic

Unnamed: 0,date,level,mean_var,event_area,intensity,__index_level_0__,geometry
0,1959-01-01 00:00:00,-2.0,-2.40,387642.21,-56.03,0,"MULTIPOLYGON (((35 -48, 36 -48, 37 -48, 38 -48..."
1,1959-01-01 00:00:00,2.0,1.69,1507990.88,-98.13,4,"MULTIPOLYGON (((179 70, 179 66, 178 66, 177 66..."
2,1959-01-01 06:00:00,-2.0,-2.55,622487.10,-98.13,5,"MULTIPOLYGON (((39 -48, 40 -48, 41 -48, 42 -48..."
3,1959-01-01 06:00:00,2.0,1.88,901350.71,-116.81,9,"MULTIPOLYGON (((179 71, 179 68, 178 68, 177 68..."
4,1959-01-01 12:00:00,-2.0,-2.43,1030645.39,-59.38,10,"MULTIPOLYGON (((40 -50, 41 -50, 42 -50, 43 -50..."
...,...,...,...,...,...,...,...
191464,2022-12-31 18:00:00,-2.0,-2.63,462634.04,-121.72,663852,"MULTIPOLYGON (((106 -59, 106 -58, 107 -57, 108..."
191465,2022-12-31 18:00:00,2.0,1.59,439534.20,-19.05,663854,"MULTIPOLYGON (((-105 52, -106 53, -106 54, -10..."
191466,2022-12-31 18:00:00,2.0,1.42,1222794.74,-25.11,663856,"MULTIPOLYGON (((-89 52, -88 53, -87 54, -86 55..."
191467,2022-12-31 18:00:00,2.0,1.48,1049064.84,-64.59,663860,"MULTIPOLYGON (((32 62, 33 63, 33 64, 33 65, 32..."


In [64]:
geoms = cyclonic["geometry"].get_coordinates(index_parts=True).reset_index()

In [68]:
geoms["time"] = cyclonic.loc[geoms["level_0"].values, "date"].values
geoms["intensity"] = cyclonic.loc[geoms["level_0"].values, "intensity"].values
geoms = geoms.drop(["level_0", "level_1"], axis=1)

In [69]:
geoms

Unnamed: 0,x,y,time,intensity
0,35.0,-48.0,1959-01-01 00:00:00,-56.03
1,36.0,-48.0,1959-01-01 00:00:00,-56.03
2,37.0,-48.0,1959-01-01 00:00:00,-56.03
3,38.0,-48.0,1959-01-01 00:00:00,-56.03
4,39.0,-48.0,1959-01-01 00:00:00,-56.03
...,...,...,...,...
8167195,62.0,58.0,2022-12-31 18:00:00,-22.23
8167196,62.0,57.0,2022-12-31 18:00:00,-22.23
8167197,63.0,56.0,2022-12-31 18:00:00,-22.23
8167198,63.0,55.0,2022-12-31 18:00:00,-22.23


In [14]:
da_int = da_int.astype(np.float16)

In [19]:
93504 * 181 * 720 * 2 * 2

48.74176512

In [16]:
da_int.nbytes / 10 ** 9

48.74176512

# CESM

In [1]:
from pathlib import Path
import numpy as np
import xarray as xr
basepath = Path("/storage/workspaces/giub_meteo_impacts/ci01/CESM2/flat_wind")
paths = list(basepath.iterdir())
paths = [path for path in paths if path.suffix == ".nc" and path.name != "ds.nc"]
parts = [path.name.split(".")[0].split("-") for path in paths]
parts = np.asarray(parts)
sorted_order = np.argsort([memb.replace("r10", "r0") for memb in parts[:, 0]])
parts = parts[sorted_order]
paths = [paths[i] for i in sorted_order]
all_members = np.unique(parts[:, 0])
all_years = np.unique(parts[:, 1])

not_here = []
here = []
for year in all_years:
    for member in all_members:
        potential_path = basepath.joinpath(f"{member}-{year}.nc")
        if potential_path.is_file():
            here.append(potential_path)
        else:
            not_here.append(potential_path)
len(here)

from itertools import groupby
paths_to_load = []
valid_ensembles = []
for key, indices in groupby(range(len(parts)), lambda i: parts[i][0]):
    indices = list(indices)
    group = parts[indices]
    these_paths = [paths[i] for i in indices]
    years = np.asarray(list([g[1] for g in group]))
    if len(years) == 60:
        paths_to_load.append(these_paths)
        valid_ensembles.append(key)
    else:
        print(key, len(years))

In [11]:
from tqdm import tqdm
ds = []
for ptl in tqdm(paths_to_load):
    ds_ = []
    for p in ptl:
        this = xr.open_dataset(p)
        this = this.reset_coords("time_bnds", drop=True).drop_dims("nbnd")
        ds_.append(this)
    ds.append(xr.concat(ds_, dim="time"))
ds = xr.concat(ds, dim="member")
# ds = xr.concat([xr.concat([xr.open_dataset[ptl_] for ptl_ in ptl], dim="time") for ptl in paths_to_load], dim="member")

In [None]:
import dask
from dask.distributed import progress, Client
from jetstream_hugo.definitions import COMPUTE_KWARGS
client = Client(**COMPUTE_KWARGS)
dask.persist(ds)
progress(ds, notebook=False)
ds = dask.compute(ds)

In [16]:
ds = ds[0]
to_comp = ds.to_zarr(f"/storage/workspaces/giub_meteo_impacts/ci01/CESM2/flat_wind/ds.zarr", compute=False, encoding={var: {"chunks": (-1, 100, -1, -1)} for var in ds.data_vars}, mode="w")
dask.persist(to_comp)
progress(to_comp, notebook=False)
dask.compute(to_comp)

# extreme cesm clim

In [16]:
import numpy as np
from jetstream_hugo.data import *
quantiles = ds["s"].quantile(np.arange(0.6, 1, 0.05), ["member", "lon", "lat"]).compute()
quantiles = smooth(quantiles, {"time": ("win", 15)}).load()
quantiles.to_netcdf(f"{DATADIR}/CESM2/flat_wind/results/s_q.nc")


In [19]:
from matplotlib.dates import DateFormatter, MonthLocator
from jetstream_hugo.definitions import *
from scipy.stats import linregress

def get_trend(da):
    years = np.unique(da.time.dt.year)
    if "jet" not in da.dims:
        result = linregress(years, da.values)
        return xr.Dataset({"slope": result.slope, "p": result.pvalue})
    jets = da.jet.values
    slopes = xr.DataArray(np.zeros(len(jets)), coords={"jet": jets})
    pvalues = slopes.copy()
    for j, jet in enumerate(jets):
        result = linregress(years, da.isel(jet=j).values)
        slopes[j] = result.slope
        pvalues[j] = result.pvalue
    return xr.Dataset({"slope": slopes, "p": pvalues})

winsize = 15
halfwinsize = int(np.ceil(winsize / 2))

fig, ax = plt.subplots()
for q, qval in zip(quantiles[::2], np.arange(0.6, 1, 0.05 * 2)):
    gb = q.groupby("time.dayofyear")
    x = list(gb.groups)
    x = DATERANGE[x]
    ys = gb.map(get_trend) 
    ps = ys["p"]
    ys = ys["slope"]
    ys = ys.pad({"dayofyear": halfwinsize}, mode="wrap")
    ys = ys.rolling(dayofyear=winsize, center=True).mean()
    ys = ys.isel({"dayofyear": slice(halfwinsize, -halfwinsize)})
    ax.plot(x, ys, label=f"$q={qval:.1f}$", lw=2)
ax.grid(True)
ax.xaxis.set_major_locator(MonthLocator(range(0, 13, 3)))
ax.xaxis.set_major_formatter(DateFormatter("1 %b"))
ax.set_xlim(min(x), max(x))
ax.legend()

In [10]:
q_clim = compute_clim(quantiles, "dayofyear")
q_clim = smooth(q_clim, {"dayofyear": ("win", 61)}).load()
q_clim.to_netcdf(f"{DATADIR}/CESM2/flat_wind/results/s_q_clim.nc")

In [19]:
ds = ds.reset_coords("time_bnds", drop=True)

# Extreme experiment

In [None]:
exp_tp = ExtremeExperiment(
    DataHandler("ERA5", "surf", "tp", "6H", "all", "JJA", -30, 40, 30, 75, 250, 'hourofyear', {'hourofyear': ('win', 4 * 15)}, None),
    q = 0.95,
)
da_tp = exp_tp.da.load()

data_handlers = {}
for varname in ["u", "v", "s"]:
    dh = DataHandler("ERA5", "plev", varname, "6H", "all", None, -80, 40, 15, 80, [175, 200, 225, 250, 300, 350], reduce_da=False)
    data_handlers[varname] = dh
exp = MultiVarExperiment(data_handlers)

all_jets_one_df, where_are_jets, all_jets_one_array, all_jets_over_time, flags = exp.track_jets()
props_as_ds = exp.props_as_ds(True)

In [None]:
from deepdiff import DeepHash
DeepHash(load_pickle("/storage/workspaces/giub_meteo_impacts/ci01/ERA5/surf/tp/6H/hourofyear_hourofyearwin60/results/1/predictions/1/metadata.pkl"))

In [None]:
subset = ["mean_lon", "mean_lat", "mean_lev", "spe_star", "width", "wavinessR16", "persistence", "com_speed", "int"]
predictors = prepare_predictors(
    props_as_ds,
    subset=subset,
    anomalize=True,
    normalize=True,
    detrend=True,
    nan_method="nearest",
    season="JJA",
)
time_before = pd.Timedelta(0, "D")
n_clu = 22
clusters_da = exp_tp.spatial_clusters_as_da(n_clu)
targets, length_targets, all_spells_ts, all_spells = exp_tp.create_targets(n_clu, 0.95, minlen=np.timedelta64(1, "D"))
binary_targets = length_targets > 0
masked_predictors = mask_from_spells_multi_region(predictors, targets, all_spells_ts, all_spells, time_before=time_before)

In [None]:
clu = Clusterplot(1, 1, exp_tp.region)
cmap = colormaps.BlAqGrYeOrReVi200
ax = clu.axes[0]
unique_clusters = np.arange(n_clu)
norm = BoundaryNorm(np.concatenate([[-1], unique_clusters]) + 0.5, cmap.N)
clusters_da.unstack().plot(
    ax=ax,
    cmap=cmap,
    norm=norm,
    add_colorbar=False,
    add_labels=False
)
for j in range(n_clu):
    lo = clusters_da.lon.where(clusters_da==j).mean().item()
    la = clusters_da.lat.where(clusters_da==j).mean().item()
    ax.text(lo, la, f"${j}$", ha="center", va="center", fontweight="bold")

In [None]:
compute_all_smoothed_anomalies("ERA5", "plev", "s", "6H", 'hourofyear', {'hourofyear': ('win', 4 * 15)}, None)

In [None]:
compute_all_smoothed_anomalies("ERA5", "surf", "tp", "6H", 'hourofyear', {'hourofyear': ('win', 4 * 15)}, None)

In [None]:
basepath = Path(f"{DATADIR}/ERA5/surf")
varnames = ["u10", "v10", "s10"]
for year, month in tqdm(product(YEARS, range(1, 13)), total=len(YEARS) * 12):
    month_str = str(month).zfill(2)
    ofiles = {varname: basepath.joinpath(f"{varname}/6H/{year}{month_str}.nc") for varname in varnames}
    if all([ofile.is_file() for ofile in ofiles.values()]):
        continue
    ds = xr.open_dataset(basepath.joinpath(f"raw/{year}{month_str}.nc"))
    ds = ds.rename(longitude="lon", latitude="lat")
    ds = ds.assign_coords(lon=(((ds.lon + 180) % 360) - 180))
    ds = ds.sortby("lon")
    ds = ds.sortby("lat")
    ds["s10"] = np.sqrt(ds["u10"] ** 2 + ds["v10"] ** 2)
    for varname in varnames:
        da = ds[varname]
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=xr.SerializationWarning)
            da.to_netcdf(ofiles[varname])