In [None]:
# Imports 
import os
from os import listdir
import random
import itertools
from pathlib import Path
from typing import Any, Callable, Dict, List, Sequence, Tuple, Optional

import numpy as np
import pandas as pd
import geopandas as gpd

import rasterio
from rasterio.windows import Window

# Pytorch
import torch
from torch.utils.data import Dataset, DataLoader

# Lightning
import pytorch_lightning as pl


In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from tqdm import tqdm

# Function to load image data

First we need to define some functions to load data from image file with geographic information. The must frequently used library for geo raster data is the rasterio wrapping of GDAL library (c++). This library enable to read raster data with a defined number of bands and a specified windows (sub part) of the image. The abilities to only load a part of an image is useful as geographical image (raster) could be very huge, i.e more than 10 thousand pixel width.

In [None]:
# first we defined a simple loader we load some selected band of image band with all width/height
def geoimage_simple_load(img_path, band_indices):
    """
    load a geo image in numpy array of shape C * W * H
    
    """
    
    with rasterio.open(img_path) as src:
        img_array = src.read(indexes=band_indices)
    
    if img_array.ndim == 2:
        img_array = img_array[np.newaxis, ...]
    
    return img_array


def geoimage_load_tile(img_path, band_indices, window):
    """
    
    windows : tuple with col_off, row_off, width, height
    """
    
    with rasterio.open(img_path) as src:
        img_array = src.read(window=Window(*window), indexes=band_indices)
    
    if img_array.ndim == 2:
        img_array = img_array[np.newaxis, ...]
    
    return img_array  

# PreProcess or transform image data

## common transform

In [None]:
def format_to_dict(image: np.ndarray, mask:np.ndarray) -> Dict[str, Any] :
    """
    """
    return {"image": image, "mask": mask}
            

class BasicTransform():
    
    def __init__(self):
        self.params: Dict[Any, Any] = {}
        self.img_only: bool = False
        self.mask_only: bool = False
        
    def __call__(self, image: np.ndarray, mask:np.ndarray) -> Dict[str, Any]:
        if self.img_only :
            return {
                "image": self.apply_to_img(image),
                "mask": mask}
        elif self.mask_only :
            return {
                "image": image,
                "mask": self.apply_to_mask(mask)}
        else:
            return {
                "image": self.apply_to_img(image),
                "mask": self.apply_to_mask(mask)}
    
    def apply_to_img(self, img: np.ndarray) -> np.ndarray:
        raise NotImplementedError
        
    def apply_to_mask(self, mask: np.ndarray) -> np.ndarray:
        raise NotImplementedError


class HWC_to_CHW(BasicTransform):
    
    def __init__(self, img_only: bool=False, mask_only : bool = False) :
        super(CHW_to_HWC, self).__init__()
        self.img_only = img_only
        self.mask_only = mask_only
    
    @staticmethod
    def swap_axes(array : np.ndarray) -> np.ndarray:
        # swap the axes order from (rows, columns, bands) to (band, rows, columns)
        array = np.ma.transpose(array, [2, 0, 1])
        return array
    
    def apply_to_img(self, img: np.ndarray) -> np.ndarray:
        return HWC_to_CHW.swap_axes(img)
        
    def apply_to_mask(self, mask: np.ndarray) -> np.ndarray:
        return HWC_to_CHW.swap_axes(mask)
        
    
class CHW_to_HWC(BasicTransform):
    
    def __init__(self, img_only: bool=False, mask_only : bool = False) :
        super(CHW_to_HWC, self).__init__()
        self.img_only = img_only
        self.mask_only = mask_only
    
    @staticmethod
    def swap_axes(array : np.ndarray) -> np.ndarray:
        # swap the axes order from (bands, rows, columns) to (rows, columns, bands)
        array = np.ma.transpose(array, [1, 2, 0])
        return array
    
    def apply_to_img(self, img: np.ndarray) -> np.ndarray:
        return CHW_to_HWC.swap_axes(img)
        
    def apply_to_mask(self, mask: np.ndarray) -> np.ndarray:
        return CHW_to_HWC.swap_axes(mask)
        

