#Introduction

**Mount drive**

In [None]:
from google.colab import drive
drive.mount('/content/drive')

%cd ''

**Install and import libraries**

In [None]:
pip install --quiet -U albumentations

In [None]:
pip install --quiet torchmetrics

In [None]:
pip install --quiet segmentation-models-pytorch

In [None]:
!pip install --quiet tqdm

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torch.utils.data import DataLoader
import torchvision.transforms.functional as tf
import albumentations as A
from PIL import Image
import cv2
from torchvision.transforms import ToTensor
from albumentations.pytorch import ToTensorV2
from albumentations.augmentations.transforms import Normalize
import segmentation_models_pytorch as smp
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import glob
from torchmetrics.classification import MulticlassJaccardIndex, MulticlassF1Score, MulticlassAccuracy
from torchmetrics import Accuracy
from tqdm.notebook import tqdm_notebook
import pickle
from sklearn import svm
from sklearn.metrics import accuracy_score, confusion_matrix, ConfusionMatrixDisplay

**Check GPU and setup device**

In [None]:
# This function checks whether GPU is available. If yes, sets uo device = 'cuda:0' (GPU), otherwise device = 'cpu'
train_on_gpu = torch.cuda.is_available()
if not train_on_gpu:
    print('CUDA is not available.  Training on CPU ...')
else:
    print('CUDA is available!  Training on GPU ...')
device = torch.device("cuda:0" if train_on_gpu else "cpu")
print(device)

## Define useful functions

**Functions for data preprocessing** \\
**1.** This function converts original 10 classes into 3 classes: background, building flooded, road flooded. Except the last two classes, everything becomes background. \\

**2.** This function can be used in the Dataset class to retrieve the mask corresponding to the image by using the image number in the filepath.

In [None]:
# This function converts mask into three classes
# Class 0 = background
# Class 1 = building flooded
# Class 2 = road flooded
def three_classes(mask):
  mask_3 = np.where((mask != 1)&(mask!= 3), 0, mask)
  mask_3 = np.where(mask_3 == 3, 2, mask_3)
  return mask_3


#-----------------------------------------------------------
#This function reads only the image number from the filepath
def get_img_number(path):
  splits = path.split('/')
  #keep only the last split and remove the '.jpg' chars
  number = splits[-1][:-4]
  return number

**Functions for labeling** \\
**1.** This function assigns a binary label to the image (0 = not flooded, 1 = flooded), based on whether more or less than 25% of the image is flooded. \\
**2.** This function converts the binary label into string.

In [None]:
# Assign a label to images based on how many pixels are flooded building or road
# Label 1 = 'flooded' when percentage of flooded pixels is more or equal to 25% of the total pixels
# Label 0 = 'non-flooded'

def create_label(mask):
  label = 0
  tot_pixels = mask.flatten().shape[0]
  flood_pixels = np.count_nonzero((mask != 0))
  flood_perc = flood_pixels/tot_pixels
  if flood_perc >= 0.25:
    label = 1

  return label


#-------------------------------------------------------------------------------
#Use this function to convert binary label to 'flooded' when label = 1, or 'non flooded' when label = 0
def bin_to_label(label):
  if label == 0:
    return 'non-flooded'
  else:
    return 'flooded'

**Functions for data inspection** \\
This cell defines some useful variables for plotting.

In [None]:
# List of nominal classes (for plotting purposes)
classes = ['background', 'building flooded', 'road flooded']

#Define custom palette for classes plots
palette = ['yellow', 'cyan','red']
cmap_custom = mcolors.ListedColormap(palette)

The following functions can be used to inspect the data: \\
**1.** Creates a dictionary containing these information about the original dataset: n. of flooded and not flooded samples, sizes of all images and total pixel number, number of pixel per class for all images and corresponding percentage relative to total number of pixels, average percentage of background, flooded buildings and flooded roads in the dataset. \\

**2.** Creates a barplot to compare label distribution in different datasets.\\

**3.** Creates a histogram per class with the distribution of images per percentage of class pixels in the dataset.\\

**4.** Makes an histogram based on the mask's pixel distribution per class. \\

**5.** Plots one next to the other the original image, the mask with assigned label, the mask's histogram. \\




