# Global MODIS snow presence processing

### Installs and imports

In [None]:
#!pip install -q -e 'git+https://github.com/egagli/easysnowdata.git#egg=easysnowdata'
#!pip install -q geodatasets

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import geopandas as gpd
import geodatasets
import xarray as xr
import easysnowdata
import modis_masking
#import dask_gateway
import tqdm
import glob
import os
import shutil
import tempfile

In [None]:
import dask
import dask.distributed
from dask.diagnostics import ProgressBar

import coiled


# coiled.create_software_environment(
#     name="sar_snowmelt_timing",
#     conda="../sar_snowmelt_timing/environment.yml",
#     workspace="azure",
#     #force_rebuild=True
# )

cluster = coiled.Cluster(idle_timeout="5 minutes",
                         n_workers=30,
                         worker_memory="8 GiB",
                         spot_policy="spot", # spot usually
                         #software="sar_snowmelt_timing",
                         environ={"GDAL_DISABLE_READDIR_ON_OPEN": "EMPTY_DIR"},
                         #container="mcr.microsoft.com/planetary-computer/python:latest",
                         workspace="azure",
                         )

client = cluster.get_client()

### Use dask gateway on planetary computer, we've got lots of computations to do...

In [None]:
# cluster = dask_gateway.GatewayCluster()
# client = cluster.get_client()
# client.upload_file('modis_masking.py')
# cluster.scale(50)
# print(cluster.dashboard_link)

### To save time, let's only process MODIS tiles with land in them! Check out [MODIS's grid system](https://modis-land.gsfc.nasa.gov/MODLAND_grid.html).

In [None]:
modis_grid = gpd.read_file('zip+http://book.ecosens.org/wp-content/uploads/2016/06/modis_grid.zip!modis_sinusoidal_grid_world.shp')
land = gpd.read_file(geodatasets.get_url('naturalearth land'))
land_modis_crs = gpd.GeoSeries(land.union_all(),crs='EPSG:4326').to_crs(modis_grid.crs)
modis_grid_land_idx = modis_grid.intersects(land_modis_crs.union_all())
modis_grid_land_idx[600] = False
modis_grid_land = modis_grid[modis_grid_land_idx]
modis_grid_not_land = modis_grid[~modis_grid_land_idx]
modis_grid_land

In [None]:
f,ax=plt.subplots(figsize=(15,15))
land_modis_crs.plot(ax=ax,color='green')
modis_grid_not_land.geometry.boundary.plot(ax=ax,color='gray',linewidth=0.5)
modis_grid_land.geometry.boundary.plot(ax=ax,color='blue',linewidth=2)
ax.set_title('MODIS grid system\nland tiles in blue')

### Use easysnowdata to bring in MODIS MOD10A2 product, select 'Maximum_Snow_Extent'. Add WY and DOWY info to the time dimension. Binarize and cloud fill data. Groupby water year and create snow presence rasters for each water year. Then save to zarr to be stitched together later.

In [None]:
WY_start = 2015
WY_end = 2023

modis_grid_land_list = list(modis_grid_land.iterrows())
output_dir = 'output/global'


file_list = glob.glob(f'{output_dir}/*.zarr')
file_list

In [None]:
for index, tile in tqdm.tqdm(modis_grid_land_list):

    h = tile['h']
    v = tile['v']

    filepath = f'{output_dir}/tile_h{h}_v{v}.zarr'

    hemisphere = 'northern' if v < 9 else 'southern'
    
    if filepath not in file_list:
    
        if hemisphere == 'northern':
            modis_snow_da = easysnowdata.remote_sensing.MODIS_snow(vertical_tile=v,horizontal_tile=h,clip_to_bbox=False,start_date=f'{WY_start-1}-10-01',end_date=f'{WY_end}-09-30',data_product="MOD10A2", bands='Maximum_Snow_Extent', mute=True).data['Maximum_Snow_Extent']
        else:
            modis_snow_da = easysnowdata.remote_sensing.MODIS_snow(vertical_tile=v,horizontal_tile=h,clip_to_bbox=False,start_date=f'{WY_start}-04-01',end_date=f'{WY_end+1}-03-31',data_product="MOD10A2", bands='Maximum_Snow_Extent', mute=True).data['Maximum_Snow_Extent']
        
        modis_snow_da.coords['WY'] = ("time",pd.to_datetime(modis_snow_da.time).map(lambda x: easysnowdata.utils.datetime_to_WY(x,hemisphere=hemisphere)))
        modis_snow_da.coords['DOWY'] = ("time",pd.to_datetime(modis_snow_da.time).map(lambda x: easysnowdata.utils.datetime_to_DOWY(x,hemisphere=hemisphere)))
    
        modis_snow_da = modis_snow_da[(modis_snow_da.WY >= WY_start) & (modis_snow_da.WY <= WY_end)]
        
        effective_snow_da = modis_masking.binarize_with_cloud_filling(modis_snow_da)
    
        seasonal_snow_presence = effective_snow_da.groupby('WY').apply(modis_masking.get_max_consec_snow_days_SAD_SDD_one_WY).compute()
    
        # using tempfile so i dont accidentally corrupt with keyboard interrupt
        with tempfile.TemporaryDirectory() as tmpdirname:
            temp_filepath = os.path.join(tmpdirname, 'temp.zarr')
            seasonal_snow_presence.to_zarr(temp_filepath, mode='w')
            shutil.move(temp_filepath, filepath)
            print(f'{filepath} complete')

    else:
        print(f'{filepath} already processed. skipping...')