class ToTorchTensor(BasicTransform):
    
    def __init__(self, img_only: bool=False, mask_only : bool = False) :
        super(ToTorchTensor, self).__init__()
        self.img_only = img_only
        self.mask_only = mask_only
    
    def apply_to_img(self, img: np.ndarray) :
        return torch.from_numpy(img).type(torch.float32)
        
    def apply_to_mask(self, mask: np.ndarray) :
        return  torch.from_numpy(mask).type(torch.float32)


class TensorToArray(BasicTransform):
    
    def __init__(self, img_only: bool=False, mask_only : bool = False) :
        super(TensorToArray, self).__init__()
        self.img_only = img_only
        self.mask_only = mask_only
    
    def apply_to_img(self, img ) -> np.ndarray:
        return img.cpu().numpy()
        
    def apply_to_mask(self, mask ) -> np.ndarray:
        return mask.cpu().numpy()
        
        
class ScaleImageToFloat(BasicTransform):
    """
    scale an input image to float image between [0, 1]
    """
    
    def __init__(self, scale_factor : float = 255, clip : bool = False) :
        super(ScaleImageToFloat, self).__init__()
        self.img_only = True
        self.scale_factor = scale_factor
        self.clip = clip
    
    def apply_to_img(self, img: np.ndarray) -> np.ndarray:
        img = np.multiply(img, 1./self.scale_factor, dtype=np.float32)
        if self.clip :
            return np.clip(img, 0, 1)
        else :
            return img

        
class FloatImageToByte(BasicTransform):
    """
    scale an input image from [0-1] to [0-255] mainly ofr rgb display purpose
    """
    
    def __init__(self, clip : bool = False) :
        super(FloatImageToByte, self).__init__()
        self.img_only = True
        self.scale_factor = 255
        self.clip = clip
    
    def apply_to_img(self, img: np.ndarray) -> np.ndarray:
        img = np.multiply(img, self.scale_factor, dtype=np.float32)
        img = img.astype(np.uint8)
        if self.clip :
            return np.clip(img, 0, 255)
        else :
            return img

# Prepair Gers data

We have to prepocess some data in order to use Gers Dataset in our expriment.

 * load define label metadata : name, color (lut),  associated shapefile
 * load image bounds and join with fold group
 * load /adapt odeon csv of patch dataset for testing and comparision between two dataset loading strategy

In [None]:
# Main parameters
gers_dataset_root_dir = Path("/home/data/32_2019_prod")

path_data = gers_dataset_root_dir.joinpath("dataset_ocsng_gers_naf_fold")
image_bands = [1, 2, 3] 
mask_bands = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]

Then we defined the naf nomenclature used on gers dataset.

In [None]:
naf_label = [
    "batiment",
    "zone_permeable",
    "zone_impermeable",
    "piscine",
    "sol_nus",
    "surface-eau",
    "neige",
    "coupe",
    "feuillus",
    "conifere",
    "brousaille",
    "vigne",
    "culture",
    "terre_labouree",
    "autre"
]

naf_lut = np.array([
 [219,  14, 154],
 [114, 113, 112],
 [248,  12,   0],
 [ 61, 230, 235],
 [169, 113,   1],
 [ 21,  83, 174],
 [255, 255, 255],
 [138, 179, 160],
 [ 70, 228, 131],
 [ 25,  74,  38],
 [243, 166,  13],
 [102,   0, 130],
 [255, 243,  13],
 [228, 223, 124],
 [  0,   0,   0]
])

label_names = [
    "batiment",
    "zone_permeable",
    "zone_impermeable",
    "piscine",
    "sol_nus",
    "surface-eau",
    "neige",
    "coupe",
    "feuillus",
    "conifere",
    "brousaille",
    "vigne",
    "culture",
    "terre_labouree"
]
label_shp = [
    "mask_32_2019_01-batiment.shp",          
    "mask_32_2019_02-zone-permeable.shp",
    "mask_32_2019_03-zone-impermeable.shp",
    "mask_32_2019_04-piscine.shp",
    "mask_32_2019_05-sol-nu.shp",
    "mask_32_2019_06-surface-eau.shp", 
    "mask_32_2019_07-neige.shp",
    "mask_32_2019_08-naf_coupe.shp",
    "mask_32_2019_09-naf_feuillus.shp", 
    "mask_32_2019_10-naf_conifere.shp",
    "mask_32_2019_11-naf_landes-ligneuses.shp",
    "mask_32_2019_12-naf_vignes.shp",
    "mask_32_2019_13-naf_cultures.shp",
    "mask_32_2019_14-naf-terre_labouree.shp"
]
label_channel = [0,1,2,3,4,5,6,7,8,9,10,11,12,13]


