In [2]:
import spires
import xarray
import h5py
import numpy
import scipy
import netCDF4
import dask
import geopandas
import pandas

In [3]:
import importlib
import spires

# Reload the module
importlib.reload(spires)

# Todo
- [x] Read Solar zenith
- [x] Interpolate solar zenith, maybe
- [x] Read Sentinel LUT
- [x] Create LUT interpolator
- [x] Call speedy_invert on a single obervation
- [ ] Call speedy_invert on a timestep (dask delayed)
- [ ] Call speedy_invert on cube (Probably iterate)

I wouldnâ€™t bother with 8a. Maybe run spires on bands 2,3,4,8 & resample 11-12 (& maybe 5-7) to 10 m

# Corrections
- [ ] correct spectral distortion: https://www.ncbi.nlm.nih.gov/pmc/articles/PMC8321035/
- [ ] canopy cover
- [ ] temporal smoothing
- [ ] spatial interpolations

## Performace 
- [ ] adjust chunking
- [ ] Filter locations with NDSI < -0.5
- [ ] numba ... probably wont work with scipy.interpolate._rgi.RegularGridInterpolator
- [ ] How much LUT can reside in memory
- [ ] compute for uniques only
    - [ ] implement uniquetol
        - https://www.mathworks.com/help/matlab/ref/uniquetol.html
        - https://github.com/edwardbair/SPIRES/blob/master/core/run_spires.m#L98
    - [ ] find untiques and label them
    - [ ] compute for unqiues and broadcast back
- [ ] How much can we cache?
    - 21 dimensions (10 in R, 10 in R0, 1 solar_z)
    - Discretize to 8 bit: 256 ** 21 = '3.74E+50' values/bytes = '3.74E+38' TB ... infeasible (obviously)

## Loading Observations 
- interpolate the viewing and sun angles
- Calculate the ndvi and ndis

In [1]:
zarr_store = '/tablespace/sentinel2/ucsb_sharpend.zarr/'
ds = xarray.open_zarr(zarr_store)

ds['sun_zenith_grid'] = ds['sun_zenith_grid'].interp(y_angles=ds.y, x_angles=ds.x, method='nearest').squeeze()
ds['sun_azimuth_grid'] = ds['sun_azimuth_grid'].interp(y_angles=ds.y, x_angles=ds.x, method='nearest').squeeze()

ds['viewing_zenith_grid'] = ds['viewing_zenith_grid'].interp(y_angles=ds.y, x_angles=ds.x, method='nearest').squeeze()
ds['viewing_azimuth_grid'] = ds['viewing_azimuth_grid'].interp(y_angles=ds.y, x_angles=ds.x, method='nearest').squeeze()

ds = ds.drop_dims(('x_angles', 'y_angles'))

with dask.config.set(**{'array.slicing.split_large_chunks': False}):   
    ds = ds.sel(band=['B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B11', 'B12', 'B8'])

In [14]:
reflectance = ds['reflectance']

b8_b4 = (reflectance.sel(band='B8') + reflectance.sel(band='B4'))
b8_b4 = b8_b4.where(b8_b4!=0)
ndvi = (reflectance.sel(band='B8') - reflectance.sel(band='B4')) / b8_b4

b3_b11 = (reflectance.sel(band='B3') + reflectance.sel(band='B11'))
b3_b11 = b3_b11.where(b3_b11!=0)
ndsi = (reflectance.sel(band='B3') - reflectance.sel(band='B11')) / b3_b11

ndsi = ndsi.where(ndsi<1).where(ndsi>-1)
ndvi = ndvi.where(ndvi<1).where(ndvi>-1)

ds['ndvi'] = ndvi
ds['ndsi'] = ndsi

# Load background reflectances

In [16]:
zarr_store = '/tablespace/sentinel2/ucsb_r0_cut.zarr/'
ds_r0 = xarray.open_zarr(zarr_store)

with dask.config.set(**{'array.slicing.split_large_chunks': False}):    
    ds_r0 = ds_r0.sel(band=['B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B11', 'B12', 'B8'])

ds_r0 = ds_r0.drop_dims(('x_angles', 'y_angles'))

# Subset to ROI

In [17]:
# ROIs from the tristate project
path = '/tablespace/sentinel2/rois.gpkg'
rois = geopandas.read_file(path)

bounds = rois[rois['name'] == 'UCSB'].to_crs(32611).buffer(20).iloc[0].bounds
minx, miny, maxx, maxy = bounds
start_time = pandas.Timestamp('2024-02-25')

