In [None]:
# Imports 
import os
from os import listdir
import random
import itertools
from pathlib import Path
from typing import Any, Callable, Dict, List, Sequence, Tuple, Optional

import pandas as pd
import numpy as np

import rasterio
from rasterio.windows import Window
import geopandas as gpd

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from tqdm import tqdm

# Prepair Gers data

We have to prepocess some data in order to use Gers Dataset in our expriment.

 * load define label metadata : name, color (lut),  associated shapefile
 * load image bounds and join with fold group
 * load /adapt odeon csv of patch dataset for testing and comparision between two dataset loading strategy

In [None]:
# Main parameters
gers_dataset_root_dir = Path("/media/dlsupport/DATA1/32_2019_prod")

path_data = gers_dataset_root_dir.joinpath("dataset_ocsng_gers_naf_fold")
image_bands = [1, 2, 3] 
mask_bands = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]

Then we defined the naf nomenclature used on gers dataset.

In [None]:
naf_label = [
    "batiment",
    "zone_permeable",
    "zone_impermeable",
    "piscine",
    "sol_nus",
    "surface-eau",
    "neige",
    "coupe",
    "feuillus",
    "conifere",
    "brousaille",
    "vigne",
    "culture",
    "terre_labouree",
    "autre"
]

naf_lut = np.array([
 [219,  14, 154],
 [114, 113, 112],
 [248,  12,   0],
 [ 61, 230, 235],
 [169, 113,   1],
 [ 21,  83, 174],
 [255, 255, 255],
 [138, 179, 160],
 [ 70, 228, 131],
 [ 25,  74,  38],
 [243, 166,  13],
 [102,   0, 130],
 [255, 243,  13],
 [228, 223, 124],
 [  0,   0,   0]
])

label_names = [
    "batiment",
    "zone_permeable",
    "zone_impermeable",
    "piscine",
    "sol_nus",
    "surface-eau",
    "neige",
    "coupe",
    "feuillus",
    "conifere",
    "brousaille",
    "vigne",
    "culture",
    "terre_labouree"
]
label_shp = [
    "mask_32_2019_01-batiment.shp",          
    "mask_32_2019_02-zone-permeable.shp",
    "mask_32_2019_03-zone-impermeable.shp",
    "mask_32_2019_04-piscine.shp",
    "mask_32_2019_05-sol-nu.shp",
    "mask_32_2019_06-surface-eau.shp", 
    "mask_32_2019_07-neige.shp",
    "mask_32_2019_08-naf_coupe.shp",
    "mask_32_2019_09-naf_feuillus.shp", 
    "mask_32_2019_10-naf_conifere.shp",
    "mask_32_2019_11-naf_landes-ligneuses.shp",
    "mask_32_2019_12-naf_vignes.shp",
    "mask_32_2019_13-naf_cultures.shp",
    "mask_32_2019_14-naf-terre_labouree.shp"
]
label_channel = [0,1,2,3,4,5,6,7,8,9,10,11,12,13]

### loading of train val test patch for set/fold 4

as image path in csv file are in absolut path we need to replace old one with new corresponding to current root dir

In [None]:
df_train = pd.read_csv(os.path.join(path_data, 'train_4_fold_2-fold_3-fold_4.csv'), names=['img', 'msk'])
df_val = pd.read_csv(os.path.join(path_data, 'val_4_fold_1.csv'), names=['img', 'msk'])
df_test = pd.read_csv(os.path.join(path_data, 'test_4_fold_5.csv'), names=['img', 'msk'])

# change abs dir to new root path
old_root_dir = Path("/home/ign.fr/ndavid/test_odeon_ocsng_32")
for df in [df_train, df_val, df_test]:
    df["img"] =  df["img"].str.replace(str(old_root_dir), str(gers_dataset_root_dir), regex=False)
    df["msk"] =  df["msk"].str.replace(str(old_root_dir), str(gers_dataset_root_dir), regex=False)

In [None]:
df_train.head(5)

In [None]:
# import for use in train code
from eotorchloader.transform.scale import ScaleImageToFloat
from eotorchloader.transform.tensor import ToTorchTensor
from eotorchloader.dataset.patch_dataset import PatchDataset

In [None]:
# display import
from eotorchloader.transform.scale import FloatImageToByte
from eotorchloader.transform.tensor import TensorToArray, CHW_to_HWC
from eotorchloader.transform.display import ToRgbDisplay

from eotorchloader.display.matplotlib import view_patch

In [None]:
# transforms for test
in_transforms = [
    ScaleImageToFloat(scale_factor=255, clip=True),
    ToTorchTensor()
]
# Data
train_dataset = PatchDataset(
    image_files=df_train['img'],
    mask_files=df_train['msk'],
    transforms=in_transforms,
    image_bands=image_bands,
    mask_bands=mask_bands)

val_dataset = PatchDataset(
    image_files=df_val['img'],
    mask_files=df_val['msk'],
    transforms=in_transforms,
    image_bands=image_bands,
    mask_bands=mask_bands)

