In [1]:
from dotenv import load_dotenv
load_dotenv()

True

In [2]:
!pip install gcsfs



In [3]:
# SIMPLIFIED EFFICIENT ZARR SAVING
# Focus: Proper parallelism, simple and reliable
# Removed complex auto-detection that might cause issues

import os
import time
import xarray as xr
from numcodecs import Blosc
import gcsfs

# Re-use a global filesystem client when possible
gcs = gcsfs.GCSFileSystem(project=os.getenv("GOOGLE_CLOUD_PROJECT"), token='/usr/src/app/user_id.json')
# Initialize GCS filesystem
fs = gcs # same same

def save_dataset_efficient_zarr(
    ds,
    zarr_path,
    chunk_sizes=None,
    compression='lz4',
    compression_level=1,
    overwrite=True,
    consolidated=True,
    storage='auto',
    gcs_project=None,
):
    """
    Simplified zarr saving ‚Äì focuses on reliable parallelism.

    Parameters
    ----------
    ds : xarray.Dataset
        Dataset to save (lazy dask arrays or in-memory).
    zarr_path : str
        Destination path or GCS URI (e.g. gs://bucket/path.zarr).
    chunk_sizes : dict, optional
        Chunk sizes per dimension (e.g. {'time': 20, 'x': 256, 'y': 256}).
    compression : {'lz4','blosc','zstd',None} or dict
        Built-in compressor choice or explicit encoding dict.
    compression_level : int
        Compression level (1 fastest, 9 best compression).
    overwrite : bool
        Overwrite existing zarr store.
    consolidated : bool
        Create consolidated metadata (recommended).
    storage : {'auto','local','gcs'}
        Force storage backend or infer from path when 'auto'.
    gcs_project : str, optional
        Explicit GCP project for a fresh filesystem client.

    Returns
    -------
    str
        The zarr_path that was written.
    """
    def _format_size(num_bytes: int) -> str:
        size_mb = num_bytes / (1024 * 1024)
        size_gb = size_mb / 1024
        return f"{size_gb:.2f} GB" if size_gb >= 1 else f"{size_mb:.2f} MB"

    start_time = time.time()

    storage = storage.lower()
    if storage == 'auto':
        storage = 'gcs' if zarr_path.startswith('gs://') else 'local'
    if storage not in {'local', 'gcs'}:
        raise ValueError("storage must be one of {'auto', 'local', 'gcs'}")

    fs = None
    if storage == 'gcs':
        fs = gcs if gcs_project is None else gcsfs.GCSFileSystem(project=gcs_project)
    else:
        zarr_dir = os.path.dirname(zarr_path) if os.path.dirname(zarr_path) else '.'
        if zarr_dir and not os.path.exists(zarr_dir):
            os.makedirs(zarr_dir, exist_ok=True)

    # Handle overwrite
    if storage == 'gcs':
        if fs.exists(zarr_path):
            if not overwrite:
                raise FileExistsError(
                    f"Zarr store already exists on GCS: {zarr_path}\n"
                    "Set overwrite=True to replace it."
                )
            print(f"üóëÔ∏è  Removing existing GCS zarr store: {zarr_path}")
            fs.rm(zarr_path, recursive=True)
    else:
        if os.path.exists(zarr_path):
            if not overwrite:
                raise FileExistsError(
                    f"Zarr store already exists: {zarr_path}\n"
                    "Set overwrite=True to replace it."
                )
            import shutil
            print(f"üóëÔ∏è  Removing existing zarr store: {zarr_path}")
            shutil.rmtree(zarr_path)

    # Default chunk sizes
    if chunk_sizes is None:
        chunk_sizes = {}
        dims = ds.dims
        if 'time' in dims:
            chunk_sizes['time'] = min(20, dims['time'])
        if 'x' in dims:
            chunk_sizes['x'] = min(256, dims['x'])
        if 'y' in dims:
            chunk_sizes['y'] = min(256, dims['y'])
        for dim_name, dim_len in dims.items():
            chunk_sizes.setdefault(dim_name, min(100, dim_len))

    print(f"üì¶ Saving to zarr: {zarr_path}")
    print(f"   Dimensions: {dict(ds.dims)}")
    print(f"   Chunks: {chunk_sizes}")
    print(f"   Compression: {compression} (level {compression_level})")
    print(f"   Storage: {storage}")

    # Prepare compression
    if compression == 'lz4':
        compressor = Blosc(cname='lz4', clevel=compression_level, shuffle=Blosc.SHUFFLE, blocksize=0)
        encoding = {var: {'compressor': compressor} for var in ds.data_vars}
    elif compression == 'blosc':
        compressor = Blosc(cname='blosclz', clevel=compression_level, shuffle=Blosc.SHUFFLE, blocksize=0)
        encoding = {var: {'compressor': compressor} for var in ds.data_vars}
    elif compression == 'zstd':
        compressor = Blosc(cname='zstd', clevel=compression_level, shuffle=Blosc.SHUFFLE, blocksize=0)
        encoding = {var: {'compressor': compressor} for var in ds.data_vars}
    elif compression is None:
        encoding = {}
    else:
        encoding = compression  # assume dict supplied

    # Chunk and save
    ds_chunked = ds.chunk(chunk_sizes)
    print("üíæ Writing to zarr (with automatic parallelism)...")

    store = fs.get_mapper(zarr_path) if storage == 'gcs' else zarr_path
    try:
        from dask.diagnostics import ProgressBar
        with ProgressBar():
            ds_chunked.to_zarr(
                store,
                mode='w',
                encoding=encoding,
                consolidated=consolidated,
                compute=True,
                zarr_version=2,  # ADD THIS LINE
            )
    except ImportError:
        ds_chunked.to_zarr(
            store,
            mode='w',
            encoding=encoding,
            consolidated=consolidated,
            compute=True,
            zarr_version=2,  # ADD THIS LINE
        )

    elapsed = time.time() - start_time

    # Size reporting
    total_size = None
    if storage == 'gcs':
        try:
            size_info = fs.du(zarr_path)
            if isinstance(size_info, dict):
                total_size = sum(size_info.values())
            elif isinstance(size_info, (int, float)):
                total_size = size_info
        except Exception as exc:
            print(f"‚ö†Ô∏è  Could not compute GCS store size: {exc}")
    else:
        if os.path.exists(zarr_path):
            total_size = 0
            for dirpath, _, filenames in os.walk(zarr_path):
                for f in filenames:
                    fp = os.path.join(dirpath, f)
                    total_size += os.path.getsize(fp)

    if total_size is not None:
        size_str = _format_size(total_size)
        write_speed = total_size / elapsed / (1024 * 1024)
        print("‚úÖ Dataset saved successfully!")
        print(f"   Store size: {size_str}")
        print(f"   Time: {elapsed:.1f} seconds ({elapsed/60:.1f} minutes)")
        print(f"   Write speed: {write_speed:.1f} MB/s")
        print(f"   Path: {zarr_path}")
    else:
        print("‚úÖ Dataset saved successfully! (size unavailable)")
        print(f"   Time: {elapsed:.1f} seconds ({elapsed/60:.1f} minutes)")
        print(f"   Path: {zarr_path}")

    return zarr_path


