<a href="https://colab.research.google.com/github/ndvaughn/dUnet/blob/main/project442.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torchsummary
import shutil
import argparse
import zipfile
import hashlib
import pickle
import requests
from IPython.display import clear_output 
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import os
import time
import itertools
from matplotlib import image
import glob as glob
from PIL import Image

import torch
import torchvision
from torchvision import datasets, models, transforms
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchsummary import summary

print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)
# Detect if we have a GPU available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    print("Using the GPU!")
else:
    print("WARNING: Could not find GPU! Using CPU only. If you want to enable GPU, please to go Edit > Notebook Settings > Hardware Accelerator and select GPU.")


In [None]:
DEBUG = False
DETECT_SBST = True 
TOP_K = 15

In [None]:
from google.colab import drive
drive.mount('data')

In [None]:
# some helper functions to download the dataset
# this code comes mainly from gluoncv.utils
def check_sha1(filename, sha1_hash):
    """Check whether the sha1 hash of the file content matches the expected hash.
    Parameters
    ----------
    filename : str
        Path to the file.
    sha1_hash : str
        Expected sha1 hash in hexadecimal digits.
    Returns
    -------
    bool
        Whether the file content matches the expected hash.
    """
    sha1 = hashlib.sha1()
    with open(filename, 'rb') as f:
        while True:
            data = f.read(1048576)
            if not data:
                break
            sha1.update(data)

    sha1_file = sha1.hexdigest()
    l = min(len(sha1_file), len(sha1_hash))
    return sha1.hexdigest()[0:l] == sha1_hash[0:l]

def download(url, path=None, overwrite=False, sha1_hash=None):
    """Download an given URL
    Parameters
    ----------
    url : str
        URL to download
    path : str, optional
        Destination path to store downloaded file. By default stores to the
        current directory with same name as in url.
    overwrite : bool, optional
        Whether to overwrite destination file if already exists.
    sha1_hash : str, optional
        Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
        but doesn't match.
    Returns
    -------
    str
        The file path of the downloaded file.
    """
    if path is None:
        fname = url.split('/')[-1]
    else:
        path = os.path.expanduser(path)
        if os.path.isdir(path):
            fname = os.path.join(path, url.split('/')[-1])
        else:
            fname = path

    if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)):
        dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
        if not os.path.exists(dirname):
            os.makedirs(dirname)

        print('Downloading %s from %s...'%(fname, url))
        r = requests.get(url, stream=True)
        if r.status_code != 200:
            raise RuntimeError("Failed downloading url %s"%url)
        total_length = r.headers.get('content-length')
        with open(fname, 'wb') as f:
            if total_length is None: # no content length header
                for chunk in r.iter_content(chunk_size=1024):
                    if chunk: # filter out keep-alive new chunks
                        f.write(chunk)
            else:
                total_length = int(total_length)
                for chunk in tqdm(r.iter_content(chunk_size=1024),
                                  total=int(total_length / 1024. + 0.5),
                                  unit='KB', unit_scale=False, dynamic_ncols=True):
                    f.write(chunk)

        if sha1_hash and not check_sha1(fname, sha1_hash):
            raise UserWarning('File {} is downloaded but the content hash does not match. ' \
                              'The repo may be outdated or download may be incomplete. ' \
                              'If the "repo_url" is overridden, consider switching to ' \
                              'the default repo.'.format(fname))

    return fname

def download_ade(path, overwrite=False):

    """Download ADE20K
    Parameters
    ----------
    path : str
      Location of the downloaded files.
    overwrite : bool, optional
      Whether to overwrite destination file if already exists.
    """
    if not os.path.exists(path):
        os.mkdir(path)
    _AUG_DOWNLOAD_URLS = [
      ('http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip', '219e1696abb36c8ba3a3afe7fb2f4b4606a897c7'),
      ('http://data.csail.mit.edu/places/ADEchallenge/release_test.zip', 'e05747892219d10e9243933371a497e905a4860c'),]
    download_dir = os.path.join(path, 'downloads')
    if not os.path.exists(download_dir):
        os.mkdir(download_dir)
    for url, checksum in _AUG_DOWNLOAD_URLS:
        filename = download(url, path=download_dir, overwrite=overwrite, sha1_hash=checksum)
        # extract
        with zipfile.ZipFile(filename,"r") as zip_ref:
            zip_ref.extractall(path=path)

In [None]:
root = "/content/"
dataset_path = root + "ADEChallengeData2016/images/"
training_data = "training/"
val_data = "validation/"

In [None]:
download_ade(root, overwrite=False)

In [None]:
TRAINSET_SIZE = len(glob.glob(dataset_path + training_data + "*.jpg"))
print(f"The Training Dataset contains {TRAINSET_SIZE} images.")

VALSET_SIZE = len(glob.glob(dataset_path + val_data + "*.jpg"))
print(f"The Validation Dataset contains {VALSET_SIZE} images.")

N_CLASSES = 151

In [None]:
def normal_init(m, mean, std):
    """
    Helper function. Initialize model parameter with given mean and std.
    """
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()

