In [1]:
from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import PVNetUKRegionalDataset
import xarray as xr
import numpy as np

## Current usage

Our current use-pattern on data-sampler is to open multiple zarrs per input source and and concatentate them. We do this using the dask backend in xarray

In [2]:
%%time

dataset_all = PVNetUKRegionalDataset("data_config_all.yaml")

CPU times: user 29.2 s, sys: 1.7 s, total: 30.9 s
Wall time: 29.8 s


... this makes for quite slow sampling speed

In [3]:
%timeit -n 1 -r 3 dataset_all[np.random.randint(0, len(dataset_all))]

2.93 s ± 274 ms per loop (mean ± std. dev. of 3 runs, 1 loop each)


## Minimise dask graph complexity by avoiding mutliple zarrs

We can improve on this by using only single zarrs for each input source. This makes the dask graph simpler and makes sampling faster.

In this case we limit to use input zarrs covering 2022 only

In [4]:
%%time 

dataset_all_min = PVNetUKRegionalDataset("data_config_all_minimal.yaml")

CPU times: user 2.78 s, sys: 4.99 s, total: 7.77 s
Wall time: 4.09 s


In [5]:
%timeit -n 1 -r 3 dataset_all_min[np.random.randint(0, len(dataset_all_min))]

1.21 s ± 420 ms per loop (mean ± std. dev. of 3 runs, 1 loop each)


I think the reason this is faster is because the dask graph is simpler here. I don't think this is just because there are less timestamps, but because we didn't have to consider the concatentation in the dask graph

Check pout the number of "graph layers" in the two different outputs below. 

In [6]:
# 12 graph layers if we concatenate zarrs
dataset_all.datasets_dict["sat"]

Unnamed: 0,Array,Chunk
Bytes,1.74 TiB,71.15 MiB
Shape,"(380028, 11, 614, 372)","(36, 11, 314, 300)"
Dask graph,42232 chunks in 12 graph layers,42232 chunks in 12 graph layers
Data type,float16 numpy.ndarray,float16 numpy.ndarray
"Array Chunk Bytes 1.74 TiB 71.15 MiB Shape (380028, 11, 614, 372) (36, 11, 314, 300) Dask graph 42232 chunks in 12 graph layers Data type float16 numpy.ndarray",380028  1  372  614  11,

Unnamed: 0,Array,Chunk
Bytes,1.74 TiB,71.15 MiB
Shape,"(380028, 11, 614, 372)","(36, 11, 314, 300)"
Dask graph,42232 chunks in 12 graph layers,42232 chunks in 12 graph layers
Data type,float16 numpy.ndarray,float16 numpy.ndarray


In [7]:
# Only 5 graph laters when we don't concatenate
dataset_all_min.datasets_dict["sat"]

Unnamed: 0,Array,Chunk
Bytes,419.56 GiB,71.15 MiB
Shape,"(89652, 11, 614, 372)","(36, 11, 314, 300)"
Dask graph,9964 chunks in 5 graph layers,9964 chunks in 5 graph layers
Data type,float16 numpy.ndarray,float16 numpy.ndarray
"Array Chunk Bytes 419.56 GiB 71.15 MiB Shape (89652, 11, 614, 372) (36, 11, 314, 300) Dask graph 9964 chunks in 5 graph layers Data type float16 numpy.ndarray",89652  1  372  614  11,

Unnamed: 0,Array,Chunk
Bytes,419.56 GiB,71.15 MiB
Shape,"(89652, 11, 614, 372)","(36, 11, 314, 300)"
Dask graph,9964 chunks in 5 graph layers,9964 chunks in 5 graph layers
Data type,float16 numpy.ndarray,float16 numpy.ndarray


## Replace the datasets with xarray tensorstore

Let's hack into these torch datasets and replace the xarray objects with `xarray-tensorstore` versions

In [8]:
"""These functions load the satellite and NWP zarrs using xarray-tensorstore

They carry out the necessary processing steps so that the returned xarray-tensorstore objects
are equivalent to the xarray objects in the PVNetUKRegionalDataset datasets_dict.
"""


import xarray_tensorstore as xrt
from ocf_data_sampler.load.utils import (
    check_time_unique_increasing,
    make_spatial_coords_increasing,
)


