# ERA 5 quantiles


Task:
1. Read data from [ERA5 PDS Zarr](https://registry.opendata.aws/ecmwf-era5/) dataset using `fsspec` and `open_mfdataset`
2. Calculate `q=0.9` quantile for each year.

Challenge:
1. The quantile can only be calculated with all the data for a year in a single chunk so we must rechunk to an annual frequency.
2. The input dataset for each variable has dimensions `time: 392256, lat: 721, lon: 1440` with chunksizes `(372, 150, 150)`.
3. We need to rechunk from `(372, 150, 150)` to approximately `(8784, 75, 75)`

Outcome:
dask attempts to load all data and fails.

In [1]:
%load_ext watermark

import itertools

import coiled
import dask
import distributed
import flox.xarray
import fsspec
import numpy as np
import pandas as pd
import xarray as xr

%watermark -iv

distributed: 2023.10.1
coiled     : 0.9.34
pandas     : 2.1.1
flox       : 0.8.1
dask       : 2023.10.1
fsspec     : 2023.9.2
xarray     : 2023.10.1
numpy      : 1.24.4



## Setup to read data

In [2]:
fs = fsspec.filesystem("s3")
prefix = "s3://era5-pds/zarr"

In [3]:
variables = [
    store.split("/")[-1] for store in fs.glob("era5-pds/zarr/1979/01/data/*.zarr")
]
years = [path.split("/")[-1] for path in fs.glob("era5-pds/zarr/*")]
months = [f"{m:02d}" for m in range(1, 13)]
last_months = [path.split("/")[-1] for path in fs.glob(f"era5-pds/zarr/{years[-1]}/*")]

Having trouble reading precip but just work with any variable for now

In [4]:
var = ("precipitation_amount_1hour_Accumulation.zarr",)
var = (variables[0],)

In [5]:
all_stores = [
    "/".join(t)
    for t in itertools.chain(
        itertools.product((prefix,), years[:-1], months, ("data",), var),
        itertools.product((prefix,), years[-2:], last_months, ("data",), var),
    )
]

In [12]:
ds = xr.open_zarr(all_stores[2])
display(ds)


def preprocess(ds):
    """Edit the dataset so it combines nicely."""
    (time_dim,) = [dim for dim in ds.dims if "time" in dim]
    ds = ds.rename({time_dim: "time"})
    bounds_var = [var for var in ds.data_vars if "bounds" in var]
    ds = ds.drop_vars(bounds_var)
    return ds


preprocess(ds)

Unnamed: 0,Array,Chunk
Bytes,2.88 GiB,31.93 MiB
Shape,"(744, 721, 1440)","(372, 150, 150)"
Dask graph,100 chunks in 2 graph layers,100 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 2.88 GiB 31.93 MiB Shape (744, 721, 1440) (372, 150, 150) Dask graph 100 chunks in 2 graph layers Data type float32 numpy.ndarray",1440  721  744,

Unnamed: 0,Array,Chunk
Bytes,2.88 GiB,31.93 MiB
Shape,"(744, 721, 1440)","(372, 150, 150)"
Dask graph,100 chunks in 2 graph layers,100 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


Unnamed: 0,Array,Chunk
Bytes,2.88 GiB,31.93 MiB
Shape,"(744, 721, 1440)","(372, 150, 150)"
Dask graph,100 chunks in 2 graph layers,100 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 2.88 GiB 31.93 MiB Shape (744, 721, 1440) (372, 150, 150) Dask graph 100 chunks in 2 graph layers Data type float32 numpy.ndarray",1440  721  744,

Unnamed: 0,Array,Chunk
Bytes,2.88 GiB,31.93 MiB
Shape,"(744, 721, 1440)","(372, 150, 150)"
Dask graph,100 chunks in 2 graph layers,100 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


## Setup cluster

In [8]:
cluster = coiled.Cluster()

Output()

Output()

Using adaptive scaling. To manually control the size of your cluster, use n_workers=.

2023-11-03 23:20:31,259 - distributed.deploy.adaptive - INFO - Adaptive scaling started: minimum=4 maximum=20


In [9]:
import distributed

client = distributed.Client(cluster)

In [10]:
client

0,1
Connection method: Cluster object,Cluster type: coiled.Cluster
Dashboard: https://cluster-alorb.dask.host/l0-wizWvisraylQO/status,

0,1
Dashboard: https://cluster-alorb.dask.host/l0-wizWvisraylQO/status,Workers: 2
Total threads: 8,Total memory: 29.69 GiB

0,1
Comm: tls://10.0.175.59:8786,Workers: 2
Dashboard: http://10.0.175.59:8787/status,Total threads: 8
Started: Just now,Total memory: 29.69 GiB

0,1
Comm: tls://10.0.170.222:38011,Total threads: 4
Dashboard: http://10.0.170.222:8787/status,Memory: 14.83 GiB
Nanny: tls://10.0.170.222:44523,
Local directory: /scratch/dask-scratch-space/worker-wit5r84a,Local directory: /scratch/dask-scratch-space/worker-wit5r84a

0,1
Comm: tls://10.0.173.51:38411,Total threads: 4
Dashboard: http://10.0.173.51:8787/status,Memory: 14.85 GiB
Nanny: tls://10.0.173.51:34363,
Local directory: /scratch/dask-scratch-space/worker-781bdgk5,Local directory: /scratch/dask-scratch-space/worker-781bdgk5


## Read the data in parallel

In [13]:
ds = xr.open_mfdataset(
    all_stores,
    engine="zarr",
    combine="nested",
    concat_dim="time",
    preprocess=preprocess,
    join="override",
    parallel=True,
)
ds

Unnamed: 0,Array,Chunk
Bytes,1.51 TiB,31.93 MiB
Shape,"(398808, 721, 1440)","(372, 150, 150)"
Dask graph,54600 chunks in 1075 graph layers,54600 chunks in 1075 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.51 TiB 31.93 MiB Shape (398808, 721, 1440) (372, 150, 150) Dask graph 54600 chunks in 1075 graph layers Data type float32 numpy.ndarray",1440  721  398808,

Unnamed: 0,Array,Chunk
Bytes,1.51 TiB,31.93 MiB
Shape,"(398808, 721, 1440)","(372, 150, 150)"
Dask graph,54600 chunks in 1075 graph layers,54600 chunks in 1075 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


## Rechunk to annual frequency

There's something wrong in 2022. Too many timestamps!

In [14]:
def rechunk_to_frequency(ds, **kwargs):
    newchunks = {}
    for dim, freq in kwargs.items():
        newchunks[dim] = tuple(
            ds[dim]
            .copy(data=np.ones(ds[dim].shape, dtype=np.int64))
            .to_dataframe()
            .resample(freq)
            .sum()[dim]
            .values
        )
    return newchunks


newchunks = rechunk_to_frequency(ds, time="A")
print(newchunks)
rechunked = ds.chunk(**newchunks, lat=75, lon=75)
rechunked

{'time': (8760, 8784, 8760, 8760, 8760, 8784, 8760, 8760, 8760, 8784, 8760, 8760, 8760, 8784, 8760, 8760, 8760, 8784, 8760, 8760, 8760, 8784, 8760, 8760, 8760, 8784, 8760, 8760, 8760, 8784, 8760, 8760, 8760, 8784, 8760, 8760, 8760, 8784, 8760, 8760, 8760, 8784, 8760, 15312, 6552)}


Unnamed: 0,Array,Chunk
Bytes,1.51 TiB,328.56 MiB
Shape,"(398808, 721, 1440)","(15312, 75, 75)"
Dask graph,9000 chunks in 1077 graph layers,9000 chunks in 1077 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.51 TiB 328.56 MiB Shape (398808, 721, 1440) (15312, 75, 75) Dask graph 9000 chunks in 1077 graph layers Data type float32 numpy.ndarray",1440  721  398808,

Unnamed: 0,Array,Chunk
Bytes,1.51 TiB,328.56 MiB
Shape,"(398808, 721, 1440)","(15312, 75, 75)"
Dask graph,9000 chunks in 1077 graph layers,9000 chunks in 1077 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


## Calculate quantile per year

In [18]:
result = flox.xarray.xarray_reduce(
    rechunked,
    rechunked.time.dt.year,
    func="quantile",
    skipna=False,
    q=0.9,
    method="blockwise",
)
result

Unnamed: 0,Array,Chunk
Bytes,356.45 MiB,43.95 kiB
Shape,"(45, 721, 1440)","(1, 75, 75)"
Dask graph,9000 chunks in 1082 graph layers,9000 chunks in 1082 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 356.45 MiB 43.95 kiB Shape (45, 721, 1440) (1, 75, 75) Dask graph 9000 chunks in 1082 graph layers Data type float64 numpy.ndarray",1440  721  45,

Unnamed: 0,Array,Chunk
Bytes,356.45 MiB,43.95 kiB
Shape,"(45, 721, 1440)","(1, 75, 75)"
Dask graph,9000 chunks in 1082 graph layers,9000 chunks in 1082 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


## Compute

this is only two years. Delete the `.isel` to run the full thing.

In [None]:
result.isel(year=slice(2)).compute()