In [1]:
from jetstream_hugo.definitions import *
from jetstream_hugo.plots import *
from jetstream_hugo.data import *
# from jetstream_hugo.anyspell import *
from jetstream_hugo.jet_finding import *
from jetstream_hugo.clustering import *

import colormaps

%load_ext autoreload
%autoreload 2
%matplotlib inline

# arco-era5 tests

In [None]:
ds = xr.open_zarr(
    "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3",
    chunks=None,
    storage_options=dict(token="anon"),
)
ar_full_37_1h = ds.sel(
    time=slice(ds.attrs["valid_time_start"], ds.attrs["valid_time_stop"])
)

base_ds = (
    ar_full_37_1h[["u_component_of_wind", "v_component_of_wind"]]
    .sel(
        time=ar_full_37_1h.time.dt.hour % 6 == 0,
        latitude=ar_full_37_1h.latitude >= 0,
        level=[175, 200, 225, 250, 300, 350],
    )
    .isel(longitude=slice(None, None, 2), latitude=slice(None, None, 2))
)

base_path = Path(f"{DATADIR}/ERA5/plev/flat_wind/dailymean")
for year in YEARS:
    for month in range(1, 13):
        month_str = str(month).zfill(2)
        opath = base_path.joinpath(f"{year}{month_str}.nc")
        if opath.is_file():
            continue
        ds = compute(
            base_ds.sel(
                time=(base_ds.time.dt.year == year) & (base_ds.time.dt.month == month)
            ),
            progress=True,
        )
        ds = standardize(ds)
        ds["s"] = np.sqrt(ds["u"] ** 2 + ds["v"] ** 2)
        ds = flatten_by(ds, "s")
        ds.to_netcdf(opath)
        print(f"Completed {year}{month}")

In [None]:
ds = xr.open_zarr(
    "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3",
    chunks=None,
    storage_options=dict(token="anon"),
)
ar_full_37_1h = ds.sel(
    time=slice(ds.attrs["valid_time_start"], ds.attrs["valid_time_stop"])
)

temp_full = (
    ar_full_37_1h["temperature"]
    .sel(
        time=ar_full_37_1h.time.dt.hour % 6 == 0,
        latitude=ar_full_37_1h.latitude >= 0,
        level=[175, 200, 225, 250, 300, 350],
    )
    .isel(longitude=slice(None, None, 2), latitude=slice(None, None, 2))
)

temp_full = standardize(temp_full)

In [None]:
ds = xr.open_zarr(
    "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3",
    chunks=None,
    storage_options=dict(token="anon"),
)
ar_full_37_1h = ds.sel(
    time=slice(ds.attrs["valid_time_start"], ds.attrs["valid_time_stop"])
)

temp_full = (
    ar_full_37_1h["temperature"]
    .sel(
        time=ar_full_37_1h.time.dt.hour % 6 == 0,
        latitude=ar_full_37_1h.latitude >= 0,
        level=[175, 200, 225, 250, 300, 350],
    )
    .isel(longitude=slice(None, None, 2), latitude=slice(None, None, 2))
)

temp_full = standardize(temp_full)

orig_path = Path(f"{DATADIR}/ERA5/plev/flat_wind/dailymean")
base_path = Path(f"{DATADIR}/ERA5/plev/flat_wind/dailymean_2")
for year in tqdm(YEARS):
    for month in trange(1, 13, leave=False):
        month_str = str(month).zfill(2)
        opath = base_path.joinpath(f"{year}{month_str}.nc")
        if opath.is_file():
            continue
        ipath = orig_path.joinpath(f"{year}{month_str}.nc")
        ds = xr.open_dataset(ipath)
        this_temp = temp_full.sel(time=ds.time.values, lev=ds["lev"])
        this_temp = this_temp * (1000 / this_temp.lev) ** KAPPA
        this_temp = this_temp.reset_coords("lev", drop=True)
        ds["theta"] = compute(this_temp, progress_flag=True)
        ds.to_netcdf(opath)

# new pvs das

In [None]:
import polars_st as st

