In [1]:
import cdsapi
import xarray as xr
import pandas as pd
import numpy as np
from pathlib import Path
import time
import os

In [2]:
c = cdsapi.Client()

In [3]:
# specifying sizes and thinnings

lat_dict = {
    'full': slice(50, 25),
    'small': slice(45, 30),
    'slgt_small': slice(50, 25)
}

lon_dict = {
    'full': slice(360-125, 360-66),
    'small': slice(360-105, 360-85),
    'slgt_small': slice(360-125, 360-66)
}

levels_dict = {
    'full': [925, 850, 700, 500, 300],
    'small': [925, 850, 700, 500, 300],
    'slgt_small': [925, 850, 700, 500, 300]
}

time_thin_dict = {
    'full': 1,
    'small': 6,
    'slgt_small': 6
}

space_thin_dict = {
    'full': 1,
    'small': 4,
    'slgt_small': 4
}

risk_level_dict = {
    'full': ['MDT', 'HIGH'],
    'small': ['MDT', 'HIGH'],
    'slgt_small': ['SLGT', 'ENH', 'MDT', 'HIGH']
}

pressure_var_dict = {
    'full': ["geopotential", "potential_vorticity", "specific_humidity", "temperature", "u_component_of_wind", "v_component_of_wind", "vertical_velocity"],
    'small': ["geopotential", "potential_vorticity", "specific_humidity", "temperature", "u_component_of_wind", "v_component_of_wind", "vertical_velocity"],
    'slgt_small': ["geopotential", "potential_vorticity", "specific_humidity", "temperature", "u_component_of_wind", "v_component_of_wind", "vertical_velocity"]
}

surface_var_dict = {
    'full': ["10m_u_component_of_wind", "10m_v_component_of_wind", "2m_dewpoint_temperature", "2m_temperature", "geopotential_at_surface", "toa_incident_solar_radiation"],
    'small': ["10m_u_component_of_wind", "10m_v_component_of_wind", "2m_dewpoint_temperature", "2m_temperature", "geopotential_at_surface", "toa_incident_solar_radiation"],
    'slgt_small': ["10m_u_component_of_wind", "10m_v_component_of_wind", "2m_dewpoint_temperature", "2m_temperature", "geopotenital_at_surface", "toa_incident_solar_radiation"]
}

In [4]:
detail = 'slgt_small'

In [5]:
# --- risk days
pph = xr.load_dataset("data/raw_data/labelled_pph.nc")
missing_dates = [
    '200204250000', '200208300000', '200304150000', '200304160000',
    '200306250000', '200307270000', '200307280000', '200312280000',
    '200404140000', '200408090000', '200905280000', '201105210000',
    '202005240000', '200510240000'
]
dates_of_interest = pph["time"][pph["MAX_CAT"].isin(risk_level_dict[detail])]
dates_of_interest = dates_of_interest[dates_of_interest > "200203310000"]
dates_of_interest = dates_of_interest[~(dates_of_interest.isin(missing_dates))]
selected_days = pd.to_datetime(dates_of_interest.values, format="%Y%m%d%H%M").normalize()

years = np.unique(selected_days.year)

In [6]:
out_dir = Path("/glade/work/milesep/era5_cds")
out_dir.mkdir(parents=True, exist_ok=True)

pl_files = []
sfc_files = []

# --- derive requested hours directly from thin factor
hours = [f"{h:02d}:00" for h in range(0, 24, time_thin_dict[detail])]

In [7]:
def safe_retrieve(dataset, request, target, max_retries=5, wait=30):
    """
    Robust CDSAPI download:
    - Writes to .part file first
    - Retries with exponential backoff if download fails
    """
    tmp_target = target.with_suffix(".nc.part")

    for attempt in range(1, max_retries + 1):
        try:
            c.retrieve(dataset, request, str(tmp_target))
            tmp_target.rename(target)  # rename only after success
            print(f"✅ Downloaded: {target}")
            return target
        except Exception as e:
            print(f"⚠️ Attempt {attempt} failed for {target}: {e}")
            if tmp_target.exists():
                tmp_target.unlink()  # clean up bad partials
            if attempt < max_retries:
                sleep_time = wait * attempt
                print(f"Retrying in {sleep_time}s...")
                time.sleep(sleep_time)
            else:
                raise RuntimeError(f"Failed to download {target} after {max_retries} attempts.")

In [8]:
area = [lat_dict[detail].start,
        lon_dict[detail].start - 360,
        lat_dict[detail].stop,
        lon_dict[detail].stop - 360]

levels = [str(l) for l in levels_dict[detail]]

all_sfc_vars = surface_var_dict[detail]
possible_accum_vars = ['toa_incident_solar_radiation']

sfc_inst_vars = [x for x in all_sfc_vars if x not in possible_accum_vars]
sfc_inst_vars = ['geopotential' if x == 'geopotenital_at_surface' else x for x in sfc_inst_vars]
sfc_accum_vars = [x for x in all_sfc_vars if x in possible_accum_vars]

