This notebook will prepare the input datasets (by loading from era5) and save them to /work/milesep/convective_outlook_ml (2TB limit)
This should work without ever loading the data into memory

In [1]:
import numpy as np
import xarray as xr
import pandas as pd

In [2]:
full_ds = xr.open_zarr(
    'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3',
    chunks=None,
    storage_options=dict(token='anon'),
)

Select Just US

In [3]:
# select just US
ds = full_ds.sel(latitude = slice(50, 25), longitude = slice(360-125, 360-66))

Select just MDT+ (or SLGT+ days)

In [4]:
pph = xr.load_dataset('data/raw_data/labelled_pph.nc')

In [5]:
missing_dates = ['200204250000', '200208300000', '200304150000', '200304160000', '200306250000', '200307270000', '200307280000', '200312280000', '200404140000', '200408090000', '200905280000', '201105210000', '202005240000', '200510240000']
dates_of_interest = pph['time'][pph['MAX_CAT'].isin(['MDT', 'HIGH'])]
dates_of_interest = dates_of_interest[dates_of_interest > '200203310000']
dates_of_interest = dates_of_interest[~(dates_of_interest.isin(missing_dates))]

In [6]:
# select dates in ds_subset
dates = pd.to_datetime(dates_of_interest.str.slice(0, 8).values, format='%Y%m%d')

# Get the dates of each time value, dropping hours/minutes/seconds
time_dates = ds['time'].dt.floor('D')

# Subset to only the days in the date_list
ds = ds.sel(time=ds['time'].where(time_dates.isin(dates), drop=True))

Select just desired variables

In [7]:
desired_levels = [925, 850, 700, 500, 300]
pressure_vars = ["geopotential", "potential_vorticity", "specific_humidity", "temperature", "u_component_of_wind", "v_component_of_wind", "vertical_velocity"]

surface_vars = []

# Subset each desired pressure-level variable at desired levels
ds_pl = xr.Dataset()
for var in pressure_vars:
    if var in ds:
        ds_pl[var] = ds[var].sel(level=desired_levels)

surface_vars = ["10m_u_component_of_wind", "10m_v_component_of_wind", "2m_dewpoint_temperature", "2m_temperature", "geopotential_at_surface", "toa_incident_solar_radiation"]
ds_sfc = ds[surface_vars] 

ds_final = xr.merge([ds_pl, ds_sfc])

In [8]:
ds_final.chunk({'time': 24, 'latitude': 40, 'longitude': 40})

Unnamed: 0,Array,Chunk
Bytes,3.85 GiB,750.00 kiB
Shape,"(8640, 5, 101, 237)","(24, 5, 40, 40)"
Dask graph,6480 chunks in 2 graph layers,6480 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 3.85 GiB 750.00 kiB Shape (8640, 5, 101, 237) (24, 5, 40, 40) Dask graph 6480 chunks in 2 graph layers Data type float32 numpy.ndarray",8640  1  237  101  5,

Unnamed: 0,Array,Chunk
Bytes,3.85 GiB,750.00 kiB
Shape,"(8640, 5, 101, 237)","(24, 5, 40, 40)"
Dask graph,6480 chunks in 2 graph layers,6480 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,3.85 GiB,750.00 kiB
Shape,"(8640, 5, 101, 237)","(24, 5, 40, 40)"
Dask graph,6480 chunks in 2 graph layers,6480 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 3.85 GiB 750.00 kiB Shape (8640, 5, 101, 237) (24, 5, 40, 40) Dask graph 6480 chunks in 2 graph layers Data type float32 numpy.ndarray",8640  1  237  101  5,

