In [None]:
import os
import shutil
n_threads = 96
os.environ['OMPI_MCA_rmaps_base_oversubscribe'] = '1'
os.environ['USE_PYGEOS'] = '0'
os.environ['PROJ_LIB'] = '/opt/conda/share/proj/'
os.environ['NUMEXPR_MAX_THREADS'] = f'{n_threads}'
os.environ['NUMEXPR_NUM_THREADS'] = f'{n_threads}'
os.environ['OMP_THREAD_LIMIT'] = f'{n_threads}'
os.environ["OMP_NUM_THREADS"] = f'{n_threads}'
os.environ["OPENBLAS_NUM_THREADS"] = f'{n_threads}' # export OPENBLAS_NUM_THREADS=4 
os.environ["MKL_NUM_THREADS"] = f'{n_threads}' # export MKL_NUM_THREADS=6
os.environ["VECLIB_MAXIMUM_THREADS"] = f'{n_threads}' # export VECLIB_MAXIMUM_THREADS=4
import gc
from datetime import datetime
from osgeo import gdal, gdal_array
from pathlib import Path
from typing import Callable, Iterator, List,        Union
import bottleneck as bn
import geopandas as gpd
import numpy as np
import pandas as pd
import skmap_bindings
import tempfile
import time
import sys
import csv
from scipy.signal import savgol_coeffs
import numpy as np
from skmap.io import process
import matplotlib.pyplot as plt
import random
from mpi4py import MPI

gdal_opts = {
 'GDAL_HTTP_VERSION': '1.0',
 'CPL_VSIL_CURL_ALLOWED_EXTENSIONS': '.tif',
}

co = ['TILED=YES', 'BIGTIFF=YES', 'COMPRESS=DEFLATE', 'BLOCKXSIZE=1024', 'BLOCKYSIZE=1024']

executor = None

def ttprint(*args, **kwargs):
    from datetime import datetime
    import sys

    print(f'[{datetime.now():%H:%M:%S}] ', end='')
    print(*args, **kwargs, flush=True)

def make_tempdir(basedir='skmap', make_subdir = True):
    tempdir = Path(TMP_DIR).joinpath(basedir)
    if make_subdir: 
        name = Path(tempfile.NamedTemporaryFile().name).name
        tempdir = tempdir.joinpath(name)
    tempdir.mkdir(parents=True, exist_ok=True)
    return tempdir

def make_tempfile(basedir='skmap', prefix='', suffix='', make_subdir = False):
    tempdir = make_tempdir(basedir, make_subdir=make_subdir)
    return tempdir.joinpath(
        Path(tempfile.NamedTemporaryFile(prefix=prefix, suffix=suffix).name).name
    )
    
def get_SWAG_weights(att_env, att_seas, season_size, n_imag):
    conv_mat_row = np.zeros((n_imag))
    base_func = np.zeros((season_size,))
    period_y = season_size/2.0
    slope_y = att_seas/10/period_y
    for i in np.arange(season_size):
        if i <= period_y:
            base_func[i] = -slope_y*i
        else:
            base_func[i] = slope_y*(i-period_y)-att_seas/10
    # Compute the envelop to attenuate temporarly far images
    env_func = np.zeros((n_imag,))
    delta_e = n_imag
    slope_e = att_env/10/delta_e
    for i in np.arange(delta_e):
        env_func[i] = -slope_e*i
        conv_mat_row = 10.0**(np.resize(base_func,n_imag) + env_func)
    return conv_mat_row    


years = range(1997,2023)
# Do not chance the sizes, with the current setting the warping needs the size of the entire tile
x_size, y_size = (4004, 4004) 
x_off, y_off = (0,0)
n_pix = x_size * y_size
no_data = 255
bands_list = [1,]
n_imag_per_year = 6
out_index_offset = 0
n_years = len(years)
n_s = n_years*n_imag_per_year
backend = 'Matrix'

bands_prefix = ['blue_glad',
                'green_glad',
                'red_glad',
                'nir_glad',
                'swir1_glad',
                'swir2_glad',
                'thermal_glad']
