# 04 - dataset blancement

In [None]:
# Importing libraries

import os
import random
import numpy as np
import rasterio
import matplotlib.pyplot as plt
from shutil import copy2, rmtree
from scipy.ndimage import zoom
from typing import Dict, List, Tuple, Union

In [None]:
# Utils functions 

def get_water_percentage(dir_name: str) -> int:
    """
    Extract the water percentage from the directory name.

    Args:
        dir_name (str): The name of the directory containing water percentage information.

    Returns:
        int: The water percentage extracted from the directory name.
    """
    return int(dir_name.split('_')[2])

def load_dataset(base_dir: str) -> Dict[str, List[Tuple[str, int]]]:
    """
    Load dataset from the base directory.

    Args:
        base_dir (str): The base directory containing the dataset.

    Returns:
        Dict[str, List[Tuple[str, int]]]: A dictionary where keys are scene names and values are lists of tuples
                                          with directory names and corresponding water percentages.
    """
    dataset = {}
    for scene in os.listdir(base_dir):
        scene_path = os.path.join(base_dir, scene)
        if os.path.isdir(scene_path):
            dataset[scene] = []
            for dir_name in os.listdir(scene_path):
                water_percentage = get_water_percentage(dir_name)
                dataset[scene].append((dir_name, water_percentage))
    return dataset

def bin_dataset(dataset: Dict[str, List[Tuple[str, int]]], bin_size: int) -> Dict[str, Dict[int, List[str]]]:
    """
    Bin the dataset based on water percentage.

    Args:
        dataset (Dict[str, List[Tuple[str, int]]]): The dataset to bin.
        bin_size (int): The size of each bin for water percentage.

    Returns:
        Dict[str, Dict[int, List[str]]]: A dictionary where keys are scene names, and values are dictionaries
                                          with bin indices as keys and lists of directory names as values.
    """
    binned_dataset = {}
    for scene, dirs in dataset.items():
        binned_dataset[scene] = {}
        for dir_name, water_percentage in dirs:
            bin_index = water_percentage // bin_size
            if bin_index not in binned_dataset[scene]:
                binned_dataset[scene][bin_index] = []
            binned_dataset[scene][bin_index].append(dir_name)
    return binned_dataset

def balance_dataset(binned_dataset: Dict[str, Dict[int, List[str]]]) -> Dict[str, Dict[int, List[str]]]:
    """
    Balance the dataset by limiting the number of samples in each bin to the mean plus half the standard deviation.

    Args:
        binned_dataset (Dict[str, Dict[int, List[str]]]): The binned dataset to balance.

    Returns:
        Dict[str, Dict[int, List[str]]]: A balanced dataset with the number of samples in each bin adjusted.
    """
    balanced_dataset = {}
    for scene, bins in binned_dataset.items():
        num_members = [len(bins[bin_index]) for bin_index in bins]
        mean = np.mean(num_members)
        std = np.std(num_members)
        threshold = mean + (std / 2)
        
        balanced_dataset[scene] = {}
        for bin_index, dirs in bins.items():
            if len(dirs) > threshold:
                dirs = random.sample(dirs, int(threshold))
            balanced_dataset[scene][bin_index] = dirs
    return balanced_dataset

def apply_augmentations(data: np.ndarray, augmentation: str) -> np.ndarray:
    """
    Apply the specified augmentation to the given data.

    Args:
        data (np.ndarray): The data to augment.
        augmentation (str): The type of augmentation to apply ('hf', 'vf', 'sc', 'tr', 'ro').

    Returns:
        np.ndarray: The augmented data.
    """
    if augmentation == 'hf':
        return np.fliplr(data)
    elif augmentation == 'vf':
        return np.flipud(data)
    elif augmentation == 'sc':
        scale_factor = 1.2
        zoom_factor = (scale_factor, scale_factor)
        scaled_data = zoom(data, zoom_factor, order=1)  # Bilinear interpolation
        start_x = (scaled_data.shape[0] - 256) // 2
        start_y = (scaled_data.shape[1] - 256) // 2
        return scaled_data[start_x:start_x + 256, start_y:start_y + 256]
    elif augmentation == 'tr':
        shift = 10
        return np.roll(data, shift, axis=(0, 1))
    elif augmentation == 'ro':
        return np.rot90(data, 2)
    else:
        raise ValueError(f"Unknown augmentation type: {augmentation}")