In [None]:
class Ade20k(Dataset):
    def __init__(self, root_dir, split='train', transform=None):
        """
        Args:
            root_dir: the directory of the dataset
            split: "train" or "val"
            transform: pytorch transformations.
        """

        self.transform = transform
        self.resize = transforms.Resize((256,256), interpolation=Image.NEAREST)
        self.ToTensor = transforms.ToTensor()
        self.normalize = transforms.Compose([
                                    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])

        if split == 'train':
          self.img_files = glob.glob(root_dir + '/training/*.jpg')
        else:
          self.img_files = glob.glob(root_dir + '/validation/*.jpg')
        self.mask_files = [w.replace('images', 'annotations').replace('jpg','png') for w in self.img_files]

        if DEBUG:
          self.img_files = self.img_files[:500]
          self.mask_files = self.mask_files[:500]
          # print(np.shape(np.asarray(Image.open(self.mask_files[0]))))

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx):
        img = Image.open(self.img_files[idx])
        img = self.resize(img)
        img = np.asarray(img)
        if img.ndim != 3:
          img = np.expand_dims(img,2)
          img = np.repeat(img,3,2)
        img = self.ToTensor(img)
        img = self.normalize(img)

        label = Image.open(self.mask_files[idx])
        label = self.resize(label)
        label = np.asarray(label)
        if DETECT_SBST:
          #need to copy since download is read-only
          temp = np.copy(label)
          temp[temp > TOP_K] = 0
          temp = torch.from_numpy(temp).to(torch.long)
          return img, temp
        label = torch.from_numpy(label).to(torch.long)
        #Uncomment below line to move to BCE loss 
        # label = F.one_hot(label, num_classes=151).permute((2,0,1)).float()
        
        return img, label


tr_dt = Ade20k(dataset_path, split='train')
te_dt = Ade20k(dataset_path, split='val')
train_loader = DataLoader(tr_dt, batch_size=4, shuffle=True)
test_loader = DataLoader(te_dt, batch_size=5, shuffle=False)

# Make sure that you have 1,000 training images and 100 testing images before moving on
print('Number of training images {}, number of testing images {}'.format(len(tr_dt), len(te_dt)))

In [None]:
#Sample Output used for visualization
test = test_loader.__iter__().__next__()
img_size = 256
fixed_y_ = test[1].cuda()
fixed_x_ = test[0].cuda()
print(len(train_loader))
print(len(test_loader))
print(fixed_y_.shape)

# plot sample image
fig, axes = plt.subplots(2, 2)
axes = np.reshape(axes, (4, ))
for i in range(4):
    example = train_loader.__iter__().__next__()[0][i].numpy().transpose((1, 2, 0))
    mean = np.array([0.5, 0.5, 0.5])
    std = np.array([0.5, 0.5, 0.5])
    example = std * example + mean
    axes[i].imshow(example)
    axes[i].axis('off')
plt.show()

In [None]:
class Unet(nn.Module):
    # initializers
    def __init__(self):
          super(Unet, self).__init__()

          self.conv1 = nn.Conv2d(3, 64, kernel_size=(4,4), stride=2, padding=1)
          self.ReLU1 = nn.LeakyReLU(negative_slope=0.2)

          self.conv2 = nn.Conv2d(64, 128, kernel_size=(4,4), stride=2, padding=1)
          self.BN2 = nn.BatchNorm2d(128)
          self.ReLU2 = nn.LeakyReLU(negative_slope=0.2)

          self.conv3 = nn.Conv2d(128, 256, kernel_size=(4,4), stride=2, padding=1)
          self.BN3 = nn.BatchNorm2d(256)
          self.ReLU3 = nn.LeakyReLU(negative_slope=0.2)

          self.conv4 = nn.Conv2d(256, 512, kernel_size=(4,4), stride=2, padding=1)
          self.BN4 = nn.BatchNorm2d(512)
          self.ReLU4 = nn.LeakyReLU(negative_slope=0.2)

          self.conv5 = nn.Conv2d(512, 512, kernel_size=(4,4), stride=2, padding=1)
          self.BN5 = nn.BatchNorm2d(512)
          self.ReLU5 = nn.LeakyReLU(negative_slope=0.2)

          self.conv6 = nn.Conv2d(512, 512, kernel_size=(4,4), stride=2, padding=1)
          self.BN6 = nn.BatchNorm2d(512)
          self.ReLU6 = nn.LeakyReLU(negative_slope=0.2)

          self.conv7 = nn.Conv2d(512, 512, kernel_size=(4,4), stride=2, padding=1)
          self.BN7 = nn.BatchNorm2d(512)
          self.ReLU7 = nn.LeakyReLU(negative_slope=0.2)

          self.conv8 = nn.Conv2d(512, 512, kernel_size=(4,4), stride=2, padding=1)
          self.BN8 = nn.BatchNorm2d(512)
          self.ReLU8 = nn.LeakyReLU(negative_slope=0.2)



          self.conv9 = nn.ConvTranspose2d(512, 512, kernel_size=(4,4), stride=(2,2), padding=(1,1))
          self.BN9 = nn.BatchNorm2d(512)

          self.conv10 = nn.ConvTranspose2d(1024, 512, kernel_size=(4,4), stride=(2,2), padding=(1,1))
          self.BN10 = nn.BatchNorm2d(512)

          self.conv11 = nn.ConvTranspose2d(1024, 512, kernel_size=(4,4), stride=(2,2), padding=(1,1))
          self.BN11 = nn.BatchNorm2d(512)

          self.conv12 = nn.ConvTranspose2d(1024, 512, kernel_size=(4,4), stride=(2,2), padding=(1,1))
          self.BN12 = nn.BatchNorm2d(512)

          self.conv13 = nn.ConvTranspose2d(1024, 256, kernel_size=(4,4), stride=(2,2), padding=(1,1))
          self.BN13 = nn.BatchNorm2d(256)

          self.conv14 = nn.ConvTranspose2d(512, 128, kernel_size=(4,4), stride=(2,2), padding=(1,1))
          self.BN14 = nn.BatchNorm2d(128)

          self.conv15 = nn.ConvTranspose2d(256, 64, kernel_size=(4,4), stride=(2,2), padding=(1,1))
          self.BN15 = nn.BatchNorm2d(64)

          self.conv16 = nn.ConvTranspose2d(128, 3, kernel_size=(4,4), stride=(2,2), padding=(1,1))
          if not DETECT_SBST:
            self.conv17 = nn.ConvTranspose2d(3, 151, kernel_size=(1,1), stride=(1,1))
          else:
            self.conv17 = nn.ConvTranspose2d(3, TOP_K + 1, kernel_size=(1,1), stride=(1,1))
          self.sigmoid = nn.Sigmoid()


    # weight_init
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

    # forward method
    def forward(self, input):
        o1 = self.conv1(input)
        o = nn.functional.leaky_relu(o1, negative_slope=0.2)


        o2 = self.conv2(o)
        o = self.BN2(o2)
        o = nn.functional.leaky_relu(o, negative_slope=0.2)


        o3 = self.conv3(o)
        o = self.BN3(o3)
        o = nn.functional.leaky_relu(o, negative_slope=0.2)


        o4 = self.conv4(o)
        o = self.BN4(o4)
        o = nn.functional.leaky_relu(o, negative_slope=0.2)


        o5 = self.conv5(o)
        o = self.BN5(o5)
        o = nn.functional.leaky_relu(o, negative_slope=0.2)


        o6 = self.conv6(o)
        o = self.BN6(o6)
        o = nn.functional.leaky_relu(o, negative_slope=0.2)

        
        o7 = self.conv7(o)
        o = self.BN7(o7)
        o = nn.functional.leaky_relu(o, negative_slope=0.2)


        o8 = self.conv8(o)
        o = nn.functional.leaky_relu(o8, negative_slope=0.2)

        o = torch.tanh(o)


        o9 = self.conv9(o)
        o = self.BN9(o9)
        o = torch.nn.functional.relu(o, inplace=True)

        o10 = self.conv10(torch.cat((o,o7), dim=1))
        o = self.BN10(o10)
        o = torch.nn.functional.relu(o, inplace=True)


        o11 = self.conv11(torch.cat((o6,o), dim=1))
        o = self.BN11(o11)
        o = torch.nn.functional.relu(o, inplace=True)

  
        o12 = self.conv12(torch.cat((o5,o), dim=1))
        o = self.BN12(o12)
        o = torch.nn.functional.relu(o, inplace=True)

        o13 = self.conv13(torch.cat((o4,o), dim=1))
        o = self.BN13(o13)
        o = torch.nn.functional.relu(o, inplace=True)


        o14 = self.conv14(torch.cat((o3,o), dim=1))
        o = self.BN14(o14)
        o = torch.nn.functional.relu(o, inplace=True)

    
        o15 = self.conv15(torch.cat((o2,o), dim=1))
        o = self.BN15(o15)
        o = torch.nn.functional.relu(o, inplace=True)

        o = self.conv16(torch.cat((o1,o), dim=1))
        o = self.conv17(o)

        return o

