# Analyze Samples
## Imports

In [None]:
import os
import random

import numpy as np
import matplotlib.pyplot as plt
import sklearn.metrics
from tqdm.auto import tqdm

from tlib import tgeo, tutils
from asos import settings, utils

%load_ext autoreload
%autoreload 2

## Setup Files

In [None]:
channel_indices = {'B2': 0, 'B3': 1, 'B4': 2, 'B5': 3, 'B6': 4, 'B7': 5, 'B8': 6, 'B8A': 7, 'B11': 8, 'B12': 9}

either files from dataset(s)...

In [None]:
datasets = ['test']

csv = settings.load_file_infos()
file_infos = csv.df
file_infos = file_infos[file_infos['dataset'].isin(datasets)]
print(len(file_infos))
files = file_infos.index

# edit file names
files = list(files)
files = ['/'.join(file.split('/')[1:]) for file in files]  # remove True / False folder in the beginning
files = ['_'.join(file.split('_')[:-1]) + '.tif' for file in files]  # remove tile number in the end
print(len(files))
files = list(set(files))  # remove doubles
print(f'# files: {len(files)}')
files = [os.path.join(settings.data_folder_raw, file) for file in files]  # add file directory in front of files

batch_size = 1

... or specific file(s):

In [None]:
folder = settings.data_folder_investigative
regex_filter = '.tif'  # e.g. '.tif', 'wdpa-Ia', ''wdpa-V_6926.tif'

files = tutils.files.files_from_folder(folder, regex_filter=regex_filter, whole_path=True)
print(f'# files: {len(files)}')

batch_size = 1

## Predict

In [None]:
threshold = 0  # 2e-6

plot = False
val_range = (0, 2**11)

asos = utils.load_asos()

all_seg_maps = []
all_wdpa_masks = []
files_sublists = tutils.lists.create_sublists(files, size=batch_size)
for files_ in tqdm(files_sublists, desc='batch', disable=plot):
    
    # sensitivity maps
    unet_maps = utils.predict(*files_, disable_tqdm=True)
    sens_maps = asos.predict_sensitivities(unet_maps, disable_tqdm=True)
    
    # segmentation maps
    seg_maps = np.ma.zeros(sens_maps.shape).astype(bool)  # same shape as sens_maps
    seg_maps.mask = sens_maps.mask.copy()  # same mask as sens_maps
    
    seg_maps[sens_maps < 0] = False  # define segmentation
    seg_maps[sens_maps > 0] = True  # define segmentation
    
    seg_maps.mask[np.abs(sens_maps) <= np.abs(threshold)] = True  # add masking according to given threshold
    
    # get wdpa masks
    wdpa_masks = [tgeo.geotif.get_array(file=file) for file in utils.get_corresponding_files(files_, 'mask')]
    wdpa_masks = np.vstack(wdpa_masks)
    wdpa_masks = wdpa_masks.astype(bool)
    wdpa_masks = np.ma.masked_array(wdpa_masks, mask=seg_maps.mask)  # same mask as seg_maps
    
    # plot
    if plot:
        for index in range(len(files_)):
            fig, axs = plt.subplots(1, 4, figsize=(12, 3))

            # s2 image
            channels = (channel_indices['B2'], channel_indices['B3'], channel_indices['B4'])
            tgeo.geotif.plot_image(file=files_[index], channels=channels, val_range=val_range, ax=axs[0])

            # wdpa mask
            axs[1].imshow(wdpa_masks.data[index], cmap=asos.cmap, clim=(0, 1))

            # sensitivity map
            cmax = np.quantile(np.abs(sens_maps[index]), 0.98)
            axs[2].imshow(sens_maps[index], cmap=asos.cmap, clim=(-cmax, cmax))

            # segmentation map
            axs[3].imshow(seg_maps[index], cmap=asos.cmap, clim=(0, 1))

            for ax in axs:
                ax.axis(False)
            fig.tight_layout()
            plt.show()
            
    all_seg_maps.append(seg_maps)
    all_wdpa_masks.append(wdpa_masks)
all_seg_maps = np.ma.vstack(all_seg_maps)
all_wdpa_masks = np.ma.vstack(all_wdpa_masks)

seg_maps = all_seg_maps
wdpa_masks = all_wdpa_masks

## Metrics

In [None]:
n_unmasked_pixels = np.count_nonzero(~seg_maps.mask)
print(f'unmaked pixels: {n_unmasked_pixels / seg_maps.size * 100:.1f} %')

In [None]:
# accuracy
np.count_nonzero(seg_maps.data[~seg_maps.mask] == wdpa_masks.data[~wdpa_masks.mask]) / n_unmasked_pixels * 100

In [None]:
# confusion matrix
sklearn.metrics.confusion_matrix(
    y_true=wdpa_masks.data[~wdpa_masks.mask], y_pred=seg_maps.data[~seg_maps.mask], normalize='all') * 100

In [None]:
# intersection over union
iou_0 = np.count_nonzero((seg_maps == 0) & (wdpa_masks == 0)) / np.count_nonzero((seg_maps == 0) | (wdpa_masks == 0)) * 100
iou_1 = np.count_nonzero((seg_maps == 1) & (wdpa_masks == 1)) / np.count_nonzero((seg_maps == 1) | (wdpa_masks == 1)) * 100

print(
    f'IoU 0: {iou_0:.1f} %\n'
    f'IoU 1: {iou_1:.1f} %'
)