In [None]:
import rioxarray as riox
import xarray as xr
import os
import geopandas as gpd
from datetime import datetime, timedelta

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.signal import find_peaks

In [None]:
from hlsstack.hls_funcs.indices import ndvi_func
from hlsstack.hls_funcs.predict import pred_cp
from hlsstack.models.load import load_model
from hlsstack.hls_funcs.smooth import smooth_xr, despike_ts_xr

In [None]:
#import sys
#sys.path.insert(1, '/project/cper_neon_aop/hls_nrt_utils/hlsstack/hls_funcs')
#sys.path.insert(1, '/project/cper_neon_aop/hls_nrt_utils/hlsstack/models')
#from predict import pred_bm, pred_bm_se
#from indices import ndvi_func
#from smooth import smooth_xr, despike_ts_xr
#from load import load_model

In [None]:
from tqdm import tqdm

In [None]:
prefix = 'cper'

inDIR = os.path.join('/90daydata/cper_neon_aop/cper_hls_veg_models/data/lmf_cper')
outDIR = '/90daydata/cper_neon_aop/hls_nrt/cper/lmf_cp/'
if not os.path.exists(outDIR):
    os.mkdir(outDIR)
overwrite=False

# the path to a shapefile with CPER pasture boundaries
cper_f = '/project/cper_neon_aop/cper_hls_veg_models/data/ground/boundaries/cper_pastures_2017_dissolved.shp'

In [None]:
# prepare the pasture boundaries to be converted into an xarray mask
cper = gpd.read_file(cper_f).to_crs(32613)
#cper_info = cper[['Pasture', 'geometry']].reset_index(drop=True).reset_index().rename(columns={'index': 'id'})
#past_dict = {row.id+1: row.Pasture for _, row in cper_info.iterrows()}
#past_dict[0] = 'UNK'
#cper_mask_shp = [(row.geometry, row.id+1) for _, row in cper_info.iterrows()]

In [None]:
from dask.distributed import LocalCluster, Client
import dask
from jupyter_server import serverapp
try:
    jupServer = [x for x in serverapp.list_running_servers()][0]
    dask.config.set({'distributed.dashboard.link': 'https://atlas-ood.hpc.msstate.edu' + jupServer['base_url'] + 'proxy/{port}/status'})
except:
    dask.config.set({'distributed.dashboard.link': 'https://atlas-ood.hpc.msstate.edu' + '/node/atlas-0024/6142/' + 'proxy/{port}/status'})
    pass
cluster = LocalCluster(n_workers=48)
client = Client(cluster)
display(client)

In [None]:
#cluster.close()
#client.close()

In [None]:
# Distance away from the FBEWMA that data should be removed.
DELTA = 250

# clip data above this value:
HIGH_CLIP = 10000

# clip data below this value:
LOW_CLIP = 0

# random values above this trigger a spike:
RAND_HIGH = 0.98

# random values below this trigger a negative spike:
RAND_LOW = 0.02

# How many samples to run the FBEWMA over.
SPAN = 45

# spike amplitude
SPIKE = 2


def clip_data(unclipped, high_clip, low_clip):
    ''' Clip unclipped between high_clip and low_clip. 
    unclipped contains a single column of unclipped data.'''
    
    # convert to np.array to access the np.where method
    np_unclipped = np.array(unclipped)
    # clip data above HIGH_CLIP or below LOW_CLIP
    cond_high_clip = (np_unclipped > HIGH_CLIP) | (np_unclipped < LOW_CLIP)
    np_clipped = np.where(cond_high_clip, np.nan, np_unclipped)
    return np_clipped.tolist()


def create_sample_data():
    ''' Create sine wave, amplitude +/-2 with random spikes. '''
    x = np.linspace(0, 2*np.pi, 1000)
    y = 2 * np.sin(x)
    df = pd.DataFrame(list(zip(x,y)), columns=['x', 'y'])
    df['rand'] = np.random.random_sample(len(x),)
    # create random positive and negative spikes
    cond_spike_high = (df['rand'] > RAND_HIGH)
    df['spike_high'] = np.where(cond_spike_high, SPIKE, 0)
    cond_spike_low = (df['rand'] < RAND_LOW)
    df['spike_low'] = np.where(cond_spike_low, -SPIKE, 0)
    df['y_spikey'] = df['y'] + df['spike_high'] + df['spike_low']
    return df