In [None]:
#-------------------------------------------------------------------------------
# 1. This functions defines a dictionary containing some statistics about the dataset
def stats_original(dir):
  ds_stats = dict()

  mask_paths= glob.glob(f'{dir}/*label*/*')

  #Variables to count how many flooded and non flooded images are present
  ds_stats['non_flooded'] = 0
  ds_stats['flooded'] = 0

  #Array to store image sizes
  ds_stats['img_size'] = np.zeros((len(mask_paths),2))
  ds_stats['tot_pixels'] = np.zeros(len(mask_paths))

  #Arrays to store pixel distribution per class
  ds_stats['background'] = np.zeros(len(mask_paths))
  ds_stats['building'] = np.zeros(len(mask_paths))
  ds_stats['road'] = np.zeros(len(mask_paths))

  for idx in range(len(mask_paths)):
    mask = np.array(Image.open(mask_paths[idx]))
    label = create_label(mask)

    #Return image size
    ds_stats['img_size'][idx] = mask.shape
    ds_stats['tot_pixels'][idx] = ds_stats['img_size'][idx][0] * ds_stats['img_size'][idx][1]

    #Cound flood or no flood
    if label == 0:
      ds_stats['non_flooded'] += 1
    else:
      ds_stats['flooded'] += 1

    #Count how many pixels per class
    ds_stats['background'][idx] = np.sum(np.where(mask == 0, 1, 0))
    ds_stats['building'][idx] = np.sum(np.where(mask == 1, 1, 0))
    ds_stats['road'][idx] = np.sum(np.where(mask == 2, 1, 0))

  #Compute percentage of each class per image
  ds_stats['perc_background']= ds_stats['background'] * 100 / ds_stats['tot_pixels']
  ds_stats['perc_building'] = ds_stats['building'] * 100 / ds_stats['tot_pixels']
  ds_stats['perc_road'] = ds_stats['road'] * 100 / ds_stats['tot_pixels']

  #Compute average values
  ds_stats['ave_background'] = np.mean(ds_stats['perc_background'])
  ds_stats['ave_building'] = np.mean(ds_stats['perc_building'])
  ds_stats['ave_road'] = np.mean(ds_stats['perc_road'])

  return ds_stats


#-------------------------------------------------------------------------------
# 2. This function makes a barplot for the flooded, non-flooded distribution
# flood_labels = list with number of flood labels for each dataset
# noflood_labels = list with number of non-flood labels for each dataset
def barplot(flood_labels, noflood_labels, dfs = ['Train', 'Validation', 'Test']):
  r = np.arange(len(dfs))
  width = 0.25

  plt.bar(r, flood_labels, color = 'b',
          width = width, edgecolor = 'black',
          label='Flooded')
  plt.bar(r + width, noflood_labels, color = 'pink',
          width = width, edgecolor = 'black',
          label='Not flooded')

  plt.xlabel("Dataset", fontsize = 12)
  plt.ylabel("Number of images")
  plt.title("Number of flooded and not flooded images per dataset")

  plt.xticks(r + width/2, dfs)
  plt.legend()
  plt.grid(alpha = 0.5)

  plt.show()


#-------------------------------------------------------------------------------
# 3. Make three histograms for different classes
def class_hist(dict, df_name):
  plt.subplots(1,3, figsize = (15,5))

  #First histogram for percentage of background
  plt.subplot(1,3,1)
  plt.hist(dict['perc_background'], bins = 10, range = (-5,105))
  plt.title(f'% Background in {df_name} set')
  plt.xlabel('%')
  plt.ylabel('Occurrences')
  plt.xlim(-5,105)
  plt.grid(alpha = 0.5)

  #Second histogram for percentage of flooded buildings
  plt.subplot(1,3,2)
  plt.hist(dict['perc_building'], bins = 10, range = (-5,105))
  plt.title(f'% Flooded buildings in {df_name} set')
  plt.xlabel('%')
  plt.ylabel('Occurrences')
  plt.xlim(-5,105)
  plt.grid(alpha = 0.5)

  #Third histogram for percentage of flooded roads
  plt.subplot(1,3,3)
  plt.hist(dict['perc_road'], bins = 10, range = (-5,105))
  plt.title(f'% Flooded roads in {df_name} set')
  plt.xlabel('%')
  plt.ylabel('Occurrences')
  plt.xlim(-5,105)
  plt.grid(alpha = 0.5)

  plt.tight_layout()
  plt.show()


#-------------------------------------------------------------------------------
# 4. This function shows the distribution of each class in the mask with a histogram
def mask_hist(mask):
  hist = plt.hist(mask.flatten(), bins = 3, range = (-0.5, 2.5))
  ticks = np.arange(0,3,1)
  plt.xticks(ticks, labels = classes, rotation = 90)
  plt.xlabel('Class')
  plt.ylabel('N. of pixels')
  plt.title('Pixel distribution per class')
  plt.grid(alpha = 0.5)
  plt.show()
  return hist


