In [2]:
from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import PVNetUKRegionalDataset
import xarray as xr
import numpy as np

## Current usage

Our current use-pattern on data-sampler is to open multiple zarrs per input source and and concatentate them. We do this using the dask backend in xarray

In [3]:
%%time

dataset_all = PVNetUKRegionalDataset("data_config_all.yaml")

CPU times: user 35.8 s, sys: 2.44 s, total: 38.2 s
Wall time: 36.4 s


... this makes for quite slow sampling speed

In [4]:
%timeit -n 1 -r 3 dataset_all[np.random.randint(0, len(dataset_all))]

3.91 s ± 200 ms per loop (mean ± std. dev. of 3 runs, 1 loop each)


## Minimise dask graph complexity by avoiding mutliple zarrs

We can improve on this by using only single zarrs for each input source. This makes the dask graph simpler and makes sampling faster.

In this case we limit to use input zarrs covering 2022 only

In [5]:
%%time 

dataset_all_min = PVNetUKRegionalDataset("data_config_all_minimal.yaml")

CPU times: user 3.94 s, sys: 783 ms, total: 4.73 s
Wall time: 3.2 s


In [6]:
%timeit -n 1 -r 3 dataset_all_min[np.random.randint(0, len(dataset_all_min))]

1.34 s ± 152 ms per loop (mean ± std. dev. of 3 runs, 1 loop each)


I think the reason this is faster is because the dask graph is simpler here. I don't think this is just because there are less timestamps, but because we didn't have to consider the concatentation in the dask graph

Check pout the number of "graph layers" in the two different outputs below. 

In [7]:
# 12 graph layers if we concatenate zarrs
dataset_all.datasets_dict["sat"]

Unnamed: 0,Array,Chunk
Bytes,1.74 TiB,71.15 MiB
Shape,"(380028, 11, 614, 372)","(36, 11, 314, 300)"
Dask graph,42232 chunks in 12 graph layers,42232 chunks in 12 graph layers
Data type,float16 numpy.ndarray,float16 numpy.ndarray
"Array Chunk Bytes 1.74 TiB 71.15 MiB Shape (380028, 11, 614, 372) (36, 11, 314, 300) Dask graph 42232 chunks in 12 graph layers Data type float16 numpy.ndarray",380028  1  372  614  11,

Unnamed: 0,Array,Chunk
Bytes,1.74 TiB,71.15 MiB
Shape,"(380028, 11, 614, 372)","(36, 11, 314, 300)"
Dask graph,42232 chunks in 12 graph layers,42232 chunks in 12 graph layers
Data type,float16 numpy.ndarray,float16 numpy.ndarray


In [8]:
# Only 5 graph laters when we don't concatenate
dataset_all_min.datasets_dict["sat"]

Unnamed: 0,Array,Chunk
Bytes,419.56 GiB,71.15 MiB
Shape,"(89652, 11, 614, 372)","(36, 11, 314, 300)"
Dask graph,9964 chunks in 5 graph layers,9964 chunks in 5 graph layers
Data type,float16 numpy.ndarray,float16 numpy.ndarray
"Array Chunk Bytes 419.56 GiB 71.15 MiB Shape (89652, 11, 614, 372) (36, 11, 314, 300) Dask graph 9964 chunks in 5 graph layers Data type float16 numpy.ndarray",89652  1  372  614  11,

Unnamed: 0,Array,Chunk
Bytes,419.56 GiB,71.15 MiB
Shape,"(89652, 11, 614, 372)","(36, 11, 314, 300)"
Dask graph,9964 chunks in 5 graph layers,9964 chunks in 5 graph layers
Data type,float16 numpy.ndarray,float16 numpy.ndarray


## Replace the datasets with xarray tensorstore

Let's hack into these torch datasets and replace the xarray objects with `xarray-tensorstore` versions

In [9]:
"""These functions load the satellite and NWP zarrs using xarray-tensorstore

They carry out the necessary processing steps so that the returned xarray-tensorstore objects
are equivalent to the xarray objects in the PVNetUKRegionalDataset datasets_dict.
"""


