# FloodNet Project

## Imports

In [6]:
from PIL import Image
import pandas as pd
import numpy as np
import os
import glob
from tqdm import tqdm
from matplotlib import pyplot as plt

import torch
from torch.utils.data import Dataset, DataLoader
import albumentations as A
import cv2

## First Tries

## Dataset

In [2]:
#helper function
#code in this cell from https://albumentations.ai/docs/examples/example_kaggle_salt/
def visualize(image, mask, original_image=None, original_mask=None):
    fontsize = 18
    
    if original_image is None and original_mask is None:
        f, ax = plt.subplots(2, 1, figsize=(8, 8))

        ax[0].imshow(image)
        ax[1].imshow(mask)
    else:
        f, ax = plt.subplots(2, 2, figsize=(8, 8))

        ax[0, 0].imshow(original_image)
        ax[0, 0].set_title('Original image', fontsize=fontsize)
        
        ax[1, 0].imshow(original_mask)
        ax[1, 0].set_title('Original mask', fontsize=fontsize)
        
        ax[0, 1].imshow(image)
        ax[0, 1].set_title('Transformed image', fontsize=fontsize)
        
        ax[1, 1].imshow(mask)
        ax[1, 1].set_title('Transformed mask', fontsize=fontsize)

In [51]:
class FloodData(Dataset):

    # mapping between label class names and indices
    LABEL_CLASSES = {
      'background': 		  0,
      'building-flooded': 			    1,
      'building-non-flooded': 	  2,
      'road-flooded': 				      3,
      'road-non-flooded': 			    4,
      'water': 			    5,
      'tree':   6,
      'vehicle': 				    7,
      'pool': 				    8,
      'grass': 			  9
    }
   

    def __init__(self, transforms=None, split='train'):
        
        self.transforms = transforms
        
        SPLIT = pd.read_csv("FloodNet_split_train_valid_test.csv", sep=',', header=None, names=["Column1", "Column2", "Column3"])
        SPLIT["Column1"] = SPLIT["Column1"].map(lambda x: "Data/image/" + x)
        SPLIT["Column2"] = SPLIT["Column2"].map(lambda x: "Data/mask/" + x)
        
        splitted_set = SPLIT[SPLIT["Column3"]==split]
        
        # prepare data
        self.data = list(zip(splitted_set["Column1"], splitted_set["Column2"]))                                  # list of tuples of (image path, label class)
        """
        images = np.empty((len(self.data)*3000,len(self.data)*4000,3))
        
        for i in range(len(self.data)):
            images[i] = np.array()
            Image.open()
        """ 
            
    #TODO: please provide the remaining functions required for the torch.utils.data.Dataset class.
    def __len__(self):
        return len(self.data)


    def __getitem__(self, x):
        imgName, labelsName = self.data[x]

        img = np.array(Image.open(imgName))
        labels = np.array(Image.open(labelsName))
        if self.transforms is not None:
            transformed = self.transforms(image=img, mask=labels)
            #code to visualize transformation - uncomment if want to use
            #visualize(transformed["image"], transformed["mask"], img, labels)
            img = transformed['image']
            labels = transformed['mask']
        else:
            img, labels = img[:3000, :4000,:], labels[:3000, :4000]
        
        img, labels = torch.tensor(img, dtype=torch.double), torch.tensor(labels)
        return img, labels


In [52]:
# source of code in this cell:
# https://www.binarystudy.com/2021/04/how-to-calculate-mean-standard-deviation-images-pytorch.html

train_not_transformed_set = FloodData(transforms = None, split = 'train')
train_not_transformed_loader = DataLoader(train_not_transformed_set, batch_size = 16)

def batch_mean_and_sd(loader):
    count = 0
    fst_moment = torch.empty(3)
    snd_moment = torch.empty(3)
    
    for images, _  in tqdm(loader):
        b, h, w, c = images.shape #batch, height, width, color
        nb_pixels = b*h*w
        
        sum_ = torch.sum(images, dim = [0,1,2])
        sum_of_square = torch.sum(torch.square(images), dim = [0,1,2])
        
        fst_moment = (count*fst_moment+sum_)/(count+nb_pixels)
        snd_moment = (count*snd_moment+sum_of_square)/(count+nb_pixels)

        count += nb_pixels
        
    mean, std = fst_moment, torch.sqrt(snd_moment-fst_moment**2)
    return mean, std

mean, std = batch_mean_and_sd(train_not_transformed_loader)
print(mean, std)

100%|██████████████████████████████████████████████████████████████████████████████████| 18/18 [04:22<00:00, 14.57s/it]


tensor([106.5385, 116.1601,  87.6059], dtype=torch.float64) tensor([53.1838, 49.5204, 53.5829], dtype=torch.float64)


In [78]:
transform_train = A.Compose([
    A.RandomSizedCrop(min_max_height = [1000, 2500], height=713, width=713),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.ShiftScaleRotate(p=0.5),
    #A.Blur(blur_limit = 3),
    A.RandomRotate90(),
    #A.OpticalDistortion(),
    #A.GridDistortion(),
    #A.Resize(height=713, width=713),
    
    #Normalization is applied by the formula: img = (img - mean * max_pixel_value) / (std * max_pixel_value)
    A.Normalize(mean = mean, std = std, max_pixel_value=1)
])

transform_val = A.Compose([
    A.RandomSizedCrop(min_max_height = [500, 2500], height=713, width=713),
    #A.Resize(height=713, width=713),
    A.Normalize(mean = mean, std = std, max_pixel_value=1)
])

In [79]:
train_dataset = FloodData(transforms = transform_train, split = 'train')
val_dataset = FloodData(transforms = transform_val, split = 'valid')
test_dataset = FloodData(transforms = transform_val, split = 'test')

In [83]:
train_loader = DataLoader(train_dataset, batch_size = 1)
next(iter(train_loader))[0]

tensor([[[[ 0.8736, -0.0234, -0.0860],
          [ 0.7796, -0.1244, -0.1793],
          [ 0.7796, -0.1244, -0.1979],
          ...,
          [-0.6306, -1.1139, -0.9631],
          [-0.5742, -1.0533, -0.8885],
          [-0.5742, -1.0129, -0.8698]],

         [[ 0.0275, -1.0129, -0.9631],
          [ 0.6104, -0.3667, -0.3846],
          [ 0.7608, -0.2052, -0.2726],
          ...,
          [-0.6118, -1.0937, -0.9258],
          [-0.5930, -1.0735, -0.9071],
          [-0.6118, -1.0735, -0.9258]],

         [[-0.2546, -1.3562, -1.2617],
          [ 0.6292, -0.3667, -0.4032],
          [ 0.5728, -0.4475, -0.4779],
          ...,
          [-0.4990, -0.9725, -0.8138],
          [-0.5554, -1.0331, -0.8698],
          [-0.5366, -1.0331, -0.8698]],

         ...,

         [[ 0.3283, -0.4879, -0.6458],
          [ 0.4599, -0.3667, -0.5712],
          [ 0.4599, -0.3667, -0.5712],
          ...,
          [-0.0665, -0.9321, -0.9444],
          [-0.0853, -0.9523, -1.0004],
          [-0.0289, -0