<a href="https://colab.research.google.com/github/hilasha2/UDWTNet/blob/main/UDWTNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Initialization

Install PyTorch 1.6.0
---

In [None]:
!pip install torch===1.6.0+cu101 torchvision===0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html

Install imgaug
---

In [None]:
# imaug - Library for augmentation
# https://imgaug.readthedocs.io/en/latest/source/installation.html
!pip uninstall imgaug             # remove old version of imagaug before installing the new version
!pip uninstall albumentations    # remove old version of albumentations before installing the new version
!pip install git+https://github.com/aleju/imgaug.git

Import libraries
---

In [None]:
# Import the necessary libraries
import torch                                      # main module that holds all the things you need for Tensor computation
import torch.nn as nn                             # fundamental building blocks of neural networks: models, all kinds of layers, activation functions, parameter classes, etc.
import torch.nn.functional as F                   # for binary_cross_entropy loss
import torch.optim as optim                       # optimizers like SGD, ADAM, etc

from torch.utils.data.dataset import Dataset      # for custom data-sets
import os                                         # for handling paths
from PIL import Image                             # for loading tif images
import natsort                                    # for sorting database by order

from imgaug import augmenters as iaa              # for image augmentation
import imgaug as ia                               # for image augmentation
import numpy as np                                # imgaug uses np images
from imgaug.augmentables.segmaps import SegmentationMapsOnImage  # classes dealing with segmentation maps.
import cv2                                        # image processing methods     
import skimage.morphology                         # image processing methods
from scipy import ndimage                         # image processing methods

import torchvision                                # consists popular datasets. 
import torchvision.transforms as transforms       # common image transformations
import random                                     # random methods
import math                                       # mathematical methods library

import pdb                                        # breakpoint for debugging: pdb.set_trace(),  n (next), s(step)
import matplotlib.pyplot as plt                   # for ploting images, figues
import matplotlib.lines 
from tabulate import tabulate                     # to print as table
import time                                       # to time the training
from google.colab import drive                    # to mount google drive to save the models

Mount Google Drive- to save/load Model, load database
---

In [None]:
# https://medium.com/@ml_kid/how-to-save-our-model-to-google-drive-and-reuse-it-2c1028058cb2
drive.mount('/content/gdrive')

Set GPU  device
---

In [None]:
print ('Available GPU devices ', torch.cuda.device_count())
# If we have a GPU available, we'll set our device to GPU. 
# We'll use this device variable later in our code.

is_cuda = torch.cuda.is_available()

if is_cuda:
    device = torch.device("cuda")
    print("GPU is set")
else:
    device = torch.device("cpu")

CUDA_LAUNCH_BLOCKING = 1 # For debugging CUDA runtime error.

# Data Handling

Set path to Data base / save Path
---

In [None]:
BaseDataPath=F"/content/gdrive/My Drive/DeepLearningCourse/Project/Dataset/Training/DIC-C2DH-HeLa"
BaseSavePath=F"/content/gdrive/My Drive/DeepLearningCourse/Project"
#BaseDataPath=F"/content/gdrive/My Drive/Dataset/Training/DIC-C2DH-HeLa"
#BaseSavePath=F"/content/gdrive/My Drive"

Unet_save_name    = 'UNet.pt'
DWTNet_save_name  = 'DWTNet.pt'
UDWTNet_save_name = 'UDWTNet.pt' # 'UDWTNet_16.pt' 'UDWTNet_3.pt'

 Load database / class CustomDataset
---

In [None]:
# https://pytorch.org/tutorials/beginner/data_loading_tutorial.html
# https://discuss.pytorch.org/t/how-make-customised-dataset-for-semantic-segmentation/30881

"""
 CustomDataset loads images and their segmentation from root_dir.
 returns:
 imageAsFloatTensor     - Tensor of the cells images. 
 segAsLongTensor        - Binary silver truth segmentation tensor.
                        0 = background 
                        1 = foreground (cells)
 ThreeSegAsLongTensor - 3 labels silver truth segmentation tensor.
                        0 = background
                        1 = foreground (cells)
                        2 = cell boundaries

segDepthAsLongTensor - 16 labels of eneregy levels in depth images.
                       Serves as silver truth for DWTNet.                
"""
class CustomDataset(Dataset):

    def __init__(self, root_dir,generateSeg_depth = False, transform = None):  
  
        self.root_dir          =  root_dir
        self.seg_dir           = os.path.join(self.root_dir +"_ST" ,"SEG")     # silver truth
        self.seg_depth_dir     = os.path.join(self.root_dir +"_ST" ,"DWTNET")  # silver truth for dwtnet
        self.transform         = transform
        self.generateSeg_depth  = generateSeg_depth
        
        self.image_paths     = CustomDataset.sort_files(self.root_dir)
        self.seg_paths       = CustomDataset.sort_files(self.seg_dir)
        
        if self.generateSeg_depth:
          CustomDataset.generate_segDepthmap(self.seg_paths, self.seg_depth_dir)
        self.seg_depth_paths = CustomDataset.sort_files(self.seg_depth_dir)     

    def __getitem__(self, index):
      
      # Open  images, segmentations and markers
      image     = Image.open(self.image_paths[index],'r') # opens as [0-255]
      seg_depth = Image.open(self.seg_depth_paths[index], 'r') 

      # Convert to np array 
      imageAsNPArray    = CustomDataset.pil_image_to_np_array(image)
      segDepthAsNPArray = CustomDataset.pil_image_to_np_array(seg_depth)
        
      # Images & segmentations augmentation
      if self.transform:
        imageAsNPArray, segDepthAsNPArray  = \
        self.transform(imageAsNPArray,segDepthAsNPArray) # imgaug supports unit8 images 

      ## convert segmentation to labels mask
      # --- 2 labels images:
      BinarySegAsNPArray = np.copy(segDepthAsNPArray)
      BinarySegAsNPArray[BinarySegAsNPArray <= 2] = 0
      BinarySegAsNPArray[BinarySegAsNPArray > 2] = 1

      # --- Conversion to long and float tensors 
      BinarySegAsLongTensor = torch.LongTensor(BinarySegAsNPArray.astype(np.uint8)) 
      segDepthAsLongTensor  = torch.LongTensor(segDepthAsNPArray.astype(np.uint8))

      imageAsFloatTensor = torch.FloatTensor(imageAsNPArray.astype(np.float))
      imageAsFloatTensor = imageAsFloatTensor / 255 
      imageAsFloatTensor = imageAsFloatTensor.unsqueeze(0) # add dim

      # --- Z-normalization of the image
      imageAsFloatTensor = (imageAsFloatTensor  - imageAsFloatTensor.mean()) / imageAsFloatTensor.std()
 
      return imageAsFloatTensor , BinarySegAsLongTensor, segDepthAsLongTensor
      
    def __len__(self):
      return len(self.image_paths)

    # Sort files by filename
    def sort_files(dir):
      paths = []
      filelist =  os.listdir(dir)
      filelist = natsort.natsorted(filelist,reverse = False) 
      for file in filelist:
        if file.endswith(".tif"):
          fullpath = os.path.join(dir, file)
          paths.append(fullpath)
      return paths

    # Convert PIL.Image into a uint8 np array
    def pil_image_to_np_array(pil_image):
      np_image = np.array(pil_image.getdata())
      np_image = np_image.astype('uint8')
      return np.resize(np_image, [pil_image.width, pil_image.height])
      
    # generate GT for DWTNet out of the original segmentation image
    def generate_segDepthmap(seg_paths, seg_depth_dir):
      depth_bins = np.array([0,1,2,3,4,5,7,9,12,15,19,24,30,37,45,54,1000]);

      for segfile in seg_paths:
        #print("opening " + segfile)
        seg = Image.open(segfile , 'r') #opens as uint8
        segAsNPArray = CustomDataset.pil_image_to_np_array(seg)
        
        segDepthAsNPArray= np.zeros (segAsNPArray.shape,dtype= float)
        
        ids = np.unique(segAsNPArray)   
        for i in range(1, len(ids)):
          seg_i = np.copy(segAsNPArray)
          # convert to mask containing all labels except the current label
          cond = seg_i != ids[i]
          seg_i [ seg_i != ids[i] ] = 0
          seg_i [     seg_i > 0   ] = 1
          
          #if (seg_i.sum() < 100):
          #  continue
          
          #distance transform
          depth_i = ndimage.morphology.distance_transform_edt(seg_i)
          segDepthAsNPArray = segDepthAsNPArray + depth_i
        
        for  i in range(len(depth_bins)-1):
          cond_1 = segDepthAsNPArray > depth_bins[i]
          cond_2 = segDepthAsNPArray <= depth_bins[i+1]
          segDepthAsNPArray [cond_1 & cond_2] = i-1 if i>0 else 0

        # save image
        if not os.path.exists(seg_depth_dir):
          os.makedirs(seg_depth_dir)

        save_pth = os.path.join(seg_depth_dir,os.path.basename(segfile))
        #print("save_pth = " + save_pth)
        im = Image.fromarray(segDepthAsNPArray) # convert to PIL image
        im.save(save_pth)

Define show_dataset
---

In [None]:
"""
Show only num_images, every jump image. 
For example, if num_images = 5, jump = 10,
then we show only 5 images, every 10th image in the dataset.
"""

def show_dataset(dataset, num_images = 5, jump = 10):

  # Dataset must at least contain num_images * jump images.
  if len(dataset) < jump * num_images :
    print(str(len(dataset)) + 
           ' = len(dataset) < (jump * num_images) = ' 
           + str(jump * num_images))
    return;
  
  fig = plt.figure(figsize = (15,10))
  img_idx = 0

  for i in range(len(dataset)):
    img_idx = img_idx + jump 
    image, binary_seg, depth_seg  = dataset[img_idx]
    
    image = image.squeeze(0)

    plot_subplot(image, None, fig, num_images, 0, img_idx, 'gray', 1, i)
    plot_subplot(image, binary_seg, fig, num_images, 1, img_idx, 'gray', 0.7, i)
    plot_subplot(image, depth_seg, fig, num_images, 2, img_idx, 'jet', 0.4, i)

    if i == num_images - 1 :
      plt.show()
      break

def plot_subplot(image, seg, fig, num_images, row_idx, img_idx, cmp, alph, i):
  ax = plt.subplot(3, num_images , i + 1 + row_idx * num_images)
  plt.tight_layout()
  ax.set_title('Sample #{}'.format(img_idx))
  ax.axis('off')
  plt.imshow(image,cmap ='gray', aspect = 'equal')
  if row_idx != 0:
     plt.imshow(seg,cmap = cmp, alpha = alph, aspect = 'equal')

View original dataset
---

In [None]:
folder_data = BaseDataPath +"/01"

train_dataset = CustomDataset(folder_data, generateSeg_depth = False)

show_dataset(train_dataset)

Define Class for Data Augmentation
---

In [None]:
# https://imgaug.readthedocs.io/en/latest/index.html
# https://imgaug.readthedocs.io/en/latest/source/examples_segmentation_maps.html#notebook
# https://colab.research.google.com/drive/109vu3F1LTzD1gdVV6cho9fKGx7lzbFll#scrollTo=r8PI3SKZqY8H

class ImgAugTransform:
  def __init__(self):
    self.aug_seq = iaa.Sequential([
                                   # Small gaussian blur with random sigma between 0 and 0.5.
                                   iaa.Sometimes(0.7,iaa.GaussianBlur(sigma=(0, 0.5))),
                                   # random crops, max crop is 20% of the image size         
                                   iaa.Sometimes(0.5,iaa.Crop(percent=(0, 0.2))), 
                                   # sharpen the image     
                                   iaa.Sometimes(0.5,iaa.Sharpen((0.0, 1.0))),          
                                   iaa.Sometimes(0.01,
                                                  iaa.Affine(rotate=(-2, 2), # rotate by -5 to 5 degrees (affects segmaps)
                                                  shear = (-2,2), # shear by -2 to 2 degrees (affects segmaps)
                                                  translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)}, # translate_percent (“move” image on the x-/y-axis)
                                                  scale={"x": (0.8, 1.2), "y": (0.8, 1.2)}) # “zoom” in/out
                                                  ),
                                   # apply water effect (affects segmaps)
                                   iaa.Sometimes(0.2,iaa.ElasticTransformation(alpha=(0.0,10.0), sigma=(0.0,2.0), mode = 'nearest')),  
                                   iaa.Sometimes(0.5,iaa.PerspectiveTransform(scale=(0.01, 0.15))),
                                   # Flip/mirror input images horizontally.
                                   iaa.Sometimes(0.3,iaa.Fliplr()),
                                   # Flip/mirror input images vertically. 
                                   iaa.Sometimes(0.3,iaa.Flipud()), 
                                   ], random_order=True)
    
    # Convert this augmenter from a stochastic to a deterministic one
    # in order to make the same augmentation for the sequence
    self.aug_seq = self.aug_seq.to_deterministic() 
   
  def __call__(self, img, seg):
    # just in case, make sure it's an np.array
    img = np.array(img)
    seg = np.array(seg)

    # convert seg to class SegmentationMapsOnImage
    segmap = SegmentationMapsOnImage(seg, shape = seg.shape)

    # apply augmentation
    aug_img, aug_segmap = self.aug_seq(image = img, segmentation_maps = segmap)

    # convert augmented segmentation maps back to np array 
    aug_seg = aug_segmap.get_arr();

    return aug_img, aug_seg

