In [1]:
import os
import logging
import argparse
import json
from typing import Union

import torch
from torch.utils.tensorboard import SummaryWriter
from torch.optim import Adam
from torch.optim import Optimizer
                              
from models import UNet
from utils import *
from losses import PSNR
from SynthTrainer import SynthTrainer
import arg_parser

We are going to start by preparing our various models. First we will start with the model described in our target paper "Desmoking Laparoscopy Surgery Images Using an Image-to-Image Translation Guided by an Embedded Dark Channel"
We will have to define a discriminator and a generator. Note the tables here reflect processing a 512x512 image as opposed to 256x256 as the original paper used. Also note that upon inspecting the paper repository, the tables in the paper appear inaccurate. I recreated these nets to be true to their repository by looking at the logic that generates the nets used.

Since the arguments used were not specified in the paper I made a best estimate related to optional arguments based on what I thought would make sense. I also made some modification on where I apply dropout based on some advice from Dr. Florian Richter. It appears dropout was applied on all the 'decoder' layers in the unet, but it equally well could have not been applied at all depending on arguments passed at runtime. My assumption is dropout will be helpful and should be applied at the innermost layers of the decoder.

ADD TABLES HERE AS IMAGES. KATEX DOES NOT IMPLEMENT TABULAR...

In [None]:
# DEFINE THE PAPER DISCRIMINATOR
class Discriminator(torch.nn.Module):
    def __init__(self, input_channels:int = 3):
        super().__init__()

        self.sequence = [torch.nn.Conv2d(input_channels, 64, kernel_size=4, stride=2, padding=1), 
                    torch.nn.LeakyReLU(0.2, True)]
        
        sequence += [torch.nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), 
                     torch.nn.BatchNorm2d(128),
                     torch.nn.LeakyReLU(0.2, True)]
        
        sequence += [torch.nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), 
                     torch.nn.BatchNorm2d(256),
                     torch.nn.LeakyReLU(0.2, True)]
        
        sequence += [torch.nn.ZeroPad2d(2)]
        
        sequence += [torch.nn.Conv2d(256, 512, kernel_size=4, stride=1), 
                     torch.nn.BatchNorm2d(512),
                     torch.nn.LeakyReLU(0.2, True)]
        
        sequence += [torch.nn.ZeroPad2d(2)]
        
        sequence += [torch.nn.Conv2d(512, 1, kernel_size=4, stride=1)]
       
       # sequence += [torch.nn.Sigmoid()] #IT APPEARS THIS WAS NOT USED IN THE PAPER. COULD LOOK AT ITS IMPACT LATER POSSIBLY
        
        self.model = torch.nn.Sequential(*sequence)

    def forward(self, x:torch.Tensor) -> torch.Tensor:
        return self.model(x)



Now we define the first unet which follows the paper.