#-------------------------------------------------------------------------------
# 2. This function plots the image, the mask with corresponding label, and the histogram of the mask
def plot_item(img, mask, label):
  max_img = img.max()
  min_img = img.min()

  max_msk = mask.max()
  min_msk = mask.min()

  plt.subplots(1,3,figsize =(20,5))
  plt.subplot(1,3,1)
  plt.title('Original image')
  plt.imshow(img, vmin = min_img, vmax = max_img)
  #plt.imshow(img.permute(1,2,0), vmin = min_img, vmax = max_img)

  plt.subplot(1,3,2)
  plt.imshow(mask, vmin = min_msk, vmax = max_msk, cmap = cmap_custom)
  plt.title(f'Mask: {bin_to_label(label)}')
  plt.colorbar(orientation = 'vertical', ticks = range(0,3))
  plt.clim(-0.5,2.5)

  plt.subplot(1,3,3)
  hist = mask_hist(mask)

  plt.show()
  plt.tight_layout()

**Functions for the models** \\
This function takes as input the output of the model and applies a softmax2d activation function to convert the output to probabilities. For each pixel, the segmentation mask contains the index of the layer with highest probability (which also corresponds to class 0, 1, or 2). The function returns the mask, moved the cpu, as a numpy array and as a tensor.

In [None]:
# This function transforms into one channel the output of the model by taking the maximum value per pixel among all channels
def output_to_mask(output):
  output = output.squeeze(0)
  softmax = nn.Softmax2d()
  output = softmax(output)
  segm = np.argmax(output.cpu().detach().numpy(), axis = 0, keepdims = True)
  segm = segm.squeeze(0)
  segm_tensor = torch.tensor(segm).to(device)
  return segm, segm_tensor

# Prepare dataset and dataloader

In [None]:
# Define data directories
train_dir = 'Flood_data/train_new'
val_dir = 'Flood_data/val'
test_dir = 'Flood_data/test'

## Dataset

**Define transformations** \\
Different transformations are defined for train set and validation and testing sets, since training set may include data augmentation.
Normalization is defined separately because it is only applied to the image and not to the mask.

In [None]:
#--Train set transformations--
# Images have a 50% probability of being cropped to a slightly smaller image size (zoom in)
# Resized to 256x256 pixels and converted to tensors
train_transform = A.Compose([
                            A.RandomCrop(width= 3600, height= 2700, p = 0.5),   #crop by 0.9 factor to slightly zoom in
                            A.Resize(256, 256),
                            ToTensorV2()
                            ],is_check_shapes=False)


#--Test and validation transformations--
# Resized to 256x256 pixels and converted to tensors
test_transform = A.Compose([
                            A.Resize(256, 256),
                            ToTensorV2()
                            ],is_check_shapes=False)

#--Normalize--
# Data is normalized for the pretrained model
normalize = A.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225])

**Dataset class** \\
* This class function receives as root directory of images and labels, possible transformations and whether the class is used for validation sets (includes also test set).
* When val = True, the mask is converted into three classes. Augmented training masks already contain three classes. \\
* It contains a function which returns the length of the dataset and a function which returns the image, mask and label at a certain index.

In [None]:
#--Define dataset class--
class Dataset(torch.utils.data.Dataset):

  # Initialize dataset class, default transforms = None, default val = False
  def __init__(self, root_dir, transforms = None, val = False):
        #Save all image paths and transforms
        self.root_dir = root_dir
        self.img_paths= glob.glob(f'{self.root_dir}/*org*/*')
        self.transform = transforms
        self.val = val

  def __len__(self):
      # here i will return the number of samples in the dataset
      return len(self.img_paths)

  def __getitem__(self, idx):
      #Open image at index idx following folder paths
      img = np.array(Image.open(f'{self.img_paths[idx]}'))
      #Get the image number from image path to find the corresponding mask
      img_number = get_img_number(self.img_paths[idx])

      #Define mask path
      mask_path = glob.glob(f'{self.root_dir}/*label*/{img_number}*')[0]        #glob.glob returns a list so we only want the first element
      #Open corresponding mask
      mask = np.array(Image.open(f'{mask_path}'))

      #Convert to three classes if validation or test set is used
      if self.val:
        mask = three_classes(mask)

      #If any transformation is passed to the Dataset, apply transformations on image and mask
      if self.transform != None:
        #First normalize the image (not the mask)
        img = normalize(image = img)['image']
        #Apply transformations
        transf = self.transform(image = img, mask = mask)
        img = transf['image']
        mask = transf['mask']

      #Create the label of the original or the transformed mask
      label = create_label(mask)

      return img, mask, label

## Data augmentation

💀 **Dangerous function: do not set to 'True'** 💀 \\

