In [1]:
import os
import logging
import argparse
import json
import datetime
from typing import Union

import torch
from torch.utils.tensorboard import SummaryWriter
from torch.optim import Adam
from torch.optim import Optimizer
                              

from utils import *
from losses import PSNR
from SynthTrainer import SynthTrainer
import gan

from synth_data_source import loadData

We are going to start by preparing our various models. First we will start with the model described in our target paper "Desmoking Laparoscopy Surgery Images Using an Image-to-Image Translation Guided by an Embedded Dark Channel"
We will have to define a discriminator and a generator. Note the tables here reflect processing a 512x512 image as opposed to 256x256 as the original paper used. Also note that upon inspecting the paper repository, the tables in the paper appear inaccurate. I recreated these nets to be true to their repository by looking at the logic that generates the nets used.

Since the arguments used were not specified in the paper I made a best estimate related to optional arguments based on what I thought would make sense. I also made some modification on where I apply dropout based on some advice from Dr. Florian Richter. It appears dropout was applied on all the 'decoder' layers in the unet, but it equally well could have not been applied at all depending on arguments passed at runtime. My assumption is dropout will be helpful and should be applied at the innermost layers of the decoder.

ADD TABLES HERE AS IMAGES. KATEX DOES NOT IMPLEMENT TABULAR...

I am only showing this discriminator here. Please see models.py for all the models we will be using. They take up a lot of screen space since I wrote them out explicitly.

In [2]:
# DEFINE THE PAPER DISCRIMINATOR
class Discriminator_demo(torch.nn.Module):
    def __init__(self, input_channels:int = 3):
        super().__init__()

        self.sequence = [torch.nn.Conv2d(input_channels, 64, kernel_size=4, stride=2, padding=1), 
                    torch.nn.LeakyReLU(0.2, True)]
        
        sequence += [torch.nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), 
                     torch.nn.BatchNorm2d(128),
                     torch.nn.LeakyReLU(0.2, True)]
        
        sequence += [torch.nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), 
                     torch.nn.BatchNorm2d(256),
                     torch.nn.LeakyReLU(0.2, True)]
        
        sequence += [torch.nn.ZeroPad2d(2)]
        
        sequence += [torch.nn.Conv2d(256, 512, kernel_size=4, stride=1), 
                     torch.nn.BatchNorm2d(512),
                     torch.nn.LeakyReLU(0.2, True)]
        
        sequence += [torch.nn.ZeroPad2d(2)]
        
        sequence += [torch.nn.Conv2d(512, 1, kernel_size=4, stride=1)]
        
        self.model = torch.nn.Sequential(*sequence)

    def forward(self, x:torch.Tensor) -> torch.Tensor:
        return self.model(x)



Now we import models. Note that UNET is the paper implementation and UNETsmolr is the ablated version with layers x-y taken out. Please reference the code and the numbers to the right of the layers. These layers have different assigned numbers than how it is organized in the report. In the code every individual nn.module has its own layer number instead of being grouped together.

In [3]:
import models

Note our trainer is located in gan.py. You will want to reference that file along with the inherited class from data_augmentation in order to understand the training, validation and testing. I have included it instead of implementing it here to reduce the length of this notebook.

You will also want to reference losses.py for PSNR and other metrics being taken for image quality.

In [4]:
#from gan import GANTrainer
from losses import PSNR
from losses import SSIMLoss
from torchmetrics import UniversalImageQualityIndex
from libs.pytorch_fsim import fsim


Now we will setup some arguments which we will use to run our trainer.