import xarray_tensorstore as xrt
from ocf_data_sampler.load.utils import (
    check_time_unique_increasing,
    make_spatial_coords_increasing,
)


def sat_preprocessing(ds, channels):
    
    ds = ds.rename({"variable": "channel", "time": "time_utc"})

    check_time_unique_increasing(ds.time_utc)
    ds = make_spatial_coords_increasing(ds, x_coord="x_geostationary", y_coord="y_geostationary")
    ds = ds.transpose("time_utc", "channel", "x_geostationary", "y_geostationary")
    ds = ds.sel(channel=channels)

    return ds.data

def xr_tensorstore_open_sat(zarr_path, channels):
    return sat_preprocessing(xrt.open_zarr(zarr_path), channels)


def ukv_preprocessing(ds, channels):
    ds = ds.rename(
        {
            "init_time": "init_time_utc",
            "variable": "channel",
            "x": "x_osgb",
            "y": "y_osgb",
        },
    )

    check_time_unique_increasing(ds.init_time_utc)
    ds = make_spatial_coords_increasing(ds, x_coord="x_osgb", y_coord="y_osgb")
    ds =  ds.transpose("init_time_utc", "step", "channel", "x_osgb", "y_osgb")
    ds = ds.sel(channel=channels)

    return ds.UKV

def xr_tensorstore_open_ukv(zarr_path, channels):
    return ukv_preprocessing(xrt.open_zarr(zarr_path), channels)


def ecmwf_preprocessing(ds, channels):
    ds = ds.rename({"init_time": "init_time_utc", "variable": "channel"})

    check_time_unique_increasing(ds.init_time_utc)
    ds = make_spatial_coords_increasing(ds, x_coord="longitude", y_coord="latitude")
    ds = ds.transpose("init_time_utc", "step", "channel", "longitude", "latitude")
    ds = ds.sel(channel=channels)

    return ds.ECMWF_UK

def xr_tensorstore_open_ecmwf(zarr_path, channels):
    return ecmwf_preprocessing(xrt.open_zarr(zarr_path), channels)



Open the xarray-tensorstore objects

In [10]:
input_config = dataset_all_min.config.input_data

# Open tensorstore versions of the datasets
dst_sat = xr_tensorstore_open_sat(input_config.satellite.zarr_path, input_config.satellite.channels)
dst_ukv = xr_tensorstore_open_ukv(input_config.nwp.ukv.zarr_path, input_config.nwp.ukv.channels)
dst_ecmwf = xr_tensorstore_open_ecmwf(input_config.nwp.ecmwf.zarr_path, input_config.nwp.ecmwf.channels)

  return ukv_preprocessing(xrt.open_zarr(zarr_path), channels)
  return ecmwf_preprocessing(xrt.open_zarr(zarr_path), channels)


In [11]:
# Store the original datasets which were opened via xarray and with the dask backend
ds_sat = dataset_all_min.datasets_dict["sat"]
ds_ukv = dataset_all_min.datasets_dict["nwp"]["ukv"]
ds_ecmwf = dataset_all_min.datasets_dict["nwp"]["ecmwf"]

Time (again) the PVNetUKRegionalDataset with the dask xarray objects

In [12]:
%timeit -n 1 -r 3 dataset_all_min[np.random.randint(0, len(dataset_all_min))]

1.34 s ± 158 ms per loop (mean ± std. dev. of 3 runs, 1 loop each)


Time the PVNetUKRegionalDataset with the xarray-tensorstore objects

In [13]:
# Replace the datasets in the PVNet sampler with tensorstore versions
dataset_all_min.datasets_dict["nwp"]["ukv"] = dst_ukv
dataset_all_min.datasets_dict["nwp"]["ecmwf"] = dst_ecmwf
dataset_all_min.datasets_dict["sat"] = dst_sat

# Speed test
%timeit -n 1 -r 8 dataset_all_min[np.random.randint(0, len(dataset_all_min))]

# Put the original datasets back into the sampler
dataset_all_min.datasets_dict["nwp"]["ukv"] = ds_ukv
dataset_all_min.datasets_dict["nwp"]["ecmwf"] = ds_ecmwf
dataset_all_min.datasets_dict["sat"] = ds_sat