View augmentated dataset
---

In [None]:
folder_data = BaseDataPath + "/01"

data_transform = ImgAugTransform()

train_dataset = CustomDataset(folder_data ,transform = data_transform)

show_dataset(train_dataset)

# Models

UNet model
---

In [None]:
# sources:
# Understanding  Unet:
#  1. https://towardsdatascience.com/u-net-b229b32b4a71
#  2. https://towardsdatascience.com/unet-line-by-line-explanation-9b191c76baf5
# Kaggle Challenge source code for input of 512x512
# https://github.com/petrosgk/Kaggle-Carvana-Image-Masking-Challenge/blob/master/model/u_net.py
# https://github.com/ugent-korea/pytorch-unet-segmentation#model
# loss - Focal and Dice Loss
# 1.  https://www.kaggle.com/iafoss/unet34-dice-0-87
# 2.  https://becominghuman.ai/investigating-focal-and-dice-loss-for-the-kaggle-2018-data-science-bowl-65fb9af4f36c

class UNet(nn.Module):
  def contracting_block(self, in_channels, out_channels):
    block = torch.nn.Sequential(
        torch.nn.Conv2d(in_channels = in_channels,
                        out_channels = out_channels,
                        kernel_size = 3,
                        padding = 1),
        torch.nn.BatchNorm2d(out_channels),
        torch.nn.LeakyReLU(),
        torch.nn.Conv2d(in_channels = out_channels,
                        out_channels = out_channels,
                        kernel_size = 3,
                        padding = 1),
        torch.nn.BatchNorm2d(out_channels),
        torch.nn.LeakyReLU(),
        )
    return block
  
  def expansive_block(self, in_channels, out_channels):
    block = torch.nn.Sequential(
        torch.nn.Conv2d(in_channels=in_channels,
                        out_channels = out_channels,
                        kernel_size = 3,
                        padding = 1),
        torch.nn.BatchNorm2d(out_channels),
        torch.nn.LeakyReLU(),
        torch.nn.Conv2d( in_channels=out_channels,
                        out_channels = out_channels,
                        kernel_size = 3,
                        padding = 1),
        torch.nn.BatchNorm2d(out_channels),
        torch.nn.LeakyReLU(),
        torch.nn.Conv2d( in_channels=out_channels,
                        out_channels = out_channels,
                        kernel_size = 3,
                        padding = 1),
        torch.nn.BatchNorm2d(out_channels),
        torch.nn.LeakyReLU(),
        )
    return block

  def upsampling_block(self, in_channels, out_channels):
    return nn.ConvTranspose2d(
        in_channels = in_channels,
        out_channels = out_channels,
        kernel_size = 3,
        stride = 2,
        padding=1,
        output_padding=1)  
    
  def final_block(self, in_channels, out_channels):
    block = torch.nn.Sequential(
        torch.nn.Conv2d(in_channels = in_channels,
                        out_channels = out_channels,
                        kernel_size = 1),
        #torch.nn.Sigmoid(),
        #torch.nn.Softmax(dim = 1), 
        )    
    return  block
       
  def __init__(self, in_channel, out_channel,doPrint):
    super(UNet, self).__init__()
    
    self.doPrint =  doPrint

    #Encode
    self.maxpool  = nn.MaxPool2d(kernel_size = 2);
    self.encode0a = self.contracting_block (in_channels = in_channel, out_channels= 16) 
    self.encode0  = self.contracting_block (in_channels = 16, out_channels= 32) 
    self.encode1  = self.contracting_block (in_channels = 32, out_channels= 64) 
    self.encode2  = self.contracting_block (in_channels = 64, out_channels= 128) 
    self.encode3  = self.contracting_block (in_channels = 128, out_channels= 256) 
    self.encode4  = self.contracting_block (in_channels = 256, out_channels= 512)
    
    # Bottleneck (center)
    self.bottleneck = self.contracting_block (in_channels = 512, out_channels= 1024)
    
    # Decode
    self.upsample4 = self.upsampling_block(in_channels = 1024, out_channels= 512) #16
    self.decode4   = self.expansive_block (in_channels = 1024, out_channels= 512) 

    self.upsample3 = self.upsampling_block(in_channels = 512, out_channels= 256) #32
    self.decode3   = self.expansive_block (in_channels = 512, out_channels= 256) 

    self.upsample2 = self.upsampling_block(in_channels = 256, out_channels= 128) #64
    self.decode2   = self.expansive_block (in_channels = 256, out_channels= 128) 
    
    self.upsample1 = self.upsampling_block(in_channels = 128, out_channels= 64)  #128
    self.decode1   = self.expansive_block (in_channels = 128, out_channels= 64)

    self.upsample0 = self.upsampling_block(in_channels = 64, out_channels= 32)  #256
    self.decode0   = self.expansive_block (in_channels = 64, out_channels= 32)
    
    self.upsample0a = self.upsampling_block(in_channels = 32, out_channels= 16)  #512
    self.decode0a   = self.expansive_block (in_channels = 32, out_channels= 16)
    
    # Final layer
    self.logits = self.final_block(in_channels = 16, out_channels = out_channel)
    
    self.softmax_layer = torch.nn.Softmax(dim = 1);

  def forward(self, x):
    # Encode
    if self.doPrint: print("x.size() = "+ str(x.size()))

    if self.doPrint: print("\n" +  '\033[1m' +" Encode " + '\033[0m')

    down0a  = self.encode0a(x)
    if self.doPrint: print("down0a.size() = "+ str(down0a.size()))

    down0a_pool  = self.maxpool(down0a )
    if self.doPrint: print("down0a_pool.size() = "+ str(down0a_pool.size()))

    down0 = self.encode0(down0a_pool)
    if self.doPrint: print("down0.size() = "+ str(down0.size()))

    down0_pool  = self.maxpool(down0)
    if self.doPrint: print("down0_pool.size() = "+ str(down0_pool.size()))
    
    down1 = self.encode1(down0_pool)
    if self.doPrint: print("down1.size() = "+ str(down1.size()))

    down1_pool  = self.maxpool(down1)
    if self.doPrint: print("down1_pool.size() = "+ str(down1_pool.size()))

    down2 = self.encode2(down1_pool)
    if self.doPrint: print("down2.size() = "+ str(down2.size()))

    down2_pool  = self.maxpool(down2)
    if self.doPrint: print("down2_pool.size() = "+ str(down2_pool.size()))

    down3 = self.encode3(down2_pool)
    if self.doPrint: print("down3.size() = "+ str(down3.size()))

    down3_pool  = self.maxpool(down3)
    if self.doPrint: print("down3_pool.size() = "+ str(down3_pool.size()))
    
    down4 = self.encode4(down3_pool)
    if self.doPrint: print("down4.size() = "+ str(down4.size()))

    down4_pool  = self.maxpool(down4)
    if self.doPrint: print("down4_pool.size() = "+ str(down4_pool.size()))

    # center
    if self.doPrint: print("\n" +  '\033[1m' +" Center " + '\033[0m')
    center  = self.bottleneck(down4_pool)
    if self.doPrint: print("center .size() = "+ str(center .size()))
    
    # Decode
    if self.doPrint: print("\n" +  '\033[1m' +" Decode " + '\033[0m')
    decode_up4 = self.upsample4(center)
    if self.doPrint: print("decode_up4.size() = "+ str(decode_up4.size()))
    decode_cat4 = torch.cat((down4, decode_up4), axis = 1) 
    if self.doPrint: print("decode_cat4.size() = "+ str(decode_cat4.size()))
    decode_block4= self.decode4(decode_cat4)
    if self.doPrint: print("decode_block4.size() = "+ str(decode_block4.size()))

    decode_up3 = self.upsample3(decode_block4)
    if self.doPrint: print("decode_up3.size() = "+ str(decode_up3.size()))
    decode_cat3 = torch.cat((down3, decode_up3), axis = 1) 
    if self.doPrint: print("decode_cat3.size() = "+ str(decode_cat3.size()))
    decode_block3= self.decode3(decode_cat3)
    if self.doPrint: print("decode_block3.size() = "+ str(decode_block3.size()))

    decode_up2 = self.upsample2(decode_block3)
    if self.doPrint: print("decode_up2.size() = "+ str(decode_up2.size()))
    decode_cat2 = torch.cat((down2, decode_up2), axis = 1) 
    if self.doPrint: print("decode_cat2.size() = "+ str(decode_cat2.size()))
    decode_block2= self.decode2(decode_cat2)
    if self.doPrint: print("decode_block2.size() = "+ str(decode_block2.size()))

    decode_up1 = self.upsample1(decode_block2)
    if self.doPrint: print("decode_up1.size() = "+ str(decode_up1.size()))
    decode_cat1 = torch.cat((down1, decode_up1), axis = 1)
    if self.doPrint: print("decode_cat1.size() = "+ str(decode_cat1.size()))
    decode_block1= self.decode1(decode_cat1)
    if self.doPrint: print("decode_block1.size() = "+ str(decode_block1.size()))

    decode_up0 = self.upsample0(decode_block1)
    if self.doPrint: print("decode_up0.size() = "+ str(decode_up0.size()))
    decode_cat0 = torch.cat((down0, decode_up0), axis = 1) 
    if self.doPrint: print("decode_cat0.size() = "+ str(decode_cat0.size()))
    decode_block0= self.decode0(decode_cat0)
    if self.doPrint: print("decode_block0.size() = "+ str(decode_block0.size()))

    decode_up0a = self.upsample0a(decode_block0)
    if self.doPrint: print("decode_up0a.size() = "+ str(decode_up0a.size()))
    decode_cat0a = torch.cat((down0a, decode_up0a), axis = 1)
    if self.doPrint: print("decode_cat0a.size() = "+ str(decode_cat0a.size()))
    decode_block0a= self.decode0a(decode_cat0a)
    if self.doPrint: print("decode_block0a.size() = "+ str(decode_block0a.size()))
    
    # Final layer
    if self.doPrint: print("\n" +  '\033[1m' +" Final Layer " + '\033[0m')
    logits = self.logits (decode_block0a)
    if self.doPrint: print("logits.size() = "+ str(logits.size()))
    softmax_layer = self.softmax_layer (logits);
    if self.doPrint: print("softmax_layer.size() = "+ str(softmax_layer.size()))
    
    labels= torch.argmax(softmax_layer, dim=1)
    if self.doPrint: print("labels.size() = "+ str(labels.size()))

    return  logits,labels

Debug class UNet
---

In [None]:
# check the encoder with batch_size =2 
batch_size = 2
labels = 200

# input to the network
z = torch.randn(batch_size,1, 512 ,512) # create random image
#z= z.unsqueeze(0)
z= z.to(device) # move to GPU

# model
unet = UNet(in_channel= 1, out_channel = labels, doPrint = True) #out_channel represents number of segments desired
unet=unet.to(device)  # move to GPU

# output of the network
logits,labels  = unet(z)

del z
del logits , labels
del unet


DWTNet Model
---

In [None]:
# Sources:
# https://arxiv.org/pdf/1611.08303.pdf
# https://github.com/min2209/dwt
# About weights initialization:  
# https://stackoverflow.com/questions/49433936/how-to-initialize-weights-in-pytorch
# https://towardsdatascience.com/weight-initialization-in-neural-networks-a-journey-from-the-basics-to-kaiming-954fb9b47c79
    