def load_dataset_zarr(zarr_path, consolidated=True, storage='auto', gcs_project=None):
    """
    Load a dataset from a zarr store located locally or on GCS.
    """
    storage = storage.lower()
    if storage == 'auto':
        storage = 'gcs' if zarr_path.startswith('gs://') else 'local'
    if storage not in {'local', 'gcs'}:
        raise ValueError("storage must be one of {'auto', 'local', 'gcs'}")

    if storage == 'gcs':
        fs = gcs if gcs_project is None else gcsfs.GCSFileSystem(project=gcs_project)
        if not fs.exists(zarr_path):
            raise FileNotFoundError(f"Zarr store not found on GCS: {zarr_path}")
        mapper = fs.get_mapper(zarr_path)
        print(f"üìÇ Loading dataset from GCS zarr: {zarr_path}")
        ds = xr.open_zarr(mapper, consolidated=consolidated)
    else:
        if not os.path.exists(zarr_path):
            raise FileNotFoundError(f"Zarr store not found: {zarr_path}")
        print(f"üìÇ Loading dataset from zarr: {zarr_path}")
        ds = xr.open_zarr(zarr_path, consolidated=consolidated)

    print(f"‚úÖ Dataset loaded: {dict(ds.dims)}")
    return ds


print("‚úÖ Simplified zarr saving functions loaded!")
print("\nKey simplifications:")
print("  - No complex auto-detection")
print("  - Always uses compute=True (let dask handle parallelism)")
print("  - Simple, reliable, focuses on parallelism")
print("  - Works with both lazy and in-memory arrays")


‚úÖ Simplified zarr saving functions loaded!

Key simplifications:
  - No complex auto-detection
  - Always uses compute=True (let dask handle parallelism)
  - Simple, reliable, focuses on parallelism
  - Works with both lazy and in-memory arrays


In [9]:
zarr_path = os.getenv('GCS_ZARR_DIR') + '/ds_resampled.zarr'
# zarr_path = 'data/ds_resampled.zarr'
# storage = 'local'
storage = 'gcs'

ds_resampled = load_dataset_zarr(zarr_path, storage=storage)
ds_resampled


üìÇ Loading dataset from GCS zarr: gs://remote_sensing_saas/01-korindo/timeseries_zarr/ds_resampled.zarr


‚úÖ Dataset loaded: {'time': 81, 'x': 4489, 'y': 3213}


  print(f"‚úÖ Dataset loaded: {dict(ds.dims)}")


Unnamed: 0,Array,Chunk
Bytes,5.38 kiB,2.66 kiB
Shape,"(81,)","(40,)"
Dask graph,3 chunks in 2 graph layers,3 chunks in 2 graph layers
Data type,,
"Array Chunk Bytes 5.38 kiB 2.66 kiB Shape (81,) (40,) Dask graph 3 chunks in 2 graph layers Data type",81  1,

