# Document for the retrieval of the correct Interquartile range for the dataset

In this document we provide the code to identify the correct interquartile range for both S1 SAR and DEM images downloaded by the dataset.   
Since the values can be calculated once for all, the codde has been saved in a notebook format, so to have an easy retrieval of the charts and the results.

In [None]:
from pathlib import Path

import numpy as np
from plotille import histogram
from tqdm import tqdm
from glob import glob
import rasterio
from typing import Tuple

In [None]:
def imread(path: Path, channels_first: bool = True) -> np.ndarray:
    """Wraps rasterio open functionality to read the numpy array and exit the context.

    Args:
        path (Path): path to the geoTIFF image
        channels_first (bool, optional): whether to return it channels first or not. Defaults to True.

    Returns:
        np.ndarray: image array
    """
    with rasterio.open(str(path), mode="r", driver="GTiff") as src:
        image = src.read()
    return image if channels_first else image.transpose(1, 2, 0)

def getStats(data: np.array) -> Tuple[float, float]:
    print(histogram(data))
    return np.percentile(data, [25, 75])

def find_IQR(dataset_path: str):
    # get the images from the folder
    sar_files = glob(dataset_path + '/train/sar/*.tif')
    dem_files = glob(dataset_path + '/train/dem/*.tif')
    mask_files = glob(dataset_path + '/train/mask/*.tif')

    # Take only a sample
    vv_list = np.zeros(shape=(1))
    vh_list = np.zeros(shape=(1))
    dem_list = np.zeros(shape=(1))

    assert len(sar_files) == len(dem_files), f'Number of files not matching, SAR: {len(sar_files)}, DEM: {len(dem_files)}'
    for i, sar_path in enumerate(tqdm(sar_files)):
        if(i % 100 == 0):
            sar = imread(Path(sar_path))
            dem = imread(Path(dem_files[i]))
            mask = imread(Path(mask_files[i]))

            valid = mask.squeeze(0) != 255
            dem = dem[:, valid]
            sar = sar[:, valid]

            vv_list = np.concatenate((vv_list, sar[0]), axis=0)
            vh_list = np.concatenate((vh_list, sar[1]), axis=0)
            dem_list = np.concatenate((dem_list, dem[0]), axis=0)

    print('Histogram VV:')
    print('0,1,2,5,50,95,98,99,100:')
    print(*getStats(vv_list))
    print('Histogram VH:')
    print('0,1,2,5,50,95,98,99,100:')
    print(*getStats(vh_list))
    print('Histogram DEM:')
    print('0,1,2,5,50,95,98,99,100:')
    print(*getStats(dem_list))
    return