In [4]:
import datetime

import colormaps
import matplotlib.pyplot as plt
import numpy as np
import polars as pl
import xarray as xr
from jetutils.anyspell import get_persistent_jet_spells, mask_from_spells_pl, subset_around_onset
from jetutils.clustering import Experiment
from jetutils.data import DataHandler, open_da
from jetutils.definitions import (
    DATADIR,
    YEARS,
    PRETTIER_VARNAME,
    compute,
    get_region,
    infer_direction,
    polars_to_xarray,
    xarray_to_polars,
)
from jetutils.jet_finding import JetFindingExperiment, gather_normal_da_jets, iterate_over_year_maybe_member
from jetutils.plots import COLORS, Clusterplot, gather_normal_da_jets_wrapper, interp_jets_to_zero_one
from matplotlib.cm import ScalarMappable
from matplotlib.colors import BoundaryNorm
from matplotlib.ticker import MaxNLocator
from tqdm import tqdm

%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# create jet relative climatologies

In [6]:
exp = JetFindingExperiment(DataHandler(f"{DATADIR}/ERA5/plev/high_wind/6H/results/1"))

Found config override file at  /storage/homefs/hb22g102/.jetutils.ini
Guessed N_WORKERS :  1
Guessed MEMORY_LIMIT :  153600


In [7]:
def create_jet_relative_clim(exp, da, suffix=""):
    all_jets_one_df = exp.find_jets()
    jets = all_jets_one_df.with_columns(pl.col("time").dt.round("1d"))
    jets = jets.with_columns(jets.group_by("time", maintain_order=True).agg(pl.col("jet ID").rle_id())["jet ID"].explode())
    indexer = iterate_over_year_maybe_member(jets, da)
    to_average = []
    for idx1, idx2 in tqdm(indexer, total=len(YEARS)):
        jets_ = jets.filter(*idx1)
        da_ = da.sel(**idx2)
        try:
            jets_with_interp = gather_normal_da_jets(jets_, da_, half_length=20)
        except (KeyError, ValueError):
            break
        varname = da_.name + "_interp"
        jets_with_interp = interp_jets_to_zero_one(jets_with_interp, [varname, "is_polar"])
        jets_with_interp = jets_with_interp.group_by("time", pl.col("is_polar") > 0.5, "norm_index", "n", maintain_order=True).agg(pl.col(varname).mean() )
        to_average.append(jets_with_interp)
    to_average = pl.concat(to_average)
    clim = to_average.group_by(pl.col("time").dt.ordinal_day().alias("dayofyear"), "is_polar", "norm_index", "n").agg(pl.col(varname).mean()).sort("dayofyear", "is_polar", "norm_index", "n")
    clim_ds = polars_to_xarray(clim, ["dayofyear", "is_polar", "n", "norm_index"])
    clim_ds.to_netcdf(exp.path.joinpath(f"{da.name}{suffix}_relative_clim.nc"))

In [None]:
da_T = open_da("ERA5", "surf", "t2m", "dailymean", "all", None, *get_region(exp.ds), "all", 'dayofyear', {'dayofyear': ('win', 15)}, None,)
da_T = compute(da_T)
create_jet_relative_clim(exp, da_T, "_anom")
del da_T
da_tp = open_da("ERA5", "surf", "tp", "dailysum", "all", None, *get_region(exp.ds), "all", 'dayofyear', {'dayofyear': ('win', 15)}, None,)
da_tp = compute(da_tp)
create_jet_relative_clim(exp, da_tp, "_anom")
del da_tp
da_apvs = open_da("ERA5", "thetalev", "apvs", "dailyany", "all", None, *get_region(exp.ds), "all", 'dayofyear', {'dayofyear': ('win', 15)}, None,)
da_apvs = compute(da_apvs)
create_jet_relative_clim(exp, da_apvs, "_anom")
del da_apvs
da_cpvs = open_da("ERA5", "thetalev", "cpvs", "dailyany", "all", None, *get_region(exp.ds), "all", 'dayofyear', {'dayofyear': ('win', 15)}, None,)
da_cpvs = compute(da_cpvs)
create_jet_relative_clim(exp, da_cpvs, "_anom")
del da_cpvs

100%|██████████| 64/64 [20:25<00:00, 19.14s/it]
100%|██████████| 64/64 [21:49<00:00, 20.46s/it]
 72%|███████▏  | 46/64 [16:45<06:40, 22.23s/it]

In [11]:
create_jet_relative_clim(exp, da_T, "_anom")
create_jet_relative_clim(exp, da_tp, "_anom")
create_jet_relative_clim(exp, da_apvs, "_anom")
create_jet_relative_clim(exp, da_cpvs, "_anom")

  0%|          | 0/64 [00:02<?, ?it/s]


ValueError: cannot concat empty list

