### Zid Network

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

from tqdm import tqdm
from cv2.ximgproc import guidedFilter
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

import numpy as np
import sys

sys.path.insert(0, '..')
from utils.collection import HazeCollection
from utils.network import VaeNet
from utils.skip import Skip
from utils.loss import StdLoss
from utils.utils import get_atmosphere, torch_to_np, np_to_torch, save_image, get_information
from utils.imresize import np_imresize

#### Collection

In [2]:
haze_path_dt1 = "/home/maldonadoq/Datasets/Reside/Standard/HSTS/synthetic/synthetic"
image_path_dt1 = "/home/maldonadoq/Datasets/Reside/Standard/HSTS/synthetic/original"

haze_path_dt2 = "/home/maldonadoq/Datasets/Reside/Standard/HSTS/real-world"

haze_path_dt3 = "/home/maldonadoq/Datasets/Reside/Standard/SOTS/indoor/hazy_val"
image_path_dt3 = "/home/maldonadoq/Datasets/Reside/Standard/SOTS/indoor/gt_val"

In [3]:
size = 256
channels = 3

collection_dt1 = HazeCollection(haze_path_dt1, image_path_dt1)
collection_dt2 = HazeCollection(haze_path_dt2)
collection_dt3 = HazeCollection(haze_path_dt3, image_path_dt3)

#### Model

In [4]:
# Skip
un = Skip(
    channels, 3,
    num_channels_down=[8, 16, 32, 64, 128],
    num_channels_up=[8, 16, 32, 64, 128],
    num_channels_skip=[0, 0, 0, 4, 4],
    upsample_mode='bilinear',
    need_sigmoid=True, need_bias=True, pad='reflection'
)

x = torch.zeros([1, channels, size, size])
a = un(x)
a.shape

torch.Size([1, 3, 256, 256])

#### Training

In [5]:
class ZidTrainer:
    def __init__(
        self,
        metrics=[]
    ):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.metrics = metrics

    def init(self, image):
        input_depth = 3
        pad = 'reflection'
        netJ = Skip(
            input_depth, 3,
            num_channels_down=[8, 16, 32, 64, 128],
            num_channels_up=[8, 16, 32, 64, 128],
            num_channels_skip=[0, 0, 0, 4, 4],
            upsample_mode='bilinear',
            need_sigmoid=True, need_bias=True, pad=pad
        )

        netT = Skip(
            input_depth, 1,
            num_channels_down=[8, 16, 32, 64, 128],
            num_channels_up=[8, 16, 32, 64, 128],
            num_channels_skip=[0, 0, 0, 4, 4],
            upsample_mode='bilinear',
            need_sigmoid=True, need_bias=True, pad=pad
        )

        self.model = {
            "netJ": netJ.to(self.device),
            "netT": netT.to(self.device),
            "netA": VaeNet(image.shape).to(self.device)
        }
        self.loss = {
            "mse": nn.MSELoss().to(self.device),
            "std": StdLoss().to(self.device)
        }

        atmosphere = get_atmosphere(image)
        self.valA = nn.Parameter(data=torch.cuda.FloatTensor(
            atmosphere.reshape((1, 3, 1, 1))), requires_grad=False)

        parameters = [p for p in self.model["netJ"].parameters()]
        parameters += [p for p in self.model["netT"].parameters()]
        parameters += [p for p in self.model["netA"].parameters()]

        self.optimizer = optim.Adam(
            parameters,
            lr=self.learning_rate
        )

    def t_matting(self, mask_out_np, original):
        refine_t = guidedFilter(original.transpose(1, 2, 0).astype(
            np.float32), mask_out_np[0].astype(np.float32), 50, 1e-4)
        return np.array([np.clip(refine_t, 0.1, 1)])

    def train(self, information, epochs):
        (haze, image, name) = information
        self.init(haze)
        loop = tqdm(range(epochs))

        losses = []
        lossFinal = None

        hazeTorch = np_to_torch(haze)
        hazeTorch = hazeTorch.to(device=self.device)
        for i in loop:
            # forward
            predT = self.model["netT"](hazeTorch)
            predJ = self.model["netJ"](hazeTorch)
            predA = self.model["netA"](hazeTorch)

            lossT = self.loss["std"](predT)
            lossJ = self.loss["mse"](predT * predJ + (1 - predT) * predA, hazeTorch)
            lossA = self.model["netA"].getLoss()

            dcp_prior = torch.min(predJ.permute(0, 2, 3, 1), 3)[0]
            dcp_loss = self.loss["mse"](dcp_prior, torch.zeros_like(dcp_prior)) - 0.05

            lossFinal = lossJ + lossA
            lossFinal += 0.005 * lossT
            lossFinal += dcp_loss
            
            lossFinal += 0.1 * self.loss["std"](predA)
            lossFinal += self.loss["mse"](predA, self.valA * torch.ones_like(predA))

            #if i % 100 == 0 or i == epochs-1:
            if i == epochs-1:
                imageJ = np.clip(torch_to_np(predJ), 0, 1)
                imageT = np.clip(torch_to_np(predT), 0, 1)
                imageA = np.clip(torch_to_np(predA), 0, 1)
                imageT = self.t_matting(imageT, haze)

                # I(x) = J(x)t(x) + A(1 − t(x))
                post = np.clip((haze - ((1 - imageT) * imageA)) / imageT, 0, 1)
                post = np_imresize(post, output_shape=haze.shape[1:])
                save_image('{}_{}'.format(name[0].split('.')[0], i+1), post, '../images/zid/' + self.dt_number)

            # backward
            self.optimizer.zero_grad()
            lossFinal.backward(retain_graph=True)
            self.optimizer.step()

            # update tqdm
            loop.set_postfix(loss=lossFinal.item())
            losses.append(lossFinal.item())

        # metrics
        finalSSIM, finalPSNR = 0, 0
        if image is not None:
            finalSSIM = ssim(post, image, channel_axis=0, data_range=1)
            finalPSNR = psnr(post, image)
            print("Ssim: {:.4f}, Psnr: {:.4f}\n".format(finalSSIM, finalPSNR))

        return [np.sum(losses)/len(losses), finalSSIM, finalPSNR]

    def fit(
        self,
        collection,
        epochs=1,
        dt_number='',
        learning_rate=1e-3
    ):
        historial = []
        self.dt_number = dt_number
        self.learning_rate = learning_rate
        #loop = enumerate(tqdm(collection))
        for i, info in enumerate(collection):
            print('Image {}/{} [{}]'.format(i+1, len(collection), info[2][0]))
            loss, ssim, psnr = self.train(info, epochs)
            historial.append([loss, ssim, psnr])
        return historial