Unnamed: 0,Array,Chunk
Bytes,5.38 kiB,2.66 kiB
Shape,"(81,)","(40,)"
Dask graph,3 chunks in 2 graph layers,3 chunks in 2 graph layers
Data type,,

Unnamed: 0,Array,Chunk
Bytes,4.35 GiB,160.00 MiB
Shape,"(81, 4489, 3213)","(40, 1024, 1024)"
Dask graph,60 chunks in 2 graph layers,60 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 4.35 GiB 160.00 MiB Shape (81, 4489, 3213) (40, 1024, 1024) Dask graph 60 chunks in 2 graph layers Data type float32 numpy.ndarray",3213  4489  81,

Unnamed: 0,Array,Chunk
Bytes,4.35 GiB,160.00 MiB
Shape,"(81, 4489, 3213)","(40, 1024, 1024)"
Dask graph,60 chunks in 2 graph layers,60 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,4.35 GiB,160.00 MiB
Shape,"(81, 4489, 3213)","(40, 1024, 1024)"
Dask graph,60 chunks in 2 graph layers,60 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 4.35 GiB 160.00 MiB Shape (81, 4489, 3213) (40, 1024, 1024) Dask graph 60 chunks in 2 graph layers Data type float32 numpy.ndarray",3213  4489  81,

Unnamed: 0,Array,Chunk
Bytes,4.35 GiB,160.00 MiB
Shape,"(81, 4489, 3213)","(40, 1024, 1024)"
Dask graph,60 chunks in 2 graph layers,60 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [10]:
print(ds_resampled)

<xarray.Dataset> Size: 9GB
Dimensions:   (time: 81, x: 4489, y: 3213)
Coordinates:
    image_id  (time) <U17 6kB dask.array<chunksize=(40,), meta=np.ndarray>
  * time      (time) datetime64[ns] 648B 2018-02-15 2018-05-15 ... 2025-08-15
  * x         (x) float64 36kB 5.786e+05 5.786e+05 ... 6.235e+05 6.235e+05
  * y         (y) float64 26kB 9.949e+06 9.949e+06 ... 9.982e+06 9.982e+06
Data variables:
    EVI       (time, x, y) float32 5GB dask.array<chunksize=(40, 1024, 1024), meta=np.ndarray>
    NDVI      (time, x, y) float32 5GB dask.array<chunksize=(40, 1024, 1024), meta=np.ndarray>
Attributes:
    crs:      EPSG:32749


In [11]:
crs_ds = ds_resampled.attrs.get('crs')

In [6]:
# crs_ds = 'EPSG:32749'

In [29]:
#### DS TRAIN CHECK RESAMPLING
import os
import geopandas as gpd
import pandas as pd
import gcsfs

path_training = '00_input/training_shp/'

layers = [shp for shp in os.listdir(path_training) if shp.endswith('.shp') and shp.startswith('sample')]
gcs_path = 'gs://remote_sensing_saas/01-korindo/sample_tsfresh/20251112_training_gdf_col_filtered.parquet'

use_parquet_training = True

if use_parquet_training != True:
    gdf_list = []

    for lyr in layers:
        print(lyr)
        gdf = gpd.read_file(os.path.join(path_training,lyr))
        gdf['layer'] = lyr.replace('.shp', '')
        print(gdf.crs)
        print('check size, if too big, you need to recheck: ',gdf.shape)
        gdf_utm = gdf.to_crs(crs_ds)       
        print('transforming to crs: ',gdf_utm.crs)

        # data dissolve to clean, if the data is too big, use ArcGIS or QGIS instead
        list_columns_time = [i for i in gdf_utm.columns if i.startswith('t_') and not i.endswith('D')]
        gdf_utm = gdf_utm.dissolve(by=['layer']+list_columns_time)
        gdf_utm = gdf_utm.reset_index()
        gdf_list.append(gdf_utm)

    training_gdf = gpd.GeoDataFrame(pd.concat(gdf_list, ignore_index=True))
    # training_gdf = gpd.read_file('00_input/training_shp/sample_1.shp')
    # training_gdf.head()
    # training_gdf.head()

    # training_gdf.crs

    # training_gdf.geometry.head()
    # training_gdf.columns

    list_columns_time = [i for i in training_gdf.columns if i.startswith('t_') and not i.endswith('D')]
    list_columns_time = list(sorted(list_columns_time))
    # list_columns_time

    columns_filter = ['layer'] + list_columns_time + ['geometry']

    training_gdf_col_filtered = training_gdf[columns_filter]
    # training_gdf_col_filtered = training_gdf_col_filtered.dissolve(by=['layer']+list_columns_time)
    # training_gdf_col_filtered = training_gdf_col_filtered.reset_index()
    # training_gdf_col_filtered.head()
    # training_gdf_col_filtered.plot()

    # Save as GeoParquet (BEST option)
    gcs_path = gcs_path    
    training_gdf_col_filtered.to_parquet(gcs_path, filesystem=fs, compression='snappy')