In [9]:
for year in years:
    days_this_year = selected_days[selected_days.year == year]

    for month in sorted(set(days_this_year.month)):
        days_this_month = days_this_year[days_this_year.month == month]
        days = sorted({f"{d.day:02d}" for d in days_this_month})
        month_str = f"{month:02d}"

        print(year, month, len(days_this_month))

        # ------------------ Pressure levels ------------------
        pl_file = out_dir / f"era5_pl_{year}_{month_str}.nc"
        if pl_file.exists():
            try:
                xr.open_dataset(pl_file).close()
                print(f"Skipping (exists): {pl_file}")
                pl_files.append(pl_file)
            except Exception:
                print(f"Corrupt file detected, redownloading: {pl_file}")
                pl_file.unlink()
        if not pl_file.exists():
            safe_retrieve(
                "reanalysis-era5-pressure-levels",
                {
                    "product_type": "reanalysis",
                    "format": "netcdf",
                    "variable": pressure_var_dict[detail],
                    "pressure_level": levels,
                    "year": str(year),
                    "month": month_str,
                    "day": days,
                    "time": hours,
                    "area": area,
                },
                pl_file,
            )
            pl_files.append(pl_file)

        # ------------------ Single levels: instantaneous ------------------
        sfc_inst_file = out_dir / f"era5_sfc_inst_{year}_{month_str}.nc"
        if sfc_inst_file.exists():
            try:
                xr.open_dataset(sfc_inst_file).close()
                print(f"Skipping (exists): {sfc_inst_file}")
                sfc_files.append(sfc_inst_file)
            except Exception:
                print(f"Corrupt file detected, redownloading: {sfc_inst_file}")
                sfc_inst_file.unlink()
        if not sfc_inst_file.exists() and sfc_inst_vars:
            safe_retrieve(
                "reanalysis-era5-single-levels",
                {
                    "product_type": "reanalysis",
                    "format": "netcdf",
                    "variable": sfc_inst_vars,
                    "year": str(year),
                    "month": month_str,
                    "day": days,
                    "time": hours,
                    "area": area,
                },
                sfc_inst_file,
            )
            sfc_files.append(sfc_inst_file)

        # ------------------ Single levels: accumulated ------------------
        sfc_accum_file = out_dir / f"era5_sfc_accum_{year}_{month_str}.nc"
        if sfc_accum_file.exists():
            try:
                xr.open_dataset(sfc_accum_file).close()
                print(f"Skipping (exists): {sfc_accum_file}")
                sfc_files.append(sfc_accum_file)
            except Exception:
                print(f"Corrupt file detected, redownloading: {sfc_accum_file}")
                sfc_accum_file.unlink()
        if not sfc_accum_file.exists() and sfc_accum_vars:
            safe_retrieve(
                "reanalysis-era5-single-levels",
                {
                    "product_type": "reanalysis",
                    "format": "netcdf",
                    "variable": sfc_accum_vars,
                    "year": str(year),
                    "month": month_str,
                    "day": days,
                    "time": hours,
                    "area": area,
                },
                sfc_accum_file,
            )
            sfc_files.append(sfc_accum_file)

2002 4 23
Skipping (exists): /glade/work/milesep/era5_cds/era5_pl_2002_04.nc
Skipping (exists): /glade/work/milesep/era5_cds/era5_sfc_inst_2002_04.nc
Skipping (exists): /glade/work/milesep/era5_cds/era5_sfc_accum_2002_04.nc
2002 5 28
Skipping (exists): /glade/work/milesep/era5_cds/era5_pl_2002_05.nc
Skipping (exists): /glade/work/milesep/era5_cds/era5_sfc_inst_2002_05.nc
Skipping (exists): /glade/work/milesep/era5_cds/era5_sfc_accum_2002_05.nc
2002 6 30
Skipping (exists): /glade/work/milesep/era5_cds/era5_pl_2002_06.nc
Skipping (exists): /glade/work/milesep/era5_cds/era5_sfc_inst_2002_06.nc
Skipping (exists): /glade/work/milesep/era5_cds/era5_sfc_accum_2002_06.nc
2002 7 30
Skipping (exists): /glade/work/milesep/era5_cds/era5_pl_2002_07.nc
Skipping (exists): /glade/work/milesep/era5_cds/era5_sfc_inst_2002_07.nc
Skipping (exists): /glade/work/milesep/era5_cds/era5_sfc_accum_2002_07.nc
2002 8 30
Skipping (exists): /glade/work/milesep/era5_cds/era5_pl_2002_08.nc
Skipping (exists): /glade/w

In [10]:
pl = xr.open_mfdataset(pl_files, combine="by_coords")
sfc = xr.open_mfdataset(sfc_files, combine="by_coords").rename({"z": "z_sfc"})

ds = xr.merge([pl, sfc])
if "valid_time" in ds:
    ds = ds.rename(valid_time="time")  # or vice versa depending on your pipeline
    ds = ds.drop_vars(["number", "expver"], errors="ignore")

In [None]:
# subset exactly the selected days again (to be safe)
time_days = ds.time.dt.floor("D")
ds = ds.sel(time=ds.time[np.isin(time_days, selected_days)])

# add day/tod index, thin spatially
ds = ds.assign_coords(day=ds.time.dt.floor("D"), tod=ds.time.dt.hour)
ds = ds.set_index(time=["day", "tod"]).unstack("time")

ds = ds.drop_vars(["time"], errors="ignore")

ds = ds.thin({"latitude": space_thin_dict[detail], "longitude": space_thin_dict[detail]})

# Rechunk so that day has uniform chunks
ds = ds.chunk({"day": 30, "latitude": -1, "longitude": -1, "tod": -1})

ds.to_zarr(f"/glade/work/milesep/convective_outlook_ml/inputs_raw_{detail}_cds.zarr",
           mode="w", consolidated=True)


In [None]:
def estimate_dataset_size_bytes(ds):
    total_bytes = 0
    for var in ds.data_vars.values():
        if var.chunks is not None:
            total_bytes += var.nbytes
        else:
            # Use 64-bit integers to avoid overflow
            n_elements = np.prod(var.shape, dtype=np.int64)
            dtype_size = np.dtype(var.dtype).itemsize
            total_bytes += int(n_elements * dtype_size)
    return total_bytes


size_bytes = estimate_dataset_size_bytes(ds)
print(f"Estimated uncompressed size: {size_bytes / 1e9:.2f} GB")