# Todo
- [x] Read Solar zenith
- [x] Interpolate solar zenith, maybe
- [x] Read Sentinel LUT
- [x] Create LUT interpolator
- [x] Call speedy_invert on a single obervation
- [x] Call speedy_invert on a timestep (dask delayed)
- [x] Call speedy_invert on cube (Probably iterate)

I wouldnâ€™t bother with 8a. Maybe run spires on bands 2,3,4,8 & resample 11-12 (& maybe 5-7) to 10 m

# Corrections
- [ ] correct spectral distortion: https://www.ncbi.nlm.nih.gov/pmc/articles/PMC8321035/
- [ ] canopy cover
- [ ] temporal smoothing
- [ ] spatial interpolations

## Performace 
- [ ] adjust chunking
- [ ] Filter locations with NDSI < -0.5
- [x] How much LUT can reside in memory (all of it)
- [ ] compute for uniques only
    - [x] implement uniquetol
        - https://www.mathworks.com/help/matlab/ref/uniquetol.html
        - https://github.com/edwardbair/SPIRES/blob/master/core/run_spires.m#L98
    - [ ] find untiques and label them
    - [ ] compute for unqiues and broadcast back
- [x] How much can we cache?
    - 21 dimensions (10 in R, 10 in R0, 1 solar_z)
    - Discretize to 8 bit: 256 ** 21 = '3.74E+50' values/bytes = '3.74E+38' TB ... infeasible (obviously)

# Bands in LUT

We need to drop band 8A. 

| band id   | band name | Res |
| --:       |       --: | --: |
| 1         | 1         | 60  |
| 2         | 2         | 10  |
| 3         | 3         | 10  |
| 4         | 4         | 10  |
| 5         | 5         | 20  |
| 6         | 6         | 20  |
| 7         | 7         | 20  |
| 8         | 8         | 10  |
| 9         | 8a        | 20  |
| 10        |  9        | 60  |
| 11        |  10       | 60  |
| 12        |  11       | 20  |
| 13        |  12       | 20  |

sentinel_lut.mat contains an array called SensorTableBandOrder. It contains the values `[2, 3, 4, 5, 6, 7, 12, 13, 9]`.

We conclude that the bands are in the following order: `['B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B11', 'B12', 'B8']` and therfore need to subset and order r and r0


In [1]:
import spires
import xarray
import h5py
import numpy as np
import scipy
import netCDF4
import dask
import geopandas
import pandas
import pyproj
import matplotlib.pyplot as plt

## Loading Observations 
- interpolate the viewing and sun angles
- Calculate the ndvi and ndis

In [2]:
region = 'BSU'

In [3]:
zarr_store = f'/tablespace/sentinel2/{region}_sharpend.zarr/'
ds = xarray.open_zarr(zarr_store)

ds['sun_zenith_grid'] = ds['sun_zenith_grid'].interp(y_angles=ds.y, x_angles=ds.x, method='nearest').squeeze()
ds['sun_azimuth_grid'] = ds['sun_azimuth_grid'].interp(y_angles=ds.y, x_angles=ds.x, method='nearest').squeeze()

ds['viewing_zenith_grid'] = ds['viewing_zenith_grid'].interp(y_angles=ds.y, x_angles=ds.x, method='nearest').squeeze()
ds['viewing_azimuth_grid'] = ds['viewing_azimuth_grid'].interp(y_angles=ds.y, x_angles=ds.x, method='nearest').squeeze()

ds = ds.drop_dims(('x_angles', 'y_angles'))
ds = ds.chunk(x=500, y=500)

# we need this order!
ds = ds.sel(band=['B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B11', 'B12', 'B8'])

In [None]:
reflectance = ds['reflectance']

b8_b4 = (reflectance.sel(band='B8') + reflectance.sel(band='B4'))
b8_b4 = b8_b4.where(b8_b4!=0)
ndvi = (reflectance.sel(band='B8') - reflectance.sel(band='B4')) / b8_b4

b3_b11 = (reflectance.sel(band='B3') + reflectance.sel(band='B11'))
b3_b11 = b3_b11.where(b3_b11!=0)
ndsi = (reflectance.sel(band='B3') - reflectance.sel(band='B11')) / b3_b11

ndsi = ndsi.where(ndsi<1).where(ndsi>-1)
ndvi = ndvi.where(ndvi<1).where(ndvi>-1)

ds['ndvi'] = ndvi
ds['ndsi'] = ndsi

# Load background reflectances

In [None]:
zarr_store = f'/tablespace/sentinel2/{region}_r0.zarr/'
ds_r0 = xarray.open_zarr(zarr_store)

ds_r0 = ds_r0.sel(band=['B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B11', 'B12', 'B8'])

#ds_r0 = ds_r0.drop_dims(('x_angles', 'y_angles'))
#ds_r0 = ds_r0.chunk(band=9)
ds_r0

## Lut file