else:
    training_gdf_col_filtered = gpd.read_parquet(gcs_path, filesystem=fs)
print('final shape: ',training_gdf_col_filtered.shape)

final shape:  (78, 71)


In [30]:
# for i in training_gdf.columns:
#     print(i)
# # training_gdf_col_filtered.head()

In [31]:
training_gdf_col_filtered.head()

Unnamed: 0,layer,t_201603,t_201606,t_201609,t_201612,t_201703,t_201706,t_201709,t_201712,t_201803,...,t_202501,t_202502,t_202503,t_202504,t_202505,t_202506,t_202507,t_202508,t_202509,geometry
0,sample_2,,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,,"MULTIPOLYGON (((592439.970 9950624.446, 592472..."
1,sample_2,,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,,"POLYGON ((591963.540 9951127.666, 591985.528 9..."
2,sample_2,,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,,"MULTIPOLYGON (((592062.841 9950226.187, 592071..."
3,sample_2,,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,,"MULTIPOLYGON (((590310.884 9951179.496, 590319..."
4,sample_2,,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,,"POLYGON ((590565.026 9950696.784, 590564.841 9..."


In [32]:
# ### TAKES TIME FOR THIS GEE since we will upload data to GEE
# # import os
# import ee
# # import geemap.foliumap as geemap

# service_account = os.getenv('SERVICE_ACCOUNT')
# key_path = '/usr/src/app/user_id.json'

# credentials = ee.ServiceAccountCredentials(service_account, key_path)
# ee.Initialize(credentials)

# # gdf_wgs84 = training_gdf_col_filtered.to_crs(epsg=4326)

# # centroid = gdf_wgs84.geometry.centroid.unary_union.centroid
# # m = geemap.Map(center=[centroid.y, centroid.x], zoom=9, ee_initialize=False)
# # m.add_gdf(gdf_wgs84, layer_name="training_gdf_col_filtered")
# # m

In [33]:
# training_gdf_col_filtered.to_file("00_input/training_shp/training_gdf_col_filtered_4326.shp") # local non gcs parquet

In [34]:
##no need to plot for the second time if the data is the same
# Quick plot (very fast)
# from matplotlib import pyplot as plt

# training_gdf_col_filtered.plot(figsize=(12, 10), markersize=0.5)
# plt.title(f"Training Data ({len(training_gdf_col_filtered):,} features)")
# plt.show()

In [35]:
# ## lets load in wmts from gdf, its LOADING TOO SLOW TO LOAD AS THE GDF IF LOAD FROM WFS ee.featureCollection
# gdf_wgs84 = training_gdf_col_filtered.to_crs(epsg=4326)

# #### wmts loading
# from wfs_manager import WFSManager
# import geemap

# fc = geemap.gdf_to_ee(gdf_wgs84)

# wfs = WFSManager(fastapi_url="http://fastapi:8000", wfs_base_url="http://localhost:8001")
# wfs.addLayer(fc, "Training data collected")
# wfs.publish()

In [36]:
# training_gdf_col_filtered.columns

In [38]:
import pandas as pd
import geopandas as gpd

use_existing_df_long = True
gcs_path_df_long = 'gs://remote_sensing_saas/01-korindo/sample_tsfresh/20251112_df_long.parquet'

if use_existing_df_long != True:
    # training_gdf_col_filtered

    t_cols = [col for col in training_gdf_col_filtered.columns if col.startswith('t_')]

    # Melt the dataframe with geometry as id_var
    df_long = pd.melt(
        training_gdf_col_filtered, 
        id_vars=['layer','geometry'],
        value_vars=t_cols,
        var_name='time_period',
        value_name='value'
    )

    # Set geometry as index
    # df_long = df_long.set_index('geometry')

    print(f"Original shape: {training_gdf_col_filtered.shape}")
    print(f"Long format shape: {df_long.shape}")
    print(f"Time columns found: {len(t_cols)}")
    # print(f"\nFirst few rows:")
    # df_long.head()
    # df_long.layer.unique()

    # df_long.reset_index(inplace=True)
    # df_long.head()

    ## reformating the column (date)
    df_long = df_long.rename(columns={'value': 'type'})
    df_long['date'] = df_long['time_period'].str[2:].astype(int)
    # df_long.head()

    # type(df_long)

    training_gdf = df_long.copy()
    # Remove multipolygons
    training_gdf = training_gdf.explode(index_parts=False)
    print(training_gdf.crs)

    ##### conversion
    # Drop rows where 'type' is NA
    training_gdf = training_gdf.copy()
    training_gdf = training_gdf.dropna(subset=['type'])
    training_gdf['date'] = pd.to_datetime(training_gdf['date'], format='%Y%m')

    # add year columnt
    training_gdf['year'] = training_gdf['date'].dt.year

    # convert column 'type' from string to int
    training_gdf['type'] = training_gdf['type'].astype(int)
    print('shape of training_gdf after drop NA:', training_gdf.shape)
    
    training_gdf.to_parquet(gcs_path_df_long, filesystem=fs, compression='snappy')

