### XYZ Network

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
torch.manual_seed(0)

from tqdm import tqdm
from cv2.ximgproc import guidedFilter
from skimage.color import rgb2hsv
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

import numpy as np

from utils.collection import HazeCollection
from utils.network import VaeNet, HazyNet
from utils.loss import StdLoss
from utils.utils import get_atmosphere, torch_to_np, np_to_torch, save_image, get_information
from utils.imresize import np_imresize

#### Collection

In [2]:
haze_path_dt0 = "/home/maldonadoq/Datasets/Reside/Standard/OWN/hazy_only"
image_path_dt0 = "/home/maldonadoq/Datasets/Reside/Standard/OWN/gt_only"

haze_path_dt1 = "/home/maldonadoq/Datasets/Reside/Standard/HSTS/synthetic/synthetic"
image_path_dt1 = "/home/maldonadoq/Datasets/Reside/Standard/HSTS/synthetic/original"

haze_path_dt2 = "/home/maldonadoq/Datasets/Reside/Standard/HSTS/real-world"

haze_path_dt3 = "/home/maldonadoq/Datasets/Reside/Standard/SOTS/indoor/hazy_val"
image_path_dt3 = "/home/maldonadoq/Datasets/Reside/Standard/SOTS/indoor/gt_val"

In [3]:
size = 256
channels = 3

collection_dt0 = HazeCollection(haze_path_dt0, image_path_dt0)
collection_dt1 = HazeCollection(haze_path_dt1, image_path_dt1)
collection_dt2 = HazeCollection(haze_path_dt2)
collection_dt3 = HazeCollection(haze_path_dt3, image_path_dt3)

#### Model

In [4]:
# UNet
un = HazyNet(channels)

x = torch.zeros([1, channels, size, size])
a = un(x)
a.shape

torch.Size([1, 3, 256, 256])

#### Training

In [5]:
class EnsembleTrainer:
    def __init__(
        self,
        metrics=[]
    ):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.metrics = metrics

    def init(self, image):
        self.model = {
            "netJ": HazyNet(out_channel=3).to(self.device),
            "netT": HazyNet(out_channel=1).to(self.device),
            "netA": VaeNet(image.shape).to(self.device)
        }
        self.loss = {
            "mse": nn.MSELoss().to(self.device),
            "std": StdLoss().to(self.device)
        }

        atmosphere = get_atmosphere(image)
        self.valA = nn.Parameter(data=torch.cuda.FloatTensor(
            atmosphere.reshape((1, 3, 1, 1))), requires_grad=False)

        parameters = [p for p in self.model["netJ"].parameters()]
        parameters += [p for p in self.model["netT"].parameters()]
        parameters += [p for p in self.model["netA"].parameters()]

        self.optimizer = optim.Adam(
            parameters,
            lr=self.learning_rate
        )

    def t_matting(self, mask_out_np, original):
        refine_t = guidedFilter(original.transpose(1, 2, 0).astype(
            np.float32), mask_out_np[0].astype(np.float32), 50, 1e-4)
        return np.array([np.clip(refine_t, 0.1, 1)])

    def train(self, information, epochs):
        (haze, image, name) = information
        self.init(haze)
        loop = tqdm(range(epochs))

        losses = []
        lossFinal = None

        hazeTorch = np_to_torch(haze)
        hazeTorch = hazeTorch.to(device=self.device)
        for i in loop:
            # forward
            predT = self.model["netT"](hazeTorch)
            predJ = self.model["netJ"](hazeTorch)
            predA = self.model["netA"](hazeTorch)

            lossA = self.model["netA"].getLoss()
            lossT = self.loss["std"](predT)

            # Xhot
            mse_loss = self.loss["mse"](predT * predJ + (1 - predT) * predA, hazeTorch)
            XLossJ = mse_loss

            # Yoly
            hsv = np_to_torch(rgb2hsv(torch_to_np(predJ).transpose(1, 2, 0)))
            cap_prior = hsv[:, :, :, 2] - hsv[:, :, :, 1]
            cap_loss = self.loss["mse"](cap_prior, torch.zeros_like(cap_prior))
            YLossJ = 1.0 * cap_loss

            # Zid
            dcp_prior = torch.min(predJ.permute(0, 2, 3, 1), 3)[0]
            dcp_loss = self.loss["mse"](dcp_prior, torch.zeros_like(dcp_prior)) - 0.05
            ZLossJ = dcp_loss

            # Final
            lossFinalJ = 0.4*XLossJ + 0.4*YLossJ + 0.2*ZLossJ
            lossFinalT = 0.005 * lossT
            lossFinalA = self.loss["mse"](predA, self.valA) + lossA
            lossFinal = lossFinalJ + lossFinalT + lossFinalA

            # if i % 100 == 0 or i == epochs-1:
            if i == epochs-1:
                imageJ = np.clip(torch_to_np(predJ), 0, 1)
                imageT = np.clip(torch_to_np(predT), 0, 1)
                imageA = np.clip(torch_to_np(predA), 0, 1)
                imageT = self.t_matting(imageT, haze)

                # I(x) = J(x)t(x) + A(1 − t(x))
                # J(x) = (I(X) - A(1 − t(x))) / t(x)
                post = np.clip((haze - ((1 - imageT) * imageA)) / imageT, 0, 1)
                post = np_imresize(post, output_shape=haze.shape[1:])
                save_image('{}_{}'.format(name[0].split('.')[
                           0], i+1), post, './images/xyz/' + self.dt_number)

            # backward
            self.optimizer.zero_grad()
            lossFinal.backward(retain_graph=True)
            self.optimizer.step()

            # update tqdm
            loop.set_postfix(loss=lossFinal.item())
            losses.append(lossFinal.item())

        # metrics
        finalSSIM, finalPSNR = 0, 0
        if image is not None:
            finalSSIM = ssim(post, image, channel_axis=0, data_range=1)
            finalPSNR = psnr(post, image)
            print("Ssim: {:.4f}, Psnr: {:.4f}\n".format(finalSSIM, finalPSNR))

        return [np.sum(losses)/len(losses), finalSSIM, finalPSNR]

    def fit(
        self,
        collection,
        epochs=1,
        dt_number='',
        learning_rate=1e-3
    ):
        historial = []
        self.dt_number = dt_number
        self.learning_rate = learning_rate
        # loop = enumerate(tqdm(collection))
        for i, info in enumerate(collection):
            print('Image {}/{} [{}]'.format(i+1, len(collection), info[2][0]))
            loss, ssim, psnr = self.train(info, epochs)
            historial.append([loss, ssim, psnr])
        return historial