* When ext_augment = True, this function will create a new training set with
augmented images.
* Seven augmentations are applied to images with label 1 (flooded) to balance the dataset in terms of flooded and not flooded images. The first transformations are geometric and are applied to both image and mask; the second set of transformations are color-related so they should be applied only to images. \\
* In this function, all training data is saved to a new folder, but only flooded images are augmented. \\

🕓 Estimated running time is one hour.

In [None]:
ext_augment = False

if ext_augment:
  n = 1
  transforms_all = A.Compose([
                          A.HorizontalFlip(p=1),
                          A.Compose([A.HorizontalFlip(p=1),A.VerticalFlip(p=1)]),
                          A.VerticalFlip(p=1)])

  transforms_img = A.Compose([A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.5, p =1),
                              A.HueSaturationValue(hue_shift_limit=0, sat_shift_limit=20, val_shift_limit = 10, p =1),
                              A.GaussianBlur(blur_limit = (9,19), p =1),
                              A.Compose([A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.5, p =1),
                                         A.HueSaturationValue(hue_shift_limit=0, sat_shift_limit=20, val_shift_limit = 10, p =1),
                                         A.GaussianBlur(blur_limit = (9,19), p =1)])])

  #Source directory where training images to be augmented are found
  source_train_dir = 'Flood_data_old/train'
  to_augment = Dataset(source_train_dir)

  #progress bar
  tqdm_images = tqdm_notebook(total = to_augment.__len__(), desc ='Image')

  for idx in range(0, len(to_augment)+1):

    image, mask, label = to_augment[idx]

    img_number = get_img_number(to_augment.img_paths[idx])
    # Save original image and mask with original number
    (Image.fromarray(image)).save(f'{train_dir}_new/train-org-img/{img_number}.jpg')
    (Image.fromarray(mask)).save(f'{train_dir}_new/train-label-img/{img_number}_lab.png')

    # When images are assigned to label = 'flooded' they are saved with img_number = n (dynamic)
    if label == 1:

      #Apply geometric transformations
      for transf in transforms_all:
        t = transf(image = image, mask = mask)
        img = tf.to_pil_image(t['image'])
        msk = tf.to_pil_image(t['mask'])
        img.save(f'{train_dir}_new/train-org-img/{n}.jpg')
        msk.save(f'{train_dir}_new/train-label-img/{n}_lab.png')
        n += 1

      n = n
      #Apply color transformations
      for transf in transforms_img:
        t = transf(image = image)
        img = tf.to_pil_image(t['image'])
        msk = tf.to_pil_image(mask)
        img.save(f'{train_dir}_new/train-org-img/{n}.jpg')
        msk.save(f'{train_dir}_new/train-label-img/{n}_lab.png')
        n += 1

    tqdm_images.update(1)

  tqdm_images.close()

## Dataloaders

In [None]:
#Import training, testing and validation datasets to be used for training the model (apply transformations)
trainset = Dataset(train_dir, transforms = train_transform)
valset = Dataset(val_dir, transforms = test_transform, val = True)
testset = Dataset(test_dir, transforms = test_transform, val = True)

In [None]:
#--Create dataloaders--
batch_size = 40

trainloader = DataLoader(trainset,
                         batch_size = batch_size,
                         shuffle = True,
                         drop_last = True,
                         num_workers = 0)
testloader = DataLoader(testset,
                        batch_size = 1,
                        shuffle = False,
                        drop_last = False,
                        num_workers = 0)
valloader = DataLoader(valset,
                        batch_size = 1,
                        shuffle = False,
                        drop_last = False,
                        num_workers = 0)

# Explore data

**Visualize sample image**

In [None]:
i = 0
for img, mask,label in Dataset(train_dir):
  plot_item(img, mask, label)

  if i == 10:
    break

In [None]:
i = 0
for img, mask,label in testset:
  plot_item(img, mask, label)

  if i == 2:
    break

**Compute dataset statistics** \\
* This cell computes a dictionary with statistics of each dataset or loads it if already computed.
* It is possible to compute the dictionary, save it and load it.
* Contents of the dictionary are described in section 'Define useful functions'


In [None]:
compute_stats = False
save = False
load = True

if compute_stats:
  dict_val = stats_original(val_dir)
  dict_test = stats_original(test_dir)
  dict_train = stats_original(train_dir)

  if save:
    with open('/content/drive/MyDrive/ACT/Final project/Useful_items/val_dict.pkl', 'wb') as f:
      pickle.dump(dict_val, f)
    f.close()

    with open('/content/drive/MyDrive/ACT/Final project/Useful_items/test_dict.pkl', 'wb') as f:
      pickle.dump(dict_test, f)
    f.close()

    with open('/content/drive/MyDrive/ACT/Final project/Useful_items/train_dict.pkl', 'wb') as f:
      pickle.dump(dict_train, f)
    f.close()