106 ms ± 14.6 ms per loop (mean ± std. dev. of 8 runs, 1 loop each)


The `xarray-tensorstore` version is much much faster!

### Check the generated samples are exactly the same

In [14]:
index = 238234

# Get a sample from the original dataset
sample_dask = dataset_all_min[index]


# Replace the datasets in the PVNet sampler with tensorstore versions
dataset_all_min.datasets_dict["nwp"]["ukv"] = dst_ukv
dataset_all_min.datasets_dict["nwp"]["ecmwf"] = dst_ecmwf
dataset_all_min.datasets_dict["sat"] = dst_sat

# Get the same sample using the tensorstore datasets
sample_tensorstore = dataset_all_min[index]

# Put the original datasets back into the sampler
dataset_all_min.datasets_dict["nwp"]["ukv"] = ds_ukv
dataset_all_min.datasets_dict["nwp"]["ecmwf"] = ds_ecmwf
dataset_all_min.datasets_dict["sat"] = ds_sat

In [15]:
print("satellite same:", (sample_dask["satellite_actual"]==sample_tensorstore["satellite_actual"]).all())
print("UKV same:", (sample_dask["nwp"]["ecmwf"]["nwp"]==sample_tensorstore["nwp"]["ecmwf"]["nwp"]).all())
print("ECMWF same:", (sample_dask["nwp"]["ukv"]["nwp"]==sample_tensorstore["nwp"]["ukv"]["nwp"]).all())

satellite same: True
UKV same: True
ECMWF same: True


## Try with multi-zarr


Now let's extend `xarray-tensorstore` to allow us to use multiple zarrs per dataset.


We will compare this to the dask version using multiple zarrs - i.e. the "Current usage" section above

In [16]:
"""This is an extension of `xarray_tensorstore` that opens multiple zarr and concatentates them
"""

from xarray_tensorstore import (
    _zarr_spec_from_path, 
    _TensorStoreAdapter, 
    _raise_if_mask_and_scale_used_for_data_vars,
)
import tensorstore as ts
import xarray as xr
import os


def tensorstore_open_multi_zarrs(
    paths: list[str], 
    data_vars: list[str], 
    concat_axes: list[int],
    context: ts.Context, 
    write: bool,
) -> dict[str, ts.TensorStore]:
    
    arrays_list = []
    for path in paths:
        specs = {k: _zarr_spec_from_path(os.path.join(path, k)) for k in data_vars}
        array_futures = {
          k: ts.open(spec, read=True, write=write, context=context)
          for k, spec in specs.items()
        }
        arrays_list.append({k: v.result() for k, v in array_futures.items()})
        
    arrays = {}
    for k, axis in zip(data_vars, concat_axes):
        datasets = [d[k] for d in arrays_list]
        arrays[k] = ts.concat(datasets, axis=axis)
        
    return arrays


def open_zarrs(
    paths: list[str],
    concat_dim: str,
    *,
    context: ts.Context | None = None,
    mask_and_scale: bool = True,
    write: bool = False,
) -> xr.Dataset:

    if context is None:
        context = ts.Context()

    ds = xr.open_mfdataset(
        paths,
        concat_dim=concat_dim,
        combine="nested",
        mask_and_scale=mask_and_scale,
        decode_timedelta=True,
    )
    
    if mask_and_scale:
        # Data variables get replaced below with _TensorStoreAdapter arrays, which
        # don't get masked or scaled. Raising an error avoids surprising users with
        # incorrect data values.
        _raise_if_mask_and_scale_used_for_data_vars(ds)
    
    data_vars = list(ds.data_vars)
    
    concat_axes = [ds[v].dims.index(concat_dim) for v in data_vars]
        
    arrays = tensorstore_open_multi_zarrs(paths, data_vars, concat_axes, context, write)
    
    new_data = {k: _TensorStoreAdapter(v) for k, v in arrays.items()}

    return ds.copy(data=new_data)

In [17]:
def xr_tensorstore_mf_open_sat(zarr_paths, channels):
    return sat_preprocessing(open_zarrs(zarr_paths, concat_dim="time"), channels)

