# MODIS Snow Cover Data Processing with Dask and Azure Blob Storage

## Setup and Imports

In [1]:
import numpy as np
import pandas as pd
import xarray as xr
import zarr
import easysnowdata
import modis_masking
import coiled
import tqdm
import logging
import traceback
import matplotlib.pyplot as plt
import adlfs
import pathlib
from dask.distributed import as_completed
#odc.stac.configure_rio(cloud_defaults=True)

import odc.stac
print(f'odc.stac version: {odc.stac.__version__}')
print(f'xarray version: {xr.__version__}')
print(f'zarr version: {zarr.__version__}')

odc.stac version: 0.4.0
xarray version: 2025.4.0
zarr version: 2.18.7


## Configuration

In [2]:
WY_start = 2015
WY_end = 2024

# get token from https://github.com/egagli/azure_authentication/raw/main/sas_token.txt
sas_token = pathlib.Path("sas_token.txt").read_text()

store = adlfs.AzureBlobFileSystem(
    account_name="snowmelt", credential=sas_token
).get_mapper("snowmelt/snow_cover/global_modis_snow_cover_4.zarr") # make sure this matches the store created in make_grid_and_zarr_store.ipynb

# created in make_grid_and_zarr_store.ipynb
with open('modis_tile_processing_list.txt', 'r') as f:
    tile_processing_list = f.read().splitlines()

## Define processing function


In [3]:
# note to self as leaving 6/13....... we started new, now we just have to run this processing! do it just for polar tiles, then view, then also view MAD. if looks good, run everywhere

In [4]:
def process_tile(tile, store):

    h, v = (int(part[1:]) for part in tile.split('_'))

    # odc.stac.configure_rio(cloud_defaults=True)
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)
    logging.getLogger('azure').setLevel(logging.WARNING)

    logger.info(f"Starting process for tile {tile}")

    try:

        # logger.info(f"Zarr store opened successfully")

        hemisphere = "northern" if v < 9 else "southern"

        # logger.info(f"Fetching MODIS data for tile {tile}")
        if hemisphere == "northern":
            modis_snow_da = modis_masking.get_modis_MOD10A2_max_snow_extent(
                vertical_tile=v,
                horizontal_tile=h,
                start_date=f"{WY_start-2}-10-01", # normally should be WY_start-1, but we want to get data from WY before, as October 1st acquistion in NH already has fill values
                end_date=f"{WY_end}-09-30",
                #chunks={},
                #chunks={"time": -1, "y": 600, "x": 600},
                chunks={"time": 1, "y": 2400, "x": 2400},

            ).chunk({"time": -1, "y": 600, "x": 600})

        else:
            modis_snow_da = modis_masking.get_modis_MOD10A2_max_snow_extent(
                vertical_tile=v,
                horizontal_tile=h,
                start_date=f"{WY_start-1}-04-01", # normally should be WY_start, but we want to get data from WY before, as October 1st acquistion in NH already has fill values
                end_date=f"{WY_end+1}-03-31",
                #chunks={},
                #chunks={"time": -1, "y": 600, "x": 600},
                chunks={"time": 1, "y": 2400, "x": 2400},

            ).chunk({"time": -1, "y": 600, "x": 600})

        # logger.info(f"Processing MODIS data for tile {tile}")
        modis_snow_da.coords["water_year"] = (
            "time",
            pd.to_datetime(modis_snow_da.time).map(
                lambda x: easysnowdata.utils.datetime_to_WY(x, hemisphere=hemisphere)
            ),
        )
        modis_snow_da.coords["DOWY"] = (
            "time",
            pd.to_datetime(modis_snow_da.time).map(
                lambda x: easysnowdata.utils.datetime_to_DOWY(x, hemisphere=hemisphere)
            ),
        )

        if (v>=15) | (v<=2):# if north or south pole, remove strange no snow artifacts around time of darkness (check to see if darkness in the scene?)
            # if a scene contains significant no decision (1) / night (11) values, change no snow (25) to fill (255)?

            value25_da = modis_snow_da.where(lambda x: x == 25).count(dim=["x","y"])
            value200_da = modis_snow_da.where(lambda x: x == 200).count(dim=["x","y"])
            no_decision_and_night_counts = modis_snow_da.where(lambda x: (x == 1) | (x == 11)).count(dim=["x","y"])

            land_area_da = value200_da + value25_da
            max_land_pixels = land_area_da.max(dim='time')
            bad_pixel_thresh = int(0.05*int(max_land_pixels))

            scenes_with_polar_night = no_decision_and_night_counts > bad_pixel_thresh
            scenes_with_polar_night_buffered = (scenes_with_polar_night.shift(time=-1).fillna(0) | scenes_with_polar_night | scenes_with_polar_night.shift(time=1).fillna(0)).astype(int)
            backward_check = scenes_with_polar_night_buffered.rolling(time=4, center=False).sum() >= 4  # forward-looking
            forward_check = scenes_with_polar_night_buffered[::-1].rolling(time=4, center=False).sum()[::-1] >= 4 # backward-looking
            center_check = scenes_with_polar_night_buffered.rolling(time=4, center=True).sum() >= 4  # center-looking

            # A position should be kept if it's part of a sequence of 4+ when looking in any direction
            scenes_with_polar_night_buffered_filtered = scenes_with_polar_night_buffered.where(
                backward_check | forward_check | center_check,
                other=0
            ).astype(bool).chunk(dict(time=-1))

            scenes_with_polar_night_buffered_filtered_complete = (
                scenes_with_polar_night_buffered_filtered.where(lambda x: x == 1)
                .interpolate_na(
                    dim="time", method="nearest", max_gap=pd.Timedelta(days=80)
                )
                .where(lambda x: x == 1, other=0)
                .astype(bool)
            )

            modis_snow_da = modis_snow_da.where(
                ~((modis_snow_da == 25) & (scenes_with_polar_night_buffered_filtered_complete)), other=255)

        # logger.info(f"Applying binarize_with_cloud_filling for tile {tile}")
        effective_snow_da = modis_masking.binarize_with_cloud_filling(modis_snow_da)

        effective_snow_complete_wys_da = modis_masking.align_wy_start(effective_snow_da, hemisphere=hemisphere)

        # logger.info(f"Calculating seasonal snow presence for tile {tile}")
        # seasonal_snow_presence = effective_snow_da.groupby("water_year").apply(
        #     modis_masking.get_max_consec_snow_days_SAD_SDD_one_WY
        # )
        seasonal_snow_cover_ds = effective_snow_complete_wys_da.groupby('water_year').apply(modis_masking.get_max_consec_snow_days_SAD_SDD_one_WY).compute()

        seasonal_snow_cover_ds = seasonal_snow_cover_ds.sel(water_year=slice(WY_start, WY_end))
        # logger.info(f"Writing results to zarr store for tile {tile}")

        num_years = len(seasonal_snow_cover_ds.water_year)
        y_slice = slice(v * 2400, (v + 1) * 2400)
        x_slice = slice(h * 2400, (h + 1) * 2400)

        existing_ds = xr.open_zarr(store, consolidated=True)
        y_coords = existing_ds.y[y_slice].values
        x_coords = existing_ds.x[x_slice].values

        if np.allclose(y_coords, seasonal_snow_cover_ds.y.values, atol=0.1) or np.allclose(x_coords, seasonal_snow_cover_ds.x.values, atol=0.1):
            seasonal_snow_cover_ds = seasonal_snow_cover_ds.assign_coords(y=y_coords, x=x_coords)
        else:
            logger.error(f"y or x coordinates do not match for tile {tile}")
            raise ValueError(f"y or x coordinates do not match for tile {tile}")

        # remove _FillValue from coords
        for var in seasonal_snow_cover_ds.data_vars:
            seasonal_snow_cover_ds[var] = seasonal_snow_cover_ds[var].drop_attrs()

        seasonal_snow_cover_ds.drop_vars("spatial_ref").chunk(
            {"water_year": 1, "y": 2400, "x": 2400}
        ).to_zarr(store, region="auto", mode="r+", consolidated=True)
        # logger.info(f"Tile {tile} processed and written successfully")
        # existing_ds.attrs['processed_tiles'].append(tile)
        # logger.info(f"Tile {tile} processed and written, added to processed_tiles list")

        return True

    except Exception as e:
        logger.error(f"(PT) Error processing tile {tile}: {str(e)}")
        logger.error(f"(PT) Traceback: {traceback.format_exc()}")
        return False