In [None]:
lut_file = '/tablespace/sentinel2/lut_sentinel2b_b2to12_3um_dust.mat' 
lut_interpolator = spires.LutInterpolator(lut_file=lut_file)

# Subset time

In [None]:
#ds = ds.sortby('time')
#ds = ds.sel(time=slice('2024-02-10', None))

# Subset to single timestep

In [None]:
ts = ds.isel(time=0).squeeze()

# Invert one

In [None]:
x = 500
y = 900
spectrum_target = ts.isel(x=x, y=y)['reflectance'].values
spectrum_background = ds_r0.isel(x=x, y=y)['reflectance'].values
spectrum_shade = np.zeros_like(spectrum_target)
solar_angle = ts.attrs['sun_zenith_mean']
solar_angle = float(ts['sun_zenith_grid'].isel(x=x, y=y).values)
x0 = np.array([0.5, 0.05, 10, 250])

In [None]:
#%%timeit
res = spires.speedy_invert(spectrum_background=spectrum_background, 
                           spectrum_target=spectrum_target,
                           spectrum_shade=spectrum_shade,                          
                           solar_angle=solar_angle, 
                           interpolator=lut_interpolator,                     
                           max_eval=500,
                           x0=x0,
                           algorithm=2)
res

# Invert an array (single timestep)

In [None]:
spectra_targets = ts['reflectance'].stack(yx=('y', 'x')).transpose('yx', 'band')
spectra_backgrounds = ds_r0['reflectance'].stack(yx=('y', 'x')).transpose('yx', 'band')
obs_solar_angles = ts['sun_zenith_grid'].stack(yx=('y', 'x'))
spectrum_shade = np.zeros_like(spectrum_target)

In [None]:
%%time
results = spires.speedy_invert_array1d(spectra_targets=spectra_targets, 
                                       spectra_backgrounds=spectra_backgrounds, 
                                       obs_solar_angles=obs_solar_angles, 
                                       spectrum_shade=spectrum_shade,                                  
                                       interpolator=lut_interpolator,
                                       max_eval=100,
                                       x0=x0, 
                                       algorithm=2)

In [None]:
results

In [None]:
%%time
results = spires.speedy_invert_array2d(spectra_targets=ts['reflectance'], 
                                       spectra_backgrounds=ds_r0['reflectance'], 
                                       obs_solar_angles=ts['sun_zenith_grid'],                                        
                                       interpolator=lut_interpolator,
                                       max_eval=100,
                                       x0=x0, 
                                       algorithm=2)

## Dask
- [x] parallelize in space
- [x] paralleize in time
- [x] time AND space
    -  might be worth it because too few timesteps (load balancing)
    -  will run out of memory for big ROIs
     
lsof -i :9895

Port forward to get to the dashboard

    ssh -N -L 8001:localhost:8787 schiss.eri.ucsb.edu

- should we try multiple threads per worker? Maybe 8?

In [None]:
from dask.distributed import LocalCluster
import dask.distributed
import logging

dask.config.set({'temporary-directory': '/tablespace/dask'})
dask.config.set({'distributed.comm.timeouts.tcp': '1200s'})
dask.config.set({'distributed.comm.timeouts.connect': '1200s'})
dask.config.get('distributed.comm.timeouts')

cluster = dask.distributed.LocalCluster(n_workers=64, 
                                        threads_per_worker=1, # Good question ... 
                                        memory_limit='4GB', 
                                        processes=True,  # Probably a good idea here
                                        dashboard_address='localhost:8787',
                                        silence_logs=logging.ERROR)

client = dask.distributed.Client(cluster) 

In [None]:
client.close()
cluster.close()

## Scatter the reflectance
I honestly have no idea what is happening here. Copy-paste from https://github.com/pydata/xarray/issues/6803 . This is what sucks about dask.
- whats up with client.replicate()?

In [None]:
%%time
import dask.array
a = dask.array.from_array(lut_interpolator.reflectances)
dsk = client.scatter(dict(a.dask), broadcast=True)
a = dask.array.Array(dsk, name=a.name, chunks=a.chunks, dtype=a.dtype, meta=a._meta, shape=a.shape)
refletance_scattered = xarray.DataArray(a, dims=['bands', 'sz', 'dust', 'grain'])

## Parallelize in space, 2D

In [None]:
def invert_xarray(spectra_targets, spectra_backgrounds, obs_solar_angles, bands, solar_angles, dust, grain, reflectances):
    results = spires.speedy_invert_array2d(spectra_targets=spectra_targets, 
                                           spectra_backgrounds=spectra_backgrounds, 
                                           obs_solar_angles=obs_solar_angles,                                            
                                           bands=bands,
                                           solar_angles=solar_angles,
                                           dust_concentrations=dust,
                                           grain_sizes=grain,
                                           reflectances=reflectances,                                                                                      
                                           max_eval=100,
                                           x0=x0, 
                                           algorithm=2)
    return results

In [None]:
spectra_targets = ts['reflectance']
spectra_backgrounds = ds_r0['reflectance']
obs_solar_angles = ts['sun_zenith_grid']

