In [None]:
import os
import shutil
n_threads = 48
os.environ['USE_PYGEOS'] = '0'
os.environ['PROJ_LIB'] = '/opt/conda/share/proj/'
os.environ['NUMEXPR_MAX_THREADS'] = f'{n_threads}'
os.environ['NUMEXPR_NUM_THREADS'] = f'{n_threads}'
os.environ['OMP_THREAD_LIMIT'] = f'{n_threads}'
os.environ["OMP_NUM_THREADS"] = f'{n_threads}'
os.environ["OPENBLAS_NUM_THREADS"] = f'{n_threads}' # export OPENBLAS_NUM_THREADS=4 
os.environ["MKL_NUM_THREADS"] = f'{n_threads}' # export MKL_NUM_THREADS=6
os.environ["VECLIB_MAXIMUM_THREADS"] = f'{n_threads}' # export VECLIB_MAXIMUM_THREADS=4
import gc
from datetime import datetime
from osgeo import gdal, gdal_array
from pathlib import Path
from typing import Callable, Iterator, List,        Union
import bottleneck as bn
import geopandas as gpd
import numpy as np
import pandas as pd
import skmap_bindings
import tempfile
import time
import sys
import csv
from scipy.signal import savgol_coeffs
import numpy as np
from skmap.io import process
import matplotlib.pyplot as plt

gdal_opts = {
 'GDAL_HTTP_VERSION': '1.0',
 'CPL_VSIL_CURL_ALLOWED_EXTENSIONS': '.tif',
}

co = ['TILED=YES', 'BIGTIFF=YES', 'COMPRESS=DEFLATE', 'BLOCKXSIZE=1024', 'BLOCKYSIZE=1024']

executor = None

def ttprint(*args, **kwargs):
    from datetime import datetime
    import sys

    print(f'[{datetime.now():%H:%M:%S}] ', end='')
    print(*args, **kwargs, flush=True)

def make_tempdir(basedir='skmap', make_subdir = True):
    tempdir = Path(TMP_DIR).joinpath(basedir)
    if make_subdir: 
        name = Path(tempfile.NamedTemporaryFile().name).name
        tempdir = tempdir.joinpath(name)
    tempdir.mkdir(parents=True, exist_ok=True)
    return tempdir

def make_tempfile(basedir='skmap', prefix='', suffix='', make_subdir = False):
    tempdir = make_tempdir(basedir, make_subdir=make_subdir)
    return tempdir.joinpath(
        Path(tempfile.NamedTemporaryFile(prefix=prefix, suffix=suffix).name).name
    )
    
def get_SWAG_weights(att_env, att_seas, season_size, n_imag):
    conv_mat_row = np.zeros((n_imag))
    base_func = np.zeros((season_size,))
    period_y = season_size/2.0
    slope_y = att_seas/10/period_y
    for i in np.arange(season_size):
        if i <= period_y:
            base_func[i] = -slope_y*i
        else:
            base_func[i] = slope_y*(i-period_y)-att_seas/10
    # Compute the envelop to attenuate temporarly far images
    env_func = np.zeros((n_imag,))
    delta_e = n_imag
    slope_e = att_env/10/delta_e
    for i in np.arange(delta_e):
        env_func[i] = -slope_e*i
        conv_mat_row = 10.0**(np.resize(base_func,n_imag) + env_func)
    return conv_mat_row    

def apllaySWAG(data_in, att_env, att_seas, season_size, future_scaling):
    n_s = data_in.shape[0]
    w_p = (get_SWAG_weights(att_env, att_seas, season_size, n_s)[1:][::-1]).astype(np.float32)
    w_f = (get_SWAG_weights(att_env, att_seas, season_size, n_s)[1:]).astype(np.float32)*future_scaling
    w_0 = 1.0
    out_index_offset = 0
    data_in_t = np.empty((data_in.shape[1], data_in.shape[0]), dtype=np.float32)
    data_rec_t = np.empty((data_in.shape[1], data_in.shape[0]), dtype=np.float32)
    data_rec = np.empty(data_in.shape, dtype=np.float32)
    skmap_bindings.transposeArray(data_in, n_threads, data_in_t)
    skmap_bindings.applySircle(data_in_t, n_threads, data_rec_t, out_index_offset, w_0, w_p, w_f, True, "v2", backend)
    skmap_bindings.transposeArray(data_rec_t, n_threads, data_rec)
    return data_rec