### loading of train val test patch for set/fold 4

as image path in csv file are in absolut path we need to replace old one with new corresponding to current root dir

In [None]:
df_train = pd.read_csv(os.path.join(path_data, 'train_4_fold_2-fold_3-fold_4.csv'), names=['img', 'msk'])
df_val = pd.read_csv(os.path.join(path_data, 'val_4_fold_1.csv'), names=['img', 'msk'])
df_test = pd.read_csv(os.path.join(path_data, 'test_4_fold_5.csv'), names=['img', 'msk'])

# change abs dir to new root path
old_root_dir = Path("/home/ign.fr/ndavid/test_odeon_ocsng_32")
for df in [df_train, df_val, df_test]:
    df["img"] =  df["img"].str.replace(str(old_root_dir), str(gers_dataset_root_dir), regex=False)
    df["msk"] =  df["msk"].str.replace(str(old_root_dir), str(gers_dataset_root_dir), regex=False)

In [None]:
df_train.head(5)

### loading images list with corresponding fold

In [None]:
def load_img_dir(img_dir : Path):
    files = [img_dir.joinpath(f) for f in listdir(img_dir) if img_dir.joinpath(f).is_file()]
    img_files = [f for f in files if f.suffix.lower() in [".tif", ".jp2"]]
    img_rows = []
    for img_path in img_files:
        with rasterio.open(img_path) as ds:
            width = ds.width # x axis
            height = ds.height # y axis
            transform = ds.transform
            res_x = transform[0]
            res_y = transform[4]
            ul_x = transform[2]
            ul_y = transform[5]
            path = str(img_path)
            name = str(img_path.stem)
            row = {
                "name" : name,
                "width" : width,
                "height" : height,
                "res_x" : res_x,
                "res_y" : res_y,
                "ul_x" : ul_x,
                "ul_y" : ul_y,
                "path" : path,
                "transform" : transform
            }
            img_rows.append(row)
    img_df = pd.DataFrame(img_rows)
    return img_df

# utils to load fold and image lists
rvb_dir = gers_dataset_root_dir.joinpath("IMAGES_RVB")
roi_shp_path = gers_dataset_root_dir.joinpath("kfold_32", "zones_vt_32.shp")

gers_rvb_df = load_img_dir(rvb_dir)
gers_rvb_df["roi_name"] = gers_rvb_df["name"].str[0:-4]
gers_rvb_df = gers_rvb_df.set_index("roi_name")

roi_gdf = gpd.read_file(roi_shp_path)
roi_gdf = roi_gdf[["id","kfold"]]
roi_gdf["roi_name"] = roi_gdf["id"].apply(lambda x: f"FR_032_2019_{x[0].upper()}-{x[1:].zfill(2)}")
roi_gdf = roi_gdf.set_index("roi_name")

gers_rvb_df = gers_rvb_df.join(roi_gdf, lsuffix='img', rsuffix='shp')
gers_rvb_df.head(6)

In [None]:
label_df = pd.DataFrame(data={"names" : label_names, "shp" : label_shp, "channel" : label_channel})
label_df.head(10)

In [None]:
# add mask image columns
gers_rvb_df = gers_rvb_df.rename(columns={"path": "rvb_path"})
gers_rvb_df["msk_path"] =  gers_rvb_df["rvb_path"].str.replace("IMAGES_RVB", "IMAGES_MASK", regex=False)
gers_rvb_df["msk_path"] =  gers_rvb_df["msk_path"].str.replace("_RVB.tif", "-MASK.tif", regex=False)

#### rasterize data

If not done yet we need to rasterize shapefile data by zone/images bounds before using it in a Torch dataset

In [None]:
# rasterize image
img_rvb_list = gers_rvb_df["rvb_path"].values
in_shp_dir = gers_dataset_root_dir.joinpath("MASK_SHP", "SAISIE")
out_mask_dir = gers_dataset_root_dir.joinpath("IMAGES_MASK")

from rasterio import features