Unnamed: 0,Array,Chunk
Bytes,3.85 GiB,750.00 kiB
Shape,"(8640, 5, 101, 237)","(24, 5, 40, 40)"
Dask graph,6480 chunks in 2 graph layers,6480 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,3.85 GiB,750.00 kiB
Shape,"(8640, 5, 101, 237)","(24, 5, 40, 40)"
Dask graph,6480 chunks in 2 graph layers,6480 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 3.85 GiB 750.00 kiB Shape (8640, 5, 101, 237) (24, 5, 40, 40) Dask graph 6480 chunks in 2 graph layers Data type float32 numpy.ndarray",8640  1  237  101  5,

Unnamed: 0,Array,Chunk
Bytes,3.85 GiB,750.00 kiB
Shape,"(8640, 5, 101, 237)","(24, 5, 40, 40)"
Dask graph,6480 chunks in 2 graph layers,6480 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,3.85 GiB,750.00 kiB
Shape,"(8640, 5, 101, 237)","(24, 5, 40, 40)"
Dask graph,6480 chunks in 2 graph layers,6480 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 3.85 GiB 750.00 kiB Shape (8640, 5, 101, 237) (24, 5, 40, 40) Dask graph 6480 chunks in 2 graph layers Data type float32 numpy.ndarray",8640  1  237  101  5,

Unnamed: 0,Array,Chunk
Bytes,3.85 GiB,750.00 kiB
Shape,"(8640, 5, 101, 237)","(24, 5, 40, 40)"
Dask graph,6480 chunks in 2 graph layers,6480 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,3.85 GiB,750.00 kiB
Shape,"(8640, 5, 101, 237)","(24, 5, 40, 40)"
Dask graph,6480 chunks in 2 graph layers,6480 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 3.85 GiB 750.00 kiB Shape (8640, 5, 101, 237) (24, 5, 40, 40) Dask graph 6480 chunks in 2 graph layers Data type float32 numpy.ndarray",8640  1  237  101  5,

Unnamed: 0,Array,Chunk
Bytes,3.85 GiB,750.00 kiB
Shape,"(8640, 5, 101, 237)","(24, 5, 40, 40)"
Dask graph,6480 chunks in 2 graph layers,6480 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,3.85 GiB,750.00 kiB
Shape,"(8640, 5, 101, 237)","(24, 5, 40, 40)"
Dask graph,6480 chunks in 2 graph layers,6480 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 3.85 GiB 750.00 kiB Shape (8640, 5, 101, 237) (24, 5, 40, 40) Dask graph 6480 chunks in 2 graph layers Data type float32 numpy.ndarray",8640  1  237  101  5,

Unnamed: 0,Array,Chunk
Bytes,3.85 GiB,750.00 kiB
Shape,"(8640, 5, 101, 237)","(24, 5, 40, 40)"
Dask graph,6480 chunks in 2 graph layers,6480 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,3.85 GiB,750.00 kiB
Shape,"(8640, 5, 101, 237)","(24, 5, 40, 40)"
Dask graph,6480 chunks in 2 graph layers,6480 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 3.85 GiB 750.00 kiB Shape (8640, 5, 101, 237) (24, 5, 40, 40) Dask graph 6480 chunks in 2 graph layers Data type float32 numpy.ndarray",8640  1  237  101  5,

Unnamed: 0,Array,Chunk
Bytes,3.85 GiB,750.00 kiB
Shape,"(8640, 5, 101, 237)","(24, 5, 40, 40)"
Dask graph,6480 chunks in 2 graph layers,6480 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,788.94 MiB,150.00 kiB
Shape,"(8640, 101, 237)","(24, 40, 40)"
Dask graph,6480 chunks in 2 graph layers,6480 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 788.94 MiB 150.00 kiB Shape (8640, 101, 237) (24, 40, 40) Dask graph 6480 chunks in 2 graph layers Data type float32 numpy.ndarray",237  101  8640,