In [5]:
args = {      
        'save' : 'C:/Users/Karol/Documents/DL4H/runs/{}'.format(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')),
        'batch' : 8,
        'lr' : 0.0002,
        'epochs' : 60,
        'gpu' : 0,
        'use_dark_channel' : False,
        'load_net' : None,
        'run_val_and_test_every_steps' : 1718
}

args_data = {
        'pre_path' : 'C:/Users/Karol/Documents/DL4H',
        'cache_subfolder' : '/datasets/cholec80/cache',
        'cache_subfolder_test' : '/datasets/cholec80/cache_testset',
        'syn_smoke_subfolder' : '/datasets/cholec80/synthetic_smoke/',
        'dataset_subfolder' : '/datasets/cholec80/input_formatted',
        'dataset_subfolder_test' : '/datasets/cholec80/input_formatted_test'
}

Now let us load all our data in.

In [6]:
train_data, val_data, test_data = loadData(args_data)

KeyboardInterrupt: 

Now we setup our models and run our GAN trainer! This will train and output to tensorboard as it trains. Please set 'save' appropriately in args for log and tensorboard. This step is super easy.

For the inner workings please reference models.py for models and gan.py for the trainer.
The trainer also uses SynthTrainer.py and Trainer.py from the data_augmentation folder.

Most of the basic training is done in gan.py however lots of logging of metrics occurs in SynthTrainer.py

Trainer.py contains framework code to generically train various implementations of SynthTrainer. For the purpose of this project its best to mainly stick to gan.py for intuition on the actual models.

In [None]:
import importlib
importlib.reload(models)

myUNET = models.UNETsmolr()
myDisc = models.Discriminator()


In [None]:
myUNET.parameters

<bound method Module.parameters of UNETsmolr(
  (drp): Dropout(p=0.5, inplace=False)
  (sequence): ParameterList(
      (0): Object of type: Conv2d
      (1): Object of type: LeakyReLU
      (2): Object of type: Conv2d
      (3): Object of type: BatchNorm2d
      (4): Object of type: LeakyReLU
      (5): Object of type: Conv2d
      (6): Object of type: BatchNorm2d
      (7): Object of type: LeakyReLU
      (8): Object of type: Conv2d
      (9): Object of type: BatchNorm2d
      (10): Object of type: LeakyReLU
      (11): Object of type: Conv2d
      (12): Object of type: BatchNorm2d
      (13): Object of type: LeakyReLU
      (14): Object of type: Conv2d
      (15): Object of type: BatchNorm2d
      (16): Object of type: LeakyReLU
      (17): Object of type: Conv2d
      (18): Object of type: BatchNorm2d
      (19): Object of type: LeakyReLU
      (20): Object of type: Conv2d
      (21): Object of type: ReLU
      (22): Object of type: ConvTranspose2d
      (23): Object of type: Batch

In [12]:
#importlib.reload(gan)
args['save'] = 'C:/Users/Karol/Documents/DL4H/runs/smol_nodark_60ep{}'.format(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
gan.run(args, myUNET, myDisc, train_data, val_data, test_data)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Running Validation: 100%|██████████| 137/137 [00:22<00:00,  6.17batch/s]
Running Test: 100%|██████████| 290/290 [00:22<00:00, 13.09batch/s]
Epoch 1/60: 27464batch [08:41, 52.64batch/s]                          
Running Validation: 100%|██████████| 137/137 [00:21<00:00,  6.52batch/s]
Running Test: 100%|██████████| 290/290 [00:21<00:00, 13.23batch/s]
Running Validation: 100%|██████████| 137/137 [00:21<00:00,  6.26batch/s]
Running Test: 100%|██████████| 290/290 [00:21<00:00, 13.52batch/s]
Epoch 2/60: 27464batch [08:24, 54.46batch/s]                          
Running Validation: 100%|██████████| 137/137 [00:19<00:00,  6.92batch/s]
Running Test: 100%|██████████| 290/290 [00:20<00:00, 14.01batch/s]
Running Validation: 100%|██████████| 137/137 [00:20<00:00,  6.82batch/s]
Running Test: 100%|██████████| 290/290 [00:21<00:00, 13.72batch/s]
Epoch 3/60: 27464batch [07:58, 57.39batch/s]                          
Running Validati

<gan.GANTrainer at 0x1a281443610>

In [None]:
import torch
torch.cuda.empty_cache()



In [None]:
dir()


In [None]:
hold = dir()
for x in hold:
    if x != 'torch' and x != '_oh':
        xc = 'del '+x
        print(xc)
        exec(xc)
del compose_smoke_and_vid_img
del VideoLoader
del SynthTrainer
del SynthSmokeLoader

del getDatasets
del generateVideo

del models
del myDisc
del myUNET
del gan
del test_data
del val_data
del train_data
del synthDatafromVidData


Run training, validation, and test for scenario 2. This is full u-net and no dark channel.

In [18]:
#CODE TO TRAIN SCENARIO 2
myUNET2 = models.UNET()
myDisc2 = models.Discriminator()
args['save'] = 'C:/Users/Karol/Documents/DL4H/runs/full_nodark_60ep{}'.format(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
gan.run(args, myUNET2, myDisc2, train_data, val_data, test_data)


Running Validation: 100%|██████████| 137/137 [00:19<00:00,  6.97batch/s]
Running Test: 100%|██████████| 290/290 [00:20<00:00, 13.86batch/s]
Epoch 1/60: 27464batch [08:57, 51.11batch/s]                          
Running Validation: 100%|██████████| 137/137 [00:19<00:00,  7.09batch/s]
Running Test: 100%|██████████| 290/290 [00:21<00:00, 13.69batch/s]
Running Validation: 100%|██████████| 137/137 [00:19<00:00,  6.92batch/s]
Running Test: 100%|██████████| 290/290 [00:21<00:00, 13.25batch/s]
Epoch 2/60: 27464batch [08:47, 52.06batch/s]                          
Running Validation: 100%|██████████| 137/137 [00:19<00:00,  7.09batch/s]
Running Test: 100%|██████████| 290/290 [00:21<00:00, 13.41batch/s]
Running Validation: 100%|██████████| 137/137 [00:19<00:00,  7.06batch/s]
Running Test: 100%|██████████| 290/290 [00:21<00:00, 13.65batch/s]
Epoch 3/60: 27464batch [08:38, 52.92batch/s]                          
Running Validation: 100%|██████████| 137/137 [00:19<00:00,  7.13batch/s]
Running Test: 

<gan.GANTrainer at 0x1a29dfcea60>

Run training, validation, and test for scenario 3. This is ablated u-net and dark channel.

In [None]:
#CODE TO TRAIN SCENARIO 3
myUNET3 = models.UNETsmolr(input_channels=4)
myDisc3 = models.Discriminator()
args['save'] = 'C:/Users/Karol/Documents/DL4H/runs/smol_dark_60ep{}'.format(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
args['use_dark_channel'] = True
gan.run(args, myUNET3, myDisc3, train_data, val_data, test_data)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


torch.Size([3, 256, 256])
torch.Size([3, 256, 256])




torch.Size([3, 256, 256])




torch.Size([3, 256, 256])




torch.Size([3, 256, 256])




torch.Size([3, 256, 256])


Running Validation: 100%|██████████| 137/137 [00:21<00:00,  6.24batch/s]


torch.Size([3, 256, 256])
torch.Size([3, 256, 256])




torch.Size([3, 256, 256])




torch.Size([3, 256, 256])




torch.Size([3, 256, 256])




torch.Size([3, 256, 256])


Running Test: 100%|██████████| 290/290 [00:24<00:00, 11.88batch/s]
Epoch 1/60:   6%|▌         | 1608/27461 [01:42<08:33, 50.36batch/s]  

torch.Size([3, 256, 256])
torch.Size([3, 256, 256])





torch.Size([3, 256, 256])


Epoch 1/60:   6%|▌         | 1608/27461 [02:00<08:33, 50.36batch/s]h/s][A

torch.Size([3, 256, 256])




torch.Size([3, 256, 256])





torch.Size([3, 256, 256])


Running Validation: 100%|██████████| 137/137 [00:21<00:00,  6.41batch/s][A


torch.Size([3, 256, 256])
torch.Size([3, 256, 256])





torch.Size([3, 256, 256])


Running Test:  23%|██▎       | 66/290 [00:16<00:08, 27.35batch/s][A

torch.Size([3, 256, 256])





torch.Size([3, 256, 256])


Running Test:  63%|██████▎   | 182/290 [00:19<00:03, 31.88batch/s][A

torch.Size([3, 256, 256])


Running Test: 100%|██████████| 290/290 [00:23<00:00, 12.23batch/s]
Epoch 1/60:   9%|▉         | 2408/27461 [03:00<08:16, 50.51batch/s]   

torch.Size([3, 256, 256])
torch.Size([3, 256, 256])




torch.Size([3, 256, 256])




torch.Size([3, 256, 256])





torch.Size([3, 256, 256])


Running Validation:  65%|██████▍   | 89/137 [00:18<00:01, 29.22batch/s][A


torch.Size([3, 256, 256])


Running Validation: 100%|██████████| 137/137 [00:21<00:00,  6.48batch/s][A


torch.Size([3, 256, 256])
torch.Size([3, 256, 256])





torch.Size([3, 256, 256])


Running Test:  23%|██▎       | 66/290 [00:15<00:07, 28.37batch/s][A

torch.Size([3, 256, 256])




torch.Size([3, 256, 256])





torch.Size([3, 256, 256])


Running Test: 100%|██████████| 290/290 [00:23<00:00, 12.12batch/s][A
Epoch 1/60:  12%|█▏        | 3208/27461 [04:00<08:01, 50.35batch/s]   

torch.Size([3, 256, 256])
torch.Size([3, 256, 256])




torch.Size([3, 256, 256])




torch.Size([3, 256, 256])





torch.Size([3, 256, 256])


Running Validation:  66%|██████▌   | 90/137 [00:19<00:01, 30.04batch/s][A

torch.Size([3, 256, 256])


Running Validation: 100%|██████████| 137/137 [00:21<00:00,  6.43batch/s]


torch.Size([3, 256, 256])
torch.Size([3, 256, 256])




torch.Size([3, 256, 256])




torch.Size([3, 256, 256])




torch.Size([3, 256, 256])




torch.Size([3, 256, 256])


Running Test: 100%|██████████| 290/290 [00:23<00:00, 12.48batch/s]
Epoch 1/60:  15%|█▍        | 4008/27461 [05:00<07:47, 50.18batch/s]   

torch.Size([3, 256, 256])
torch.Size([3, 256, 256])





torch.Size([3, 256, 256])


Running Validation:  26%|██▌       | 35/137 [00:16<00:08, 12.10batch/s][A

torch.Size([3, 256, 256])




torch.Size([3, 256, 256])




torch.Size([3, 256, 256])


Running Validation: 100%|██████████| 137/137 [00:21<00:00,  6.39batch/s]


torch.Size([3, 256, 256])
torch.Size([3, 256, 256])




torch.Size([3, 256, 256])




torch.Size([3, 256, 256])




torch.Size([3, 256, 256])





torch.Size([3, 256, 256])


Running Test: 100%|██████████| 290/290 [00:23<00:00, 12.25batch/s][A
Epoch 1/60:  18%|█▊        | 4808/27461 [05:53<07:38, 49.45batch/s]   

torch.Size([3, 256, 256])
torch.Size([3, 256, 256])




torch.Size([3, 256, 256])


Epoch 1/60:  18%|█▊        | 4808/27461 [06:10<07:38, 49.45batch/s]

torch.Size([3, 256, 256])





torch.Size([3, 256, 256])


Running Validation:  65%|██████▍   | 89/137 [00:18<00:01, 29.26batch/s][A

torch.Size([3, 256, 256])


Running Validation: 100%|██████████| 137/137 [00:20<00:00,  6.65batch/s]
Running Test:   0%|          | 0/290 [00:09<?, ?batch/s]
Epoch 1/60:  18%|█▊        | 4808/27461 [06:25<30:17, 12.46batch/s]


AttributeError: 'tuple' object has no attribute 'tb_frame'

Run training, validation, and test for scenario 4. This is full u-net and dark channel.

In [None]:
#CODE TO TRAIN SCENARIO 4
myUNET4 = models.UNET(input_channels=4)
myDisc4 = models.Discriminator()
args['save'] = 'C:/Users/Karol/Documents/DL4H/runs/full_dark_60ep{}'.format(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
args['use_dark_channel'] = True
gan.run(args, myUNET4, myDisc4, train_data, val_data, test_data)

Now we must run testing on all 4 scenarios. Load in all the models at each epoch so we can plot metrics. (INTEGRATE THIS INTO TRAIN AND VALIDATE??)

In [None]:
# YES?

http://localhost:6006