class DWTNet(nn.Module):
  def __init__(self, doPrint = False, doPrintMax =False, num_labels = 16):
    super(DWTNet, self).__init__()
    
    self.doPrint    =  doPrint
    self.doPrintMax =  doPrintMax
    self.num_labels = num_labels

    ########  Direction Net ######
    
    self.direction_layer1 = torch.nn.Sequential(torch.nn.Conv2d(in_channels = 2,
                                                                out_channels = 16, #64
                                                                kernel_size = 3,
                                                                padding = 1),
                                                torch.nn.LeakyReLU(negative_slope = 0.01),
                                                torch.nn.BatchNorm2d (num_features = 16),
                                                torch.nn.Conv2d(in_channels = 16, #64
                                                                out_channels = 16, #64
                                                                kernel_size = 3,
                                                                padding = 1),
                                                torch.nn.LeakyReLU(negative_slope = 0.01),
                                                torch.nn.BatchNorm2d (num_features = 16),
                                                nn.MaxPool2d(kernel_size = (2, 2), stride = (2,2))
                                                )
    
    self.direction_layer2 = torch.nn.Sequential(torch.nn.Conv2d(in_channels = 16,#64
                                                                out_channels = 32, #128
                                                                kernel_size = 3,
                                                                padding = 1),
                                                torch.nn.LeakyReLU(negative_slope = 0.01),
                                                torch.nn.BatchNorm2d (num_features = 32),
                                                torch.nn.Conv2d(in_channels = 32, #128
                                                                out_channels = 32, #128
                                                                kernel_size = 3,
                                                                padding = 1),
                                                torch.nn.LeakyReLU(negative_slope = 0.01),
                                                torch.nn.BatchNorm2d (num_features = 32),
                                                nn.MaxPool2d(kernel_size = (2, 2), stride = (2,2))
                                                )
    
    self.direction_layer3 = torch.nn.Sequential(torch.nn.Conv2d(in_channels = 32,#128
                                                                out_channels = 64, #256
                                                                kernel_size = 3,
                                                                padding = 1),
                                                torch.nn.LeakyReLU(negative_slope = 0.01),
                                                torch.nn.BatchNorm2d (num_features = 64),
                                                torch.nn.Conv2d(in_channels = 64,#256
                                                                out_channels = 64,#256
                                                                kernel_size = 3,
                                                                padding = 1),
                                                torch.nn.LeakyReLU(negative_slope = 0.01),
                                                torch.nn.BatchNorm2d (num_features = 64),
                                                torch.nn.Conv2d(in_channels = 64,#256
                                                                out_channels = 64,#256
                                                                kernel_size = 3,
                                                                padding = 1),
                                                torch.nn.LeakyReLU(negative_slope = 0.01),
                                                torch.nn.BatchNorm2d (num_features = 64),
                                                )
    
    self.direction_layer4 = torch.nn.Sequential(torch.nn.Conv2d(in_channels = 64,#256
                                                                out_channels = 128,#512
                                                                kernel_size = 3,
                                                                padding = 1),
                                                torch.nn.LeakyReLU(negative_slope = 0.01),
                                                torch.nn.BatchNorm2d (num_features = 128),
                                                torch.nn.Conv2d(in_channels = 128,#512
                                                                out_channels = 128,#512
                                                                kernel_size = 3,
                                                                padding = 1),
                                                torch.nn.LeakyReLU(negative_slope = 0.01),
                                                torch.nn.BatchNorm2d (num_features = 128),
                                                torch.nn.Conv2d(in_channels = 128,#512
                                                                out_channels = 128,#512
                                                                kernel_size = 3,
                                                                padding = 1),
                                                torch.nn.LeakyReLU(negative_slope = 0.01),
                                                torch.nn.BatchNorm2d (num_features = 128),
                                                )
    
    self.direction_layer5 = torch.nn.Sequential(torch.nn.Conv2d(in_channels = 128, #512
                                                                out_channels = 128,#512
                                                                kernel_size = 3,
                                                                padding = 1),
                                                torch.nn.LeakyReLU(negative_slope = 0.01),
                                                torch.nn.BatchNorm2d (num_features = 128),
                                                torch.nn.Conv2d(in_channels = 128,#512
                                                                out_channels = 128,#512
                                                                kernel_size = 3,
                                                                padding = 1),
                                                torch.nn.LeakyReLU(negative_slope = 0.01),
                                                torch.nn.BatchNorm2d (num_features = 128),
                                                torch.nn.Conv2d(in_channels = 128,#512
                                                                out_channels = 128,#512
                                                                kernel_size = 3,
                                                                padding = 1),
                                                torch.nn.LeakyReLU(negative_slope = 0.01),
                                                torch.nn.BatchNorm2d (num_features = 128)
                                                )
    
    self.direction_fc5 = torch.nn.Sequential(torch.nn.Conv2d(in_channels = 128,#512
                                                             out_channels = 128,#512
                                                             kernel_size = 5,
                                                             padding = 2),
                                             torch.nn.LeakyReLU(negative_slope = 0.01),
                                             torch.nn.BatchNorm2d (num_features = 128),
                                             torch.nn.Conv2d(in_channels = 128,#512
                                                             out_channels = 128,#512
                                                             kernel_size = 1),
                                             torch.nn.LeakyReLU(negative_slope = 0.01),
                                             torch.nn.BatchNorm2d (num_features = 128),
                                             torch.nn.Conv2d(in_channels = 128,#512
                                                             out_channels = 64,#256
                                                             kernel_size = 1),
                                             torch.nn.LeakyReLU(negative_slope = 0.01),
                                             torch.nn.BatchNorm2d (num_features = 64)
                                                )
    
    self.direction_fc4 = torch.nn.Sequential(torch.nn.Conv2d(in_channels = 128,#512
                                                             out_channels = 128,#512
                                                             kernel_size = 5,
                                                             padding = 2),
                                             torch.nn.LeakyReLU(negative_slope = 0.01),
                                             torch.nn.BatchNorm2d (num_features = 128),
                                             torch.nn.Conv2d(in_channels = 128,#512
                                                             out_channels = 128,#512
                                                             kernel_size = 1),
                                             torch.nn.LeakyReLU(negative_slope = 0.01),
                                             torch.nn.BatchNorm2d (num_features = 128),
                                             torch.nn.Conv2d(in_channels = 128,#512
                                                             out_channels = 64,#256
                                                             kernel_size = 1),      
                                             torch.nn.LeakyReLU(negative_slope = 0.01),
                                             torch.nn.BatchNorm2d (num_features = 64)
                                                )
    
    self.direction_fc3 = torch.nn.Sequential(torch.nn.Conv2d(in_channels = 64,#256
                                                             out_channels = 64,#256
                                                             kernel_size = 5,
                                                             padding = 2),
                                             torch.nn.LeakyReLU(negative_slope = 0.01),
                                             torch.nn.BatchNorm2d (num_features = 64),
                                             torch.nn.Conv2d(in_channels = 64,#256
                                                             out_channels = 64,#256
                                                             kernel_size = 1),
                                             torch.nn.LeakyReLU(negative_slope = 0.01),
                                             torch.nn.BatchNorm2d (num_features = 64),
                                             torch.nn.Conv2d(in_channels = 64,#256
                                                             out_channels = 64,#256
                                                             kernel_size = 1),
                                             torch.nn.LeakyReLU(negative_slope = 0.01),
                                             torch.nn.BatchNorm2d (num_features = 64),
                                                )
    
    self.direction_upscore5 = nn.ConvTranspose2d(in_channels = 64,#256
                                                 out_channels = 64,#256
                                                 kernel_size = 8,
                                                 stride = 4,
                                                 padding=2)

    self.direction_upscore4 = nn.ConvTranspose2d(in_channels = 64,#256
                                                 out_channels = 64,#256
                                                 kernel_size = 4,
                                                 stride = 2,
                                                 padding = 1)   

    self.direction_fuse3 = torch.nn.Sequential(torch.nn.Conv2d(in_channels = 64 * 3,#256
                                                               out_channels = 128,#512
                                                               kernel_size = 1,
                                                               padding = 0),
                                               torch.nn.LeakyReLU(negative_slope = 0.01),
                                               torch.nn.BatchNorm2d (num_features = 128),
                                               torch.nn.Conv2d(in_channels = 128,#512
                                                               out_channels = 128,#512
                                                               kernel_size = 1),        
                                               torch.nn.LeakyReLU(negative_slope = 0.01),
                                               torch.nn.BatchNorm2d (num_features = 128),
                                               torch.nn.Conv2d(in_channels = 128,#512
                                                               out_channels = 2,
                                                               kernel_size = 1),
                                                )   
    
    self.direction = nn.ConvTranspose2d(in_channels = 2,
                                        out_channels = 2,
                                        kernel_size = 8,
                                        stride = 4,
                                        padding =2) 

    self.avgPool2d =  nn.AvgPool2d(kernel_size = (2, 2), stride = (2,2))
    
    ########  END Direction Net ######
    
    ########  Watershed Transform Net ######

    self.watershed_layer1 = torch.nn.Sequential(torch.nn.Conv2d(in_channels = 2,
                                                                out_channels = 16, #64
                                                                kernel_size = 5,
                                                                padding = 2),
                                                torch.nn.LeakyReLU(negative_slope = 0.01),
                                                torch.nn.BatchNorm2d (num_features = 16),
                                                torch.nn.Conv2d(in_channels = 16, #64
                                                                out_channels = 32, #128
                                                                kernel_size = 5,
                                                                padding = 2),
                                                torch.nn.LeakyReLU(negative_slope = 0.01),
                                                torch.nn.BatchNorm2d (num_features = 32),
                                                )
    
    self.watershed_layer2 = torch.nn.Sequential(torch.nn.Conv2d(in_channels = 32, #128
                                                                out_channels = 32,#128
                                                                kernel_size = 5,
                                                                padding = 2),
                                                torch.nn.LeakyReLU(negative_slope = 0.01),
                                                torch.nn.BatchNorm2d (num_features = 32),
                                                torch.nn.Conv2d(in_channels = 32,#128
                                                                out_channels = 32,#128
                                                                kernel_size = 5,
                                                                padding = 2),
                                                torch.nn.LeakyReLU(negative_slope = 0.01),
                                                torch.nn.BatchNorm2d (num_features = 32),
                                                torch.nn.Conv2d(in_channels = 32,#128
                                                                out_channels = 32,#128
                                                                kernel_size = 5,
                                                                padding = 2),
                                                torch.nn.LeakyReLU(negative_slope = 0.01),
                                                torch.nn.BatchNorm2d (num_features = 32),
                                                torch.nn.Conv2d(in_channels = 32,#128
                                                                out_channels = 32,#128
                                                                kernel_size = 5,
                                                                padding = 2),
                                                torch.nn.LeakyReLU(negative_slope = 0.01),
                                                torch.nn.BatchNorm2d (num_features = 32),
                                                )
    
    self.watershed_fc1 = torch.nn.Sequential(torch.nn.Conv2d(in_channels = 32, #128
                                                             out_channels = 32,#128
                                                             kernel_size = 1),
                                             torch.nn.LeakyReLU(negative_slope = 0.01),
                                             torch.nn.BatchNorm2d (num_features = 32),
                                             torch.nn.Dropout(p = 0.7 ))

    self.watershed_fc2 = torch.nn.Sequential(torch.nn.Conv2d(in_channels = 32,#128
                                                             out_channels = self.num_labels,
                                                             kernel_size = 1),
                                             torch.nn.LeakyReLU(negative_slope = 0.01),
                                             torch.nn.BatchNorm2d (num_features = self.num_labels),
                                             torch.nn.Dropout(p = 0.7 ))
    
    self.outputData  = nn.ConvTranspose2d(in_channels = self.num_labels,
                                          out_channels = self.num_labels,
                                          kernel_size = 8,
                                          stride = 4,
                                          padding =2)

    self.softmax_layer =  torch.nn.Softmax(dim = 1);
 
  def init_weights(self,m):   
    for name, param in m.named_parameters():
      if type(m) == nn.Conv2d:
        if 'bias' in name:         
          nn.init.constant_(param, 0.5)       
        elif 'weight' in name:         

          torch.nn.init.xavier_normal_(param, gain=1.0)


  def forward(self, x, seg):
   # x size   (batch_size x 1 x 512 x 512)
   # seg size (BATCH_SIZE x 512 x512)
   
   if self.doPrint: print("x.size() = "+ str(x.size())) , print("seg.size() = "+ str(seg.size()))
   if self.doPrintMax: print("x (min,max) = ("+ str(x.min().item()) + "," + str(x.max().item())+")")

   x = torch.cat((x, seg.unsqueeze(1)), axis = 1)
   if self.doPrint: print("concat x.size() = "+ str(x.size())) 
   if self.doPrintMax: print("concat x (min,max) = ("+ str(x.min().item()) + "," + str(x.max().item())+")")

   if self.doPrint: print("\n" +  '\033[1m' +" Direction Net " + '\033[0m')
   
   direction_layer1 = self.direction_layer1(x)
   if self.doPrint: print(" direction_layer1.size() = "+ str(direction_layer1.size())) 
   if self.doPrintMax: print("direction_layer1 (min,max) = ("+ str(direction_layer1.min().item()) + "," + str(direction_layer1.max().item())+")")

   direction_layer2 = self.direction_layer2(direction_layer1)
   if self.doPrint: print(" direction_layer2.size() = "+ str(direction_layer2.size()))
   if self.doPrintMax: print("direction_layer2 (min,max) = ("+ str(direction_layer2.min().item()) + "," + str(direction_layer2.max().item())+")")

   direction_layer3 = self.direction_layer3(direction_layer2)
   if self.doPrint: print(" direction_layer3.size() = "+ str(direction_layer3.size()))
   if self.doPrintMax: print("direction_layer3 (min,max) = ("+ str(direction_layer3.min().item()) + "," + str(direction_layer3.max().item())+")")

   direction_layer3_avgpool= self.avgPool2d(direction_layer3)
   if self.doPrint: print(" direction_layer3_avgpool.size() = "+ str(direction_layer3_avgpool.size())) 
   if self.doPrintMax: print("direction_layer3_avgpool (min,max) = ("+ str(direction_layer3_avgpool.min().item()) + "," + str(direction_layer3_avgpool.max().item())+")")

   direction_layer4 = self.direction_layer4(direction_layer3_avgpool)
   if self.doPrint: print(" direction_layer4.size() = "+ str(direction_layer4.size()))
   if self.doPrintMax: print("direction_layer4 (min,max) = ("+ str(direction_layer4.min().item()) + "," + str(direction_layer4.max().item())+")")

   direction_layer4_avgpool= self.avgPool2d(direction_layer4)
   if self.doPrint: print(" direction_layer4_avgpool.size() = "+ str(direction_layer4_avgpool.size())) 
   if self.doPrintMax: print("direction_layer4_avgpool (min,max) = ("+ str(direction_layer4_avgpool.min().item()) + "," + str(direction_layer4_avgpool.max().item())+")")

   direction_layer5 = self.direction_layer5(direction_layer4_avgpool)
   if self.doPrint: print(" direction_layer5.size() = "+ str(direction_layer5.size()))
   if self.doPrintMax: print("direction_layer5 (min,max) = ("+ str(direction_layer5.min().item()) + "," + str(direction_layer5.max().item())+")")
   
   direction_fc5 = self.direction_fc5(direction_layer5)
   if self.doPrint: print(" direction_fc5.size() = "+ str(direction_fc5.size()))
   if self.doPrintMax: print("direction_fc5 (min,max) = ("+ str(direction_fc5.min().item()) + "," + str(direction_fc5.max().item())+")")

   direction_fc4 = self.direction_fc4(direction_layer4)
   if self.doPrint: print(" direction_fc4.size() = "+ str(direction_fc4.size()))
   if self.doPrintMax: print("direction_fc4 (min,max) = ("+ str(direction_fc4.min().item()) + "," + str(direction_fc4.max().item())+")")

   direction_fc3 = self.direction_fc3(direction_layer3)
   if self.doPrint: print(" direction_fc3.size() = "+ str(direction_fc3.size()))
   if self.doPrintMax: print("direction_fc3 (min,max) = ("+ str(direction_fc3.min().item()) + "," + str(direction_fc3.max().item())+")")

   direction_upscore5 = self.direction_upscore5(direction_fc5)
   if self.doPrint: print(" direction_upscore5.size() = "+ str(direction_upscore5.size())) 
   if self.doPrintMax: print("direction_upscore5 (min,max) = ("+ str(direction_upscore5.min().item()) + "," + str(direction_upscore5.max().item())+")")

   direction_upscore4 = self.direction_upscore4(direction_fc4)
   if self.doPrint: print(" direction_upscore4.size() = "+ str(direction_upscore4.size())) 
   if self.doPrintMax: print("direction_upscore4 (min,max) = ("+ str(direction_upscore4.min().item()) + "," + str(direction_upscore4.max().item())+")")

   direction_cat_fuse3 = torch.cat((direction_fc3, direction_upscore5,direction_upscore4), axis = 1) 
   if self.doPrint: print(" direction_cat_fuse3.size() = "+ str(direction_cat_fuse3.size())) 
   if self.doPrintMax: print("direction_cat_fuse3 (min,max) = ("+ str(direction_cat_fuse3.min().item()) + "," + str(direction_cat_fuse3.max().item())+")")

   direction_fuse3 = self.direction_fuse3(direction_cat_fuse3)
   if self.doPrint: print(" direction_fuse3.size() = "+ str(direction_fuse3.size())) 
   if self.doPrintMax: print("direction_fuse3 (min,max) = ("+ str(direction_fuse3.min().item()) + "," + str(direction_fuse3.max().item())+")")

   direction = self.direction(direction_fuse3)
   if self.doPrint: print(" direction.size() = "+ str(direction.size())) 
   if self.doPrintMax: print("direction (min,max) = ("+ str(direction.min().item()) + "," + str(direction.max().item())+")")
   
   direction_normalized =F.normalize(direction, p=2, dim=1, eps=1e-20, out=None)
   if self.doPrint: print(" direction_normalized.size() = "+ str(direction_normalized.size()))
   if self.doPrintMax: print("direction_normalized (min,max) = ("+ str(direction_normalized.min().item()) + "," + str(direction_normalized.max().item())+")") 

   if self.doPrint: print("\n" +  '\033[1m' +" Watershed Transform Net " + '\033[0m')

   watershed_layer1 =self.watershed_layer1(direction_normalized)
   if self.doPrint: print(" watershed_layer1.size() = "+ str(watershed_layer1.size()))
   if self.doPrintMax: print("watershed_layer1 (min,max) = ("+ str(watershed_layer1.min().item()) + "," + str(watershed_layer1.max().item())+")") 

   watershed_layer1_avgpool =self.avgPool2d(watershed_layer1)
   if self.doPrint: print(" watershed_layer1_avgpool.size() = "+ str(watershed_layer1_avgpool.size()))
   if self.doPrintMax: print("watershed_layer1_avgpool (min,max) = ("+ str(watershed_layer1_avgpool.min().item()) + "," + str(watershed_layer1_avgpool.max().item())+")")   
   
   watershed_layer2 =self.watershed_layer2(watershed_layer1_avgpool)
   if self.doPrint: print(" watershed_layer2.size() = "+ str(watershed_layer2.size()))
   if self.doPrintMax: print("watershed_layer2 (min,max) = ("+ str(watershed_layer2.min().item()) + "," + str(watershed_layer2.max().item())+")")   
   
   watershed_layer2_avgpool =self.avgPool2d(watershed_layer2)
   if self.doPrint: print(" watershed_layer2_avgpool.size() = "+ str(watershed_layer2_avgpool.size()))
   if self.doPrintMax: print("watershed_layer2_avgpool (min,max) = ("+ str(watershed_layer2_avgpool.min().item()) + "," + str(watershed_layer2_avgpool.max().item())+")")   
   
   watershed_fc1 =self.watershed_fc1(watershed_layer2_avgpool)
   if self.doPrint: print(" watershed_fc1.size() = "+ str(watershed_fc1.size()))
   if self.doPrintMax: print("watershed_fc1 (min,max) = ("+ str(watershed_fc1.min().item()) + "," + str(watershed_fc1.max().item())+")")   
   
   watershed_fc2 =self.watershed_fc2(watershed_fc1)
   if self.doPrint: print(" watershed_fc2.size() = "+ str(watershed_fc2.size()))
   if self.doPrintMax: print("watershed_fc2 (min,max) = ("+ str(watershed_fc2.min().item()) + "," + str(watershed_fc2.max().item())+")")  
   
   logits =self.outputData(watershed_fc2)
   if self.doPrint: print(" logits.size() = "+ str(logits.size()))
   if self.doPrintMax: print("logits (min,max) = ("+ str(logits.min().item()) + "," + str(logits.max().item())+")")      
   
   softmax_layer = self.softmax_layer (logits);
   if self.doPrint: print(" softmax_layer.size() = "+ str(softmax_layer.size()))
   if self.doPrintMax: print("softmax_layer (min,max) = ("+ str(softmax_layer.min().item()) + "," + str(softmax_layer.max().item())+")")
   
   labels= torch.argmax(softmax_layer, dim=1)
   if self.doPrint: print(" labels.size() = "+ str(labels.size()))
   if self.doPrintMax: print("labels (min,max) = ("+ str(labels.min().item()) + "," + str(labels.max().item())+")") 

   return  logits, labels