else:
    training_gdf = gpd.read_parquet(gcs_path_df_long, filesystem=fs)

In [20]:
training_gdf.head()

Unnamed: 0,layer,time_period,type,date,geometry,year
131213,sample_3,t_201603,1,2016-03-01,"POLYGON ((584180.083 9968585.299, 584157.366 9...",2016
131214,sample_3,t_201603,1,2016-03-01,"POLYGON ((583509.394 9968409.259, 583519.222 9...",2016
131215,sample_3,t_201603,1,2016-03-01,"POLYGON ((585435.167 9969315.480, 585465.247 9...",2016
131216,sample_3,t_201603,1,2016-03-01,"POLYGON ((585162.142 9969842.168, 585157.111 9...",2016
131217,sample_3,t_201603,1,2016-03-01,"POLYGON ((584315.521 9969957.392, 584316.046 9...",2016


In [21]:
# test = 't_202509'
# test[2:]

In [22]:
# training_gdf['date']

In [23]:
# training_gdf['year']

In [24]:
## to minimize the ram usage, we can delete the ram usage to variables
# del training_gdf_col_filtered
# del df_long


In [26]:
# ============================================================================
# STEP 1: IMPORT REQUIRED LIBRARIES
# ============================================================================
from multiprocessing import Pool  # For parallel processing (running multiple tasks at once)
from functools import partial     # For creating functions with pre-filled arguments
import xarray as xr              # For working with multi-dimensional arrays (like netCDF, zarr)
import numpy as np               # For numerical operations
import pandas as pd              # For data manipulation
import multiprocessing as mp     # For getting CPU count

# ============================================================================
# STEP 2: DEFINE FUNCTION TO CREATE MASK FOR ONE DATE-LAYER COMBINATION
# ============================================================================
def get_raster_mask_with_layer(date_layer_tuple, ds, gdf_1_dissolved, gdf_0_dissolved):
    """
    Creates a raster mask from vector polygons for a specific date and layer.
    
    What this does:
    - Takes polygons (shapes) from your training data
    - Converts them into a grid (raster) matching your satellite data
    - Assigns value 1 for tree areas, 0 for non-tree areas, NaN for no data
    
    Parameters:
    -----------
    date_layer_tuple : tuple (date, layer)
        The specific date and layer (e.g., 'sample_1') to process
    ds : xarray.Dataset
        Your satellite dataset (to match grid size and coordinates)
    gdf_1_dissolved : GeoDataFrame
        Pre-processed polygons for tree areas (type=1)
    gdf_0_dissolved : GeoDataFrame
        Pre-processed polygons for non-tree areas (type=0)
    
    Returns:
    --------
    mask_da : xarray.DataArray
        A grid with 1s (trees), 0s (non-trees), and NaN (no label)
    """
    import numpy as np
    import rasterio.features  # For converting vector shapes to raster grids
    from affine import Affine  # For coordinate transformations
    
    # --- Unpack the date and layer from input tuple ---
    date, layer = date_layer_tuple
    
    # --- Filter polygons for this specific date and layer ---
    # MultiIndex means the index has multiple levels (date AND layer)
    # We need to match BOTH date and layer to get the right polygons
    trees = gdf_1_dissolved[
        (gdf_1_dissolved.index.get_level_values('date') == date) & 
        (gdf_1_dissolved.index.get_level_values('layer') == layer)
    ]
    non_trees = gdf_0_dissolved[
        (gdf_0_dissolved.index.get_level_values('date') == date) & 
        (gdf_0_dissolved.index.get_level_values('layer') == layer)
    ]
    
    # --- Prepare features for rasterization ---
    # Features are tuples of (geometry, value)
    # All tree geometries get value 1, all non-tree geometries get value 0
    features = [(geom, 1) for geom in trees.geometry] + \
               [(geom, 0) for geom in non_trees.geometry]
    
    # --- Get the grid dimensions from your satellite dataset ---
    x = ds.coords['x'].values  # X coordinates (longitude-like)
    y = ds.coords['y'].values  # Y coordinates (latitude-like)
    
    # --- Handle case where there are no training labels for this date-layer ---
    if not features:
        # Create an empty grid filled with NaN (no data)
        mask_raster = np.full((len(y), len(x)), np.nan, dtype="float32")
    else:
        # --- Calculate pixel resolution (size of each grid cell) ---
        # Resolution is the distance between adjacent pixels
        res_x = (x[-1] - x[0]) / (len(x) - 1)  # Horizontal resolution
        res_y = (y[0] - y[-1]) / (len(y) - 1)  # Vertical resolution
        
        # --- Create affine transformation ---
        # This tells rasterio how to map real-world coordinates to pixel indices
        # Translation: moves origin to top-left corner of top-left pixel
        # Scale: defines pixel size (negative y because images start from top)
        transform = Affine.translation(x[0] - res_x / 2, y[0] + res_y / 2) * \
                   Affine.scale(res_x, -res_y)
        
        # --- Rasterize: convert vector polygons to raster grid ---
        mask_raster = rasterio.features.rasterize(
            features,              # List of (geometry, value) tuples
            out_shape=(len(y), len(x)),  # Output grid size
            transform=transform,   # Coordinate transformation
            fill=np.nan,          # Value for pixels outside all polygons
            dtype="float32"       # Data type (float to allow NaN)
        )
    
    # --- Wrap the numpy array in an xarray DataArray ---
    # This adds coordinate labels and metadata to the grid
    mask_da = xr.DataArray(
        mask_raster,              # The actual data (2D grid)
        dims=("y", "x"),         # Dimension names
        coords={
            "y": ds.coords["y"],  # Y coordinate values
            "x": ds.coords["x"],  # X coordinate values
            "date": date,         # Date as scalar coordinate
            "layer": layer        # Layer as scalar coordinate
        },
    )
    
    return mask_da


