In [None]:
import shutil
import rasterio
import numpy as np
import rasterio.warp
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import os
import tqdm

In [None]:
import sys
sys.path.append('..')

In [None]:
import mysecrets
import fetch_from_cluster
import create_stack
import rsutils.modify_images as modify_images
import rsutils.utils

In [None]:
test_grid_id = 'grid25'

remotepath = f'/gpfs/data1/cmongp2/sasirajann/fetch_satdata/data/gee_compare/{test_grid_id}.zip'
stack_folderpath = f'../data/{test_grid_id}_stack'

In [None]:
zip_filepath = fetch_from_cluster.download_file_from_cluster(
    sshcreds = mysecrets.SSH_UMD_SASIRAJANN,
    remotepath = remotepath,
    download_folderpath = '../data/cluster_files',
    # overwrite = True,
)

shutil.unpack_archive(zip_filepath, stack_folderpath)

bands, metadata = create_stack.load_stack(stack_folderpath)

band_indices = {bandname:index for index, bandname in enumerate(metadata['bands'])}

In [None]:
def gee_tif_to_ndarray(
    gee_tif_filepath:str,
    bands:str,
):
    with rasterio.open(gee_tif_filepath) as src:
        raster_ndarray = src.read()
        raster_meta = src.meta.copy()
        raster_desc = src.descriptions
    
    band_ts_array = dict()
    unique_ts = set()
    unique_bandnames = set()
    for index, channel_name in enumerate(raster_desc):
        ts, bandname = channel_name.split('_')
        ts = int(ts)
        bandname = bandname if len(bandname) == 3 else 'B0' + bandname[-1]
        unique_ts.add(ts)
        unique_bandnames.add(bandname)
        if bandname not in band_ts_array.keys():
            band_ts_array[bandname] = {}
        band_ts_array[bandname][ts] = raster_ndarray[index]
    
    raster_ndarray_reformatted = []

    for ts in range(len(unique_ts)):
        bandstack = []
        for bandname in bands:
            bandstack.append(band_ts_array[bandname][ts])
        bandstack = np.stack(bandstack, axis=-1)
        raster_ndarray_reformatted.append(bandstack)
        del bandstack
    
    raster_ndarray_reformatted = np.stack(raster_ndarray_reformatted, axis=0)

    return raster_ndarray_reformatted, raster_meta

In [None]:
def write_all_tifs_out(
    bands:np.ndarray, 
    bandnames:list[str], 
    out_meta:dict, 
    folderpath:str,
):
    n_ts, _, _, _ = bands.shape
    band_indices = {bandname:index for index, bandname in enumerate(bandnames)}

    filepaths = {}

    os.makedirs(folderpath, exist_ok=True)

    for ts in range(n_ts):
        filepaths[ts] = {}
        
        for bandname, band_index in band_indices.items():
            tif_filepath = os.path.join(folderpath, f'{bandname}_{str(ts).zfill(3)}.tif')
            
            with rasterio.open(tif_filepath, 'w', **out_meta) as dst:
                dst.write(np.expand_dims(bands[ts,:,:,band_index], axis=0))
            
            filepaths[ts][bandname] = tif_filepath
    
    return filepaths

In [None]:
s3_filepaths = write_all_tifs_out(
    bands = bands, 
    bandnames = metadata['bands'],
    out_meta = metadata['geotiff_metadata'],
    folderpath = f'../data/gee_compare/s3_output/{test_grid_id}',
)

In [None]:
raster_ndarrays = {}