In [6]:
dehazing = EnsembleTrainer()

#### Testing

In [7]:
epochs = 512
learning_rate = 1e-4

##### D0

In [8]:
historial_dt0 = dehazing.fit(collection_dt0, epochs, 'd0/', learning_rate)
get_information(historial_dt0, collection_dt0)

Image 1/5 [8180_1.png]


  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 512/512 [00:21<00:00, 23.49it/s, loss=0.0189]


Ssim: 0.9168, Psnr: 24.0159

Image 2/5 [4561_1.png]


  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 512/512 [00:19<00:00, 26.94it/s, loss=0.0187]


Ssim: 0.5637, Psnr: 16.6411

Image 3/5 [1436_5.png]


  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 512/512 [00:30<00:00, 16.58it/s, loss=0.0123]


Ssim: 0.8884, Psnr: 23.0519

Image 4/5 [3146_1.png]


100%|██████████| 512/512 [00:19<00:00, 25.97it/s, loss=0.0247]


Ssim: 0.5425, Psnr: 20.0960

Image 5/5 [1401_5.png]


100%|██████████| 512/512 [00:31<00:00, 16.22it/s, loss=0.0188]

Ssim: 0.7975, Psnr: 20.2498

Mean SSIM: 0.74178067445755
Mean PSNR: 20.810930718462565
Best SSIM: 8180_1.png
Best PSNR: 8180_1.png





##### D1

In [9]:
historial_dt1 = dehazing.fit(collection_dt1, epochs, 'd1/', learning_rate)
get_information(historial_dt1, collection_dt1)

Image 1/10 [1381.jpg]


  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 512/512 [00:25<00:00, 19.95it/s, loss=0.0077] 


Ssim: 0.8148, Psnr: 18.1931

Image 2/10 [5576.jpg]


100%|██████████| 512/512 [00:19<00:00, 26.00it/s, loss=0.00993]


Ssim: 0.8261, Psnr: 22.5090

Image 3/10 [7471.jpg]


100%|██████████| 512/512 [00:21<00:00, 23.85it/s, loss=0.0226]


Ssim: 0.8852, Psnr: 20.9043

Image 4/10 [0586.jpg]


100%|██████████| 512/512 [00:25<00:00, 20.29it/s, loss=0.0234]


Ssim: 0.8512, Psnr: 17.0770

Image 5/10 [5920.jpg]


100%|██████████| 512/512 [00:26<00:00, 19.69it/s, loss=0.0155]


Ssim: 0.9587, Psnr: 26.7712

Image 6/10 [3146.jpg]


100%|██████████| 512/512 [00:19<00:00, 26.56it/s, loss=0.0226]


Ssim: 0.8673, Psnr: 23.1697

Image 7/10 [4184.jpg]


  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 512/512 [00:28<00:00, 17.74it/s, loss=0.0247]


Ssim: 0.8963, Psnr: 21.8985