In [6]:
da_cpvs = open_da(
    "ERA5", "thetalev", "cpvs", "dailyany", "all", None, -100, 60, 0, 90, "all",
)
da_cpvs = compute(da_cpvs)
create_jet_relative_clim(exp, da_cpvs)

100%|██████████| 64/64 [19:06<00:00, 17.91s/it]


In [4]:
da_t2m = open_da("ERA5", "surf", "t2m", "dailymean", "all")
create_jet_relative_clim(exp, da_t2m)

100%|██████████| 64/64 [16:37<00:00, 15.59s/it]


In [10]:
da_tp = open_da("ERA5", "surf", "tp", "dailysum", "all")
create_jet_relative_clim(exp, da_tp)

100%|██████████| 64/64 [15:14<00:00, 14.28s/it]


In [11]:
da_apvs = open_da("ERA5", "thetalev", "apvs", "dailymean", "all", levels=350)
create_jet_relative_clim(exp, da_apvs)

100%|██████████| 64/64 [21:16<00:00, 19.94s/it]


In [12]:
da_cpvs = open_da("ERA5", "thetalev", "cpvs", "dailymean", "all", levels=350)
create_jet_relative_clim(exp, da_cpvs)

100%|██████████| 64/64 [19:48<00:00, 18.56s/it]


# arco-era5 tests

In [None]:
ds = xr.open_zarr(
    "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3",
    chunks=None,
    storage_options=dict(token="anon"),
)
ar_full_37_1h = ds.sel(
    time=slice(ds.attrs["valid_time_start"], ds.attrs["valid_time_stop"])
)

base_ds = standardize(ar_full_37_1h["total_precipitation"].chunk("auto"))
base_ds = (
    base_ds
    .sel(
        lat=base_ds.lat >= 0,
        time=np.isin(base_ds.time.dt.year, YEARS)
    )
    .isel(lon=slice(None, None, 2), lat=slice(None, None, 2))
)

six_hourly = base_ds.resample(time="6h").sum()
daily = six_hourly.resample(time="1d").sum()
six_hourly = six_hourly * 4

base_path_1 = Path(f"{DATADIR}/ERA5/surf/tp/6H")
base_path_1.mkdir(exist_ok=True, parents=True)
base_path_2 = Path(f"{DATADIR}/ERA5/surf/tp/dailysum")
base_path_2.mkdir(exist_ok=True, parents=True)
for year in YEARS:
    opath_1 = base_path_1.joinpath(f"{year}.nc")
    opath_2 = base_path_2.joinpath(f"{year}.nc")
    if not opath_1.is_file():
        six_hourly_ = compute(six_hourly.sel(time=six_hourly.time.dt.year == year), progress_flag=True)
        six_hourly_.to_netcdf(opath_1)
    if not opath_2.is_file():
        daily_ = compute(daily.sel(time=daily.time.dt.year == year), progress_flag=True)
        daily_.to_netcdf(opath_2)
    print(f"Completed {year}")

In [17]:
six_hourly = base_ds.resample(time="6h").sum()
daily = six_hourly.resample(time="1d").sum()
six_hourly = six_hourly * 4

In [None]:
ds = xr.open_zarr(
    "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3",
    chunks=None,
    storage_options=dict(token="anon"),
)
ar_full_37_1h = ds.sel(
    time=slice(ds.attrs["valid_time_start"], ds.attrs["valid_time_stop"])
)

temp_full = (
    ar_full_37_1h["temperature"]
    .sel(
        time=ar_full_37_1h.time.dt.hour % 6 == 0,
        latitude=ar_full_37_1h.latitude >= 0,
        level=[175, 200, 225, 250, 300, 350],
    )
    .isel(longitude=slice(None, None, 2), latitude=slice(None, None, 2))
)

temp_full = standardize(temp_full)

In [None]:
ds = xr.open_zarr(
    "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3",
    chunks=None,
    storage_options=dict(token="anon"),
)
ar_full_37_1h = ds.sel(
    time=slice(ds.attrs["valid_time_start"], ds.attrs["valid_time_stop"])
)

temp_full = (
    ar_full_37_1h["temperature"]
    .sel(
        time=ar_full_37_1h.time.dt.hour % 6 == 0,
        latitude=ar_full_37_1h.latitude >= 0,
        level=[175, 200, 225, 250, 300, 350],
    )
    .isel(longitude=slice(None, None, 2), latitude=slice(None, None, 2))
)

temp_full = standardize(temp_full)