Debug DWTNet Model
---

In [None]:
# check the encoder with batch_size =2 
batch_size = 1

# input to the network
x = torch.randn  (batch_size ,1, 512 ,512) # create random image
seg = torch.randn(batch_size , 512 ,512) # create random image

x= x.to(device) # move to GPU
seg= seg.to(device) # move to GPU

# model
dwtnet = DWTNet(doPrint = True, doPrintMax = False) # out_channel represents number of segments desired
dwtnet = dwtnet.to(device)  # move to GPU

# output of the network
out = dwtnet(x,seg)

del x, seg
del out
del dwtnet

UDWTNet Model
---

In [None]:
class UDWTNet (nn.Module) : 
  # constructor 
  def __init__(self,unet,dwtnet):
    super(UDWTNet, self).__init__() # call to super constructor
  
    self.unet   = unet
    self.dwtnet = dwtnet

  def forward(self, x):
    _, estimate_labels  =  self.unet(x)
    logits, estimate_depth_maps =  self.dwtnet(x, estimate_labels)
    return logits, estimate_labels, estimate_depth_maps

Debug UDWTNet model
---

In [None]:
# check the encoder with batch_size =2 
batch_size = 1

# input to the network
x = torch.randn  (batch_size ,1, 512 ,512) # create random image

x= x.to(device) # move to GPU

# model
unet = UNet(in_channel= 1, out_channel = 2 , doPrint = True) #out_channel represents number of segments desired
unet=unet.to(device)  # move to GPU
dwtnet = DWTNet(doPrint = True, doPrintMax = False) # out_channel represents number of segments desired
dwtnet = dwtnet.to(device)  # move to GPU

udwtnet = UDWTNet(unet,dwtnet)
udwtnet = udwtnet.to(device)  # move to GPU

# output of the network
logits, estimate_labels,estimate_depth_maps = udwtnet(x)

del x
del logits,estimate_labels, estimate_depth_maps
del dwtnet

# Train supporting functions

Define Load Model
---

In [None]:
def LoadUnetModel(model_save_name):

  model = UNet(in_channel= 1, out_channel = 2, doPrint = False) #out_channel represents number of segments desired
     
  loadpath=BaseSavePath +F"/{model_save_name}"       
  print ('loading model in '+ loadpath) 
  model.load_state_dict(torch.load(loadpath,map_location='cpu'))
  
  return model
  
def LoadDWTModel(model_save_name):
  
  model = DWTNet() 
     
  loadpath=BaseSavePath +F"/{model_save_name}"       
  print ('loading model in '+ loadpath) 
  model.load_state_dict(torch.load(loadpath,map_location='cpu'))
  
  return model 

def LoadUDWTModel(model_save_name):
  
  unet = UNet(in_channel= 1, out_channel = 2, doPrint= False) #out_channel represents number of segments desired
  dwtnet = DWTNet(num_labels = 16) 

  model = UDWTNet(unet, dwtnet) 
     
  loadpath=BaseSavePath +F"/{model_save_name}"       
  print ('loading model in '+ loadpath) 
  model.load_state_dict(torch.load(loadpath,map_location='cpu'))
  
  return model 


Define plotModelResult
---

In [None]:
def plotModelResultUnet(image, true_label, estimate_label, i):  
  # to cpu
  image = image.cpu()
  true_label = true_label.cpu()
  estimate_label = estimate_label.cpu()
  
  fig = plt.figure(figsize = (20,10))
  
  ax = plt.subplot(1, 3, 1)
  plt.tight_layout()
  ax.set_title('Sample #{}'.format(i))
  ax.axis('off')
  plt.imshow(image,cmap='gray',aspect='equal')
  
  ax = plt.subplot(1, 3, 2)
  plt.tight_layout()
  ax.set_title('True label Sample #{}'.format(i))
  ax.axis('off')
  plt.imshow(image,cmap='gray' ,aspect='equal')
  plt.imshow(true_label,cmap='gray', alpha=0.7, aspect='equal')

  ax = plt.subplot(1, 3, 3)
  plt.tight_layout()
  ax.set_title('Estimate label Sample #{}'.format(i))
  ax.axis('off')
  plt.imshow(image,cmap='gray' ,aspect='equal')
  plt.imshow(estimate_label,cmap='gray', alpha=0.7, aspect='equal') 
   
  plt.show()

def plotModelResultDWTNet(image, true_label, true_depth,estimate_depth, i):  
  # to cpu
  image = image.cpu()
  true_label = true_label.cpu()
  true_depth= true_depth.cpu()
  estimate_depth = estimate_depth.cpu()
  
  fig = plt.figure(figsize = (20,10))
  
  ax = plt.subplot(1, 4, 1)
  plt.tight_layout()
  ax.set_title('Sample #{}'.format(i))
  ax.axis('off')
  plt.imshow(image,cmap='gray',aspect='equal')
  
  ax = plt.subplot(1, 4, 2)
  plt.tight_layout()
  ax.set_title('True label Sample #{}'.format(i))
  ax.axis('off')
  plt.imshow(image,cmap='gray' ,aspect='equal')
  plt.imshow(true_label,cmap='gray', alpha=0.7, aspect='equal')

  ax = plt.subplot(1, 4, 3)
  plt.tight_layout()
  ax.set_title('True depth Sample #{}'.format(i))
  ax.axis('off')
  plt.imshow(image,cmap='gray' ,aspect='equal')
  plt.imshow(true_depth,cmap='jet', alpha=0.6, aspect='equal')

  ax = plt.subplot(1, 4, 4)
  plt.tight_layout()
  ax.set_title('Estimate depth Sample #{}'.format(i))
  ax.axis('off')
  plt.imshow(image,cmap='gray' ,aspect='equal')
  plt.imshow(estimate_depth,cmap='jet', alpha=0.6, aspect='equal')
   
  plt.show()