## Set Up Dask Cluster with Coiled

In [5]:
# cluster = coiled.Cluster(idle_timeout="10 minutes",
#                         n_workers=10, #30
#                         worker_memory="16 GB", #32 
#                         worker_cpu=4, # 4,
#                         #worker_options={"nthreads": 1},
#                         #scheduler_memory="32 GB",
#                         spot_policy="spot",
#                         environ={"GDAL_DISABLE_READDIR_ON_OPEN": "EMPTY_DIR"},
#                         workspace="uwtacolab",
#                         )

cluster = coiled.Cluster(idle_timeout="10 minutes",
                        n_workers=30, #30
                        worker_memory="32 GB", #32 
                        worker_cpu=4, # 4,
                        #worker_options={"nthreads": 1},
                        #scheduler_memory="32 GB",
                        spot_policy="spot",
                        environ={"GDAL_DISABLE_READDIR_ON_OPEN": "EMPTY_DIR"},
                        workspace="uwtacolab",
                        )


client = cluster.get_client()

Output()

Output()



Exception in callback None()
handle: <Handle cancelled>
Traceback (most recent call last):
  File "/home/eric/miniconda3/envs/new_global_snowmelt_runoff_onset/lib/python3.13/site-packages/tornado/iostream.py", line 1363, in _do_ssl_handshake
    self.socket.do_handshake()
    ~~~~~~~~~~~~~~~~~~~~~~~~^^
  File "/home/eric/miniconda3/envs/new_global_snowmelt_runoff_onset/lib/python3.13/ssl.py", line 1372, in do_handshake
    self._sslobj.do_handshake()
    ~~~~~~~~~~~~~~~~~~~~~~~~~^^
ssl.SSLEOFError: [SSL: UNEXPECTED_EOF_WHILE_READING] EOF occurred in violation of protocol (_ssl.c:1028)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/eric/miniconda3/envs/new_global_snowmelt_runoff_onset/lib/python3.13/asyncio/events.py", line 89, in _run
    self._context.run(self._callback, *self._args)
    ~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/eric/miniconda3/envs/new_global_snowmelt_runoff_onset/lib/python3.1

## Process MODIS Tiles

In [None]:
tile = "h11_v2"
client.submit(process_tile, tile, store)

In [None]:
client.restart()

In [None]:
tile_processing_list = tile_processing_list_filtered

In [6]:
failed_tiles = []
completed_tiles_batch = []
BATCH_SIZE = 30  # For zarr store updates AND submission batches

# Get initially processed tiles
processed_tiles_list_initial = zarr.open(store).attrs['processed_tiles']

# Filter out already processed tiles
unprocessed_tiles = [tile for tile in tile_processing_list 
                    if tile not in processed_tiles_list_initial]

print(f"Found {len(unprocessed_tiles)} unprocessed tiles")

# Process tiles in batches of 20
for i in range(0, len(unprocessed_tiles), BATCH_SIZE):
    batch = unprocessed_tiles[i:i+BATCH_SIZE]
    batch_num = i//BATCH_SIZE + 1
    total_batches = (len(unprocessed_tiles)-1)//BATCH_SIZE + 1
    
    print(f"\n=== Processing batch {batch_num}/{total_batches} ({len(batch)} tiles) ===")
    
    # Submit this batch
    batch_futures = {}
    for tile in batch:
        future = client.submit(process_tile, tile, store)
        batch_futures[future] = tile
    
    print(f"Submitted {len(batch_futures)} tiles for batch {batch_num}")
    
    # Process this batch to completion
    for future in tqdm.tqdm(as_completed(batch_futures), total=len(batch_futures), 
                           desc=f"Batch {batch_num} progress"):
        tile = batch_futures[future]
        
        try:
            result = future.result()
            
            if result == True:
                completed_tiles_batch.append(tile)
                print(f"Tile {tile} SUCCESS")
            else:
                print(f"Tile {tile} FAIL, adding to failed list")
                failed_tiles.append(tile)
                
        except Exception as e:
            print(f"Tile {tile} FAILED with exception: {str(e)}")
            failed_tiles.append(tile)
    
    # Update zarr store with all completed tiles from this batch
    if completed_tiles_batch:
        with zarr.open(store) as zarr_store:
            processed_tile_list = zarr_store.attrs['processed_tiles']
            processed_tile_list.extend(completed_tiles_batch)
            zarr_store.attrs['processed_tiles'] = processed_tile_list
        print(f"Updated zarr store with {len(completed_tiles_batch)} tiles")
        completed_tiles_batch = []
    
    print(f"Batch {batch_num} completed. Failed so far: {len(failed_tiles)}")
    
    # Restart client after each batch
    client.restart(wait_for_workers=True)
    print(f"Client restarted after batch {batch_num}")

# Final status report
print(f"\n=== FINAL RESULTS ===")
print(f"Total tiles processed: {len(unprocessed_tiles) - len(failed_tiles)}")
print(f"Failed tiles: {len(failed_tiles)}")

if failed_tiles:
    print("Run this cell again. The following tiles could not be processed:")
    for tile in failed_tiles:
        print(tile)
else:
    print("Now consolidating metadata...")
    zarr.consolidate_metadata(store)
    print("All tiles processed successfully!!!")

Found 3 unprocessed tiles

=== Processing batch 1/1 (3 tiles) ===
Submitted 3 tiles for batch 1


Batch 1 progress:  33%|███▎      | 1/3 [02:42<05:25, 162.98s/it]

Tile h23_v4 SUCCESS


Batch 1 progress: 100%|██████████| 3/3 [02:43<00:00, 54.46s/it] 

Tile h27_v5 SUCCESS
Tile h10_v3 SUCCESS





Updated zarr store with 3 tiles
Batch 1 completed. Failed so far: 0
Client restarted after batch 1

=== FINAL RESULTS ===
Total tiles processed: 3
Failed tiles: 0
Now consolidating metadata...
All tiles processed successfully!!!


## Test longest consec stretch function