In [None]:
class UnetPP(nn.Module):
  def __init__(self):
          super(UnetPP, self).__init__()

          self.ReLU = nn.LeakyReLU(negative_slope=0.2)

          # self.enc00 = nn.Conv2d(4, 64//2, kernel_size=(4,4))
          # self.enc01 = nn.Conv2d(64//2, 64//4, kernel_size=(4,4), dilation=4)
          # self.enc02 = nn.Conv2d(64//4, 64//8, kernel_size=(4,4), dilation=6)
          # self.enc03 = nn.Conv2d(64//8, 64//16, kernel_size=(4,4), dilation=9)
          # self.enc04 = nn.Conv2d(64//16, 64//16, kernel_size=(4,4), dilation=12)
          self.enc00 = nn.Conv2d(3,64,kernel_size=(3,3), stride=1, padding=1) #This keeps dimensionailty


          #ENCODE BLOCK 1
          self.max1 = nn.MaxPool2d((2,2), stride=2) # this should decrease dimensionality by factor of 2
          self.enc10 = nn.Conv2d(64, 128//2, kernel_size=(3,3), stride=1, padding=1) 
          self.enc11 = nn.Conv2d(128//2, 128//4, kernel_size=(3,3), dilation=3)
          self.enc12 = nn.Conv2d(128//4, 128//8, kernel_size=(3,3), dilation=6)
          self.enc13 = nn.Conv2d(128//8, 128//16, kernel_size=(3,3), dilation=9)
          self.enc14 = nn.Conv2d(128//16, 128//16, kernel_size=(3,3), dilation=12)

          #ENCODE BLOCK 2
          self.max2 = nn.MaxPool2d(2,stride=2) # decrease by factor of 2
          self.enc20 = nn.Conv2d(128, 256//2, kernel_size=(3,3), stride=1, padding=1)
          self.enc21 = nn.Conv2d(256//2, 256//4, kernel_size=(3,3), dilation=3)
          self.enc22 = nn.Conv2d(256//4, 256//8, kernel_size=(3,3), dilation=6)
          self.enc23 = nn.Conv2d(256//8, 256//16, kernel_size=(3,3), dilation=9)
          self.enc24 = nn.Conv2d(256//16, 256//16, kernel_size=(3,3), dilation=12)

          #ENCODE BLOCK 3
          self.max3 = nn.MaxPool2d(2, stride=2)
          self.enc30 = nn.Conv2d(256, 512//2, kernel_size=(3,3), stride=1, padding=1)
          self.enc31 = nn.Conv2d(512//2, 512//4, kernel_size=(3,3), dilation=3)
          self.enc32 = nn.Conv2d(512//4, 512//8, kernel_size=(3,3), dilation=6)
          self.enc33 = nn.Conv2d(512//8, 512//16, kernel_size=(3,3), dilation=9)
          self.enc34 = nn.Conv2d(512//16, 512//16, kernel_size=(3,3), dilation=12)

          #ENCODE BLOCK 4
          self.max4 = nn.MaxPool2d(2, stride=2)
          self.enc40 = nn.Conv2d(512, 512//2, kernel_size=(3,3), stride=1, padding=1)
          self.enc41 = nn.Conv2d(512//2, 512//4, kernel_size=(3,3), dilation=3)
          self.enc42 = nn.Conv2d(512//4, 512//8, kernel_size=(3,3), dilation=6)
          self.enc43 = nn.Conv2d(512//8, 512//16, kernel_size=(3,3), dilation=9)
          self.enc44 = nn.Conv2d(512//16, 512//16, kernel_size=(3,3), dilation=12)

          #SKIP LAYER (2,1)
          self.skip21 = nn.Conv2d(512+512, 512+512, kernel_size=(3,3), stride=1, padding=1) #input is concatination of up(enc3) and (enc2)

          #SKIP LAYER (1,1) and (1,2)
          self.skip11 = nn.Conv2d(512+256, 512+256, kernel_size=(3,3), stride=1, padding=1)
          self.skip12 = nn.Conv2d(512+512+256+(256+512), 512+512+256+(256+512), kernel_size=(3,3), stride=1, padding=1)

          #SKIP LAYER (0,1), (0,2), (0,3)
          self.skip01 = nn.Conv2d(256+128, 256+128, kernel_size=(3,3), stride=1, padding=1)
          self.skip02 = nn.Conv2d(128+(256+128)+(512+256),128+(256+128)+(512+256), kernel_size=(3,3), stride=1, padding=1)
          self.skip03 = nn.Conv2d(128+(256+128)+128+(256+128)+(256+512)+512+512+256+(256+512),3840, kernel_size=(3,3), stride=1, padding=1)


          #DECODE BLOCK 1
          self.up0 = nn.Upsample(scale_factor=(2,2))
          self.dec00 = nn.ConvTranspose2d(1024+1024, 256//2, kernel_size=(3,3), stride=1, padding=1)
          self.dec01 = nn.ConvTranspose2d(256//2, 256//4, kernel_size=(3,3), dilation=3, padding=3)
          self.dec02 = nn.ConvTranspose2d(256//4, 256//8, kernel_size=(3,3), dilation=6, padding=6)
          self.dec03 = nn.ConvTranspose2d(256//8, 256//16, kernel_size=(3,3), dilation=9, padding=9)
          self.dec04 = nn.ConvTranspose2d(256//16, 256//16, kernel_size=(3,3), dilation=12, padding=12)

          #DECODE BLOCK 2
          self.up1 = nn.Upsample(scale_factor=(2,2))
          self.dec10 = nn.ConvTranspose2d(3328, 128//2, kernel_size=(3,3), stride=1, padding=1)
          self.dec11 = nn.ConvTranspose2d(128//2, 128//4, kernel_size=(3,3), dilation=3, padding=3)
          self.dec12 = nn.ConvTranspose2d(128//4, 128//8, kernel_size=(3,3), dilation=6, padding=6)
          self.dec13 = nn.ConvTranspose2d(128//8, 128//16, kernel_size=(3,3), dilation=9, padding=9)
          self.dec14 = nn.ConvTranspose2d(128//16, 128//16, kernel_size=(3,3), dilation=12, padding=12)

          #DECODE BLOCK 3
          self.up2 = nn.Upsample(scale_factor=(2,2))
          self.dec20 = nn.ConvTranspose2d(5760, 64//2, kernel_size=(3,3), padding=1, stride=1)
          self.dec21 = nn.ConvTranspose2d(64//2, 64//4, kernel_size=(3,3), dilation=3, padding=3)
          self.dec22 = nn.ConvTranspose2d(64//4, 64//8, kernel_size=(3,3), dilation=6, padding=6)
          self.dec23 = nn.ConvTranspose2d(64//8, 64//16, kernel_size=(3,3), dilation=9, padding=9)
          self.dec24 = nn.ConvTranspose2d(64//16, 64//16, kernel_size=(3,3), dilation=12, padding=12)

          #DECODE BLOCK 4
          self.up3 = nn.Upsample(scale_factor=(2,2))
          self.dec30 = nn.ConvTranspose2d(128, 64//2, kernel_size=(3,3), padding=1, stride=1)
          self.dec31 = nn.ConvTranspose2d(64//2, 64//4, kernel_size=(3,3), dilation=3, padding=3)
          self.dec32 = nn.ConvTranspose2d(64//4, 64//8, kernel_size=(3,3), dilation=6, padding=6)
          self.dec33 = nn.ConvTranspose2d(64//8, 64//16, kernel_size=(3,3), dilation=9,padding=9)
          self.dec34 = nn.ConvTranspose2d(64//16, 64//16, kernel_size=(3,3), dilation=12, padding=12)


          #CLASSIFICATION LAYER
          self.up4 = nn.Upsample(scale_factor=(2,2))
          if not DETECT_SBST:
            self.dec4 = nn.ConvTranspose2d(64,151,kernel_size=(1,1))
          else:
            self.dec4 = nn.Conv2d(64,TOP_K + 1,kernel_size=(1,1))
          
          self.sigmoid = nn.Sigmoid()
          
    
  # weight_init
  def weight_init(self, mean, std):
      for m in self._modules:
          normal_init(self._modules[m], mean, std)

  def forward(self, input):

      #ENCODE

      x=input
      x = self.enc00(input)
      out0_x = self.ReLU(x)

      x = self.max1(out0_x)
      x_1 = self.ReLU(self.enc10(x))
      x = F.pad(x_1,(3,3,3,3))
      x_2 = self.ReLU(self.enc11(x))
      x = F.pad(x_2,(6,6,6,6))
      x_3 = self.ReLU(self.enc12(x))
      x = F.pad(x_3,(9,9,9,9))
      x_4 = self.ReLU(self.enc13(x))
      x = F.pad(x_4,(12,12,12,12))
      x_5 = self.ReLU(self.enc14(x))
      out1_x = torch.cat((x_1, x_2, x_3, x_4, x_5),dim=1)

      x = self.max2(out1_x)
      x_1 = self.ReLU(self.enc20(x))
      x = F.pad(x_1,(3,3,3,3))
      x_2 = self.ReLU(self.enc21(x))
      x = F.pad(x_2,(6,6,6,6))
      x_3 = self.ReLU(self.enc22(x))
      x = F.pad(x_3,(9,9,9,9))
      x_4 = self.ReLU(self.enc23(x))
      x = F.pad(x_4,(12,12,12,12))
      x_5 = self.ReLU(self.enc24(x))
      out2_x = torch.cat((x_1, x_2, x_3, x_4, x_5),dim=1)

      x = self.max3(out2_x)
      x_1 = self.ReLU(self.enc30(x))
      x = F.pad(x_1,(3,3,3,3))
      x_2 = self.ReLU(self.enc31(x))
      x = F.pad(x_2,(6,6,6,6))
      x_3 = self.ReLU(self.enc32(x))
      x = F.pad(x_3,(9,9,9,9))
      x_4 = self.ReLU(self.enc33(x))
      x = F.pad(x_4,(12,12,12,12))
      x_5 = self.ReLU(self.enc34(x))
      out3_x = torch.cat((x_1, x_2, x_3, x_4, x_5),dim=1)
   
      x = self.max3(out3_x)
      x_1 = self.ReLU(self.enc40(x))
      x = F.pad(x_1,(3,3,3,3))
      x_2 = self.ReLU(self.enc41(x))
      x = F.pad(x_2,(6,6,6,6))
      x_3 = self.ReLU(self.enc42(x))
      x = F.pad(x_3,(9,9,9,9))
      x_4 = self.ReLU(self.enc43(x))
      x = F.pad(x_4,(12,12,12,12))
      x_5 = self.ReLU(self.enc44(x))
      out4_x = torch.cat((x_1, x_2, x_3, x_4, x_5),dim=1)

      #SKIP CONNECTIONS

      print(out3_x.shape, out4_x.shape)
      x21 = self.ReLU(self.skip21(torch.cat((self.up0(out4_x), out3_x),dim=1)))

      x11 = self.ReLU(self.skip11(torch.cat((self.up0(out3_x),out2_x),dim=1)))
      x12 = self.ReLU(self.skip12(torch.cat((self.up0(x21), torch.cat((out2_x,x11),dim=1)),dim=1)))

      x01 = self.ReLU(self.skip01(torch.cat((self.up0(out2_x),out1_x),dim=1)))
      x02 = self.ReLU(self.skip02(torch.cat((self.up0(x11), torch.cat((out1_x,x01),dim=1)),dim=1)))
      x03 = self.ReLU(self.skip03(torch.cat((self.up0(x12), torch.cat((out1_x,x01,x02),dim=1)),dim=1)))


      #DECODE

      x = self.up0(out4_x)
      x_1 = self.ReLU(self.dec00(torch.cat((out3_x,x21,x),dim=1)))
      x_2 = self.ReLU(self.dec01(x_1))
      x_3 = self.ReLU(self.dec02(x_2))
      x_4 = self.ReLU(self.dec03(x_3))
      x_5 = self.ReLU(self.dec04(x_4))
      x = torch.cat((x_1, x_2, x_3, x_4, x_5),dim=1)

      x = self.up1(x)
      x_1 = self.ReLU(self.dec10(torch.cat((out2_x,x11,x12,x),dim=1)))
      x_2 = self.ReLU(self.dec11(x_1))
      x_3 = self.ReLU(self.dec12(x_2))
      x_4 = self.ReLU(self.dec13(x_3))
      x_5 = self.ReLU(self.dec14(x_4))
      x = torch.cat((x_1, x_2, x_3, x_4, x_5),dim=1)

      x = self.up2(x)
      x_1 = self.ReLU(self.dec20(torch.cat((out1_x,x01,x02,x03,x),dim=1)))
      x_2 = self.ReLU(self.dec21(x_1))
      x_3 = self.ReLU(self.dec22(x_2))
      x_4 = self.ReLU(self.dec23(x_3))
      x_5 = self.ReLU(self.dec24(x_4))
      x = torch.cat((x_1, x_2, x_3, x_4, x_5),dim=1)

      x = self.up3(x)
      x_1 = self.ReLU(self.dec30(torch.cat((out0_x,x),dim=1)))
      x_2 = self.ReLU(self.dec31(x_1))
      x_3 = self.ReLU(self.dec32(x_2))
      x_4 = self.ReLU(self.dec33(x_3))
      x_5 = self.ReLU(self.dec34(x_4))
      x = torch.cat((x_1, x_2, x_3, x_4, x_5),dim=1)

      x = self.dec4(x)

      return x

          


In [None]:
class DialatedUnet(nn.Module):
  def __init__(self):
          super(DialatedUnet, self).__init__()

          self.ReLU = nn.LeakyReLU(negative_slope=0.2)

          # self.enc00 = nn.Conv2d(4, 64//2, kernel_size=(4,4))
          # self.enc01 = nn.Conv2d(64//2, 64//4, kernel_size=(4,4), dilation=4)
          # self.enc02 = nn.Conv2d(64//4, 64//8, kernel_size=(4,4), dilation=6)
          # self.enc03 = nn.Conv2d(64//8, 64//16, kernel_size=(4,4), dilation=9)
          # self.enc04 = nn.Conv2d(64//16, 64//16, kernel_size=(4,4), dilation=12)
          self.enc00 = nn.Conv2d(3,64,kernel_size=(3,3), stride=1, padding=1) #This keeps dimensionailty

          self.max1 = nn.MaxPool2d((2,2), stride=2) # this should decrease dimensionality by factor of 2
          self.enc10 = nn.Conv2d(64, 128//2, kernel_size=(3,3), stride=1, padding=1) 
          self.enc11 = nn.Conv2d(128//2, 128//4, kernel_size=(3,3), dilation=3)
          self.enc12 = nn.Conv2d(128//4, 128//8, kernel_size=(3,3), dilation=6)
          self.enc13 = nn.Conv2d(128//8, 128//16, kernel_size=(3,3), dilation=9)
          self.enc14 = nn.Conv2d(128//16, 128//16, kernel_size=(3,3), dilation=12)

          self.max2 = nn.MaxPool2d(2,stride=2) # decrease by factor of 2
          self.enc20 = nn.Conv2d(128, 256//2, kernel_size=(3,3), stride=1, padding=1)
          self.enc21 = nn.Conv2d(256//2, 256//4, kernel_size=(3,3), dilation=3)
          self.enc22 = nn.Conv2d(256//4, 256//8, kernel_size=(3,3), dilation=6)
          self.enc23 = nn.Conv2d(256//8, 256//16, kernel_size=(3,3), dilation=9)
          self.enc24 = nn.Conv2d(256//16, 256//16, kernel_size=(3,3), dilation=12)

          self.max3 = nn.MaxPool2d(2, stride=2)
          self.enc30 = nn.Conv2d(256, 512//2, kernel_size=(3,3), stride=1, padding=1)
          self.enc31 = nn.Conv2d(512//2, 512//4, kernel_size=(3,3), dilation=3)
          self.enc32 = nn.Conv2d(512//4, 512//8, kernel_size=(3,3), dilation=6)
          self.enc33 = nn.Conv2d(512//8, 512//16, kernel_size=(3,3), dilation=9)
          self.enc34 = nn.Conv2d(512//16, 512//16, kernel_size=(3,3), dilation=12)

          self.max4 = nn.MaxPool2d(2, stride=2)
          self.enc40 = nn.Conv2d(512, 512//2, kernel_size=(3,3), stride=1, padding=1)
          self.enc41 = nn.Conv2d(512//2, 512//4, kernel_size=(3,3), dilation=3)
          self.enc42 = nn.Conv2d(512//4, 512//8, kernel_size=(3,3), dilation=6)
          self.enc43 = nn.Conv2d(512//8, 512//16, kernel_size=(3,3), dilation=9)
          self.enc44 = nn.Conv2d(512//16, 512//16, kernel_size=(3,3), dilation=12)

          self.up0 = nn.Upsample(scale_factor=(2,2))
          self.dec00 = nn.ConvTranspose2d(1024, 256//2, kernel_size=(3,3), stride=1, padding=1)
          self.dec01 = nn.ConvTranspose2d(256//2, 256//4, kernel_size=(3,3), dilation=3, padding=3)
          self.dec02 = nn.ConvTranspose2d(256//4, 256//8, kernel_size=(3,3), dilation=6, padding=6)
          self.dec03 = nn.ConvTranspose2d(256//8, 256//16, kernel_size=(3,3), dilation=9, padding=9)
          self.dec04 = nn.ConvTranspose2d(256//16, 256//16, kernel_size=(3,3), dilation=12, padding=12)

          self.up1 = nn.Upsample(scale_factor=(2,2))
          self.dec10 = nn.ConvTranspose2d(512, 128//2, kernel_size=(3,3), stride=1, padding=1)
          self.dec11 = nn.ConvTranspose2d(128//2, 128//4, kernel_size=(3,3), dilation=3, padding=3)
          self.dec12 = nn.ConvTranspose2d(128//4, 128//8, kernel_size=(3,3), dilation=6, padding=6)
          self.dec13 = nn.ConvTranspose2d(128//8, 128//16, kernel_size=(3,3), dilation=9, padding=9)
          self.dec14 = nn.ConvTranspose2d(128//16, 128//16, kernel_size=(3,3), dilation=12, padding=12)

          self.up2 = nn.Upsample(scale_factor=(2,2))
          self.dec20 = nn.ConvTranspose2d(256, 64//2, kernel_size=(3,3), padding=1, stride=1)
          self.dec21 = nn.ConvTranspose2d(64//2, 64//4, kernel_size=(3,3), dilation=3, padding=3)
          self.dec22 = nn.ConvTranspose2d(64//4, 64//8, kernel_size=(3,3), dilation=6, padding=6)
          self.dec23 = nn.ConvTranspose2d(64//8, 64//16, kernel_size=(3,3), dilation=9, padding=9)
          self.dec24 = nn.ConvTranspose2d(64//16, 64//16, kernel_size=(3,3), dilation=12, padding=12)

          self.up3 = nn.Upsample(scale_factor=(2,2))
          self.dec30 = nn.ConvTranspose2d(128, 64//2, kernel_size=(3,3), padding=1, stride=1)
          self.dec31 = nn.ConvTranspose2d(64//2, 64//4, kernel_size=(3,3), dilation=3, padding=3)
          self.dec32 = nn.ConvTranspose2d(64//4, 64//8, kernel_size=(3,3), dilation=6, padding=6)
          self.dec33 = nn.ConvTranspose2d(64//8, 64//16, kernel_size=(3,3), dilation=9,padding=9)
          self.dec34 = nn.ConvTranspose2d(64//16, 64//16, kernel_size=(3,3), dilation=12, padding=12)

          self.up4 = nn.Upsample(scale_factor=(2,2))

          if not DETECT_SBST:
            self.dec4 = nn.ConvTranspose2d(64,151,kernel_size=(1,1))
          else:
            self.dec4 = nn.Conv2d(64,TOP_K + 1,kernel_size=(1,1))
          
          self.sigmoid = nn.Sigmoid()
          
    
  # weight_init
  def weight_init(self, mean, std):
      for m in self._modules:
          normal_init(self._modules[m], mean, std)

  def forward(self, input):

      x=input
      x = self.enc00(input)
      out0_x = self.ReLU(x)

      x = self.max1(out0_x)
      x_1 = self.ReLU(self.enc10(x))
      #print(x_1.shape)
      x = F.pad(x_1,(3,3,3,3))
      x_2 = self.ReLU(self.enc11(x))
      x = F.pad(x_2,(6,6,6,6))
      x_3 = self.ReLU(self.enc12(x))
      x = F.pad(x_3,(9,9,9,9))
      x_4 = self.ReLU(self.enc13(x))
      x = F.pad(x_4,(12,12,12,12))
      x_5 = self.ReLU(self.enc14(x))
      out1_x = torch.cat((x_1, x_2, x_3, x_4, x_5),dim=1)

      x = self.max2(out1_x)
      x_1 = self.ReLU(self.enc20(x))
      # print(x_1.shape)
      x = F.pad(x_1,(3,3,3,3))
      x_2 = self.ReLU(self.enc21(x))
      x = F.pad(x_2,(6,6,6,6))
      x_3 = self.ReLU(self.enc22(x))
      x = F.pad(x_3,(9,9,9,9))
      x_4 = self.ReLU(self.enc23(x))
      x = F.pad(x_4,(12,12,12,12))
      x_5 = self.ReLU(self.enc24(x))
      out2_x = torch.cat((x_1, x_2, x_3, x_4, x_5),dim=1)

      x = self.max3(out2_x)
      x_1 = self.ReLU(self.enc30(x))
      # print(x_1.shape)
      x = F.pad(x_1,(3,3,3,3))
      x_2 = self.ReLU(self.enc31(x))
      x = F.pad(x_2,(6,6,6,6))
      x_3 = self.ReLU(self.enc32(x))
      x = F.pad(x_3,(9,9,9,9))
      x_4 = self.ReLU(self.enc33(x))
      x = F.pad(x_4,(12,12,12,12))
      x_5 = self.ReLU(self.enc34(x))
      out3_x = torch.cat((x_1, x_2, x_3, x_4, x_5),dim=1)
   
      x = self.max3(out3_x)
      x_1 = self.ReLU(self.enc40(x))
      # print(x_1.shape)
      x = F.pad(x_1,(3,3,3,3))
      x_2 = self.ReLU(self.enc41(x))
      x = F.pad(x_2,(6,6,6,6))
      x_3 = self.ReLU(self.enc42(x))
      x = F.pad(x_3,(9,9,9,9))
      x_4 = self.ReLU(self.enc43(x))
      x = F.pad(x_4,(12,12,12,12))
      x_5 = self.ReLU(self.enc44(x))
      out4_x = torch.cat((x_1, x_2, x_3, x_4, x_5),dim=1)

      x = self.up0(out4_x)
      x_1 = self.ReLU(self.dec00(torch.cat((out3_x,x),dim=1)))
      x_2 = self.ReLU(self.dec01(x_1))
      x_3 = self.ReLU(self.dec02(x_2))
      x_4 = self.ReLU(self.dec03(x_3))
      x_5 = self.ReLU(self.dec04(x_4))
      x = torch.cat((x_1, x_2, x_3, x_4, x_5),dim=1)

      x = self.up1(x)
      x_1 = self.ReLU(self.dec10(torch.cat((out2_x,x),dim=1)))
      x_2 = self.ReLU(self.dec11(x_1))
      x_3 = self.ReLU(self.dec12(x_2))
      x_4 = self.ReLU(self.dec13(x_3))
      x_5 = self.ReLU(self.dec14(x_4))
      x = torch.cat((x_1, x_2, x_3, x_4, x_5),dim=1)

      x = self.up2(x)
      x_1 = self.ReLU(self.dec20(torch.cat((out1_x,x),dim=1)))
      x_2 = self.ReLU(self.dec21(x_1))
      x_3 = self.ReLU(self.dec22(x_2))
      x_4 = self.ReLU(self.dec23(x_3))
      x_5 = self.ReLU(self.dec24(x_4))
      x = torch.cat((x_1, x_2, x_3, x_4, x_5),dim=1)

      x = self.up3(x)
      x_1 = self.ReLU(self.dec30(torch.cat((out0_x,x),dim=1)))
      x_2 = self.ReLU(self.dec31(x_1))
      x_3 = self.ReLU(self.dec32(x_2))
      x_4 = self.ReLU(self.dec33(x_3))
      x_5 = self.ReLU(self.dec34(x_4))
      x = torch.cat((x_1, x_2, x_3, x_4, x_5),dim=1)

      x = self.dec4(x)

      return x

          


dice_loss was experimented with but ultimately unused, there is currently no code which uses dice_loss

In [None]:
def dice_loss(pred, target):
    """This definition generalize to real valued pred and target vector.
This should be differentiable.
    pred: tensor with first dimension as batch
    target: tensor with first dimension as batch
    """

    smooth = 1.

    # have to use contiguous since they may from a torch.view op
    iflat = pred.contiguous().view(-1)
    tflat = target.contiguous().view(-1)
    intersection = (iflat * tflat).sum()

    A_sum = torch.sum(tflat * iflat)
    B_sum = torch.sum(tflat * tflat)
    
    return 1 - ((2. * intersection + smooth) / (A_sum + B_sum + smooth) )


In [None]:

# BCE_loss = nn.BCELoss().cuda()
# L1_loss = nn.L1Loss().cuda()
CE_loss = nn.CrossEntropyLoss().cuda()



def train(G, num_epochs = 20):
    hist_G_losses = []

    G_optimizer = optim.Adam(G.parameters(), lr=0.0003, betas=(0.5,0.999))

    print('training start!')
    start_time = time.time()
    for epoch in range(num_epochs):
        print('Start training epoch %d' % (epoch + 1))
        G_losses = []
        val_losses = []
        epoch_start_time = time.time()
        num_iter = 0
        for x_, y_ in tqdm(train_loader):
            
            x_, y_ = x_.cuda(), y_.cuda()
      
            # Train the generator
            G.zero_grad()
            G_result = G(x_)

            G_train_loss = CE_loss(G_result, y_)
            # G_train_loss = dice_loss(G_result, y_)
            
            G_train_loss.backward()
            G_optimizer.step()
            loss_G = G_train_loss.detach().item()

            G_losses.append(loss_G)
            hist_G_losses.append(loss_G)
            num_iter += 1
        
        # for x_, y_ in tqdm(test_loader):
        #   with torch.no_grad():
        #     x_, y_ = x_.cuda(), y_.cuda()
        #     G_result = G(x_)
        #     G_val_loss = CE_loss(G_result, y_)
        #     val_losses.append(G_val_loss.detach().item())

        epoch_end_time = time.time()
        per_epoch_ptime = epoch_end_time - epoch_start_time

        print('[%d/%d] - using time: %.2f seconds' % ((epoch + 1), num_epochs, per_epoch_ptime))
        print('loss of generator G: %.3f' % (torch.mean(torch.FloatTensor(G_losses))))
        # print('Validation loss: %.3f' % (torch.mean(torch.FloatTensor(val_losses))))
        # if epoch == 0 or (epoch + 1) % 2 == 0:
        with torch.no_grad():
                # print('TO DO: display image')
                show_result(G, fixed_x_, fixed_y_, (epoch+1))
                print("IOU after epochs ", mIOU(fixed_y_,G(fixed_x_),TOP_K+1))

    end_time = time.time()
    total_ptime = end_time - start_time

    return hist_G_losses

In [None]:
# Uncomment below networks to work with each model one at a time

# Define network
# unet = Unet()
# unet.weight_init(mean=0.0, std=0.02)
# unet.cuda()
# unet.train()

unetD = DialatedUnet()
unetD.weight_init(mean=(0.0), std=0.02)
unetD.cuda()
unetD.train()

# unetPP = UnetPP()
# unetPP.weight_init(mean=(0.0), std=0.02)
# unetPP.cuda()
# unetPP.train()

In [None]:
def process_image(img):
    if img.shape[0] == 3:
      return (img.cpu().data.numpy().transpose(1, 2, 0) + 1) / 2
    return (img.cpu().data.numpy() + 1) / 2

def show_result(G, x_, y_, num_epoch):
    predict_images = torch.argmax(G(x_),dim=1)
    
    fig, ax = plt.subplots(x_.size()[0], 3, figsize=(6,10))

    for i in range(x_.size()[0]):
        ax[i, 0].get_xaxis().set_visible(False)
        ax[i, 0].get_yaxis().set_visible(False)
        ax[i, 1].get_xaxis().set_visible(False)
        ax[i, 1].get_yaxis().set_visible(False)
        ax[i, 2].get_xaxis().set_visible(False)
        ax[i, 2].get_yaxis().set_visible(False)
        ax[i, 0].cla()
        ax[i, 0].imshow(process_image(x_[i]))
        ax[i, 1].cla()
        ax[i, 1].imshow(process_image(predict_images[i]))
        ax[i, 2].cla()
        ax[i, 2].imshow(process_image(y_[i]))
    
    plt.tight_layout()
    label_epoch = 'Epoch {0}'.format(num_epoch)
    fig.text(0.5, 0, label_epoch, ha='center')
    label_input = 'Input'
    fig.text(0.18, 1, label_input, ha='center')
    label_output = 'Output'
    fig.text(0.5, 1, label_output, ha='center')
    label_truth = 'Ground truth'
    fig.text(0.81, 1, label_truth, ha='center')

    plt.show()

SMOOTH = 1e-6

def mIOU(label, pred, num_classes=TOP_K+1):
    pred = F.softmax(pred, dim=1)              
    pred = torch.argmax(pred, dim=1).squeeze(1)
    iou_list = list()
    present_iou_list = list()

    pred = pred.view(-1)
    label = label.view(-1)
    # Note: Following for loop goes from 0 to (num_classes-1)
    # and ignore_index is num_classes, thus ignore_index is
    # not considered in computation of IoU.
    for sem_class in range(num_classes):
        pred_inds = (pred == sem_class)
        target_inds = (label == sem_class)
        if target_inds.long().sum().item() == 0:
            iou_now = float('nan')
        else: 
            intersection_now = (pred_inds[target_inds]).long().sum().item()
            union_now = pred_inds.long().sum().item() + target_inds.long().sum().item() - intersection_now
            iou_now = float(intersection_now) / float(union_now)
            present_iou_list.append(iou_now)
        iou_list.append(iou_now)
    return np.mean(present_iou_list)


**Train over dataset**

In [None]:
#Uncomment below lines to train specific models

# hist_G_L1_losses = train(unetPP, num_epochs = 3)
hist_G_L1_losses = train(unetD, num_epochs = 15)
# hist_G_L1_losses = train(unet, num_epochs = 10)

**Test**

In [None]:
#Uncomment below lines to test specific models
#results = results = unetPP(test)
results = unetD(test)
# results = unet(test)