test_dataset = PatchDataset(
    image_files=df_test['img'],
    mask_files=df_test['msk'],
    transforms=in_transforms,
    image_bands=image_bands,
    mask_bands=mask_bands)

In [None]:
test_idx = 170
test_data = train_dataset[test_idx]
print(f" keys : {test_data.keys()}")
img_type =  test_data['image'].type()
msk_type = test_data['mask'].type()
print(f" image type : {img_type}, mask type : {msk_type}")

display_transforms = [
    TensorToArray(),
    FloatImageToByte(clip=True),
    CHW_to_HWC(img_only=True),
    ToRgbDisplay(lut=naf_lut)
]

view_patch(test_data, transforms=display_transforms)

In [None]:
from torch.utils.data import DataLoader
patch_train_dataloader = DataLoader(train_dataset, batch_size=4, num_workers=4)

with tqdm(
    total=len(patch_train_dataloader), desc=f"patch Image loader",
    bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}]') as pbar:

    for sample in patch_train_dataloader:
        images = sample['image']
        masks = sample['mask']
        pbar.update(1)

### loading images list with corresponding fold

In [None]:
def load_img_dir(img_dir : Path):
    files = [img_dir.joinpath(f) for f in listdir(img_dir) if img_dir.joinpath(f).is_file()]
    img_files = [f for f in files if f.suffix.lower() in [".tif", ".jp2"]]
    img_rows = []
    for img_path in img_files:
        with rasterio.open(img_path) as ds:
            width = ds.width # x axis
            height = ds.height # y axis
            transform = ds.transform
            res_x = transform[0]
            res_y = transform[4]
            ul_x = transform[2]
            ul_y = transform[5]
            path = str(img_path)
            name = str(img_path.stem)
            row = {
                "name" : name,
                "width" : width,
                "height" : height,
                "res_x" : res_x,
                "res_y" : res_y,
                "ul_x" : ul_x,
                "ul_y" : ul_y,
                "path" : path,
                "transform" : transform
            }
            img_rows.append(row)
    img_df = pd.DataFrame(img_rows)
    return img_df

# utils to load fold and image lists
rvb_dir = gers_dataset_root_dir.joinpath("IMAGES_RVB")
roi_shp_path = gers_dataset_root_dir.joinpath("kfold_32", "zones_vt_32.shp")

gers_rvb_df = load_img_dir(rvb_dir)
gers_rvb_df["roi_name"] = gers_rvb_df["name"].str[0:-4]
gers_rvb_df = gers_rvb_df.set_index("roi_name")

roi_gdf = gpd.read_file(roi_shp_path)
roi_gdf = roi_gdf[["id","kfold"]]
roi_gdf["roi_name"] = roi_gdf["id"].apply(lambda x: f"FR_032_2019_{x[0].upper()}-{x[1:].zfill(2)}")
roi_gdf = roi_gdf.set_index("roi_name")

gers_rvb_df = gers_rvb_df.join(roi_gdf, lsuffix='img', rsuffix='shp')
gers_rvb_df.head(6)

In [None]:
label_df = pd.DataFrame(data={"names" : label_names, "shp" : label_shp, "channel" : label_channel})
label_df.head(10)

In [None]:
# add mask image columns
gers_rvb_df = gers_rvb_df.rename(columns={"path": "rvb_path"})
gers_rvb_df["msk_path"] =  gers_rvb_df["rvb_path"].str.replace("IMAGES_RVB", "IMAGES_MASK", regex=False)
gers_rvb_df["msk_path"] =  gers_rvb_df["msk_path"].str.replace("_RVB.tif", "-MASK.tif", regex=False)

#### rasterize data

If not done yet we need to rasterize shapefile data by zone/images bounds before using it in a Torch dataset

In [None]:
# rasterize image
img_rvb_list = gers_rvb_df["rvb_path"].values
in_shp_dir = gers_dataset_root_dir.joinpath("MASK_SHP", "SAISIE")
out_mask_dir = gers_dataset_root_dir.joinpath("IMAGES_MASK")

from rasterio import features

for img_filename in img_rvb_list :
    img_path = Path(img_filename)
    print(img_path.stem)
    with rasterio.open(img_path) as src_dataset:
        kwds = src_dataset.profile
    
    kwds['driver'] = 'GTiff'
    kwds['count'] = len(label_df)
    
    out_mask = out_mask_dir.joinpath(f"{img_path.stem[:-4]}-MASK.tif")
    with rasterio.open(out_mask, 'w', **kwds) as dst_dataset:
        out_transform = dst_dataset.transform
        xmin, ymin, xmax, ymax  = dst_dataset.bounds
        width = dst_dataset.width # x axis
        height = dst_dataset.height # y axis
        out_shape = (height, width)
        # print(out_shape)
        for ind in label_df.index:
            label_name = label_df["names"][ind]
            label_shp = label_df["shp"][ind]
            out_channel = int(label_df["channel"][ind])
            label_path = in_shp_dir.joinpath(label_shp)
            # print(label_shp)
            # print(out_channel)
            # load shapefile to geopandas
            label_gdf = gpd.read_file(label_path)
            # overlay by image bounds ?
            label_gdf = label_gdf.cx[xmin:xmax, ymin:ymax]
            # this is where we create a generator of geom, value pairs to use in rasterizing
            if len(label_gdf) !=0 :
                shapes = ((geom, 1) for geom in label_gdf.geometry)
                burned = features.rasterize(shapes=shapes, out_shape=out_shape, fill=0, transform=out_transform)
            else :
                burned = np.zeros(out_shape, dtype=np.uint8)
            # print(burned.shape)
            dst_dataset.write(burned.astype(rasterio.uint8), out_channel+1)

