## Imports

In [1]:
from itertools import product
import numpy as np

import xarray as xr
import xbitinfo as xb

## Data loading

In [2]:
# load data
ds = xr.tutorial.load_dataset("air_temperature")
chunks = {
    "lat": 5,
    "lon": 10,
}  # Defining chunks that will be used for the reading/bitrounding/writing
ds = ds.chunk(chunks)  # Apply chunking

In [3]:
ds

Unnamed: 0,Array,Chunk
Bytes,14.76 MiB,570.31 kiB
Shape,"(2920, 25, 53)","(2920, 5, 10)"
Dask graph,30 chunks in 1 graph layer,30 chunks in 1 graph layer
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 14.76 MiB 570.31 kiB Shape (2920, 25, 53) (2920, 5, 10) Dask graph 30 chunks in 1 graph layer Data type float32 numpy.ndarray",53  25  2920,

Unnamed: 0,Array,Chunk
Bytes,14.76 MiB,570.31 kiB
Shape,"(2920, 25, 53)","(2920, 5, 10)"
Dask graph,30 chunks in 1 graph layer,30 chunks in 1 graph layer
Data type,float32 numpy.ndarray,float32 numpy.ndarray


## Saving to file

In [4]:
ds.to_netcdf("0.air_original.nc")

  ds.to_netcdf("0.air_original.nc")


## Compress with `to_compressed_netcdf`

In [5]:
ds.to_compressed_netcdf("1.air_compressed_all.nc")

  self._obj.to_netcdf(


## Compress with bitrounding

In [6]:
info_per_bit = xb.get_bitinformation(ds, dim="lon", implementation="python")
keepbits = xb.get_keepbits(info_per_bit, 0.99)
ds_bitrounded = xb.xr_bitround(ds, keepbits)

  0%|          | 0/1 [00:00<?, ?it/s]

In [7]:
ds_bitrounded.to_compressed_netcdf("2.air_bitrounded_compressed.nc")

## Zarr chunking and compressing

In [8]:
def bitrounding(chunk, var="lat"):
    """
    Just a function that handles all the xbitinfo calls
    """
    bitinfo = xb.get_bitinformation(chunk, dim=var, implementation="python")
    keepbits = xb.get_keepbits(bitinfo, 0.99)
    bitround = xb.xr_bitround(chunk, keepbits)
    return bitround


def slices_from_chunks(chunks):
    """Translate chunks tuple to a set of slices in product order

    >>> slices_from_chunks(((2, 2), (3, 3, 3)))  # doctest: +NORMALIZE_WHITESPACE
     [(slice(0, 2, None), slice(0, 3, None)),
      (slice(0, 2, None), slice(3, 6, None)),
      (slice(0, 2, None), slice(6, 9, None)),
      (slice(2, 4, None), slice(0, 3, None)),
      (slice(2, 4, None), slice(3, 6, None)),
      (slice(2, 4, None), slice(6, 9, None))]
    """
    cumdims = []
    for bds in chunks:
        out = np.empty(len(bds) + 1, dtype=int)
        out[0] = 0
        np.cumsum(bds, out=out[1:])
        cumdims.append(out)
    slices = [
        [slice(s, s + dim) for s, dim in zip(starts, shapes)]
        for starts, shapes in zip(cumdims, chunks)
    ]
    return list(product(*slices))

In [9]:
fn = "air.zarr"  # Output filename
ds.to_compressed_zarr(fn, compute=False, mode="w")  # Creates empty file structure

In [10]:
dims = ds.air.dims
len_dims = len(dims)
slices = slices_from_chunks(ds.air.chunks)

In [11]:
%%capture
for b, block in enumerate(ds.air.data.to_delayed().ravel()):  # Loop over each chunk
    # slices = {d:s for (d,s) in zip(dims, block.key[1:1+len_dims])}
    ds_block = xr.Dataset(
        {"air": (dims, block.compute())}
    )  # Conversion of dask.delayed array to Dataset (as xbitinfo wants type xr.Dataset)
    rounded_ds = bitrounding(ds_block)  # Apply bitrounding
    rounded_ds.to_zarr(
        fn, region={dims[d]: s for (d, s) in enumerate(slices[b])}
    )  # Write individual chunk to disk

## Creating smaller datasets as chunks and compressing

In [12]:
%%capture

at_least_zero = lambda x: max(x, 0)

chunk_long, chunk_lat = [10, 5]  # for int division
var = "lat"

dss = []
dss_bitrounded = []
dss_kbits = []

long_c = int(ds.lon.size / chunk_long)
lat_c = int(ds.lat.size / chunk_lat)

for i in range(long_c):
    for j in range(lat_c):
        temp_ds = ds.isel(
            lon=slice(i * chunk_long, (i + 1) * chunk_long),
            lat=slice(j * chunk_lat, (j + 1) * chunk_lat),
        )
        dss.append(temp_ds)
        temp_info_pbit = xb.get_bitinformation(
            temp_ds, dim=var, implementation="python"
        )
        temp_keepbits = xb.get_keepbits(temp_info_pbit, 0.99)
        # temp_keepbits = temp_keepbits.map(at_least_zero)
        dss_kbits.append(temp_keepbits)
        temp_ds_bitrounded = xb.xr_bitround(temp_ds, temp_keepbits)
        dss_bitrounded.append(temp_ds_bitrounded)

        if i == 0 and j == 0:
            MERGED_ds_bitr = temp_ds_bitrounded
        else:
            MERGED_ds_bitr = xr.merge([MERGED_ds_bitr, temp_ds_bitrounded])

In [13]:
MERGED_ds_bitr.to_compressed_netcdf("3.air_chunked_bitr_compressed.nc")

## ALL

In [14]:
!du -hs *.nc *.zarr

7.5M	0.air_original.nc
1.7M	1.air_compressed_all.nc
1.3M	2.air_bitrounded_compressed.nc
776K	3.air_chunked_bitr_compressed.nc
1.1M	air.zarr