def xr_tensorstore_mf_open_ukv(zarr_paths, channels):
    return ukv_preprocessing(open_zarrs(zarr_paths, concat_dim="init_time"), channels)

def xr_tensorstore_mf_open_ecmwf(zarr_paths, channels):
    return ecmwf_preprocessing(open_zarrs(zarr_paths, concat_dim="init_time"), channels)

In [18]:
multizarr_input_config = dataset_all.config.input_data

In [19]:
%%time 

# Open tensorstore versions of the multi-file datasets
dst_mf_sat = xr_tensorstore_mf_open_sat(
    multizarr_input_config.satellite.zarr_path,
    multizarr_input_config.satellite.channels,
)


CPU times: user 1.01 s, sys: 99.1 ms, total: 1.1 s
Wall time: 1.11 s


In [20]:
%%time 

dst_mf_ecmwf = xr_tensorstore_mf_open_ecmwf(
    multizarr_input_config.nwp.ecmwf.zarr_path,
    multizarr_input_config.nwp.ecmwf.channels,
)

CPU times: user 491 ms, sys: 31.3 ms, total: 522 ms
Wall time: 521 ms


In [21]:
%%time 

dst_mf_ukv = xr_tensorstore_mf_open_ukv(
    multizarr_input_config.nwp.ukv.zarr_path,
    multizarr_input_config.nwp.ukv.channels,
)

CPU times: user 19.1 s, sys: 3.1 s, total: 22.2 s
Wall time: 22.1 s


In [22]:
# Store the original multi-file datasets which were opened via xarray and with the dask backend
ds_mf_ukv = dataset_all.datasets_dict["nwp"]["ukv"]
ds_mf_ecmwf = dataset_all.datasets_dict["nwp"]["ecmwf"]
ds_mf_sat = dataset_all.datasets_dict["sat"]

Time the original multi-file version which uses dask

In [23]:
%timeit -n 1 -r 3 dataset_all[np.random.randint(0, len(dataset_all))]

3.62 s ± 192 ms per loop (mean ± std. dev. of 3 runs, 1 loop each)


Time the multi-file version which uses tensorstore

In [25]:
# Replace the datasets in the PVNet sampler with tensorstore versions
dataset_all.datasets_dict["nwp"]["ukv"] = dst_mf_ukv
dataset_all.datasets_dict["nwp"]["ecmwf"] = dst_mf_ecmwf
dataset_all.datasets_dict["sat"] = dst_mf_sat

# Speed test
%timeit -n 1 -r 8 dataset_all[np.random.randint(0, len(dataset_all))]

# Put the original datasets back into the sampler
dataset_all.datasets_dict["nwp"]["ukv"] = ds_mf_ukv
dataset_all.datasets_dict["nwp"]["ecmwf"] = ds_mf_ecmwf
dataset_all.datasets_dict["sat"] = ds_mf_sat

118 ms ± 10.2 ms per loop (mean ± std. dev. of 8 runs, 1 loop each)


It works way way faster!

### Check the generated samples are exactly the same

In [26]:
index = 238234

# Get a sample from the original dataset
sample_dask = dataset_all[index]

dataset_all.datasets_dict["nwp"]["ukv"] = dst_mf_ukv
dataset_all.datasets_dict["nwp"]["ecmwf"] = dst_mf_ecmwf
dataset_all.datasets_dict["sat"] = dst_mf_sat

# Get the same sample using the tensorstore datasets
sample_tensorstore = dataset_all[index]

dataset_all.datasets_dict["nwp"]["ukv"] = ds_mf_ukv
dataset_all.datasets_dict["nwp"]["ecmwf"] = ds_mf_ecmwf
dataset_all.datasets_dict["sat"] = ds_mf_sat

In [27]:
print("satellite same:", (sample_dask["satellite_actual"]==sample_tensorstore["satellite_actual"]).all())
print("UKV same:", (sample_dask["nwp"]["ecmwf"]["nwp"]==sample_tensorstore["nwp"]["ecmwf"]["nwp"]).all())
print("ECMWF same:", (sample_dask["nwp"]["ukv"]["nwp"]==sample_tensorstore["nwp"]["ukv"]["nwp"]).all())