def plotModelResult(image, true_label, true_depth, estimate_label,estimate_depth, i):  
  # to cpu
  image = image.cpu()
  true_label = true_label.cpu()
  true_depth= true_depth.cpu()
  estimate_label = estimate_label.cpu()
  estimate_depth = estimate_depth.cpu()
  
  fig = plt.figure(figsize = (20,10))
  
  ax = plt.subplot(1, 5, 1)
  plt.tight_layout()
  ax.set_title('Sample #{}'.format(i))
  ax.axis('off')
  plt.imshow(image,cmap='gray',aspect='equal')
  
  ax = plt.subplot(1, 5, 2)
  plt.tight_layout()
  ax.set_title('True label Sample #{}'.format(i))
  ax.axis('off')
  plt.imshow(image,cmap='gray' ,aspect='equal')
  plt.imshow(true_label,cmap='gray', alpha=0.7, aspect='equal')

  ax = plt.subplot(1, 5, 3)
  plt.tight_layout()
  ax.set_title('True depth Sample #{}'.format(i))
  ax.axis('off')
  plt.imshow(image,cmap='gray' ,aspect='equal')
  plt.imshow(true_depth,cmap='jet', alpha=0.6, aspect='equal')

  ax = plt.subplot(1, 5, 4)
  plt.tight_layout()
  ax.set_title('Estimate label Sample #{}'.format(i))
  ax.axis('off')
  plt.imshow(image,cmap='gray' ,aspect='equal')
  plt.imshow(estimate_label,cmap='gray', alpha=0.7, aspect='equal') 

  ax = plt.subplot(1, 5, 5)
  plt.tight_layout()
  ax.set_title('Estimate depth Sample #{}'.format(i))
  ax.axis('off')
  plt.imshow(image,cmap='gray' ,aspect='equal')
  plt.imshow(estimate_depth,cmap='jet', alpha=0.6, aspect='equal')
   
  plt.show()


Dice Loss
---

In [None]:
# https://www.jeremyjordan.me/semantic-segmentation/
# https://forums.fast.ai/t/understanding-the-dice-coefficient/5838
# https://stackoverflow.com/questions/61488732/how-calculate-the-dice-coefficient-for-multi-class-segmentation-task-using-pytho
# https://github.com/kevinzakka/pytorch-goodies/blob/master/losses.py

# good artical about losses compare Dice vs CrossEntropy vs 
# https://arxiv.org/pdf/1911.01685.pdf

def calculate_dice_loss(logits,true, eps = 1e-7):
    """Computes the Sørensen–Dice loss.
    Note that PyTorch optimizers minimize a loss. In this
    case, we would like to maximize the dice loss so we
    return the negated dice loss.
    Args:
        true: a tensor of shape [Batch size  x 512 x 512].
        logits: a tensor of shape [Batch size x numLabels x 512 x 512]. Corresponds to
            the raw output or logits of the model.
        eps: added to the denominator for numerical stability.
    Returns:
        dice_loss: the Sørensen–Dice loss.
    """

    num_classes = logits.shape[1]
    true = true.unsqueeze(1) # now true: a tensor of shape [Batch size x 1 x 512 x 512].

    true_1_hot = torch.eye(num_classes)[true.squeeze(1)]
    true_1_hot = true_1_hot.permute(0, 3, 1, 2).float().contiguous()
    probas = F.softmax(logits, dim = 1)
    
    true_1_hot = true_1_hot.type(logits.type())
    dims = (0,) + tuple(range(2, true.ndimension()))
    intersection = torch.sum(probas * true_1_hot, dims)
    cardinality = torch.sum(probas + true_1_hot, dims)
    dice_loss = (2. * intersection / (cardinality + eps)).mean()
    return (1 - dice_loss)

IOU
---

In [None]:
def calculate_iou(estimate_labels, true_labels, eps = 1e-7):
    # You can comment out this line if you are passing tensors of equal shape
    # But if you are passing output from UNet or something it will most probably
    # be with the BATCH x 1 x H x W shape
    
    estimate_labels = estimate_labels.squeeze(1)  # BATCH x 1 x H x W => BATCH x H x W

    intersection = (estimate_labels & true_labels).float().sum((1, 2))  # Will be zero if Truth=0 or Prediction=0
    union = (estimate_labels | true_labels).float().sum((1, 2))         # Will be zero if both are 0
    
    iou = (intersection + eps) / (union + eps)  # We smooth our devision to avoid 0/0
    
    thresholded = torch.clamp(20 * (iou - 0.5), 0, 10).ceil() / 10  # This is equal to comparing with thresolds
    
    return thresholded.mean()  #  average across the batch

Average of list
---

In [None]:
def Average(lst): 
  if len(lst) > 0: 
    return sum(lst) / len(lst) 
  else: 
    return 0

Post Processing of Estimated Label
---

In [None]:
Unet_LABELS = {"cell":1}#, "cell_border":2}
Unet_THRESHOLD = {"cell":0}

def PostProcessingUnet (estimate_labels, doPlot = False):
  # estimate_labels size(BATCH_SIZE x 512 x512)

  if doPlot : print("################### Post Processing Unet #######################")
  output = np.zeros(estimate_labels.size(),dtype = np.int_ )
  
  for i in range(output.shape[0]): #batch loop 
   current_estimate_labels = estimate_labels[i,:,:].cpu().numpy()
   #plt.set_title(' input ')
   for semLabel in Unet_LABELS.keys(): # label loop 
     current_MaskLabel = (current_estimate_labels > Unet_THRESHOLD[semLabel])
     if doPlot :
        fig = plt.figure(figsize = (20,10)) 
        ax = plt.subplot(1, 4, 1)
        plt.imshow(current_MaskLabel,cmap='gray',aspect='equal')
        ax.set_title(' before ')
     current_MaskLabel = skimage.morphology.remove_small_objects(current_MaskLabel, min_size = 20)
     
     if doPlot :
       ax = plt.subplot(1, 4, 2)
       plt.imshow(current_MaskLabel,cmap='gray',aspect='equal')
       ax.set_title(' after remove_small_objects ')
     current_MaskLabel = skimage.morphology.remove_small_holes(current_MaskLabel ,area_threshold = 2500)
     if doPlot :
       ax = plt.subplot(1, 4, 3)
       plt.imshow(current_MaskLabel,cmap='gray',aspect='equal')
       ax.set_title(' after remove_small_holes ')
     #current_MaskLabel = skimage.morphology.binary_dilation(current_MaskLabel ,  np.ones((3,3)).astype(np.bool) )
     if doPlot :
       ax = plt.subplot(1, 4, 4)
       plt.imshow(current_MaskLabel,cmap='gray',aspect='equal')
       ax.set_title(' after binary_dilation ') 

     output[i,:,:] =  current_MaskLabel  
  
  if doPlot :
    fig = plt.figure(figsize = (20,10)) 
    ax = plt.subplot(1, 2, 1)
    plt.tight_layout()
    ax.set_title(' before PostProcessing ')
    ax.axis('off')
    plt.imshow(estimate_labels[0,:,:].cpu(),cmap='gray',aspect='equal')
   
    ax = plt.subplot(1, 2, 2)
    plt.tight_layout()
    ax.set_title(' after PostProcessing ')
    ax.axis('off')
    plt.imshow(output[0,:,:],cmap='gray' ,aspect='equal')
  
  output = torch.from_numpy(output) 
  return output.to(device);

##############################################
CLASS_TO_SS = {"cell":1}#, "border":2}
#LABELS = {"cell":1, "boundary":2}
THRESHOLD = {"cell":2, "border": 0}
SELEM = {0: (np.ones((5,5))).astype(np.bool), # kernels 
         6: (np.ones((3,3))).astype(np.bool)}

def PostProcessingDWT(depthImage_input, ssMask_input, doPlot = False):
  if doPlot : print("################### Post Processing DWT #######################")
  resultImage = np.zeros(shape=ssMask_input.shape, dtype=np.uint8) #np.float32
  for i in range(depthImage_input.size()[0]): # batch loop 
    depthImage = depthImage_input[i,:,:].cpu().numpy()
    ssMask = ssMask_input[i,:,:].cpu().numpy().astype(np.int32)

    for semClass in CLASS_TO_SS.keys():
      ssCode = CLASS_TO_SS[semClass]
      ssMaskClass = (ssMask == ssCode)

      ccImage = (depthImage > THRESHOLD[semClass]) * ssMaskClass
      
      if doPlot :
        fig = plt.figure(figsize = (20,10)) 
        ax = plt.subplot(1, 5, 1)
        plt.imshow(ccImage,cmap='gray',aspect='equal')
        ax.set_title(' before ')
 
      # to smooth the cell edges, effect created by the thresholding  
     # ccImage =skimage.morphology.erosion(ccImage)#, selem = disk(6))
      ccImage= skimage.filters.median(ccImage)#,selem = np.ones((6,6)).astype(np.bool), mode='nearest') 
      if doPlot :
        fig = plt.figure(figsize = (20,10)) 
        ax = plt.subplot(1, 5, 2)
        plt.imshow(ccImage,cmap='gray',aspect='equal')
        ax.set_title(' after median blur ')

      ccImage = skimage.morphology.remove_small_objects(ccImage, min_size=20)
      
      if doPlot :
        ax = plt.subplot(1, 5, 3)
        plt.imshow(ccImage,cmap='gray',aspect='equal')
        ax.set_title(' after remove_small_objects ')

      ccImage = skimage.morphology.remove_small_holes(ccImage,area_threshold = 2500)
      
      if doPlot :
        ax = plt.subplot(1, 5, 4)
        plt.imshow(ccImage,cmap='gray',aspect='equal')
        ax.set_title(' after remove_small_holes ')
        
      #ccImage =skimage.morphology.binary_dilation(ccImage, (np.ones((5,5))).astype(np.bool))
      if doPlot :
        ax = plt.subplot(1, 5, 5)
        plt.imshow(ccImage,cmap='gray',aspect='equal')
        ax.set_title(' after binary_dilation ')

      resultImage[i,:,:] = ccImage
      """ 
      ccLabels = skimage.morphology.label(ccImage)
      
      if doPlot :
        ax = plt.subplot(1, 6, 6)
        plt.imshow(ccLabels,cmap='jet',aspect='equal')
        ax.set_title(' after labels ')
      
       
      resultImage[i,:,:] = ccLabels
     
      ccIDs = np.unique(ccLabels)[1:] 
      
      for ccID in ccIDs:
        ccIDMask = (ccLabels == ccID)
        #ccIDMask = skimage.morphology.binary_dilation(ccIDMask, SELEM[THRESHOLD[semClass]])
        # ccID max size is the maximum, cells in the image. say it 100 cells 
        # ssCode is the maximum label 
        instanceID = 1000 * ssCode + ccID
        resultImage[ccIDMask] = instanceID
      """

  #resultImage_tot = resultImage_tot.astype(np.int16)

  if doPlot:
    fig = plt.figure(figsize = (20,10)) 
    ax = plt.subplot(1, 2, 1)
    plt.tight_layout()
    ax.set_title(' before PostProcessing ')
    ax.axis('off')
    plt.imshow(depthImage_input[0,:,:].cpu(),cmap='jet',aspect='equal')
   
    ax = plt.subplot(1, 2, 2)
    plt.tight_layout()
    ax.set_title(' after PostProcessing ')
    ax.axis('off')
    plt.imshow(resultImage[0,:,:],cmap='gray' ,aspect='equal')

  resultImage = torch.from_numpy(resultImage).to(device)
  return resultImage


Debug Post Processing
---

In [None]:
folder_data = BaseDataPath +"/01"
test_dataset = CustomDataset(folder_data)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size= 1 , shuffle=True, num_workers=2 )

# load model
model = LoadUnetModel(Unet_save_name);
model=model.to(device)  # move to GPU


# load model
#model = LoadUDWTModel(UDWTNet_save_name)
#model=model.to(device)  # move to GPU

# declare that we at evaluation mode
model = model.eval()

for i, (images, binary_true_labels ,depth_true_maps) in enumerate(test_loader):
  
  _,estimate_labels=model(images.to(device))
  estimate_labels_output = PostProcessingUnet(estimate_labels, doPlot = True)
  del estimate_labels, estimate_labels_output
  #_,estimate_labels,estimate_map=model(images.to(device))
  #depthImage_output = PostProcessingDWT(estimate_map, estimate_labels, doPlot = True)
  #del estimate_labels, estimate_map

  if i > 2:
    break

del model

Define Gradients Graph
---