In [18]:
ds = ds.sel(x=slice(minx, maxx), y=slice(maxy, miny), time=slice(start_time, None))
ds_r0 = ds_r0.sel(x=slice(minx, maxx), y=slice(maxy, miny))

In [20]:
ds.to_netcdf('/home/griessbaum/sentinel_r.nc')
ds_r0.to_netcdf('/home/griessbaum/sentinel_r0.nc')

## Lut file

In [None]:
lut_file = '/tablespace/sentinel2/lut_sentinel2b_b2to12_3um_dust.mat' 
spires.Interpolator(lut_file=lut_file)

# Subset to single timestep

In [None]:
ts = ds.sel(time='2022-03-02').squeeze().drop_vars('time')

# Invert one

In [None]:
%%time
r = ts.isel(x=600, y=500)['reflectance'].values
r0 = ds_r0.isel(x=600, y=500)['reflectance'].values

In [None]:
solar_z = numpy.array(ts.attrs['sun_zenith_mean'])
shade = numpy.zeros(9)

In [None]:
##%%timeit
res, model_ref = spires.speedy_invert(lut_interpolator, 
                                      r, 
                                      r0, 
                                      solar_z, 
                                      shade, 
                                      bounds_grain=bounds_grain,  
                                      bounds_dust=bounds_dust,                                      
                                      mode=4)

# Invert an array

In [None]:
%%time
r = ts['reflectance'].compute()#.values
r0 = ds_r0['reflectance'].compute()#.values

In [None]:
size = 10
r = r[:, 0:size, 0:2]
r0 = r0[:, 0:size, 0:2]

In [None]:
%%time
r = r.T
r0 = r0.T
properties = numpy.zeros([r.shape[0], r.shape[1], 4])
for x in range(r.shape[0]):
    #print(x)
    for y in range(r.shape[1]):
        r_ = r[ x, y, :].values
        r0_ = r0[ x, y, :].values
        res, model_refl = spires.speedy_invert(lut_interpolator_f, r_, r0_, solar_z, shade, mode=4, 
                                               bounds_grain=bounds_grain,  
                                               bounds_dust=bounds_dust)     
        properties[ x, y, :] = res.x

In [None]:
chunksize = 200
r0 = r0.chunk(band=-1, x=chunksize, y=chunksize)
r = r.chunk(band=-1, x=chunksize, y=chunksize)

In [None]:
def run(r, r0):
    res, model_refl = spires.speedy_invert(lut_interpolator, r, r0, solar_z, shade, mode=4)    
    return res.x    

In [None]:
def run_vectorized(r, r0):
    properties = numpy.zeros([r.shape[0], r.shape[1], 4])
    
    for x in range(r.shape[0]):
        for y in range(r.shape[1]):
            r_ = r[ x, y, :]
            r0_ = r0[ x, y, :]            
            res, model_refl = spires.speedy_invert(lut_interpolator, 
                                                   r_, 
                                                   r0_, 
                                                   solar_z, 
                                                   shade, 
                                                   mode=4, 
                                                   bounds_grain=bounds_grain,  
                                                   bounds_dust=bounds_dust)     
            properties[ x, y, :] = res.x
    
    return properties

# ssh -N -L 8001:localhost:8787 schiss.eri.ucsb.edu

In [None]:
from dask.distributed import LocalCluster
import logging

cluster = dask.distributed.LocalCluster(n_workers=20, 
                                        threads_per_worker=1, 
                                        memory_limit='5GB', 
                                        processes=True, 
                                        dashboard_address='localhost:8787',
                                        silence_logs=logging.ERROR)

In [None]:
cluster.close()

In [None]:
import dask 

res = xarray.apply_ufunc(run_vectorized, 
                         r,      
                         r0,                                      
                         dask='parallelized',                              
                         input_core_dims=[['band'], ['band']],
                         output_core_dims=[['property']], 
                         dask_gufunc_kwargs={'allow_rechunk': False, 'output_sizes': {'property': 4}}, 
                         output_dtypes=[float],                          
                         vectorize=False)
res

In [None]:
%%time
with dask.distributed.Client(cluster) as client:              
    res = res.compute()

properties = ['fsca', 'fshade', 'rg', 'dust']    
res = res.assign_coords(coords={'property': properties})
res = res.to_dataset(dim='property')
res

In [None]:
# Modis benchmark:
(2400 * 2400) / 60 / 60 / 5