def xr_tensorstore_open_sat(zarr_path, channels):
    
    ds = xrt.open_zarr(zarr_path)

    ds = ds.rename({"variable": "channel", "time": "time_utc"})

    check_time_unique_increasing(ds.time_utc)
    ds = make_spatial_coords_increasing(ds, x_coord="x_geostationary", y_coord="y_geostationary")
    ds = ds.transpose("time_utc", "channel", "x_geostationary", "y_geostationary")
    ds = ds.sel(channel=channels)

    return ds.data


def xr_tensorstore_open_ukv(zarr_path, channels):
    
    ds = xrt.open_zarr(zarr_path)

    ds = ds.rename(
        {
            "init_time": "init_time_utc",
            "variable": "channel",
            "x": "x_osgb",
            "y": "y_osgb",
        },
    )

    check_time_unique_increasing(ds.init_time_utc)
    ds = make_spatial_coords_increasing(ds, x_coord="x_osgb", y_coord="y_osgb")
    ds =  ds.transpose("init_time_utc", "step", "channel", "x_osgb", "y_osgb")
    ds = ds.sel(channel=channels)

    return ds.UKV


def xr_tensorstore_open_ecmwf(zarr_path, channels):
    
    ds = xrt.open_zarr(zarr_path)

    ds = ds.rename({"init_time": "init_time_utc", "variable": "channel"})

    check_time_unique_increasing(ds.init_time_utc)
    ds = make_spatial_coords_increasing(ds, x_coord="longitude", y_coord="latitude")
    ds = ds.transpose("init_time_utc", "step", "channel", "longitude", "latitude")
    ds = ds.sel(channel=channels)

    return ds.ECMWF_UK

Open the xarray-tensorstore objects

In [9]:
input_config = dataset_all_min.config.input_data

# Open tensorstore versions of the datasets
dst_sat = xr_tensorstore_open_sat(input_config.satellite.zarr_path, input_config.satellite.channels)
dst_ukv = xr_tensorstore_open_ukv(input_config.nwp.ukv.zarr_path, input_config.nwp.ukv.channels)
dst_ecmwf = xr_tensorstore_open_ecmwf(input_config.nwp.ecmwf.zarr_path, input_config.nwp.ecmwf.channels)

  ds = xrt.open_zarr(zarr_path)
  ds = xrt.open_zarr(zarr_path)


In [10]:
# Store the original datasets which were opened via xarray and with the dask backend
ds_sat = dataset_all_min.datasets_dict["sat"]
ds_ukv = dataset_all_min.datasets_dict["nwp"]["ukv"]
ds_ecmwf = dataset_all_min.datasets_dict["nwp"]["ecmwf"]

Time (again) the PVNetUKRegionalDataset with the dask xarray objects

In [11]:
%timeit -n 1 -r 8 dataset_all_min[np.random.randint(0, len(dataset_all_min))]

1.03 s ± 179 ms per loop (mean ± std. dev. of 8 runs, 1 loop each)


Time the PVNetUKRegionalDataset with the xarray-tensorstore objects

In [12]:
# Replace the datasets in the PVNet sampler with tensorstore versions
dataset_all_min.datasets_dict["nwp"]["ukv"] = dst_ukv
dataset_all_min.datasets_dict["nwp"]["ecmwf"] = dst_ecmwf
dataset_all_min.datasets_dict["sat"] = dst_sat

# Speed test
%timeit -n 1 -r 8 dataset_all_min[np.random.randint(0, len(dataset_all_min))]

# Put the original datasets back into the sampler
dataset_all_min.datasets_dict["nwp"]["ukv"] = ds_ukv
dataset_all_min.datasets_dict["nwp"]["ecmwf"] = ds_ecmwf
dataset_all_min.datasets_dict["sat"] = ds_sat

72.6 ms ± 9.03 ms per loop (mean ± std. dev. of 8 runs, 1 loop each)


The `xarray-tensorstore` version is much much faster!

## Check the generated samples are exactly the same

In [13]:
index = 238234

# Get a sample from the original dataset
sample_dask = dataset_all_min[index]


# Replace the datasets in the PVNet sampler with tensorstore versions
dataset_all_min.datasets_dict["nwp"]["ukv"] = dst_ukv
dataset_all_min.datasets_dict["nwp"]["ecmwf"] = dst_ecmwf
dataset_all_min.datasets_dict["sat"] = dst_sat

# Get the same sample using the tensorstore datasets
sample_tensorstore = dataset_all_min[index]