In [6]:
dehazing = ZidTrainer()

#### Testing

In [7]:
epochs = 512
learning_rate = 1e-4

##### D1

In [8]:
historial_dt1 = dehazing.fit(collection_dt1, epochs, 'd1/', learning_rate)
get_information(historial_dt1, collection_dt1)

Image 1/10 [1381.jpg]


100%|██████████| 512/512 [00:34<00:00, 14.71it/s, loss=0.000815]


Ssim: 0.6752, Psnr: 13.0089

Image 2/10 [5576.jpg]


100%|██████████| 512/512 [00:26<00:00, 19.12it/s, loss=-.0013]  


Ssim: 0.7034, Psnr: 17.0090

Image 3/10 [7471.jpg]


100%|██████████| 512/512 [00:29<00:00, 17.21it/s, loss=0.0382]


Ssim: 0.8683, Psnr: 19.6908

Image 4/10 [0586.jpg]


100%|██████████| 512/512 [00:35<00:00, 14.27it/s, loss=0.135]


Ssim: 0.7487, Psnr: 12.3101

Image 5/10 [5920.jpg]


100%|██████████| 512/512 [00:35<00:00, 14.34it/s, loss=0.0365]


Ssim: 0.9342, Psnr: 23.9943

Image 6/10 [3146.jpg]


100%|██████████| 512/512 [00:27<00:00, 18.85it/s, loss=0.0412]


Ssim: 0.7321, Psnr: 18.8292

Image 7/10 [4184.jpg]


100%|██████████| 512/512 [00:40<00:00, 12.51it/s, loss=0.0222]


Ssim: 0.8059, Psnr: 18.0182

Image 8/10 [8180.jpg]


100%|██████████| 512/512 [00:30<00:00, 16.99it/s, loss=0.0105]


Ssim: 0.8711, Psnr: 20.4244

Image 9/10 [1352.jpg]


100%|██████████| 512/512 [01:08<00:00,  7.50it/s, loss=0.0114]


Ssim: 0.9512, Psnr: 23.2605

Image 10/10 [4561.jpg]