n_band = len(bands_prefix)
nir_range = range(n_s*3, n_s*4)
red_range = range(n_s*2, n_s*3)
NDVI_range = range(n_s*7, n_s*8)
MODIS_offset = 10000.
MODIS_scaling = 250. / 20000.
MODIS_agg_factor = 2
band_scaling = 0.004
result_scaling = 125.
result_offset = 125.
file_ending = '_go_epsg.4326_v20230908.tif'
m_end = [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31];
month_start = ['0101','0301','0501','0701','0901','1101']
month_end = ['0228' ,'0430' ,'0630' ,'0831' ,'1031' ,'1231']
compression_command = f"gdal_translate -a_nodata {no_data} -co COMPRESS=deflate -co ZLEVEL=9 -co TILED=TRUE -co BLOCKXSIZE=1024 -co BLOCKYSIZE=1024"
att_env, att_seas, future_scaling = (20.0, 40.0, 0.1)
w_l, p_o = (2, 2)
diff_th, count_th = (35, int(n_s/5))
# tiles = ['009E_04N', '009E_51N', '013E_61N', '050W_07S', '085W_52N', '091W_37N', '115E_03S', '127E_42N']
# tiles = ['013E_61N', '013E_60N', '013E_62N', '012E_61N', '012E_60N', '012E_62N', '014E_61N', '014E_60N', '014E_62N'] # Norway 
tiles = ['013E_61N', '013E_60N',]
for tile in tiles:
    
    ttprint(f"Processing tile {tile}")
    landsat_files_in = []
    modis_mosaics = []
    for b in bands_prefix:
        for year in years:
            for bimonth in range(n_imag_per_year):
                    landsat_files_in.append(f'/vsicurl/http://192.168.49.{random.randint(30,43)}:8333/prod-landsat-ard2/{tile}/agg/' + \
                                    f'{b}.ard2_m_30m_s_{year}{month_start[bimonth]}_{year}{month_end[bimonth]}{file_ending}')    
    for year in years:
        if year < 2000:
            year = 2000
        for m in range(12):
            modis_mosaics.append(f'/vsicurl/http://192.168.49.{random.randint(30,43)}:8333/global/veg/' + 
                    f'veg_ndvi_mod13q1.v061_p50_250m_s0..0cm_{year}.{str(m+1).zfill(2)}.01..{year}.{str(m+1).zfill(2)}.{m_end[m]}_v2.tif' )

    start = time.time()
    agg_data = np.empty((n_s*(n_band+1), x_size * y_size), dtype=np.float32)
    skmap_bindings.readData(agg_data, n_threads, landsat_files_in, range(len(landsat_files_in)), x_off, y_off, x_size, y_size, bands_list, gdal_opts, no_data, np.nan)
    ttprint(f"Tile {tile} - Reading Landsat data: {(time.time() - start):.2f} segs")
    
    start = time.time()
    modis_month_data = np.empty((n_s*2,agg_data.shape[1]), dtype=np.float32)
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    size = comm.Get_size()
    if size != 1:
        print("This example requires exactly one parent process.")
        comm.Abort(1)
    n_proc = len(modis_mosaics)
    intercomm = MPI.COMM_SELF.Spawn(
        sys.executable, args=['child.py', ",".join(landsat_files_in[0:n_proc]), ",".join(modis_mosaics), str(n_threads), str(n_pix)], maxprocs=n_proc)
    received_array = np.empty(n_pix, dtype=np.float32)
    for i in range(n_proc):
        intercomm.Recv(received_array, source=MPI.ANY_SOURCE, tag=i)
        skmap_bindings.copyVecInMatrixRow(modis_month_data, n_threads, received_array, i)
    
    ttprint(f"Tile {tile} - Warping and scaling MODIS data: {(time.time() - start):.2f} segs")
    
    start = time.time()
    modis_month_data_t = np.empty((modis_month_data.shape[1], modis_month_data.shape[0]), dtype=np.float32)
    skmap_bindings.swapRowsValues(modis_month_data, n_threads, range(modis_month_data.shape[0]), -3000, np.nan)
    skmap_bindings.transposeArray(modis_month_data, n_threads, modis_month_data_t)
    n_aggr = int(np.ceil(float(modis_month_data_t.shape[1])/float(MODIS_agg_factor)))
    modis_NDVI_t = np.empty((modis_month_data_t.shape[0], n_aggr), dtype=np.float32)
    skmap_bindings.averageAggregate(modis_month_data_t, n_threads, modis_NDVI_t, MODIS_agg_factor)
    skmap_bindings.offsetAndScale(modis_NDVI_t, n_threads, MODIS_offset, MODIS_scaling)
    ttprint(f"Tile {tile} - Masking, aggregating and scaling MODIS data: {(time.time() - start):.2f} segs")   
    
    start = time.time()
    skmap_bindings.computeNormalizedDifference(agg_data, n_threads,
                                nir_range, red_range, NDVI_range,
                                band_scaling, band_scaling, result_scaling, result_offset, [0., 250.])
    agg_NDVI = np.empty((n_s, x_size * y_size), dtype=np.float32)
    skmap_bindings.extractArrayRows(agg_data, n_threads, agg_NDVI, NDVI_range)
    ttprint(f"Tile {tile} - Compute Landsat NDVI: {(time.time() - start):.2f} segs")

    start = time.time()
    agg_NDVI_dirty = agg_NDVI.copy()
    agg_NDVI_t = np.empty((agg_NDVI.shape[1], agg_NDVI.shape[0]), dtype=np.float32)
    skmap_bindings.transposeArray(agg_NDVI, n_threads, agg_NDVI_t)
    mask_t = np.empty(agg_NDVI_t.shape, dtype=np.float32)
    skmap_bindings.maskDifference(agg_NDVI_t, n_threads, diff_th, count_th, modis_NDVI_t, mask_t)
    ttprint(f"Tile {tile} - Get artifact mask from NDVI: {(time.time() - start):.2f} segs")

    start = time.time()
    clean_data = np.empty((n_s, x_size * y_size), dtype=np.float32)
    clean_data_t = np.empty((x_size * y_size, n_s), dtype=np.float32)
    rec_data = np.empty((n_s, x_size * y_size), dtype=np.float32)
    rec_data_t = np.empty((x_size * y_size, n_s), dtype=np.float32)
    w_p = (get_SWAG_weights(att_env, att_seas, n_imag_per_year, n_s)[1:][::-1]).astype(np.float32)
    w_f = (get_SWAG_weights(att_env, att_seas, n_imag_per_year, n_s)[1:]).astype(np.float32)*future_scaling
    w_0 = 1.0
    out_dir = f'data/{tile}'
    os.makedirs(out_dir, exist_ok=True)
    for b in range(n_band+1):
        if b == n_band:
            band_label = "ndvi_glad"
        else:            
            band_label = bands_prefix[b]
        # ttprint(f"Tile {tile} - band {band_label}")
        skmap_bindings.extractArrayRows(agg_data, n_threads, clean_data, range(n_s*b, n_s*(b+1)))
        skmap_bindings.transposeArray(clean_data, n_threads, clean_data_t)
        skmap_bindings.maskData(clean_data_t, n_threads, range(clean_data_t.shape[0]), mask_t, 1., np.nan)
        skmap_bindings.applySircle(clean_data_t, n_threads, rec_data_t, 0, w_0, w_p, w_f, True, "v2", backend)
        skmap_bindings.transposeArray(rec_data_t, n_threads, rec_data)
        out_files = []
        out_s3 = [ f'gaia/prod-landsat-ard2/{s3_prefix}/{tile}/swag' for o in out_files ]
        for year in years:
            for bimonth in range(n_imag_per_year):                
                out_files.append(f'{band_label}.SWAG.ard2_m_30m_s_{year}{month_start[bimonth]}_{year}{month_end[bimonth]}_go_epsg.4326_v20240621')
        skmap_bindings.writeByteData(rec_data, n_threads, gdal_opts, landsat_files_in[0:len(out_files)], out_dir, out_files, range(len(out_files)),
            x_off, y_off, x_size, y_size, no_data, compression_command, out_s3)
        
    ttprint(f"Tile {tile} - Removing artefacts, reconstructing with SWAG and saving: {(time.time() - start):.2f} segs")
    
    data_to_plot = (agg_NDVI_t.transpose(),  modis_NDVI_t.transpose(),  mask_t.transpose(), rec_data)
    var_names = ("Agg. NDVI", "MODIS", "Mask", "Rec NDVI")
    for year in [2022,]:
    # crappy years: 2022, 2018, 2017, 2013
        fig, axes = plt.subplots(len(data_to_plot), n_imag_per_year, figsize=(12, 2.3*len(data_to_plot)+0.5))
        for i, tmp_data in enumerate(data_to_plot):
            start_index = n_s - n_imag_per_year - n_imag_per_year * (2022-year)        
            for j in range(n_imag_per_year):
                if i == 2:
                    axes[i, j].imshow(tmp_data.reshape(tmp_data.shape[0], x_size, y_size)[start_index + j], cmap='gnuplot2', vmin=0, vmax=1)
                else:
                    axes[i, j].imshow(tmp_data.reshape(tmp_data.shape[0], x_size, y_size)[start_index + j], cmap='gnuplot2', vmin=0, vmax=250)

                axes[i, j].set_title(f"{var_names[i]} - {month_start[j]}")
                axes[i, j].set_xticks([])
                axes[i, j].set_yticks([])

        fig.suptitle(f"NDVI {year}")
        plt.tight_layout()
        # fig.savefig(f'{tile}_NDVI_{year}.png', bbox_inches='tight')
        plt.show()
intercomm.Disconnect()
MPI.Finalize()
    

[10:40:09] Processing tile 013E_61N
[10:40:16] Tile 013E_61N - Reading Landsat data: 6.93 segs
[10:41:37] Tile 013E_61N - Warping and scaling MODIS data: 81.93 segs
[10:41:42] Tile 013E_61N - Masking, aggregating and scaling MODIS data: 4.99 segs
[10:41:44] Tile 013E_61N - Compute Landsat NDVI: 1.90 segs
[10:41:54] Tile 013E_61N - Get artifact mask from NDVI: 9.15 segs