if load:
  with open('/content/drive/MyDrive/ACT/Final project/Useful_items/val_dict.pkl', 'rb') as f:
    dict_val = pickle.load(f)
  f.close()

  with open('/content/drive/MyDrive/ACT/Final project/Useful_items/test_dict.pkl', 'rb') as f:
    dict_test = pickle.load(f)
  f.close()

  with open('/content/drive/MyDrive/ACT/Final project/Useful_items/train_dict.pkl', 'rb') as f:
    dict_train = pickle.load(f)
  f.close()

**Label distribution** \\
The following plot shows the distribution of labels in each dataset.

In [None]:
# Store number of labels per dataset as two lists to pass to the barplot function
flooded = [dict_train['flooded'], dict_val['flooded'], dict_test['flooded']]
notflooded = [dict_train['non_flooded'], dict_val['non_flooded'], dict_test['non_flooded']]

# Barplot function takes in list of number of flooded and non-flooded labels per dataset and the names of datasets for plotting
barplot(flooded, notflooded, ['Train', 'Validation', 'Test'])

print('Number of flooded images in train set: ', flooded[0])
print('Number of not flooded images in train set: ', notflooded[0])

print('Number of flooded images in validation set: ', flooded[1])
print('Number of not flooded images in validation set: ', notflooded[1])

print('Number of flooded images in test set: ', flooded[2])
print('Number of not flooded images in test set: ', notflooded[2])

**Image distribution per percentage of class pixels** \\
These plots show the distribution of images in each dataset as a function of percentage of pixels for each class: 1) background, 2) flooded buildings, 3) flooded roads.

In [None]:
class_hist(dict_train, 'training')
class_hist(dict_val, 'validation')
class_hist(dict_test, 'test')

# UNet model

**Define UNet model** \\
This function defines the segmentation model by using the smp library where:
* encoder is pretrained EfficientNet B2
* encoder used pretrained weights on ImageNet
* output classes of the net is 3

In [None]:
model = smp.Unet(
        encoder_name = "efficientnet-b2",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
        encoder_weights = "imagenet",     # use `imagenet` pre-trained weights for encoder initialization
        encoder_depth = 5,
        in_channels = 3,                  # model input channels (1 for gray-scale images, 3 for RGB)
        classes = 3,                     # model output channels (number of classes in your dataset)
        activation = 'sigmoid'
        )


# Net is moved to device (can be cpu or gpu/cuda)
model.to(device)

# Train and evaluate model

**Upload model if already trained** \\
This cell allows to:
1. Import an already trained version of the model.
2. **Transfer learning**: freeze layers of the encoder to update only the decoder weights during training.

In [None]:
# To import the previous best model and continue training, set 'load_model' to True
load_model = True

if load_model:
  checkpoint = torch.load("Models/FloodSegmentation3_dataaugm_best.pth", map_location = device)
  model.load_state_dict(checkpoint['model'])
  model.to(device)

#---------------------------------------------------------------------------------
freeze_layers = True

if freeze_layers:
  ## Iteration to freeze first layers and only train the last ones
  for key, value in dict(model.named_children()).items():
    if "encoder" in key:
      for param in value.parameters():
          param.requires_grad = False
    else:
      for param in value.parameters():
          param.requires_grad = True

**Define function to validate the model** \\
This function is used to validate the model by computing IoU score and Dice score for segmentation image, and micro and macro accuracy for predicted label.
The predicted label is computed based on the number of flooded pixels in the image. \\

* When return_labels = True, ground truth and predicted labels are returned (use for testing).

* **N.B.** The dataloader passed to this function must have batch_size = 1.

In [None]:
# Create validation routine