def ewma_fb(df_column, span):
    ''' Apply forwards, backwards exponential weighted moving average (EWMA) to df_column. '''
    # Forwards EWMA.
    fwd = pd.Series.ewm(df_column, span=span).mean()
    # Backwards EWMA.
    bwd = pd.Series.ewm(df_column[::-1],span=10).mean()
    # Add and take the mean of the forwards and backwards EWMA.
    stacked_ewma = np.vstack(( fwd, bwd[::-1] ))
    fb_ewma = np.mean(stacked_ewma, axis=0)
    return fb_ewma
    
    
def remove_outliers(spikey, fbewma, delta):
    ''' Remove data from df_spikey that is > delta from fbewma. '''
    np_spikey = np.array(spikey)
    np_fbewma = np.array(fbewma)
    cond_delta = (np.abs(np_spikey-np_fbewma) > delta)
    np_remove_outliers = np.where(cond_delta, np.nan, np_spikey)
    return np_remove_outliers

    
def main():
    df = create_sample_data()

    df['y_clipped'] = clip_data(df['y_spikey'].tolist(), HIGH_CLIP, LOW_CLIP)
    df['y_ewma_fb'] = ewma_fb(df['y_clipped'], SPAN)
    df['y_remove_outliers'] = remove_outliers(df['y_clipped'].tolist(), df['y_ewma_fb'].tolist(), DELTA)
    df['y_interpolated'] = df['y_remove_outliers'].interpolate()
    
    ax = df.plot(x='x', y='y_spikey', color='blue', alpha=0.5)
    ax2 = df.plot(x='x', y='y_interpolated', color='black', ax=ax)
    
def remove_spikes(ts, HIGH_CLIP, LOW_CLIP, SPAN, DELTA):
    df = pd.DataFrame({'y_spikey': ts})
    #df['y_spikey'] = df['y_spikey'].interpolate()
    df['y_clipped'] = clip_data(df['y_spikey'].tolist(), HIGH_CLIP, LOW_CLIP)
    df['y_ewma_fb'] = ewma_fb(df['y_clipped'], SPAN)
    df['y_remove_outliers'] = remove_outliers(df['y_clipped'].tolist(), df['y_ewma_fb'].tolist(), DELTA)
    df['y_interpolated'] = df['y_remove_outliers'].interpolate()
    return df['y_remove_outliers'].isnull().astype(int).values

In [None]:
def despike_lmf_mask(band_ts, mask_pos_spikes=False, min_prominence=0.2, low_prominence=0.05, max_width=None):
    try:
        if mask_pos_spikes:
            ts = pd.Series(band_ts).interpolate()
        else:
            ts = pd.Series(band_ts*-1.0).interpolate()
        pks = find_peaks(ts, prominence=low_prominence, width=[0, max_width])
        for i in range(len(pks[0])):
            if pks[1]['prominences'][i] > min_prominence:
                #print(i)
                pk = pks[0][i]
                if i > 0 and pks[1]['prominences'][i-1] < min_prominence:
                    left_b = max(pks[1]['left_bases'][i], pks[1]['right_bases'][i-1])
                else:
                    left_b = pks[1]['left_bases'][i]
                if i < len(pks) and pks[1]['prominences'][i+1] < min_prominence:
                    right_b = min(pks[1]['right_bases'][i], pks[1]['left_bases'][i+1])
                else:
                    right_b = pks[1]['right_bases'][i]

                if max_width is None:
                    drop_pk = False
                else:
                    if (right_b - left_b) > max_width:
                        drop_pk = True
                    else:
                        drop_pk = False
                if not drop_pk:
                    ts_sub = ts[left_b:right_b]
            
                    slp, xcpt = np.polyfit([left_b, right_b], ts_sub[[left_b, right_b-1]], 1)
                    trend = ts_sub.iloc[0] + slp*np.arange(0, len(ts_sub))
                
                    trend_diff = ts_sub - trend
                    ts_sub_masked = ts_sub.copy()
                    ts_sub_masked[trend_diff > 0] = np.nan
                    ts.loc[ts_sub_masked[ts_sub_masked.isnull()].index] = np.nan
                else:
                    continue
        return ts.isnull().astype(int).values
    except (KeyError, IndexError) as e:
        return np.ones_like(band_ts)
        