100%|██████████| 512/512 [00:27<00:00, 18.43it/s, loss=0.0332]

Ssim: 0.7164, Psnr: 19.4504

Mean SSIM: 0.8006464540958405
Mean PSNR: 18.59956848950961
Best SSIM: 1352.jpg
Best PSNR: 5920.jpg





##### D2

In [9]:
historial_dt2 = dehazing.fit(collection_dt2, epochs, 'd2/', learning_rate)
get_information(historial_dt2, collection_dt2)

Image 1/10 [SFC_Google_197.jpeg]


100%|██████████| 512/512 [00:51<00:00,  9.87it/s, loss=0.0192]


Image 2/10 [MLS_Bing_117.jpeg]


100%|██████████| 512/512 [01:01<00:00,  8.35it/s, loss=0.000864]


Image 3/10 [HazyDr_Google_396.jpeg]


100%|██████████| 512/512 [00:56<00:00,  9.05it/s, loss=0.0133]


Image 4/10 [SGP_Bing_085.jpeg]


100%|██████████| 512/512 [02:30<00:00,  3.41it/s, loss=0.0463]


Image 5/10 [NW_Google_837.jpeg]


100%|██████████| 512/512 [00:51<00:00,  9.90it/s, loss=0.0279]


Image 6/10 [YST_Bing_667.jpeg]


100%|██████████| 512/512 [01:01<00:00,  8.33it/s, loss=0.0949]


Image 7/10 [MLS_Google_585.png]


100%|██████████| 512/512 [01:50<00:00,  4.63it/s, loss=0.0476]


Image 8/10 [SGP_Bing_588.png]


100%|██████████| 512/512 [01:01<00:00,  8.33it/s, loss=0.0174]


Image 9/10 [KRO_Google_143.jpeg]


100%|██████████| 512/512 [01:01<00:00,  8.34it/s, loss=0.0241]


Image 10/10 [HazeDr_Google_404.jpeg]


100%|██████████| 512/512 [02:11<00:00,  3.88it/s, loss=0.143]

Mean SSIM: 0.0
Mean PSNR: 0.0
Best SSIM: SFC_Google_197.jpeg
Best PSNR: SFC_Google_197.jpeg





##### D3

In [10]:
historial_dt3 = dehazing.fit(collection_dt3, epochs, 'd3/', learning_rate)
get_information(historial_dt3, collection_dt3)

Image 1/10 [1410_5.png]


100%|██████████| 512/512 [00:46<00:00, 11.03it/s, loss=0.0255]


Ssim: 0.8560, Psnr: 17.9021

Image 2/10 [1430_5.png]


100%|██████████| 512/512 [00:45<00:00, 11.13it/s, loss=0.108]


Ssim: 0.4818, Psnr: 16.1625

Image 3/10 [1440_5.png]


100%|██████████| 512/512 [00:45<00:00, 11.26it/s, loss=0.0217]


Ssim: 0.8391, Psnr: 16.6056

Image 4/10 [1405_5.png]


100%|██████████| 512/512 [00:44<00:00, 11.39it/s, loss=0.0402]


Ssim: 0.8805, Psnr: 20.1874

Image 5/10 [1400_5.png]


100%|██████████| 512/512 [00:45<00:00, 11.33it/s, loss=0.0391]


Ssim: 0.7983, Psnr: 17.6596

Image 6/10 [1415_5.png]


100%|██████████| 512/512 [00:45<00:00, 11.27it/s, loss=0.0397]


Ssim: 0.8443, Psnr: 20.1377

Image 7/10 [1445_5.png]


100%|██████████| 512/512 [00:45<00:00, 11.29it/s, loss=0.0556]


Ssim: 0.8280, Psnr: 21.2616

Image 8/10 [1435_5.png]


100%|██████████| 512/512 [00:45<00:00, 11.31it/s, loss=0.013] 


Ssim: 0.8399, Psnr: 19.2470

Image 9/10 [1425_5.png]


100%|██████████| 512/512 [00:45<00:00, 11.30it/s, loss=0.0282]


Ssim: 0.7318, Psnr: 14.1651

Image 10/10 [1420_5.png]


100%|██████████| 512/512 [00:45<00:00, 11.32it/s, loss=0.0405]


Ssim: 0.8119, Psnr: 16.4296

Mean SSIM: 0.791169810295105
Mean PSNR: 17.975829057953593
Best SSIM: 1405_5.png
Best PSNR: 1445_5.png
