In [1]:
import os
import logging
import argparse
import json
from typing import Union

import torch
from torch.utils.tensorboard import SummaryWriter
from torch.optim import Adam
from torch.optim import Optimizer
                              
from models import UNet
from utils import *
from losses import PSNR
from SynthTrainer import SynthTrainer
import arg_parser

We are going to start by preparing our various models. First we will start with the model described in our target paper "Desmoking Laparoscopy Surgery Images Using an Image-to-Image Translation Guided by an Embedded Dark Channel"
We will have to define a discriminator and a generator. Note the tables here reflect processing a 512x512 image as opposed to 256x256 as the original paper used.

ADD TABLES HERE AS IMAGES. KATEX DOES NOT IMPLEMENT TABULAR...

In [None]:
# DEFINE THE PAPER DISCRIMINATOR
class Discriminator(torch.nn.Module):
    def __init__(self, input_channels:int = 3):
        super().__init__()

        sequence = [torch.nn.Conv2d(input_channels, 64, kernel_size=4, stride=2, padding=1), 
                    torch.nn.LeakyReLU(0.2, True)]
        
        sequence += [torch.nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), 
                     torch.nn.BatchNorm2d(128),
                     torch.nn.LeakyReLU(0.2, True)]
        
        sequence += [torch.nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), 
                     torch.nn.BatchNorm2d(256),
                     torch.nn.LeakyReLU(0.2, True)]
        
        sequence += [torch.nn.ZeroPad2d(2)]
        
        sequence += [torch.nn.Conv2d(256, 512, kernel_size=4, stride=1), 
                     torch.nn.BatchNorm2d(512),
                     torch.nn.LeakyReLU(0.2, True)]
        
        sequence += [torch.nn.ZeroPad2d(2)]
        
        sequence += [torch.nn.Conv2d(512, 1, kernel_size=4, stride=1)]
       
       # sequence += [torch.nn.Sigmoid()] #IT APPEARS THIS WAS NOT USED IN THE PAPER. COULD LOOK AT ITS IMPACT LATER POSSIBLY
        
        self.model = torch.nn.Sequential(*sequence)

    def forward(self, x:torch.Tensor) -> torch.Tensor:
        return self.model(x)



Now we define the first unet which follows the paper.

In [None]:
# DEFINE THE PAPER UNET
class UNETplus(torch.nn.Module):
    def __init__(self, input_channels:int = 3):
        super().__init__()

        sequence = [torch.nn.Conv2d()]



        sequence = [torch.nn.Conv2d(input_channels, 64, kernel_size=4, stride=2, padding=1), 
                    torch.nn.LeakyReLU(0.2, True)]
        
        sequence += [torch.nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), 
                     torch.nn.BatchNorm2d(128),
                     torch.nn.LeakyReLU(0.2, True)]
        
        sequence += [torch.nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), 
                     torch.nn.BatchNorm2d(256),
                     torch.nn.LeakyReLU(0.2, True)]
        
        sequence += [torch.nn.ZeroPad2d(2)]
        
        sequence += [torch.nn.Conv2d(256, 512, kernel_size=4, stride=1), 
                     torch.nn.BatchNorm2d(512),
                     torch.nn.LeakyReLU(0.2, True)]
        
        sequence += [torch.nn.ZeroPad2d(2)]
        
        sequence += [torch.nn.Conv2d(512, 1, kernel_size=4, stride=1)]
        
        self.model = torch.nn.Sequential(*sequence)

    def forward(self, x:torch.Tensor) -> torch.Tensor:
        return self.model(x)

Now we define our unet with ablations. Layers xxxx are taken out.

In [None]:
#DEFINE ABLATED UNET

Now we need to write our training code. There will be differences here for loading each variation we intend to run (4 in all). Regular and ablated with and without dark channel. In all cases the discriminator remains the same other than input channels which are set by the constructor.

In [None]:
#WRITE TRAINING CODE

Now we need to write code to run our models on a test dataset. Technically this is not necessary as we are not doing parameter tuning. But this does give us an idea of how the model performs on a video that it didn't even have a portion of fed in for training.

In [None]:
#WRITE TESTING CODE FOR TRAINED MODEL

Run training and validation for scenario 1. This is full u-net and no dark channel.

In [None]:
#CODE TO TRAIN SCENARIO 1

Run training and validation for scenario 2. This is ablated u-net and no dark channel.

In [None]:
#CODE TO TRAIN SCENARIO 2

Run training and validation for scenario 3. This is full u-net and dark channel.

In [None]:
#CODE TO TRAIN SCENARIO 3

Run training and validation for scenario 4. This is ablated u-net and dark channel.

In [None]:
#CODE TO TRAIN SCENARIO 4

Now we must run testing on all 4 scenarios. Load in all the models at each epoch so we can plot metrics. (INTEGRATE THIS INTO TRAIN AND VALIDATE??)