In [None]:
import os
from pathlib import Path
from pydantic import BaseSettings
from matplotlib import  pyplot as plt

import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader

In [None]:
class StatsConfig(BaseSettings):
    class Config:
        env_file = ".env"
        env_file_encoding = "utf-8"
    data_processed: Path
    subset: str = "train"

In [None]:
cwd = Path(os.getcwd())
os.chdir(str(cwd.parent))
print(os.getcwd())

In [None]:
cfg = StatsConfig()
cfg

In [None]:
from floods.datasets.flood import FloodDataset
from floods.prepare import train_transforms_base, train_transforms_dem, train_transforms_sar, eval_transforms, inverse_transform
from floods.utils.gis import rgb_ratio

In [None]:
mean = FloodDataset.mean()
std = FloodDataset.std()

base_trf = train_transforms_base(image_size=512)
sar_trf = train_transforms_sar()
dem_trf = train_transforms_dem()
normalize = eval_transforms(mean=mean,
                            std=std,
                            clip_min=-30,
                            clip_max=30)
# create train and validation sets
dataset = FloodDataset(path=cfg.data_processed,
                       subset="train",
                       include_dem=True,
                       transform_base=base_trf,
                       transform_sar=sar_trf,
                       transform_dem=dem_trf,
                       normalization=normalize)
loader = DataLoader(dataset, batch_size=2, num_workers=4, pin_memory=False, shuffle=False)
invert = inverse_transform(mean=FloodDataset.mean(), std=FloodDataset.std())

In [None]:
indices = np.random.choice(len(loader), size=5, replace=False)
indices = [0, 1, 2, 3, 4]


In [None]:
for i, (images, label) in tqdm(enumerate(loader)):
    if i not in indices:
        continue
    else:
        
        images = invert(images)
        f, axes = plt.subplots(1, 6, figsize=(30, 10))
        print('vv', images[:,:,:,0].min(), images[:,:,:,0].max())
        print('vh', images[:,:,:,1].min(), images[:,:,:,1].max())
        for j in range(images.shape[0]):
            # rgb = rgb_ratio(images[j], weights=(0.6, 1.1, 0.005))
            axes[j * 3].imshow(images[j, :, :,0]/images[:,:,:,0].max(), cmap="gray")
            axes[j * 3 + 1].imshow(images[j][:,:,-1])
            axes[j * 3 + 2].imshow(label[j])
        plt.show()
    if(i == 4):
        break

In [None]:
images, label = dataset.__getitem__(0)
images = invert(images)

f, axes = plt.subplots(1, 3, figsize=(15, 5))
rgb = rgb_ratio(images, weights=(5, 10, 0.1))
axes[0].imshow(rgb)
axes[ 1].imshow(images[:,:,-1])
axes[2].imshow(label)
plt.show()