In [None]:
def plot_grad_flow(named_parameters):
    '''Plots the gradients flowing through different layers in the net during training.
    Can be used for checking for possible gradient vanishing / exploding problems.
    
    Usage: Plug this function in Trainer class after loss.backwards() as 
    "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow'''
    ave_grads = []
    max_grads= []
    layers = []
    for n, p in named_parameters:
        if(p.requires_grad) and ("bias" not in n):
            layers.append(n)
            ave_grads.append(p.grad.abs().mean())
            max_grads.append(p.grad.abs().max())
    plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c")
    plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b")
    plt.hlines(0, 0, len(ave_grads)+1, lw=2, color="k" )
    plt.xticks(range(0,len(ave_grads), 1), layers, rotation="vertical")
    plt.xlim(left=0, right=len(ave_grads))
    plt.ylim(bottom = -0.001, top=0.02) # zoom in on the lower gradient regions
    plt.xlabel("Layers")
    plt.ylabel("average gradient")
    plt.title("Gradient flow")
    plt.grid(True)
    plt.legend([matplotlib.lines.Line2D([0], [0], color="c", lw=4),
                matplotlib.lines.Line2D([0], [0], color="b", lw=4),
                matplotlib.lines.Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient'], loc = 'upper left')

# Training

Train UNet
---

In [None]:
# train dataset
PlotDuringTraining  = False

# NETWORK PARAMS
BATCH_SIZE = 2
NUM_LABELS = 2 
NUM_IMAGE_CHANNELS = 1 
NUM_EPOCHS = 100

# Train loader
folder_data = BaseDataPath +"/01"
# divide to train and validation
dataset = CustomDataset(folder_data)
dataset_len = len(dataset)
dataset_indices = [[x] for x in range(dataset_len)]
picked = random.sample(dataset_indices,math.floor(.7*dataset_len))
flat_list = []
for sublist in picked:
    for item in sublist:
        flat_list.append(item)
picked = flat_list 

train_indices=[]
valid_indices=[]
for i in range (dataset_len):
  if i in picked:
    train_indices.append(i)
  else:
    valid_indices.append(i)


unet = UNet(in_channel= NUM_IMAGE_CHANNELS, out_channel = NUM_LABELS, doPrint = False) #out_channel represents number of segments desired
unet=unet.to(device)  # move to GPU

#loss
CrossEntropy_criterion = torch.nn.CrossEntropyLoss()

# Optimizer
optimizer = torch.optim.SGD(unet.parameters(), lr = 0.01, momentum=0.99)
optimizer.zero_grad()

row_list = []
for epoch in range(NUM_EPOCHS):
  train_loss_list = []
  train_iou_list  = []
  
  valid_loss_list = []
  valid_iou_list  = []
  
  data_transform = ImgAugTransform()
  train_dataset = CustomDataset(folder_data, transform =  data_transform)
  train_loader = torch.utils.data.DataLoader(torch.utils.data.Subset(train_dataset, train_indices), batch_size= BATCH_SIZE , shuffle=True, num_workers=2 )
  valid_loader = torch.utils.data.DataLoader(torch.utils.data.Subset(train_dataset, valid_indices), batch_size= BATCH_SIZE , shuffle=True, num_workers=2 )
  #print('epoch = ' + str(epoch))
  #show_dataset(train_dataset)

  epoch_time_start = time.time();
  ## training loop
  unet = unet.train() # declare training
  for i, (images, true_labels,_) in enumerate(train_loader):

    images = images.to(device)           #move to gpu   # images size (BATCH_SIZE x 1 x 512 x512)
    true_labels = true_labels.to(device) #move to gpu   # true_labels size(BATCH_SIZE x 512 x512)
  
    logits,estimate_labels = unet(images)  # logits size (BATCH_SIZE x NUM_LABELS x 512 x512) ,  estimate_labels size(BATCH_SIZE x 1 x 512 x512)
    
    # loss values
    CrossEntropy_Loss = CrossEntropy_criterion(logits, true_labels)
    dice_loss =  calculate_dice_loss(logits,true_labels)
    
    iou = calculate_iou(estimate_labels, true_labels)

    loss = CrossEntropy_Loss # dice_loss 

    # Back propagation
    unet.zero_grad()
    loss.backward()
    #plot_grad_flow(unet.named_parameters())
    optimizer.step()

    train_loss_list.append(loss.item())
    train_iou_list.append(iou.item())
      
    del images , true_labels
    del logits, estimate_labels
    del loss , dice_loss, iou

  ## end train_loop

  ## validation loop 
  unet = unet.eval() # declare validation
  with torch.no_grad():
    for i, (images, true_labels,_) in enumerate(valid_loader):

      images = images.to(device)           #move to gpu   # images size (BATCH_SIZE x 1 x 512 x512)
      true_labels = true_labels.to(device) #move to gpu   # true_labels size(BATCH_SIZE x 512 x512)
        
      logits,estimate_labels = unet(images)  # logits size (BATCH_SIZE x NUM_LABELS x 512 x512) ,  estimate_labels size(BATCH_SIZE x 512 x512)

      # loss values
      dice_loss =  calculate_dice_loss(logits,true_labels)
      CrossEntropy_Loss = CrossEntropy_criterion(logits, true_labels)
      loss = CrossEntropy_Loss #dice_loss

      iou = calculate_iou(estimate_labels, true_labels)
        
      if PlotDuringTraining: 
        if i == 0:
          #show validation process on the validation data
          image = images[0,:,:,:]
          image = image.squeeze(0)
          true_label = true_labels[0,:,:]
          estimate_label= estimate_labels[0,:,:]
            
          plotModelResultUnet(image , true_label , estimate_label , epoch)

          del image, true_label , estimate_label
              
      valid_loss_list.append(loss.item())
      valid_iou_list.append(iou.item())


      del images , true_labels
      del logits, estimate_labels
      del loss, dice_loss, iou, CrossEntropy_Loss

    # end train_loop
    
 ##end validation loop 

  avrg_train_loss =  Average(train_loss_list)
  avrg_train_iou  =  Average(train_iou_list)
  
  avrg_valid_loss =  Average(valid_loss_list)
  avrg_valid_iou  =  Average(valid_iou_list)

  time_end =  time.time() - epoch_time_start # get end time
  print('Epoch {}/{} | '.format(epoch + 1, NUM_EPOCHS),end =' ')
  print('Time {:.3} sec | '.format(time_end),end =' ')
  print('train_loss {:.3f} | '.format(avrg_train_loss),end =' ')
  print('train_iou {:.3f} | '.format(avrg_train_iou),end =' ')
  print('validation_loss {:.3f} | '.format(avrg_valid_loss),end =' ')
  print('validation_iou {:.3f} | '.format(avrg_valid_iou))

  #print summary in table
  row_list.append([time_end, epoch+1, avrg_train_loss, avrg_train_iou, avrg_valid_loss, avrg_valid_iou])


#end epoch loop
savepath = BaseSavePath + F"/{Unet_save_name}"
print ('saving model in '+ savepath)
torch.save(unet.state_dict(), savepath)


#Graph : Train, validation
fig = plt.figure(figsize = (20,10))

ax = plt.subplot(1, 2, 1)
plt.ylabel('Loss')
plt.xlabel('epoch')
train_list = [item[2] for item in row_list]
valid_list = [item[4] for item in row_list]
plt.plot(train_list, label = ' TRAIN loss' , color = 'r') #linestyle='dashed'
plt.plot(valid_list, label = ' VALID loss' , color = 'b')
plt.legend(bbox_to_anchor=(1.05, 1), loc='best', borderaxespad=0.)
plt.grid(True)

ax = plt.subplot(1, 2, 2)
plt.ylabel('IOU')
plt.xlabel('epoch')
train_list = [item[3] for item in row_list]
valid_list = [item[5] for item in row_list]
plt.plot(train_list, label = ' TRAIN iou' , color = 'r') #linestyle='dashed'
plt.plot(valid_list, label = ' VALID iou' , color = 'b')
plt.legend(bbox_to_anchor=(1.05, 1), loc='best', borderaxespad=0.)
plt.grid(True)
plt.show

del unet

Test Unet
---

In [None]:
folder_data = BaseDataPath +"/02"
test_dataset = CustomDataset(folder_data)

# NETWORK PARAMS
BATCH_SIZE = 2
NUM_LABELS = 2 
NUM_IMAGE_CHANNELS = 1 

#loss
CrossEntropy_criterion = torch.nn.CrossEntropyLoss()

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = BATCH_SIZE , shuffle=True, num_workers=2)

# load model
model = LoadUnetModel(Unet_save_name);
model=model.to(device)  # move to GPU

# init losses_test
losses_test =[]
iou_test   = [];

# declare that we at evaluation mode
model = model.eval()

for i, (images, true_labels,_) in enumerate(test_loader):

  images = images.to(device)           #move to gpu   # images size (BATCH_SIZE x 1 x 512 x512)
  true_labels = true_labels.to(device) #move to gpu   # true_labels size(BATCH_SIZE x 512 x512)
  
  logits,estimate_labels = model(images)  # logits size (BATCH_SIZE x NUM_LABELS x 512 x512) ,  estimate_labels size(BATCH_SIZE x 512 x512)
  
  #post processing estimate_labels
  estimate_labels = PostProcessingUnet(estimate_labels)

  # calculate the loss
  CrossEntropy_Loss = CrossEntropy_criterion(logits, true_labels)
  dice_loss =  calculate_dice_loss(logits,true_labels)
  
  iou = calculate_iou(estimate_labels, true_labels)
  
  #print(iou)

  test_loss = CrossEntropy_Loss#dice_loss

  losses_test.append(test_loss.item())
  iou_test.append(iou)
  # display the first  result
  if i < 10 :
    #print(i)
    image = images[0,:,:,:]
    image = image.squeeze(0)
    true_label = true_labels[0,:,:]
    estimate_label= estimate_labels[0,:,:]
    
    plotModelResultUnet(image , true_label , estimate_label , i)

    del image , true_label , estimate_label

  # end display first result

  del images
  del true_labels
  del logits , estimate_labels
  del test_loss

loss_average = Average (losses_test)
iou_average  = Average (iou_test)
print(' Average Loss on the test data = {:.3f}  '.format(loss_average))
print(' Average IOU measurement on the test data = {:.3f}  '.format(iou_average))

del model

Train DWTNet
---

In [None]:
# train dataset
PlotDuringTraining  = True

# NETWORK PARAMS
BATCH_SIZE = 1
NUM_EPOCHS = 100
NUM_LABELS = 16

# Train loader
folder_data = BaseDataPath +"/01"
# divide to train and validation
dataset = CustomDataset(folder_data)
dataset_len = len(dataset)
dataset_indices = [[x] for x in range(dataset_len)]
picked = random.sample(dataset_indices,math.floor(.7*dataset_len))
flat_list = []
for sublist in picked:
    for item in sublist:
        flat_list.append(item)
picked = flat_list 

train_indices=[]
valid_indices=[]
for i in range (dataset_len):
  if i in picked:
    train_indices.append(i)
  else:
    valid_indices.append(i)

dwtnet = DWTNet(doPrint = False, doPrintMax = False, num_labels = NUM_LABELS) 
dwtnet=dwtnet.to(device)  # move to GPU

# Load U-net 
unet = LoadUnetModel(Unet_save_name)
unet = unet.to(device)
unet.eval() # declare evaluation mode for unet

# use the modules apply function to recursively apply the initialization
dwtnet.apply(dwtnet.init_weights)

# Loss function
weights = torch.ones (16,dtype = torch.float) 
#weights[1:5] = 3.0
CrossEntropy_criterion = torch.nn.CrossEntropyLoss(weight= weights.to(device))

# Optimizer
optimizer = torch.optim.SGD(dwtnet.parameters(), lr =0.1, momentum=0.5) #0.001
#optimizer = torch.optim.Adam(dwtnet.parameters(), lr=5e-6)
optimizer.zero_grad()

row_list = []

for epoch in range(NUM_EPOCHS):
  train_loss_list = []
  train_iou_list  = []
  train_iou_binary_list = []
  
  valid_loss_list = []
  valid_iou_list  = []
  valid_iou_binary_list =[]
  

  data_transform = ImgAugTransform()
  train_dataset = CustomDataset(folder_data, transform =  data_transform)
  train_loader = torch.utils.data.DataLoader(torch.utils.data.Subset(train_dataset, train_indices), batch_size= BATCH_SIZE , shuffle=True, num_workers=2 )
  valid_loader = torch.utils.data.DataLoader(torch.utils.data.Subset(train_dataset, valid_indices), batch_size= BATCH_SIZE , shuffle=True, num_workers=2 )


  epoch_time_start = time.time();
  ## training loop
  dwtnet = dwtnet.train() # declare training
  for i, (images, binary_true_labels ,depth_true_maps) in enumerate(train_loader):

    images = images.to(device)           # move to gpu   # images size (BATCH_SIZE x 1 x 512 x512)

    binary_true_labels  = binary_true_labels.to(device) # move to gpu   # true_labels size(BATCH_SIZE x 512 x512)

    depth_true_maps  = depth_true_maps.to(device) # move to gpu   # depth_true_map size(BATCH_SIZE x 512 x512)
    _, unet_estimate_labels = unet(images) # unet_estimate_labels size(BATCH_SIZE x 512 x512)
    
    #post processing Unet
    unet_estimate_labels = PostProcessingUnet(unet_estimate_labels,doPlot = False)
    
    logits, depth_estimate_maps = dwtnet(images, unet_estimate_labels)
    #logits, depth_estimate_maps = dwtnet(images, binary_true_labels)  # logits size (BATCH_SIZE x NUM_LABELS x 512 x512) ,  depth_estimate_map size(BATCH_SIZE x 512 x512)
    
    # loss values
    CrossEntropy_Loss = CrossEntropy_criterion(logits, depth_true_maps)
    
    iou = calculate_iou(depth_estimate_maps, depth_true_maps)

    binary_estimate_maps = depth_estimate_maps.clone()
    binary_estimate_maps[binary_estimate_maps <= 2] = 0
    binary_estimate_maps[binary_estimate_maps > 2] = 1
      
    iou_binary = calculate_iou(binary_estimate_maps, binary_true_labels)

    #print(iou)

    loss =  CrossEntropy_Loss
    #print(" loss = " + str(loss.item())) 

    # Back propagation
    dwtnet.zero_grad()
    loss.backward()
    #plot_grad_flow(dwtnet.named_parameters())
    optimizer.step()

    train_loss_list.append(loss.item())
    train_iou_list.append(iou.item())
    train_iou_binary_list.append(iou_binary.item())
    
    del images , unet_estimate_labels , depth_true_maps # binary_true_labels
    del logits, binary_estimate_maps,depth_estimate_maps
    del loss , iou,iou_binary, CrossEntropy_Loss

  ## end train_loop

  ## validation loop
  dwtnet = dwtnet.eval() # declare validation
  with torch.no_grad():
    for i, (images, binary_true_labels ,depth_true_maps) in enumerate(valid_loader):
      
      images = images.to(device)           # move to gpu   # images size (BATCH_SIZE x 1 x 512 x512)
      
      binary_true_labels  = binary_true_labels.to(device) # move to gpu   # true_labels size(BATCH_SIZE x 512 x512)
      
      _, unet_estimate_labels = unet(images) # unet_estimate_labels size(BATCH_SIZE x 512 x512)
      
      #post processing Unet
      unet_estimate_labels = PostProcessingUnet(unet_estimate_labels,doPlot = False)
      
      depth_true_maps  = depth_true_maps.to(device) # move to gpu   # depth_true_map size(BATCH_SIZE x 512 x512)
      
      logits, depth_estimate_maps = dwtnet(images, unet_estimate_labels)  # logits size (BATCH_SIZE x NUM_LABELS x 512 x512) ,  depth_estimate_map size(BATCH_SIZE x 512 x512)

      #post processing 
      #depth_estimate_maps = PostProcessing(depth_estimate_maps)

      # loss values
      CrossEntropy_Loss = CrossEntropy_criterion(logits, depth_true_maps)
      loss =  CrossEntropy_Loss

      iou = calculate_iou(depth_estimate_maps, depth_true_maps)
      
      binary_estimate_maps = depth_estimate_maps.clone()
      binary_estimate_maps[binary_estimate_maps <= 2] = 0
      binary_estimate_maps[binary_estimate_maps > 2] = 1
      
      iou_binary = calculate_iou(binary_estimate_maps, binary_true_labels)

      if PlotDuringTraining: 
        if i == 0:
          #show validation process on the validation data
          image = images[0,:,:,:]
          image = image.squeeze(0)
          binary_true_label = binary_true_labels[0,:,:]
          depth_estimate_map= depth_estimate_maps[0,:,:]
          depth_true_map = depth_true_maps[0,:,:]
          unet_estimate_labels = unet_estimate_labels[0,:,:]

          plotModelResult(image, binary_true_label, depth_true_map, unet_estimate_labels,depth_estimate_map, i) 
          #plotModelResultDWTNet(image, binary_true_label, depth_true_map,depth_estimate_map, epoch)
          
          del image, binary_true_label , depth_estimate_map, depth_true_map
              
      valid_loss_list.append(loss.item())
      valid_iou_list.append(iou.item())
      valid_iou_binary_list.append(iou_binary.item())

      del images , binary_true_labels,depth_true_maps, unet_estimate_labels
      del logits,binary_estimate_maps, depth_estimate_maps
      del loss, iou,iou_binary, CrossEntropy_Loss

    # end train_loop

 ##end validation loop 

  avrg_train_loss =  Average(train_loss_list)
  avrg_train_iou  =  Average(train_iou_list)
  avrg_train_binary_iou = Average(train_iou_binary_list)
  
  avrg_valid_loss =  Average(valid_loss_list)
  avrg_valid_iou  =  Average(valid_iou_list)
  avrg_valid_binary_iou = Average(valid_iou_binary_list)

  time_end =  time.time() - epoch_time_start # get end time
  print('Epoch {}/{} | '.format(epoch + 1, NUM_EPOCHS),end =' ')
  print('Time {:.3} sec | '.format(time_end),end =' ')
  print('train_loss {:.3f} | '.format(avrg_train_loss),end =' ')
  print('train_depth_iou {:.3f} | '.format(avrg_train_iou),end =' ')
  print('train_binary_iou {:.3f} | '.format(avrg_train_binary_iou),end =' ')
  print('validation_loss {:.3f} | '.format(avrg_valid_loss),end =' ')
  print('validation_depth_iou {:.3f} | '.format(avrg_valid_iou),end =' ')
  print('validation_binary_iou {:.3f} | '.format(avrg_valid_binary_iou))

  #print summary in table
  row_list.append([time_end, epoch+1, avrg_train_loss, avrg_train_iou,avrg_train_binary_iou, avrg_valid_loss, avrg_valid_iou,avrg_valid_binary_iou])


#end epoch loop
savepath = BaseSavePath + F"/{DWTNet_save_name}"
print ('saving model in '+ savepath)
torch.save(dwtnet.state_dict(), savepath)


#Graph : Train, validation
fig = plt.figure(figsize = (20,10))

ax = plt.subplot(1, 3, 1)
plt.ylabel(' Loss')
plt.xlabel('epoch')
train_list = [item[2] for item in row_list]
valid_list = [item[5] for item in row_list]
plt.plot(train_list, label = ' TRAIN loss' , color = 'r') #linestyle='dashed'
plt.plot(valid_list, label = ' VALID loss' , color = 'b')
plt.legend(bbox_to_anchor=(1.05, 1), loc='best', borderaxespad=0.)
plt.grid(True)

ax = plt.subplot(1, 3, 2)
plt.ylabel(' IOU')
plt.xlabel('epoch')
train_list = [item[3] for item in row_list]
valid_list = [item[6] for item in row_list]
plt.plot(train_list, label = ' TRAIN depth iou' , color = 'r') #linestyle='dashed'
plt.plot(valid_list, label = ' VALID depth iou' , color = 'b')
plt.legend(bbox_to_anchor=(1.05, 1), loc='best', borderaxespad=0.)
plt.grid(True)
plt.show

ax = plt.subplot(1, 3, 3)
plt.ylabel('binary IOU')
plt.xlabel('epoch')
train_list = [item[4] for item in row_list]
valid_list = [item[7] for item in row_list]
plt.plot(train_list, label = ' TRAIN binary iou' , color = 'r') #linestyle='dashed'
plt.plot(valid_list, label = ' VALID binary iou' , color = 'b')
plt.legend(bbox_to_anchor=(1.05, 1), loc='best', borderaxespad=0.)
plt.grid(True)
plt.show



del dwtnet

Test DWTNet
---

In [None]:
folder_data = BaseDataPath +"/02"
test_dataset = CustomDataset(folder_data)

# NETWORK PARAMS
BATCH_SIZE = 1

# Loss function
CrossEntropy_criterion = torch.nn.CrossEntropyLoss()

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = BATCH_SIZE , shuffle=False, num_workers=2)