# ============================================================================
# STEP 3: DEFINE FUNCTION FOR PARALLEL PROCESSING
# ============================================================================
def parallel_rasterize_with_layer(date_layer_combinations, ds, gdf_1_dissolved, 
                                   gdf_0_dissolved, n_workers=4):
    """
    Process multiple date-layer combinations in parallel (at the same time).
    
    Why parallel?
    - Instead of processing dates one by one (slow), we process multiple at once
    - Uses multiple CPU cores to speed up computation
    - Example: 4 workers = ~4x faster for independent tasks
    
    Parameters:
    -----------
    date_layer_combinations : list of tuples
        All (date, layer) pairs to process
    ds : xarray.Dataset
        Your satellite dataset
    gdf_1_dissolved, gdf_0_dissolved : GeoDataFrames
        Pre-processed training polygons
    n_workers : int
        Number of parallel processes (usually = number of CPU cores)
    
    Returns:
    --------
    masks : list of xarray.DataArray
        One mask for each date-layer combination
    """
    # --- Create a partial function with pre-filled arguments ---
    # This is needed because pool.map() can only pass one argument
    # We "freeze" ds, gdf_1_dissolved, gdf_0_dissolved so only date_layer changes
    func = partial(
        get_raster_mask_with_layer, 
        ds=ds, 
        gdf_1_dissolved=gdf_1_dissolved,
        gdf_0_dissolved=gdf_0_dissolved
    )
    
    # --- Create a pool of worker processes and run in parallel ---
    with Pool(n_workers) as pool:
        masks = pool.map(func, date_layer_combinations)
    
    return masks

In [None]:
# ============================================================================
# STEP 4: PREPARE DATA AND RUN PARALLEL PROCESSING
# ============================================================================

# --- Get all unique date-layer combinations from training data ---
# drop_duplicates() removes repeated combinations
# itertuples() converts each row to a tuple (date, layer)
date_layer_combos = list(
    training_gdf[['date', 'layer']]
    .drop_duplicates()
    .itertuples(index=False, name=None)
)
print(f"Processing {len(date_layer_combos)} date-layer combinations")
print(f"Example: {date_layer_combos[:3]}")  # Show first 3 combinations

# --- Determine number of parallel workers ---
# Use all available CPU cores, but not more than we have combinations
n_workers = int(min(mp.cpu_count(), len(date_layer_combos))/2) #lets half
print(f"Using {n_workers} parallel workers")

# --- Pre-process: Dissolve geometries by date AND layer ---
# dissolve() merges overlapping/touching polygons into single shapes
# Why? Faster rasterization and avoids duplicate pixels
# This is done ONCE here instead of inside the loop (huge speed-up!)
# print("Dissolving geometries by date and layer...")
# gdf_1_dissolved = training_gdf[training_gdf['type'] == 1].dissolve(by=['date', 'layer'])
# gdf_0_dissolved = training_gdf[training_gdf['type'] == 0].dissolve(by=['date', 'layer'])
print('using dissolved gdf, dissolve is done in arcgis separately because its too slow in geopandas')

print(f"  Trees (type=1): {len(gdf_1_dissolved)} date-layer groups")
print(f"  Non-trees (type=0): {len(gdf_0_dissolved)} date-layer groups")

# --- Run parallel processing to create all masks ---
print("Creating masks in parallel...")
masks = parallel_rasterize_with_layer(
    date_layer_combos,     # All date-layer combinations
    ds_resampled,         # Your satellite dataset
    gdf_1_dissolved,      # Pre-dissolved tree polygons
    gdf_0_dissolved,      # Pre-dissolved non-tree polygons
    n_workers=n_workers   # Number of parallel processes
)

print(f"‚úì Created {len(masks)} masks")
print(f"  Each mask shape: {masks[0].shape}")

