In [9]:
import os
import logging
import argparse
import json
import datetime
from typing import Union

import torch
from torch.utils.tensorboard import SummaryWriter
from torch.optim import Adam
from torch.optim import Optimizer
                              

from utils import *
from losses import PSNR
from SynthTrainer import SynthTrainer
import gan
import time

from synth_data_source import loadData

We are going to start by preparing our various models. First we will start with the model described in our target paper "Desmoking Laparoscopy Surgery Images Using an Image-to-Image Translation Guided by an Embedded Dark Channel"
We will have to define a discriminator and a generator. Note the tables here reflect processing a 512x512 image as opposed to 256x256 as the original paper used. Also note that upon inspecting the paper repository, the tables in the paper appear inaccurate. I recreated these nets to be true to their repository by looking at the logic that generates the nets used.

Since the arguments used were not specified in the paper I made a best estimate related to optional arguments based on what I thought would make sense. I also made some modification on where I apply dropout based on some advice from Dr. Florian Richter. It appears dropout was applied on all the 'decoder' layers in the unet, but it equally well could have not been applied at all depending on arguments passed at runtime. My assumption is dropout will be helpful and should be applied at the innermost layers of the decoder.

ADD TABLES HERE AS IMAGES. KATEX DOES NOT IMPLEMENT TABULAR...

I am only showing this discriminator here. Please see models.py for all the models we will be using. They take up a lot of screen space since I wrote them out explicitly.

In [10]:
# DEFINE THE PAPER DISCRIMINATOR
class Discriminator_demo(torch.nn.Module):
    def __init__(self, input_channels:int = 3):
        super().__init__()

        self.sequence = [torch.nn.Conv2d(input_channels, 64, kernel_size=4, stride=2, padding=1), 
                    torch.nn.LeakyReLU(0.2, True)]
        
        sequence += [torch.nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), 
                     torch.nn.BatchNorm2d(128),
                     torch.nn.LeakyReLU(0.2, True)]
        
        sequence += [torch.nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), 
                     torch.nn.BatchNorm2d(256),
                     torch.nn.LeakyReLU(0.2, True)]
        
        sequence += [torch.nn.ZeroPad2d(2)]
        
        sequence += [torch.nn.Conv2d(256, 512, kernel_size=4, stride=1), 
                     torch.nn.BatchNorm2d(512),
                     torch.nn.LeakyReLU(0.2, True)]
        
        sequence += [torch.nn.ZeroPad2d(2)]
        
        sequence += [torch.nn.Conv2d(512, 1, kernel_size=4, stride=1)]
        
        self.model = torch.nn.Sequential(*sequence)

    def forward(self, x:torch.Tensor) -> torch.Tensor:
        return self.model(x)



Now we import models. Note that UNET is the paper implementation and UNETsmolr is the ablated version with layers x-y taken out. Please reference the code and the numbers to the right of the layers. These layers have different assigned numbers than how it is organized in the report. In the code every individual nn.module has its own layer number instead of being grouped together.

In [11]:
import models

Note our trainer is located in gan.py. You will want to reference that file along with the inherited class from data_augmentation in order to understand the training, validation and testing. I have included it instead of implementing it here to reduce the length of this notebook.

You will also want to reference losses.py for PSNR and other metrics being taken for image quality.

In [12]:
#from gan import GANTrainer
from losses import PSNR
from losses import SSIMLoss
from torchmetrics import UniversalImageQualityIndex
from libs.pytorch_fsim import fsim


Now we will setup some arguments which we will use to run our trainer.

