In [23]:
import xarray as xr
import random
import glob
import os
import h5py
import numpy as np
import itertools
import pandas as pd

H = 128
W = 128
START_YEAR = 2017
END_YEAR = 2023
T = (END_YEAR - START_YEAR + 1) * 73

save_dir = "/data_1/scratch_1/dbrueggemann"
cubes = glob.glob("/data_1/scratch/ntinner/cubes/*_raw.nc")

def cartesian_product(a, b):
    return np.array(list(itertools.product(a, b)))

def get_doy(dates):
    return np.array(
            [
                pd.to_datetime(d).to_pydatetime().timetuple().tm_yday
                for d in dates
            ],
            dtype=int,
    )

def create_h5(file, key, data, shape, dtype, chunk_size=None):
    file.create_dataset(
        key,
        data=data,
        maxshape=(None, *shape),
        dtype=dtype,
        compression="lzf",
        chunks=chunk_size,
    )

def append_h5(file, key, data):
    file[key].resize((file[key].shape[0] + data.shape[0]), axis=0)
    file[key][-data.shape[0] :] = data

def check_missing_timestamps(cube, max_conseq_dates=2):
    """Check for missing timestamps in cube.

    Args:
        cube (xr.Dataset): Cube to check for missing timestamps.
        max_conseq_dates (int): Maximum number of consecutive missing timestamps to allow.

    Returns:
        missing_dates (list): List of missing timestamps
    """
    timestamps = cube.time.values
    missing_dates = []

    # beginning of 2017
    current_timestamp = timestamps[0]
    while (current_timestamp - np.timedelta64(5, "D")).astype("datetime64[Y]").astype(
        int
    ) + 1970 >= START_YEAR:
        current_timestamp -= np.timedelta64(5, "D")
        missing_dates.append(current_timestamp)

    # end of 2023
    current_timestamp = timestamps[-1]
    while (current_timestamp + np.timedelta64(5, "D")).astype("datetime64[Y]").astype(
        int
    ) + 1970 <= END_YEAR:
        current_timestamp += np.timedelta64(5, "D")
        missing_dates.append(current_timestamp)

    current_timestamp = timestamps[0]
    last_timestamp = timestamps[-1]
    nr_conseq_dates_max = 0
    while current_timestamp < last_timestamp:
        # Check for presence of next timestamp at 5 days interval
        expected_date = current_timestamp + np.timedelta64(5, "D")
        if expected_date not in timestamps:
            missing_dates.append(expected_date)
            # Record number of consecutive missing timestamps
            if len(missing_dates) > 1 and (
                missing_dates[-1] - missing_dates[-2]
            ) == np.timedelta64(5, "D"):
                nr_conseq_dates_max += 1
            else:
                nr_conseq_dates_max = 1
        current_timestamp = expected_date

    if nr_conseq_dates_max > max_conseq_dates:
        print(f"Warning: Too many consecutive missing dates ({nr_conseq_dates_max})")

    return missing_dates

with h5py.File(os.path.join(save_dir, "qforest_dataset.h5"), "w") as h5_file:
    i_max = 200
    i = 0
    for c in cubes:

        try:
            minicube = xr.open_dataset(c, engine="h5netcdf")
        except OSError:
            continue

        missing_dates = check_missing_timestamps(minicube)
        if missing_dates:
            minicube = minicube.reindex(
                time=np.sort(np.concatenate([minicube.time.values, missing_dates]))
            )

        try:
            s2_cube = minicube.s2_ndvi.where((minicube.s2_mask == 0) & minicube.s2_SCL.isin([1, 2, 4, 5, 6, 7])).values
        except AttributeError:
            continue
        s2_mask = (minicube.FOREST_MASK.values > 0.8)
        pixels = s2_cube[:, s2_mask].transpose(1, 0)

        longitude = np.array(minicube.lon.values, dtype=np.float32)
        latitude = np.array(minicube.lat.values, dtype=np.float32)
        lon_lat = cartesian_product(longitude, latitude).reshape(len(longitude), len(latitude), 2)
        lon_lat = lon_lat[s2_mask, :]

        dem = np.expand_dims(minicube.DEM.values[s2_mask], axis=1)

        N = pixels.shape[0]

        doy = np.expand_dims(get_doy(minicube.time.values), axis=0).repeat(N, axis=0)

        if not "ndvi" in h5_file.keys():
            create_h5(
                h5_file,
                "ndvi",
                pixels,
                (T,),
                "float32",
            )
            create_h5(
                h5_file,
                "lon_lat",
                lon_lat,
                (2,),
                "float32",
            )
            create_h5(
                h5_file,
                "dem",
                dem,
                (1,),
                "float32",
            )
            create_h5(
                h5_file,
                "doy",
                doy,
                (T,),
                "uint16",
            )
        else:
            append_h5(h5_file, "ndvi", pixels)
            append_h5(h5_file, "lon_lat", lon_lat)
            append_h5(h5_file, "dem", dem)
            append_h5(h5_file, "doy", doy)

        i += 1
        if i > i_max:
            break

AttributeError: 'Dataset' object has no attribute 's2_ndvi'

In [24]:
i

133

In [19]:
T

511

In [13]:
minicube