In [None]:
import os
from pathlib import Path
from pydantic import BaseSettings
from matplotlib import  pyplot as plt

import numpy as np
from tqdm import tqdm
from scipy.stats import iqr
from torch.utils.data import DataLoader

In [None]:
class StatsConfig(BaseSettings):
    class Config:
        env_file = ".env"
        env_file_encoding = "utf-8"
    data_processed: Path
    subset: str = "train"

In [None]:
cwd = Path(os.getcwd())
os.chdir(str(cwd.parent))
print(os.getcwd())

In [None]:
cfg = StatsConfig()
cfg

In [None]:
import importlib
from floods.datasets import flood
from floods import prepare

importlib.reload(flood)
importlib.reload(prepare)


In [None]:
dataset = flood.FloodDataset(path=cfg.data_processed, subset="train", include_dem=True)
loader = DataLoader(dataset, batch_size=64, num_workers=4, pin_memory=True, shuffle=False)
compute_stats = False

In [None]:

# Measure the percentile intervals and then estimate Standard Deviation of the distribution, both from median to the 90th percentile and from the 10th to 90th percentile
if compute_stats:
    p25s = list()
    p75s = list()
    p50s = list()
    minval = np.ones(3) * np.finfo(np.float32).max
    maxval = np.ones(3) * np.finfo(np.float32).min


    # compute a robust standard deviation using 10th and 90th percentile
    for image, label in tqdm(loader):
        image = image.numpy().reshape(-1, 3)
        valid = label.flatten() != 255
        image = image[valid]
        minval = np.minimum(minval, np.min(image, axis=0))
        maxval = np.maximum(maxval, np.max(image, axis=0))

        p75, p25 = np.percentile(image, (75, 25), axis=0)
        p25s.append(p25)
        p75s.append(p75)
        p50s.append(np.median(image, axis=0))

    p25 = np.stack(p25s).mean(axis=0)
    p75 = np.stack(p75s).mean(axis=0)
    p50 = np.stack(p50s).mean(axis=0)

    iqr = p75 - p25
    sigma = iqr / 1.34896

In [None]:
p50 = np.array([4.9329374e-02, 1.1776519e-02, 1.4241237e+02])
sigma = np.array([3.91287043e-02, 1.03687926e-02, 8.11010422e+01])
print(f"std: {sigma}")
print(f"median: {p50}")
# print(f"min: {[f'{v:.4f}' for v in minval]}")
# print(f"max: {[f'{v:.4f}' for v in maxval]}")

In [None]:
if compute_stats:
    factor = 2

    clip_min = p50 - factor * sigma
    clip_max = p50 + factor * sigma
    # store values
    means = list()
    stds = list()
    # compute robust mean and std on data outside (factor x) iqr
    for image, label in tqdm(loader):
        image = image.reshape(-1, 3)
        valid = label.flatten() != 255
        image = image[valid]

        image = np.clip(image, clip_min, clip_max)
        means.append(image.mean(axis=0))
        stds.append(image.std(axis=0))

    means = np.stack(means).mean(axis=0)
    stds = np.stack(stds).mean(axis=0)
    print(f"avg: {means}")
    print(f"std: {stds}")


In [None]:
factor = 10
# clip_min = p50 - factor * sigma
# clip_max = p50 + factor * sigma
# print(clip_max, clip_min)
clip_min = np.array([-50.0, -50.0, -50.0])
clip_max = np.array([+50, 1, 6000.0])


In [None]:
import albumentations as alb
from albumentations.pytorch import ToTensorV2

In [None]:
class ClipNormalize(alb.Normalize):

    def __init__(self,
                 mean: tuple,
                 std: tuple,
                 clip_min: tuple,
                 clip_max: tuple,
                 max_pixel_value: float = 1.0,
                 always_apply: bool = False,
                 p: float = 1.0):
        super().__init__(mean=mean, std=std, max_pixel_value=max_pixel_value, always_apply=always_apply, p=p)
        self.min = clip_min
        self.max = clip_max

    def apply(self, image, **params):
        result = super().apply(image=image, **params)
        return np.clip(result, self.min, self.max)

    def get_transform_init_args_names(self):
        parent = list(super().get_transform_init_args_names())
        return tuple(parent + ["clip_min", "clip_max"])


In [None]:
base_trf = alb.Compose([ClipNormalize(mean=p50, std=sigma, clip_min=-30, clip_max=30),
                        ToTensorV2()])

In [None]:
dataset_cls = flood.FloodDataset
dataset2 = dataset_cls(path=cfg.data_processed, subset="train", include_dem=True, transform_base=base_trf)
loader2 = DataLoader(dataset2, batch_size=32, num_workers=4, pin_memory=True, shuffle=True)
plot_loader = loader2

In [None]:
import seaborn as sns


# compute robust mean and std on data outside (factor x) iqr
for i, (batch, label) in tqdm(enumerate(plot_loader)):
    if i >= 5:
        break
    batch = batch.numpy()
    batch = np.swapaxes(batch, 0, 1)
    print(batch.shape)

    sns.histplot(batch[0].flatten(), bins=500)
    plt.title("vv")
    plt.show()
    sns.histplot(batch[1].flatten(), bins=500)
    plt.title("vh")
    plt.show()
    sns.histplot(batch[2].flatten(), bins=500)
    plt.title("dem")
    plt.show()


In [None]:
# means = list()
# stds = list()
# # compute robust mean and std on data outside (factor x) iqr
# for image, label in tqdm(loader2):
#     image = image.reshape(-1, 3)
#     valid = label.flatten() != 255
#     image = image[valid]

#     means.append(image.mean(axis=0))
#     stds.append(image.std(axis=0))

# means = np.stack(means).mean(axis=0)
# stds = np.stack(stds).mean(axis=0)
# print(f"avg: {means}")
# print(f"std: {stds}")