Processing 142 date-layer combinations
Example: [(Timestamp('2016-03-01 00:00:00'), 'sample_3'), (Timestamp('2016-06-01 00:00:00'), 'sample_2'), (Timestamp('2016-06-01 00:00:00'), 'sample_3')]
Using 2.0 parallel workers
Dissolving geometries by date and layer...
  Trees (type=1): 142 date-layer groups
  Non-trees (type=0): 140 date-layer groups
Creating masks in parallel...


TypeError: 'float' object cannot be interpreted as an integer

In [None]:
# ============================================================================
# STEP 5: ORGANIZE MASKS INTO A 3D DATASET WITH LAYER AS COORDINATE
# ============================================================================
def merge_all_masks_3d(masks):
    """
    Combine individual masks into a single 3D dataset.
    
    What this does:
    - Takes a list of 2D masks (y, x)
    - Organizes them into a 3D array (date, y, x)
    - Keeps 'layer' (plot) as a coordinate label for each date
    
    Result: Each date-layer combination is a separate time point
    Example: date=2020-01-01 with layer='sample_1', date=2020-01-01 with layer='sample_2'
    
    Parameters:
    -----------
    masks : list of xarray.DataArray
        Individual masks from parallel processing
    
    Returns:
    --------
    gt : xarray.Dataset
        3D dataset (date, y, x) with layer/plot as coordinate
    """
    import pandas as pd
    
    # --- Extract date and layer from each mask ---
    # We'll create unique date-layer combinations as separate time points
    dates_with_layer = []
    layers = []
    
    for mask in masks:
        date = pd.Timestamp(mask.coords['date'].values)
        layer = mask.coords['layer'].values
        dates_with_layer.append(date)
        layers.append(layer)
    
    print(f"Organizing into 3D dataset:")
    print(f"  {len(masks)} total date-layer combinations")
    print(f"  Will be stored as {len(masks)} time points")
    
    # --- Stack all masks along a new dimension ---
    # Each mask becomes one time slice
    # Remove the old scalar coordinates first to avoid conflicts
    masks_cleaned = []
    for mask in masks:
        # Keep only the spatial data, drop scalar coordinates temporarily
        mask_clean = mask.drop_vars(['date', 'layer'], errors='ignore')
        masks_cleaned.append(mask_clean)
    
    # Concatenate all masks along a new 'date' dimension
    combined = xr.concat(masks_cleaned, dim='date')
    
    # --- Add date and layer as coordinates ---
    # 'date' is the dimension (time axis)
    # 'layer' (or 'plot') is a coordinate that varies along date
    combined = combined.assign_coords({
        'date': dates_with_layer,  # Actual dates
        'plot': ('date', layers)   # Layer/plot name for each date
    })
    
    # --- Create Dataset with proper structure ---
    gt = xr.Dataset(
        {
            'ground_truth': combined
        }
    )
    
    # --- Add metadata attributes ---
    gt.attrs['description'] = 'Ground truth training masks'
    gt.attrs['values'] = '0=non-tree, 1=tree, NaN=no label'
    gt['ground_truth'].attrs['units'] = 'category'
    gt.coords['plot'].attrs['description'] = 'Training plot/layer identifier'
    
    # --- Show summary ---
    unique_plots = list(set(layers))
    print(f"  {len(unique_plots)} unique plots: {unique_plots}")
    print(f"  Final dimensions: date={len(dates_with_layer)}, y={len(gt.y)}, x={len(gt.x)}")
    
    return gt


# --- Run the merge function ---
print("\nMerging masks into 3D dataset...")
gt = merge_all_masks_3d(masks)

# --- Display the result ---
print("\n" + "="*60)
print("FINAL DATASET:")
print("="*60)
print(gt)
print("\nCoordinates:")
print(f"  date: {len(gt.date)} time points")
print(f"  plot: {len(gt.plot)} labels (one per date)")
print(f"  x: {len(gt.x)} pixels")
print(f"  y: {len(gt.y)} pixels")

# --- Examples of accessing the data ---
print("\n" + "="*60)
print("EXAMPLE USAGE:")
print("="*60)

# Example 1: Select by date (may have multiple plots for same date)
print("\n1. Get all data for a specific date:")
print("   gt.sel(date='2025-01-01')")
print("   Note: May return multiple entries if multiple plots have this date")

# Example 2: Filter by plot name
print("\n2. Get all dates for a specific plot:")
print("   gt.where(gt.plot == 'sample_1', drop=True)")

# Example 3: Select specific date AND plot
print("\n3. Get data for specific date and plot:")
print("   gt.sel(date=gt.date[gt.plot == 'sample_1'])")

# Example 4: Get plot name for each date
print("\n4. See which plot each date corresponds to:")
print("   gt.plot.values")

# Example 5: Group by plot
print("\n5. Work with one plot at a time:")
print("   for plot_name in gt.plot.values:")
print("       plot_data = gt.where(gt.plot == plot_name, drop=True)")


# ============================================================================
# STEP 6: CREATE VALIDITY MASK
# ============================================================================
print("\n" + "="*60)
print("CREATING VALIDITY MASK:")
print("="*60)