In [None]:
# Comprehensive test of function equivalence with more and longer cases
def comprehensive_test_vectorized_function():
    import numpy as np
    from numba import jit
    
    # Define both functions here for testing
    def get_longest_consec_stretch_original(arr):
        max_len = 0
        max_start = 0
        max_end = 0
        current_start = None
        
        for i, val in enumerate(arr):
            if val:
                if current_start is None:
                    current_start = i
            else:
                if current_start is not None:
                    length = i - current_start
                    if length >= max_len:
                        max_len = length
                        max_start = current_start
                        max_end = i
                    current_start = None
        
        if current_start is not None:
            length = len(arr) - current_start
            if length > max_len:
                max_len = length
                max_start = current_start
                max_end = len(arr)
        
        if max_len == 0:
            return -32768, -32768, -32768  # fill_value
        
        return max_start, max_end, max_len
    
    @jit(nopython=True)
    def get_longest_consec_stretch_vectorized(arr):
        n = len(arr)
        if n == 0:
            return -32768, -32768, -32768
        
        max_len = 0
        max_start = 0
        max_end = 0
        current_start = -1
        
        for i in range(n):
            if arr[i]:
                if current_start == -1:
                    current_start = i
            else:
                if current_start != -1:
                    length = i - current_start
                    if length >= max_len:
                        max_len = length
                        max_start = current_start
                        max_end = i
                    current_start = -1
        
        if current_start != -1:
            length = n - current_start
            if length > max_len:
                max_len = length
                max_start = current_start
                max_end = n
        
        if max_len == 0:
            return -32768, -32768, -32768
        
        return max_start, max_end, max_len
    
    # Comprehensive test cases
    test_cases = [
        # Basic cases
        [],
        [False],
        [True],
        [False, False, False],
        [True, True, True],
        
        # Simple patterns
        [True, True, False, True, True, True],
        [False, True, True, False, False, True],
        [True, False, True, False, True],
        
        # Edge cases - snow at boundaries
        [True, False, False, False],  # Snow at start only
        [False, False, False, True],  # Snow at end only
        [True, True, False, False],   # Snow at start, none at end
        [False, False, True, True],   # No snow at start, snow at end
        
        # Full coverage cases
        [True] * 10,    # All snow
        [False] * 10,   # No snow
        
        # Realistic MODIS water year patterns (46 time steps)
        # Early season snow pattern
        [False] * 5 + [True] * 15 + [False] * 26,
        
        # Mid-season with gaps
        [False] * 8 + [True] * 12 + [False] * 6 + [True] * 10 + [False] * 10,
        
        # Late season pattern
        [True] * 20 + [False] * 26,
        
        # Full season snow
        [True] * 46,
        
        # No snow season
        [False] * 46,
        
        # Complex multi-period pattern
        [False] * 3 + [True] * 8 + [False] * 5 + [True] * 15 + [False] * 7 + [True] * 8,
        
        # Very long arrays (simulate multiple years)
        # Long early season
        [False] * 20 + [True] * 50 + [False] * 30,
        
        # Multiple long periods
        [True] * 25 + [False] * 10 + [True] * 35 + [False] * 15 + [True] * 15,
        
        # Intermittent pattern (cloud gaps)
        ([True, False, True, True, False, False, True, True, True, False] * 10),
        
        # Very large array with complex pattern
        ([False] * 10 + [True] * 20 + [False] * 5 + [True] * 30 + 
         [False] * 8 + [True] * 15 + [False] * 12 + [True] * 25 + [False] * 20),
        
        # Alternating single values (worst case for algorithm)
        [True, False] * 50,
        [False, True] * 50,
        
        # Random patterns of various lengths
        np.random.choice([True, False], size=20, p=[0.4, 0.6]).tolist(),
        np.random.choice([True, False], size=50, p=[0.3, 0.7]).tolist(),
        np.random.choice([True, False], size=100, p=[0.4, 0.6]).tolist(),
        np.random.choice([True, False], size=200, p=[0.2, 0.8]).tolist(),
        
        # Edge case: very long consecutive period
        [False] * 50 + [True] * 150 + [False] * 50,
        
        # Multiple equal-length periods (test tie-breaking)
        [True] * 10 + [False] * 5 + [True] * 10 + [False] * 5 + [True] * 10,
        
        # Single snow day in long period
        [False] * 100 + [True] + [False] * 100,
        
        # Two equal periods at start and end
        [True] * 15 + [False] * 20 + [True] * 15,
    ]
    
    print("Comprehensive Function Test:")
    print("=" * 60)
    
    all_passed = True
    failed_cases = []
    
    for i, case in enumerate(test_cases):
        arr = np.array(case, dtype=bool)
        
        orig = get_longest_consec_stretch_original(arr)
        vect = get_longest_consec_stretch_vectorized(arr)
        
        match = orig == vect
        status = "✓ PASS" if match else "✗ FAIL"
        
        print(f"Test {i+1:2d}: {status} - Length: {len(case):3d}")
        
        if len(case) <= 20:  # Show pattern for short cases
            print(f"  Pattern: {case}")
        else:  # Show summary for long cases
            true_count = sum(case)
            false_count = len(case) - true_count
            print(f"  Summary: {true_count} True, {false_count} False")
        
        print(f"  Original:   start={orig[0]:4d}, end={orig[1]:4d}, length={orig[2]:4d}")
        print(f"  Vectorized: start={vect[0]:4d}, end={vect[1]:4d}, length={vect[2]:4d}")
        
        if not match:
            all_passed = False
            failed_cases.append((i+1, case[:20] if len(case) > 20 else case, orig, vect))
            print(f"  *** MISMATCH! ***")
        
        print()
    
    print("=" * 60)
    print(f"SUMMARY:")
    print(f"Total tests: {len(test_cases)}")
    print(f"Passed: {len(test_cases) - len(failed_cases)}")
    print(f"Failed: {len(failed_cases)}")
    
    if all_passed:
        print("\n🎉 ALL TESTS PASSED! Functions are equivalent.")
        print("✅ SAFE TO USE the vectorized version.")
    else:
        print(f"\n❌ {len(failed_cases)} TESTS FAILED! Functions are NOT equivalent.")
        print("❌ DO NOT USE the vectorized version yet.")
        print("\nFailed test details:")
        for test_num, pattern, orig, vect in failed_cases:
            print(f"  Test {test_num}: Pattern={pattern} | Orig={orig} | Vect={vect}")
    
    return all_passed

# Performance benchmark with longer arrays
def benchmark_functions():
    import time
    import numpy as np
    from numba import jit
    
    # Redefine functions for benchmarking
    def get_longest_consec_stretch_original(arr):
        max_len = 0
        max_start = 0
        max_end = 0
        current_start = None
        
        for i, val in enumerate(arr):
            if val:
                if current_start is None:
                    current_start = i
            else:
                if current_start is not None:
                    length = i - current_start
                    if length >= max_len:
                        max_len = length
                        max_start = current_start
                        max_end = i
                    current_start = None
        
        if current_start is not None:
            length = len(arr) - current_start
            if length > max_len:
                max_len = length
                max_start = current_start
                max_end = len(arr)
        
        if max_len == 0:
            return -32768, -32768, -32768
        
        return max_start, max_end, max_len
    
    @jit(nopython=True)
    def get_longest_consec_stretch_vectorized(arr):
        n = len(arr)
        if n == 0:
            return -32768, -32768, -32768
        
        max_len = 0
        max_start = 0
        max_end = 0
        current_start = -1
        
        for i in range(n):
            if arr[i]:
                if current_start == -1:
                    current_start = i
            else:
                if current_start != -1:
                    length = i - current_start
                    if length >= max_len:
                        max_len = length
                        max_start = current_start
                        max_end = i
                    current_start = -1
        
        if current_start != -1:
            length = n - current_start
            if length > max_len:
                max_len = length
                max_start = current_start
                max_end = n
        
        if max_len == 0:
            return -32768, -32768, -32768
        
        return max_start, max_end, max_len
    
    print("\nPerformance Benchmark:")
    print("=" * 40)
    
    # Test different array sizes
    sizes = [46, 100, 500, 1000]  # Realistic MODIS sizes
    iterations = [50000, 20000, 5000, 2000]  # Fewer iterations for larger arrays
    
    for size, iters in zip(sizes, iterations):
        np.random.seed(42)  # Reproducible results
        test_array = np.random.choice([True, False], size=size, p=[0.3, 0.7])
        
        # Warm up numba
        _ = get_longest_consec_stretch_vectorized(test_array)
        
        # Benchmark original
        start_time = time.time()
        for _ in range(iters):
            _ = get_longest_consec_stretch_original(test_array)
        orig_time = time.time() - start_time
        
        # Benchmark vectorized
        start_time = time.time()
        for _ in range(iters):
            _ = get_longest_consec_stretch_vectorized(test_array)
        vect_time = time.time() - start_time
        
        speedup = orig_time / vect_time if vect_time > 0 else float('inf')
        
        print(f"Array size {size:4d} ({iters:5d} iterations):")
        print(f"  Original:   {orig_time:.4f}s ({orig_time/iters*1000:.3f}ms per call)")
        print(f"  Vectorized: {vect_time:.4f}s ({vect_time/iters*1000:.3f}ms per call)")
        print(f"  Speedup:    {speedup:.2f}x")
        print()

