In [1]:
from pathlib import Path

import earthkit.data
import earthkit.plots
import numcodecs
import numpy as np
import xarray as xr

from matplotlib import pyplot as plt
from numcodecs.abc import Codec
from numcodecs_combinators.stack import CodecStack
from numcodecs_safeguards import SafeguardsCodec
from numcodecs_safeguards.lossless import BytesCodec
from numcodecs_wasm_zfp import Zfp
from numcodecs_wasm_zfp_classic import ZfpClassic

In [2]:
t2m = xr.open_dataset(
    Path() / "data" / "era5-daily-2tm-2024-01" / "data.nc",
    engine="netcdf4",
    decode_timedelta=True,
).t2m.chunk({"valid_time": 1})
t2m

Unnamed: 0,Array,Chunk
Bytes,122.78 MiB,3.96 MiB
Shape,"(31, 721, 1440)","(1, 721, 1440)"
Dask graph,31 chunks in 2 graph layers,31 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 122.78 MiB 3.96 MiB Shape (31, 721, 1440) (1, 721, 1440) Dask graph 31 chunks in 2 graph layers Data type float32 numpy.ndarray",1440  721  31,

Unnamed: 0,Array,Chunk
Bytes,122.78 MiB,3.96 MiB
Shape,"(31, 721, 1440)","(1, 721, 1440)"
Dask graph,31 chunks in 2 graph layers,31 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [3]:
vmin, vmax = int(np.floor(np.amin(t2m))), int(np.ceil(np.amax(t2m)))
vmin, vmax

(217, 316)

In [10]:
from numcodecs_observers import observe
from numcodecs_observers.bytesize import BytesizeObserver
from numcodecs_observers.hash import HashableCodec

In [7]:
zfp_classic = CodecStack(ZfpClassic(mode="fixed-accuracy", tolerance=1.0))

bytesize = BytesizeObserver()
with observe(zfp_classic, observers=[bytesize]) as codec:
    decoded_zfp_classic = codec.encode_decode_data_array(t2m).compute()

In [12]:
(
    sum(bs.pre for bs in bytesize.encode_sizes[HashableCodec(zfp_classic[0])]) /
    sum(bs.post for bs in bytesize.encode_sizes[HashableCodec(zfp_classic[0])])
)

9.092469660812577

In [None]:
with xr.set_options(keep_attrs=True):
    error_zfp_classic = (decoded_zfp_classic - t2m).mean(dim="valid_time").compute()
float(np.mean(error_zfp_classic) / np.std(error_zfp_classic))

In [None]:
plt.hist(error_zfp_classic.values.flatten(), bins=100)
plt.show()

In [None]:
emin, emax = float(np.amin(error_zfp_classic)), float(np.amax(error_zfp_classic))
emin, emax

In [None]:
earthkit.plots.quickplot(
    error_zfp_classic,
    style=earthkit.plots.Style(levels=np.linspace(-0.25, 0.25, 11), colors="coolwarm"),
    methods="pcolormesh",
);

In [None]:
zfp = CodecStack(Zfp(mode="fixed-accuracy", tolerance=1.0))

decoded_zfp = zfp.encode_decode_data_array(t2m).compute()

In [None]:
with xr.set_options(keep_attrs=True):
    error_zfp = (decoded_zfp - t2m).mean(dim="valid_time").compute()
float(np.mean(error_zfp) / np.std(error_zfp))

In [None]:
plt.hist(error_zfp.values.flatten(), bins=100)
plt.show()

In [None]:
emin, emax = float(np.amin(error_zfp)), float(np.amax(error_zfp))
emin, emax

In [None]:
earthkit.plots.quickplot(
    error_zfp,
    style=earthkit.plots.Style(levels=np.linspace(-0.31, 0.31, 11), colors="coolwarm"),
    methods="pcolormesh",
);

In [None]:
zfp_sg = CodecStack(
    SafeguardsCodec(
        codec=ZfpClassic(mode="fixed-accuracy", tolerance=1.0),
        safeguards=[
            dict(kind="abs", eb_abs=1.0),
            dict(kind="bias"),
            dict(kind="bias"),
            dict(kind="bias"),
        ],
    )
)

bytesize = BytesizeObserver()
with observe(zfp_sg, observers=[bytesize]) as codec:
    decoded_zfp_sg = codec.encode_decode_data_array(t2m).compute()

In [None]:
list(bytesize.encode_sizes.keys())[-1]

In [None]:
(
    sum(bs.pre for bs in list(bytesize.encode_sizes.values())[-1]) /
    sum(bs.post for bs in list(bytesize.encode_sizes.values())[-1])
)

In [None]:
with xr.set_options(keep_attrs=True):
    error_zfp_sg = (decoded_zfp_sg - t2m).mean(dim="valid_time").compute()
float(np.mean(error_zfp_sg) / np.std(error_zfp_sg))

In [None]:
plt.hist(error_zfp_sg.values.flatten(), bins=100)
plt.show()

In [None]:
emin, emax = float(np.amin(error_zfp_sg)), float(np.amax(error_zfp_sg))
emin, emax

In [None]:
earthkit.plots.quickplot(
    error_zfp_sg,
    style=earthkit.plots.Style(levels=np.linspace(-0.35, 0.35, 11), colors="coolwarm"),
    methods="pcolormesh",
);

In [None]:
class BiasedLinearQuantizer(Codec):
    codec_id = "biased-linear-quantizer"

    def __init__(self, precision):
        self.precision = precision

    def encode(self, buf):
        return np.floor(buf / self.precision) * self.precision

    def decode(self, buf, out=None):
        return numcodecs.compat.ndarray_copy(buf, out)

In [None]:
blq = CodecStack(BiasedLinearQuantizer(precision=1.0))

decoded_blq = blq.encode_decode_data_array(t2m).compute()

In [None]:
with xr.set_options(keep_attrs=True):
    error_blq = (decoded_blq - t2m).mean(dim="valid_time").compute()
float(np.mean(error_blq) / np.std(error_blq))

In [None]:
plt.hist(error_blq.values.flatten(), bins=100)
plt.show()

In [None]:
emin, emax = float(np.amin(error_blq)), float(np.amax(error_blq))
emin, emax

In [None]:
earthkit.plots.quickplot(
    error_blq,
    style=earthkit.plots.Style(levels=np.linspace(-0.83, 0.83, 11), colors="coolwarm"),
    methods="pcolormesh",
);

In [None]:
blq_sg = CodecStack(
    SafeguardsCodec(
        codec=CodecStack(
            BiasedLinearQuantizer(precision=1.0),
            BytesCodec(),
        ),
        safeguards=[
            dict(kind="abs", eb_abs=1.0),
            dict(kind="bias"),
        ],
    )
)

decoded_blq_sg = blq_sg.encode_decode_data_array(t2m).compute()

In [None]:
with xr.set_options(keep_attrs=True):
    error_blq_sg = (decoded_blq_sg - t2m).mean(dim="valid_time").compute()
float(np.mean(error_blq_sg) / np.std(error_blq_sg))

In [None]:
plt.hist(error_blq_sg.values.flatten(), bins=100)
plt.show()

In [None]:
emin, emax = float(np.amin(error_blq_sg)), float(np.amax(error_blq_sg))
emin, emax

In [None]:
earthkit.plots.quickplot(
    error_blq_sg,
    style=earthkit.plots.Style(levels=np.linspace(-0.55, 0.83, 11), colors="coolwarm"),
    methods="pcolormesh",
);