Unnamed: 0,Array,Chunk
Bytes,788.94 MiB,150.00 kiB
Shape,"(8640, 101, 237)","(24, 40, 40)"
Dask graph,6480 chunks in 2 graph layers,6480 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,788.94 MiB,150.00 kiB
Shape,"(8640, 101, 237)","(24, 40, 40)"
Dask graph,6480 chunks in 2 graph layers,6480 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 788.94 MiB 150.00 kiB Shape (8640, 101, 237) (24, 40, 40) Dask graph 6480 chunks in 2 graph layers Data type float32 numpy.ndarray",237  101  8640,

Unnamed: 0,Array,Chunk
Bytes,788.94 MiB,150.00 kiB
Shape,"(8640, 101, 237)","(24, 40, 40)"
Dask graph,6480 chunks in 2 graph layers,6480 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,788.94 MiB,150.00 kiB
Shape,"(8640, 101, 237)","(24, 40, 40)"
Dask graph,6480 chunks in 2 graph layers,6480 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 788.94 MiB 150.00 kiB Shape (8640, 101, 237) (24, 40, 40) Dask graph 6480 chunks in 2 graph layers Data type float32 numpy.ndarray",237  101  8640,

Unnamed: 0,Array,Chunk
Bytes,788.94 MiB,150.00 kiB
Shape,"(8640, 101, 237)","(24, 40, 40)"
Dask graph,6480 chunks in 2 graph layers,6480 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,788.94 MiB,150.00 kiB
Shape,"(8640, 101, 237)","(24, 40, 40)"
Dask graph,6480 chunks in 2 graph layers,6480 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 788.94 MiB 150.00 kiB Shape (8640, 101, 237) (24, 40, 40) Dask graph 6480 chunks in 2 graph layers Data type float32 numpy.ndarray",237  101  8640,

Unnamed: 0,Array,Chunk
Bytes,788.94 MiB,150.00 kiB
Shape,"(8640, 101, 237)","(24, 40, 40)"
Dask graph,6480 chunks in 2 graph layers,6480 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,788.94 MiB,150.00 kiB
Shape,"(8640, 101, 237)","(24, 40, 40)"
Dask graph,6480 chunks in 2 graph layers,6480 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 788.94 MiB 150.00 kiB Shape (8640, 101, 237) (24, 40, 40) Dask graph 6480 chunks in 2 graph layers Data type float32 numpy.ndarray",237  101  8640,

Unnamed: 0,Array,Chunk
Bytes,788.94 MiB,150.00 kiB
Shape,"(8640, 101, 237)","(24, 40, 40)"
Dask graph,6480 chunks in 2 graph layers,6480 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,788.94 MiB,150.00 kiB
Shape,"(8640, 101, 237)","(24, 40, 40)"
Dask graph,6480 chunks in 2 graph layers,6480 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 788.94 MiB 150.00 kiB Shape (8640, 101, 237) (24, 40, 40) Dask graph 6480 chunks in 2 graph layers Data type float32 numpy.ndarray",237  101  8640,

Unnamed: 0,Array,Chunk
Bytes,788.94 MiB,150.00 kiB
Shape,"(8640, 101, 237)","(24, 40, 40)"
Dask graph,6480 chunks in 2 graph layers,6480 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [12]:
# Convert all float32 variables to float16 lazily
for var in ds_final.data_vars:
    if ds_final[var].dtype == 'float32':
        ds_final[var] = ds_final[var].astype('float16')

KeyboardInterrupt: 

Estimate size

In [9]:
def estimate_dataset_size_bytes(ds):
    total_bytes = 0
    for var in ds.data_vars.values():
        if var.chunks is not None:
            total_bytes += var.nbytes  # Dask knows how to compute this safely
        else:
            # Use 64-bit integers to avoid overflow
            n_elements = np.prod(var.shape, dtype=np.int64)
            dtype_size = np.dtype(var.dtype).itemsize
            total_bytes += int(n_elements * dtype_size)
    return total_bytes

size_bytes = estimate_dataset_size_bytes(ds_final)
print(f"Estimated uncompressed size: {size_bytes / 1e9:.2f} GB")

Estimated uncompressed size: 33.92 GB
