# Global snowmelt runoff onset processing

This notebook implements the core processing pipeline for detecting snowmelt runoff onset timing from Sentinel-1 SAR data at a global scale. The methodology detects the timing of minimum backscatter values, which correspond to snowmelt runoff onset.


## Processing pipeline
1. **Data acquisition**: Acquire Sentinel-1 RTC data from Microsoft Planetary Computer
1. **Snow masking**: Apply spatiotemporal snow cover constraints  
1. **Quality filtering**: Remove scenes with insufficient temporal sampling
1. **Calculate temporal resolution**: Calculate temporal resolution on the filtered dataset
1. **Runoff detection**: Calculate minimum backscatter timing per orbit
1. **Aggregate statistics**: Compute 10-year median and MAD for runoff onset, 10-year median for temporal resolution
1. **Output generation**: Write results to global zarr store

## Last checks!!

Before large-scale processing...

1. use [egagli/MODIS_seasonal_snow_mask](https://github.com/egagli/MODIS_seasonal_snow_mask) to create the seasonal snow cover dataset (contains: snow appearance date, snow disappearnce date, max consecutive number of snow cover days, all per water year) 
1. check if config file looks good, make sure to update paths with version number!!!!
1. cloud credentials updated if needed (i.e. `config/sas_token.txt` and `ee_key.json`) (note to self: use [egagli/azure_authentication](https://github.com/egagli/azure_authentication) to get new sas_token weekly)
1. `select_tiles_to_process.ipynb` ran and `tile_data/global_tiles_with_seasonal_snow.geojson` created
1. `create_zarr_store.ipynb` ran and zarr_store exists on cloud storage and can be read
1. coiled is working and using spot instances and price isn't too high--check usage stats too!!
1. check tiles are being output to `tile_data/tile_results_vX.csv` and tiles are showing success
1. check tiles are being processed with `view_maps.ipynb`, check all variables
1. validate on test tiles (check automatic weather station tile subset below)
1. check failed tiles, potentially adjust cluster settings

In [None]:
import easysnowdata
import pystac_client
import tqdm
import planetary_computer
import numpy as np
import pandas as pd
import geopandas as gpd
import xarray as xr
import odc.stac
import time
import dask
import dask.distributed
import coiled
import matplotlib.pyplot as plt
import traceback
from global_snowmelt_runoff_onset.config import Config, Tile
import global_snowmelt_runoff_onset.processing as processing
import flox

## Configuration overview

In [None]:
config = Config('../config/global_config_v7.txt')

## Start up the coiled cluster!

In [None]:
cluster = coiled.Cluster(idle_timeout="10 minutes",
                         n_workers=6,
                         worker_memory="32 GB", #coiled.list_instance_types(backend="azure")
                         worker_cpu=4,
                         #worker_options={"nthreads": 1},
                         #worker_options={"nthreads": 32},# 16 8 4 oversubscribe?
                         #scheduler_memory="128 GB",
                         #scheduler_memory="64 GB",
                         spot_policy="spot", # spot usually
                         environ={"GDAL_DISABLE_READDIR_ON_OPEN": "EMPTY_DIR"},
                         workspace="uwtacolab", # azure
                         
                         )

client = cluster.get_client()

#use the following config for the problem tiles, otherwise 4 and 32, 8 and 32
                        #  worker_memory="64 GB", 
                        #  worker_cpu=8,

odc.stac.configure_rio(cloud_defaults=True, client=client)

## Tile processing function

The `process_tile` function implements the complete processing pipeline for a single spatial tile:

### Key processing steps:

1. **Sentinel-1 data retrieval**
   - Retrieve Sentinel-1 RTC data from Microsoft Planetary Computer
   - Organizes by satellite orbit and adds water year coordinates
   - Applies optimized chunking for memory management

2. **Snow cover masking** 
   - Retrieves [custom MODIS-derived seasonal snow data](https://github.com/egagli/MODIS_seasonal_snow_mask) per water year: appearance, disappearance, and maximum number of consecutive snow cover days
   - Defines pixels with seasonal snow coverage
   - Sets temporal detection windows from snow accumulation to disappearance

3. **Quality filtering**
   - Removes bad scenes and border noise artifacts
   - Filters pixels with insufficient temporal sampling
   - Calculates maximum temporal gaps per orbit

4. **Runoff onset detection**
   - Identifies minimum backscatter timing per orbit/polarization
   - Aggregates using median for robustness
   - Converts to day-of-water-year format

5. **Aggregations**
   - Computes median and MAD across water years
   - Calculates temporal resolution metrics

6. **Data output**
   - Writes results to global zarr store
   - Updates processing status tracking
   - Manages memory cleanup

In [None]:
def process_tile(tile: Tile):
    odc.stac.configure_rio(cloud_defaults=True)
    tile.start_time = time.time()

    try:
        print(f"Getting data for tile ({tile.row},{tile.col}).")

        s1_rtc_ds = processing.get_sentinel1_rtc(
            geobox=tile.geobox,
            bands=config.bands,
            start_date=config.start_date,
            end_date=config.end_date,
            chunks_read=config.chunks_s1_read,
            fail_on_error=True,
        )

        s1_rtc_ds["vv"] = s1_rtc_ds["vv"].chunk(config.chunks_s1_process).persist()
        print("Data retrieved.")

        tile.s1_rtc_ds_dims = dict(s1_rtc_ds.sizes)

        spatiotemporal_snow_cover_mask_ds = processing.get_spatiotemporal_snow_cover_mask(
            ds=s1_rtc_ds,
            bbox_gdf=tile.bbox_gdf,
            seasonal_snow_mask_store=config.seasonal_snow_mask_store,
            extend_search_window_beyond_SDD_days=config.extend_search_window_beyond_SDD_days,
            min_consec_snow_days_for_seasonal_snow=config.min_consec_snow_days_for_seasonal_snow,
            reproject_method="rasterio", #rasterio
        ).persist()

        if config.mountain_snow_only:
            gmba_clipped_gdf = processing.get_gmba_mountain_inventory(tile.bbox_gdf)
        else:
            gmba_clipped_gdf = None

        print("Applying masks...")
        s1_rtc_masked_ds = processing.apply_all_masks(
            s1_rtc_ds=s1_rtc_ds,
            gmba_clipped_gdf=gmba_clipped_gdf,
            spatiotemporal_snow_cover_mask_ds=spatiotemporal_snow_cover_mask_ds,
            water_years=config.water_years,
        )

        print("Removing bad scenes and border noise...")
        s1_rtc_masked_ds = processing.remove_bad_scenes_and_border_noise(
            s1_rtc_masked_ds, config.low_backscatter_threshold
        )
        print("Bad scenes and border noise removed.")

        print("Filtering by acquisitions and gaps...")
        s1_rtc_masked_filtered_ds = s1_rtc_masked_ds.groupby("water_year").map(
            lambda group: processing.filter_insufficient_pixels_per_orbit(
                s1_rtc_masked_ds=group,
                spatiotemporal_snow_cover_mask_ds=spatiotemporal_snow_cover_mask_ds,
                min_monthly_acquisitions=config.min_monthly_acquisitions,
                max_allowed_days_gap_per_orbit=config.max_allowed_days_gap_per_orbit,
            )
        ).persist()
        print("Filtering completed.")

        print("Calculating temporal resolution...")
        temporal_resolution_da = processing.get_temporal_resolution(
            s1_rtc_masked_filtered_ds, spatiotemporal_snow_cover_mask_ds
        ).persist()

        tile_median_temporal_resolution = temporal_resolution_da.median(
            dim=["latitude", "longitude"]
        )
        tile_pixel_count = temporal_resolution_da.count(dim=["latitude", "longitude"])

        for water_year in config.water_years:
            if water_year in tile_median_temporal_resolution.water_year:
                temporal_resolution = tile_median_temporal_resolution.sel(
                    water_year=water_year
                ).values
                setattr(tile, f"tr_{water_year}", round(float(temporal_resolution), 3))

            if water_year in tile_pixel_count.water_year:
                pixel_count = tile_pixel_count.sel(water_year=water_year).values
                setattr(tile, f"pix_ct_{water_year}", int(pixel_count))

        print("Temporal resolution calculated.")

        print("Calculating runoff onsets...")
        runoff_onsets_da = s1_rtc_masked_filtered_ds.groupby("water_year").apply(
            processing.calculate_runoff_onset,
            returned_dates_format="dowy",
            return_constituent_runoff_onsets=False,
        )
        print("Runoff onsets calculated.")

        tile.runoff_onsets_dims = dict(runoff_onsets_da.sizes)

        # Calculate median and MAD
        median_da, mad_da = processing.median_and_mad_with_min_obs(
            da=runoff_onsets_da,
            dim="water_year",
            min_count=config.min_years_for_median_std
        )

        # Calculate median temporal resolution
        median_temporal_resolution_da = processing.median_with_min_obs(
            da=temporal_resolution_da,
            dim="water_year",
            min_count=config.min_years_for_median_std
        )

        # Create dataset
        runoff_onsets_ds = processing.dataarrays_to_dataset(
            runoff_onsets_da=runoff_onsets_da,
            median_da=median_da,
            mad_da=mad_da,
            water_years=config.water_years,
            temporal_resolution_da=temporal_resolution_da,
            median_temporal_resolution_da=median_temporal_resolution_da,
        )

        print("Median and MAD calculated, converted to dataset.")

        # Reindex to global coordinates
        global_ds = xr.open_zarr(config.global_runoff_store, consolidated=True)
        print("Global dataset opened.")
        global_subset_ds = global_ds.sel(
            latitude=runoff_onsets_ds.latitude,
            longitude=runoff_onsets_ds.longitude,
            method="nearest",
        )
        print("Global dataset subsetted.")
        runoff_onsets_reindexed_ds = runoff_onsets_ds.assign_coords(
            latitude=global_subset_ds.latitude, longitude=global_subset_ds.longitude
        )
        print("Dataset reindexed.")

        # Write to Zarr
        runoff_onsets_reindexed_ds.drop_vars("spatial_ref").chunk(
            config.chunks_zarr_output
        ).to_zarr(
            config.global_runoff_store, region="auto", mode="r+", consolidated=True
        )
        print("Dataset written to Zarr.")

        tile.total_time = time.time() - tile.start_time
        tile.success = True

        # Clean up memory
        del (
            s1_rtc_ds,
            spatiotemporal_snow_cover_mask_ds,
            s1_rtc_masked_ds,
            s1_rtc_masked_filtered_ds,
            temporal_resolution_da,
            runoff_onsets_da,
            runoff_onsets_ds,
            global_subset_ds,
            runoff_onsets_reindexed_ds,
            median_da,
            mad_da,
            median_temporal_resolution_da,
            tile_median_temporal_resolution,
            tile_pixel_count,
            gmba_clipped_gdf,
            global_ds,
        )

    except Exception as e:
        tile.error_messages.append(str(e))
        tile.error_messages.append(traceback.format_exc())
        tile.total_time = time.time() - tile.start_time
        tile.success = False

    return tile

## Test on a single tile

For testing, individual tiles can be processed to verify the pipeline before large-scale deployment.

In [None]:
# tiles = config.get_list_of_tiles(which='all')
# tile=tiles[0]
tile = config.get_tile(23,39)

In [None]:
future = client.submit(process_tile, tile)

In [None]:
future.status

In [None]:
future.result().success

In [None]:
computed_result = future.result()

In [None]:
df = pd.DataFrame(
    [[getattr(computed_result, f) for f in config.fields]],
    columns=config.fields,
)
df
# rio 250 sec

In [None]:
for col in df.columns:
    # print the temporal resolutions
   if col.startswith('tr_'):
       print(f"Temporal resolution for {col}: {df[col].values}")

## Tile selection and testing

Select tiles for processing based on different criteria:
- **`'all'`**: all global coverage
- **`'processed'`**: successfully completed tiles  
- **`'failed'`**: tiles that encountered errors
- **`'unprocessed'`**: tiles not yet attempted
- **`'unprocessed_and_failed'`**: tiles needing processing or reprocessing. Unless you have a specific need / debugging, you should probably use this one.
- **`'unprocessed_and_failed_weather_stations'`**: unprocessed/failed tiles that contain automatic weather stations. Useful for validation.


Provide one of these arguments to `config.get_list_of_tiles(which=)`

In [None]:
#tiles = config.get_list_of_tiles(which='unprocessed_and_failed')
tiles = config.get_list_of_tiles(which='unprocessed_and_failed_weather_stations') # run this to process the tiles with weather stations

In [None]:
batch_size = 20
tile_batches = [tiles[i:i + batch_size] for i in range(0, len(tiles), batch_size)]


for tile_batch in tqdm.tqdm(tile_batches, total=len(tile_batches)):

    futures = [client.submit(process_tile, tile) for tile in tile_batch] #, retries=0

    successful_tiles = []

    try:
        for future, computed_result in dask.distributed.as_completed(futures, with_results=True, timeout=1600):
            successful_tiles.append(computed_result)

            df = pd.DataFrame(
                [[getattr(computed_result, f) for f in config.fields]],
                columns=config.fields,
            )
            df.to_csv(config.tile_results_path, mode='a', header=False, index=False) # header=True if starting over
            print(f'Tile ({computed_result.row},{computed_result.col}) completed')
    except Exception as e:
        for tile in tile_batch:
            if tile.index not in [computed_tile.index for computed_tile in successful_tiles]:

                df = pd.DataFrame([[getattr(tile, f) for f in config.fields]],columns=config.fields,)
                df.to_csv(config.tile_results_path, mode='a', header=False, index=False) # header=True if starting over

                print(f'Tile ({tile.row},{tile.col}) failed')
                print(e)
                print(traceback.format_exc())
    
    client.restart()