# Script to make the gap filling of the stacks

In [None]:
import numpy as np
from scipy import interpolate
from scipy.ndimage.morphology import binary_dilation 
import rasterio as r
from tqdm.notebook import tqdm
import os
import gdal
import multiprocessing as mp
from functools import partial
import datetime
import time

In [None]:
def CubicSpline(ts, days, bc_t, extrap):
    # ts = time series with clouds
    # days = day entry for each ts entry. must be crescent.
    # bc_t = CubicSpline bc_type. Boundary condition type. natural
    # extrap = CubicSpline extrapolate. bool
    # result = filtered time series

    # 'true' is where there is invalid pixels
    pqa = ts==-9999

    # x1: array for the Spline function
    x1 = days[np.invert(pqa)]

    # y1: array for the Spline function
    y1 = ts[np.invert(pqa)]
    
    if len(x1)>1:
        # the spline interpolator
        spline = interpolate.CubicSpline(x1,y1, bc_type=bc_t, extrapolate=extrap)

        # x values to interpolate
        x2 = np.where(pqa==True)

        result = ts.copy()
        for index in x2:
            result[index] = spline(days[index])
            
        return result
    else:
        return [-9999]*len(ts)

In [None]:
def update_log(string):
    with open("/home/bruno.matosak/log.txt", "a") as f:
        f.write(string)  

In [None]:
update_log(time.asctime()+'\n')
###############

os.chdir('/home/bruno.matosak/IGARSS2023/Pantanal/Landsat SR all/cubes')

cubes = ['band2.tif','band3.tif','band4.tif','band5.tif','band6.tif','band7.tif']
mask_path = 'pixel_qa.tif'


################
# DAYS
################
ref = gdal.Open(cubes[0])
days = np.arange(ref.RasterCount)
del ref

################
# LOADING FMASK
################
print('Loading mask...')
update_log('Loading mask...\n')

fmask = r.open(mask_path).read()

# cloud masking
vals = [322, 386, 834, 898, 1346, 324, 388, 836, 900, 1348]

mask = fmask>-987876

for val in vals:
    mask[fmask==val]=False
    
del fmask

################
# PROCESSING EACH CUBE
################
for id in cubes:
    t1 = time.time()
    print('--------------------------------------------------------------')
    print('                  PROCESSING '+id)
    update_log('--------------------------------------------------------------\n'+
               f'                  PROCESSING {id}\n')

    ##################
    # PATHS
    ##################
    cube_path = id
    save_path = id.replace('.tif', '_filled.tif')

    ################
    # LOADING CUBE
    ################
    print('Loading Cube '+id+'...')
    update_log('Loading Cube '+id+'...\n')

    cube = np.asarray(r.open(cube_path).read())
    cube[mask]=-9999

    ################
    # PROCESSING
    ################
    print('Processing...')
    update_log('Processing...\n')

    n = cube.shape[0]
    series_to_process = []
    ij = []
    count = 0
    n_to_process = 5000000
    for i in tqdm(range(cube.shape[1])):
        for j in range(cube.shape[2]):
            series_to_process.append(cube[:,i,j])
            ij.append([i,j])
            count +=1

            if count == n_to_process or (i==cube.shape[1]-1 and j==cube.shape[2]-1):
                with mp.Pool(processes=mp.cpu_count()-10) as pool:
                    result_series = pool.starmap(partial(CubicSpline), [(series_to_process[k], days, 'natural', bool) for k in range(len(ij))])
                for k in range(len(ij)):
                    cube[:,ij[k][0],ij[k][1]] = result_series[k]

                count = 0
                series_to_process = []
                ij = []

    ################
    # SAVING
    ################
    print('Saving...')
    update_log('Saving...\n')

    ref2 = gdal.Open(cube_path)
    in_band = ref2.GetRasterBand(1)

    gtiff_driver = gdal.GetDriverByName('GTiff')
    path_result = save_path

    print('File Location: '+path_result)
    update_log('File Location: '+path_result+'\n')

    out_ds = gtiff_driver.Create(path_result, in_band.XSize, in_band.YSize, cube.shape[0], in_band.DataType, ['COMPRESSION=LZW'])
    out_ds.SetProjection(ref2.GetProjection())
    out_ds.SetGeoTransform(ref2.GetGeoTransform())

    for i in range(1, cube.shape[0]+1, 1):
        band = out_ds.GetRasterBand(i)
        band.SetNoDataValue(-9999)
        band.WriteArray(cube[i-1,:,:])
        band.FlushCache()

    out_ds = None
    ref2 = None
    t2 = time.time()
    print('Elapsed time: %.3f minutes\nDone!'%((t2-t1)/60))
    update_log('Elapsed time: %.3f minutes\nDone!\n'%((t2-t1)/60))

print('---------------------------------------------------------\nAll done!')
update_log('---------------------------------------------------------\nAll done!\n')

update_log('==============================================================\nPROCESSING FINISHED!\n')