In [None]:
import os
from pathlib import Path
from pydantic import BaseSettings
from matplotlib import  pyplot as plt

import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader

In [None]:
class StatsConfig(BaseSettings):
    class Config:
        env_file = ".env"
        env_file_encoding = "utf-8"
    data_processed: Path
    subset: str = "train"

In [None]:
cwd = Path(os.getcwd())
os.chdir(str(cwd.parent))
print(os.getcwd())

In [None]:
cfg = StatsConfig()
cfg

In [None]:
from floods.datasets.flood import FloodDataset
from floods.prepare import train_transforms, inverse_transform
from floods.utils.gis import rgb_ratio

In [None]:
transform = train_transforms(image_size=512, mean=FloodDataset.mean(), std=FloodDataset.std(), clip_min=FloodDataset.clip_min(), clip_max=FloodDataset.clip_max())
dataset = FloodDataset(path=cfg.data_processed, subset="train", include_dem=True, transform=transform)
loader = DataLoader(dataset, batch_size=5, num_workers=4, pin_memory=True, shuffle=False)
invert = inverse_transform(mean=FloodDataset.mean(), std=FloodDataset.std())

In [None]:
indices = np.random.choice(len(loader), size=50, replace=False)
indices

In [None]:
for i, (images, label) in tqdm(enumerate(loader)):
    if i not in indices:
        continue
    else:
        images = invert(images)
        f, axes = plt.subplots(1, 5, figsize=(20, 5))
        for j in range(images.shape[0]):
            rgb = rgb_ratio(images[4-j], weights=(0.6, 1.1, 0.005))
            axes[j].imshow(rgb)
        plt.show()