# Run the comprehensive test
print("Running comprehensive test...")
test_passed = comprehensive_test_vectorized_function()

if test_passed:
    print("\n" + "="*60)
    benchmark_functions()
else:
    print("\nSkipping benchmark due to failed tests.")

print(f"\n{'='*60}")
print(f"FINAL RESULT: {'✅ SAFE TO USE' if test_passed else '❌ NEEDS FIXING'}")
print(f"{'='*60}")

## Individal tile investigation

In [None]:
# to figure out why doesnt SAD start early in antartica? is filling happening correctly? 
# # and then what is going on at the north pole ish during start of WY as well?
# v=16
# h=16
# #v=0
# WY_start = 2014
# WY_end = 2024

# hemisphere = "northern" if v < 9 else "southern"


# if hemisphere == "northern":
#     modis_snow_da = modis_masking.get_modis_MOD10A2_max_snow_extent(
#         vertical_tile=v,
#         horizontal_tile=h,
#         start_date=f"{WY_start-2}-10-01",
#         end_date=f"{WY_end}-09-30",
#         #chunks={},
#         #chunks={"time": -1, "y": 600, "x": 600},
#         chunks={"time": 1, "y": 2400, "x": 2400},

#     ).chunk({"time": -1, "y": 600, "x": 600})

# else:
#     modis_snow_da = modis_masking.get_modis_MOD10A2_max_snow_extent(
#         vertical_tile=v,
#         horizontal_tile=h,
#         start_date=f"{WY_start-1}-04-01",
#         end_date=f"{WY_end+1}-03-31",
#         #chunks={},
#         #chunks={"time": -1, "y": 600, "x": 600},
#         chunks={"time": 1, "y": 2400, "x": 2400},

#     ).chunk({"time": -1, "y": 600, "x": 600})

# modis_snow_da.coords["water_year"] = (
#     "time",
#     pd.to_datetime(modis_snow_da.time).map(
#         lambda x: easysnowdata.utils.datetime_to_WY(x, hemisphere=hemisphere)
#     ),
# )
# modis_snow_da.coords["DOWY"] = (
#     "time",
#     pd.to_datetime(modis_snow_da.time).map(
#         lambda x: easysnowdata.utils.datetime_to_DOWY(x, hemisphere=hemisphere)
#     ),
# )

# modis_snow_da

# modis_snow_da.sel(time='2020-03-29').plot.imshow()

# no_decision_and_night_counts = modis_snow_da.where(lambda x: (x == 1) | (x == 11)).count(dim=["x","y"]).compute()

# scenes_with_polar_night = no_decision_and_night_counts > 50
# # then for every true value, change the value that comes directly before and after it to 1 as well
# scenes_with_polar_night_buffered = (scenes_with_polar_night.shift(time=-1).fillna(0) | scenes_with_polar_night | scenes_with_polar_night.shift(time=1).fillna(0)).astype(int)
# backward_check = scenes_with_polar_night_buffered.rolling(time=4, center=False).sum() >= 4  # forward-looking
# forward_check = scenes_with_polar_night_buffered[::-1].rolling(time=4, center=False).sum()[::-1] >= 4  # backward-looking
# center_check = scenes_with_polar_night_buffered.rolling(time=4, center=True).sum() >= 4  # center-looking

# # A position should be kept if it's part of a sequence of 4+ when looking in any direction
# scenes_with_polar_night_buffered_filtered = scenes_with_polar_night_buffered.where(
#     backward_check | forward_check | center_check,
#     other=0
# ).astype(bool)
# scenes_with_polar_night_buffered_filtered
# # count the occurance of 1s and 11s in modis_snow_da which is dask array


# # f,ax=plt.subplots(figsize=(20,7))
# # no_decision_and_night_counts.plot()
# # ax.axvline(x=pd.to_datetime('2020-09-05'), color='red', linestyle='--')
# # binary_ax = ax.twinx()
# # scenes_with_polar_night_buffered.plot(ax=binary_ax, color='orange')
# # scenes_with_polar_night_buffered_filtered.plot(ax=binary_ax, color='green')
# # (scenes_with_polar_night_buffered_filtered-scenes_with_polar_night_buffered).plot(ax=binary_ax, color='purple', linestyle='--')

# scenes_with_polar_night_buffered_filtered.sel(time=slice('2020-03-13', '2020-09-30'))


# modis_snow_da # not working as intended? CHECK LOGIC

# values,counts=np.unique(modis_snow_da.sel(time='2020-03-29'), return_counts=True)
# for v,c in zip(values,counts):
#     print(f"{v}: {c}")

# # now for time steps where no_decision_and_night_counts is greater than 10, change the value of modis_snow_da from 25 to 255
# # modis_snow_da = modis_snow_da.where(
# #     ~((modis_snow_da == 25) & (no_decision_and_night_counts > 10)), other=255
# # )
# # modis_snow_da

# modis_snow_da.sel(time='2020-03-29').plot.imshow()

# values,counts=np.unique(modis_snow_da.sel(time='2020-03-29'), return_counts=True)
# for v,c in zip(values,counts):
#     print(f"{v}: {c}")


# effective_snow_da = modis_masking.binarize_with_cloud_filling(modis_snow_da)
# effective_snow_da

# effective_snow_da.sel(time=slice('2020-03-29', '2020-09-30')).plot.imshow(col='time',col_wrap=5,figsize=(10, 10))

# effective_snow_complete_wys_da = modis_masking.align_wy_start(effective_snow_da, hemisphere=hemisphere)
# effective_snow_complete_wys_da


# seasonal_snow_cover_ds = effective_snow_complete_wys_da.groupby('water_year').apply(modis_masking.get_max_consec_snow_days_SAD_SDD_one_WY).compute()
# seasonal_snow_cover_ds

# seasonal_snow_cover_ds['max_consec_snow_days'].sel(water_year=2020).plot.imshow(vmin=0, vmax=365, cmap='viridis')
# seasonal_snow_cover_ds['SAD_DOWY'].sel(water_year=2020).plot.imshow(vmin=0, vmax=365, cmap='viridis')

# tile_processing_list = [
#     "h15_v15",
#     "h15_v16",
#     "h15_v17",
#     "h16_v16",
#     "h16_v17",
#     "h17_v16",
#     "h17_v17",
#     "h16_v0",
#     "h16_v1",
#     "h16_v2",
#     "h17_v0",
#     "h17_v1",
#     "h17_v2",
#     "h17_v3",
#     "h14_v1",
#     "h14_v2",
#     "h14_v3",
#     "h14_v4",
#     "h15_v1",
#     "h15_v2",]
# tile_processing_list

## other approaches (code graveyard)

In [None]:
# failed_tiles = []
# results = []

# #processed_tiles_list_initial = zarr.open(store).attrs['processed_tiles']
# for tile in tqdm.tqdm(tile_processing_list[0:20]):
#     result = client.submit(process_tile, tile, store)
#     results.append(result)

# results[0].result()
# tile_processing_list

# result = client.submit(process_tile, tile_processing_list[0], store)
# result

# failed_tiles = []
# futures = {}