def validate(net, valloader, device, return_labels = False):

    # Get final scores for IoU, Dice, micro and macro accuracy
    iou_score = MulticlassJaccardIndex(num_classes=3, average = 'macro')
    dice_score = MulticlassF1Score(num_classes=3, average = 'macro')
    accuracy_micro = Accuracy(task = 'multiclass', num_classes = 3, average = 'micro')
    accuracy_macro = Accuracy(task = 'multiclass', num_classes = 3, average = 'macro')

    # Move metrics to device
    iou_score = iou_score.to(device)
    dice_score = dice_score.to(device)
    accuracy_micro = accuracy_micro.to(device)
    accuracy_macro = accuracy_macro.to(device)

    # Set network in eval mode
    net.eval()

    #Lists to store ground truth labels and predicted labels
    gt_labels = []
    pred_labels = []

    n=0

    tqdm_val = tqdm_notebook(total= valset.__len__(), desc='Validation progress', unit='Image')

    # At the end of epoch, validate model
    for inp, mask, label in valloader:

        # Move batch to gpu
        inp = inp.to(device)
        mask = mask.to(device)
        label = label

        # Get output mask
        with torch.no_grad():
            outmask = net(inp)

        # Create a segmentation mask as numpy array and one as tensor
        segm_mask, segm_mask_tensor = output_to_mask(outmask)

        # Create first label to add to pred_labels from segmentation mask
        # Second label is formatted as a torch element of the batch
        segm_label = create_label(segm_mask)
        segm_label_tensor = torch.tensor(segm_label).unsqueeze(0).to(device)

        # Add ground truth label (taken from tensor) and segmentation result to lists
        gt_labels += [label.item()]
        pred_labels += [segm_label]

        # Update metrics for each item
        iou_score.update(segm_mask_tensor, mask.squeeze(0))
        dice_score.update(segm_mask_tensor, mask.squeeze(0))
        accuracy_micro.update(segm_label_tensor, label.to(device))
        accuracy_macro.update(segm_label_tensor, label.to(device))

        tqdm_val.update(1)
        n += 1

    # Compute metrics for segmentation tensor vs. original mask
    iou_score = iou_score.compute()
    dice_score = dice_score.compute()
    accuracy_micro = accuracy_micro.compute()
    accuracy_macro = accuracy_macro.compute()
    print(iou_score, dice_score, accuracy_micro, accuracy_macro)

    # set network in training mode
    net.train()

    if return_labels:
      return iou_score, dice_score, accuracy_micro, accuracy_macro, gt_labels, pred_labels

    else:
      return iou_score, dice_score, accuracy_micro, accuracy_macro

## Training

**Setup tensorboard**
* Define name of the experiment to store model training.
* Launch tensorboard

In [None]:
experiment_name = 'FloodSegmentation3_dataaugm'

In [None]:
import shutil
#%load_ext tensorboard
%reload_ext tensorboard
%tensorboard --logdir={experiment_name}

**Launch training** \\
* When train = True, this cells trains the network. \\
* The loss function used for training the Cross Entropy Loss. The model is trained and optimized only based on the segmentation prediction (**not on label**). \\
* Validation is run at the end of each epoch and the model with best IoU and Dice score is saved as best model.


In [None]:
train = False

if train:

  cross_entropy = nn.CrossEntropyLoss()
  model = model.to(device)
  learning_rate = 0.001

  # define Adam optimizer
  optimizer = torch.optim.Adam(params=model.parameters(), lr= learning_rate)

  # define summary writer
  writer = SummaryWriter(experiment_name)

  # initialize iteration number
  n_iter = 0

  # define best validation value
  best_val_dice = 0
  best_val_iou = 0

  # number of epoch
  n_epoch = 10
  total_batches = len(trainset)//batch_size

  #Progress bar for epochs
  tqdm_epochs = tqdm_notebook(total=n_epoch, desc='Epochs')

  for cur_epoch in range(n_epoch):
      # plot current epoch
      writer.add_scalar("epoch", cur_epoch, n_iter)

      # Progress bar for batches
      tqdm_batches = tqdm_notebook(total= total_batches, desc=f'Epoch {cur_epoch}')

      for inp, mask, label in trainloader:
          # move batch to gpu
          inp = inp.to(device)
          mask = mask.to(device)

          # reset gradients
          optimizer.zero_grad()
          # get output
          outmask = model(inp)

          # compute loss
          loss = nn.CrossEntropyLoss().to(device)
          loss = loss(outmask, mask.long())
          loss.backward()

          # update weights
          optimizer.step()

          #Plot
          writer.add_scalar("Loss",loss.item(), n_iter)

          #Update progress bar
          tqdm_batches.update(1)
          n_iter += 1

      # At the end, validate model
      # Validate the model with IoU, Dice, micro and macro accuracy
      cur_iou, cur_dice, cur_accuracy_micro, cur_accuracy_macro = validate(model, valloader, device)

      # plot validation scores
      writer.add_scalar("Dice", cur_dice, cur_epoch)
      writer.add_scalar("IoU", cur_iou, cur_epoch)
      writer.add_scalar("Micro Accuracy", cur_accuracy_micro, cur_epoch)
      writer.add_scalar("Macro Accuracy", cur_accuracy_macro, cur_epoch)

      # Check if it is the best model so far
      if (best_val_dice is None or cur_dice >= best_val_dice) and (best_val_iou is None or cur_iou >= best_val_iou):
          # define new best val
          best_val_dice = cur_dice
          best_val_iou = cur_iou

          # save current model as best
          torch.save({
              'model': model.state_dict(),
              'opt': optimizer.state_dict(),
              'epoch': cur_epoch
          }, 'Models/' + experiment_name + '_best.pth')

      # save last model
      torch.save({
          'model': model.state_dict(),
          'opt': optimizer.state_dict(),
          'epoch': cur_epoch
      }, 'Models/' + experiment_name + '_last.pth')

      tqdm_batches.close()
      tqdm_epochs.update(1)

  tqdm_epochs.close()