def despike_lmf_mask_xr(dat, dims, kwargs):
    xr_mask = xr.apply_ufunc(remove_spikes,
                             dat.stack(z = ['y', 'x']).chunk({'time': -1, 'z': 20}),
                             kwargs=kwargs,
                             input_core_dims=[dims],
                             output_core_dims=[dims],
                             dask='parallelized', vectorize=True,
                             output_dtypes=[int])
    return xr_mask.transpose('time', 'z').unstack()

In [10]:
for yr in tqdm(range(2000, 2015)):
    out_path_exists = os.path.exists(os.path.join(outDIR, 
                                              'cper_lmf_cp_'+str(yr)+'.nc'))
    in_path_exists = os.path.exists(os.path.join(inDIR, 'CPER_'+str(yr)+'.Landsat_MODIS_STARFM.nc'))
    if not overwrite and out_path_exists:
        continue
    elif not in_path_exists:
        continue
    else:
        print(yr)
        nc_f = os.path.join(inDIR, 'CPER_'+str(yr)+'.Landsat_MODIS_STARFM.nc')
        hls_ds = xr.open_dataset(nc_f, chunks={'DOY': -1, 'y': 20, 'x': 20})
        
        if hls_ds.rio.crs != cper.crs:
            hls_ds = hls_ds.rio.reproject(cper.crs)
        # convert the band coordinate to date
        hls_ds['DOY'] = [datetime(yr, 1, 1) + timedelta(days=int(x)-1) for x in hls_ds['DOY'].values]
        # rename band coordinate to date
        hls_ds = hls_ds.rename({'DOY': 'time'})
        # subset to only CPER boundaries
        hls_ds = hls_ds.sel(x=slice(cper.total_bounds[0], cper.total_bounds[2] + 30),
                            y=slice(cper.total_bounds[3], cper.total_bounds[1] - 30))
        
        hls_ds = hls_ds.where(hls_ds != -9999)
        
        hls_ds['NDVI'] = ndvi_func(hls_ds)
        hls_ds['NDVI'] = hls_ds['NDVI'].where(hls_ds['NDVI'] > 0)
        
        ndvi_mask = despike_lmf_mask_xr(hls_ds['NDVI'], dims=['time'], kwargs={'HIGH_CLIP': 1.0,
                                                                               'LOW_CLIP': 0.0,
                                                                               'SPAN': 2,
                                                                               'DELTA': 0.07})
        
        ndvi_mask_out = ndvi_mask.compute()
        
        hls_ds = hls_ds.where(ndvi_mask_out==0)
        
        blue_mask = despike_lmf_mask_xr(hls_ds['BLUE'], dims=['time'], kwargs={'HIGH_CLIP': 10000,
                                                                               'LOW_CLIP': 0,
                                                                               'SPAN': 20,
                                                                               'DELTA': 100})
        
        blue_mask_out = blue_mask.compute()
        
        swir2_mask = despike_lmf_mask_xr(hls_ds['SWIR2'], dims=['time'], kwargs={'HIGH_CLIP': 10000,
                                                                               'LOW_CLIP': 0,
                                                                               'SPAN': 20,
                                                                               'DELTA': 200})
        
        swir2_mask_out = swir2_mask.compute()
        
        hls_ds = hls_ds.where(blue_mask_out==0)
        hls_ds = hls_ds.where(swir2_mask_out==0)             
        
        dat_out_ndvi_ds = despike_ts_xr(hls_ds['NDVI'],
                                   dat_thresh=0.07, 
                                   mask_outliers=False,
                                   iters=2,
                                   dims=['time']).persist()
        
        dat_out_ndvi = smooth_xr(dat_out_ndvi_ds, 
                               dims=['time'], 
                               kwargs={'double': True, 'limit': 91})

        mod_cp = load_model('cper_cp')
        # create biomass array
        dat_out_da = pred_cp(dat_out_da_ndvi, model=mod_cp)
        dat_out_da.name = 'CP'
        
        dat_out_da = dat_out_da.rio.write_crs(hls_ds.rio.crs)
        dat_out_da = dat_out_da.where(dat_out_da > 0)
        dat_out_da = dat_out_da.astype('float32')
        dat_out_da = dat_out_da.sortby(["time", "y", "x"]).sortby('y', ascending=False)
        
        display(dat_out_da)
        
        dat_out_da.to_netcdf(os.path.join(outDIR, 
                                          'cper_lmf_cp_'+str(yr)+'.nc'))
        os.chmod(os.path.join(outDIR, 'cper_lmf_cp_'+str(yr)+'.nc'), 0o777)

100%|██████████| 15/15 [00:00<00:00, 174.58it/s]