def apllaySG(data_in, w_l, p_o):
    SG_coef = savgol_coeffs(w_l*2+1, p_o, use='dot')
    w_0_SG = SG_coef[w_l].astype(np.float32)
    w_f_SG = SG_coef[w_l+1:].astype(np.float32)
    w_p_SG = SG_coef[0:w_l].astype(np.float32)
    out_index_offset = 0
    data_in_t = np.empty((data_in.shape[1], data_in.shape[0]), dtype=np.float32)
    data_rec_t = np.empty((data_in.shape[1], data_in.shape[0]), dtype=np.float32)
    data_rec = np.empty(data_in.shape, dtype=np.float32)
    skmap_bindings.transposeArray(data_in, n_threads, data_in_t)
    skmap_bindings.applySircle(data_in_t, n_threads, data_rec_t, out_index_offset, w_0_SG, w_p_SG, w_f_SG, False, "v1", backend)
    skmap_bindings.transposeArray(data_rec_t, n_threads, data_rec)
    return data_rec

years = range(1997,2023)
x_size, y_size = (4004, 4004)
x_off, y_off = (0,0)

no_data = 255
bands_list = [1,]
file_ending = '_go_epsg.4326_v20230908.tif'
n_imag_per_year = 6
out_index_offset = 0
n_years = len(years)
n_s = n_years*n_imag_per_year
backend = 'Matrix'

# bands_prefix = ['blue_glad.ard2_m_30m_s_',
#                 'green_glad.ard2_m_30m_s_',
#                 'red_glad.ard2_m_30m_s_',
#                 'nir_glad.ard2_m_30m_s_',
#                 'swir1_glad.ard2_m_30m_s_',
#                 'swir2_glad.ard2_m_30m_s_',
#                 'thermal_glad.ard2_m_30m_s_']

bands_prefix = ['red_glad.ard2_m_30m_s_',
                'nir_glad.ard2_m_30m_s_']

month_start = ['0101'
           ,'0301'
           ,'0501'
           ,'0701'
           ,'0901'
           ,'1101']
month_end = ['0228'
           ,'0430'
           ,'0630'
           ,'0831'
           ,'1031'
           ,'1231']
compression_command = f"gdal_translate -a_nodata {no_data} -co COMPRESS=deflate -co ZLEVEL=9 -co TILED=TRUE -co BLOCKXSIZE=1024 -co BLOCKYSIZE=1024"

