In [1]:
import os
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from pyproj import Transformer, CRS
from tqdm import tqdm
from skimage.util import view_as_windows

In [2]:
base_path = "/work/FAC/FGSE/IDYST/tbeucler/downscaling/raw_data/OPERA"
date_folders = sorted(os.listdir(base_path))

datasets = []
for d in tqdm(date_folders):
    ds = xr.open_zarr(os.path.join(base_path, d), chunks={})  # chunks={} triggers lazy loading
    datasets.append(ds)

ds = xr.concat(datasets, dim="time")

100%|██████████| 435/435 [02:24<00:00,  3.01it/s]


In [3]:
ds

Unnamed: 0,Array,Chunk
Bytes,1.26 TiB,31.89 MiB
Shape,"(41575, 2200, 1900)","(1, 2200, 1900)"
Dask graph,41575 chunks in 871 graph layers,41575 chunks in 871 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 1.26 TiB 31.89 MiB Shape (41575, 2200, 1900) (1, 2200, 1900) Dask graph 41575 chunks in 871 graph layers Data type float64 numpy.ndarray",1900  2200  41575,

Unnamed: 0,Array,Chunk
Bytes,1.26 TiB,31.89 MiB
Shape,"(41575, 2200, 1900)","(1, 2200, 1900)"
Dask graph,41575 chunks in 871 graph layers,41575 chunks in 871 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [4]:
def plot_prec(ds, i, title):

    # Your custom LAEA projection string (used in OPERA, close to EPSG:3035)
    laea_proj = CRS.from_proj4("+proj=laea +lat_0=55 +lon_0=10 +x_0=1950000 +y_0=-2100000 +units=m +datum=WGS84 +no_defs")

    # Target CRS (WGS 84)
    target = CRS.from_epsg(3857) # 4326, 3857

    # Set up transformer
    transformer = Transformer.from_crs(laea_proj, target, always_xy=True)

    # Grid arrays
    x = ds['x'].to_numpy()        # shape (1900,)
    y = ds['y'].to_numpy()[::-1]  # flip y to match north-up orientation
    time = ds['time'].to_numpy()[i]
    precip = ds['TOT_PREC'][i].to_numpy()[::-1, :]

    X, Y = np.meshgrid(x, y)

    # Project to lon/lat
    lon, lat = transformer.transform(X, Y)
    eps = 1e-4
    _, ax = plt.subplots(figsize=(15, 15), subplot_kw={'projection': ccrs.Mercator()}) # PlateCarree, Mercator
    im = ax.pcolormesh(lon, lat, precip, transform=ccrs.Mercator(), cmap='viridis', shading='auto')
    im = ax.pcolormesh(lon, lat, np.log(precip), transform=ccrs.Mercator(), cmap='viridis', shading='auto')

    ax.coastlines(resolution='10m')
    ax.add_feature(cfeature.BORDERS)
    ax.grid(True)
    ax.set_title(f"{title}")

    cbar = plt.colorbar(im, ax=ax, shrink=.6, pad=.02)
    cbar.set_label("Log Precipitation [kg/m²]")

    plt.tight_layout()
    plt.show()

In [5]:
# plot_prec(ds, i=8800, title="Total precipitation (Ciarán storm, Nov 2023)")

# 1. Extract data and generate dataset

In [6]:
def extract_valid_patches_3d(data, patch_size=128, stride=128, min_valid_ratio=0.5):
    time_dim, _, _ = data.shape
    all_patches = []

    min_valid = int(min_valid_ratio * patch_size * patch_size)

    print("STARTING")

    for t in tqdm(range(time_dim), desc="Extracting patches"):
        frame = data[t]
        # Get view of patches without copying data
        patches = view_as_windows(frame, (patch_size, patch_size), step=stride)
        patches = patches.reshape(-1, patch_size, patch_size)

        # Filter: keep patches with enough non-NaN values
        non_nan_counts = np.count_nonzero(~np.isnan(patches), axis=(1, 2))
        valid_patches = patches[non_nan_counts >= min_valid]

        all_patches.append(valid_patches)
        print("NEXT")

    return np.concatenate(all_patches, axis=0)

In [7]:
# data = ds["TOT_PREC"].to_numpy()

# patches = extract_valid_patches_3d(data[:1])

In [None]:
import dask.array as da
from skimage.util import view_as_windows

data_array = ds["TOT_PREC"].to_numpy()[0]

darr = da.from_array(data_array, chunks=(1, 512, 512))  # or from Zarr/NetCDF

def extract_valid_patches_dask(frame, patch_size=128, stride=128, min_valid=1):
    patches = view_as_windows(frame, (patch_size, patch_size), step=stride)
    patches = patches.reshape(-1, patch_size, patch_size)
    valid_mask = da.sum(~da.isnan(patches), axis=(1, 2)) >= min_valid * patch_size * patch_size
    return patches[valid_mask]

results = [extract_valid_patches_dask(darr[t]) for t in range(darr.shape[0])]