In [None]:
%%time
results = xarray.apply_ufunc(invert_xarray,
                             spectra_targets,
                             spectra_backgrounds,  
                             obs_solar_angles,
                             lut_interpolator.bands,
                             lut_interpolator.solar_angles,
                             lut_interpolator.dust_concentrations,
                             lut_interpolator.grain_sizes,       
                             refletance_scattered,
                             dask='parallelized',
                             input_core_dims=[['band'], ['band'], [], ['bands'], ['sz'], ['dust'], ['grain'], ['bands', 'sz', 'dust', 'grain']],
                             output_core_dims=[['property']],
                             dask_gufunc_kwargs={'allow_rechunk': False, 'output_sizes': {'property': 4}},
                             output_dtypes=[float],
                             vectorize=False)

In [None]:
results = results.to_dataset(dim='property').rename({0: 'fsca', 1: 'fshade', 2: 'dust', 3: 'grain'})

In [None]:
results

In [None]:
%%time
results = results.compute()

## Parallelize In Space+time

In [None]:
spectra_targets = ds['reflectance']
spectra_backgrounds = ds_r0['reflectance']
obs_solar_angles = ds['sun_zenith_grid']
x0 = np.array([0.5, 0.05, 10, 250])

In [None]:
def invert_xarray(spectra_targets, spectra_backgrounds, obs_solar_angles, bands, solar_angles, dust, grain, reflectances):       
    spectra_targets = spectra_targets.squeeze()
    obs_solar_angles = obs_solar_angles.squeeze()
    results = spires.speedy_invert_array2d(spectra_targets=spectra_targets, 
                                           spectra_backgrounds=spectra_backgrounds, 
                                           obs_solar_angles=obs_solar_angles,                                            
                                           bands=bands,
                                           solar_angles=solar_angles,
                                           dust_concentrations=dust,
                                           grain_sizes=grain,
                                           reflectances=reflectances,                                                                                      
                                           max_eval=100,
                                           x0=x0, 
                                           algorithm=2)
    results = np.expand_dims(results, axis=0)
    return results

In [None]:
results = xarray.apply_ufunc(invert_xarray,
                             spectra_targets,
                             spectra_backgrounds,  
                             obs_solar_angles,
                             lut_interpolator.bands,
                             lut_interpolator.solar_angles,
                             lut_interpolator.dust_concentrations,
                             lut_interpolator.grain_sizes,       
                             refletance_scattered,                   
                             dask='parallelized',
                             input_core_dims=[['band'], ['band'], [], ['bands'], ['sz'], ['dust'], ['grain'], ['bands', 'sz', 'dust', 'grain']],
                             output_core_dims=[['property']],
                             dask_gufunc_kwargs={'allow_rechunk': False, 'output_sizes': {'property': 4}},
                             output_dtypes=[float], # np.float32
                             vectorize=False)

In [None]:
results = results.to_dataset(dim='property').rename({0: 'fsca', 1: 'fshade', 2: 'dust_concentration', 3: 'grain_size'})

In [None]:
results['fsca'] = (results['fsca']*100).astype(np.uint8)
results['fshade'] = (results['fshade']*100).astype(np.uint8)
results['dust_concentration'] = (results['dust_concentration']).astype(np.uint16)
results['grain_size'] = (results['grain_size']).astype(np.uint16)

In [None]:
results

# Compute

In [None]:
%%time
results = results.compute()

In [None]:
%%time
zarr_store = f'/tablespace/sentinel2/{region}_results.zarr'
results.to_zarr(zarr_store, mode='w', compute=True)    

In [None]:
client.close()
cluster.close()

# Write to NC

In [None]:
zarr_store = f'/tablespace/sentinel2/{region}_results.zarr'
results = xarray.open_zarr(zarr_store)

In [None]:
%%time
# Not written in parallel. Can do on one worker

nc_file = f'/tablespace/sentinel2/{region}_results.nc'
#results.to_netcdf(nc_file, mode='w')

compression_opts = {"zlib": True, "complevel": 5}
results.to_netcdf(nc_file, mode='w',  encoding={var: compression_opts for var in results.variables})

# Plots

In [None]:
ds_r0['reflectance'].sel(band=['B4', 'B4', 'B3']).plot.imshow()

In [None]:
fix, ax = plt.subplots(1,2, figsize=(12, 5))
time = ds.time.isel(time=0)

ds['reflectance'].sel(band=['B4', 'B3', 'B2'], time=time).plot.imshow(ax=ax[0])
results.sel(time=time)['fsca'].plot.imshow(interpolation=None, ax=ax[1])

In [None]:
fix, ax = plt.subplots(1,2, figsize=(12, 5))
time = ds.time.isel(time=1)

ds['reflectance'].sel(band=['B4', 'B3', 'B2'], time=time).plot.imshow(ax=ax[0])
results.sel(time=time)['fsca'].plot.imshow(interpolation=None, ax=ax[1])

In [None]:
1

In [None]:
can you hear me?