for img_filename in img_rvb_list :
    img_path = Path(img_filename)
    print(img_path.stem)
    with rasterio.open(img_path) as src_dataset:
        kwds = src_dataset.profile
    
    kwds['driver'] = 'GTiff'
    kwds['count'] = len(label_df)
    
    out_mask = out_mask_dir.joinpath(f"{img_path.stem[:-4]}-MASK.tif")
    with rasterio.open(out_mask, 'w', **kwds) as dst_dataset:
        out_transform = dst_dataset.transform
        xmin, ymin, xmax, ymax  = dst_dataset.bounds
        width = dst_dataset.width # x axis
        height = dst_dataset.height # y axis
        out_shape = (height, width)
        # print(out_shape)
        for ind in label_df.index:
            label_name = label_df["names"][ind]
            label_shp = label_df["shp"][ind]
            out_channel = int(label_df["channel"][ind])
            label_path = in_shp_dir.joinpath(label_shp)
            # print(label_shp)
            # print(out_channel)
            # load shapefile to geopandas
            label_gdf = gpd.read_file(label_path)
            # overlay by image bounds ?
            label_gdf = label_gdf.cx[xmin:xmax, ymin:ymax]
            # this is where we create a generator of geom, value pairs to use in rasterizing
            if len(label_gdf) !=0 :
                shapes = ((geom, 1) for geom in label_gdf.geometry)
                burned = features.rasterize(shapes=shapes, out_shape=out_shape, fill=0, transform=out_transform)
            else :
                burned = np.zeros(out_shape, dtype=np.uint8)
            # print(burned.shape)
            dst_dataset.write(burned.astype(rasterio.uint8), out_channel+1)

### test rasterization display


In [None]:
img_rvb_list = gers_rvb_df["rvb_path"].values
out_mask_dir = gers_dataset_root_dir.joinpath("IMAGES_MASK")
test_mask = out_mask_dir.joinpath("FR_032_2019_U-17-MASK.tif")

with rasterio.open(test_mask) as mask_ds:
    mask_array = mask_ds.read()
    mask_array = np.argmax(mask_array, axis=0)
    mask_array = np.take(naf_lut, mask_array, axis=0)

print(mask_array.shape)
plt.figure(figsize=(12,12))
plt.imshow(mask_array[1024:2048,2048:3072])

# Torch Dataset

A torch dataset is a class who is responsible for loading item on a deep learning dataset. It is use to iterate on the train and val parts of the dataset on the treaining loop and on the test aprt when evaluating model results.


### utils to display patch/batch

In [None]:
# some function to display patch and control dataset functionnality

def view_patch(data, transforms=None):
    """
    dataset: dataset contains tile & mask 
    idx : index 
    
    Returns : plot tile & mask  
    """    
    if transforms is not None :
        for t in transforms:
            data = t(**data)
                
    raster_tile = data["image"]
    raster_gt = data["mask"]
    
    figure, ax = plt.subplots(nrows=1, ncols=2,figsize=(10,6))
    
    ax[0].imshow(raster_tile)
    ax[0].set_title('Raster Tile')
    ax[0].set_axis_off()
    
    ax[1].imshow(raster_gt)
    ax[1].set_title('Raster Gt')
    ax[1].set_axis_off()
    
    plt.tight_layout()
    plt.show()

    
def view_batch(batch_data, transforms=None, size = None, ncols = None):
    
    raster_tiles = batch_data["image"]
    raster_gts = batch_data["mask"]
    
    batch_size = raster_tiles.shape[0]
    ncols = batch_size
    if size is not None :
        ncols = size
    
    figure, ax = plt.subplots(nrows=2, ncols=ncols, figsize=(20, 8))    
        
    for idx in range(ncols):
        if transforms is not None :
            data = {
                "image" : raster_tiles[idx],
                "mask" : raster_gts[idx]}
            for t in transforms:
                data = t(**data)
        
            raster_tile = data["image"]
            raster_gt = data["mask"]
        else :
            raster_tile = raster_tiles[idx]
            raster_gt = raster_gts[idx]
        
        ax[0][idx].imshow(raster_tile)
        ax[0][idx].set_axis_off()

        ax[1][idx].imshow(raster_gt)
        ax[1][idx].set_axis_off()

    plt.tight_layout()
    plt.show()
    
    
