In [1]:
import glob
import numpy as np
import numpy.ma as ma
import os
import rasterio
import rioxarray
from scipy.ndimage import maximum_filter
import xarray as xr


%run _constants.ipynb



def _buffer_mask(mask, radius=12):
    
    kernel = _get_circular_mask(radius)
    buffered_mask = maximum_filter(mask, footprint=kernel, mode='constant', cval=0)
    return xr.DataArray(buffered_mask, coords=mask.coords, dims=mask.dims).chunk(chunks=mask.chunks)


def _get_circular_mask(size):
    
    radius = int(size/2)
    center = (radius, radius)
    Y, X = np.ogrid[:size, :size]
    dist_from_center = np.sqrt((X - center[0])**2 + (Y-center[1])**2)

    mask = dist_from_center <= radius
    return np.array([mask])


def _get_shadow(cloud_height, azimuth_rad, zenith_rad, cloud_mask, scale=10):

    shadow_vector = round(np.tan(zenith_rad) * cloud_height)
        
    x_shift = round(np.cos(azimuth_rad) * shadow_vector / scale)
    y_shift = round(np.sin(azimuth_rad) * shadow_vector / scale)

    print('\t\tx_shift:', x_shift, ', y_shift:', y_shift)
    
    shadows = cloud_mask.roll(x=x_shift, roll_coords=False)
    shadows = shadows.roll(y=y_shift, roll_coords=False)
        
    if x_shift > 0:
        shadows[0, :, :x_shift] = False
    elif x_shift < 0:
        shadows[0, :, x_shift:] = False

    if y_shift > 0:
        shadows[0, :y_shift, :] = False
    elif y_shift < 0:
        shadows[0, y_shift:, :] = False
        
    return shadows


def _get_cloud_shadow_mask(cloud_mask, azimuth, zenith, nir_da, scl_da):

    # solar azimuth is opposite of illumination direction plus another 90 for the S2 instrument
    azimuth = azimuth - 270   
    azimuth_rad = np.deg2rad(azimuth)
    zenith_rad = np.deg2rad(zenith)
        
        
    cloud_heights = np.arange(400, 1200, 200)
    potential_shadow = [_get_shadow(height, azimuth_rad, zenith_rad, cloud_mask) for height in cloud_heights]
    potential_shadow = xr.concat(potential_shadow, dim="band")
    potential_shadow = potential_shadow.sum(dim='band') > 0    
    potential_shadow = potential_shadow.expand_dims(dim='band', axis=0)

    water = scl_da == 6
    dark_pixels = (nir_da < 1500) & ~water
    shadow = potential_shadow & dark_pixels
        
    return shadow

    
def _get_scl_bad_pixel_mask(scl_da):
    
    bad_values = [0, 1, 11]
    mask = scl_da.isin(bad_values)
    return mask
    

def _get_scl_cloud_mask(scl_da):

    cloud_values = [8, 9, 10]
    mask = scl_da.isin(cloud_values)
    return mask

    
def _get_bcy_cloud_mask(green_da, red_da):
    
    # (green > 0.175 AND NDGR > 0) OR (green > 0.39)
       
    ndgr = (green_da.astype(np.float32) - red_da.astype(np.float32)) / (green_da + red_da)
    
    cond1 = (green_da > 1750) & (ndgr > 0) 
    cond2 = green_da > 3900
    mask = cond1 | cond2
    
    return mask



def save_cloud_masked_images_dask(scene_dict, dst_dir, overwrite=True):
    
    if os.path.exists(f'{dst_dir}/B08_masked.tif') and not overwrite:
        return {
            band_name: f'{dst_dir}/{band_name}_masked.tif'
            for band_name in scene_dict
        }
        
    # TODO: can just use masked?
    green_da = rioxarray.open_rasterio(scene_dict['B03'], chunks=(1, 1000, 1000), mask_and_scale=True)
    red_da = rioxarray.open_rasterio(scene_dict['B04'], chunks=(1, 1000, 1000), mask_and_scale=True)
    nir_da = rioxarray.open_rasterio(scene_dict['B08'], chunks=(1, 1000, 1000), mask_and_scale=True)
    scl_da = rioxarray.open_rasterio(scene_dict['SCL'], chunks=(1, 1000, 1000), mask_and_scale=True)
       
    
    bcy_cloud_mask = _get_bcy_cloud_mask(green_da, red_da)
    scl_cloud_mask = _get_scl_cloud_mask(scl_da)
    cloud_mask = bcy_cloud_mask | scl_cloud_mask
    # cloud_mask = _buffer_mask(cloud_mask)

    bad_mask = _get_scl_bad_pixel_mask(scl_da)
    
    meta = scene_dict['meta']    
    cloud_shadow_mask = _get_cloud_shadow_mask(cloud_mask, meta["AZIMUTH_ANGLE"], meta["ZENITH_ANGLE"], nir_da, scl_da)    
    # cloud_shadow_mask = _buffer_mask(cloud_shadow_mask)
    
    mask = cloud_mask | bad_mask | cloud_shadow_mask
    mask = _buffer_mask(mask)

    #mask = mask.compute().values[0, :, :]
                    
    masked_dict = {}
    for band_name in scene_dict:
        if band_name in ["meta", "SCL"]: continue

        band_path = scene_dict[band_name]
        masked_path = f'{dst_dir}/{band_name}_masked.tif'
        masked_dict[band_name] = masked_path
        
        band_da = rioxarray.open_rasterio(band_path, chunks=(1, 1000, 1000), mask_and_scale=True)
        
        masked_band = band_da.where(mask)

        masked_band.rio.to_raster(masked_path, dtype='uint16', nodata=0)
        continue


           
        with rasterio.open(band_path) as band_src:        
            band_data = band_src.read(1)           
            masked_data = ma.masked_array(band_data, mask=mask)
                
            zeroed_data = masked_data.data
            zeroed_data[masked_data.mask] = NODATA_UINT16
            
            with rasterio.open(masked_path, "w", **band_src.profile) as masked_src:
                masked_src.write(zeroed_data, 1)

    return masked_dict