# # Get initially processed tiles
# processed_tiles_list_initial = zarr.open(store).attrs['processed_tiles']

# # Submit all unprocessed tiles to the cluster
# for tile in tqdm.tqdm(tile_processing_list, desc="Submitting tasks"):
#     if tile in processed_tiles_list_initial:
#         print(f"Tile {tile} already processed, skipping")
#         continue
    
#     future = client.submit(process_tile, tile, store)
#     futures[future] = tile

# print(f"Submitted {len(futures)} tiles for processing")

# # Process completed futures as they finish
# from dask.distributed import as_completed

# for future in tqdm.tqdm(as_completed(futures), total=len(futures), desc="Processing tiles"):
#     tile = futures[future]
    
#     try:
#         result = future.result()
        
#         if result == True:
#             # Update processed tiles list in zarr store
#             with zarr.open(store) as zarr_store:
#                 processed_tile_list = zarr_store.attrs['processed_tiles']
#                 processed_tile_list.append(tile)
#                 zarr_store.attrs['processed_tiles'] = processed_tile_list
            
#             print(f"Tile {tile} SUCCESS, added to processed_list attribute")
#         else:
#             print(f"Tile {tile} FAIL, adding to failed list")
#             failed_tiles.append(tile)
            
#     except Exception as e:
#         print(f"Tile {tile} FAILED with exception: {str(e)}")
#         failed_tiles.append(tile)

# # Restart cluster after all processing is complete
# if len(futures) > 0:
#     client.restart(wait_for_workers=True)

# # Final status report
# if failed_tiles:
#     print("Run this cell again. The following tiles could not be processed:")
#     for tile in failed_tiles:
#         print(tile)
# else:
#     print("Now consolidating metadata...")
#     zarr.consolidate_metadata(store)
#     print("All tiles processed successfully!!!")


# failed_tiles = []

# processed_tiles_list_initial = zarr.open(store).attrs['processed_tiles']
# for tile in tqdm.tqdm(tile_processing_list):
    
#     if tile in processed_tiles_list_initial:
#         print(f"Tile {tile} already processed, skipping")
#         continue
        
#     result = process_tile(tile, store)

#     if result == True:

#         with zarr.open(store) as zarr_store:
#             processed_tile_list = zarr_store.attrs['processed_tiles']
#             processed_tile_list.append(tile)
#             zarr_store.attrs['processed_tiles'] = processed_tile_list

#         print(f"Tile {tile} SUCCESS, added to processed_list attribute")

#         client.restart(wait_for_workers=True)

#     else:
#         print(f"Tile {tile} FAIL, adding to failed list")
#         failed_tiles.append(tile)


# if failed_tiles:
#     print("Run this cell again. The following tiles could not be processed:")
#     for tile in failed_tiles:
#         print(tile)
# else:
#     print("Now consolidating metadata...")
#     zarr.consolidate_metadata(store)
#     print("All tiles processed successfully!!!")

In [None]:
def test_tile(tile,bad_pixel_thresh=500):
#     h, v = (int(part[1:]) for part in tile.split('_'))

#     hemisphere = "northern" if v < 9 else "southern"

#     # logger.info(f"Fetching MODIS data for tile {tile}")
#     if hemisphere == "northern":
#         modis_snow_da = modis_masking.get_modis_MOD10A2_max_snow_extent(
#             vertical_tile=v,
#             horizontal_tile=h,
#             start_date=f"{WY_start-2}-10-01", # normally should be WY_start-1, but we want to get data from WY before, as October 1st acquistion in NH already has fill values
#             end_date=f"{WY_end}-09-30",
#             #chunks={},
#             #chunks={"time": -1, "y": 600, "x": 600},
#             chunks={"time": 1, "y": 2400, "x": 2400},

#         ).chunk({"time": -1, "y": 600, "x": 600})

#     else:
#         modis_snow_da = modis_masking.get_modis_MOD10A2_max_snow_extent(
#             vertical_tile=v,
#             horizontal_tile=h,
#             start_date=f"{WY_start-1}-04-01", # normally should be WY_start, but we want to get data from WY before, as October 1st acquistion in NH already has fill values
#             end_date=f"{WY_end+1}-03-31",
#             #chunks={},
#             #chunks={"time": -1, "y": 600, "x": 600},
#             chunks={"time": 1, "y": 2400, "x": 2400},

#         ).chunk({"time": -1, "y": 600, "x": 600})

#     # logger.info(f"Processing MODIS data for tile {tile}")
#     modis_snow_da.coords["water_year"] = (
#         "time",
#         pd.to_datetime(modis_snow_da.time).map(
#             lambda x: easysnowdata.utils.datetime_to_WY(x, hemisphere=hemisphere)
#         ),
#     )
#     modis_snow_da.coords["DOWY"] = (
#         "time",
#         pd.to_datetime(modis_snow_da.time).map(
#             lambda x: easysnowdata.utils.datetime_to_DOWY(x, hemisphere=hemisphere)
#         ),
#     )

#     value1_da = modis_snow_da.where(lambda x: x == 1).count(dim=["x","y"])
#     value11_da = modis_snow_da.where(lambda x: x == 11).count(dim=["x","y"])
#     value25_da = modis_snow_da.where(lambda x: x == 25).count(dim=["x","y"])
#     value200_da = modis_snow_da.where(lambda x: x == 200).count(dim=["x","y"])
#     no_decision_and_night_counts = modis_snow_da.where(lambda x: (x == 1) | (x == 11)).count(dim=["x","y"])

#     land_area_da = value200_da + value25_da
#     max_land_pixels = land_area_da.max(dim='time')
#     bad_pixel_thresh = int(0.05*int(max_land_pixels))

#     scenes_with_polar_night = no_decision_and_night_counts > bad_pixel_thresh
#     scenes_with_polar_night_buffered = (scenes_with_polar_night.shift(time=-1).fillna(0) | scenes_with_polar_night | scenes_with_polar_night.shift(time=1).fillna(0)).astype(int)
#     backward_check = scenes_with_polar_night_buffered.rolling(time=4, center=False).sum() >= 4  # forward-looking
#     forward_check = scenes_with_polar_night_buffered[::-1].rolling(time=4, center=False).sum()[::-1] >= 4 # backward-looking
#     center_check = scenes_with_polar_night_buffered.rolling(time=4, center=True).sum() >= 4  # center-looking

#     # A position should be kept if it's part of a sequence of 4+ when looking in any direction
#     scenes_with_polar_night_buffered_filtered = scenes_with_polar_night_buffered.where(
#         backward_check | forward_check | center_check,
#         other=0
#     ).astype(bool).chunk(dict(time=-1))

#     scenes_with_polar_night_buffered_filtered_complete = (
#         scenes_with_polar_night_buffered_filtered.where(lambda x: x == 1)
#         .interpolate_na(dim="time", method="nearest", max_gap=pd.Timedelta(days=80))
#         .where(lambda x: x == 1, other=0)
#         .astype(bool)
#     )

#     all_variables = xr.Dataset({
#         "value1_da": value1_da,
#         "value11_da": value11_da,
#         "value25_da": value25_da,
#         "no_decision_and_night_counts": no_decision_and_night_counts,
#         "scenes_with_polar_night": scenes_with_polar_night,
#         "scenes_with_polar_night_buffered": scenes_with_polar_night_buffered,
#         "scenes_with_polar_night_buffered_filtered": scenes_with_polar_night_buffered_filtered,
#         "scenes_with_polar_night_buffered_filtered_complete": scenes_with_polar_night_buffered_filtered_complete
#     })

#     all_variables.coords['tile'] = tile
#     all_variables.coords['bad_pixel_thresh'] = bad_pixel_thresh

#     return all_variables