class ToRgbDisplay(BasicTransform):
    """
    scale an input image to float image between [0, 1]
    """
    
    def __init__(self, color_compo : str = None, channels_display : List[int] = None, lut : np.array =None) :
        """
        
        channels_display : list of channel to display in r,g,b order if dim 3, channel dto display if dim 1
        """
        
        super(ToRgbDisplay, self).__init__()
        if color_compo is None and channels_display is None:
            # set rgb as first three channel
            self.color_compo = "rgb"
            self.channels_display = [0,1,2]
        
        if lut is not None:
            self.lut = lut
        else :
            self.lut = None
            
        if len(self.channels_display) != 1 and len(self.channels_display) !=  3 :
            raise ValueError("number of channel to display should be 1 or 3 ")
    
    def apply_to_img(self, img: np.ndarray) -> np.ndarray:
        img = img[:,:,self.channels_display]
        return img

    def apply_to_mask(self, mask: np.ndarray) -> np.ndarray:
        mask = np.argmax(mask, axis=0)
        if self.lut is not None:
            # mask = self.lut[mask]
            mask = np.take(self.lut, mask, axis=0)
        return mask

## simple patch dataset

In [None]:
class PatchDataset(Dataset):
    def __init__(self, image_files, mask_files, transforms=None, image_bands=None,
                 mask_bands=None):
        self.image_files = image_files
        self.image_bands = image_bands
        self.mask_files = mask_files
        self.mask_bands = mask_bands
        self.transforms = transforms

        self.load_array = geoimage_simple_load
        self.format_data = format_to_dict

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, index):
        # get path
        image_file = self.image_files[index]
        mask_file = self.mask_files[index]

        # load array
        img = self.load_array(
            image_file,
            band_indices=self.image_bands)

        msk  = self.load_array(
            mask_file,
            band_indices=self.mask_bands)

        data = self.format_data(image = img, mask = msk)
        
        if self.transforms is not None :
            for t in self.transforms:
                data = t(**data)

        return data
    

In [None]:
# transforms for test
in_transforms = [
    ScaleImageToFloat(scale_factor=255, clip=True),
    ToTorchTensor()
]
# Data
train_dataset = PatchDataset(
    image_files=df_train['img'],
    mask_files=df_train['msk'],
    transforms=in_transforms,
    image_bands=image_bands,
    mask_bands=mask_bands)

val_dataset = PatchDataset(
    image_files=df_val['img'],
    mask_files=df_val['msk'],
    transforms=in_transforms,
    image_bands=image_bands,
    mask_bands=mask_bands)

test_dataset = PatchDataset(
    image_files=df_test['img'],
    mask_files=df_test['msk'],
    transforms=in_transforms,
    image_bands=image_bands,
    mask_bands=mask_bands)

In [None]:
test_idx = 170
test_data = train_dataset[test_idx]
print(f" keys : {test_data.keys()}")
img_type =  test_data['image'].type()
msk_type = test_data['mask'].type()
print(f" image type : {img_type}, mask type : {msk_type}")

display_transforms = [
    TensorToArray(),
    FloatImageToByte(clip=True),
    CHW_to_HWC(img_only=True),
    ToRgbDisplay(lut=naf_lut)
]


view_patch(test_data, transforms=display_transforms)

## dataset with chop/clip when reading image

Instead of first preprocess imagery to have a DL dataset into à list of small patch image files we could also try to tile imagery on the fly based on the original large aerial/satellite imagery. 

Why could we be interrested in such functionnality :
 
 * first to test speed/data efficiency management. Large dataset could be simpler to manage and compress than a lot of small files and reading could also be as efficient.
 * the decision made when splitting could be changed on the fly. As pass to 256 to 512 pixels patchs
 * could act as a form of data augmentation without duplicate memory by sampling tile a different ovelerapping positions.
 * could help to change/test sampling by class.

In [None]:
def get_nb_tile_from_img(img_shape: Tuple[int, int], tile_size: int) -> int :
    """
    """
    nb_tile_col = img_shape[0] // tile_size
    nb_tile_row = img_shape[1] // tile_size
    return  nb_tile_col*nb_tile_row

