In [1]:
import datetime

import colormaps
import matplotlib.pyplot as plt
import numpy as np
import polars as pl
import xarray as xr
from jetutils.anyspell import get_persistent_jet_spells, mask_from_spells_pl, subset_around_onset
from jetutils.clustering import Experiment
from jetutils.data import *
from jetutils.definitions import *
from jetutils.jet_finding import JetFindingExperiment, iterate_over_year_maybe_member, average_jet_categories
from jetutils.plots import COLORS, Clusterplot
from jetutils.geospatial import *
from matplotlib.cm import ScalarMappable
from matplotlib.colors import BoundaryNorm
from matplotlib.ticker import MaxNLocator
from tqdm import tqdm

%load_ext IPython.extensions.autoreload
%autoreload 2
%matplotlib inline

IPython could not be loaded!


Found config override file at  /storage/homefs/hb22g102/.jetutils.ini
Guessed N_WORKERS :  10
Guessed MEMORY_LIMIT :  102400


# 2023-24 wind data

In [4]:
basepath = Path(DATADIR, "ERA5/plev/high_wind/6H")
for year in [2024]:
    to_concat = []
    yearstr = str(year).zfill(4)
    for month in range(1, 13):
        monthstr = str(month).zfill(2)
        ds = standardize(xr.open_dataset(basepath.joinpath(f"{yearstr}{monthstr}_raw.nc")))
        ds["t"] = standardize(xr.open_dataarray(basepath.joinpath(f"{yearstr}{monthstr}_raw_t.nc")))
        opath = basepath.joinpath()
        ds["s"] = np.sqrt(ds["u"] ** 2 + ds["v"] ** 2)
        ds = flatten_by(ds, "s")
        ds["theta"] = ds["t"] * (1000 / ds["lev"]) ** KAPPA
        ds = ds.drop_vars("t")
        to_concat.append(ds)
    to_concat = xr.concat(to_concat, dim="time")
    to_concat.to_netcdf(basepath.joinpath(f"{yearstr}.nc"))