satellite same: True
UKV same: True
ECMWF same: True


## Repeat with no-dask (i.e. python only version)

**Note this is incredibly slow!**

In [28]:
import xarray_tensorstore as xrt
from ocf_data_sampler.load.utils import (
    check_time_unique_increasing,
    make_spatial_coords_increasing,
)


def xr_nodask_open_sat(zarr_path, channels):
    
    ds = xr.open_zarr(zarr_path, chunks=None)

    ds = ds.rename({"variable": "channel", "time": "time_utc"})

    check_time_unique_increasing(ds.time_utc)
    ds = make_spatial_coords_increasing(ds, x_coord="x_geostationary", y_coord="y_geostationary")
    ds = ds.transpose("time_utc", "channel", "x_geostationary", "y_geostationary")
    
    # Slicing the channels like this causes the whole array to be loaded into memory
    #ds = ds.sel(channel=channels)

    return ds.data


def xr_nodask_open_ukv(zarr_path, channels):
    
    ds = xr.open_zarr(zarr_path, chunks=None)

    ds = ds.rename(
        {
            "init_time": "init_time_utc",
            "variable": "channel",
            "x": "x_osgb",
            "y": "y_osgb",
        },
    )

    check_time_unique_increasing(ds.init_time_utc)
    ds = make_spatial_coords_increasing(ds, x_coord="x_osgb", y_coord="y_osgb")
    ds =  ds.transpose("init_time_utc", "step", "channel", "x_osgb", "y_osgb")
    # Slicing the channels like this causes the whole array to be loaded into memory
    #ds = ds.sel(channel=channels)

    return ds.UKV


def xr_nodask_open_ecmwf(zarr_path, channels):
    
    ds = xr.open_zarr(zarr_path, chunks=None)

    ds = ds.rename({"init_time": "init_time_utc", "variable": "channel"})

    check_time_unique_increasing(ds.init_time_utc)
    ds = make_spatial_coords_increasing(ds, x_coord="longitude", y_coord="latitude")
    ds = ds.transpose("init_time_utc", "step", "channel", "longitude", "latitude")
    
    # Slicing the channels like this causes the whole array to be loaded into memory
    #ds = ds.sel(channel=channels)

    return ds.ECMWF_UK


input_config = dataset_all_min.config.input_data

# Open tensorstore versions of the datasets
ds_nodask_sat = xr_nodask_open_sat(input_config.satellite.zarr_path, input_config.satellite.channels)
ds_nodask_ukv = xr_nodask_open_ukv(input_config.nwp.ukv.zarr_path, input_config.nwp.ukv.channels)
ds_nodask_ecmwf = xr_nodask_open_ecmwf(input_config.nwp.ecmwf.zarr_path, input_config.nwp.ecmwf.channels)

  ds = xr.open_zarr(zarr_path, chunks=None)
  ds = xr.open_zarr(zarr_path, chunks=None)


In [29]:
# Replace the datasets in the PVNet sampler with tensorstore versions
dataset_all_min.datasets_dict["nwp"]["ukv"] = ds_nodask_ukv
dataset_all_min.datasets_dict["nwp"]["ecmwf"] = ds_nodask_ecmwf
dataset_all_min.datasets_dict["sat"] = ds_nodask_sat

# Speed test
%timeit -n 1 -r 1 dataset_all_min[np.random.randint(0, len(dataset_all_min))]

# Put the original datasets back into the sampler
dataset_all_min.datasets_dict["nwp"]["ukv"] = ds_ukv
dataset_all_min.datasets_dict["nwp"]["ecmwf"] = ds_ecmwf
dataset_all_min.datasets_dict["sat"] = ds_sat

MemoryError: Unable to allocate 6.71 GiB for an array with shape (2485, 37, 17, 24, 24) and data type int64

The dimensions of the object are huge -> `(2485, 37, 17, 24, 24)`! When using this non-dask version, when the spatial slice is taken, it is done non-lazily. So a huge array is loaded into memory!