att_env, att_seas, future_scaling = (20.0, 40.0, 0.1)
w_l, p_o = (2, 2)
diff_th = 35
count_th = int(n_s/4)
# tiles = ['009E_04N', '009E_51N', '013E_61N', '050W_07S', '085W_52N', '091W_37N', '115E_03S', '127E_42N']
tiles = ['085W_52N']
for tile in tiles:
    
    ttprint(f"Processing tile {tile}")

    in_files = []
    modis_files = []
    for b in bands_prefix:
        for year in years:
            for bimonth in range(n_imag_per_year):
                    in_files.append(f'data/{tile}/landsat_agg/{b}{year}{month_start[bimonth]}_{year}{month_end[bimonth]}{file_ending}')
    for year in years:
        if year < 2000:
            year = 2000
        elif year > 2021:
            year = 2021
        for m in range(12):
            modis_files.append(f'data/{tile}/modis/modis_ndvi_{year}{str(m+1).zfill(2)}01.tif')

    
    agg_data = np.empty((n_s*3, x_size * y_size), dtype=np.float32)
    modis_month_data = np.empty((n_s*2,agg_data.shape[1]), dtype=np.float32)

    start = time.time()
    skmap_bindings.readData(agg_data, n_threads, in_files, range(len(in_files)), x_off, y_off, x_size, y_size, bands_list, gdal_opts, no_data, np.nan)
    skmap_bindings.readData(modis_month_data, n_threads, modis_files, range(len(modis_files)), x_off, y_off, x_size, y_size, bands_list, gdal_opts, -32768, np.nan)
    offset = 10000
    scaling = 250 / 20000
    skmap_bindings.offsetAndScale(modis_month_data, n_threads, offset, scaling)
    modis_month_data_t = np.empty((modis_month_data.shape[1],modis_month_data.shape[0]), dtype=np.float32)
    skmap_bindings.transposeArray(modis_month_data, n_threads, modis_month_data_t)
    agg_factor = 2
    n_aggr = int(np.ceil(float(modis_month_data_t.shape[1])/float(agg_factor)))
    modis_NDVI_t = np.empty((modis_month_data_t.shape[0], n_aggr), dtype=np.float32)
    modis_NDVI = np.empty((modis_NDVI_t.shape[1],modis_NDVI_t.shape[0]), dtype=np.float32)
    skmap_bindings.averageAggregate(modis_month_data_t, n_threads, modis_NDVI_t, agg_factor)
    skmap_bindings.transposeArray(modis_NDVI_t, n_threads, modis_NDVI)
    ttprint(f"Tile {tile} - Reading data: {(time.time() - start):.2f} segs")

    start = time.time()
    band_scaling = 0.004
    result_scaling = 125.
    result_offset = 125.
    skmap_bindings.computeNormalizedDifference(agg_data, n_threads,
                                range(n_s*1, n_s*2), range(n_s*0, n_s*1), range(n_s*2, n_s*3),
                                band_scaling, band_scaling, result_scaling, result_offset, [0., 250.])
    
    agg_NDVI = np.empty((n_s, x_size * y_size), dtype=np.float32)
    skmap_bindings.extractArrayRows(agg_data, n_threads, agg_NDVI, range(n_s*2, n_s*3))
    
    
    # del agg_data
    # gc.collect()
    ttprint(f"Tile {tile} - Compute NDVI: {(time.time() - start):.2f} segs")


    start = time.time()
    rec_NDVI = apllaySWAG(agg_NDVI, att_env, att_seas, n_imag_per_year, future_scaling)
    agg_NDVI_clean = np.empty((n_s, x_size * y_size), dtype=np.float32)
    skmap_bindings.extractArrayRows(agg_NDVI, n_threads, agg_NDVI_clean, range(n_s))
    skmap_bindings.maskDifference(agg_NDVI_clean, n_threads, diff_th, count_th, modis_NDVI)
    rec_NDVI_clean = apllaySWAG(agg_NDVI_clean, att_env, att_seas, n_imag_per_year, future_scaling)
    smoooth_NDVI_clean = apllaySG(rec_NDVI_clean, w_l, p_o)
    ttprint(f"Tile {tile} - Reconstructing with SWAG: {(time.time() - start):.2f} segs")


    # start = time.time()
    # data_to_save = (rec_NDVI, rec_NDVI_clean, smoooth_NDVI_clean, modis_NDVI)
    # data_names = ("SWAG", "SWAG_clean", "SWAG_SG_clean", "MODIS")
    # for data_tmp, data_name in zip(data_to_save, data_names):
    #     out_files = []
    #     for year in years:
    #         for bimonth in range(n_imag_per_year):
    #                 out_files.append(f'NDVI_{data_name}.{year}{month_start[bimonth]}_{year}{month_end[bimonth]}')
    #     skmap_bindings.writeByteData(data_tmp, n_threads, gdal_opts, in_files[0:len(out_files)], f'data/{tile}/out', out_files, range(len(out_files)),
    #         x_off, y_off, x_size, y_size, 255, compression_command)
    # ttprint(f"Tile {tile} - Saving data: {(time.time() - start):.2f} segs")

    data_to_plot = (agg_NDVI, rec_NDVI, agg_NDVI_clean, rec_NDVI_clean, smoooth_NDVI_clean, modis_NDVI)
    var_names = ("Aggr.", "Rec.", "Agg. clean", "Rec. clean", "Rec. SG clean", "MODIS")
    # for year in years:
    for year in [2022, 2018, 2017, 2013]:
    # crappy years: 2022, 2018, 2017, 2013
        fig, axes = plt.subplots(len(data_to_plot), n_imag_per_year, figsize=(12, 2*len(data_to_plot)+0.5))
        for i, tmp_data in enumerate(data_to_plot):
            start_index = n_s - n_imag_per_year - n_imag_per_year * (2022-year)        
            for j in range(n_imag_per_year):
                axes[i, j].imshow(tmp_data.reshape(tmp_data.shape[0], x_size, y_size)[start_index + j], cmap='gnuplot2', vmin=0, vmax=250)
                axes[i, j].set_title(f"{var_names[i]} - {month_start[j]}")
                axes[i, j].set_xticks([])
                axes[i, j].set_yticks([])

        fig.suptitle(f"NDVI {year}")
        plt.tight_layout()
        # fig.savefig(f'{tile}_NDVI_{year}.png', bbox_inches='tight')
        plt.show()
    # del agg_NDVI, rec_NDVI, agg_NDVI_clean, rec_NDVI_clean, smoooth_NDVI_clean, modis_NDVI, tmp_data
    # gc.collect()
    

[11:23:53] Processing tile 085W_52N
[11:23:53] Tile 085W_52N - Reading data: 0.10 segs
[11:23:53] Tile 085W_52N - Compute NDVI: 0.01 segs
[11:23:53] T1 segs