# tile_processing_list_filtered[0]
# test_ds = test_tile(tile_processing_list_filtered[0], bad_pixel_thresh=500).compute()
# test_ds



# f,ax= plt.subplots(figsize=(12, 7))
# test_ds['value1_da'].plot(ax=ax,label='Value 1 (no decision) count', color='blue')
# test_ds['value11_da'].plot(ax=ax,label='Value 11 (night) count', color='orange')
# test_ds['value25_da'].plot(ax=ax,label='Value 25 (no snow) count', color='green')
# test_ds['no_decision_and_night_counts'].plot(ax=ax,label='No decision and night counts', color='purple')
# binary_ax = ax.twinx()
# test_ds['scenes_with_polar_night_buffered_filtered'].plot(ax=binary_ax, label='Scenes with polar night',color='black')
# ax.legend()
# f.tight_layout()



# def create_test_dataset(tile_processing_list_filtered):
#     futures = [client.submit(test_tile, tile) for tile in tile_processing_list_filtered]
#     results = []
#     for future in tqdm.tqdm(as_completed(futures), total=len(futures)):
#         try:
#             result = future.result()
#             results.append(result)
#         except Exception as e:
#             logging.error(f"Error processing tile {future.key}: {str(e)}")
#             logging.error(f"Traceback: {traceback.format_exc()}")
    
#     return xr.concat(results, dim="tile")


# full_test_ds = create_test_dataset(tile_processing_list_filtered).compute()
# full_test_ds


# # sort full_test_ds by tile, but specifically by increasing v value. remember you CANNOT use .map with a dataarray bc of error: AttributeError: 'DataArray' object has no attribute 'map'
# v_values = []
# for tile in full_test_ds.tile.values:
#     # Split the tile string and extract v value
#     parts = tile.split('_')
#     v_part = parts[1]  # e.g., 'v17'
#     v_value = int(v_part[1:])  # Remove 'v' and convert to int
#     v_values.append(v_value)

# # Convert to numpy array for sorting
# v_values = np.array(v_values)

# full_test_sorted_ds = full_test_ds.isel(tile=np.argsort(v_values)[::-1])
# full_test_sorted_ds


# # investigate h9_v2 -> no polar night ever
# # # 
# # f,ax=plt.subplots(figsize=(12,7))
# # for tile in full_test_ds.tile.values[0:3]:
# #     print(f"Tile: {tile}")
# #     full_test_ds.sel(tile=tile)['scenes_with_polar_night_buffered_filtered'].plot(ax=ax, label=tile)

# # f,ax= plt.subplots(figsize=(12, 7))
# # test_ds['value1_da'].plot(ax=ax,label='Value 1 (no decision) count', color='blue')
# # test_ds['value11_da'].plot(ax=ax,label='Value 11 (night) count', color='orange')
# # test_ds['value25_da'].plot(ax=ax,label='Value 25 (no snow) count', color='green')
# # test_ds['no_decision_and_night_counts'].plot(ax=ax,label='No decision and night counts', color='purple')
# # binary_ax = ax.twinx()
# # test_ds['scenes_with_polar_night_buffered_filtered'].plot(ax=binary_ax, label='Scenes with polar night',color='black')
# # ax.legend()
# # f.tight_layout()

# # h=11
# # v=2

# # hemisphere = "northern" if v < 9 else "southern"

# # # logger.info(f"Fetching MODIS data for tile {tile}")
# # if hemisphere == "northern":
# #     modis_snow_da = modis_masking.get_modis_MOD10A2_max_snow_extent(
# #         vertical_tile=v,
# #         horizontal_tile=h,
# #         start_date=f"{WY_start-2}-10-01", # normally should be WY_start-1, but we want to get data from WY before, as October 1st acquistion in NH already has fill values
# #         end_date=f"{WY_end}-09-30",
# #         #chunks={},
# #         #chunks={"time": -1, "y": 600, "x": 600},
# #         chunks={"time": 1, "y": 2400, "x": 2400},

# #     ).chunk({"time": -1, "y": 600, "x": 600})

# # else:
# #     modis_snow_da = modis_masking.get_modis_MOD10A2_max_snow_extent(
# #         vertical_tile=v,
# #         horizontal_tile=h,
# #         start_date=f"{WY_start-1}-04-01", # normally should be WY_start, but we want to get data from WY before, as October 1st acquistion in NH already has fill values
# #         end_date=f"{WY_end+1}-03-31",
# #         #chunks={},
# #         #chunks={"time": -1, "y": 600, "x": 600},
# #         chunks={"time": 1, "y": 2400, "x": 2400},

# #     ).chunk({"time": -1, "y": 600, "x": 600})

# # # logger.info(f"Processing MODIS data for tile {tile}")
# # modis_snow_da.coords["water_year"] = (
# #     "time",
# #     pd.to_datetime(modis_snow_da.time).map(
# #         lambda x: easysnowdata.utils.datetime_to_WY(x, hemisphere=hemisphere)
# #     ),
# # )
# # modis_snow_da.coords["DOWY"] = (
# #     "time",
# #     pd.to_datetime(modis_snow_da.time).map(
# #         lambda x: easysnowdata.utils.datetime_to_DOWY(x, hemisphere=hemisphere)
# #     ),
# # )

# # modis_snow_da

# # value1_da = modis_snow_da.where(lambda x: x == 1).count(dim=["x","y"]).compute()
# # value1_da

# # value11_da = modis_snow_da.where(lambda x: x == 11).count(dim=["x","y"]).compute()
# # value11_da

# # no_decision_and_night_counts = modis_snow_da.where(lambda x: (x == 1) | (x == 11)).count(dim=["x","y"]).compute()
# # no_decision_and_night_counts


# # f,ax=plt.subplots(figsize=(20,7))
# # no_decision_and_night_counts.plot(ax=ax, color='orange', label='No Decision and Night Count')
# # #ax.axvline(x=pd.to_datetime('2020-09-05'), color='red', linestyle='--')
# # small_ax = ax.twinx()
# # value1_da.plot(ax=ax, color='green', label='Value 1 Count')
# # value11_da.plot(ax=ax, color='blue', label='Value 11 Count')


# # scenes_with_polar_night = no_decision_and_night_counts > 500

# # # then for every true value, change the value that comes directly before and after it to 1 as well
# # scenes_with_polar_night_buffered = (scenes_with_polar_night.shift(time=-1).fillna(0) | scenes_with_polar_night | scenes_with_polar_night.shift(time=1).fillna(0)).astype(int)
# # backward_check = scenes_with_polar_night_buffered.rolling(time=4, center=False).sum() >= 4  # forward-looking
# # forward_check = scenes_with_polar_night_buffered[::-1].rolling(time=4, center=False).sum()[::-1] >= 4 # backward-looking
# # center_check = scenes_with_polar_night_buffered.rolling(time=4, center=True).sum() >= 4  # center-looking

# # # A position should be kept if it's part of a sequence of 4+ when looking in any direction
# # scenes_with_polar_night_buffered_filtered = scenes_with_polar_night_buffered.where(
# #     backward_check | forward_check | center_check,
# #     other=0
# # ).astype(bool)

# # f,ax=plt.subplots(figsize=(20,7))
# # no_decision_and_night_counts.plot()
# # #ax.axvline(x=pd.to_datetime('2020-09-05'), color='red', linestyle='--')
# # binary_ax = ax.twinx()
# # scenes_with_polar_night_buffered.plot(ax=binary_ax, color='orange')
# # scenes_with_polar_night_buffered_filtered.plot(ax=binary_ax, color='green')
# # (scenes_with_polar_night_buffered_filtered-scenes_with_polar_night_buffered).plot(ax=binary_ax, color='purple', linestyle='--')