Image 8/10 [8180.jpg]


100%|██████████| 512/512 [00:21<00:00, 23.75it/s, loss=0.0229]


Ssim: 0.6638, Psnr: 14.7711

Image 9/10 [1352.jpg]


  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 512/512 [00:46<00:00, 10.99it/s, loss=0.0153]


Ssim: 0.9643, Psnr: 25.7760

Image 10/10 [4561.jpg]


100%|██████████| 512/512 [00:19<00:00, 26.04it/s, loss=0.0174]


Ssim: 0.4350, Psnr: 14.4302

Mean SSIM: 0.8162753313779831
Mean PSNR: 20.55001385959589
Best SSIM: 1352.jpg
Best PSNR: 5920.jpg


##### D2

In [10]:
historial_dt2 = dehazing.fit(collection_dt2, epochs, 'd2/', learning_rate)
get_information(historial_dt2, collection_dt2)

Image 1/10 [SFC_Google_197.jpeg]


  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 512/512 [00:36<00:00, 13.89it/s, loss=0.0228] 


Image 2/10 [MLS_Bing_117.jpeg]


  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 512/512 [00:41<00:00, 12.30it/s, loss=0.0194]


Image 3/10 [HazyDr_Google_396.jpeg]


  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 512/512 [00:39<00:00, 12.96it/s, loss=0.0365]


Image 4/10 [SGP_Bing_085.jpeg]


  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 512/512 [01:58<00:00,  4.32it/s, loss=0.0309]


Image 5/10 [NW_Google_837.jpeg]


100%|██████████| 512/512 [00:36<00:00, 14.14it/s, loss=0.0134]


Image 6/10 [YST_Bing_667.jpeg]


100%|██████████| 512/512 [00:41<00:00, 12.48it/s, loss=0.018]  


Image 7/10 [MLS_Google_585.png]


  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 512/512 [01:16<00:00,  6.68it/s, loss=0.0344]


Image 8/10 [SGP_Bing_588.png]


100%|██████████| 512/512 [00:41<00:00, 12.48it/s, loss=0.0178]


Image 9/10 [KRO_Google_143.jpeg]


100%|██████████| 512/512 [00:40<00:00, 12.51it/s, loss=0.0159]


Image 10/10 [HazeDr_Google_404.jpeg]


  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 512/512 [01:42<00:00,  5.00it/s, loss=0.0491]

Mean SSIM: 0.0
Mean PSNR: 0.0
Best SSIM: SFC_Google_197.jpeg
Best PSNR: SFC_Google_197.jpeg





##### D3

In [11]:
historial_dt3 = dehazing.fit(collection_dt3, epochs, 'd3/', learning_rate)
get_information(historial_dt3, collection_dt3)

Image 1/10 [1410_5.png]


100%|██████████| 512/512 [00:30<00:00, 16.67it/s, loss=0.0391]


Ssim: 0.7548, Psnr: 17.6932

Image 2/10 [1430_5.png]


100%|██████████| 512/512 [00:30<00:00, 16.74it/s, loss=0.0283] 


Ssim: 0.6245, Psnr: 16.3687

Image 3/10 [1440_5.png]


100%|██████████| 512/512 [00:30<00:00, 16.73it/s, loss=0.0385]


Ssim: 0.7830, Psnr: 17.7673

Image 4/10 [1405_5.png]


100%|██████████| 512/512 [00:30<00:00, 16.68it/s, loss=0.0304]


Ssim: 0.8462, Psnr: 20.9826

Image 5/10 [1400_5.png]


100%|██████████| 512/512 [00:30<00:00, 16.68it/s, loss=0.0195]


Ssim: 0.7698, Psnr: 17.2342

Image 6/10 [1415_5.png]


100%|██████████| 512/512 [00:30<00:00, 16.72it/s, loss=0.0249]


Ssim: 0.8827, Psnr: 22.7774

Image 7/10 [1445_5.png]


100%|██████████| 512/512 [00:30<00:00, 16.67it/s, loss=0.0351]


Ssim: 0.4723, Psnr: 15.2513

Image 8/10 [1435_5.png]


100%|██████████| 512/512 [00:30<00:00, 16.72it/s, loss=0.0221]


Ssim: 0.8723, Psnr: 21.9887

Image 9/10 [1425_5.png]


100%|██████████| 512/512 [00:30<00:00, 16.68it/s, loss=0.0221]


Ssim: 0.5283, Psnr: 14.9922

Image 10/10 [1420_5.png]


100%|██████████| 512/512 [00:30<00:00, 16.68it/s, loss=0.0283]


Ssim: 0.7826, Psnr: 19.0827

Mean SSIM: 0.7316538900136947
Mean PSNR: 18.413821582554505
Best SSIM: 1415_5.png
Best PSNR: 1415_5.png