# load model
model = LoadDWTModel(DWTNet_save_name)
model=model.to(device)  # move to GPU

# init losses_test
losses_test =[]
iou_test   = []
binary_iou_test = []

# declare that we at evaluation mode
model = model.eval()

for i, (images, binary_true_labels ,depth_true_maps) in enumerate(test_loader):

  images = images.to(device)           # move to gpu   # images size (BATCH_SIZE x 1 x 512 x512)
  
  binary_true_labels  = binary_true_labels.to(device) # move to gpu   # true_labels size(BATCH_SIZE x 512 x512)
  
  depth_true_maps  = depth_true_maps.to(device) # move to gpu   # depth_true_map size(BATCH_SIZE x 512 x512)
  
  logits, depth_estimate_maps = model(images, binary_true_labels)  # logits size (BATCH_SIZE x NUM_LABELS x 512 x512) ,  depth_estimate_map size(BATCH_SIZE x 512 x512) 
  
  #post processing estimate_labels
  #depth_estimate_maps = PostProcessingDWT(depth_estimate_maps, binary_true_labels, doPlot = True)

  # calculate the loss
  CrossEntropy_Loss = CrossEntropy_criterion(logits, depth_true_maps)
  
  iou = calculate_iou(depth_estimate_maps, depth_true_maps)
  
  binary_estimate_maps = depth_estimate_maps.clone()
  binary_estimate_maps[binary_estimate_maps <= 2] = 0
  binary_estimate_maps[binary_estimate_maps > 2] = 1
      
  iou_binary = calculate_iou(binary_estimate_maps, binary_true_labels)

  test_loss = CrossEntropy_Loss

  losses_test.append(test_loss.item())
  iou_test.append(iou)
  binary_iou_test.append(iou_binary)
  # display the first  result
  if i < 10 :
    #print(i)
    image = images[0,:,:,:]
    image = image.squeeze(0)
    binary_true_label = binary_true_labels[0,:,:]
    depth_estimate_map= depth_estimate_maps[0,:,:]
    depth_true_map = depth_true_maps[0,:,:]
    
    plotModelResultDWTNet(image, binary_true_label, depth_true_map,depth_estimate_map, i)
          
    del image, binary_true_label , depth_estimate_map, depth_true_map

  # end display first result

  del images , binary_true_labels,depth_true_maps
  del logits, binary_estimate_maps,depth_estimate_maps
  del test_loss, iou,iou_binary, CrossEntropy_Loss

loss_average =  Average(losses_test)
iou_average  =  Average(iou_test)
binary_iou_average  =  Average(binary_iou_test)

print(' Average Loss on the test data = {:.3f}  '.format(loss_average))
print(' Average IOU measurement on the test data = {:.3f}  '.format(iou_average))
print(' Average binary IOU measurement on the test data = {:.3f}  '.format(binary_iou_average))

del model


Train UDWTNet
---

In [None]:
# train dataset
PlotDuringTraining  = False

# NETWORK PARAMS
BATCH_SIZE = 2
NUM_EPOCHS = 200

# Train loader
folder_data = BaseDataPath +"/01"
# divide to train and validation
dataset = CustomDataset(folder_data)
dataset_len = len(dataset)
dataset_indices = [[x] for x in range(dataset_len)]
picked = random.sample(dataset_indices,math.floor(.7*dataset_len))
flat_list = []
for sublist in picked:
    for item in sublist:
        flat_list.append(item)
picked = flat_list 

train_indices=[]
valid_indices=[]
for i in range (dataset_len):
  if i in picked:
    train_indices.append(i)
  else:
    valid_indices.append(i)

  
# models
unet = UNet(in_channel= 1, out_channel = 2, doPrint = False) #out_channel represents number of segments desired

dwtnet = DWTNet(doPrint = False, doPrintMax = False, num_labels = 16) 
dwtnet.apply(dwtnet.init_weights)

udwtnet=UDWTNet(unet.to(device),dwtnet.to(device) )
udwtnet=udwtnet.to(device)  # move to GPU

# Loss function
weights = torch.ones (16,dtype = torch.float) 
CrossEntropy_criterion = torch.nn.CrossEntropyLoss(weight= weights.to(device))

# Optimizer
optimizer = torch.optim.SGD(dwtnet.parameters(), lr =0.1, momentum=0.5) #0.001
#optimizer = torch.optim.Adam(dwtnet.parameters(), lr=5e-6)
optimizer.zero_grad()

row_list = []