# Create validity mask: pixels that have non-NaN values for all dates
# We check across all date-layer combinations
gt['valid'] = gt['ground_truth'].notnull().all(dim='date')

print("‚úì Added 'valid' variable showing pixels with labels for ALL date-layer combinations")
print(f"  Shape: {gt['valid'].shape}")
print(f"  Valid pixels: {gt['valid'].sum().values:,}")

# If you want validity per plot, you can compute it separately:
print("\nValid pixels per plot:")
for plot_name in sorted(set(gt.plot.values)):
    # Get all dates for this plot
    plot_mask = gt.where(gt.plot == plot_name, drop=True)
    n_valid = plot_mask['ground_truth'].notnull().all(dim='date').sum().values
    n_dates = (gt.plot == plot_name).sum().values
    print(f"  {plot_name}: {n_valid:,} pixels (across {n_dates} dates)")


# ============================================================================
# HELPER FUNCTION: SELECT BY DATE AND PLOT
# ============================================================================
def select_by_date_and_plot(gt, date, plot_name):
    """
    Helper function to easily select data for a specific date and plot.
    
    Parameters:
    -----------
    gt : xarray.Dataset
        Your ground truth dataset
    date : str or pd.Timestamp
        The date to select
    plot_name : str
        The plot/layer name
    
    Returns:
    --------
    data : xarray.Dataset
        Subset of data matching the criteria
    """
    date = pd.Timestamp(date)  # Ensure date is in correct format
    
    # Find indices where both date and plot match
    mask = (gt.date == date) & (gt.plot == plot_name)
    
    if not mask.any():
        print(f"Warning: No data found for date={date} and plot={plot_name}")
        return None
    
    return gt.sel(date=gt.date[mask])


# --- Example usage of helper function ---
print("\n" + "="*60)
print("HELPER FUNCTION EXAMPLE:")
print("="*60)
print("# Get data for specific date and plot:")
print("result = select_by_date_and_plot(gt, '2025-01-01', 'sample_1')")
print("print(result)")

In [None]:
masks[0]

In [None]:
# ds_resampled

In [None]:
# training_gdf.date.unique()

In [None]:
# merge mask data
mask_da = xr.concat(masks, dim='date')

# create mask that will containt True if all pixles in year range have valid value (ie. not nan)
mask = mask_da.notnull().all(dim='date')

In [None]:
mask

In [None]:
# convert to dataset 
gt = mask_da.to_dataset(name='ground_truth')
gt['valid'] = mask

In [None]:
gt

In [None]:
gt.date

In [None]:
### lets check RAM usage
!free -h

In [None]:
### if RAM too much consumed, we can delete object
del mask_da
del masks
del gdf_1_dissolved
del gdf_0_dissolved
del training_gdf

# Force garbage collection
import gc
gc.collect()

# If you're using Dask (which you are with zarr data), clear the cache
from distributed import Client
try:
    client = Client.current()
    client.restart()  # Restart workers to clear memory
except ValueError:
    # No distributed client, use local cache clearing
    import dask
    dask.config.set(scheduler='synchronous')
    
# Or simply clear local Dask cache
try:
    from dask.cache import Cache
    Cache(2e9).clear()  # Clear cache
except:
    pass

In [None]:
!free -h

In [None]:
gt

In [None]:
### saving the xarray ds -> gt - sample groundtruth to the bucket
gcs_path_gt = 'gs://remote_sensing_saas/01-korindo/timeseries_zarr/20251112_gt.zarr'
# gt.to_zarr(gcs_path_gt, mode='w', compute=True, consolidated=True,)
# Save to GCS
save_dataset_efficient_zarr(
    gt,
    gcs_path_gt,
    chunk_sizes={'date':20, 'x':512, 'y':512},
    compression='lz4',
    compression_level=1,
    overwrite=True,
    consolidated=True,
    storage='gcs'
)

In [None]:
!free -h

In [None]:
### INPUT
resampling_freq = 'MS'
# gt.date is already datetime, so just convert to pandas Timestamp
start_date = pd.Timestamp(gt.date.min().values)
cut_off_date = pd.Timestamp(gt.date.max().values)
print(f"Start date: {start_date}")
print(f"Cut-off date: {cut_off_date}")
cut_off_date

In [None]:
# plot_gt_mask(gdf, gt, show=True, save=False)
# plot_gt_usable_data(gt, show=True, save=False)

# convert yeart in gt from int to datetime
# gt['date'] = pd.to_datetime(gt['date'].astype(str),  format='%Y%-m')

gt = gt.sortby('date').rename({'date': 'time'})
monthly_time = pd.date_range(start_date, cut_off_date, freq=resampling_freq)
# This will repeat the value from January 1 throughout the year until the next available time.
gt = gt.reindex(time=monthly_time, method='ffill').sortby('time')

In [None]:
gt

In [None]:
import numpy as np
np.unique(gt.sel(time='2025-9-01').valid.values)

In [None]:
gt.sel(time='2025-01-01').valid