orig_path = Path(f"{DATADIR}/ERA5/plev/flat_wind/dailymean")
base_path = Path(f"{DATADIR}/ERA5/plev/flat_wind/dailymean_2")
for year in tqdm(YEARS):
    for month in trange(1, 13, leave=False):
        month_str = str(month).zfill(2)
        opath = base_path.joinpath(f"{year}{month_str}.nc")
        if opath.is_file():
            continue
        ipath = orig_path.joinpath(f"{year}{month_str}.nc")
        ds = xr.open_dataset(ipath)
        this_temp = temp_full.sel(time=ds.time.values, lev=ds["lev"])
        this_temp = this_temp * (1000 / this_temp.lev) ** KAPPA
        this_temp = this_temp.reset_coords("lev", drop=True)
        ds["theta"] = compute(this_temp, progress_flag=True)
        ds.to_netcdf(opath)

# new pvs das: any() over levels

In [13]:
for year in tqdm(YEARS):
    opath = Path("/storage/workspaces/giub_meteo_impacts/ci01/ERA5/thetalev/apvs/dailyany", f"{year}.nc")
    if opath.is_file():
        continue
    da = open_da("ERA5", "thetalev", "apvs", "6H", [year], None, None, None, None, None, "all").astype(np.int8).any("lev").resample(time="1D").any()
    da = compute(da)
    to_netcdf(da, opath)

  0%|          | 0/64 [00:09<?, ?it/s]


KeyboardInterrupt: 

# CESM

### new download with urls

In [None]:
import datetime
from itertools import pairwise
from pathlib import Path
from jetstream_hugo.definitions import DATADIR, compute
from jetstream_hugo.data import standardize
import numpy as np 
import xarray as xr

experiment_dict = {
    "past": "BHISTcmip6",
    "future": "BSSP370cmip6",
}
yearbounds = {
    "past": np.arange(1960, 2021, 10),
    "future": np.arange(2045, 2106, 10),
}
yearbounds["past"][-1] = yearbounds["past"][-1] - 5
yearbounds["future"][-1] = yearbounds["future"][-1] - 4
timebounds = {key: [f"{year1}0101-{year2 - 1}1231" for year1, year2 in pairwise(val)] for key, val in yearbounds.items()}

members = [f"{year}.{str(number).zfill(3)}" for year, number in zip(range(1001, 1201, 20), range(1, 11))]
for startyear in [1231, 1251, 1281, 1301]:
    members.extend(f"{startyear}.{str(number).zfill(3)}" for number in range(1, 11))
    
members2 = [f"r{number}i{year}p1f1" for year, number in zip(range(1001, 1201, 20), range(1, 11))]
for startyear in [1231, 1251, 1281, 1301]:
    members2.extend(f"r{number}i{startyear}p1f1" for number in range(1, 11))
    
season = None
minlon = -180
maxlon = 180
minlat = 0
maxlat = 90
    
    
def get_url(varname: str, period: str, member: str, timebounds: str):
    experiment = experiment_dict[period]
    h = 6 if varname in ["U", "V", "T"] else 1

    return fr"https://tds.ucar.edu/thredds/fileServer/datazone/campaign/cgd/cesm/CESM2-LE/atm/proc/tseries/day_1/{varname}/b.e21.{experiment}.f09_g17.LE2-{member}.cam.h{h}.{varname}.{timebounds}.nc?api-token=ayhBFVYTOtGi2LM2cHDn6DjFCoKeCAqt69z8Ezt4#mode=bytes"

basepath = Path(f"{DATADIR}/CESM2/high_wind/ssp370")
var = "T"
period = "future"
for member1, member2 in zip(members, members2):
    opath = basepath.joinpath(f"{member2}.nc")
    if opath.is_file():
        print(member1, "already there")
        continue
    da = []
    for tb in timebounds["future"]:
        da.append(
            standardize(xr.open_dataset(
                get_url(var, period, member1,tb),            
                engine="h5netcdf"
            )[var])
        )
    da = xr.concat(da, "time")
    ds = xr.open_mfdataset(basepath.glob(f"{member2}-????.nc"))

    for coord in ["lon", "lat", "lev"]:
        da[coord] = da[coord].astype(np.float32)

    da["time"] = da.indexes["time"].to_datetimeindex(time_unit="us") + datetime.timedelta(hours=12)
    ds["time"] = ds.indexes["time"].to_datetimeindex(time_unit="us")
    da_ = da.sel(time=ds.time.values, lon=ds.lon.values, lat=ds.lat.values)
    da_ = compute(da_.sel(lev=ds["lev"]), progress_flag=True)
    da_.to_netcdf(opath)
    print(member1, "done")

## newnew merger script: download then postprocess:

In [17]:
os.environ["TMPDIR"]

'/scratch/local/17548418'

In [None]:
import datetime
import os
from itertools import pairwise
from pathlib import Path
from jetstream_hugo.definitions import DATADIR, compute
from jetstream_hugo.data import standardize
import numpy as np 
import xarray as xr
from tqdm import tqdm
from urllib.request import urlretrieve

