In [None]:
import numpy as np
import h5py

from matplotlib import pyplot as plt

from scipy.interpolate import interp1d, splrep, splev

In [None]:
train_data = h5py.File('/work/ka1176/shared_data/2021-ai4food/dev_data/south-africa/sentinel-2/extracted/train_data.h5', 'r')

In [None]:
image_stack = train_data['image_stack']
print(image_stack.shape)

In [None]:
clp = image_stack[:, :, -1]

In [None]:
clp.shape

In [None]:
# lets just select one pixel time series

In [None]:
s_ix = 2000 # sample
p_ix = 43  # pixel
b_ix = 2   # band

In [None]:
bands = image_stack[s_ix, :, b_ix, p_ix]
cloud = image_stack[s_ix, :, -1,   p_ix] * 1e4 / 255
print(bands.shape)

band_2 = image_stack[s_ix, :, 1, p_ix] # Blue
band_3 = image_stack[s_ix, :, 2, p_ix] # Green
band_4 = image_stack[s_ix, :, 3, p_ix] # Red
band_8 = image_stack[s_ix, :, 7, p_ix] # NIR

band_8A = image_stack[s_ix, :, 8, p_ix]
band_11 = image_stack[s_ix, :, 10, p_ix]
band_12 = image_stack[s_ix, :, 11, p_ix]

bands = (band_8 - band_4) / (band_8 + band_4) # ndvi
#bands = (band_8A - band_4) / (band_8A + band_4) # ndvi

#bands = (band_8A - band_11) / (band_11 - band_8A) # moisture index

bands = (band_8A - (band_11 - band_12)) / (band_8A + (band_11 - band_12)) # nmdi normalized multi band drought index

In [None]:
plt.plot(bands)
plt.plot(cloud)

In [None]:
#clp_thresholds = [0.05, 0.075, 0.1, 0.125, 0.15, 0.2, 1.1]
clp_thresholds = [0.05, 0.1, 0.125, 1.1]
timesteps = np.arange(len(bands))

fig, axs = plt.subplots(1, len(clp_thresholds), figsize=(22,4), sharex=True, sharey=True)

for ax, clpt in zip(axs, clp_thresholds):
    
    good_ix = np.where(cloud < clpt)[0]
    
    first_is_bad = False
    last_is_bad  = False
    
    if not 0 in good_ix:
        good_ix = list([0]) + list(good_ix)
        first_is_bad = True
    if not len(cloud) - 1 in good_ix:
        good_ix = list(good_ix) + list([len(cloud) - 1])
        last_is_bad = True
    
    good_x = timesteps[good_ix]
    good_y = bands[good_ix]
    
    if first_is_bad:
        good_y[0] = good_y[1]
    if last_is_bad:
        good_y[-1] = good_y[-2]
    
    f = interp1d(good_x, good_y, kind='linear')
    
    interp_y = f(timesteps)
    
    weights = (1.0 - cloud[good_ix])**(20)
    g = splrep(good_x, good_y, k=3, w=weights)
    interp_y1 = splev(timesteps, g)
    
    ax.plot(timesteps[good_ix], bands[good_ix], 'o--')
    ax.plot(timesteps, interp_y, ':')
    ax.plot(timesteps, interp_y1, ':')
    
    g = splrep(good_x, good_y, k=3)
    interp_y1 = splev(timesteps, g)
    ax.plot(timesteps, interp_y1, ':')
    
    ax.set_title(clpt*100)
    
    ax.set_xlabel('Time step in season')
    ax.set_ylabel('NDVI')
fig.tight_layout()

In [None]:
plt.scatter(cloud[good_ix], weights)