In [None]:
def to_xarray(events: st.GeoDataFrame, dummy_da: xr.DataArray, varname: str):
    dummy_da = dummy_da.rename("dummy")
    da_df = pl.from_pandas(dummy_da.to_dataframe().reset_index())
    orig_times = da_df["time"]
    timedtype = orig_times.dtype
    da_df = da_df.drop("time", "dummy")
    da_df = da_df.unique(["lat", "lon"]).sort(["lat", "lon"]).with_columns(geometry=st.from_xy("lon", "lat"))
    da_df = st.GeoDataFrame(da_df)
    events = events.with_columns(pl.col("geometry").st.buffer(0.25))
    events = (
        events.select(["date", "geometry", varname])
        .cast({varname: pl.Float32})
        .rename({"date": "time"})
    )
    events = events.cast({"time": timedtype})
    events = events.filter(pl.col("time").is_in(orig_times))
    dummy_da = xr.zeros_like(dummy_da, dtype=np.float32)
    events = events.st.sjoin(da_df, on="geometry", how="inner", predicate="contains")
    events = events.unique(["time", "lon", "lat"])
    events_da = xr.DataArray.from_series(
        events[["time", varname, "lat", "lon"]]
        .to_pandas()
        .set_index(["time", "lat", "lon"])[varname]
    ).fillna(0)
    dummy_da.loc[
        {
            "time": events_da.time.values,
            "lat": events_da.lat.values,
            "lon": events_da.lon.values,
        }
    ] = events_da
    return dummy_da

In [None]:
all_events = {}
for level in trange(310, 365, 5):
    events = st.from_geopandas(gpd.read_parquet(f"/storage/workspaces/giub_meteo_impacts/ci01/ERA5/RWB_index/era5_pv_streamers_{level}K_1959-2022.parquet"))

    tropospheric = events.filter(pl.col("mean_var") < pl.col("level"))
    anticyclonic = tropospheric.filter(pl.col("intensity") >= pl.col("level"))
    cyclonic = tropospheric.filter(pl.col("intensity") < pl.col("level"))
    
    all_events[level] = {"anti": anticyclonic, "cycl": cyclonic}

In [None]:
for year in YEARS:
    # for month in range(1, 13):
    ofile_anti = Path(f"{DATADIR}/ERA5/thetalev/apvs/6H/{year}.nc")
    ofile_cycl = Path(f"{DATADIR}/ERA5/thetalev/cpvs/6H/{year}.nc")
    if ofile_cycl.is_file() and ofile_anti.is_file():
        continue
    time_mask = (TIMERANGE.year == year)# & (TIMERANGE.month == month)
    coords = {
        "time": TIMERANGE[time_mask],
        "lat": np.arange(0, 90.5, .5),
        "lon": np.arange(-180, 180, .5),
    }
    shape = [len(co) for co in coords.values()]
    dummy_da = xr.DataArray(np.zeros(shape), coords=coords)
    anti_all_levs = {}
    cycl_all_levs = {}
    for lev, events in tqdm(all_events.items(), total=11):
        anti_all_levs[lev] = to_xarray(events["anti"], dummy_da, "intensity")
        cycl_all_levs[lev] = to_xarray(events["cycl"], dummy_da, "intensity")
    anti_all_levs = xr.concat(anti_all_levs.values(), dim="lev").assign_coords(lev=list(anti_all_levs))
    cycl_all_levs = xr.concat(cycl_all_levs.values(), dim="lev").assign_coords(lev=list(cycl_all_levs))
    anti_all_levs.to_netcdf(ofile_anti)
    cycl_all_levs.to_netcdf(ofile_cycl)

# CESM

In [8]:
from jetstream_hugo.definitions import *
from jetstream_hugo.data import *
import intake
url = 'https://raw.githubusercontent.com/NCAR/cesm2-le-aws/main/intake-catalogs/aws-cesm2-le.json'
varname = "TS"
varname_to_search = varname
component = "atm"
# experiment = "historical"
experiment = "ssp370"
frequency = "daily"
forcing_variant = "cmip6"
# period = (1980, 2009)
period = (2070, 2099)
season = None
minlon = -180
maxlon = 180
minlat = 0
maxlat = 90
levels = list(range(13, 19))
members = "all"
reduce_da = True