# Test model on test set

**Compute metrics for test set** \\
The validation function is also used to validate the model with the test set. return_labels = True to compute confusion matrix.

In [None]:
# Get metrics and labels for test set
iou, dice, acc_micro, acc_macro, gt_labels, pred_labels = validate(model, testloader, device, return_labels= True)

In [None]:
# Compute and plot confusion matrix using predicted and true labels
cm = confusion_matrix(gt_labels, pred_labels)

disp = ConfusionMatrixDisplay(confusion_matrix = cm,
                              display_labels= [0,1])

disp.plot()

plt.show()

**Visualize test results** \\
This cell iterates over the test loader and computed segmented mask and label. It also prints the following:
* Plot of the original image, original mask and label
* Plot of the original image, segmented mask and predicted label
\\

**N.B.** The function breaks at i == 30.

In [None]:
i = 0
for img, msk, lab in testloader:
  model.eval()
  outmask = model(img)
  segm_mask, _ = output_to_mask(outmask)

  plot_item(img.squeeze(0).cpu(), msk.squeeze(0).cpu(), create_label(msk))
  plt.show()

  plot_item(img.squeeze(0).cpu(), segm_mask, create_label(segm_mask))
  plt.show()

  i += 1

  if i == 30:
    break

# Encoded features extraction

**Place a hook at the end of the encoder**

In [None]:
# This function registers a forward hook on a specified module of the model
class Hook():
    def __init__(self, module):
        self.hook = module.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output):
        self.input = input
        self.output = output

In [None]:
# Use this function to load the model for feature extraction
load_model = True
if load_model:
  #Load model
  checkpoint = torch.load("Models/FloodSegmentation3_dataaugm_best.pth", map_location = device)
  model.load_state_dict(checkpoint['model'])
  model.to(device)

# Register a hook on the last layer of the encoder
model.hook = Hook(model.encoder._blocks[21])

**Test feature extraction** \\
**N.B.** It is necessary to run this cell to define 'features_shape'.

In [None]:
# Try to extract features from a sample image
outmask = model(trainset.__getitem__(100)[0].to(device).unsqueeze(0))
features = model.hook.output
features_shape = np.array(features.squeeze(0).shape)
flat_output_size = features.flatten().shape[0]
print(np.array(features_shape))

**Load features** \\
If load_features = True, extracted features and corresponding labels are imported.

In [None]:
load_features = True

if load_features:
  with open('/content/drive/MyDrive/ACT/Final project/Features/train_labels.pkl', 'rb') as f:
    train_labels = pickle.load(f)
  f.close()

  with open('/content/drive/MyDrive/ACT/Final project/Features/train_features.pkl', 'rb') as f:
    train_features = pickle.load(f)
  f.close()

  with open('/content/drive/MyDrive/ACT/Final project/Features/test_labels.pkl', 'rb') as f:
    test_labels = pickle.load(f)
  f.close()

  with open('/content/drive/MyDrive/ACT/Final project/Features/test_features.pkl', 'rb') as f:
    test_features = pickle.load(f)
  f.close()

**Compute features** \\
The following cells can be run to extract features at the end of the decoder for train set and test set. Corresponding ground truth labels are also saved. When features and labels are extracted, they are automatically saved externally.