def get_img_windows_list(img_shape: Tuple[int, int], tile_size: int) :
    """
    """
    col_step = [col for col in range(0, img_shape[0], tile_size)]
    col_step.append(img_shape[0])
    row_step =  [row for row in range(0, img_shape[1], tile_size)]
    row_step.append(img_shape[1])

    windows_list = []
    for i, j in itertools.product(range(0, len(col_step)-1), 
                                  range(0, len(row_step)-1)):
        windows_list.append(tuple((
           row_step[j], 
           col_step[i], 
           row_step[j+1]- row_step[j],
           col_step[i+1]- col_step[i])))
    return windows_list

    
class LargeImageDataset(Dataset):

    def __init__(
        self, image_files, mask_files, tile_size=512, transforms=None, image_bands=None, mask_bands=None):
        self.image_files = image_files
        self.image_bands = image_bands
        self.tile_size = tile_size
        self.mask_files = mask_files
        self.mask_bands = mask_bands
        self.transforms = transforms
        self.format_data = format_to_dict
        
        self.load_array = geoimage_load_tile
        ## init tiles/windows list
        self.tiles_list = []
        for img_id, img_path in enumerate(self.image_files) :
            with rasterio.open(img_path) as img_ds :
                # shape dimension is [C, W, H ]
                img_width = img_ds.width
                img_heigth = img_ds.height
                img_shape = img_ds.shape # shape = (H, W)
                # print(f" W={img_width}, H={img_heigth}, shape ={img_shape}")
            
            windows_list = get_img_windows_list( img_shape, self.tile_size)
            tile_img_list = [ (img_id, window) for window in windows_list ]
            self.tiles_list.extend(tile_img_list)
                
        # shuffle list
        random.shuffle(self.tiles_list)

    def __len__(self):
        return len(self.tiles_list)

    def __getitem__(self, index):
        # get path
        idx, window = self.tiles_list[index]

        # print(window)
        # load array
        img = self.load_array(
            self.image_files[idx],
            band_indices=self.image_bands,
            window = window)
        # print(img.shape)

        msk  = self.load_array(
            self.mask_files[idx],
            band_indices=self.mask_bands,
            window = window)
        # print(msk.shape)

        data = self.format_data(image = img, mask = msk)
        
        if self.transforms is not None :
            for t in self.transforms:
                data = t(**data)

        return data

In [None]:
image_files_train = gers_rvb_df["rvb_path"].values
mask_files_train = gers_rvb_df["msk_path"].values
print(image_files_train[0:5])
print(mask_files_train[0:5])

train_dataset_tile = LargeImageDataset(
    image_files=image_files_train,
    mask_files=mask_files_train,
    tile_size = 512,
    transforms=None,
    image_bands=image_bands,
    mask_bands=mask_bands[:-1])

In [None]:
channel_last_transform = CHW_to_HWC(img_only=True)
display_patch_transform = ToRgbDisplay( lut=naf_lut)

test_data = train_dataset_tile[142]
view_patch(test_data, transforms=[channel_last_transform, display_patch_transform])

# Lighning Datamodule

we use datamodule to manage fold loading setup and some other utils

In [None]:
class TerriaDataModule(pl.LightningDataModule):
    
    def __init__(
        self, data_df, transforms= None , img_col = "images", img_bands = None, mask_col = "mask", mask_bands = None,
        group_col = "fold", set_config = None):
        
        super().__init__()
        self.batch_size = 4
        self.tile_size = 512
        self.num_workers = 2
        
        self.transform = transforms
        self.image_bands = image_bands
        self.mask_band = mask_bands
        
        self.data_df = data_df
        self.img_col = img_col
        self.mask_col = mask_col
        self.group_col = group_col
        if set_config is None :
            self.set_config = {
                "train" : ["train"],
                "val" : ["val"],
                "test" : ["test"] }
        else:
            self.set_config = set_config
        
                 
    def prepare_data(self):
        # rasterize from vector ?
        pass

    def setup(self, stage: Optional[str] = None):

        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            train_config = self.set_config["train"]
            train_df = self.data_df[self.data_df[self.group_col].isin(train_config)]
            image_files_train = train_df[self.img_col].values
            mask_files_train = train_df[self.mask_col].values
            self.train_dataset = LargeImageDataset(
                image_files=image_files_train,
                mask_files=mask_files_train,
                tile_size = self.tile_size,
                transforms=self.transform,
                image_bands=self.image_bands,
                mask_bands=self.mask_band)
            
            val_config = self.set_config["val"]
            val_df = self.data_df[self.data_df[self.group_col].isin(val_config)]
            image_files_val = val_df[self.img_col].values
            mask_files_val = val_df[self.mask_col].values
            self.val_dataset = LargeImageDataset(
                image_files=image_files_val,
                mask_files=mask_files_val,
                tile_size = self.tile_size,
                transforms=self.transform,
                image_bands=self.image_bands,
                mask_bands=self.mask_band)

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            test_config = self.set_config["test"]
            test_df = self.data_df[self.data_df[self.group_col].isin(test_config)]
            image_files_test = test_df[self.img_col].values
            mask_files_test = test_df[self.mask_col].values
            self.val_dataset = LargeImageDataset(
                image_files=image_files_test,
                mask_files=mask_files_test,
                tile_size = self.tile_size,
                transforms=self.transform,
                image_bands=self.image_bands,
                mask_bands=self.mask_band)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers = self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers = self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers = self.num_workers)

