This notebook will prepare the input datasets (by loading from era5) and save them to /work/milesep/convective_outlook_ml (2TB limit)
This should work without ever loading the data into memory

In [1]:
import numpy as np
import xarray as xr
import pandas as pd

In [2]:
# specifying sizes and thinnings

lat_dict = {
    'full': slice(50, 25),
    'small': slice(45, 30),
    'slgt_small': slice(50, 25)
}

lon_dict = {
    'full': slice(360-125, 360-66),
    'small': slice(360-105, 360-85),
    'slgt_small': slice(360-125, 360-66)
}

levels_dict = {
    'full': [925, 850, 700, 500, 300],
    'small': [925, 850, 700, 500, 300],
    'slgt_small': [925, 850, 700, 500, 300]
}

time_thin_dict = {
    'full': 1,
    'small': 6,
    'slgt_small': 6
}

space_thin_dict = {
    'full': 1,
    'small': 4,
    'slgt_small': 4
}

risk_level_dict = {
    'full': ['MDT', 'HIGH'],
    'small': ['MDT', 'HIGH'],
    'slgt_small': ['SLGT', 'ENH', 'MDT', 'HIGH']
}

pressure_var_dict = {
    'full': ["geopotential", "potential_vorticity", "specific_humidity", "temperature", "u_component_of_wind", "v_component_of_wind", "vertical_velocity"],
    'small': ["geopotential", "potential_vorticity", "specific_humidity", "temperature", "u_component_of_wind", "v_component_of_wind", "vertical_velocity"],
    'slgt_small': ["geopotential", "potential_vorticity", "specific_humidity", "temperature", "u_component_of_wind", "v_component_of_wind", "vertical_velocity"]
}

surface_var_dict = {
    'full': ["10m_u_component_of_wind", "10m_v_component_of_wind", "2m_dewpoint_temperature", "2m_temperature", "geopotential_at_surface", "toa_incident_solar_radiation"],
    'small': ["10m_u_component_of_wind", "10m_v_component_of_wind", "2m_dewpoint_temperature", "2m_temperature", "geopotential_at_surface", "toa_incident_solar_radiation"],
    'slgt_small': ["10m_u_component_of_wind", "10m_v_component_of_wind", "2m_dewpoint_temperature", "2m_temperature", "geopotential_at_surface", "toa_incident_solar_radiation"]
}

In [3]:
detail = 'slgt_small'

In [4]:
full_ds = xr.open_zarr(
    'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3',
    chunks=None,
    storage_options=dict(token='anon'),
)

Select just desired days

In [5]:
pph = xr.load_dataset('data/raw_data/labelled_pph.nc')

In [6]:
missing_dates = ['200204250000', '200208300000', '200304150000', '200304160000', '200306250000', '200307270000', '200307280000', '200312280000', '200404140000', '200408090000', '200905280000', '201105210000', '202005240000', '200510240000']
dates_of_interest = pph['time'][pph['MAX_CAT'].isin(risk_level_dict[detail])]
dates_of_interest = dates_of_interest[dates_of_interest > '200203310000']
dates_of_interest = dates_of_interest[~(dates_of_interest.isin(missing_dates))]
selected_days = pd.to_datetime(dates_of_interest.values, format="%Y%m%d%H%M").normalize()

In [7]:
time_days = full_ds.time.dt.floor('D')
full_ds = full_ds.sel(time=full_ds.time[np.isin(time_days, selected_days)])

Select lat/lon domain

In [8]:
# select just US
ds = full_ds.sel(latitude = lat_dict[detail], longitude = lon_dict[detail])

In [9]:
ds = ds.chunk({'time': 240})

Select just MDT+ (or SLGT+ days)

In [10]:
# split time into date and TOD
ds = ds.assign_coords(
    day=ds.time.dt.floor('D'),
    tod=ds.time.dt.hour
)
ds = ds.set_index(time=['day', 'tod']).unstack('time')

    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array.reshape(shape)

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    >>> array.reshape(shape, limit='128 MiB')
  return func(*args, **kwargs)
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array.reshape(shape)

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    >>> array.reshape(shape, limit='128 MiB')
  return func(*args, **kwargs)
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array.reshape(shape)

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    >>> array.reshape(shape, limit='128 MiB')
  return func(*args, **kwargs)
    >>> with dask.config.set(**{'array.slicing.split_la

In [11]:
ds = ds.thin({'latitude': space_thin_dict[detail], 'longitude': space_thin_dict[detail], 'tod': time_thin_dict[detail]})

Select just desired variables

In [12]:
# Subset each desired pressure-level variable at desired levels
ds_pl = xr.Dataset()
for var in pressure_var_dict[detail]:
    if var in ds:
        ds_pl[var] = ds[var].sel(level=levels_dict[detail])

ds_sfc = ds[surface_var_dict[detail]]

ds_final = xr.merge([ds_pl, ds_sfc])

Save inputs

In [None]:
ds_final.to_zarr("/glade/work/milesep/convective_outlook_ml/inputs_raw_" + detail + ".zarr", mode="w", consolidated=True)

To estimate size

In [13]:
def estimate_dataset_size_bytes(ds):
    total_bytes = 0
    for var in ds.data_vars.values():
        if var.chunks is not None:
            total_bytes += var.nbytes
        else:
            # Use 64-bit integers to avoid overflow
            n_elements = np.prod(var.shape, dtype=np.int64)
            dtype_size = np.dtype(var.dtype).itemsize
            total_bytes += int(n_elements * dtype_size)
    return total_bytes


size_bytes = estimate_dataset_size_bytes(ds_final)
print(f"Estimated uncompressed size: {size_bytes / 1e9:.2f} GB")

Estimated uncompressed size: 4.17 GB