# Put the original datasets back into the sampler
dataset_all_min.datasets_dict["nwp"]["ukv"] = ds_ukv
dataset_all_min.datasets_dict["nwp"]["ecmwf"] = ds_ecmwf
dataset_all_min.datasets_dict["sat"] = ds_sat

In [14]:
print("satellite same:", (sample_dask["satellite_actual"]==sample_tensorstore["satellite_actual"]).all())
print("UKV same:", (sample_dask["nwp"]["ecmwf"]["nwp"]==sample_tensorstore["nwp"]["ecmwf"]["nwp"]).all())
print("ECMWF same:", (sample_dask["nwp"]["ukv"]["nwp"]==sample_tensorstore["nwp"]["ukv"]["nwp"]).all())

satellite same: True
UKV same: True
ECMWF same: True


## Repeat with no-dask (i.e. python only version)

In [15]:
import xarray_tensorstore as xrt
from ocf_data_sampler.load.utils import (
    check_time_unique_increasing,
    make_spatial_coords_increasing,
)


def xr_nodask_open_sat(zarr_path, channels):
    
    ds = xr.open_zarr(zarr_path, chunks=None)

    ds = ds.rename({"variable": "channel", "time": "time_utc"})

    check_time_unique_increasing(ds.time_utc)
    ds = make_spatial_coords_increasing(ds, x_coord="x_geostationary", y_coord="y_geostationary")
    ds = ds.transpose("time_utc", "channel", "x_geostationary", "y_geostationary")
    
    # Slicing the channels like this causes the whole array to be loaded into memory
    #ds = ds.sel(channel=channels)

    return ds.data


def xr_nodask_open_ukv(zarr_path, channels):
    
    ds = xr.open_zarr(zarr_path, chunks=None)

    ds = ds.rename(
        {
            "init_time": "init_time_utc",
            "variable": "channel",
            "x": "x_osgb",
            "y": "y_osgb",
        },
    )

    check_time_unique_increasing(ds.init_time_utc)
    ds = make_spatial_coords_increasing(ds, x_coord="x_osgb", y_coord="y_osgb")
    ds =  ds.transpose("init_time_utc", "step", "channel", "x_osgb", "y_osgb")
    # Slicing the channels like this causes the whole array to be loaded into memory
    #ds = ds.sel(channel=channels)

    return ds.UKV


def xr_nodask_open_ecmwf(zarr_path, channels):
    
    ds = xr.open_zarr(zarr_path, chunks=None)

    ds = ds.rename({"init_time": "init_time_utc", "variable": "channel"})

    check_time_unique_increasing(ds.init_time_utc)
    ds = make_spatial_coords_increasing(ds, x_coord="longitude", y_coord="latitude")
    ds = ds.transpose("init_time_utc", "step", "channel", "longitude", "latitude")
    
    # Slicing the channels like this causes the whole array to be loaded into memory
    #ds = ds.sel(channel=channels)

    return ds.ECMWF_UK


input_config = dataset_all_min.config.input_data

# Open tensorstore versions of the datasets
ds_nodask_sat = xr_nodask_open_sat(input_config.satellite.zarr_path, input_config.satellite.channels)
ds_nodask_ukv = xr_nodask_open_ukv(input_config.nwp.ukv.zarr_path, input_config.nwp.ukv.channels)
ds_nodask_ecmwf = xr_nodask_open_ecmwf(input_config.nwp.ecmwf.zarr_path, input_config.nwp.ecmwf.channels)

  ds = xr.open_zarr(zarr_path, chunks=None)
  ds = xr.open_zarr(zarr_path, chunks=None)


In [16]:
# Replace the datasets in the PVNet sampler with tensorstore versions
dataset_all_min.datasets_dict["nwp"]["ukv"] = ds_nodask_ukv
dataset_all_min.datasets_dict["nwp"]["ecmwf"] = ds_nodask_ecmwf
dataset_all_min.datasets_dict["sat"] = ds_nodask_sat

# Speed test
%timeit -n 1 -r 1 dataset_all_min[np.random.randint(0, len(dataset_all_min))]

# Put the original datasets back into the sampler
dataset_all_min.datasets_dict["nwp"]["ukv"] = ds_ukv
dataset_all_min.datasets_dict["nwp"]["ecmwf"] = ds_ecmwf
dataset_all_min.datasets_dict["sat"] = ds_sat

5min 34s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


This was incredibly slow. This suggests that there is some step in data-sampler which is non-lazy when we don't use dask and causes a lot of unneeded data to be loaded from disk