experiment_dict = {
    "past": "BHISTcmip6",
    "future": "BSSP370cmip6",
}
yearbounds = {
    "past": np.arange(1960, 2021, 10),
    "future": np.arange(2045, 2106, 10),
}
yearbounds["past"][-1] = yearbounds["past"][-1] - 5
yearbounds["future"][-1] = yearbounds["future"][-1] - 4
timebounds = {key: [f"{year1}0101-{year2 - 1}1231" for year1, year2 in pairwise(val)] for key, val in yearbounds.items()}

members = [f"{year}.{str(number).zfill(3)}" for year, number in zip(range(1001, 1201, 20), range(1, 11))]
for startyear in [1231, 1251, 1281, 1301]:
    members.extend(f"{startyear}.{str(number).zfill(3)}" for number in range(1, 11))
    
members2 = [f"r{number}i{year}p1f1" for year, number in zip(range(1001, 1201, 20), range(1, 11))]
for startyear in [1231, 1251, 1281, 1301]:
    members2.extend(f"r{number}i{startyear}p1f1" for number in range(1, 11))
    
season = None
minlon = -180
maxlon = 180
minlat = 0
maxlat = 90
    
    
def get_url(varname: str, period: str, member: str, timebounds: str):
    experiment = experiment_dict[period]
    h = 6 if varname in ["U", "V", "T"] else 1

    return fr"https://tds.ucar.edu/thredds/fileServer/datazone/campaign/cgd/cesm/CESM2-LE/atm/proc/tseries/day_1/{varname}/b.e21.{experiment}.f09_g17.LE2-{member}.cam.h{h}.{varname}.{timebounds}.nc?api-token=ayhBFVYTOtGi2LM2cHDn6DjFCoKeCAqt69z8Ezt4#mode=bytes"


class DownloadProgressBar(tqdm):
    def update_to(self, b=1, bsize=1, tsize=None):
        if tsize is not None:
            self.total = tsize
        self.update(b * bsize - self.n)
        
        
def download_url(url, output_path):
    with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=url.split('/')[-1]) as t:
        urlretrieve(url, filename=output_path, reporthook=t.update_to)


basepath = Path(f"{DATADIR}/CESM2/high_wind/ssp370")
scratchdir = Path(os.environ["TMPDIR"], "tmp_T_cesm_downloads_hbanderi")
scratchdir.mkdir(exist_ok=True)
var = "T"
period = "future"
for member1, member2 in zip(members, members2):
    opath = basepath.joinpath(f"{member2}.nc")
    if opath.is_file():
        print(member1, "already there")
        continue
    scratchpaths = []
    for tb in timebounds["future"]:
        url = get_url(var, period, member1,tb)
        scratchpath = scratchdir.joinpath(url.split("/")[-1].split("?")[0])
        scratchpaths.append(scratchpath)
        if scratchpath.is_file():
            continue
        download_url(url, scratchpath)
    da = []
    for scratchpath in scratchpaths:
        da.append(
            standardize(xr.open_dataset(
                scratchpath,            
                engine="h5netcdf"
            )[var])
        )
    da = xr.concat(da, "time")
    ds = xr.open_mfdataset(basepath.glob(f"{member2}-????.nc"))

    for coord in ["lon", "lat", "lev"]:
        da[coord] = da[coord].astype(np.float32)

    da["time"] = da.indexes["time"].to_datetimeindex(time_unit="us") + datetime.timedelta(hours=12)
    ds["time"] = ds.indexes["time"].to_datetimeindex(time_unit="us")
    da_ = da.sel(time=ds.time.values, lon=ds.lon.values, lat=ds.lat.values)
    da_ = compute(da_.sel(lev=ds["lev"]), progress_flag=True)
    da_.to_netcdf(opath)
    for scratchpath in scratchpaths:
        os.remove(scratchpath)
    print(member1, "done")

In [12]:
opath

PosixPath('/scratch/tmp_T_cesm_downloads/b.e21.BSSP370cmip6.f09_g17.LE2-1061.004.cam.h6.T.20450101-20541231.nc')

### new cesm zarrification

In [3]:
basepath = Path("/storage/workspaces/giub_meteo_impacts/ci01/CESM2/high_wind/ssp370")
paths = list(basepath.glob("*.nc"))
names = [path.stem.split("-") for path in paths]
members = [name[0] for name in names]
years = [name[1] for name in names]
for i, member in enumerate(tqdm(np.unique(members))):
    da = xr.open_mfdataset(basepath.joinpath(f"{member}-*.nc").as_posix())
    kwargs = {"mode": "w"} if i == 0 else {"mode": "a", "append_dim": "member"}
    da["member"] = da["member"].astype("<U15")
    da = da.expand_dims("member").copy(deep=True)
    break
    # da.to_zarr(basepath.joinpath("ds.zarr"), **kwargs)

  0%|          | 0/50 [00:03<?, ?it/s]