In [None]:
# def process_tile(tile, output_dir, file_list, WY_start, WY_end):
    
#     h = tile['h']
#     v = tile['v']

#     hemisphere = 'northern' if v < 9 else 'southern'

#     filepath = f'{output_dir}/tile_h{h}_v{v}.zarr'
#     print(f'the filepath is {filepath}')
    
    
#     if filepath not in file_list:
#         if hemisphere == 'northern':
#             modis_snow_da = easysnowdata.remote_sensing.MODIS_snow(vertical_tile=v, horizontal_tile=h, clip_to_bbox=False, start_date=f'{WY_start-1}-10-01', end_date=f'{WY_end}-09-30', data_product="MOD10A2", bands='Maximum_Snow_Extent', mute=True).data['Maximum_Snow_Extent']
#         else:
#             modis_snow_da = easysnowdata.remote_sensing.MODIS_snow(vertical_tile=v, horizontal_tile=h, clip_to_bbox=False, start_date=f'{WY_start}-04-01', end_date=f'{WY_end+1}-03-31', data_product="MOD10A2", bands='Maximum_Snow_Extent', mute=True).data['Maximum_Snow_Extent']
        
#         modis_snow_da.coords['WY'] = ("time", pd.to_datetime(modis_snow_da.time).map(lambda x: easysnowdata.utils.datetime_to_WY(x, hemisphere=hemisphere)))
#         modis_snow_da.coords['DOWY'] = ("time", pd.to_datetime(modis_snow_da.time).map(lambda x: easysnowdata.utils.datetime_to_DOWY(x, hemisphere=hemisphere)))
    
#         modis_snow_da = modis_snow_da[(modis_snow_da.WY >= WY_start) & (modis_snow_da.WY <= WY_end)]
        
#         effective_snow_da = modis_masking.binarize_with_cloud_filling(modis_snow_da)
    
#         seasonal_snow_presence = effective_snow_da.groupby('WY').apply(modis_masking.get_max_consec_snow_days_SAD_SDD_one_WY)
    
#         seasonal_snow_presence.to_zarr(filepath, mode='w-') #compute False so we can use delayed
    
#         print(f'tile h={h} v={v} complete')
#     else:
#         print(f'{filepath} already processed. skipping...')
# 
# for index, tile in tqdm.tqdm(modis_grid_land_list):
#     process_tile(tile, output_dir, file_list, WY_start, WY_end)
# 
# 
# 
# # import logging

# # Setup basic configuration for logging
# logging.basicConfig(level=logging.INFO)

# # Replace print statements with logging

# # Define a function to process each tile
# def process_tile(tile, output_dir, file_list, WY_start, WY_end):
#     h = tile['h']
#     v = tile['v']
#     filepath = f'{output_dir}/tile_h{h}_v{v}.zarr'
#     hemisphere = 'northern' if v < 9 else 'southern'
#     logging.info(filepath)
    
#     if filepath not in file_list:
#         if hemisphere == 'northern':
#             modis_snow_da = easysnowdata.remote_sensing.MODIS_snow(vertical_tile=v, horizontal_tile=h, clip_to_bbox=False, start_date=f'{WY_start-1}-10-01', end_date=f'{WY_end}-09-30', data_product="MOD10A2", bands='Maximum_Snow_Extent', mute=True).data['Maximum_Snow_Extent']
#         else:
#             modis_snow_da = easysnowdata.remote_sensing.MODIS_snow(vertical_tile=v, horizontal_tile=h, clip_to_bbox=False, start_date=f'{WY_start}-04-01', end_date=f'{WY_end+1}-03-31', data_product="MOD10A2", bands='Maximum_Snow_Extent', mute=True).data['Maximum_Snow_Extent']
        
#         modis_snow_da.coords['WY'] = ("time", pd.to_datetime(modis_snow_da.time).map(lambda x: easysnowdata.utils.datetime_to_WY(x, hemisphere=hemisphere)))
#         modis_snow_da.coords['DOWY'] = ("time", pd.to_datetime(modis_snow_da.time).map(lambda x: easysnowdata.utils.datetime_to_DOWY(x, hemisphere=hemisphere)))
    
#         modis_snow_da = modis_snow_da[(modis_snow_da.WY >= WY_start) & (modis_snow_da.WY <= WY_end)]
        
#         effective_snow_da = modis_masking.binarize_with_cloud_filling(modis_snow_da)
    
#         seasonal_snow_presence = effective_snow_da.groupby('WY').apply(modis_masking.get_max_consec_snow_days_SAD_SDD_one_WY)
    
#         seasonal_snow_presence.to_zarr(filepath, mode='w-') #compute False so we can use delayed
    
#         return f'tile h={h} v={v} complete'
#     else:
#         return f'{filepath} already processed. skipping...'

# modis_grid_land_list[0]

# # Assuming modis_grid_land_list, output_dir, file_list, WY_start, and WY_end are defined
# tasks = [dask.delayed(process_tile)(tile, output_dir, file_list, WY_start, WY_end) for index,tile in modis_grid_land_list[0:1]]
# tasks


# # Compute tasks in parallel
# with ProgressBar():
#     results = dask.compute(*tasks)

# results