indexers = [component, experiment, frequency, forcing_variant]
indexers = [np.atleast_1d(indexer) for indexer in indexers]
indexers = [[str(idx_) for idx_ in indexer] for indexer in indexers]
indexers = list(product(*indexers))
ensemble_keys = [".".join(indexer) for indexer in indexers]

basepath = Path(f"{DATADIR}/CESM2/{varname}")
catalog = intake.open_esm_datastore(url)
catalog_subset = catalog.search(
    variable=varname_to_search,
    component=component,
    experiment=experiment,
    frequency=frequency,
    forcing_variant=forcing_variant,
)
dsets = catalog_subset.to_dataset_dict(
    xarray_open_kwargs={"consolidated": False},
    storage_options={"anon": True}
)
ds = dsets[ensemble_keys[0]]

ds = standardize(ds)

da = extract(
    ds,
    period=period,
    season=season,
    minlon=minlon,
    maxlon=maxlon,
    minlat=minlat,
    maxlat=maxlat,
    # levels=levels,
    # members=members,
)[varname]


--> The keys in the returned dictionary of datasets are constructed as follows:
	'component.experiment.frequency.forcing_variant'


  da["time"] = da.indexes["time"].to_datetimeindex()


In [9]:
da = compute(da, progress_flag=True)

[########################################] | 100% Completed | 12m 3ss


In [10]:
da.to_zarr("/storage/workspaces/giub_meteo_impacts/ci01/CESM2/TS/past.zarr", mode="w")

<xarray.backends.zarr.ZarrStore at 0x7f243135eac0>

In [None]:
from pathlib import Path
import numpy as np
import xarray as xr
basepath = Path("/storage/workspaces/giub_meteo_impacts/ci01/CESM2/flat_wind")
paths = list(basepath.iterdir())
paths = [path for path in paths if path.suffix == ".nc" and path.name != "ds.nc"]
parts = [path.name.split(".")[0].split("-") for path in paths]
parts = np.asarray(parts)
sorted_order = np.argsort([memb.replace("r10", "r0") for memb in parts[:, 0]])
parts = parts[sorted_order]
paths = [paths[i] for i in sorted_order]
all_members = np.unique(parts[:, 0])
all_years = np.unique(parts[:, 1])

not_here = []
here = []
for year in all_years:
    for member in all_members:
        potential_path = basepath.joinpath(f"{member}-{year}.nc")
        if potential_path.is_file():
            here.append(potential_path)
        else:
            not_here.append(potential_path)
len(here)

from itertools import groupby
paths_to_load = []
valid_ensembles = []
for key, indices in groupby(range(len(parts)), lambda i: parts[i][0]):
    indices = list(indices)
    group = parts[indices]
    these_paths = [paths[i] for i in indices]
    years = np.asarray(list([g[1] for g in group]))
    if len(years) == 60:
        paths_to_load.append(these_paths)
        valid_ensembles.append(key)
    else:
        print(key, len(years))

In [None]:
from tqdm import tqdm
ds = []
for ptl in tqdm(paths_to_load):
    ds_ = []
    for p in ptl:
        this = xr.open_dataset(p)
        this = this.reset_coords("time_bnds", drop=True).drop_dims("nbnd")
        ds_.append(this)
    ds.append(xr.concat(ds_, dim="time"))
ds = xr.concat(ds, dim="member")
# ds = xr.concat([xr.concat([xr.open_dataset[ptl_] for ptl_ in ptl], dim="time") for ptl in paths_to_load], dim="member")

In [None]:
import dask
from dask.distributed import progress, Client
from jetstream_hugo.definitions import COMPUTE_KWARGS
client = Client(**COMPUTE_KWARGS)
dask.persist(ds)
progress(ds, notebook=False)
ds = dask.compute(ds)

In [None]:
ds = ds[0]
to_comp = ds.to_zarr(f"/storage/workspaces/giub_meteo_impacts/ci01/CESM2/flat_wind/ds.zarr", compute=False, encoding={var: {"chunks": (-1, 100, -1, -1)} for var in ds.data_vars}, mode="w")
dask.persist(to_comp)
progress(to_comp, notebook=False)
dask.compute(to_comp)