# # if (v>=15) | (v<=2):# if north or south pole, remove strange no snow artifacts around time of darkness (check to see if darkness in the scene?)
# #         # if a scene contains significant no decision (1) / night (11) values, change no snow (25) to fill (255)?
            
# #             no_decision_and_night_counts = modis_snow_da.where(lambda x: (x == 1) | (x == 11)).count(dim=["x","y"]).compute()

# #             scenes_with_polar_night = no_decision_and_night_counts > 50
# #             # then for every true value, change the value that comes directly before and after it to 1 as well
# #             scenes_with_polar_night_buffered = (scenes_with_polar_night.shift(time=-1).fillna(0) | scenes_with_polar_night | scenes_with_polar_night.shift(time=1).fillna(0)).astype(int)
# #             backward_check = scenes_with_polar_night_buffered.rolling(time=4, center=False).sum() >= 4  # forward-looking
# #             forward_check = scenes_with_polar_night_buffered[::-1].rolling(time=4, center=False).sum()[::-1] >= 4  # backward-looking
# #             center_check = scenes_with_polar_night_buffered.rolling(time=4, center=True).sum() >= 4  # center-looking

# #             # A position should be kept if it's part of a sequence of 4+ when looking in any direction
# #             scenes_with_polar_night_buffered_filtered = scenes_with_polar_night_buffered.where(
# #                 backward_check | forward_check | center_check,
# #                 other=0
# #             ).astype(bool)

# #             modis_snow_da = modis_snow_da.where(
# #                 ~((modis_snow_da == 25) & (scenes_with_polar_night_buffered_filtered)), other=255
# #             )



### serverless approach (this got close, couldn't push it across finish line though)

In [None]:
# https://docs.coiled.io/user_guide/functions.html
# inspired by: https://github.com/earth-mover/serverless-datacube-demo/blob/main/src/lib.py

# maybe another option: https://xarray.dev/blog/cubed-xarray
# @coiled.function(
#     n_workers=50,
#     cpu=4,
#     #threads_per_worker=8,
#     memory="16GiB",
#     spot_policy="spot",
#     region="westeurope",
#     environ={"GDAL_DISABLE_READDIR_ON_OPEN": "EMPTY_DIR"},
#     keepalive="5m",
#     workspace="azure"
# )
# # def process_chunks(tile_list, store):
# #     odc.stac.configure_rio(cloud_defaults=True)
# #     results = []
# #     for _, tile in tile_list:
# #         h = tile['h']
# #         v = tile['v']
# #         result = process_and_write_tile(h, v, store, serverless=False)
# #         results.append(result)
# #     return results
# def process_chunk(tile, store):
#     odc.stac.configure_rio(cloud_defaults=True)
#     #with dask.config.set(pool=concurrent.futures.ThreadPoolExecutor(16), scheduler="threads"):
#     process = process_and_write_tile(tile, store, serverless=False)
#     return process


# def spawn_coiled_jobs(
#     modis_grid_land_list, store):
#     h_list = [tile['h'] for _, tile in modis_grid_land_list]
#     v_list = [tile['v'] for _, tile in modis_grid_land_list]
#     results = list(
#         tqdm.tqdm(
#             process_chunk.map(
#                 h_list, 
#                 v_list,
#                 store=store,
#                 retries=5
#             ),
#             total=len(h_list),
#             desc="Jobs Completed",
#         )
#     )
#     return results

# # def spawn_coiled_jobs(modis_grid_land_list, store, batch_size=10):
# #     batches = [modis_grid_land_list[i:i+batch_size] for i in range(0, len(modis_grid_land_list), batch_size)]
# #     results = list(
# #         tqdm.tqdm(
# #             process_chunks.map(
# #                 batches,
# #                 store=store,
# #                 retries=5
# #             ),
# #             total=len(batches),
# #             desc="Batch completed",
# #         )
# #     )
# #     return [item for sublist in results for item in sublist]

# #results = spawn_coiled_jobs(modis_grid_land_list, store)
# #results

In [None]:
#futures = []
# # for _, tile in tqdm.tqdm(modis_grid_land_list):
# #     h = tile['h']
# #     v = tile['v']
# #     try:
# #         process_and_write_tile(h, v, store)
# #         print(f"Tile h{h}_v{v} processed and written")
# #     except Exception as e:
# #         print(f"Error processing tile h{h}_v{v}: {str(e)}")
# #         print(f"Traceback: {traceback.format_exc()}")
# #         # maybe append to a list of all tiles that need to be rerun
# #     #future = client.submit(process_and_write_tile, h, v, store)
# #     #futures.append(future)
# # #results = wait(futures)
# # 
# # for future in futures:
#     try:
#         result = future.result()
#         print(result)
#     except Exception as e:
#         print(f"Task failed: {str(e)}")
#         print(f"Traceback: {future.traceback()}")
# 
# 
# # client.close()
# cluster.close()
# 
#     #seasonal_snow_presence.drop_vars('spatial_ref').chunk({'water_year':1,'y':2400,'x':2400}).to_zarr(store, region={'water_year':water_year_slice,'y':y_slice,'x':x_slice}, mode="r+")

# if serverless:
#     print(f'running serverless mode, using threadpoolexecutor...')
#     with dask.config.set(pool=concurrent.futures.ThreadPoolExecutor(16), scheduler="threads"):
#         for var in ['SAD_DOWY', 'SDD_DOWY', 'max_consec_snow_days']:
#             data = seasonal_snow_presence[var].values
#             root[var][:,y_start:y_end,x_start:x_end] = data
# else:
#     for var in ['SAD_DOWY', 'SDD_DOWY', 'max_consec_snow_days']:
#         data = seasonal_snow_presence[var].values
#         root[var][:,y_start:y_end,x_start:x_end] = data

# root[:, time_slice, y_slice, x_slice] = data


    #root[var][time_slice, y_slice, x_slice] = data

# if data.shape[0] == 9 and data.shape[1] == 2400 and data.shape[2] == 2400:
#    print(f'transpose necessary h{h}_v{v}')
#    data = np.transpose(data, (1, 2, 0))


# store.flush()



# with concurrent.futures.ThreadPoolExecutor(max_workers=20) as executor:  # Adjust number as needed
#     futures = [executor.submit(process_and_write_tile, h, v, azure_zarr_path) 
#                for h, v in modis_grid_land_list]

# def process_batch(batch):
#     results = []
#     for h, v in batch:
#         results.append(process_and_write_tile(h, v, azure_zarr_path))
#     return results

# batch_size = 10  # Adjust based on your workload
# batches = [modis_grid_land_list[i:i+batch_size] for i in range(0, len(modis_grid_land_list), batch_size)]
# futures = client.map(process_batch, batches)



# def create_azure_zarr_store(connection_string, container_name, zarr_store_path):
#     blob_service_client = BlobServiceClient.from_connection_string(connection_string)
#     container_client = blob_service_client.get_container_client(container_name)

#     class AzureBlobStore(zarr.ABSStore):
#         def __init__(self, container_client, prefix):
#             self.container_client = container_client
#             self.prefix = prefix
#             self.client = container_client  # Add this line

#         def __getitem__(self, key):
#             blob_client = self.container_client.get_blob_client(f"{self.prefix}/{key}")
#             return blob_client.download_blob().readall()

#         def __setitem__(self, key, value):
#             blob_client = self.container_client.get_blob_client(f"{self.prefix}/{key}")
#             blob_client.upload_blob(value, overwrite=True)

#         def __contains__(self, key):
#             blob_client = self.container_client.get_blob_client(f"{self.prefix}/{key}")
#             return blob_client.exists()

#         def __delitem__(self, key):
#             blob_client = self.container_client.get_blob_client(f"{self.prefix}/{key}")
#             blob_client.delete_blob()