In [None]:
# DEFINE THE PAPER UNET
class UNET(torch.nn.Module):
    def __init__(self, input_channels:int = 3):
        super().__init__()

        self.drp = torch.nn.dropout(0.5) # We will apply this in forward. Technically this could be in the list below instead too.

        # Note that all convolutions with a BatchNorm after them have bias false as BatchNorm contains a bias term itself.
        # It would therefore be redundant and add nothing.

        sequence = [torch.nn.Conv2d(input_channels, 64, kernel_size=4, stride=2, padding=1),                # 0         -> 42
                    torch.nn.LeakyReLU(0.2, True)]                                                          # 1
        
        sequence += [torch.nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),              # 2
                     torch.nn.BatchNorm2d(128),                                                             # 3         -> 39
                     torch.nn.LeakyReLU(0.2, True)]                                                         # 4
        sequence += [torch.nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),             # 5
                     torch.nn.BatchNorm2d(256),                                                             # 6         -> 36
                     torch.nn.LeakyReLU(0.2, True)]                                                         # 7
        sequence += [torch.nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),             # 8
                     torch.nn.BatchNorm2d(512),                                                             # 9         -> 33
                     torch.nn.LeakyReLU(0.2, True)]                                                         # 10
        
        sequence += [torch.nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False),             # 11
                     torch.nn.BatchNorm2d(512),                                                             # 12        -> 30
                     torch.nn.LeakyReLU(0.2, True)]                                                         # 13
        sequence += [torch.nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False),             # 14
                     torch.nn.BatchNorm2d(512),                                                             # 15        -> 27
                     torch.nn.LeakyReLU(0.2, True)]                                                         # 16
        sequence += [torch.nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False),             # 17
                     torch.nn.BatchNorm2d(512),                                                             # 18        -> 24
                     torch.nn.LeakyReLU(0.2, True)]                                                         # 19
        
        sequence += [torch.nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1),                         # 20
                     torch.nn.ReLU(True)]                                                                   # 21
        sequence += [torch.nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False),    # 22
                     torch.nn.BatchNorm2d(512),                                                             # 23
                     torch.nn.ReLU(True)]                                                                   # 24        <- 18
        
        sequence += [torch.nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1, bias=False),   # 25
                     torch.nn.BatchNorm2d(512),                                                             # 26
                     torch.nn.ReLU(True)]                                                                   # 27        <- 15
        sequence += [torch.nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1, bias=False),   # 28
                     torch.nn.BatchNorm2d(512),                                                             # 29
                     torch.nn.ReLU(True)]                                                                   # 30        <- 12
        sequence += [torch.nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1, bias=False),   # 31
                     torch.nn.BatchNorm2d(512),                                                             # 32
                     torch.nn.ReLU(True)]                                                                   # 33        <- 9
        
        sequence += [torch.nn.ConvTranspose2d(1024, 256, kernel_size=4, stride=2, padding=1, bias=False),   # 34
                     torch.nn.BatchNorm2d(256),                                                             # 35
                     torch.nn.ReLU(True)]                                                                   # 36        <- 6
        sequence += [torch.nn.ConvTranspose2d(512, 128, kernel_size=4, stride=2, padding=1, bias=False),    # 37
                     torch.nn.BatchNorm2d(128),                                                             # 38
                     torch.nn.ReLU(True)]                                                                   # 39        <- 3
        sequence += [torch.nn.ConvTranspose2d(1024, 64, kernel_size=4, stride=2, padding=1, bias=False),    # 40
                     torch.nn.BatchNorm2d(64),                                                              # 41
                     torch.nn.ReLU(True)]                                                                   # 42        <- 0
        
        sequence += [torch.nn.ConvTranspose2d(128, 3, kernel_size=4, stride=2, padding=1),                  # 43
                     torch.nn.Tanh()]                                                                       # 44
        
        for x in range(0, 20):
            if type(sequence[x]) == torch.nn.modules.conv.Conv2d:
                torch.nn.init.kaiming_uniform_(sequence[x].weight, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
        
        for x in range(20, len(sequence)):
            if type(sequence[x]) == torch.nn.modules.conv.Conv2d or type(sequence[x]) == torch.nn.modules.conv.conTranspose2d:
                torch.nn.init.kaiming_uniform_(sequence[x].weight, mode='fan_in', nonlinearity='relu')
        
        for x in range(0, len(sequence)):
            if type(sequence[x]) == torch.nn.modules.batchnorm.BatchNorm2d:
                torch.nn.init.constant_(sequence[x].bias.data, 0.0)
                torch.nn.init.normal_(sequence[x].weight.data, 1.0, 0.02)

        self.sequence = sequence

    def forward(self, x:torch.Tensor) -> torch.Tensor:
        x0 = x = self.sequence[0](x)
        self.sequence[1](x)
        x = self.sequence[2](x)
        x3 = x = self.sequence[3](x)
        self.sequence[4](x)
        x = self.sequence[5](x)
        x6 = x = self.sequence[6](x)
        self.sequence[7](x)
        x = self.sequence[8](x)
        x9 = x = self.sequence[9](x)
        self.sequence[10](x)
        x = self.sequence[11](x)
        x12 = x = self.sequence[12](x)
        self.sequence[13](x)
        x = self.sequence[14](x)
        x15 = x = self.sequence[15](x)
        self.sequence[16](x)
        x = self.sequence[17](x)
        x18 = x = self.sequence[18](x)
        self.sequence[19](x)
        x = self.sequence[20](x)
        self.sequence[21](x)
        x = self.sequence[22](x)
        x = self.sequence[23](x)
        x = torch.cat(x, x18, 1)
        self.sequence[24](x)
        x = self.sequence[25](x)
        x = self.sequence[26](x)
        x = torch.cat(x, x15, 1)
        x = self.sequence[27](x)
        x = self.sequence[28](x)
        x = self.sequence[29](x)
        x = torch.cat(x, x12, 1)
        x = self.sequence[30](x)
        x = self.sequence[31](x)
        x = self.sequence[32](x)
        x = torch.cat(x, x9, 1)
        x = self.sequence[33](x)
        x = self.sequence[34](x)
        x = self.sequence[35](x)
        x = torch.cat(x, x6, 1)
        x = self.sequence[36](x)
        x = self.sequence[37](x)
        x = self.sequence[38](x)
        x = torch.cat(x, x3, 1)
        x = self.sequence[39](x)
        x = self.sequence[40](x)
        x = self.sequence[41](x)
        x = torch.cat(x, x0, 1)
        x = self.sequence[42](x)
        x = self.sequence[43](x)
        return self.sequence[44](x)


Now we define our unet with ablations. Layers xxxx are taken out.

In [None]:
# DEFINE THE ABLATED UNET
class UNETsmolr(torch.nn.Module):
    def __init__(self, input_channels:int = 3):
        super().__init__()

        self.drp = torch.nn.dropout(0.5) # We will apply this in forward. Technically this could be in the list below instead too.

        # Note that instead of changing the number of entries, I substituted the parts of the net I planned to ablate with None
        # and removed them from forward. This means I don't have to bother as much with changing indices.

        # Note that all convolutions with a BatchNorm after them have bias false as BatchNorm contains a bias term itself.
        # It would therefore be redundant and add nothing.

        sequence = [torch.nn.Conv2d(input_channels, 64, kernel_size=4, stride=2, padding=1),                # 0         -> 42
                    torch.nn.LeakyReLU(0.2, True)]                                                          # 1
        
        sequence += [torch.nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),              # 2
                     torch.nn.BatchNorm2d(128),                                                             # 3         -> 39
                     torch.nn.LeakyReLU(0.2, True)]                                                         # 4
        sequence += [torch.nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),             # 5
                     torch.nn.BatchNorm2d(256),                                                             # 6         -> 36
                     torch.nn.LeakyReLU(0.2, True)]                                                         # 7
        sequence += [torch.nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),             # 8
                     torch.nn.BatchNorm2d(512),                                                             # 9         -> 33
                     torch.nn.LeakyReLU(0.2, True)]                                                         # 10
        
        sequence += [torch.nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False),             # 11
                     torch.nn.BatchNorm2d(512),                                                             # 12        -> 30
                     torch.nn.LeakyReLU(0.2, True)]                                                         # 13
        sequence += [torch.nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False),             # 14
                     torch.nn.BatchNorm2d(512),                                                             # 15        -> 27
                     torch.nn.LeakyReLU(0.2, True)]                                                         # 16
        sequence += [torch.nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False),             # 17
                     torch.nn.BatchNorm2d(512),                                                             # 18        -> 24
                     torch.nn.LeakyReLU(0.2, True)]                                                         # 19
        
        sequence += [torch.nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1),                         # 20
                     torch.nn.ReLU(True)]                                                                   # 21
        sequence += [torch.nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2, padding=1, bias=False),    # 22
                     torch.nn.BatchNorm2d(512),                                                             # 23
                     torch.nn.ReLU(True)]                                                                   # 24        <- 18
        
        sequence += [torch.nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1, bias=False),   # 25
                     torch.nn.BatchNorm2d(512),                                                             # 26
                     torch.nn.ReLU(True)]                                                                   # 27        <- 15
        sequence += [torch.nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1, bias=False),   # 28
                     torch.nn.BatchNorm2d(512),                                                             # 29
                     torch.nn.ReLU(True)]                                                                   # 30        <- 12
        sequence += [torch.nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1, bias=False),   # 31
                     torch.nn.BatchNorm2d(512),                                                             # 32
                     torch.nn.ReLU(True)]                                                                   # 33        <- 9
        
        sequence += [torch.nn.ConvTranspose2d(1024, 256, kernel_size=4, stride=2, padding=1, bias=False),   # 34
                     torch.nn.BatchNorm2d(256),                                                             # 35
                     torch.nn.ReLU(True)]                                                                   # 36        <- 6
        sequence += [torch.nn.ConvTranspose2d(512, 128, kernel_size=4, stride=2, padding=1, bias=False),    # 37
                     torch.nn.BatchNorm2d(128),                                                             # 38
                     torch.nn.ReLU(True)]                                                                   # 39        <- 3
        sequence += [torch.nn.ConvTranspose2d(1024, 64, kernel_size=4, stride=2, padding=1, bias=False),    # 40
                     torch.nn.BatchNorm2d(64),                                                              # 41
                     torch.nn.ReLU(True)]                                                                   # 42        <- 0
        
        sequence += [torch.nn.ConvTranspose2d(128, 3, kernel_size=4, stride=2, padding=1),                  # 43
                     torch.nn.Tanh()]                                                                       # 44
        
        # Initilizing here. Initializing the generator should have impact, so making sure to do that.

        for x in range(0, 20):
            if type(sequence[x]) == torch.nn.modules.conv.Conv2d:
                torch.nn.init.kaiming_uniform_(sequence[x].weight, a=0.2, mode='fan_in', nonlinearity='leaky_relu')
        
        for x in range(20, len(sequence)):
            if type(sequence[x]) == torch.nn.modules.conv.Conv2d or type(sequence[x]) == torch.nn.modules.conv.conTranspose2d:
                torch.nn.init.kaiming_uniform_(sequence[x].weight, mode='fan_in', nonlinearity='relu')
        
        for x in range(0, len(sequence)):
            if type(sequence[x]) == torch.nn.modules.batchnorm.BatchNorm2d:
                torch.nn.init.constant_(sequence[x].bias.data, 0.0)
                torch.nn.init.normal_(sequence[x].weight.data, 1.0, 0.02)

        self.sequence = sequence

    def forward(self, x:torch.Tensor) -> torch.Tensor:
        x0 = x = self.sequence[0](x)
        self.sequence[1](x)
        x = self.sequence[2](x)
        x3 = x = self.sequence[3](x)
        self.sequence[4](x)
        x = self.sequence[5](x)
        x6 = x = self.sequence[6](x)
        self.sequence[7](x)
        x = self.sequence[8](x)
        x9 = x = self.sequence[9](x)
        self.sequence[10](x)
        x = self.sequence[11](x)
        x12 = x = self.sequence[12](x)
        self.sequence[13](x)
        x = self.sequence[14](x)
        x15 = x = self.sequence[15](x)
        self.sequence[16](x)
        x = self.sequence[17](x)
        x18 = x = self.sequence[18](x)
        self.sequence[19](x)
        x = self.sequence[20](x)
        self.sequence[21](x)
        x = self.sequence[22](x)
        x = self.sequence[23](x)
        x = torch.cat(x, x18, 1)
        self.sequence[24](x)
        x = self.sequence[25](x)
        x = self.sequence[26](x)
        x = torch.cat(x, x15, 1)
        x = self.sequence[27](x)
        x = self.sequence[28](x)
        x = self.sequence[29](x)
        x = torch.cat(x, x12, 1)
        x = self.sequence[30](x)
        x = self.sequence[31](x)
        x = self.sequence[32](x)
        x = torch.cat(x, x9, 1)
        x = self.sequence[33](x)
        x = self.sequence[34](x)
        x = self.sequence[35](x)
        x = torch.cat(x, x6, 1)
        x = self.sequence[36](x)
        x = self.sequence[37](x)
        x = self.sequence[38](x)
        x = torch.cat(x, x3, 1)
        x = self.sequence[39](x)
        x = self.sequence[40](x)
        x = self.sequence[41](x)
        x = torch.cat(x, x0, 1)
        x = self.sequence[42](x)
        x = self.sequence[43](x)
        return self.sequence[44](x)


Now we need to write our training code. There will be differences here for loading each variation we intend to run (4 in all). Regular and ablated with and without dark channel. In all cases the discriminator remains the same other than input channels which are set by the constructor.

In [None]:
#WRITE TRAINING CODE

Now we need to write code to run our models on a test dataset. Technically this is not necessary as we are not doing parameter tuning. But this does give us an idea of how the model performs on a video that it didn't even have a portion of fed in for training.

In [None]:
#WRITE TESTING CODE FOR TRAINED MODEL

Run training and validation for scenario 1. This is full u-net and no dark channel.

In [None]:
#CODE TO TRAIN SCENARIO 1

Run training and validation for scenario 2. This is ablated u-net and no dark channel.

In [None]:
#CODE TO TRAIN SCENARIO 2

Run training and validation for scenario 3. This is full u-net and dark channel.

In [None]:
#CODE TO TRAIN SCENARIO 3

Run training and validation for scenario 4. This is ablated u-net and dark channel.

In [None]:
#CODE TO TRAIN SCENARIO 4

Now we must run testing on all 4 scenarios. Load in all the models at each epoch so we can plot metrics. (INTEGRATE THIS INTO TRAIN AND VALIDATE??)

In [None]:
# YES?