for epoch in range(NUM_EPOCHS):
  train_loss_list = []
  train_iou_list  = []
  train_iou_binary_list = []
  
  valid_loss_list = []
  valid_iou_list  = []
  valid_iou_binary_list = []
  
  data_transform = ImgAugTransform()
  train_dataset = CustomDataset(folder_data, transform =  data_transform)
  train_loader = torch.utils.data.DataLoader(torch.utils.data.Subset(train_dataset, train_indices), batch_size= BATCH_SIZE , shuffle=True, num_workers=2 )
  valid_loader = torch.utils.data.DataLoader(torch.utils.data.Subset(train_dataset, valid_indices), batch_size= BATCH_SIZE , shuffle=True, num_workers=2 )

  epoch_time_start = time.time();
  ## training loop
  dwtnet = dwtnet.train() # declare training
  for i, (images, binary_true_labels ,depth_true_maps) in enumerate(train_loader):

    images = images.to(device)           # move to gpu   # images size (BATCH_SIZE x 1 x 512 x512)

    binary_true_labels  = binary_true_labels.to(device) # move to gpu   # depth_true_map size(BATCH_SIZE x 512 x512)

    depth_true_maps  = depth_true_maps.to(device) # move to gpu   # depth_true_map size(BATCH_SIZE x 512 x512)

    logits, estimate_labels,depth_estimate_maps = udwtnet(images)  # logits size (BATCH_SIZE x NUM_LABELS x 512 x512) ,  depth_estimate_map size(BATCH_SIZE x 512 x512)
  
    # loss values
    CrossEntropy_Loss = CrossEntropy_criterion(logits, depth_true_maps)
    
    iou = calculate_iou(depth_estimate_maps, depth_true_maps)
    #print(iou)
    
    binary_estimate_maps = depth_estimate_maps.clone()
    binary_estimate_maps[binary_estimate_maps <= 2] = 0
    binary_estimate_maps[binary_estimate_maps > 2] = 1
      
    iou_binary = calculate_iou(binary_estimate_maps, binary_true_labels)

    loss =  CrossEntropy_Loss
    #print(" loss = " + str(loss.item())) 

    # Back propagation
    udwtnet.zero_grad()
    loss.backward()
    #plot_grad_flow(dwtnet.named_parameters())
    optimizer.step()

    train_loss_list.append(loss.item())
    train_iou_list.append(iou.item())
    train_iou_binary_list.append(iou_binary.item())
    
    del images , binary_true_labels,estimate_labels, depth_true_maps
    del logits, binary_estimate_maps,depth_estimate_maps
    del loss , iou,iou_binary, CrossEntropy_Loss

  ## end train_loop

  ## validation loop
  dwtnet = dwtnet.eval() # declare validation
  with torch.no_grad():
    for i, (images, binary_true_labels ,depth_true_maps) in enumerate(valid_loader):
      
      images = images.to(device)           # move to gpu   # images size (BATCH_SIZE x 1 x 512 x512)
      
      binary_true_labels  = binary_true_labels.to(device) # move to gpu   # true_labels size(BATCH_SIZE x 512 x512)
      
      depth_true_maps  = depth_true_maps.to(device) # move to gpu   # depth_true_map size(BATCH_SIZE x 512 x512)
      
      logits,estimate_labels, depth_estimate_maps = udwtnet(images)  # logits size (BATCH_SIZE x NUM_LABELS x 512 x512) ,  depth_estimate_map size(BATCH_SIZE x 512 x512)

      # loss values
      CrossEntropy_Loss = CrossEntropy_criterion(logits, depth_true_maps)
      loss =  CrossEntropy_Loss

      iou = calculate_iou(depth_estimate_maps, depth_true_maps)
      
      binary_estimate_maps = depth_estimate_maps.clone()
      binary_estimate_maps[binary_estimate_maps <= 2] = 0
      binary_estimate_maps[binary_estimate_maps > 2] = 1
      
      iou_binary = calculate_iou(binary_estimate_maps, binary_true_labels)

      if PlotDuringTraining: 
        if i == 0:
          #show validation process on the validation data
          image = images[0,:,:,:]
          image = image.squeeze(0)
          binary_true_label = binary_true_labels[0,:,:]
          depth_true_map = depth_true_maps[0,:,:]
          estimate_label =  estimate_labels[0,:,:]
          depth_estimate_map= depth_estimate_maps[0,:,:]
          
          plotModelResult(image, binary_true_label, depth_true_map, estimate_label,depth_estimate_map, i) 

          del image, binary_true_label, depth_true_map, estimate_label,depth_estimate_map
              
      valid_loss_list.append(loss.item())
      valid_iou_list.append(iou.item())
      valid_iou_binary_list.append(iou_binary.item())

      del images , binary_true_labels,depth_true_maps
      del logits, estimate_labels,binary_estimate_maps, depth_estimate_maps
      del loss, iou, iou_binary, CrossEntropy_Loss

    # end train_loop

 ##end validation loop 

  avrg_train_loss =  Average(train_loss_list)
  avrg_train_iou  =  Average(train_iou_list)
  avrg_train_binary_iou  =  Average(train_iou_binary_list)

  
  avrg_valid_loss =  Average(valid_loss_list)
  avrg_valid_iou  =  Average(valid_iou_list)
  avrg_valid_binary_iou  =  Average(valid_iou_binary_list)

  time_end =  time.time() - epoch_time_start # get end time
  print('Epoch {}/{} | '.format(epoch + 1, NUM_EPOCHS),end =' ')
  print('Time {:.3} sec | '.format(time_end),end =' ')
  print('train_loss {:.3f} | '.format(avrg_train_loss),end =' ')
  print('train_depth_iou {:.3f} | '.format(avrg_train_iou),end =' ')
  print('train_binary_iou {:.3f} | '.format(avrg_train_binary_iou),end =' ')
  print('validation_loss {:.3f} | '.format(avrg_valid_loss),end =' ')
  print('validation_depth_iou {:.3f} | '.format(avrg_valid_iou),end =' ')
  print('validation_binary_iou {:.3f} | '.format(avrg_valid_binary_iou))

  #print summary in table
  row_list.append([time_end, epoch+1, avrg_train_loss, avrg_train_iou,avrg_train_binary_iou, avrg_valid_loss, avrg_valid_iou,avrg_valid_binary_iou])

#end epoch loop
savepath = BaseSavePath + F"/{UDWTNet_save_name}"
print ('saving model in '+ savepath)
torch.save(udwtnet.state_dict(), savepath)


#Graph : Train, validation
fig = plt.figure(figsize = (20,10))

ax = plt.subplot(1, 3, 1)
plt.ylabel(' Loss')
plt.xlabel('epoch')
train_list = [item[2] for item in row_list]
valid_list = [item[5] for item in row_list]
plt.plot(train_list, label = ' TRAIN loss' , color = 'r') #linestyle='dashed'
plt.plot(valid_list, label = ' VALID loss' , color = 'b')
plt.legend(bbox_to_anchor=(1.05, 1), loc='best', borderaxespad=0.)
plt.grid(True)

ax = plt.subplot(1, 3, 2)
plt.ylabel(' IOU')
plt.xlabel('epoch')
train_list = [item[3] for item in row_list]
valid_list = [item[6] for item in row_list]
plt.plot(train_list, label = ' TRAIN depth iou' , color = 'r') #linestyle='dashed'
plt.plot(valid_list, label = ' VALID depth iou' , color = 'b')
plt.legend(bbox_to_anchor=(1.05, 1), loc='best', borderaxespad=0.)
plt.grid(True)
plt.show

ax = plt.subplot(1, 3, 3)
plt.ylabel('binary IOU')
plt.xlabel('epoch')
train_list = [item[4] for item in row_list]
valid_list = [item[7] for item in row_list]
plt.plot(train_list, label = ' TRAIN binary iou' , color = 'r') #linestyle='dashed'
plt.plot(valid_list, label = ' VALID binary iou' , color = 'b')
plt.legend(bbox_to_anchor=(1.05, 1), loc='best', borderaxespad=0.)
plt.grid(True)
plt.show

del dwtnet

Test UDWTNet
---

In [None]:
folder_data = BaseDataPath +"/02"
test_dataset = CustomDataset(folder_data)

# NETWORK PARAMS
BATCH_SIZE = 1

# Loss function
CrossEntropy_criterion = torch.nn.CrossEntropyLoss()

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = BATCH_SIZE , shuffle=True, num_workers=2)

# load model
model = LoadUDWTModel(UDWTNet_save_name)
model=model.to(device)  # move to GPU

# init losses_test
losses_test =[]
iou_test   = []
iou_binary_test = []

# declare that we at evaluation mode
model = model.eval()

for i, (images, true_binary_labels ,depth_true_maps) in enumerate(test_loader):

  images = images.to(device)  # move to gpu   # images size (BATCH_SIZE x 1 x 512 x512)
  
  true_binary_labels  = true_binary_labels.to(device) # move to gpu   # true_binary_labels size(BATCH_SIZE x 512 x512)

  depth_true_maps  = depth_true_maps.to(device) # move to gpu   # depth_true_map size(BATCH_SIZE x 512 x512)
  
  logits,estimate_labels, depth_estimate_maps = model(images)  # logits size (BATCH_SIZE x NUM_LABELS x 512 x512) ,  depth_estimate_map size(BATCH_SIZE x 512 x512) 
  
  #post processing estimate_labels
  #estimate_labels = PostProcessingDWT(depth_estimate_maps, estimate_labels, doPlot = False)

  # calculate the loss
  CrossEntropy_Loss = CrossEntropy_criterion(logits, depth_true_maps)
   
  iou = calculate_iou(depth_estimate_maps, depth_true_maps)
  
  binary_estimate_maps = depth_estimate_maps.clone()
  binary_estimate_maps[binary_estimate_maps <= 2] = 0
  binary_estimate_maps[binary_estimate_maps > 2] = 1
  
  iou_binary = calculate_iou(binary_estimate_maps, true_binary_labels)

  #print(iou)

  test_loss = CrossEntropy_Loss

  losses_test.append(test_loss.item())
  iou_test.append(iou)
  iou_binary_test.append(iou_binary)
  # display the first  result
  if i < 10 :
    #print(i)
    image = images[0,:,:,:]
    image = image.squeeze(0)
    true_binary_label = true_binary_labels[0,:,:]
    depth_true_map= depth_true_maps[0,:,:]
    estimate_label= estimate_labels[0,:,:]
    depth_estimate_map= depth_estimate_maps[0,:,:]
    
    plotModelResult(image, true_binary_label, depth_true_map, estimate_label,depth_estimate_map, i) 
          
    del image, true_binary_label ,depth_true_map, estimate_label,depth_estimate_map

  # end display first result

  del images , true_binary_labels, estimate_labels ,depth_true_maps
  del logits, depth_estimate_maps
  del test_loss, iou,iou_binary, CrossEntropy_Loss

loss_average =  Average(losses_test)
iou_average  =  Average(iou_test)
iou_binary_average  =  Average(iou_binary_test)

print(' Average Loss on the test data = {:.3f}  '.format(loss_average))
print(' Average IOU measurement on the test data = {:.3f}  '.format(iou_average))
print(' Average binary IOU measurement on the test data = {:.3f}  '.format(iou_binary_average))

del model

# Test Trained Unet -> Trained DWTNet

Test Trained Unet -> Trained DWTNet
---

In [None]:
# UNET NETWORK PARAMS
BATCH_SIZE = 2
NUM_LABELS = 2 
NUM_IMAGE_CHANNELS = 1 

# handle data 
folder_data = BaseDataPath +"/02"
test_dataset = CustomDataset(folder_data)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size = BATCH_SIZE , shuffle=True, num_workers=2)

# load Unet model
unet = LoadUnetModel(Unet_save_name);
unet=unet.to(device)  # move to GPU

# load DWT model
dwtnet = LoadDWTModel(DWTNet_save_name)
dwtnet=dwtnet.to(device)  # move to GPU

# declare that we at evaluation mode
unet   = unet.eval()
dwtnet = dwtnet.eval()

iou_list  = []
binary_iou_list  = []

for i, (images, binary_true_labels ,depth_true_maps) in enumerate(test_loader):

  images = images.to(device)  # move to gpu   # images size (BATCH_SIZE x 1 x 512 x512)
  
  binary_true_labels  = binary_true_labels.to(device) # move to gpu   # true_labels size(BATCH_SIZE x 512 x512)
  
  depth_true_maps  = depth_true_maps.to(device) # move to gpu   # depth_true_map size(BATCH_SIZE x 512 x512)
  
  # Forward pass
  unet_logits,estimate_labels = unet(images) # logits size (BATCH_SIZE x NUM_LABELS x 512 x512) ,  estimate_labels size(BATCH_SIZE x 512 x512)
  
  #post processing Unet
  estimate_labels = PostProcessingUnet(estimate_labels,doPlot = False)
  
  dwtnet_logits, depth_estimate_maps = dwtnet(images, estimate_labels)  # logits size (BATCH_SIZE x NUM_LABELS x 512 x512) ,  depth_estimate_map size(BATCH_SIZE x 512 x512) 

  #post processing DWT
  # depth_estimate_maps = PostProcessingDWT(depth_estimate_maps, estimate_labels,doPlot = True)
  
  # calculate the loss
  #CrossEntropy_Loss = CrossEntropy_criterion(logits, depth_true_maps)
  #dice_loss =  calculate_dice_loss(logits,depth_true_maps)
  
  iou = calculate_iou(depth_estimate_maps, depth_true_maps)
  
  binary_estimate_maps = depth_estimate_maps.clone()
  binary_estimate_maps[binary_estimate_maps <= 2] = 0
  binary_estimate_maps[binary_estimate_maps > 2] = 1
  
  iou_binary = calculate_iou(binary_estimate_maps, binary_true_labels)


  iou_list.append(iou.item())
  binary_iou_list.append(iou_binary.item())
  # display the results
  if i < 10 :
    #print(i)
    image = images[0,:,:,:]
    image = image.squeeze(0)
    binary_true_label = binary_true_labels[0,:,:]
    estimate_label = estimate_labels [0,:,:]
    depth_true_map= depth_true_maps[0,:,:]
    depth_estimate_map= depth_estimate_maps[0,:,:]
    
    plotModelResult(image, binary_true_label, depth_true_map, estimate_label,depth_estimate_map, i) 
    #plotModelResultDWTNet(image, binary_true_label, depth_true_map,depth_estimate_map, i)

    del image , binary_true_label ,estimate_label, depth_estimate_map, depth_true_map

  # end display first result 

  del images, binary_true_labels , depth_true_maps # delete data
  del unet_logits , estimate_labels  # delete unet result
  del dwtnet_logits ,binary_estimate_maps,  depth_estimate_maps  # delete dwtnet result

print(' Average IOU measurement on the test data = {:.3f}  '.format(Average(iou_list)))
print(' Average binary IOU measurement on the test data = {:.3f}  '.format(Average(binary_iou_list)))

del unet
del dwtnet