#         def rmdir(self, path):
#             dir_path = self.prefix
#             if path:
#                 dir_path += "/" + path
#             dir_path += "/"
#             blobs_to_delete = self.container_client.list_blobs(
#                 name_starts_with=dir_path
#             )
#             for blob in blobs_to_delete:
#                 self.container_client.delete_blob(blob)

#     store = AzureBlobStore(container_client, zarr_store_path)

#     root = zarr.open(store, mode="w")
#     # root.create_dataset('SAD_DOWY', shape=(36 * 2400, 18 * 2400, len(range(WY_start, WY_end + 1))), chunks=(2400, 2400, 1), dtype='i2')
#     # root.create_dataset('SDD_DOWY', shape=(36 * 2400, 18 * 2400, len(range(WY_start, WY_end + 1))), chunks=(2400, 2400, 1), dtype='i2')
#     # root.create_dataset('max_consec_snow_days', shape=(36 * 2400, 18 * 2400, len(range(WY_start, WY_end + 1))), chunks=(2400, 2400, 1), dtype='i2')
#     water_years = list(range(WY_start, WY_end + 1))

#     num_years = len(water_years)

#     compressor = numcodecs.Blosc(
#         cname="zstd", clevel=3, shuffle=numcodecs.Blosc.SHUFFLE
#     )

#     # Create datasets
#     for var in ['SAD_DOWY', 'SDD_DOWY', 'max_consec_snow_days']:
#         dataset = root.create_dataset(
#             var,
#             shape=(num_years, 18 * 2400, 36 * 2400),
#             chunks=(1, 2400, 2400),
#             dtype="i2",
#             compressor=compressor,
#         )



#         # Add dimension names as attributes

#     #root.create_dataset("water_year", data=water_years, shape=(num_years,), dtype="i2")
#     # root["time"].attrs[
#     #     "description"
#     # ] = "Water year. In northern hemisphere, water year starts on October 1st and ends on September 30th. For the southern hemisphere, water year starts on April 1st and ends on March 31st. For example, in the northern hemisphere water year 2015 starts on October 1st, 2014 and ends on September 30th, 2015, and in the southern hemisphere water year 2015 starts on April 1st, 2015 and ends on March 31st, 2016."


#     return f"azure://{container_name}/{zarr_store_path}"

# from azure.core.exceptions import ResourceNotFoundError

# class AzureBlobStore(zarr.ABSStore):
#     def __init__(self, container_client, prefix):
#         self.container_client = container_client
#         self.prefix = prefix
#         self.client = container_client  # Add this line

#     def __getitem__(self, key):
#         blob_client = self.container_client.get_blob_client(f"{self.prefix}/{key}")
#         return blob_client.download_blob().readall()

#     def __setitem__(self, key, value):
#         blob_client = self.container_client.get_blob_client(f"{self.prefix}/{key}")
#         blob_client.upload_blob(value, overwrite=True)

#     def __contains__(self, key):
#         blob_client = self.container_client.get_blob_client(f"{self.prefix}/{key}")
#         return blob_client.exists()

#     def __delitem__(self, key):
#         blob_client = self.container_client.get_blob_client(f"{self.prefix}/{key}")
#         blob_client.delete_blob()

#     def rmdir(self, path):
#         dir_path = self.prefix
#         if path:
#             dir_path += "/" + path
#         dir_path += "/"
#         blobs_to_delete = self.container_client.list_blobs(
#             name_starts_with=dir_path
#         )
#         for blob in blobs_to_delete:
#             self.container_client.delete_blob(blob)


#blob_service_client = BlobServiceClient.from_connection_string(connection_string)
#container_client = blob_service_client.get_container_client(container_name)

#store = AzureBlobStore(container_client, zarr_store_path)
#root = zarr.open(store, mode="w")

#y = np.arange(0, 18 * 2400)
#x = np.arange(0, 36 * 2400)
    #connection_string = os.environ["azure-storage-connection-string"]
#parts = azure_zarr_path.split("/")

#container_name = parts[2]
#zarr_store_path = "/".join(parts[3:])

# blob_service_client = BlobServiceClient.from_connection_string(
#     connection_string
# )
#container_client = blob_service_client.get_container_client(container_name)

#store = AzureBlobStore(connection_string,container_client, zarr_store_path)
#root = zarr.open(store, mode="a")


# x_start, x_end = h * 2400, (h + 1) * 2400
# y_start, y_end = v * 2400, (v + 1) * 2400
        # with dask.config.set(pool=concurrent.futures.ThreadPoolExecutor(16), scheduler="threads"):
# data = seasonal_snow_presence[['SAD_DOWY', 'SDD_DOWY', 'max_consec_snow_days']].to_array().values

#'water_year':water_years,time_slice = slice(0, data.shape[0])
                #seasonal_snow_presence.drop_vars('spatial_ref').chunk({'water_year':num_years,'y':2400,'x':2400}).to_zarr(store, region={'water_year':water_year_slice,'y':y_slice,'x':x_slice}, mode="r+")


# def check_environment():
#     import sys
#     import os
#     result = {
#         "sys.path": sys.path,
#         "current_dir": os.getcwd(),
#         "list_dir": os.listdir(),
#         "env_vars": dict(os.environ),
#     }
#     try:
#         import easysnowdata
#         result["easysnowdata_version"] = easysnowdata.__version__
#     except ImportError as e:
#         result["easysnowdata_error"] = str(e)
#     try:
#         import modis_masking
#         result["modis_masking_file"] = modis_masking.__file__
#     except ImportError as e:
#         result["modis_masking_error"] = str(e)
#     return result

# # Run this on all workers
# environment_info = client.run(check_environment)

# # Print the results
# for worker, info in environment_info.items():
#     print(f"Worker {worker}:")
#     for key, value in info.items():
#         print(f"  {key}: {value}")
#     print()


# Set the Azure Blob Storage path for the zarr store
#container_name = "snowmelt"
#zarr_store_path = "modis_mask/global_modis_snow_mask.zarr"
#azure_zarr_path = f"azure://{container_name}/{zarr_store_path}"



# # Load progress
# progress = load_progress()
# processed_tiles = set(progress['processed'])
# failed_tiles = set(progress['failed'])

# # Load processed tiles from zarr
# zarr_store = zarr.open(store, mode='r')
# zarr_processed_tiles = set(zarr_store.attrs['processed_tiles'])



# failed_tiles = []

# def process_tile(tile, store):
#     result = process_and_write_tile(tile, store)
#     client.restart()  # Restart workers to clear memory
#     return result

# # First pass: process all tiles
# for tile in tqdm.tqdm(tile_processing_list):

#     try:
#         result = process_tile(tile, store)
#         print(f"Tile {tile} processed and written")
#     except Exception as e:
#         print(f"Error processing tile {tile}: {str(e)}")
#         print(f"Traceback: {traceback.format_exc()}")
#         failed_tiles.append(tile)

# # Second pass: retry failed tiles
# max_retries = 3
# retry_count = 0

# while failed_tiles and retry_count < max_retries:
#     retry_count += 1
#     print(f"Retry attempt {retry_count} for failed tiles")
#     still_failed = []
    
#     for tile in tqdm.tqdm(failed_tiles):
#         try:
#             result = process_tile(tile, store)
#             print(f"Tile {tile} processed and written on retry")
#         except Exception as e:
#             print(f"Error processing tile {tile} on retry: {str(e)}")
#             print(f"Traceback: {traceback.format_exc()}")
#             still_failed.append(tile)
    
#     failed_tiles = still_failed

# if failed_tiles:
#     print("The following tiles could not be processed after all retries:")
#     for tile in failed_tiles:
#         print(f"{tile}")

# client.close()
# cluster.close()

# fixed