In [None]:
gers_set_config = {
    "set1" : {
        "train" : [3, 4, 5],
        "val" : [2],
        "test" : [1] },
    "set2" : {
        "train" : [4, 5, 1],
        "val" : [3],
        "test" : [2] },
    "set3" : {
        "train" : [5, 1, 2],
        "val" : [4],
        "test" : [3] },
    "set4" : {
        "train" : [1, 2, 3],
        "val" : [5],
        "test" : [4] },
    "set5" : {
        "train" : [2, 3, 4],
        "val" : [1],
        "test" : [5] }
}

In [None]:
# transforms for test
in_transforms = [
    ScaleImageToFloat(scale_factor=255, clip=True),
    ToTorchTensor()
]

train_config = gers_set_config["set5"]["train"]
train_df = gers_rvb_df[gers_rvb_df["kfold"].isin(train_config)]
image_files_train = train_df["rvb_path"].values
mask_files_train = train_df["msk_path"].values
train_dataset = LargeImageDataset(
    image_files=image_files_train,
    mask_files=mask_files_train,
    tile_size = 512,
    transforms=in_transforms,
    image_bands=image_bands,
    mask_bands=mask_bands[:-1])

train_dataloader = DataLoader(train_dataset, batch_size=4)

display_transforms = [
    TensorToArray(),
    FloatImageToByte(clip=True),
    CHW_to_HWC(img_only=True),
    ToRgbDisplay(lut=naf_lut)
]

test_batch = next(iter(train_dataloader))
view_batch(test_batch, size = 4, transforms = display_transforms)       


In [None]:
# transforms for test
in_transforms = [
    ScaleImageToFloat(scale_factor=255, clip=True),
    ToTorchTensor()
]
gers_data_module_set5 = TerriaDataModule(
    gers_rvb_df, 
    transforms = in_transforms,
    img_col = "rvb_path",
    img_bands = image_bands,
    mask_col = "msk_path",
    mask_bands = mask_bands[:-1],
    group_col = "kfold",
    set_config = gers_set_config["set5"])

gers_data_module_set5.setup()
train_dataloader = gers_data_module_set5.train_dataloader()

display_transforms = [
    TensorToArray(),
    FloatImageToByte(clip=True),
    CHW_to_HWC(img_only=True),
    ToRgbDisplay(lut=naf_lut)
]

test_batch = next(iter(train_dataloader))
view_batch(test_batch, size = 4, transforms = display_transforms)

## Test iteration dataloader

In [None]:
img_train_dataloader = gers_data_module_set5.train_dataloader()

with tqdm(
    total=len(img_train_dataloader), desc=f"Large Image loader",
    bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}]') as pbar:

    for sample in img_train_dataloader:
        images = sample['image']
        masks = sample['mask']
        pbar.update(1)

In [None]:
patch_train_dataset = PatchDataset(
    image_files=df_train['img'],
    mask_files=df_train['msk'],
    transforms=in_transforms,
    image_bands=image_bands,
    mask_bands=mask_bands)
patch_train_dataloader = DataLoader(patch_train_dataset, batch_size=4, num_workers=4)

with tqdm(
    total=len(patch_train_dataloader), desc=f"patch Image loader",
    bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}]') as pbar:

    for sample in patch_train_dataloader:
        images = sample['image']
        masks = sample['mask']
        pbar.update(1)