In [1]:
import lmdb
import rasterio
import safetensors
import numpy as np
from pathlib import Path
from safetensors.numpy import deserialize, load_file, load

def read_single_band_raster(path):
    with rasterio.open(path) as r:
        return r.read(1)

p = Path("tiffs/BigEarthNet")
source_data = {file: read_single_band_raster(file) for file in p.glob("**/*.tif*")}

# code to create the directory
# ./result/bin/encoder --bigearthnet-s1-root tiffs/BigEarthNet/S1/ --bigearthnet-s2-root tiffs/BigEarthNet/S2/ artifacts/
env = lmdb.open("../artifacts/ben", readonly=True)

with env.begin(write=False) as txn:
    cur = txn.cursor()
    decoded_lmdb_data = {k.decode("utf-8"): load(v) for (k, v) in cur}

# The encoded data is nested inside of another safetensor dictionary, where the inner keys are derived from the band suffix
decoded_values = [v for outer_v in decoded_lmdb_data.values() for v in outer_v.values()]

# Simply check if the data remains identical, as this is the only _true_ thing I care about from the Python viewpoint
# If the keys/order or anything else is wrong, it isn't part of the integration test but should be handled separately as a unit test!
for (source_key, source_value) in source_data.items():
    assert any(np.array_equal(source_value, decoded_value) for decoded_value in decoded_values), f"Couldn't find data in the LMDB database that matches the data from: {source_key}"

In [None]:
def read_all_raster_bands(path):
    """
    Given a path to a GeoTIFF return all bands as a dictionary,
    where the key is the unformatted band index (starting from 1)
    as a string and the value the array data
    """
    with rasterio.open(path) as r:
        return {str(i): r.read(i) for i in range(1, r.count + 1)}

p = Path("tiffs/HySpecNet-11k")
source_file_data = {file: read_all_raster_bands(file) for file in p.glob("**/*SPECTRAL_IMAGE.TIF")}
assert len(source_file_data) > 0

# code to create the directory
# ./result/bin/encoder --hyspecnet-11k <PATH> hyspec_artifacts/
env = lmdb.open("../artifacts/hyspecnet", readonly=True)

with env.begin(write=False) as txn:
    cur = txn.cursor()
    decoded_lmdb_data = {k.decode("utf-8"): load(v) for (k, v) in cur}

# The encoded data is nested inside of another safetensor dictionary, where the inner keys are derived from the band number as a string
decoded_dicts = [d for d in decoded_lmdb_data.values()]

# Simply check if the data remains identical, as this is the only _true_ thing I care about from the Python viewpoint
# Here I iterate over all file name and raster data as dictionaries pairs
# and then for each raster data dictionary iterate over all key-value pairs, where the key is the band name
# in the same style as the LMDB file and check if the LMDB file contained a matching array from
# a safetensors dictionary accessed via the shared band name as key.
for (source_file, source_data_dict) in source_file_data.items():
    for (source_key, source_data) in source_data_dict.items():
        assert any(np.array_equal(source_data, decoded_dict[source_key]) for decoded_dict in decoded_dicts), f"Couldn't find data in the LMDB database that matches the data from: {source_file}:{source_key}"

## Optimizing access patterns

Strictly speaking, we are not taking advantage of the lazy-loading API for our bigearthnet patches, as we are using `load` which internally calls `deserialize` on the byte string
and iterates over all elements and adds them to the dictionary. But some quick testing has revealed that there is no major performance penalty, especially since we are loading most of the data.
Only for HySpecNet, we can take advantage of this internal design style and directly add an `np.stack` to it.

In [13]:
safetensor_dict["B01"]

TypeError: byte indices must be integers or slices, not str

In [None]:
%%timeit
key = 'ENMAP01-____L2A-DT0000004950_20221103T162438Z_001_V010110_20221118T145147Z-Y01460273_X03110438

for i in range(1, 225):
    decoded_lmdb_data[key][str(i)]

In [None]:
%%timeit
for i in range(1, 22):
    decoded_lmdb_data[key][str(i)]

In [None]:
%%timeit
np.stack([decoded_lmdb_data[key][str(i)] for i in range(1, 225)], axis=0)

In [None]:
%%timeit
np.stack([decoded_lmdb_data[key][str(i)] for i in range(1, 22)], axis=0)

In [None]:
%%timeit
a = np.zeros