In [13]:
args = {      
        'save' : 'C:/Users/Karol/Documents/DL4H/runs/{}'.format(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')),
        'batch' : 8,
        'lr' : 0.0002,
        'epochs' : 60,
        'gpu' : 0,
        'use_dark_channel' : False,
        'load_net' : None,
        'run_val_and_test_every_steps' : 1718
}

args_data = {
        'pre_path' : 'C:/Users/Karol/Documents/DL4H',
        'cache_subfolder' : '/datasets/cholec80/cache',
        'cache_subfolder_test' : '/datasets/cholec80/cache_testset',
        'syn_smoke_subfolder' : '/datasets/cholec80/synthetic_smoke/',
        'dataset_subfolder' : '/datasets/cholec80/input_formatted',
        'dataset_subfolder_test' : '/datasets/cholec80/input_formatted_test'
}

Now we will start tensorboard. This will allow us to view our model train and look at all the losses as well as output images on validatoin and testing steps. PLEASE UPDATE THE BELOW WITH THE CORRECT PATH TO YOUR 'RUNS' DIRECTORY!!!

In [14]:
%load_ext tensorboard

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Running the next line twice may be necessary if you want to see tensorboard in this notebook. Alternatively scroll down for a link to open in your browser. Note that if this step takes 0.0s try changing the port as it did not actually launch tensorboard.

In [15]:

%tensorboard --logdir C:/Users/Karol/Documents/DL4H/runs --port 6006

Great, now open tensorboard here: http://localhost:6006

Lets load all our data in.

In [16]:
train_data, val_data, test_data = loadData(args_data)

CSV IS STILL GOOD SO GOOD
CSV is good?


Number of rows in csv file is 43325 but video length is 43326. Recommend deleting csv file so it is recomputed: C:/Users/Karol/Documents/DL4H/datasets/cholec80/input_formatted\video01\quality_metrics.csv.


NUM ROWS IN CSV WRONG
CSV IS STILL GOOD SO GOOD
CSV is good?


Number of rows in csv file is 70975 but video length is 70976. Recommend deleting csv file so it is recomputed: C:/Users/Karol/Documents/DL4H/datasets/cholec80/input_formatted\video02\quality_metrics.csv.


NUM ROWS IN CSV WRONG
CSV IS STILL GOOD SO GOOD
CSV is good?


Number of rows in csv file is 145700 but video length is 145701. Recommend deleting csv file so it is recomputed: C:/Users/Karol/Documents/DL4H/datasets/cholec80/input_formatted\video03\quality_metrics.csv.


NUM ROWS IN CSV WRONG
CSV IS STILL GOOD SO GOOD
CSV is good?


Number of rows in csv file is 38050 but video length is 38051. Recommend deleting csv file so it is recomputed: C:/Users/Karol/Documents/DL4H/datasets/cholec80/input_formatted\video04\quality_metrics.csv.


NUM ROWS IN CSV WRONG
CSV IS STILL GOOD SO GOOD
CSV is good?


Number of rows in csv file is 58600 but video length is 58601. Recommend deleting csv file so it is recomputed: C:/Users/Karol/Documents/DL4H/datasets/cholec80/input_formatted\video05\quality_metrics.csv.


NUM ROWS IN CSV WRONG
CSV IS STILL GOOD SO GOOD
CSV is good?


Number of rows in csv file is 53825 but video length is 53826. Recommend deleting csv file so it is recomputed: C:/Users/Karol/Documents/DL4H/datasets/cholec80/input_formatted\video06\quality_metrics.csv.


NUM ROWS IN CSV WRONG
CSV IS STILL GOOD SO GOOD
CSV is good?


Number of rows in csv file is 113925 but video length is 113926. Recommend deleting csv file so it is recomputed: C:/Users/Karol/Documents/DL4H/datasets/cholec80/input_formatted\video07\quality_metrics.csv.


NUM ROWS IN CSV WRONG
CSV IS STILL GOOD SO GOOD
CSV is good?


Number of rows in csv file is 37975 but video length is 37976. Recommend deleting csv file so it is recomputed: C:/Users/Karol/Documents/DL4H/datasets/cholec80/input_formatted\video08\quality_metrics.csv.


NUM ROWS IN CSV WRONG
CSV IS STILL GOOD SO GOOD
CSV is good?


Number of rows in csv file is 67550 but video length is 67551. Recommend deleting csv file so it is recomputed: C:/Users/Karol/Documents/DL4H/datasets/cholec80/input_formatted\video09\quality_metrics.csv.


NUM ROWS IN CSV WRONG
CSV IS STILL GOOD SO GOOD
CSV is good?


Number of rows in csv file is 43725 but video length is 43726. Recommend deleting csv file so it is recomputed: C:/Users/Karol/Documents/DL4H/datasets/cholec80/input_formatted\video10\quality_metrics.csv.


NUM ROWS IN CSV WRONG
CSV IS STILL GOOD SO GOOD
CSV is good?


Number of rows in csv file is 80500 but video length is 80501. Recommend deleting csv file so it is recomputed: C:/Users/Karol/Documents/DL4H/datasets/cholec80/input_formatted\video11\quality_metrics.csv.


NUM ROWS IN CSV WRONG
CSV IS STILL GOOD SO GOOD
CSV is good?


Number of rows in csv file is 27250 but video length is 27251. Recommend deleting csv file so it is recomputed: C:/Users/Karol/Documents/DL4H/datasets/cholec80/input_formatted\video12\quality_metrics.csv.


NUM ROWS IN CSV WRONG
CSV IS STILL GOOD SO GOOD
CSV is good?


Number of rows in csv file is 24525 but video length is 24526. Recommend deleting csv file so it is recomputed: C:/Users/Karol/Documents/DL4H/datasets/cholec80/input_formatted\video13\quality_metrics.csv.


NUM ROWS IN CSV WRONG
CSV IS STILL GOOD SO GOOD
CSV is good?


Number of rows in csv file is 42700 but video length is 42701. Recommend deleting csv file so it is recomputed: C:/Users/Karol/Documents/DL4H/datasets/cholec80/input_formatted\video14\quality_metrics.csv.


NUM ROWS IN CSV WRONG
CSV IS STILL GOOD SO GOOD
CSV is good?


Number of rows in csv file is 51450 but video length is 51451. Recommend deleting csv file so it is recomputed: C:/Users/Karol/Documents/DL4H/datasets/cholec80/input_formatted\video15\quality_metrics.csv.


NUM ROWS IN CSV WRONG
CSV IS STILL GOOD SO GOOD
CSV is good?


Number of rows in csv file is 73925 but video length is 73926. Recommend deleting csv file so it is recomputed: C:/Users/Karol/Documents/DL4H/datasets/cholec80/input_formatted\video16\quality_metrics.csv.


NUM ROWS IN CSV WRONG
CSV IS STILL GOOD SO GOOD
CSV is good?


Number of rows in csv file is 32600 but video length is 32601. Recommend deleting csv file so it is recomputed: C:/Users/Karol/Documents/DL4H/datasets/cholec80/input_formatted\video17\quality_metrics.csv.


NUM ROWS IN CSV WRONG
CSV IS STILL GOOD SO GOOD
CSV is good?


Number of rows in csv file is 48550 but video length is 48551. Recommend deleting csv file so it is recomputed: C:/Users/Karol/Documents/DL4H/datasets/cholec80/input_formatted\video18\quality_metrics.csv.


NUM ROWS IN CSV WRONG
CSV IS STILL GOOD SO GOOD
CSV is good?


Number of rows in csv file is 60600 but video length is 60601. Recommend deleting csv file so it is recomputed: C:/Users/Karol/Documents/DL4H/datasets/cholec80/input_formatted\video19\quality_metrics.csv.


NUM ROWS IN CSV WRONG
CSV IS STILL GOOD SO GOOD
CSV is good?


Number of rows in csv file is 36225 but video length is 36226. Recommend deleting csv file so it is recomputed: C:/Users/Karol/Documents/DL4H/datasets/cholec80/input_formatted\video20\quality_metrics.csv.


NUM ROWS IN CSV WRONG
43326
43325
70976
70975
145701
145700
38051
38050
58601
58600
53826
53825
113926
113925
37976
37975
67551
67550
43726
43725
80501
80500
27251
27250
24526
24525
42701
42700
51451
51450
73926
73925
32601
32600
48551
48550
60601
60600
36226
36225
CSV IS STILL GOOD SO GOOD
CSV is good?


Number of rows in csv file is 31450 but video length is 31451. Recommend deleting csv file so it is recomputed: C:/Users/Karol/Documents/DL4H/datasets/cholec80/input_formatted_test\video21\quality_metrics.csv.


NUM ROWS IN CSV WRONG
CSV IS STILL GOOD SO GOOD
CSV is good?


Number of rows in csv file is 38300 but video length is 38301. Recommend deleting csv file so it is recomputed: C:/Users/Karol/Documents/DL4H/datasets/cholec80/input_formatted_test\video22\quality_metrics.csv.


NUM ROWS IN CSV WRONG
31451
31450
38301
38300


Now we setup our models and run our GAN trainer! This will train and output to tensorboard as it trains. Please set 'save' appropriately in args for log and tensorboard. This step is super easy.

For the inner workings please reference models.py for models and gan.py for the trainer.
The trainer also uses SynthTrainer.py and Trainer.py from the data_augmentation folder.

Most of the basic training is done in gan.py however lots of logging of metrics occurs in SynthTrainer.py

Trainer.py contains framework code to generically train various implementations of SynthTrainer. For the purpose of this project its best to mainly stick to gan.py for intuition on the actual models.

In [17]:
import importlib
importlib.reload(models)

myUNET = models.UNETsmolr()
myDisc = models.Discriminator()


Lets take a look at our UNET.

In [18]:
myUNET.parameters

<bound method Module.parameters of UNETsmolr(
  (drp): Dropout(p=0.5, inplace=False)
  (sequence): ParameterList(
      (0): Object of type: Conv2d
      (1): Object of type: LeakyReLU
      (2): Object of type: Conv2d
      (3): Object of type: BatchNorm2d
      (4): Object of type: LeakyReLU
      (5): Object of type: Conv2d
      (6): Object of type: BatchNorm2d
      (7): Object of type: LeakyReLU
      (8): Object of type: Conv2d
      (9): Object of type: BatchNorm2d
      (10): Object of type: LeakyReLU
      (11): Object of type: Conv2d
      (12): Object of type: BatchNorm2d
      (13): Object of type: LeakyReLU
      (14): Object of type: Conv2d
      (15): Object of type: BatchNorm2d
      (16): Object of type: LeakyReLU
      (17): Object of type: Conv2d
      (18): Object of type: BatchNorm2d
      (19): Object of type: LeakyReLU
      (20): Object of type: Conv2d
      (21): Object of type: ReLU
      (22): Object of type: ConvTranspose2d
      (23): Object of type: Batch

Now we run training, validation, and test for our firs scenario. This is the ablated UNET and no dark channel. On a 3070+5800x3d this took 8.5 hours. Expect up to 12 hours for the dark channel models to train.

In [19]:
#CODE TO TRAIN SCENARIO 1
args['save'] = 'C:/Users/Karol/Documents/DL4H/runs/smol_nodark_60ep{}'.format(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
gan.run(args, myUNET, myDisc, train_data, val_data, test_data)

Epoch 1/60:   4%|▎         | 1024/27461 [00:40<07:47, 56.55batch/s]

Run training, validation, and test for scenario 2. This is full u-net and no dark channel.

In [None]:
#CODE TO TRAIN SCENARIO 2
myUNET2 = models.UNET()
myDisc2 = models.Discriminator()
args['save'] = 'C:/Users/Karol/Documents/DL4H/runs/full_nodark_60ep{}'.format(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
gan.run(args, myUNET2, myDisc2, train_data, val_data, test_data)


Run training, validation, and test for scenario 3. This is ablated u-net and dark channel.

In [None]:
#CODE TO TRAIN SCENARIO 3
myUNET3 = models.UNETsmolr(input_channels=4)
myDisc3 = models.Discriminator()
args['save'] = 'C:/Users/Karol/Documents/DL4H/runs/smol_dark_60ep{}'.format(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
args['use_dark_channel'] = True
gan.run(args, myUNET3, myDisc3, train_data, val_data, test_data)

Run training, validation, and test for scenario 4. This is full u-net and dark channel.

In [None]:
importlib.reload(gan)
#CODE TO TRAIN SCENARIO 4
myUNET4 = models.UNET(input_channels=4)
myDisc4 = models.Discriminator()
args['save'] = 'C:/Users/Karol/Documents/DL4H/runs/full_dark_60ep{}'.format(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
args['use_dark_channel'] = True
gan.run(args, myUNET4, myDisc4, train_data, val_data, test_data)

Now we must run testing on all 4 scenarios. Load in all the models at each epoch so we can plot metrics. (INTEGRATE THIS INTO TRAIN AND VALIDATE??)

In [None]:
# YES?