for gee_filename in tqdm.tqdm([
    'S2L1C_onlyCloudProb_grid_25',
    'S2L1C_defaultSettings_grid_25',
    'S2L2A_onlyCloudProb_grid_25',
    'S2L2A_defaultSettings_grid_25',
]):
    gee_filepath = f'../data/gee_compare/{gee_filename}.tif'
    resampled_raster_filepath = rsutils.utils.modify_filepath(
        filepath = gee_filepath, 
        prefix = 'resampled_',
    )
    
    modify_images.resample_by_ref(
        src_filepath = gee_filepath,
        dst_filepath = resampled_raster_filepath,
        ref_filepath = s3_filepaths[0]['B08'],
    )

    raster_ndarray, raster_meta = gee_tif_to_ndarray(
        gee_tif_filepath = resampled_raster_filepath,
        bands = metadata['bands'],
    )

    raster_ndarrays[gee_filename] = raster_ndarray

    del raster_ndarray

In [None]:
# print(gee_filename)
# print('---')
# for index, band in enumerate(metadata['bands']):
#     rmse = np.sqrt(np.nanmean((bands[:,:,:,index] - raster_ndarray[:,:,:,index] - 1000) ** 2))
#     mae = np.nanmean(np.abs(bands[:,:,:,index] - raster_ndarray[:,:,:,index] - 1000))
#     print(band, f'rmse={round(rmse, 2)}, mae={round(mae, 2)}')

In [None]:
def plot_band_hist(
    bands_list:list[np.ndarray],
    bands_names:list[str],
    bands_nodata:list,
    bands_biases:list,
    id_name:str,
    val_name:str,
    scale:float = 5,
    aspect_ratio:int = 1,
    title:str = None,
    bins = 'auto',
):
    fig, ax = plt.subplots(figsize=(scale*aspect_ratio, scale))

    data = {
        val_name: [],
        id_name: [],
    }

    for i in range(len(bands_list)):
        if bands_nodata[i] is not None:
            if np.isnan(bands_nodata[i]):
                bands_list[i] = bands_list[i][~np.isnan(bands_list[i])]
            else:
                bands_list[i] = bands_list[i][bands_list[i] != bands_nodata[i]]

        bands_flatten_i = list((bands_list[i].flatten() + bands_biases[i]).astype(int))
        data[val_name] += bands_flatten_i
        data[id_name] += [bands_names[i] for _ in range(len(bands_flatten_i))]
        del bands_flatten_i

    g = sns.histplot(
        ax = ax,
        data = data,
        x = val_name,
        hue = id_name,
        stat = 'count',
        bins = bins,
        kde = True,
    )

    if title is not None:
        _ = g.set_title(title)
    
    return fig

In [None]:
overwrite = True

plots_folderpath = f'../data/gee_compare/plots/'
os.makedirs(plots_folderpath, exist_ok=True)

for band_name in tqdm.tqdm(metadata['bands']):
    plot_filepath = os.path.join(plots_folderpath, f'{band_name}.png')

    if os.path.exists(plot_filepath) and not overwrite:
        continue  

    fig = plot_band_hist(
        bands_list = [
            bands[:,:,:,band_indices[band_name]],
            raster_ndarrays['S2L1C_onlyCloudProb_grid_25'][:,:,:,band_indices[band_name]],
            raster_ndarrays['S2L1C_defaultSettings_grid_25'][:,:,:,band_indices[band_name]],
            raster_ndarrays['S2L2A_onlyCloudProb_grid_25'][:,:,:,band_indices[band_name]],
            raster_ndarrays['S2L2A_defaultSettings_grid_25'][:,:,:,band_indices[band_name]],
        ],
        bands_names = [
            'S3, S2L1C',
            'GEE, S2L1C, onlyCloudProb',
            'GEE, S2L1C, defaultSettings',
            'GEE, S2L2A, onlyCloudProb',
            'GEE, S2L2A, defaultSettings',
        ],
        bands_nodata = [
            0, np.nan, np.nan, np.nan, np.nan,
        ],
        bands_biases =  [
            -1000, 0, 0, 0, 0
        ],

        val_name = band_name,
        id_name = 'source',

        bins = 100,

        title = f'GEE - S3 comparison, grid25, {band_name}',

        aspect_ratio = 2,
    )

    fig.savefig(plot_filepath, bbox_inches='tight')