In [None]:
SUPPORTED_BANDS = list(i for i in range(5))
np.stack([decoded_lmdb_data[key][str(i)] for i in range(1, 22) if i in SUPPORTED_BANDS], axis=0).shape

In [None]:
%%timeit
a = np.zeros((224, 128, 128))
for i in range(1, 224):
    a[i-1] = decoded_lmdb_data[key][str(i)]

In [None]:
%%timeit
minimum_value = 0
maximum_value = 10000

clipped = np.stack([decoded_lmdb_data[key][str(i)] for i in range(1, 225)], axis=0).clip(min=minimum_value, max=maximum_value)

In [None]:
%%timeit
out_data = (clipped - minimum_value) / (maximum_value - minimum_value)

In [None]:
%%timeit
out_dataf = out_data.astype(np.float32)

In [None]:
%%timeit
# astype without explicit intermediate value is just as fast as with intermediate value
out_data = ((clipped - minimum_value) / (maximum_value - minimum_value)).astype(np.float32)

In [None]:
# for a single patch it takes around 10ms per patch

In [None]:
# 0.72 batches / sek bei Martin for entire training

In [32]:
## Example

import lmdb
import safetensors
from safetensors.numpy import _getdtype
from pathlib import Path

# path to the encoded dataset/output of rico-hdl
encoded_path = Path("../artifacts/ben/")

# Make sure to only open the environment once
env = lmdb.open(str(encoded_path), readonly=True)

with env.begin() as txn:
  tensor_dict = txn.get('S2A_MSIL2A_20180526T100031_N9999_R122_T34WFU_14_23'.encode())

In [33]:
# load(tensor_dict)
safetensors.deserialize??

[0;31mSignature:[0m [0msafetensors[0m[0;34m.[0m[0mdeserialize[0m[0;34m([0m[0mbytes[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Opens a safetensors lazily and returns tensors as asked

Args:
    data (`bytes`):
        The byte content of a file

Returns:
    (`List[str, Dict[str, Dict[str, any]]]`):
        The deserialized content is like:
            [("tensor_name", {"shape": [2, 3], "dtype": "F32", "data": b"\0\0.." }), (...)]
[0;31mType:[0m      builtin_function_or_method

In [39]:
%%timeit
# 31 us only for RGB
# 39 us for everything
# 28 us for B01 -> almost no difference whatsoever
result = {}
for k, v in safetensors.deserialize(tensor_dict):
    if k not in ["B01"]:
        continue
    dtype = _getdtype(v["dtype"])
    arr = np.frombuffer(v["data"], dtype=dtype).reshape(v["shape"])
    result[k] = arr

28.9 µs ± 903 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [47]:
encoded_path = "../artifacts/ben"
env = lmdb.open(str(encoded_path), readonly=True)

with env.begin() as txn:
  # string encoding is required to map the string to an LMDB key
  safetensor_dict = load(txn.get("S2A_MSIL2A_20180526T100031_N9999_R122_T34WFU_14_23".encode()))

rgb_bands = ["B04", "B03", "B02"]
rgb_tensor = np.stack([safetensor_dict[b] for b in rgb_bands])
assert rgb_tensor.shape == (3, 120, 120) 


(3, 120, 120)

In [2]:
import lmdb
import numpy as np
# import desired deep learning library:
# numpy, torch, tensorflow, paddle, flax, mlx
from safetensors.numpy import load
from pathlib import Path

# Make sure to only open the environment once
# and not everytime an item is accessed.
encoded_path = "../artifacts/hyspecnet"
env = lmdb.open(str(encoded_path), readonly=True)

with env.begin() as txn:
  # string encoding is required to map the string to an LMDB key
  safetensor_dict = load(txn.get("ENMAP01-____L2A-DT0000004950_20221103T162438Z_001_V010110_20221118T145147Z-Y01460273_X04390566".encode()))

hyspecnet_bands = range(1, 225)
# recommendation from HySpecNet-11k paper 
skip_bands = [126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 160, 161, 162, 163, 164, 165, 166]
tensor = np.stack([safetensor_dict[f"B{k}"] for k in hyspecnet_bands if k not in skip_bands])
assert tensor.shape == (202, 128, 128)

In [4]:
with env.begin() as txn:
    cur = txn.cursor()
    decoded_lmdb_data = {k.decode("utf-8"): load(v) for (k, v) in cur}