[########################################] | 100% Completed | 10.32 s
[########################################] | 100% Completed | 9.39 ss
[########################################] | 100% Completed | 11.00 s
[########################################] | 100% Completed | 9.99 ss
[########################################] | 100% Completed | 10.57 s
[########################################] | 100% Completed | 9.66 ss
[########################################] | 100% Completed | 10.57 s
[########################################] | 100% Completed | 9.67 ss
[########################################] | 100% Completed | 10.48 s
[########################################] | 100% Completed | 10.07 s
[########################################] | 100% Completed | 9.86 ss
[########################################] | 100% Completed | 10.39 s


In [31]:
ds.drop_vars("t")

In [36]:
xr.open_dataset(basepath.joinpath("2022.nc"))

# new pvs

In [2]:
from pathlib import Path
from tqdm import tqdm
import numpy as np
from jetutils.definitions import compute, YEARS, DATADIR, TIMERANGE
from jetutils.data import open_da, to_netcdf
import polars as pl
import geopandas as gpd
import polars_st as st
import xarray as xr
import gc

def to_xarray(events: st.GeoDataFrame, dummy_da: xr.DataArray, varname: str):
    orig_times = pl.Series("time", dummy_da["time"].values)
    timedtype = orig_times.dtype
    lon, lat = [pl.Series(co, dummy_da[co].values).to_frame() for co in ["lon", "lat"]]
    dlo = lon["lon"][1] - lon["lon"][0]
    dla = lat["lat"][1] - lat["lat"][0]
    da_df = lat.join(lon, how="cross")
    da_df = da_df.cast({"lon": pl.Float32, "lat": pl.Float32})
    da_df = da_df.with_columns(geometry=st.point(pl.concat_list("lon", "lat")))
    da_df = st.GeoDataFrame(da_df)
    events = events.with_columns(pl.col("geometry").st.buffer((dlo + dla) / 4))
    if varname == "flag":
        events = events.with_columns(flag=pl.lit(1))
    events = (
        events.select(["time", "geometry", varname])
        .cast({varname: pl.UInt32 if varname == "flag" else pl.Float32})
    )
    events = events.cast({"time": timedtype})
    events = events.filter(pl.col("time").is_in(orig_times.implode()))
    events = events.st.sjoin(da_df, on="geometry", predicate="intersects")
    events = events.unique(["time", "lon", "lat"])
    dummy_da = xr.zeros_like(dummy_da, dtype=np.uint32 if varname=="flag" else np.float32)
    events_da = xr.DataArray.from_series(
        events[["time", varname, "lat", "lon"]]
        .to_pandas()
        .set_index(["time", "lat", "lon"])[varname]
        .astype(np.uint32 if varname=="flag" else np.float32)
    ).fillna(0)
    dummy_da.loc[
        {
            "time": events_da.time.values,
            "lat": events_da.lat.values,
            "lon": events_da.lon.values,
        }
    ] = events_da
    return dummy_da

In [None]:
all_events = {}
basepath = Path(DATADIR, "ERA5/RWB_index/pv")
levs = list(range(310, 365, 5))
for level in levs:
    events = st.GeoDataFrame(pl.read_parquet(basepath.joinpath(f"era5_pv_overturnings_{level}K_1959-2022.parquet")))
    events = events.rename({"date": "time"}).cast({"time": pl.Datetime("ms")})

    anticyclonic = events.filter(pl.col("orientation") == pl.lit("anticyclonic"))
    cyclonic = events.filter(pl.col("orientation") == pl.lit("cyclonic"))
    
    all_events[f"anti_{level}"] = anticyclonic
    all_events[f"cycl_{level}"] = cyclonic
    
    
varname = "flag"
dtype = {"flag": np.uint8, "intensity": np.float32, "mean_var": np.float32}[varname]
coords = {
    "time": TIMERANGE,
    "lat": np.arange(15, 80.5, 1),
    "lon": np.arange(-80, 40.5, 1),
}
shape = [len(co) for co in coords.values()]
dummy_da = xr.DataArray(np.zeros(shape, dtype=dtype), coords=coords)
all_events_xr = {}
for name, events in tqdm(all_events.items()):
    all_events_xr[name] = to_xarray(events, dummy_da, "flag")

xr.concat([all_events_xr[f"anti_{lev}"] for lev in levs], dim="lev").assign_coords(lev=levs).to_netcdf(basepath.joinpath("overturnings_anti_natl.nc"))
xr.concat([all_events_xr[f"cycl_{lev}"] for lev in levs], dim="lev").assign_coords(lev=levs).to_netcdf(basepath.joinpath("overturnings_cycl_natl.nc"))

100%|██████████| 22/22 [06:37<00:00, 18.08s/it]


In [19]:
from wavebreaking import to_xarray as to_xarray_orig
all_events = {}
basepath = Path(DATADIR, "ERA5/RWB_index/pv")
levs = list(range(310, 365, 5))
for level in levs:
    events = gpd.read_parquet(basepath.joinpath(f"era5_pv_streamers_{level}K_1959-2022.parquet"))
    stratospheric = events[events.mean_var >= 2]
    tropospheric = events[events.mean_var < 2]
    
    all_events[f"anti_strato_{level}"] = stratospheric[stratospheric.intensity >= 0]
    all_events[f"anti_tropo_{level}"] = tropospheric[tropospheric.intensity >= 0]
    all_events[f"cycl_strato_{level}"] = stratospheric[stratospheric.intensity < 0]
    all_events[f"cycl_tropo_{level}"] = tropospheric[tropospheric.intensity < 0]
    
    
opath = basepath.joinpath("spatial")
varname = "flag"
dtype = {"flag": np.uint32, "intensity": np.float32, "mean_var": np.float32}[varname]
coords = {
    "time": TIMERANGE,
    "lat": np.arange(15, 80.5, 1),
    "lon": np.arange(-80, 40.5, 1),
}
shape = [len(co) for co in coords.values()]
dummy_da = xr.DataArray(np.zeros(shape, dtype=dtype), coords=coords)
all_events_xr = {}
for name, events in tqdm(all_events.items()):
    all_events_xr[name] = to_xarray_orig(dummy_da, events)

xr.concat([all_events_xr[f"anti_strato_{level}"] for lev in levs], dim="lev").assign_coords(lev=levs).to_netcdf(basepath.joinpath("streamers_anti_strato_natl.nc"))
xr.concat([all_events_xr[f"anti_tropo_{level}"] for lev in levs], dim="lev").assign_coords(lev=levs).to_netcdf(basepath.joinpath("streamers_anti_tropo_natl.nc"))
xr.concat([all_events_xr[f"cycl_strato_{level}"] for lev in levs], dim="lev").assign_coords(lev=levs).to_netcdf(basepath.joinpath("streamers_cycl_strato_natl.nc"))
xr.concat([all_events_xr[f"cycl_tropo_{level}"] for lev in levs], dim="lev").assign_coords(lev=levs).to_netcdf(basepath.joinpath("streamers_cycl_tropo_natl.nc"))

100%|██████████| 44/44 [05:52<00:00,  8.00s/it]


In [16]:
f.stem

'overturnings_anti_natl'

In [None]:
from itertools import product
import shutil
streamers_subtypes = ["_".join(both) for both in product(["stratospheric", "tropospheric"], ["anticyclonic", "cyclonic"])]
types = {
    "overturnings": ["anticyclonic", "cyclonic"], 
    "streamers": streamers_subtypes
}
basepath = Path(DATADIR, "ERA5/thetalev")
for type_, subtypes in types.items():
    for subtype in subtypes:
        if "_" in subtype:
            shorthand_1 = "".join([sub[0].upper() for sub in subtype.split("_")])
            shorthand_2 = subtype.split("_")
            shorthand_2 = shorthand_2[1][:4] + "_" + shorthand_2[0].rstrip("spheric")
        else:
            shorthand_1 = subtype[0].upper()
            shorthand_2 = subtype[:4]
        file_stem = f"{type_}_{shorthand_2}"
        shorthand = shorthand_1 + "PV" + type_[0].upper()
        this_path = basepath.joinpath(shorthand)
        this_path.mkdir(exist_ok=True)
        for freq in ["6H", "dailyany"]:
            file_spec = "" if freq == "6H" else f"_{freq}"
            file_stem_ = f"{type_}_{shorthand_2}_natl{file_spec}.nc"
            this_path_ = this_path.joinpath(freq)
            this_path_.mkdir(exist_ok=True)
            source = Path(DATADIR, "ERA5", "RWB_index", "pv").joinpath(f"{file_stem_}")
            dest = this_path_.joinpath("full.nc")
            print(source, dest)
            shutil.copy(source, dest)
        # print(this_path, shorthand_2)
    

In [None]:
basepath = Path(DATADIR, "ERA5/RWB_index/pv")
files_to_treat = basepath.glob("*.nc")
levels = list(range(310, 355, 5))
for f in tqdm(files_to_treat):
    da = xr.open_dataarray(f)
    da = da.sel(lev=levels).any("lev").resample(time="1D").any().astype(np.uint8)
    da.to_netcdf(f.parent.joinpath(f"{f.stem}_dailyany.nc"))

6it [00:45,  7.53s/it]


In [28]:
from jetutils.data import *
basepath = Path(DATADIR, "ERA5/RWB_index/pv")
files_to_treat = basepath.glob("*dailyany.nc")
for f in tqdm(files_to_treat):
    da = xr.open_dataarray(f)
    clim = smooth(compute_clim(da, "dayofyear"), {"dayofyear": ("win", 15)}).astype(np.float32)
    anom = da.astype(np.float32).groupby("time.dayofyear") - clim
    clim.astype(np.float32).to_netcdf(f.parent.joinpath(f"{f.stem}_clim.nc"))
    anom.astype(np.float32).to_netcdf(f.parent.joinpath(f"{f.stem}_anom.nc"))

6it [00:17,  2.86s/it]


In [None]:
for rwb_type in ["APVO", "CPVO", "SAPVS", "SCPVS", "TAPVS", "TCPVS"]:
    dh = DataHandler.from_specs("ERA5", "thetalev", "APVO", "dailyany", "all", None, -80, 40, 15, 80)
    da = compute(dh.da)
    

[########################################] | 100% Completed | 706.27 ms


<jetutils.data.DataHandler at 0x7f9815b09040>

# CESM clims

In [6]:
da_tp = xr.open_zarr("/storage/workspaces/giub_meteo_impacts/ci01/CESM2/PRECL/past.zarr")

clim = da_tp.groupby("time.dayofyear").mean()
clim = smooth(clim, {'dayofyear': ('win', 15)})
clim = compute(clim, progress_flag=True)
clim.to_zarr("/storage/workspaces/giub_meteo_impacts/ci01/CESM2/PRECL/past_clim.zarr")

[########################################] | 100% Completed | 126.37 s


<xarray.backends.zarr.ZarrStore at 0x7ff12afe3d00>

In [None]:
anom = da_tp.groupby("time.dayofyear") - clim
anom = compute(anom, progress_flag=True)
anom.to_zarr("/storage/workspaces/giub_meteo_impacts/ci01/CESM2/PRECL/past_anom.zarr")

In [None]:
da_T = xr.open_zarr("/storage/workspaces/giub_meteo_impacts/ci01/CESM2/TS/past.zarr")

clim = da_tp.groupby("time.dayofyear").mean()
clim = smooth(clim, {'dayofyear': ('win', 15)})
clim = compute(clim, progress_flag=True)
clim.to_zarr("/storage/workspaces/giub_meteo_impacts/ci01/CESM2/PRECL/past_clim.zarr")

In [7]:
# da_tp = xr.open_zarr("/storage/workspaces/giub_meteo_impacts/ci01/CESM2/PRECL/future.zarr")

# clim = da_tp.groupby("time.dayofyear").mean()
# clim = smooth(clim, {'dayofyear': ('win', 15)})
# clim = compute(clim, progress_flag=True)
# clim.to_zarr("/storage/workspaces/giub_meteo_impacts/ci01/CESM2/PRECL/future_clim.zarr")

# create jet relative climatologies

In [2]:
dh = DataHandler.from_specs("ERA5", "plev", "high_wind", "6H", "all", None, -80, 40, 15, 80)
exp = JetFindingExperiment(dh)
all_jets_one_df = exp.find_jets(force=False, alignment_thresh=0.6, base_s_thresh=0.55, int_thresh_factor=0.35, hole_size=10)
all_jets_one_df = exp.categorize_jets(None, ["s", "theta"], force=False, n_init=5, init_params="k-means++", mode="week").cast({"time": pl.Datetime("ms")})

phat_jets = all_jets_one_df.filter((pl.col("is_polar").mean().over(["time", "jet ID"]) < 0.5) | ((pl.col("is_polar").mean().over(["time", "jet ID"]) > 0.5) & (pl.col("int").mode().first().over(["time", "jet ID"]) > 1.5e8)))
phat_jets_catd = phat_jets.with_columns(**{"jet ID": (pl.col("is_polar").mean().over(["time", "jet ID"]) > 0.5).cast(pl.UInt32())})

In [3]:
def create_jet_relative_clim(jets, path, da, suffix="", half_length: float = 20., std: bool = False):
    jets = jets.with_columns(pl.col("time").dt.round("1d"))
    jets = jets.with_columns(jets.group_by("time", maintain_order=True).agg(pl.col("jet ID").rle_id())["jet ID"].explode())
    indexer = iterate_over_year_maybe_member(jets, da)
    to_average = []
    varname = da.name + "_interp"
    for idx1, idx2 in tqdm(indexer, total=len(YEARS)):
        jets_ = jets.filter(*idx1)
        da_ = da.sel(**idx2)
        try:
            jets_with_interp = gather_normal_da_jets(jets_, da_, half_length=half_length)
        except (KeyError, ValueError) as e:
            print(e)
            break
        jets_with_interp = interp_jets_to_zero_one(jets_with_interp, [varname, "is_polar"], n_interp=30)
        jets_with_interp = jets_with_interp.group_by("time", pl.col("is_polar").mean().over(["time", "jet ID"]) > 0.5, "norm_index", "n", maintain_order=True).agg(pl.col(varname).mean())
        to_average.append(jets_with_interp)
    agg = pl.col(varname).std() if std else pl.col(varname).mean() 
    extra_suffix = "_std" if std else ""
    clim = (
        pl.concat(to_average)
        .group_by(
            pl.col("time").dt.ordinal_day().alias("dayofyear"), "is_polar", "norm_index", "n"
        ).agg(agg)
        .sort("dayofyear", "is_polar", "norm_index", "n")
    )
    clim_ds = polars_to_xarray(clim, ["dayofyear", "is_polar", "n", "norm_index"])
    clim_ds.to_netcdf(path.joinpath(f"{da.name}{suffix}{extra_suffix}_relative_clim.nc"))

In [23]:
from jetutils.definitions import DATERANGE, TIMERANGE
ds = xr.open_mfdataset("/storage/workspaces/giub_meteo_impacts/ci01/ERA5/blocks/*.nc")
ds = (ds["flag"] != 0).astype(np.int8).sel(time=np.isin(ds.time.dt.year, YEARS)).assign_coords(time=TIMERANGE).chunk("auto")

In [25]:
haha = ds.to_netcdf("/storage/workspaces/giub_meteo_impacts/ci01/ERA5/surf/blocks/6H/full.nc", compute=False)
from dask.diagnostics import ProgressBar
with ProgressBar():
    haha.compute()

[########################################] | 100% Completed | 157.01 s


In [26]:
compute_all_dailymeans("ERA5", "surf", "blocks")

  0%|          | 0/1 [00:45<?, ?it/s]


In [None]:
def create_jet_relative_dataset(jets, path, da, suffix="", half_length: float = 20.):
    jets = jets.with_columns(pl.col("time").dt.round("1d"))
    jets = jets.with_columns(jets.group_by("time", maintain_order=True).agg(pl.col("jet ID").rle_id())["jet ID"].explode())
    indexer = iterate_over_year_maybe_member(jets, da)
    to_average = []
    varname = da.name + "_interp"
    for idx1, idx2 in tqdm(indexer, total=len(YEARS)):
        jets_ = jets.filter(*idx1)
        da_ = da.sel(**idx2)
        try:
            jets_with_interp = gather_normal_da_jets(jets_, da_, half_length=half_length)
        except (KeyError, ValueError) as e:
            print(e)
            break
        jets_with_interp = interp_jets_to_zero_one(jets_with_interp, [varname, "is_polar"], n_interp=30)
        jets_with_interp = jets_with_interp.group_by("time", pl.col("is_polar").mean().over(["time", "jet ID"]) > 0.5, "norm_index", "n", maintain_order=True).agg(pl.col(varname).mean())
        to_average.append(jets_with_interp)
    pl.concat(to_average).write_parquet(path.joinpath(f"{da.name}{suffix}_relative.parquet"))
    
    
args = ["all", None, -100, 60, 0, 90]

args = ["all", None, *get_region(exp.ds)]
da_blocks = open_da("ERA5", "surf", "blocks", "dailymean", *args)
da_blocks = compute(da_blocks)
create_jet_relative_dataset(phat_jets_catd, exp.path, da_blocks, suffix="_phat_catd")
del da_blocks

100%|██████████| 64/64 [03:02<00:00,  2.84s/it]


"not all values found in index 'time'"


In [8]:
args = ["all", None, *get_region(exp.ds)]
da_T = open_da("ERA5", "surf", "t2m", "dailymean", *args)
da_T = compute(da_T)
create_jet_relative_clim(phat_jets_catd, exp.path, da_T, suffix="_phat_catd", std=True)
del da_T

da_tp = open_da("ERA5", "surf", "tp", "dailysum", *args)
da_tp = compute(da_tp)
create_jet_relative_clim(phat_jets_catd, exp.path, da_tp, suffix="_phat_catd", std=True)
del da_tp

da_pv = open_da("ERA5", "thetalev", "PV330", "dailymean", *args)
da_pv = compute(da_pv).rename("PV330")
create_jet_relative_clim(phat_jets_catd, exp.path, da_pv, suffix="_phat_catd", std=True)
del da_pv

da_pv = open_da("ERA5", "thetalev", "PV350", "dailymean", *args)
da_pv = compute(da_pv)
create_jet_relative_clim(phat_jets_catd, exp.path, da_pv, suffix="_phat_catd", std=True)
del da_pv

da_theta2pvu = open_da("ERA5", "surf", "theta2PVU", "dailymean", *args)
da_theta2pvu = compute(da_theta2pvu)
create_jet_relative_clim(phat_jets_catd, exp.path, da_theta2pvu, suffix="_phat_catd", std=True)
del da_theta2pvu

varnames_rwb = ["APVO", "CPVO"]
for var in varnamres_rwb:
    da_rwb = open_da("ERA5", "thetalev", var, "dailyany", *args)
    da_rwb = compute(da_rwb)
    create_jet_relative_clim(phat_jets_catd, exp.path, da_rwb, suffix="_phat_catd", std=True)
    del da_rwb

100%|██████████| 64/64 [05:07<00:00,  4.81s/it]


"not all values found in index 'time'"


100%|██████████| 64/64 [05:05<00:00,  4.77s/it]


"not all values found in index 'time'"


100%|██████████| 64/64 [05:20<00:00,  5.00s/it]


"not all values found in index 'time'"


100%|██████████| 64/64 [05:19<00:00,  5.00s/it]


"not all values found in index 'time'"


100%|██████████| 64/64 [05:23<00:00,  5.05s/it]


"not all values found in index 'time'"


100%|██████████| 64/64 [03:29<00:00,  3.27s/it]


"not all values found in index 'time'"


100%|██████████| 64/64 [03:28<00:00,  3.25s/it]


"not all values found in index 'time'"


In [22]:
# ["APVO", "CPVO", "SAPVS", "SCPVS", "TAPVS", "TCPVS"]
for rwb_type in ["CPVO", "SAPVS", "SCPVS", "TAPVS", "TCPVS"]:
    dh = DataHandler.from_specs("ERA5", "thetalev", rwb_type, "dailyany", "all", None, -80, 40, 15, 80)
    da = compute(dh.da)
    print(da.name)
    create_jet_relative_clim(phat_jets_catd, exp.path, da, suffix="_phat_catd")
    del da

[########################################] | 100% Completed | 504.78 ms
CPVO


100%|██████████| 64/64 [05:36<00:00,  5.26s/it]


[########################################] | 100% Completed | 505.01 ms
SAPVS


100%|██████████| 64/64 [05:42<00:00,  5.35s/it]


[########################################] | 100% Completed | 605.23 ms
SCPVS


100%|██████████| 64/64 [05:42<00:00,  5.35s/it]


[########################################] | 100% Completed | 805.89 ms
TAPVS


100%|██████████| 64/64 [05:40<00:00,  5.32s/it]


[########################################] | 100% Completed | 706.41 ms
TCPVS


100%|██████████| 64/64 [05:39<00:00,  5.31s/it]


In [19]:
exp.path

PosixPath('/storage/workspaces/giub_meteo_impacts/ci01/ERA5/plev/high_wind/6H/results/7')

In [2]:
from pathlib import Path
base_path_1 = Path(f"{DATADIR}/ERA5/surf/theta2PVU/6H")
base_path_2 = Path(f"{DATADIR}/ERA5/surf/theta2PVU/dailymean")
# base_path_2.mkdir()
for year in YEARS:
    print(year, end="\r")
    opath_1 = base_path_1.joinpath(f"{year}.nc")
    opath_2 = base_path_2.joinpath(f"{year}.nc")

    if opath_2.is_file():
        continue
    
    this_pv = standardize(open_dataarray(opath_1))
    this_pv = compute(this_pv, progress_flag=False)
    this_pv = this_pv.resample(time="1d").mean()    
    this_pv.to_netcdf(opath_2)

2022

In [5]:
args = ["all", None, *get_region(exp.ds), "all", "dayofyear", {"dayofyear": ("win", 15)}]
da_T = open_da("ERA5", "surf", "t2m", "dailymean", *args)
da_T = compute(da_T)
create_jet_relative_clim(exp, da_T, "_anom")
del da_T
# da_T = open_da("ERA5", "plev", "t300", "dailymean", *args)
# da_T = compute(da_T)
# create_jet_relative_clim(exp, da_T, "_anom")
# del da_T
# da_tp = open_da("ERA5", "surf", "tp", "dailysum", *args)
# da_tp = compute(da_tp)
# create_jet_relative_clim(exp, da_tp, "_anom")
# del da_tp
# da_apvs = open_da("ERA5", "thetalev", "apvs", "dailyany", *args)
# da_apvs = compute(da_apvs)
# create_jet_relative_clim(exp, da_apvs, "_anom")
# del da_apvs
# da_cpvs = open_da("ERA5", "thetalev", "cpvs", "dailyany", *args)
# da_cpvs = compute(da_cpvs)
# create_jet_relative_clim(exp, da_cpvs, "_anom")
# del da_cpvs

100%|██████████| 64/64 [05:34<00:00,  5.23s/it]


# arco-era5 tests

In [None]:
ds = xr.open_zarr(
    "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3",
    chunks=None,
    storage_options=dict(token="anon"),
)
ar_full_37_1h = ds.sel(
    time=slice(ds.attrs["valid_time_start"], ds.attrs["valid_time_stop"])
)

temp_full = (
    ar_full_37_1h["temperature"]
    .sel(
        time=ar_full_37_1h.time.dt.hour % 6 == 0,
        latitude=ar_full_37_1h.latitude >= 0,
        level=200,
    )
    .isel(longitude=slice(None, None, 2), latitude=slice(None, None, 2))
)

temp_full = standardize(temp_full).chunk("auto")

from pathlib import Path
base_path_1 = Path(f"{DATADIR}/ERA5/plev/t200/6H")
base_path_2 = Path(f"{DATADIR}/ERA5/plev/t200/dailymean")
# base_path_1.mkdir(parents=True)
# base_path_2.mkdir(parents=True)
for year in YEARS:
    print(year)
    opath_1 = base_path_1.joinpath(f"{year}.nc")
    opath_2 = base_path_2.joinpath(f"{year}.nc")

    if opath_2.is_file():
        continue
    this_temp = temp_full.sel(time=temp_full.time.dt.year == year)
    this_temp = this_temp.reset_coords("lev", drop=True)
    this_temp = compute(this_temp, progress_flag=True)
    this_temp.to_netcdf(opath_1)
    
    this_temp = this_temp.resample(time="1d").mean()
    this_temp.to_netcdf(opath_2)

In [5]:
temp_full

In [None]:
ds = xr.open_zarr(
    "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3",
    chunks=None,
    storage_options=dict(token="anon"),
)
ar_full_37_1h = ds.sel(
    time=slice(ds.attrs["valid_time_start"], ds.attrs["valid_time_stop"])
)

temp_full = (
    ar_full_37_1h["temperature"]
    .sel(
        time=ar_full_37_1h.time.dt.hour % 6 == 0,
        latitude=ar_full_37_1h.latitude >= 0,
        level=[175, 200, 225, 250, 300, 350],
    )
    .isel(longitude=slice(None, None, 2), latitude=slice(None, None, 2))
)

temp_full = standardize(temp_full)

orig_path = Path(f"{DATADIR}/ERA5/plev/flat_wind/dailymean")
base_path = Path(f"{DATADIR}/ERA5/plev/flat_wind/dailymean_2")
for year in tqdm(YEARS):
    for month in trange(1, 13, leave=False):
        month_str = str(month).zfill(2)
        opath = base_path.joinpath(f"{year}{month_str}.nc")
        if opath.is_file():
            continue
        ipath = orig_path.joinpath(f"{year}{month_str}.nc")
        ds = xr.open_dataset(ipath)
        this_temp = temp_full.sel(time=ds.time.values, lev=ds["lev"])
        this_temp = this_temp * (1000 / this_temp.lev) ** KAPPA
        this_temp = this_temp.reset_coords("lev", drop=True)
        ds["theta"] = compute(this_temp, progress_flag=True)
        ds.to_netcdf(opath)

# CESM

### new download with urls

## newnew merger script: download then postprocess:

In [45]:
from concurrent.futures import ThreadPoolExecutor, as_completed
from jetutils.definitions import DATADIR, KAPPA, compute
from jetutils.data import standardize, flatten_by, extract
import intake
import numpy as np
import xarray as xr
from pathlib import Path
from dask.diagnostics import ProgressBar

varname = "PRECL"
component = "atm" # for land variables like RAIN, "atm" for atmospheric variables like wind, and "ocn" for ocean variables
forcing_variant = "cmip6" # other option is "smbb", which stands for "SMoothed Biomass Burning"
out_path = Path(DATADIR, "CESM2", varname)
minlon, maxlon, minlat, maxlat = None, None, 0, 90
levels = None
years = {
    "past": np.arange(1970, 2010),
    "future": np.arange(2060, 2100),
}

col_url = (
    "https://ncar-cesm2-lens.s3-us-west-2.amazonaws.com/catalogs/aws-cesm2-le.json"
)
catalog = intake.open_esm_datastore(col_url)

catalog_subset = catalog.search(variable=varname, frequency='daily', forcing_variant=forcing_variant)
dsets = catalog_subset.to_dataset_dict(storage_options={'anon':True})

ds_past = dsets[f"{component}.historical.daily.{forcing_variant}"]
ds_future = dsets[f"{component}.ssp370.daily.{forcing_variant}"]

ds_past_ns = (
    standardize(ds_past)
    .reset_coords("time_bnds", drop=True)
    .squeeze()
    .isel(time=np.isin(ds_past.time.dt.year, years["past"]))
    .sel(lon=slice(minlon, maxlon))
    .sel(lat=slice(minlat, maxlat))
)
ds_future_ns = (
    standardize(ds_future)
    .reset_coords("time_bnds", drop=True)
    .squeeze()
    .isel(time=np.isin(ds_future.time.dt.year, years["future"]))
    .sel(lon=slice(minlon, maxlon))
    .sel(lat=slice(minlat, maxlat))
)
if levels is not None and "lev" in ds_past_ns.dims:
    ds_past_ns = ds_past_ns.isel(lev=levels)
    ds_future_ns = ds_future_ns.isel(lev=levels)

opath = out_path.joinpath("historical")
opath.mkdir(parents=True, exist_ok=True)
for varname in ds_past_ns.data_vars:
    ds_past_ns[varname] = ds_past_ns[varname].drop_encoding()
saved = ds_past_ns.to_zarr(opath.joinpath("ds.zarr"), compute=False, mode="w")
with ProgressBar():
    saved.compute()
    
opath = out_path.joinpath("ssp370")
opath.mkdir(parents=True, exist_ok=True)
for varname in ds_future_ns.data_vars:
    ds_future_ns[varname] = ds_future_ns[varname].drop_encoding()
saved = ds_future_ns.to_zarr(opath.joinpath("ds.zarr"), compute=False, mode="w")
with ProgressBar():
    saved.compute()
#     ds_past_ns = ds_past_ns.load()
# ds_past_ns.to_netcdf(out_path.joinpath(out_name_past))
# del ds_past_ns # free up memory

# with ProgressBar():
#     ds_future_ns = ds_future_ns.load()
# ds_future_ns.to_netcdf(out_path.joinpath(out_name_future))


--> The keys in the returned dictionary of datasets are constructed as follows:
	'component.experiment.frequency.forcing_variant'


[##################                      ] | 45% Completed | 10m 57ss


KeyboardInterrupt: 

### new cesm zarrification

In [3]:
basepath = Path("/storage/workspaces/giub_meteo_impacts/ci01/CESM2/high_wind/ssp370")
paths = list(basepath.glob("*.nc"))
names = [path.stem.split("-") for path in paths]
members = [name[0] for name in names]
years = [name[1] for name in names]
for i, member in enumerate(tqdm(np.unique(members))):
    da = xr.open_mfdataset(basepath.joinpath(f"{member}-*.nc").as_posix())
    kwargs = {"mode": "w"} if i == 0 else {"mode": "a", "append_dim": "member"}
    da["member"] = da["member"].astype("<U15")
    da = da.expand_dims("member").copy(deep=True)
    break
    # da.to_zarr(basepath.joinpath("ds.zarr"), **kwargs)

  0%|          | 0/50 [00:03<?, ?it/s]