### test rasterization display


In [None]:
img_rvb_list = gers_rvb_df["rvb_path"].values
out_mask_dir = gers_dataset_root_dir.joinpath("IMAGES_MASK")
test_mask = out_mask_dir.joinpath("FR_032_2019_U-17-MASK.tif")

with rasterio.open(test_mask) as mask_ds:
    mask_array = mask_ds.read()
    mask_array = np.argmax(mask_array, axis=0)
    mask_array = np.take(naf_lut, mask_array, axis=0)

print(mask_array.shape)
plt.figure(figsize=(12,12))
plt.imshow(mask_array[1024:2048,2048:3072])

## dataset with chop/clip when reading image

Instead of first preprocess imagery to have a DL dataset into à list of small patch image files we could also try to tile imagery on the fly based on the original large aerial/satellite imagery. 

Why could we be interrested in such functionnality :
 
 * first to test speed/data efficiency management. Large dataset could be simpler to manage and compress than a lot of small files and reading could also be as efficient.
 * the decision made when splitting could be changed on the fly. As pass to 256 to 512 pixels patchs
 * could act as a form of data augmentation without duplicate memory by sampling tile a different ovelerapping positions.
 * could help to change/test sampling by class.

In [None]:
# import for use in train code
from eotorchloader.dataset.scene_dataset import LargeImageDataset

In [None]:
image_files_train = gers_rvb_df["rvb_path"].values
mask_files_train = gers_rvb_df["msk_path"].values
print(image_files_train[0:5])
print(mask_files_train[0:5])

train_dataset_tile = LargeImageDataset(
    image_files=image_files_train,
    mask_files=mask_files_train,
    tile_size = 512,
    transforms=None,
    image_bands=image_bands,
    mask_bands=mask_bands[:-1])

In [None]:
# display import
from eotorchloader.transform.tensor import CHW_to_HWC
from eotorchloader.transform.display import ToRgbDisplay

from eotorchloader.display.matplotlib import view_patch, view_batch

In [None]:
channel_last_transform = CHW_to_HWC(img_only=True)
display_patch_transform = ToRgbDisplay( lut=naf_lut)

test_data = train_dataset_tile[142]
view_patch(test_data, transforms=[channel_last_transform, display_patch_transform])

# Lighning Datamodule

we use datamodule to manage fold loading setup and some other utils

In [None]:
from eotorchloader.datamodule.terria import TerriaDataModule

In [None]:
gers_set_config = {
    "set1" : {
        "train" : [3, 4, 5],
        "val" : [2],
        "test" : [1] },
    "set2" : {
        "train" : [4, 5, 1],
        "val" : [3],
        "test" : [2] },
    "set3" : {
        "train" : [5, 1, 2],
        "val" : [4],
        "test" : [3] },
    "set4" : {
        "train" : [1, 2, 3],
        "val" : [5],
        "test" : [4] },
    "set5" : {
        "train" : [2, 3, 4],
        "val" : [1],
        "test" : [5] }
}

In [None]:
# transforms for test
in_transforms = [
    ScaleImageToFloat(scale_factor=255, clip=True),
    ToTorchTensor()
]
gers_data_module_set5 = TerriaDataModule(
    gers_rvb_df, 
    transforms = in_transforms,
    img_col = "rvb_path",
    img_bands = image_bands,
    mask_col = "msk_path",
    mask_bands = mask_bands[:-1],
    group_col = "kfold",
    set_config = gers_set_config["set5"])

gers_data_module_set5.setup()
train_dataloader = gers_data_module_set5.train_dataloader()

In [None]:
display_transforms = [
    TensorToArray(),
    FloatImageToByte(clip=True),
    CHW_to_HWC(img_only=True),
    ToRgbDisplay(lut=naf_lut)
]
test_batch = next(iter(train_dataloader))

In [None]:
view_batch(test_batch, size = 4, transforms = display_transforms)

In [None]:
img_train_dataloader = gers_data_module_set5.train_dataloader()

with tqdm(
    total=len(img_train_dataloader), desc=f"Large Image loader",
    bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}]') as pbar:

    for sample in img_train_dataloader:
        images = sample['image']
        masks = sample['mask']
        pbar.update(1)