<a href="https://colab.research.google.com/github/johnnyff/light_breeding_resolution_dehazing_by_HARDGAN/blob/main/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
pip install tensorboardX

In [None]:
import torch.utils.data as data
from PIL import Image
from random import randrange
from torchvision.transforms import Compose, ToTensor, Normalize
import torchvision.transforms as tfs
from torchvision.transforms import functional as FF
import random
from glob import glob
import re
import pandas as pd
from torchvision.models import vgg16
import os
import time
import torch
import argparse
import torch.nn as nn
import torch.nn.functional as F
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader
from torch.utils.data import ConcatDataset
from torch.nn import init
import cv2
import matplotlib.pyplot as plt
import numpy as np


In [None]:

learning_rate = 1e-3
crop_size = [240, 240]
train_phrase = 10
train_batch_size = 6
network_height = 3
network_width = 6
num_dense_layer = 4
growth_rate = 16
lambda_loss = 0.04
category = 'lg'

print('--- Hyper-parameters for training ---')
print('learning_rate: {}\ncrop_size: {}\ntrain_batch_size: {}\nnetwork_height: {}\nnetwork_width: {}\n'
      'num_dense_layer: {}\ngrowth_rate: {}\nlambda_loss: {}\ncategory: {}'.format(learning_rate, crop_size,
      train_batch_size, network_height, network_width, num_dense_layer, growth_rate, lambda_loss, category))

In [None]:

device_ids = [Id for Id in range(torch.cuda.device_count())]
print(device_ids)
print(torch.cuda.get_device_name(0))


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
writer = SummaryWriter()
print(device)

In [None]:
def initialize_weights(m):
    if isinstance(m, nn.Conv2d):
        init.xavier_uniform(m.weight)
        m.weight.data.normal_(0, 0.02)
        m.bias.data.zero_()
    elif isinstance(m, nn.Linear):
        init.xavier_uniform(m.weight)
        m.weight.data.normal_(0, 0.02)
        m.bias.data.zero_()

#Phrase 1


In [None]:

# --- Build dense --- #
class MakeDense(nn.Module):
    def __init__(self, in_channels, growth_rate, kernel_size=3, dilation = 1):
        super(MakeDense, self).__init__()
        self.conv = nn.Conv2d(in_channels, growth_rate, kernel_size=kernel_size, padding=dilation, dilation=dilation)

    def forward(self, x):
        out = F.relu(self.conv(x))
        out = torch.cat((x, out), 1)
        return out

class AdaIn(nn.Module):
    def __init__(self):
        super(AdaIn, self).__init__()
        self.eps = 1e-5

    def forward(self, x, mean_style, std_style):
        B, C, H, W = x.shape

        feature = x.view(B, C, -1)

        #print (mean_feat.shape, std_feat.shape, mean_style.shape, std_style.shape)
        std_style = std_style.view(B, C, 1)
        mean_style = mean_style.view(B, C, 1)
        adain = std_style * (feature) + mean_style

        adain = adain.view(B, C, H, W)
        return adain