In [None]:
if load_features == False:
  # Create an empty array to allocate all trainset labels
  # Create an empty array to allocate all trainset features
  train_labels = np.zeros((trainset.__len__()//batch_size, batch_size))
  train_features = np.zeros((trainset.__len__()//batch_size, batch_size, features_shape[0], features_shape[1], features_shape[2]))

  n = 0

  tqdm_batches = tqdm_notebook(total= trainset.__len__()//batch_size, desc=f'Training set', unit='batch')

  for inp, mask, label in trainloader:
    inp = inp.to(device)
    mask = mask.to(device)
    label = label.to(device)

    # Allocate label
    train_labels[n] = label.cpu().numpy()

    # Run the model with input image and extract features where hook is placed
    with torch.no_grad():
      outmask = model(inp)
      train_features[n] = (model.hook.output).cpu().numpy()

    n += 1
    tqdm_batches.update(1)

  tqdm_batches.close()

  #---------------------------------------------------------------------------------------------

  with open('/content/drive/MyDrive/ACT/Final project/Features/train_labels.pkl', 'wb') as f:
    pickle.dump(train_labels, f)
  f.close()

  with open('/content/drive/MyDrive/ACT/Final project/Features/train_features.pkl', 'wb') as f:
    pickle.dump(train_features, f)
  f.close()

In [None]:
if load_features == False:
  # Create an empty array to allocate all testset labels
  # Create an empty array to allocate all testset features
  test_labels = np.zeros((testset.__len__(), 1))
  test_features = np.zeros((testset.__len__(), 1, features_shape[0], features_shape[1], features_shape[2]))

  n = 0

  tqdm_batches = tqdm_notebook(total= testset.__len__(), desc=f'Test set', unit='item')

  for inp, mask, label in testloader:
    inp = inp.to(device)
    mask = mask.to(device)
    label = label.to(device)

    # Allocate label
    test_labels[n] = label.cpu().numpy()

    # Run the model with input image and extract features where hook is placed
    with torch.no_grad():
      outmask = model(inp)
      test_features[n] = (model.hook.output).cpu().numpy()

    n += 1
    tqdm_batches.update(1)

  tqdm_batches.close()

  #----------------------------------------------------------------------------------------
  with open('/content/drive/MyDrive/ACT/Final project/Features/test_labels.pkl', 'wb') as f:
    pickle.dump(test_labels, f)
  f.close()

  with open('/content/drive/MyDrive/ACT/Final project/Features/test_features.pkl', 'wb') as f:
    pickle.dump(test_features, f)
  f.close()

# SVM for Classification

**Reshape features** \\
Training and test features and labels must be reshaped to be fed to SVM to remove division in batches and to flatten them.
* Features are reshaped to arrays of size = (len(dataset), len(flattened features array)).
* Labels are reshaped to arrays of size = (len(dataset))



In [None]:
train_features_flat = train_features.reshape((-1, features_shape[0] * features_shape[1] * features_shape[2]))
train_labels_flat = train_labels.reshape((-1))

test_features_flat = test_features.reshape((-1, features_shape[0] * features_shape[1] * features_shape[2]))
test_labels_flat = test_labels.reshape((-1))

**SVM for classification** \\
The following cell computes a SVM for classification and is trained with the flattened train features and labels and saves it externally. \\
When compute_svm = False, an already trained SVM is loaded.

In [None]:
compute_svm = False
if compute_svm:
  #Define SVM
  svc = svm.SVC(kernel = 'linear')

  #Train SVM
  svc.fit(train_features_flat, train_labels_flat)

  with open('/content/drive/MyDrive/ACT/Final project/Features/svc2006.pkl', 'wb') as f:
    pickle.dump(svc, f)
  f.close()

else:
  with open('/content/drive/MyDrive/ACT/Final project/Features/svc2006.pkl', 'rb') as f:
    svc = pickle.load(f)
  f.close()

**Predict labels** \\
Labels are predicted using the test set.

In [None]:
pred_labels = svc.predict(test_features_flat)

**Compute metrics** \\
The following function computes the macro accuracy of the classification prediction by averaging the accuracy scores for each class.


In [None]:
def macro_accuracy_score(gt, pred):
  class0_test =  gt[np.argwhere(gt == 0)]
  class0_pred = pred[np.argwhere(gt == 0)]

  class1_test =  gt[np.argwhere(gt == 1)]
  class1_pred = pred[np.argwhere(gt == 1)]

  accuracy0 = accuracy_score(class0_test, class0_pred)
  accuracy1 = accuracy_score(class1_test, class1_pred)

  accuracy_tot = (accuracy0 + accuracy1) / 2

  return accuracy_tot

Micro and macro accuracy and the confusion matrix are computed for the test predictions.

In [None]:
micro_accuracy = accuracy_score(test_labels_flat, pred_labels)
macro_accuracy = macro_accuracy_score(test_labels_flat, pred_labels)
conf_matr = confusion_matrix(test_labels_flat, pred_labels)

In [None]:
print('SVC micro accuracy: %.4f' % micro_accuracy)
print('SVC macro accuracy: %.4f' % macro_accuracy)

In [None]:
disp = ConfusionMatrixDisplay(confusion_matrix = conf_matr,
                              display_labels= svc.classes_)

disp.plot()

plt.show()