class CALayer(nn.Module):
    def __init__(self, channel):
        super(CALayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.ca = nn.Sequential(
                nn.Conv2d(channel, channel // 4, 1, padding=0, bias=True),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel // 4, channel, 1, padding=0, bias=True),
                nn.Sigmoid()
        )

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.ca(y)
        return x * y

class PALayer(nn.Module):
    def __init__(self, channel):
        super(PALayer, self).__init__()
        self.pa = nn.Sequential(
                nn.Conv2d(channel, channel // 4, 1, padding=0, bias=True),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel // 4, 1, 1, padding=0, bias=True),
                nn.Sigmoid()
        )
    def forward(self, x):
        y = self.pa(x)
        return x * y

class ApplyNoise(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(channels))

    def forward(self, x, noise):
        if noise is None:
            noise = torch.randn(x.size(0), 1, x.size(2), x.size(3), device=x.device, dtype=x.dtype)
        return x + self.weight.view(1, -1, 1, 1) * noise.to(x.device)

# --- Build the Residual Dense Block --- #
class RDB(nn.Module):
    def __init__(self, in_channels, num_dense_layer, growth_rate, dilations = [1, 1, 1, 1]):
        """

        :param in_channels: input channel size
        :param num_dense_layer: the number of RDB layers
        :param growth_rate: growth_rate
        """
        super(RDB, self).__init__()
        _in_channels = in_channels
        modules = []
        for i in range(num_dense_layer):
            modules.append(MakeDense(_in_channels, growth_rate, dilation = dilations[i]))
            _in_channels += growth_rate
        self.residual_dense_layers = nn.Sequential(*modules)

        self.conv_1x1_a = nn.Conv2d(_in_channels, in_channels, kernel_size=1, padding=0)


        _in_channels_no_style = in_channels
        no_style_modules = []
        for i in range(num_dense_layer):
            no_style_modules.append(MakeDense(_in_channels_no_style, growth_rate))
            _in_channels_no_style += growth_rate

        self.residual_dense_layers_no_style = nn.Sequential(*no_style_modules)
        self.conv_1x1_b = nn.Conv2d(_in_channels_no_style, in_channels, kernel_size=1, padding=0)

        self.norm = nn.InstanceNorm2d(in_channels)
        self.norm2 = nn.InstanceNorm2d(in_channels)
        self.adaIn = AdaIn()

        self.global_feat = nn.AdaptiveAvgPool2d((1, 1))
        self.style = nn.Linear(in_channels // 2, in_channels * 2)
        self.conv_1x1_style = nn.Conv2d(in_channels, in_channels // 2, kernel_size=3, padding=1)

        self.conv_gamma = nn.Conv2d(in_channels // 2, in_channels, kernel_size=3, padding=1)
        self.conv_beta = nn.Conv2d(in_channels // 2, in_channels, kernel_size=3, padding=1)

        self.conv_att = nn.Conv2d(in_channels, 1, kernel_size=3, padding=1)

        self.in_channels = in_channels

        self.conv_1x1_final = nn.Conv2d(in_channels * 2, in_channels, kernel_size=1, padding=0)

        self.coefficient = nn.Parameter(torch.Tensor(np.ones((1, 2))), requires_grad=True)
        self.ca = CALayer(in_channels)
        self.pool = nn.AvgPool2d((7, 7), stride=(1, 1), padding=(3, 3))

        #self.noise = ApplyNoise(in_channels)

    def forward(self, x):
        # residual
        bottle_feat = self.residual_dense_layers(x)
        out = self.conv_1x1_a(bottle_feat)
        out = out + x

        # base residual， self-guieded learn mean，std，gamma，and beta
        style_feat_1 = F.relu(self.conv_1x1_style(out))
        style_feat = self.global_feat(style_feat_1)
        style_feat = torch.flatten(style_feat, start_dim = 1)
        style_feat = self.style(style_feat)
        # mean, std
        style_mean = style_feat[:, :self.in_channels]
        style_std = style_feat[:, self.in_channels:]

        gamma = self.conv_gamma(style_feat_1)
        beta = self.conv_beta(style_feat_1)

        y = self.norm(x)
        out_no_style = self.residual_dense_layers_no_style(y)
        out_no_style = self.conv_1x1_b(out_no_style)
        out_no_style = y + out_no_style
        #out_no_style = self.noise(out_no_style, None)
        out_no_style = self.norm2(out_no_style)
        out_att = torch.sigmoid(self.conv_att(out_no_style))

        out_new_style = self.adaIn(out_no_style, style_mean , style_std)
        out_new_gamma = out_no_style * (1 + gamma) + beta
        out_new = out_att * out_new_style + (1 - out_att) * out_new_gamma
        out = self.conv_1x1_final(torch.cat([out, out_new], dim = 1))
        out = self.ca(out)
        out = out + x
        return out

In [None]:
class Generate_quarter(nn.Module):
    def __init__(self, in_channels=3, depth_rate=16, kernel_size=3, stride=2, height=3, width=6, num_dense_layer=4, growth_rate=16, attention=True):
        super(Generate_quarter, self).__init__()
        self.rdb_module = nn.ModuleDict()
        self.upsample_module = nn.ModuleDict()
        self.downsample_module = nn.ModuleDict()
        self.height = height
        self.width = width
        self.stride = stride
        self.depth_rate = depth_rate
        self.coefficient = nn.Parameter(torch.Tensor(np.ones((height, width, 2, depth_rate*stride**(height-1)))), requires_grad=attention)
        self.conv_in = nn.Conv2d(in_channels, depth_rate, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)
        self.conv_out = nn.Conv2d(depth_rate, in_channels, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)
        self.rdb_in = RDB(depth_rate, num_dense_layer, growth_rate)
        self.rdb_out = RDB(depth_rate, num_dense_layer, growth_rate)

        rdb_in_channels = depth_rate
        for i in range(height):
            for j in range(width - 1):
                self.rdb_module.update({'{}_{}'.format(i, j): RDB(rdb_in_channels, num_dense_layer, growth_rate)})
            rdb_in_channels *= stride

        _in_channels = depth_rate
        for i in range(height - 1):
            for j in range(width // 2):
                self.downsample_module.update({'{}_{}'.format(i, j): DownSample(_in_channels)})
            _in_channels *= stride

        for i in range(height - 2, -1, -1):
            for j in range(width // 2, width):
                self.upsample_module.update({'{}_{}'.format(i, j): UpSample(_in_channels)})
            _in_channels //= stride

        self.conv1 = nn.Conv2d(depth_rate * 4, depth_rate * 4, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)
        self.conv2_atrous = nn.Conv2d(depth_rate * 4, depth_rate * 4, kernel_size=kernel_size, padding=2, dilation=2)
        self.conv3_atrous = nn.Conv2d(depth_rate * 4, depth_rate * 4, kernel_size=kernel_size, padding=4, dilation=4)
        self.conv4_atrous = nn.Conv2d(depth_rate * 4, depth_rate * 4, kernel_size=kernel_size, padding=8, dilation=8)
        self.conv5_atrous = nn.Conv2d(depth_rate * 4, depth_rate * 4, kernel_size=kernel_size, padding=16, dilation=16)
        self.conv6 = nn.Conv2d(depth_rate * 4, depth_rate * 4, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)
        self.offset_conv1 = nn.Conv2d(depth_rate*4, depth_rate*2, 3, 1, 1, bias=True)
        self.offset_conv2 = nn.Conv2d(depth_rate*2, depth_rate*4, 3, 1, 1, bias=True)
        #self.dcnpack = DCN(depth_rate*4, depth_rate*4, 3, stride=1, padding=1, dilation=1, deformable_groups=8, extra_offset_mask=True)
        self.upsamle1 = UpSample(depth_rate*4)
        self.upsamle2 = UpSample(depth_rate*2)

        self.rdb_2_1 = RDB(depth_rate * 2, num_dense_layer, growth_rate)
        self.rdb_1_1 = RDB(depth_rate, num_dense_layer, growth_rate)

    def forward(self, x):
        inp = self.conv_in(x)

        x_index = [[0 for _ in range(self.width)] for _ in range(self.height)]
        i, j = 0, 0

        x_index[0][0] = self.rdb_in(inp)

        for j in range(1, self.width // 2):
            x_index[0][j] = self.rdb_module['{}_{}'.format(0, j-1)](x_index[0][j-1])

        for i in range(1, self.height):
            x_index[i][0] = self.downsample_module['{}_{}'.format(i-1, 0)](x_index[i-1][0])

        for i in range(1, self.height):
            for j in range(1, self.width // 2):
                channel_num = int(2**(i-1)*self.stride*self.depth_rate)
                x_index[i][j] = self.coefficient[i, j, 0, :channel_num][None, :, None, None] * self.rdb_module['{}_{}'.format(i, j-1)](x_index[i][j-1]) + \
                                self.coefficient[i, j, 1, :channel_num][None, :, None, None] * self.downsample_module['{}_{}'.format(i-1, j)](x_index[i-1][j])

        x_index[i][j+1] = self.rdb_module['{}_{}'.format(i, j)](x_index[i][j])
        k = j

        for j in range(self.width // 2 + 1, self.width):
            x_index[i][j] = self.rdb_module['{}_{}'.format(i, j-1)](x_index[i][j-1])

        for i in range(self.height - 2, -1, -1):
            channel_num = int(2 ** (i-1) * self.stride * self.depth_rate)
            x_index[i][k+1] = self.coefficient[i, k+1, 0, :channel_num][None, :, None, None] * self.rdb_module['{}_{}'.format(i, k)](x_index[i][k]) + \
                              self.coefficient[i, k+1, 1, :channel_num][None, :, None, None] * self.upsample_module['{}_{}'.format(i, k+1)](x_index[i+1][k+1], x_index[i][k].size())

        for i in range(self.height - 2, -1, -1):
            for j in range(self.width // 2 + 1, self.width):
                channel_num = int(2 ** (i - 1) * self.stride * self.depth_rate)
                x_index[i][j] = self.coefficient[i, j, 0, :channel_num][None, :, None, None] * self.rdb_module['{}_{}'.format(i, j-1)](x_index[i][j-1]) + \
                                self.coefficient[i, j, 1, :channel_num][None, :, None, None] * self.upsample_module['{}_{}'.format(i, j)](x_index[i+1][j], x_index[i][j-1].size())

        out = self.rdb_out(x_index[i][j])

        feat_extra = F.relu(self.conv1(x_index[-1][j]))
        feat_extra = F.relu(self.conv2_atrous(feat_extra))
        feat_extra = F.relu(self.conv3_atrous(feat_extra))
        feat_extra = F.relu(self.conv4_atrous(feat_extra))
        feat_extra = F.relu(self.conv5_atrous(feat_extra))
        feat_extra = F.relu(self.conv6(feat_extra))
        offset = F.relu(self.offset_conv1(feat_extra))
        offset = F.relu(self.offset_conv2(offset))
        #feat_extra = F.relu(self.dcnpack([feat_extra, offset]))
        feat_extra = self.upsamle1(feat_extra, x_index[-2][j].size())
        feat_extra = self.coefficient[-2, 0, 0, :32][None, :, None, None] * x_index[-2][j] + self.coefficient[-2, 0, 0, 32:64][None, :, None, None] * feat_extra
        feat_extra = self.rdb_2_1(feat_extra)
        feat_extra = self.upsamle2(feat_extra, x_index[0][j].size())
        feat_extra = self.coefficient[0, 0, 0, :16][None, :, None, None] * out + self.coefficient[0, 0, 0, 16:32][None, :, None, None] * feat_extra
        out = self.rdb_1_1(feat_extra)
        out = F.relu(self.conv_out(out))
        #out = out + x
        return out, feat_extra

class Generate_quarter_refine(nn.Module):
    def __init__(self, in_channels=3, depth_rate=16, kernel_size=3, stride=2, height=3, width=6, num_dense_layer=4, growth_rate=16, attention=True):
        super(Generate_quarter_refine, self).__init__()
        self.rdb_module = nn.ModuleDict()
        self.upsample_module = nn.ModuleDict()
        self.downsample_module = nn.ModuleDict()
        self.height = height
        self.width = width
        self.stride = stride
        self.depth_rate = depth_rate
        self.coefficient = nn.Parameter(torch.Tensor(np.ones((height, width, 2, depth_rate*stride**(height-1)))), requires_grad=attention)
        self.conv_in = nn.Conv2d(in_channels, depth_rate, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)
        self.conv_out = nn.Conv2d(depth_rate, in_channels, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)
        self.rdb_in = RDB(depth_rate, num_dense_layer, growth_rate)
        self.rdb_out = RDB(depth_rate, num_dense_layer, growth_rate)

        rdb_in_channels = depth_rate
        for i in range(height):
            for j in range(width - 1):
                self.rdb_module.update({'{}_{}'.format(i, j): RDB(rdb_in_channels, num_dense_layer, growth_rate)})
            rdb_in_channels *= stride

        _in_channels = depth_rate
        for i in range(height - 1):
            for j in range(width // 2):
                self.downsample_module.update({'{}_{}'.format(i, j): DownSample(_in_channels)})
            _in_channels *= stride

        for i in range(height - 2, -1, -1):
            for j in range(width // 2, width):
                self.upsample_module.update({'{}_{}'.format(i, j): UpSample(_in_channels)})
            _in_channels //= stride

        self.conv1 = nn.Conv2d(depth_rate * 4, depth_rate * 4, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)
        self.conv2_atrous = nn.Conv2d(depth_rate * 4, depth_rate * 4, kernel_size=kernel_size, padding=2, dilation=2)
        self.conv3_atrous = nn.Conv2d(depth_rate * 4, depth_rate * 4, kernel_size=kernel_size, padding=4, dilation=4)
        self.conv4_atrous = nn.Conv2d(depth_rate * 4, depth_rate * 4, kernel_size=kernel_size, padding=8, dilation=8)
        self.conv5_atrous = nn.Conv2d(depth_rate * 4, depth_rate * 4, kernel_size=kernel_size, padding=16, dilation=16)
        self.conv6 = nn.Conv2d(depth_rate * 4, depth_rate * 4, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)
        self.offset_conv1 = nn.Conv2d(depth_rate*4, depth_rate*2, 3, 1, 1, bias=True)
        self.offset_conv2 = nn.Conv2d(depth_rate*2, depth_rate*4, 3, 1, 1, bias=True)
        #self.dcnpack = DCN(depth_rate*4, depth_rate*4, 3, stride=1, padding=1, dilation=1, deformable_groups=8, extra_offset_mask=True)
        self.upsamle1 = UpSample(depth_rate*4)
        self.upsamle2 = UpSample(depth_rate*2)

        self.rdb_2_1 = RDB(depth_rate * 2, num_dense_layer, growth_rate)
        self.rdb_1_1 = RDB(depth_rate, num_dense_layer, growth_rate)

    def forward(self, x):
        inp = self.conv_in(x)

        x_index = [[0 for _ in range(self.width)] for _ in range(self.height)]
        i, j = 0, 0

        x_index[0][0] = self.rdb_in(inp)

        for j in range(1, self.width // 2):
            x_index[0][j] = self.rdb_module['{}_{}'.format(0, j-1)](x_index[0][j-1])

        for i in range(1, self.height):
            x_index[i][0] = self.downsample_module['{}_{}'.format(i-1, 0)](x_index[i-1][0])

        for i in range(1, self.height):
            for j in range(1, self.width // 2):
                channel_num = int(2**(i-1)*self.stride*self.depth_rate)
                x_index[i][j] = self.coefficient[i, j, 0, :channel_num][None, :, None, None] * self.rdb_module['{}_{}'.format(i, j-1)](x_index[i][j-1]) + \
                                self.coefficient[i, j, 1, :channel_num][None, :, None, None] * self.downsample_module['{}_{}'.format(i-1, j)](x_index[i-1][j])

        x_index[i][j+1] = self.rdb_module['{}_{}'.format(i, j)](x_index[i][j])
        k = j

        for j in range(self.width // 2 + 1, self.width):
            x_index[i][j] = self.rdb_module['{}_{}'.format(i, j-1)](x_index[i][j-1])

        for i in range(self.height - 2, -1, -1):
            channel_num = int(2 ** (i-1) * self.stride * self.depth_rate)
            x_index[i][k+1] = self.coefficient[i, k+1, 0, :channel_num][None, :, None, None] * self.rdb_module['{}_{}'.format(i, k)](x_index[i][k]) + \
                              self.coefficient[i, k+1, 1, :channel_num][None, :, None, None] * self.upsample_module['{}_{}'.format(i, k+1)](x_index[i+1][k+1], x_index[i][k].size())

        for i in range(self.height - 2, -1, -1):
            for j in range(self.width // 2 + 1, self.width):
                channel_num = int(2 ** (i - 1) * self.stride * self.depth_rate)
                x_index[i][j] = self.coefficient[i, j, 0, :channel_num][None, :, None, None] * self.rdb_module['{}_{}'.format(i, j-1)](x_index[i][j-1]) + \
                                self.coefficient[i, j, 1, :channel_num][None, :, None, None] * self.upsample_module['{}_{}'.format(i, j)](x_index[i+1][j], x_index[i][j-1].size())

        out = self.rdb_out(x_index[i][j])
        feat_extra = F.relu(self.conv1(x_index[-1][j]))
        feat_extra = F.relu(self.conv2_atrous(feat_extra))
        feat_extra = F.relu(self.conv3_atrous(feat_extra))
        feat_extra = F.relu(self.conv4_atrous(feat_extra))
        feat_extra = F.relu(self.conv5_atrous(feat_extra))
        feat_extra = F.relu(self.conv6(feat_extra))
        offset = F.relu(self.offset_conv1(feat_extra))
        offset = F.relu(self.offset_conv2(offset))
        #feat_extra = F.relu(self.dcnpack([feat_extra, offset]))
        feat_extra = self.upsamle1(feat_extra, x_index[-2][j].size())
        feat_extra = self.coefficient[-2, 0, 0, :32][None, :, None, None] * x_index[-2][j] + self.coefficient[-2, 0, 0, 32:64][None, :, None, None] * feat_extra
        feat_extra = self.rdb_2_1(feat_extra)
        feat_extra = self.upsamle2(feat_extra, x_index[0][j].size())
        feat_extra = self.coefficient[0, 0, 0, :16][None, :, None, None] * out + self.coefficient[0, 0, 0, 16:32][None, :, None, None] * feat_extra
        out = self.rdb_1_1(feat_extra)
        feat = out
        out = F.relu(self.conv_out(out))
        #out = out + x
        return out, feat, feat_extra

class Generate(nn.Module):
    def __init__(self, in_channels=3, depth_rate=16, kernel_size=3, stride=2, height=3, width=6, num_dense_layer=4, growth_rate=16, attention=True):
        super(Generate, self).__init__()
        self.height = height
        self.width = width
        self.stride = stride
        self.depth_rate = depth_rate

        self.conv_in_1 = nn.Conv2d(in_channels, depth_rate, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)
        self.conv_in_2 = nn.Conv2d(in_channels, depth_rate, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)
        self.conv_1_downsample = nn.Conv2d(depth_rate * 2, depth_rate * 2, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, stride = 2)

        self.conv_2 = nn.Conv2d(depth_rate * 2, depth_rate * 2, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)
        self.conv_2_downsample = nn.Conv2d(depth_rate * 2, depth_rate * 4, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, stride = 2)

        self.conv_3 = nn.Conv2d(depth_rate * 4, depth_rate * 4, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)
        self.rdb_3_1 = RDB(depth_rate * 4, num_dense_layer, growth_rate)
        self.rdb_3_2 = RDB(depth_rate * 4, num_dense_layer, growth_rate)

        self.feat_pass = nn.Conv2d(depth_rate, depth_rate * 4, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)

        self.rdb_3_3 = RDB(depth_rate * 4, num_dense_layer, growth_rate)
        self.rdb_3_4 = RDB(depth_rate * 4, num_dense_layer, growth_rate)
        self.rdb_3_5 = RDB(depth_rate * 4, num_dense_layer, growth_rate)
        self.rdb_3_6 = RDB(depth_rate * 4, num_dense_layer, growth_rate)

        self.conv_out = nn.Conv2d(depth_rate, in_channels, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)

        self.upsample_L3 = UpSample(depth_rate * 4)

        self.rdb_2_1 = RDB(depth_rate * 2, num_dense_layer, growth_rate)
        self.rdb_2_2 = RDB(depth_rate * 2, num_dense_layer, growth_rate)
        self.rdb_2_3 = RDB(depth_rate * 2, num_dense_layer, growth_rate)
        self.rdb_2_4 = RDB(depth_rate * 2, num_dense_layer, growth_rate)

        self.upsample_L2 = UpSample(depth_rate * 2)

        self.rdb_1_1 = RDB(depth_rate, num_dense_layer, growth_rate)
        self.rdb_1_2 = RDB(depth_rate, num_dense_layer, growth_rate)


    def forward(self, x1, x2, feat):
        inp1 = F.relu(self.conv_in_1(x1))
        inp2 = F.relu(self.conv_in_2(x2))
        conv2 = F.relu(self.conv_1_downsample(torch.cat([inp1, inp2], 1)))
        conv2 = F.relu(self.conv_2(conv2))
        conv3 = F.relu(self.conv_2_downsample(conv2))
        conv3 = F.relu(self.conv_3(conv3))
        conv3 = self.rdb_3_1(conv3)
        conv3 = self.rdb_3_2(conv3)

        # direct
        feat_pass = self.feat_pass(feat)
        conv3 = conv3 + feat_pass
        conv3 = self.rdb_3_3(conv3)
        conv3 = self.rdb_3_4(conv3)
        conv3 = self.rdb_3_5(conv3)
        conv3 = self.rdb_3_6(conv3)
        conv2_up = self.upsample_L3(conv3, conv2.size())
        conv2_up = self.rdb_2_1(conv2_up)
        conv2_up = self.rdb_2_2(conv2_up)
        conv2_up = self.rdb_2_3(conv2_up)
        conv2_up = self.rdb_2_4(conv2_up)
        conv1_up = self.upsample_L2(conv2_up, x1.size())
        conv1_up = self.rdb_1_1(conv1_up)
        conv1_up = self.rdb_1_2(conv1_up)
        out = self.conv_out(conv1_up)
        out = F.relu(out + x2)
        return out

class LossD(nn.Module):
    def __init__(self):
        super(LossD, self).__init__()

    def forward(self, r_x, r_x_hat):
        return (F.relu(1 + r_x_hat) + F.relu(1 - r_x)).mean().reshape(1)

class LossFeat(nn.Module):
    def __init__(self):
        super(LossFeat, self).__init__()

    def forward(self, feats1, feats2):
        loss = []
        for (f1, f2) in zip(feats1, feats2):
            loss.append(F.mse_loss(f1, f2))
        return sum(loss)/len(loss)

class Lap(nn.Module):
    def __init__(self, channels=3):
        super(Lap, self).__init__()
        self.channels = channels
        # print("channels: ", channels.shape)
        kernel = [[0,1,0],[1,-4,1],[0,1,0]]#   [[1,1,1],[1,-8,1],[1,1,1]]

        kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)    # (H, W) -> (1, 1, H, W)
        kernel = kernel.expand((int(channels), 1, 3, 3))
        self.weight = nn.Parameter(data=kernel, requires_grad=False).cuda()

    def __call__(self, dehaze, gt):
        #m = nn.Upsample(scale_factor=0.25, mode='nearest')
        #gt = m(gt)
        dehaze = F.conv2d(dehaze, self.weight, padding=1, groups=self.channels)
        gt = F.conv2d(gt, self.weight, padding=1, groups=self.channels)
        loss = []
        for dehaze1, gt1 in zip(dehaze, gt):
            loss.append(F.mse_loss(dehaze1, gt1))
        return sum(loss)/len(loss)

In [None]:
class Self_Attn(nn.Module):
    """ Self attention Layer"""
    def __init__(self,in_dim,activation,with_attn=False):
        super(Self_Attn,self).__init__()
        self.chanel_in = in_dim
        self.activation = activation
        self.with_attn = with_attn
        self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax  = nn.Softmax(dim=-1) #
    def forward(self,x):
        """
            inputs :
                x : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature
                attention: B X N X N (N is Width*Height)
        """
        m_batchsize,C,width ,height = x.size()
        proj_query  = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)
        proj_key =  self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)
        energy =  torch.bmm(proj_query,proj_key) # transpose check
        attention = self.softmax(energy) # BX (N) X (N)
        proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N

        out = torch.bmm(proj_value,attention.permute(0,2,1) )
        out = out.view(m_batchsize,C,width,height)

        out = self.gamma*out + x
        if self.with_attn:
            return out ,attention
        else:
            return out

class Discriminator(nn.Module):
    def __init__(self, input_nc, ndf=64):
        super(Discriminator, self).__init__()
        self.layer1_new = nn.Sequential(*[nn.utils.spectral_norm(nn.Conv2d(input_nc, ndf, 5, 2, 2)), nn.LeakyReLU(0.2, True)]) # 1/2
        self.layer2 = nn.Sequential(*[nn.utils.spectral_norm(nn.Conv2d(ndf, ndf * 2, 5, 2, 2)), nn.LeakyReLU(0.2, True)]) # 1/4
        self.layer3 = nn.Sequential(*[nn.utils.spectral_norm(nn.Conv2d(ndf * 2, ndf * 4, 5, 2, 2)), nn.LeakyReLU(0.2, True)]) # 1/8
        self.layer4 = nn.Sequential(*[nn.utils.spectral_norm(nn.Conv2d(ndf * 4, ndf * 4, 5, 2, 2)), nn.LeakyReLU(0.2, True)]) # 1/16
        self.layer5 = nn.Sequential(*[nn.utils.spectral_norm(nn.Conv2d(ndf * 4, ndf * 4, 5, 2, 2)), nn.LeakyReLU(0.2, True)]) # 1/32
        self.att = Self_Attn(ndf * 4, 'relu')
        self.layer6 = nn.Sequential(*[nn.utils.spectral_norm(nn.Conv2d(ndf * 4, ndf * 4, 5, 2, 2)), nn.LeakyReLU(0.2, True)]) # 1/64

    def forward(self, input):
        feats = []
        out = self.layer1_new(input)
        feats.append(out)
        out = self.layer2(out)
        feats.append(out)
        out = self.layer3(out)
        feats.append(out)
        out = self.layer4(out)
        feats.append(out)
        out = self.layer5(out)
        feats.append(out)
        out = self.att(out)
        out = self.layer6(out)
        feats.append(out)
        out = out.view(out.size(0), -1)
        return out, feats

# --- Downsampling block in GridDehazeNet  --- #
class DownSample(nn.Module):
    def __init__(self, in_channels, kernel_size=3, stride=2):
        super(DownSample, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride=stride, padding=(kernel_size-1)//2)
        self.conv2 = nn.Conv2d(in_channels, stride*in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.relu(self.conv2(out))
        return out


# --- Upsampling block in GridDehazeNet  --- #
class UpSample(nn.Module):
    def __init__(self, in_channels, kernel_size=3, stride=2):
        super(UpSample, self).__init__()
        self.deconv = nn.ConvTranspose2d(in_channels, in_channels, kernel_size, stride=stride, padding=1)
        self.conv = nn.Conv2d(in_channels, in_channels // stride, kernel_size, stride=1, padding=(kernel_size - 1) // 2)

    def forward(self, x, output_size):
        out = F.relu(self.deconv(x, output_size=output_size))
        out = F.relu(self.conv(out))
        return out


In [None]:
# --- Perceptual loss network  --- #
class LossNetworkF(torch.nn.Module):
    def __init__(self, vgg_model):
        super(LossNetworkF, self).__init__()
        self.vgg_layers = vgg_model
        self.layer_name_mapping = {
            '3': "relu1_2",
            '8': "relu2_2",
            '15': "relu3_3"
        }

    def output_features(self, x):
        output = {}
        for name, module in self.vgg_layers._modules.items():
            x = module(x)
            if name in self.layer_name_mapping:
                output[self.layer_name_mapping[name]] = x
        return list(output.values())

    def forward(self, dehaze, gt):
        loss = []
        dehaze_features = self.output_features(dehaze)
        gt_features = self.output_features(gt)
        for dehaze_feature, gt_feature in zip(dehaze_features, gt_features):
            size = dehaze_feature.size()
            pad = torch.zeros(size).cuda()
            print(size)
            import ff
            #dehaze_feature = torch.unsqueeze(dehaze_feature, 2)
            dehaze_feature = torch.cat([dehaze_feature,pad], dim=2)
            #gt_feature = torch.unsqueeze(gt_feature, 2)
            gt_feature = torch.cat([gt_feature,pad], dim=2)
            f1 = torch.fft(dehaze_feature, 2)
            f2 = torch.fft(gt_feature, 2)
            loss.append(F.mse_loss(f1[:,:,0]*f1[:,:,0]+f1[:,:,1]*f1[:,:,1], f2[:,:,0]*f2[:,:,0]+f2[:,:,1]*f2[:,:,1]))

        return torch.from_numpy(sum(loss)/len(loss)).cuda()

class LossNetwork(torch.nn.Module):
    def __init__(self, vgg_model):
        super(LossNetwork, self).__init__()
        self.vgg_layers = vgg_model
        self.layer_name_mapping = {
            '3': "relu1_2",
            '8': "relu2_2",
            '15': "relu3_3"
        }

    def output_features(self, x):
        output = {}
        for name, module in self.vgg_layers._modules.items():
            x = module(x)
            if name in self.layer_name_mapping:
                output[self.layer_name_mapping[name]] = x
        return list(output.values())

    def forward(self, dehaze, gt):
        loss = []
        dehaze_features = self.output_features(dehaze)
        gt_features = self.output_features(gt)
        for dehaze_feature, gt_feature in zip(dehaze_features, gt_features):
            loss.append(F.mse_loss(dehaze_feature, gt_feature))

        return sum(loss)/len(loss)

class LossNetworkL1(torch.nn.Module):
    def __init__(self, vgg_model):
        super(LossNetworkL1, self).__init__()
        self.vgg_layers = vgg_model
        self.layer_name_mapping = {
            '3': "relu1_2",
            '8': "relu2_2",
            '15': "relu3_3"
        }
        self.loss_rec = nn.L1Loss()

    def output_features(self, x):
        output = {}
        for name, module in self.vgg_layers._modules.items():
            x = module(x)
            if name in self.layer_name_mapping:
                output[self.layer_name_mapping[name]] = x
        return list(output.values())

    def forward(self, dehaze, gt):
        loss = []
        dehaze_features = self.output_features(dehaze)
        gt_features = self.output_features(gt)
        for dehaze_feature, gt_feature in zip(dehaze_features, gt_features):
            loss.append(self.loss_rec(dehaze_feature, gt_feature))

        return sum(loss)/len(loss)


In [None]:
net = Generate_quarter(height=network_height, width=network_width, num_dense_layer=num_dense_layer, growth_rate=growth_rate)
net.apply(initialize_weights)
print('First Phrase Init!')
optimizer = torch.optim.Adam(list(net.parameters()), lr=learning_rate, betas=(0.5, 0.999))
net = net.to(device)
net = nn.DataParallel(net, device_ids=device_ids)
pytorch_total_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print("Total_params: {}".format(pytorch_total_params))

In [None]:
# --- Define the perceptual loss network --- #
vgg_model = vgg16(pretrained=True).features[:16]
vgg_model = vgg_model.to(device)
for param in vgg_model.parameters():
    param.requires_grad = False
loss_network = LossNetwork(vgg_model)
loss_network.eval()
loss_lap = Lap()

In [None]:
start_epoch = 1


In [None]:
# --- Training dataset --- #
class TrainData(data.Dataset):
    def __init__(self, crop_size):
        super().__init__()
        train_csv = pd.read_csv('/content/drive/MyDrive/lg_vision/data/train.csv')
        train_all_input_files = '/content/drive/MyDrive/lg_vision/data/train_input_img/'+train_csv['input_img']
        train_label_all_input_files = '/content/drive/MyDrive/lg_vision/data/train_label_img/'+train_csv['label_img']
        
        #train 500 datasets
        train_all_input_files = train_all_input_files[:480]
        train_label_all_input_files = train_label_all_input_files[:480]
        

        haze_names = []
        gt_names = []
        for path in train_all_input_files:
            haze_names.append(path)
        
        for path2 in train_label_all_input_files:
            gt_names.append(path2)


        self.haze_names = haze_names
        self.gt_names = gt_names
        self.crop_size = crop_size
 
        self.haze_cache = {}
        self.gt_cache = {}

        for haze_name in haze_names:
            if haze_name in self.haze_cache:
                continue
            haze_img = Image.open(haze_name).convert('RGB')
            self.haze_cache[haze_name] = haze_img

        for gt_name in gt_names:
            if gt_name in self.gt_cache:
                continue
            gt_img = Image.open(gt_name).convert('RGB')
            self.gt_cache[gt_name] = gt_img

        print ('use cache')

    def generate_scale_label(self, haze, gt):
        f_scale = 0.8 + random.randint(0, 7) / 10.0
        width, height = haze.size
        haze = haze.resize((int(width * f_scale), (int(height * f_scale))), resample = (Image.BICUBIC))
        gt = gt.resize((int(width * f_scale), (int(height * f_scale))), resample = (Image.BICUBIC))
        return haze, gt

    def get_images(self, index):
        crop_width, crop_height = self.crop_size
        haze_name = self.haze_names[index]
        gt_name = self.gt_names[index]

        haze_img = self.haze_cache[haze_name]
        gt_img = self.gt_cache[gt_name]

        haze_img, gt_img = self.generate_scale_label(haze_img, gt_img)
        
        width, height = haze_img.size

        if width < crop_width or height < crop_height:
            raise Exception('Bad image size: {}'.format(gt_name))

        # --- x,y coordinate of left-top corner --- #
        x, y = randrange(0, width - crop_width + 1), randrange(0, height - crop_height + 1)
        haze_crop_img = haze_img.crop((x, y, x + crop_width, y + crop_height))
        gt_crop_img = gt_img.crop((x, y, x + crop_width, y + crop_height))

        rand_hor=random.randint(0,1)
        rand_rot=random.randint(0,3)
        haze_crop_img=tfs.RandomHorizontalFlip(rand_hor)(haze_crop_img)
        gt_crop_img=tfs.RandomHorizontalFlip(rand_hor)(gt_crop_img)
        if rand_rot:
          haze_crop_img=FF.rotate(haze_crop_img,90*rand_rot)
          gt_crop_img=FF.rotate(gt_crop_img,90*rand_rot)

        # --- Transform to tensor --- #
        transform_haze = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        transform_gt = Compose([ToTensor()])
        haze = transform_haze(haze_crop_img)
        gt = transform_gt(gt_crop_img)
        haze_gt = transform_gt(gt_crop_img)

        # --- Check the channel is 3 or not --- #
        if list(haze.shape)[0] is not 3 or list(gt.shape)[0] is not 3:
            raise Exception('Bad image channel: {}'.format(gt_name))

        return haze, gt, haze_gt

    def __getitem__(self, index):
        res = self.get_images(index)
        return res

    def __len__(self):
        return len(self.haze_names)


In [None]:
# import pickle
# file = open("/content/drive/MyDrive/lg_vision/train_data_",'rb')
# train_data = pickle.load(file)  

train_data = TrainData(crop_size)

In [None]:
# import pickle
# filehandler = open("/content/drive/MyDrive/lg_vision/train_data_480","wb")
# pickle.dump(train_data,filehandler)

In [None]:

# --- Validation/test dataset --- #
class TestData(data.Dataset):
    def __init__(self):
        train_csv = pd.read_csv('/content/drive/MyDrive/lg_vision/data/train.csv')
        train_all_input_files = '/content/drive/MyDrive/lg_vision/data/train_input_img/'+train_csv['input_img']
        train_label_all_input_files = '/content/drive/MyDrive/lg_vision/data/train_label_img/'+train_csv['label_img']
        
        #train 122 datasets
        train_all_input_files = train_all_input_files[480:]
        train_label_all_input_files = train_label_all_input_files[480:]

        
        haze_names = []
        gt_names = []

        for path in train_all_input_files:
            haze_names.append(path)
        
        for path2 in train_label_all_input_files:
            gt_names.append(path2)

        self.haze_names = haze_names
        self.gt_names = gt_names
        

    def get_images(self, index):
        haze_name = self.haze_names[index]
        gt_name = self.gt_names[index]
        haze_img = Image.open(haze_name)
        gt_img = Image.open(gt_name)

        # --- Transform to tensor --- #
        transform_haze = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        #haze = transform_haze(haze_img)

        transform_gt = Compose([ToTensor()])
        gt = transform_gt(gt_img)
        haze = transform_haze(haze_img)

        return haze, gt, haze_name

    def __getitem__(self, index):
        res = self.get_images(index)
        return res

    def __len__(self):
        return len(self.haze_names)


In [None]:


def main(test_phrase, test_epoch):
    device_ids = [Id for Id in range(torch.cuda.device_count())]
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    test_batch_size = 1
    network_height = 3
    network_width = 6
    num_dense_layer = 4
    growth_rate = 16
    test_phrase = test_phrase
    crop_size = [1600, 1200]
    test_data = TestData()
    test_data_loader = DataLoader(test_data, batch_size=test_batch_size)

    def save_image(dehaze, image_name, category):
        #dehaze_images = torch.split(dehaze, 1, dim=0)
        batch_num = len(dehaze)
    
        for ind in range(batch_num):
          utils.save_image(dehaze[ind], '/content/drive/MyDrive/lg_vision/after/{}_results/{}'.format(category,image_name[ind].split('/')[-1][:-3]+'png'))

    if test_phrase == 1:
      G1 = Generate_quarter(height=network_height, width=network_width, num_dense_layer=num_dense_layer, growth_rate=growth_rate)
      G1 = G1.to(device)
      G1 = nn.DataParallel(G1, device_ids=device_ids)
      G1.load_state_dict(torch.load('/content/drive/MyDrive/lg_vision/checkpoint/1_'+str(test_epoch)+'.tar'))
      G1.eval()
      psnr=[]
      net_time = 0.
      net_count = 0.
      for batch_id, test_data in enumerate(test_data_loader):
          with torch.no_grad():
              haze, gt, image_name = test_data
              #haze = F.interpolate(haze, scale_factor = 0.25)
              haze = haze.to(device)
              gt = gt.to(device)
              start_time = time.time()
              dehaze, _ = G1(haze)
              end_time = time.time() - start_time
              net_time += end_time
              net_count += 1
              test_info = to_psnr_test(dehaze, gt)
              psnr.append(sum(test_info) / len(test_info))
              print ("test : ",sum(test_info) / len(test_info))
      # --- Save image --- #
          save_image(dehaze,image_name,"NH")
      test_psnr = sum(psnr) / len(psnr)
      print ('Test PSNR:' + str(test_psnr))
      print('net time is {0:.4f}'.format(net_time / net_count))

    # if test_phrase == 2:
    #     G1 = Generate_quarter(height=network_height, width=network_width, num_dense_layer=num_dense_layer, growth_rate=growth_rate)
    #     G1 = G1.to(device)
    #     G1 = nn.DataParallel(G1, device_ids=device_ids)
    #     #G1.load_state_dict(torch.load('./checkpoint/1.tar'))
    #     G2 = Generate_quarter_refine(height=network_height, width=network_width, num_dense_layer=num_dense_layer, growth_rate=growth_rate)
    #     G2 = G2.to(device)
    #     G2 = nn.DataParallel(G2, device_ids=device_ids)
    #     G1.load_state_dict(torch.load('./checkpoint/2-'+str(test_epoch)+'_G1.tar'))

    #     G2.load_state_dict(torch.load('./checkpoint/2_' + str(test_epoch) +'_G2.tar'))
    #     G1.eval()
    #     G2.eval()
    #     psnr=[]
    #     net_time = 0.
    #     net_count = 0.
    #     for batch_id, test_data in enumerate(test_data_loader):
    #         with torch.no_grad():
    #             haze, gt, image_name = test_data
    #             #haze = F.interpolate(haze, scale_factor = 0.25,recompute_scale_factor=True)
    #             haze = haze.to(device)
    #             gt =gt.to(device)
    #             start_time = time.time()
    #             dehaze_1, feat1 = G1(haze)
    #             dehaze, _, _ = G2(dehaze_1)
    #             gt = gt
    #             end_time = time.time() - start_time
    #             net_time += end_time
    #             net_count += 1
    #             test_info = to_psnr_test(dehaze, gt)
    #             psnr.append(sum(test_info) / len(test_info))
    #             print (sum(test_info) / len(test_info))
    #         # --- Save image --- #
    #         save_image(dehaze, image_name, 'NH')
    #     test_psnr = sum(psnr) / len(psnr)
    #     print ('Test PSNR:' + str(test_psnr))
    #     print('net time is {0:.4f}'.format(net_time / net_count))

    # if test_phrase == 3:
    #     G1 = Generate_quarter(height=network_height, width=network_width, num_dense_layer=num_dense_layer, growth_rate=growth_rate)
    #     G1 = G1.to(device)
    #     G1 = nn.DataParallel(G1, device_ids=device_ids)
    #     G1.load_state_dict(torch.load('./checkpoint/3-'+str(test_epoch)+'_G1.tar'))
    #     G2 = Generate_quarter_refine(height=network_height, width=network_width, num_dense_layer=num_dense_layer, growth_rate=growth_rate)
    #     G2 = G2.to(device)
    #     G2 = nn.DataParallel(G2, device_ids=device_ids)
    #     G2.load_state_dict(torch.load('./checkpoint/3_'+str(test_epoch)+'_G2.tar'))
    #     G3 = Generate(height=network_height, width=network_width, num_dense_layer=num_dense_layer, growth_rate=growth_rate)
    #     G3 = G3.to(device)
    #     G3 = nn.DataParallel(G3, device_ids=device_ids)
    #     G3.load_state_dict(torch.load('./checkpoint/33_'+str(test_epoch)+'_G3.tar'))
    #     G1.eval()
    #     G2.eval()
    #     G3.eval()
    #     psnr=[]
    #     net_time = 0.
    #     net_count = 0.
    #     for batch_id, test_data in enumerate(test_data_loader):
    #         with torch.no_grad():
    #             haze, gt, image_name = test_data
    #             haze = haze.to(device)
    #             gt = gt.to(device)
    #             start_time = time.time()
    #             dehaze_1, feat1 = G1(F.interpolate(haze, scale_factor = 0.25,recompute_scale_factor=True))
    #             dehaze_2, feat, feat2 = G2(dehaze_1)
    #             dehaze= G3(haze, F.interpolate(dehaze_2, scale_factor = 4,recompute_scale_factor=True), feat)
    #             end_time = time.time() - start_time
    #             net_time += end_time
    #             net_count += 1
    #             test_info = to_psnr(dehaze, gt)
    #             psnr.append(sum(test_info) / len(test_info))
    #             print (sum(test_info) / len(test_info))
    #         # --- Save image --- #
    #         save_image(dehaze, image_name, 'NH')
    #     test_psnr = sum(psnr) / len(psnr)
    #     print ('Test PSNR:' + str(test_psnr))
    #     print('net time is {0:.4f}'.format(net_time / net_count))
    return test_psnr








In [None]:
if device.type == 'cuda':
    torch.cuda.empty_cache()

    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_cached(0)/1024**3,1), 'GB')

In [None]:
train_data_loader = DataLoader(train_data, batch_size=train_batch_size, shuffle=True, num_workers = 2)


In [None]:
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_cached(0)/1024**3,1), 'GB')

In [None]:
import time
import torch
import torch.nn.functional as F
import torchvision.utils as utils
from math import log10
from skimage import measure
import torch.nn as nn


def to_psnr(dehaze, gt):
    mse = F.mse_loss(dehaze, gt, reduction='none')
    mse_split = torch.split(mse, 1, dim=0)
    mse_list = [torch.mean(torch.squeeze(mse_split[ind])).item() for ind in range(len(mse_split))]

    intensity_max = 1.0
    psnr_list = [10.0 * log10(intensity_max / mse) for mse in mse_list]
    return psnr_list

def to_psnr_test(dehaze, gt):
    #print(dehaze.size())
    #print(gt.size())
    #m = nn.Upsample(scale_factor=4)
    #dehaze = m(dehaze)
    #print(dehaze.size())
    #import ff
    mse = F.mse_loss(dehaze, gt, reduction='none')
    mse_split = torch.split(mse, 1, dim=0)
    mse_list = [torch.mean(torch.squeeze(mse_split[ind])).item() for ind in range(len(mse_split))]

    intensity_max = 1.0
    psnr_list = [10.0 * log10(intensity_max / mse) for mse in mse_list]
    return psnr_list


def to_ssim_skimage(dehaze, gt):
    dehaze_list = torch.split(dehaze, 1, dim=0)
    gt_list = torch.split(gt, 1, dim=0)

    dehaze_list_np = [dehaze_list[ind].permute(0, 2, 3, 1).data.cpu().numpy().squeeze() for ind in range(len(dehaze_list))]
    gt_list_np = [gt_list[ind].permute(0, 2, 3, 1).data.cpu().numpy().squeeze() for ind in range(len(dehaze_list))]
    ssim_list = [measure.compare_ssim(dehaze_list_np[ind],  gt_list_np[ind], data_range=1, multichannel=True) for ind in range(len(dehaze_list))]

    return ssim_list


def validation(net, val_data_loader, device, category, save_tag=False):
    """
    :param net: GateDehazeNet
    :param val_data_loader: validation loader
    :param device: The GPU that loads the network
    :param category: indoor or outdoor test dataset
    :param save_tag: tag of saving image or not
    :return: average PSNR value
    """
    psnr_list = []
    ssim_list = []

    for batch_id, val_data in enumerate(val_data_loader):

        with torch.no_grad():
            haze, gt, image_name = val_data
            haze = haze.to(device)
            gt = gt.to(device)
            dehaze = net(haze)

        # --- Calculate the average PSNR --- #
        psnr_list.extend(to_psnr(dehaze, gt))

        # --- Calculate the average SSIM --- #
        ssim_list.extend(to_ssim_skimage(dehaze, gt))

        # --- Save image --- #
        if save_tag:
            save_image(dehaze, image_name, category)

    avr_psnr = sum(psnr_list) / len(psnr_list)
    avr_ssim = sum(ssim_list) / len(ssim_list)
    return avr_psnr, avr_ssim

def test_net(G2, G1, G3, test_data_loader, device, save_tag=False):
    net_time = 0.
    net_count = 0.
    for batch_id, test_data in enumerate(test_data_loader):
        with torch.no_grad():
            haze, image_name = test_data
            haze = haze.to(device)
            start_time = time.time()
            dehaze_1, feat1 = G1(F.interpolate(haze, scale_factor = 0.25))
            dehaze_2, feat, feat2 = G2(dehaze_1)
            dehaze, _, _ = G3(haze, F.interpolate(dehaze_2, scale_factor = 4), feat, feat1, feat2)
            end_time = time.time() - start_time
            net_time += end_time
            net_count += 1

        # --- Save image --- #
        if save_tag:
            save_image(dehaze, image_name, 'NH')

    print('net time is {0:.4f}'.format(net_time / net_count))

def save_image(dehaze, image_name, category):
    dehaze_images = torch.split(dehaze, 1, dim=0)
    batch_num = len(dehaze_images)

    for ind in range(batch_num):
        utils.save_image(dehaze_images[ind], '{}_results/{}'.format(category, image_name[ind][:-3] + 'png'))


def print_log(epoch, train_psnr, category):
    # --- Write the training log --- #
    
    with open('/content/drive/MyDrive/lg_vision/training_log/{}_log.txt'.format(category), 'a') as f:
        print('Date: {0}s, Time_Cost: {1:.0f}s, Epoch: [{2}/{3}], Train_PSNR: {4:.2f}, Val_PSNR: {5:.2f}, Val_SSIM: {6:.4f}'
              .format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
                      0, epoch, 0, train_psnr, 0, 0), file=f)


def adjust_learning_rate(optimizer, epoch, category, lr_decay=0.5):

    # --- Decay learning rate --- #
    # step = 20 if category == 'indoor' else 2000

    step = 200
    if not epoch % step and epoch > 0:
        for param_group in optimizer.param_groups:
            param_group['lr'] *= lr_decay
            print('Learning rate sets to {}.'.format(param_group['lr']))
    else:
        for param_group in optimizer.param_groups:
            print('Learning rate sets to {}.'.format(param_group['lr']))

def positiivate_weights(x):
    return F.relu(x) / (F.relu(x) + 1e-10)


In [None]:
loss_rec1 = nn.SmoothL1Loss()
loss_rec2 = nn.MSELoss()
num = 0
train_phrase =1
avg = nn.AvgPool2d(3, stride = 2, padding = 1)
num_epochs = 500
for epoch in range(start_epoch, num_epochs):
    psnr_list = []
    start_time = time.time()
    adjust_learning_rate(optimizer, epoch, category=category)

    for batch_id, train_data in enumerate(train_data_loader):

        optimizer.zero_grad()
        haze, gt, haze_gt = train_data
        haze = haze.to(device)
        gt = gt.to(device)
        haze_gt = haze_gt.to(device)
        gt_quarter_1 = F.interpolate(gt, scale_factor = 0.25,recompute_scale_factor=True)
        gt_quarter_2 = F.interpolate(gt, scale_factor = 0.25,recompute_scale_factor=True)

        # --- Forward + Backward + Optimize --- #

        # if train_phrase == 1:
        dehaze_1, feat_extra_1 = net(haze)
        rec_loss1 = loss_rec1(dehaze_1, gt)
        perceptual_loss = loss_network(dehaze_1, gt)
        lap_loss = loss_lap(dehaze_1, gt)
        psnr = to_psnr(dehaze_1, gt)
        psnr_list.extend(to_psnr(dehaze_1, gt))
        train_info = to_psnr(dehaze_1, gt)
        # if train_phrase == 2:
        #     dehaze_1, feat_extra_1 = net(haze)
        #     dehaze_2, feat, feat_extra_2 = G2(dehaze_1)
        #     rec_loss1 = (loss_rec1(dehaze_2, gt) + loss_rec1(dehaze_1, gt))/2.0
        #     rec_loss2 = loss_rec2(dehaze_2, gt)
        #     perceptual_loss = loss_network(dehaze_2, gt)
        #     lap_loss = loss_lap(dehaze_2, gt)
        #     psnr = to_psnr(dehaze_2, gt)
        #     psnr_list.extend(to_psnr(dehaze_2, gt))
        #     train_info = to_psnr(dehaze_2, gt)
        # if train_phrase == 3:
        #     dehaze_1, feat_extra_1 = net(F.interpolate(haze, scale_factor = 0.25,recompute_scale_factor=True))
        #     dehaze_2, feat, feat_extra_2 = G2(dehaze_1)
        #     dehaze = G3(haze, F.interpolate(dehaze_2, scale_factor = 4,recompute_scale_factor=True), feat)
        #     rec_loss1 = (loss_rec1(dehaze, gt) + loss_rec1(dehaze_2, gt_quarter_2)+loss_rec1(dehaze_1, gt_quarter_1))/3.0
        #     rec_loss2 = loss_rec2(dehaze, gt)
        #     perceptual_loss = (loss_network(dehaze, gt) + loss_network(F.interpolate(dehaze, scale_factor = 0.5,recompute_scale_factor=True), F.interpolate(gt, scale_factor = 0.5,recompute_scale_factor=True)) + loss_network(F.interpolate(dehaze, scale_factor = 0.25,recompute_scale_factor=True), F.interpolate(gt, scale_factor = 0.25,recompute_scale_factor=True)) + loss_network(dehaze_2, gt_quarter_2))/4.0
        #     lap_loss = loss_lap(dehaze, gt)
        #     psnr = to_psnr(dehaze, gt)
        #     psnr_list.extend(to_psnr(dehaze, gt))
        #     train_info = to_psnr(dehaze, gt)

        loss = (rec_loss1) * 1.2 + 0.04 *perceptual_loss #+ 0.5 * lap_loss

        loss.backward()
        optimizer.step()

        if not (batch_id % 1):
            print('Epoch: {0}, Iteration: {1}'.format(epoch, batch_id))
            print (sum(train_info) / len(train_info))
            writer.add_scalar('scalar/loss_w/ IN', loss, num)
            writer.add_scalar('scalar/psnr_w/ IN', sum(psnr) / len(psnr), num)
            num = num + 1

    # --- Calculate the average training PSNR in one epoch --- #
    train_psnr = sum(psnr_list) / len(psnr_list)
    print_log(epoch+1, train_psnr, category)

    if epoch % 5==0:
        torch.save(net.state_dict(), '/content/drive/MyDrive/lg_vision/checkpoint/'+str(int(train_phrase))+'_'+str(epoch)+'.tar')
        # if train_phrase == 2:
        #     torch.save(net.state_dict(),'./checkpoint/'+str(int(train_phrase))+'-'+str(int(epoch))+'_G1.tar')
        #     torch.save(G2.state_dict(), './checkpoint/'+str(int(train_phrase))+'_'+str(int(epoch))+'_G2.tar')
        # if train_phrase == 3:
        #     torch.save(net.state_dict(),'./checkpoint/'+str(int(train_phrase))+'-'+str(int(epoch))+'_G1.tar')
        #     torch.save(G2.state_dict(), './checkpoint/'+str(int(train_phrase))+'_'+str(int(epoch))+'_G2.tar')
        #     torch.save(G3.state_dict(), './checkpoint/'+str(int(train_phrase))+str(int(train_phrase))+'_'+str(epoch)+'_G3.tar')
        test_psnr = main(1, epoch)
        writer.add_scalar('scalar/psnr_test_w/ IN', test_psnr, epoch)

Result


In [None]:

# --- Validation/test dataset --- #
class SubmissionData(data.Dataset):
    def __init__(self):
        test_csv = pd.read_csv('/content/drive/MyDrive/lg_vision/data/test.csv')
        test_all_input_files = '/content/drive/MyDrive/lg_vision/data/test_input_img/'+test_csv['input_img']
        test_submission_files = '/content/drive/MyDrive/lg_vision/data/test_input_img/'+test_csv['submission_name']
        

        
        input_names = []
        submission_names = []

        for path in test_all_input_files:
            input_names.append(path)
        
        for path2 in test_submission_files:
            submission_names.append(path2)

        self.input_names = input_names
        self.submission_names = submission_names
      

    def get_images(self, index):
        input_name = self.input_names[index]
        input_img = Image.open(input_name)

        # --- Transform to tensor --- #
        transform_input = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        #haze = transform_haze(haze_img)

        input = transform_input(input_img)

        return input, input_name

    def __getitem__(self, index):
        res = self.get_images(index)
        return res

    def __len__(self):
        return len(self.input_names)


In [None]:



test_batch_size = 1
network_height = 3
network_width = 6
num_dense_layer = 4
growth_rate = 16
test_phrase = 1


crop_size = [1600, 1200]
test_data = SubmissionData()
test_data_loader = DataLoader(test_data, batch_size=test_batch_size)

test_epoch = 495

def save_image(dehaze, image_name):
    #dehaze_images = torch.split(dehaze, 1, dim=0)
    batch_num = len(dehaze)

    for ind in range(batch_num):
      utils.save_image(dehaze[ind], '/content/drive/MyDrive/lg_vision/submission/{}'.format(image_name[ind].split('/')[-1][:-3]+'png'))


if test_phrase == 1:
  G1 = Generate_quarter(height=network_height, width=network_width, num_dense_layer=num_dense_layer, growth_rate=growth_rate)
  G1 = G1.to(device)
  G1 = nn.DataParallel(G1, device_ids=device_ids)
  G1.load_state_dict(torch.load('/content/drive/MyDrive/lg_vision/checkpoint/1_'+str(test_epoch)+'.tar'))
  G1.eval()

  for batch_id, test_data in enumerate(test_data_loader):
      with torch.no_grad():
          haze, image_name = test_data
          #haze = F.interpolate(haze, scale_factor = 0.25)
          haze = haze.to(device)
          dehaze, _ = G1(haze)
  
          
  # --- Save image --- #
      save_image(dehaze,image_name)

  print('success')


In [None]:
import gc

gc.collect()

torch.cuda.empty_cache()

device_ids = [Id for Id in range(torch.cuda.device_count())]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = 'cpu'

test_batch_size = 1
network_height = 3
network_width = 6
num_dense_layer = 4
growth_rate = 16
test_phrase = 1


crop_size = [1600, 1200]
test_data = SubmissionData()
test_data_loader = DataLoader(test_data, batch_size=test_batch_size, num_workers =2)

test_epoch = 495

def save_image(dehaze, image_name):
    #dehaze_images = torch.split(dehaze, 1, dim=0)
    batch_num = len(dehaze)

    for ind in range(batch_num):
      utils.save_image(dehaze[ind], '/content/drive/MyDrive/lg_vision/submission/{}'.format(image_name[ind].split('/')[-1][:-3]+'png'))

with torch.no_grad():
  if test_phrase == 1:
    G1 = Generate_quarter(height=network_height, width=network_width, num_dense_layer=num_dense_layer, growth_rate=growth_rate)
    G1 = G1.to(device)
    G1 = nn.DataParallel(G1, device_ids=device_ids)
    G1.load_state_dict(torch.load('/content/drive/MyDrive/lg_vision/checkpoint/1_'+str(test_epoch)+'.tar'))
    G1.eval()

    for batch_id, test_data in enumerate(test_data_loader):
      haze, image_name = test_data
      haze = haze.to(device)
      print(haze)
      dehaze, _ = G1(haze)
      save_image(dehaze,image_name)
    print('success')


In [None]:
torch.cuda.get_device_properties(device).total_memory


In [None]:
torch.cuda.memory_allocated()

In [None]:
st ='/content/drive/MyDrive/lg_vision/data/train_input_img/train_input_10500.png'
print(st.split('/')[-1])