In [1]:
from __future__ import print_function
import argparse
import os 
import sys
import random

import time
import datetime

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.utils as utils
import torchvision.utils as vutils    

from new_model_final import bn_init_as_tf, weights_init_xavier,encoder, decoder_T, decoder_A, refinement_final
from data_final import UnNormalize#,getTestLoader_v2,getValLoader, UnNormalize
from utils import AverageMeter
# from tqdm import tqdm
from collections import OrderedDict

import matplotlib.pyplot as plt
import numpy as np
from skimage.measure import compare_psnr, compare_ssim
import cv2
from os.path import join

In [2]:
import torch.nn.functional as F
import math
from skimage.util.shape import view_as_blocks
from os import listdir
from PIL import Image,ImageEnhance
from torch.utils.data.dataset import Dataset
from torchvision import transforms

In [3]:
def is_image_file(filename):
    return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG','.bmp'])
def _is_pil_image(img):
    return isinstance(img, Image.Image)
def _is_numpy_image(img):
    return isinstance(img, np.ndarray) and (img.ndim in {2, 3})

In [17]:
class Param:
    def __init__(self):
        self.test_dir = './real_haze'
        self.checkpoint_path = './exp3_epoch92.pth'
        self.image_size=320
        self.blocks=3
        self.retrain=False
        
opt= Param()
print(opt.checkpoint_path)
print(opt.test_dir)
print(opt.blocks)

./exp3_epoch92.pth
./real_haze
3


# loading and testing

In [5]:
def gcd(x,y):
    while(y):
        x , y = y , x % y
    return x

def lcm(x,y):
    
    result = (x*y)//gcd(x,y)
    
    return result

In [6]:
# original testset + overlap
class CreateTestDataSet(Dataset):
    def __init__(self, test_dir, block_num = 3,transform=None, overlap=False):
        self.dir= test_dir
        self.transform=transform
        self.data_files = {'haze':[]}
        self.img_ratio= lcm(32,block_num)
        self.block_num = block_num
        self.overlap=overlap
#         self.data_files = {'haze':[],'GT':[]}

        for key in self.data_files.keys():
            subdir = join(self.dir, key)
            self.data_files[key] += [join(subdir,x) for x in listdir(subdir) if is_image_file(x) ]
            # self.data_files[key].sort(key=lambda f: int(''.join(filter(str.isdigit, f))))

    def __getitem__(self, index):
        haze_name = self.data_files['haze'][index]
        haze = Image.open(haze_name).convert('RGB')

        if self.transform:
            # apply transform to each sample in data_files
            haze = self.transform(haze)
            
        c,h,w = haze.shape
        new_h, new_w = self.img_ratio*math.ceil(h/self.img_ratio), self.img_ratio*math.ceil(w/self.img_ratio)
        h_block,w_block = (new_h//self.block_num), (new_w//self.block_num)
        if self.overlap:
            new_h, new_w = new_h+h_block, new_w+w_block
        
        h_pad1, h_pad2 = 0,0
        w_pad1, w_pad2= 0,0
        if new_h != h:
            pad= new_h-h
            h_pad1, h_pad2 = pad//2, (pad-pad//2)
        if new_w != w:
            pad= new_w-w
            w_pad1, w_pad2 = pad//2, (pad-pad//2)
            
        haze_pad = F.pad(haze.unsqueeze(0), (w_pad1, w_pad2, h_pad1, h_pad2), mode='reflect').squeeze(0)
        all_blocks = view_as_blocks(haze_pad.numpy(), block_shape=(c,h_block,w_block)).reshape(-1,c,h_block,w_block)
#         print("all blocks size", all_blocks.shape)
        # all_blocks: (total_block_num, channel, h, w)
        all_blocks = torch.from_numpy(all_blocks)
        
        info = (h,w,h_pad1,w_pad1)
        block_info = (h_block, w_block)      
        sample={'haze': all_blocks, 'info': info, 'block_info': block_info}       
        return sample
    def __len__(self):
        return len(self.data_files['haze'])

In [7]:
def getTestLoader(test_dir, blocks=3, overlap=False):
    # val_test_transform = transforms.Resize((image_size,image_size))
    test_transform = transforms.Compose([
                                             transforms.ToTensor(),
                                             transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
                                             # Scale((image_size, image_size), use_trans_atmos=trans_atmos),
#                                              ToTensor(use_trans_atmos = trans_atmos),
#                                              Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],use_trans_atmos = trans_atmos)
                                        ])
    test_set = CreateTestDataSet(test_dir, blocks,test_transform, overlap)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=False, num_workers=0)
    return test_loader, test_set.__len__()

def getTestLoader_OverLap(test_dir, blocks=3, overlap=True):
    # val_test_transform = transforms.Resize((image_size,image_size))
    test_transform = transforms.Compose([
                                             transforms.ToTensor(),
                                             transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
                                             # Scale((image_size, image_size), use_trans_atmos=trans_atmos),
#                                              ToTensor(use_trans_atmos = trans_atmos),
#                                              Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],use_trans_atmos = trans_atmos)
                                        ])
    test_set = CreateTestDataSet(test_dir, blocks,test_transform, overlap)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=False, num_workers=0)
    return test_loader, test_set.__len__()


In [8]:
def normalize_to_0_1(x):
    """
    normalize input tensor x to [0,1]
    x: tensor, (B,C,H,W)
    """
    b,c,h,w = x.shape
    x = x.view(b,-1)
    x -= x.min(1, keepdim=True)[0]
    x /= x.max(1, keepdim=True)[0]
    if torch.any(torch.isnan(x)):
        # avoid NaNs caused by dividing 0, should not happen
        x[torch.isnan(x)]=1.0
        print("divide by 0")
    x = x.view(b,c,h,w)
    return x

In [9]:
class Model_Test(nn.Module):
    def __init__(self):
        super(Model_Test, self).__init__()
        self.encoder = encoder()
        self.decoder_T = decoder_T(1,self.encoder.feat_out_channels)
        self.decoder_A = decoder_A(3,self.encoder.feat_out_channels, return_value=True)
        self.generate_dehaze = refinement_final()
        self.unnormalize_fun = UnNormalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))

    def forward(self, x):
        resized_x = F.interpolate(x, size = (320,320), mode='bilinear',align_corners=True)
        skip_feat = self.encoder(x)
        trans_map = self.decoder_T(skip_feat)
        print("In Model Test: trans_map shape:", trans_map.shape)
        trans_map = trans_map.repeat(1,3,1,1)
        
        atmos_skip_feat = self.encoder(resized_x)
        atmos_light = self.decoder_A(atmos_skip_feat)
        atmos_light = atmos_light.unsqueeze(-1).unsqueeze(-1)
        atmos_light = atmos_light.repeat(1,1,x.shape[-2],x.shape[-1])
        
        # Unnormalize haze images
        hazes_unnormalized = self.unnormalize_fun(x)
        # Reconstruct clean images
        nonhaze_rec = (hazes_unnormalized-atmos_light*(1-trans_map))
        nonhaze_rec = nonhaze_rec/trans_map
        nonhaze_rec = torch.clamp(nonhaze_rec, 0.0, 1.0)
            
        # Refinement Module
        nonhaze_refinement = self.generate_dehaze(nonhaze_rec, hazes_unnormalized, trans_map, atmos_light)
        
        return trans_map, atmos_light, nonhaze_rec, nonhaze_refinement

In [10]:
img_list=['aerial','castle','cityscape','cliff','forest','highquality13','img33','img54','img69','landscape','lviv','manhattan1','manhattan2','redbrickshouse','road','swans','yosemite1']
# img_list=['cityscape','landscape']
print(img_list)

['aerial', 'castle', 'cityscape', 'cliff', 'forest', 'highquality13', 'img33', 'img54', 'img69', 'landscape', 'lviv', 'manhattan1', 'manhattan2', 'redbrickshouse', 'road', 'swans', 'yosemite1']


In [11]:
print(opt.checkpoint_path)
print(opt.test_dir)
print(opt.blocks)

./exp3_epoch92.pth
./real_haze1
3


In [12]:
checkpoint = torch.load(opt.checkpoint_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu"
print("device:", device)

device: cuda


In [13]:
model = Model_Test()
model = model.to(device)

In [14]:
# Loading dehaze model ( encoder + decoder_D + decoder_A)
if opt.checkpoint_path != '' and os.path.isfile(opt.checkpoint_path):
    print("Loading dehaze checkpoint '{}'".format(opt.checkpoint_path))
    checkpoint = torch.load(opt.checkpoint_path)

    start_epoch = checkpoint['epoch']+1
    iteration = checkpoint['iteration']+1
    model.load_state_dict(checkpoint['model_state_dict'])
    print('Loading dehaze checkpoint path:', opt.checkpoint_path)
    # Retrain the loading model
    if opt.retrain:
        start_epoch =1
        iteration=0
        print('Restart Training from epoch %d!' % start_epoch)
    else:
        print('Continuing training at epoch %d' % start_epoch)
else:
    start_epoch = 1
    iteration = 0

Loading dehaze checkpoint './exp3_epoch92.pth'
Loading dehaze checkpoint path: ./exp3_epoch92.pth
Continuing training at epoch 93


In [15]:
model.eval()

Model_Test(
  (encoder): encoder(
    (base_model): Sequential(
      (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu0): ReLU(inplace=True)
      (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (denseblock1): _DenseBlock(
        (denselayer1): _DenseLayer(
          (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu1): ReLU(inplace=True)
          (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu2): ReLU(inplace=True)
          (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        )
        (denselayer2): _DenseLayer(
          (norm1): BatchNorm2d(96, eps=1e-05, momentum

# full pass 3x3

In [18]:
print("block num:",opt.blocks)
test_loader, test_num = getTestLoader( opt.test_dir,opt.blocks)
print("Testing Images Number:", test_num)

block num: 3
Testing Images Number: 17


In [19]:
# original
img_dir= join('./result0728_exp3_real_haze','full_pass3x3')
block_num= opt.blocks

with torch.no_grad():

        for i, test_batched in enumerate(test_loader,0):
            
            print("Image:",img_list[i])
            # Prepare sample and target
            # haze: (b,c,h,w) with (r,g,b) order
            haze = test_batched['haze'][0]
            h_origin,w_origin,h_pad,w_pad = test_batched['info']
            h_origin,w_origin,h_pad,w_pad = h_origin.item(),w_origin.item(),h_pad.item(),w_pad.item()
            h_block, w_block = test_batched['block_info']
            h_block, w_block = h_block.item(), w_block.item()
            print("haze shape:", haze.shape)
            print("haze info:",h_origin,w_origin,h_pad,w_pad)
            print("block info:",h_block, w_block)
            print()
            
            total_block_num,c,h,w = haze.shape
            
            trans_map, atmos_light,nonhaze_rec, dehaze = model(haze.to(device))
            print("trans_map shape:", trans_map.shape)
                
            trans_map_np = np.zeros(shape=(h*block_num,w*block_num,3))
            atmos_map_np = np.zeros(shape=(h*block_num,w*block_num,3))
            nonhaze_rec_np = np.zeros(shape=(h*block_num,w*block_num,3))
            dehaze_np = np.zeros(shape=(h*block_num,w*block_num,3))
            start_index=0
            
            for h_idx in range(block_num):
                for w_idx in range(block_num):
                    block = trans_map[start_index].cpu().numpy()
                    block = np.transpose(block, (1,2,0))
                    trans_map_np[h_idx*h_block:(h_idx+1)*h_block,w_idx*w_block:(w_idx+1)*w_block,:] = block
                    
                    block = atmos_light[start_index].cpu().numpy()
                    block = np.transpose(block, (1,2,0))
                    atmos_map_np[h_idx*h_block:(h_idx+1)*h_block,w_idx*w_block:(w_idx+1)*w_block,:] = block
                    
                    block = nonhaze_rec[start_index].cpu().numpy()
                    block = np.transpose(block, (1,2,0))
                    nonhaze_rec_np[h_idx*h_block:(h_idx+1)*h_block,w_idx*w_block:(w_idx+1)*w_block,:] = block

                    block = dehaze[start_index].cpu().numpy()
                    block = np.transpose(block, (1,2,0))
                    dehaze_np[h_idx*h_block:(h_idx+1)*h_block,w_idx*w_block:(w_idx+1)*w_block,:] = block
                    
                    start_index += 1

    
            trans_map_np = trans_map_np[h_pad:h_pad+h_origin,w_pad:w_pad+w_origin,:]
            atmos_map_np = atmos_map_np[h_pad:h_pad+h_origin,w_pad:w_pad+w_origin,:]
            nonhaze_rec_np = nonhaze_rec_np[h_pad:h_pad+h_origin,w_pad:w_pad+w_origin,:]
            dehaze_np = dehaze_np[h_pad:h_pad+h_origin,w_pad:w_pad+w_origin,:]
            
            trans_map_np = (255*trans_map_np).astype(np.uint8)[:,:,::-1]
            atmos_map_np = (255*atmos_map_np).astype(np.uint8)[:,:,::-1]
            nonhaze_rec_np = (255*nonhaze_rec_np).astype(np.uint8)[:,:,::-1]
            dehaze_np = (255*dehaze_np).astype(np.uint8)[:,:,::-1]
            
#             print("trans dtype:", trans_map_np.dtype)
#             print("atmos dtype:", atmos_map_np.dtype)
#             print("nonohaze rec dtype:", nonhaze_rec_np.dtype)
#             print("dehaze dtype:", dehaze_np.dtype)
            
#             cv2.imshow('Trans', trans_map_np)
#             cv2.imshow('Atmos', atmos_map_np)
#             cv2.imshow('Nonhaze', nonhaze_rec_np)
#             cv2.imshow('Dehaze', dehaze_np)
#             cv2.waitKey(0)
#             cv2.destroyAllWindows()
            
            atmos_intensity_np = cv2.cvtColor(atmos_map_np, cv2.COLOR_RGB2GRAY)
    
            # because opencv is BGR order, we need to change RGB to BGR
            cv2.imwrite(join(img_dir,'dehaze',img_list[i]+'_dehaze.png'),dehaze_np)
            cv2.imwrite(join(img_dir,'nonhaze_rec',img_list[i]+'_nonhaze_rec.png'),nonhaze_rec_np)
            cv2.imwrite(join(img_dir,'trans',img_list[i]+'_trans.png'),trans_map_np)
            cv2.imwrite(join(img_dir,'atmos',img_list[i]+'_atmos.png'),atmos_map_np)
            cv2.imwrite(join(img_dir,'atmos_intensity',img_list[i] +'_atmos_intensity.png'),atmos_intensity_np)
            print()
#             break

Image: aerial
haze shape: torch.Size([9, 3, 160, 224])
haze info: 442 622 19 25
block info: 160 224

In Model Test: trans_map shape: torch.Size([9, 1, 160, 224])
trans_map shape: torch.Size([9, 3, 160, 224])

Image: castle
haze shape: torch.Size([9, 3, 224, 224])
haze info: 611 619 30 26
block info: 224 224

In Model Test: trans_map shape: torch.Size([9, 1, 224, 224])
trans_map shape: torch.Size([9, 3, 224, 224])

Image: cityscape
haze shape: torch.Size([9, 3, 224, 160])
haze info: 600 400 36 40
block info: 224 160

In Model Test: trans_map shape: torch.Size([9, 1, 224, 160])
trans_map shape: torch.Size([9, 3, 224, 160])

Image: cliff
haze shape: torch.Size([9, 3, 128, 192])
haze info: 384 576 0 0
block info: 128 192

In Model Test: trans_map shape: torch.Size([9, 1, 128, 192])
trans_map shape: torch.Size([9, 3, 128, 192])

Image: forest
haze shape: torch.Size([9, 3, 256, 352])
haze info: 768 1024 0 16
block info: 256 352

In Model Test: trans_map shape: torch.Size([9, 1, 256, 352])
tr

# full pass 1x1

In [20]:
opt.blocks=1
print(opt.checkpoint_path)
print(opt.test_dir)
print("block number:",opt.blocks)

./exp3_epoch92.pth
./real_haze
block number: 1


In [21]:
test_loader, test_num = getTestLoader( opt.test_dir,opt.blocks)
print("Testing Images Number:", test_num)

Testing Images Number: 17


In [22]:
# original
img_dir= join('./result0728_exp3_real_haze','full_pass1x1')
block_num= opt.blocks

with torch.no_grad():

        for i, test_batched in enumerate(test_loader,0):
            
            print("Image:",img_list[i])
            # Prepare sample and target
            # haze: (b,c,h,w) with (r,g,b) order
            haze = test_batched['haze'][0]
            h_origin,w_origin,h_pad,w_pad = test_batched['info']
            h_origin,w_origin,h_pad,w_pad = h_origin.item(),w_origin.item(),h_pad.item(),w_pad.item()
            h_block, w_block = test_batched['block_info']
            h_block, w_block = h_block.item(), w_block.item()
            print("haze shape:", haze.shape)
            print("haze info:",h_origin,w_origin,h_pad,w_pad)
            print("block info:",h_block, w_block)
            print()
            
            total_block_num,c,h,w = haze.shape
            
            trans_map, atmos_light,nonhaze_rec, dehaze = model(haze.to(device))
            print("trans_map shape:", trans_map.shape)
                
            trans_map_np = np.zeros(shape=(h*block_num,w*block_num,3))
            atmos_map_np = np.zeros(shape=(h*block_num,w*block_num,3))
            nonhaze_rec_np = np.zeros(shape=(h*block_num,w*block_num,3))
            dehaze_np = np.zeros(shape=(h*block_num,w*block_num,3))
            start_index=0
            
            for h_idx in range(block_num):
                for w_idx in range(block_num):
                    block = trans_map[start_index].cpu().numpy()
                    block = np.transpose(block, (1,2,0))
                    trans_map_np[h_idx*h_block:(h_idx+1)*h_block,w_idx*w_block:(w_idx+1)*w_block,:] = block
                    
                    block = atmos_light[start_index].cpu().numpy()
                    block = np.transpose(block, (1,2,0))
                    atmos_map_np[h_idx*h_block:(h_idx+1)*h_block,w_idx*w_block:(w_idx+1)*w_block,:] = block
                    
                    block = nonhaze_rec[start_index].cpu().numpy()
                    block = np.transpose(block, (1,2,0))
                    nonhaze_rec_np[h_idx*h_block:(h_idx+1)*h_block,w_idx*w_block:(w_idx+1)*w_block,:] = block

                    block = dehaze[start_index].cpu().numpy()
                    block = np.transpose(block, (1,2,0))
                    dehaze_np[h_idx*h_block:(h_idx+1)*h_block,w_idx*w_block:(w_idx+1)*w_block,:] = block
                    
                    start_index += 1

    
            trans_map_np = trans_map_np[h_pad:h_pad+h_origin,w_pad:w_pad+w_origin,:]
            atmos_map_np = atmos_map_np[h_pad:h_pad+h_origin,w_pad:w_pad+w_origin,:]
            nonhaze_rec_np = nonhaze_rec_np[h_pad:h_pad+h_origin,w_pad:w_pad+w_origin,:]
            dehaze_np = dehaze_np[h_pad:h_pad+h_origin,w_pad:w_pad+w_origin,:]
            
            trans_map_np = (255*trans_map_np).astype(np.uint8)[:,:,::-1]
            atmos_map_np = (255*atmos_map_np).astype(np.uint8)[:,:,::-1]
            nonhaze_rec_np = (255*nonhaze_rec_np).astype(np.uint8)[:,:,::-1]
            dehaze_np = (255*dehaze_np).astype(np.uint8)[:,:,::-1]
            
#             print("trans dtype:", trans_map_np.dtype)
#             print("atmos dtype:", atmos_map_np.dtype)
#             print("nonohaze rec dtype:", nonhaze_rec_np.dtype)
#             print("dehaze dtype:", dehaze_np.dtype)
            
#             cv2.imshow('Trans', trans_map_np)
#             cv2.imshow('Atmos', atmos_map_np)
#             cv2.imshow('Nonhaze', nonhaze_rec_np)
#             cv2.imshow('Dehaze', dehaze_np)
#             cv2.waitKey(0)
#             cv2.destroyAllWindows()
            
            atmos_intensity_np = cv2.cvtColor(atmos_map_np, cv2.COLOR_RGB2GRAY)
    
            # because opencv is BGR order, we need to change RGB to BGR
            cv2.imwrite(join(img_dir,'dehaze',img_list[i]+'_dehaze.png'),dehaze_np)
            cv2.imwrite(join(img_dir,'nonhaze_rec',img_list[i]+'_nonhaze_rec.png'),nonhaze_rec_np)
            cv2.imwrite(join(img_dir,'trans',img_list[i]+'_trans.png'),trans_map_np)
            cv2.imwrite(join(img_dir,'atmos',img_list[i]+'_atmos.png'),atmos_map_np)
            cv2.imwrite(join(img_dir,'atmos_intensity',img_list[i] +'_atmos_intensity.png'),atmos_intensity_np)
            print()
#             break

Image: aerial
haze shape: torch.Size([1, 3, 448, 640])
haze info: 442 622 3 9
block info: 448 640

In Model Test: trans_map shape: torch.Size([1, 1, 448, 640])
trans_map shape: torch.Size([1, 3, 448, 640])

Image: castle
haze shape: torch.Size([1, 3, 640, 640])
haze info: 611 619 14 10
block info: 640 640

In Model Test: trans_map shape: torch.Size([1, 1, 640, 640])
trans_map shape: torch.Size([1, 3, 640, 640])

Image: cityscape
haze shape: torch.Size([1, 3, 608, 416])
haze info: 600 400 4 8
block info: 608 416

In Model Test: trans_map shape: torch.Size([1, 1, 608, 416])
trans_map shape: torch.Size([1, 3, 608, 416])

Image: cliff
haze shape: torch.Size([1, 3, 384, 576])
haze info: 384 576 0 0
block info: 384 576

In Model Test: trans_map shape: torch.Size([1, 1, 384, 576])
trans_map shape: torch.Size([1, 3, 384, 576])

Image: forest
haze shape: torch.Size([1, 3, 768, 1024])
haze info: 768 1024 0 0
block info: 768 1024

In Model Test: trans_map shape: torch.Size([1, 1, 768, 1024])
tran

# blur atmos

In [23]:
# final atmos blending function version
def blending_row(left,right, percent=0.5, max_value=1.0):
    h,w,c = left.shape
    width = int(w*percent)
    width_max = (w-width)//2
    width_min = w - width - width_max
#     print("h,w,c:", h,w,c)
#     print("width:", width)
    left_mask = np.concatenate((np.linspace(max_value, max_value, num=width_max),np.linspace(max_value, 0., num=width)))
    left_mask = np.concatenate((left_mask,np.linspace(0.0,0.0,num=width_min))).reshape((1,-1))
    left_mask = np.repeat(left_mask, repeats=h, axis=0).reshape((h,-1,1))
    right_mask = np.concatenate((np.linspace(0.0,0.0,num=width_min),np.linspace(0., max_value, num=width)))
    right_mask = np.concatenate((right_mask,np.linspace(max_value, max_value, num=width_max))).reshape((1,-1))
    right_mask = np.repeat(right_mask, repeats=h, axis=0).reshape((h,-1,1))
#     print(left_mask.shape)
#     print(right_mask.shape)
    
#     left_mask = cv2.GaussianBlur(left_mask,(51,51),25).reshape(left_mask.shape)
#     right_mask = cv2.GaussianBlur(right_mask,(51,51),25).reshape(right_mask.shape)

    new_left = left_mask*left
    new_right = right_mask*right
    result = new_left + new_right
#     cv2.imshow('new left block', np.hstack((left, new_left)))
#     cv2.imshow('new right block', np.hstack((right, new_right)))
#     cv2.imshow('new result', result )
#     cv2.waitKey(0)
#     cv2.destroyAllWindows()
    if np.max(result) > 1.0:
        print("exceed boundary")
        result = np.clip(result,0.0,1.0)
    return result   

In [24]:
# final atmos blending function version
def blending_col(up,down, percent=0.5, max_value=1.0):
    h,w,c = up.shape
    height = int(h*percent)
    height_max= (h-height)//2
    height_min = h - height - height_max
#     print("h,w,c:", h,w,c)
#     print("width:", width)
    up_mask = np.concatenate((np.linspace(max_value, max_value, num=height_max),np.linspace(max_value, 0., num=height)))
    up_mask = np.concatenate((up_mask,np.linspace(0.0,0.0,num=height_min))).reshape((-1,1))
    up_mask = np.repeat(up_mask, repeats=w, axis=1).reshape((-1,w,1))
    down_mask = np.concatenate((np.linspace(0.0,0.0,num=height_min),np.linspace(0., max_value, num=height)))
    down_mask = np.concatenate((down_mask,np.linspace(max_value, max_value, num=height_max))).reshape((-1,1))
    down_mask = np.repeat(down_mask, repeats=w, axis=1).reshape((-1,w,1))
    
        
#     up_mask = cv2.GaussianBlur(up_mask,(51,51),25).reshape(up_mask.shape)
#     down_mask = cv2.GaussianBlur(down_mask,(51,51),25).reshape(down_mask.shape)
    
#     print(up_mask.shape)
#     print(down_mask.shape)
    new_up = up_mask*up
    new_down = down_mask*down
    result = new_up + new_down
    
#     cv2.imshow('new up block', np.hstack((up, new_up)))
#     cv2.imshow('new down block', np.hstack((down, new_down)))
#     cv2.imshow('new result', result )
#     cv2.waitKey(0)
#     cv2.destroyAllWindows()
    
    if np.max(result) > 1.0:
        print("exceed boundary")
        result = np.clip(result,0.0,1.0)
    return result   

In [25]:
haze_dir = './real_haze/haze/'
haze_name = []
for root, dirs, files in os.walk(haze_dir):
    for file in files:
#         print(file)
        haze_name.append(join(haze_dir,file))

print(len(haze_name))
print(haze_name)

17
['./real_haze/haze/aerial_input.bmp', './real_haze/haze/castle_input.png', './real_haze/haze/cityscape_input.png', './real_haze/haze/cliff_input.jpg', './real_haze/haze/forest_input.png', './real_haze/haze/highquality13.png', './real_haze/haze/img33.png', './real_haze/haze/img54.jpg', './real_haze/haze/img69.jpg', './real_haze/haze/landscape_input.jpg', './real_haze/haze/lviv_input.png', './real_haze/haze/manhattan1_input.jpg', './real_haze/haze/manhattan2_input.png', './real_haze/haze/redbrickshouse_input.bmp', './real_haze/haze/road_input.png', './real_haze/haze/swans_input.png', './real_haze/haze/yosemite1_input.png']


In [26]:
atmos_dir = './result0728_exp3_real_haze/full_pass3x3/atmos/'
atmos_name = []
for root, dirs, files in os.walk(atmos_dir):
    for file in files:
#         print(file)
        atmos_name.append(join(atmos_dir,file))

print(len(atmos_name))
print(atmos_name)

17
['./result0728_exp3_real_haze/full_pass3x3/atmos/aerial_atmos.png', './result0728_exp3_real_haze/full_pass3x3/atmos/castle_atmos.png', './result0728_exp3_real_haze/full_pass3x3/atmos/cityscape_atmos.png', './result0728_exp3_real_haze/full_pass3x3/atmos/cliff_atmos.png', './result0728_exp3_real_haze/full_pass3x3/atmos/forest_atmos.png', './result0728_exp3_real_haze/full_pass3x3/atmos/highquality13_atmos.png', './result0728_exp3_real_haze/full_pass3x3/atmos/img33_atmos.png', './result0728_exp3_real_haze/full_pass3x3/atmos/img54_atmos.png', './result0728_exp3_real_haze/full_pass3x3/atmos/img69_atmos.png', './result0728_exp3_real_haze/full_pass3x3/atmos/landscape_atmos.png', './result0728_exp3_real_haze/full_pass3x3/atmos/lviv_atmos.png', './result0728_exp3_real_haze/full_pass3x3/atmos/manhattan1_atmos.png', './result0728_exp3_real_haze/full_pass3x3/atmos/manhattan2_atmos.png', './result0728_exp3_real_haze/full_pass3x3/atmos/redbrickshouse_atmos.png', './result0728_exp3_real_haze/full_p

In [27]:
# original method
dst_dir = './result0728_exp3_real_haze/full_pass3x3/'
# img_ratio=160 # 5x5 blocks
img_ratio = 32*3 # 3x3 blocks
# num_block=5
num_block=3 # 3x3 blocks
for k in range(len(haze_name)):
    haze = cv2.imread(haze_name[k])
#     print(haze.shape)
    h,w,c = haze.shape
    h_pad1, h_pad2 = 0,0
    w_pad1, w_pad2= 0,0
    if h%img_ratio != 0:
        pad= img_ratio *(h//img_ratio+1)-h
        h_pad1, h_pad2 = pad//2, (pad-pad//2)
    if w%img_ratio != 0:
        pad= img_ratio*(w//img_ratio+1)-w
        w_pad1, w_pad2 = pad//2, (pad-pad//2)
    info = (h,w,h_pad1,w_pad1)
    print("info:", info)
    # (b,g,r) order with value [0,1]
    atmos = cv2.imread(atmos_name[k])
#     atmos = cv2.GaussianBlur(atmos,(51,51),25).reshape(atmos.shape)
    atmos = atmos/255.
    atmos_pad = np.pad(atmos, ((h_pad1, h_pad2),(w_pad1, w_pad2),(0,0)), mode='symmetric')
    print(atmos_pad.shape)
    # h_pad,w_pad,c_pad = atmos_pad.shape
    h_block, w_block = atmos_pad.shape[0]//num_block, atmos_pad.shape[1]//num_block
    print("h_block, w_block:", h_block, w_block)
    
    
    new_atmos = atmos_pad.copy()
#     print("new atmos:", new_atmos.shape)
#     cv2.imshow('new_atmos', new_atmos)
#     cv2.waitKey(0)
#     cv2.destroyAllWindows()   
    
    for i in range(num_block):
        for j in range(num_block-1):
            left = atmos_pad[i*h_block:(i+1)*h_block, j*w_block:(j+1)*w_block,:]
            right = atmos_pad[i*h_block:(i+1)*h_block, (j+1)*w_block:(j+2)*w_block,:]
            new_atmos[i*h_block:(i+1)*h_block,w_block//2+j*w_block:w_block//2+(j+1)*w_block,:] = blending_row(left,right,percent=0.5, max_value=1.0)
            
    
    
    final_atmos = new_atmos.copy()
    print("final atmos:", final_atmos.shape)
    for i in range(num_block-1):
        up = new_atmos[i*h_block:(i+1)*h_block,:,:]
        down = new_atmos[(i+1)*h_block:(i+2)*h_block,:,:]
        final_atmos[h_block//2+i*h_block:h_block//2+(i+1)*h_block,:,:] = blending_col(up,down,percent=0.5,max_value=1.0)

    final_atmos_blur = cv2.GaussianBlur(final_atmos,(51,51),25).reshape(final_atmos.shape)
#     final_atmos_blur = cv2.GaussianBlur(final_atmos_blur,(51,51),25).reshape(final_atmos.shape)
            
#     cv2.imshow('original', atmos_pad)
#     cv2.imshow('new', new_atmos)
#     cv2.imshow('zero', zeros)
#     cv2.imshow('atmos', atmos_pad)
#     cv2.imshow('blending row', new_atmos)
#     cv2.imshow('final', final_atmos)
#     cv2.imshow('final blur', final_atmos_blur)
#     cv2.waitKey(0)
#     cv2.destroyAllWindows()
    final_atmos = (255*final_atmos[h_pad1:h_pad1+h,w_pad1:w_pad1+w,:]).astype(np.uint8)
    final_atmos_blur = (255*final_atmos_blur[h_pad1:h_pad1+h,w_pad1:w_pad1+w,:]).astype(np.uint8)
         
    cv2.imwrite(join(dst_dir,'atmos_linear_blur',img_list[k]+'_atmos_blur.png'),final_atmos)
    cv2.imwrite(join(dst_dir,'atmos_linear_blur_gaussian',img_list[k]+'_atmos_blur.png'),final_atmos_blur)
#     break
    

info: (442, 622, 19, 25)
(480, 672, 3)
h_block, w_block: 160 224
final atmos: (480, 672, 3)
info: (611, 619, 30, 26)
(672, 672, 3)
h_block, w_block: 224 224
final atmos: (672, 672, 3)
info: (600, 400, 36, 40)
(672, 480, 3)
h_block, w_block: 224 160
final atmos: (672, 480, 3)
info: (384, 576, 0, 0)
(384, 576, 3)
h_block, w_block: 128 192
final atmos: (384, 576, 3)
info: (768, 1024, 0, 16)
(768, 1056, 3)
h_block, w_block: 256 352
final atmos: (768, 1056, 3)
info: (576, 768, 0, 0)
(576, 768, 3)
h_block, w_block: 192 256
final atmos: (576, 768, 3)
info: (768, 576, 0, 0)
(768, 576, 3)
h_block, w_block: 256 192
final atmos: (768, 576, 3)
info: (763, 1180, 2, 34)
(768, 1248, 3)
h_block, w_block: 256 416
final atmos: (768, 1248, 3)
info: (1328, 1999, 8, 8)
(1344, 2016, 3)
h_block, w_block: 448 672
final atmos: (1344, 2016, 3)
info: (525, 600, 25, 36)
(576, 672, 3)
h_block, w_block: 192 224
final atmos: (576, 672, 3)
info: (1044, 1428, 6, 6)
(1056, 1440, 3)
h_block, w_block: 352 480
final atmos

In [None]:
# atmos_dir = './result0728_exp3_real_haze/full_pass3x3/atmos_linear_blur/'
# atmos_name = []
# for root, dirs, files in os.walk(atmos_dir):
#     for file in files:
# #         print(file)
#         atmos_name.append(join(atmos_dir,file))

# print(len(atmos_name))
# print(atmos_name[:3])


# atmos_origin_dir = './result0728_exp3_real_haze/full_pass3x3/atmos/'
# atmos_origin_name = []
# for root, dirs, files in os.walk(atmos_origin_dir):
#     for file in files:
# #         print(file)
#         atmos_origin_name.append(join(atmos_origin_dir,file))

# print(len(atmos_origin_name))
# print(atmos_origin_name[:3])

In [None]:
# img_ratio = 32*3 # 3x3 blocks
# block_num=3 # 3x3 blocks
# dst_dir = './result0728_exp3_real_haze/full_pass3x3/'

# for k in range(len(atmos_name)):
#     # (b,g,r) order with value [0,1]
#     atmos = cv2.imread(atmos_name[k])/255.
#     atmos_origin = cv2.imread(atmos_origin_name[k])/255.
#     h,w,c = atmos.shape
#     new_h, new_w = img_ratio*math.ceil(h/img_ratio), img_ratio*math.ceil(w/img_ratio)
#     h_block, w_block = (new_h//block_num), (new_w//block_num)
    
#     h_pad1, h_pad2 = 0,0
#     w_pad1, w_pad2= 0,0
#     if new_h != h:
#         pad= new_h-h
#         h_pad1, h_pad2 = pad//2, (pad-pad//2)
#     if new_w != w:
#         pad= new_w-w
#         w_pad1, w_pad2 = pad//2, (pad-pad//2)
    
#     info = (h,w,h_pad1,w_pad1)
#     block_info = (h_block, w_block)
#     print("info:", info)
#     print("block info:", block_info)
#     atmos_pad = np.pad(atmos, ((h_pad1, h_pad2),(w_pad1, w_pad2),(0,0)), mode='symmetric')
#     atmos_origin_pad = np.pad(atmos_origin, ((h_pad1, h_pad2),(w_pad1, w_pad2),(0,0)), mode='symmetric')
#     print("amos_pad:", atmos_pad.shape)
    
    
#     new_atmos = atmos_origin_pad.copy()
#     width = 50
#     start_idx = 5//2
#     # for each row
#     for i in range(1,num_block):
#             new_atmos[i*h_block-start_idx:i*h_block-start_idx+width,:,:] = atmos_pad[i*h_block-start_idx:i*h_block-start_idx+width,:,:]
#     # for each col
#     for i in range(1,num_block):
#             new_atmos[:,i*w_block-start_idx:i*w_block-start_idx+width,:] = atmos_pad[:,i*w_block-start_idx:i*w_block-start_idx+width,:]

    
#     cv2.imshow('original atmos', atmos_origin_pad)
#     cv2.imshow('new atmos', new_atmos)
#     cv2.waitKey(0)
#     cv2.destroyAllWindows()
    
#     new_atmos = (255*new_atmos[h_pad1:h_pad1+h,w_pad1:w_pad1+w,:]).astype(np.uint8)
#     cv2.imwrite(join(dst_dir,'atmos_final',img_list[k]+'_atmos_blur.png'),new_atmos)
    
# #     break

# 3x3 + blending atmos (failure)

In [None]:
# def CreateMask_atmos(h_block,w_block,num_block=3, max_value=0.9, overlap=False):

#     new_h, new_w = h_block*num_block, w_block*num_block
#     repeats = num_block-1
#     height, width = h_block//2, w_block//2
#     if overlap:
#         new_h, new_w = new_h+h_block, new_w + w_block
#         repeats = num_block
        
#     trans_mask = np.zeros((new_h,new_w,1))
    
# #     max_mask = np.linspace(max_value,max_value, num=16)
#     min_mask = np.linspace(0.0,0.0, num=16)
    
#    # 3x3 mask
#     col_mask = np.concatenate((np.linspace(max_value, 0.,num=w_block//2),np.linspace(0., max_value,num=w_block//2)))
#     col_mask = np.tile(col_mask,reps=repeats)
#     col_mask = np.concatenate((np.linspace(max_value,max_value, num=width),col_mask))
#     col_mask = np.concatenate((col_mask,np.linspace(max_value,max_value, num=width)))
#     print("col_mask:", col_mask.shape)
    
#     row_mask = np.concatenate((np.linspace(max_value, 0.,num=h_block//2),np.linspace(0., max_value,num=h_block//2)))
#     row_mask = np.tile(row_mask,reps=repeats)
#     row_mask = np.concatenate((np.linspace(max_value,max_value, num=height),row_mask))
#     row_mask = np.concatenate((row_mask,np.linspace(max_value,max_value, num=height)))
#     print("row mask:", row_mask.shape)
                              
#     for i in range(trans_mask.shape[0]):
#         for j in range(trans_mask.shape[1]):
#             trans_mask[i,j] = min(row_mask[i], col_mask[j])
    
#     col_mask = np.repeat(col_mask.reshape((-1,new_w,1)), repeats=new_h, axis=0)
#     row_mask = np.repeat(row_mask.reshape((new_h,-1,1)), repeats=new_w, axis=1)
#     trans_mask_blur = cv2.GaussianBlur(trans_mask,(51,51),25).reshape((new_h,new_w,1))
    
# #     cv2.imshow('col mask', col_mask)
# #     cv2.imshow('row mask', row_mask)
# #     cv2.imshow('mask', trans_mask)
# #     cv2.imshow('mask blur', trans_mask_blur)
# #     cv2.waitKey(0)
# #     cv2.destroyAllWindows()
#     return trans_mask_blur

In [None]:
# # trans_dir = './result0728_exp3_real_haze/full_pass3x3/trans/'
# trans_dir = './result0728_exp3_real_haze/full_pass3x3/atmos/'
# trans_name = []
# for root, dirs, files in os.walk(trans_dir):
#     for file in files:
# #         print(file)
#         trans_name.append(join(trans_dir,file))

# print(len(trans_name))
# print(trans_name[:3])

# # self.data_files[key].sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
# # trans_name.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
# # print(trans_name[:3])

# # trans_origin_dir = './result0728_exp3_real_haze/full_pass1x1/trans/'
# trans_origin_dir = './result0728_exp3_real_haze/full_pass3x3/atmos_linear_blur/'
# # trans_origin_dir = './result0728_exp3_real_haze/full_pass3x3/atmos_linear_blur/'
# trans_origin_name = []
# for root, dirs, files in os.walk(trans_origin_dir):
#     for file in files:
# #         print(file)
#         trans_origin_name.append(join(trans_origin_dir,file))

# print(len(trans_origin_name))
# print(trans_origin_name[:3])

In [None]:
# # img_ratio=160 # 5x5 blocks
# img_ratio = 32*3 # 3x3 blocks
# # num_block=5
# block_num=3 # 3x3 blocks
# dst_dir = './result0728_exp3_real_haze/full_pass3x3/'

# for k in range(len(trans_name)):
#     # (b,g,r) order with value [0,1]
#     trans = cv2.imread(trans_name[k])/255.
#     trans_origin = cv2.imread(trans_origin_name[k])/255.
#     h,w,c = trans.shape
#     new_h, new_w = img_ratio*math.ceil(h/img_ratio), img_ratio*math.ceil(w/img_ratio)
#     h_block, w_block = (new_h//block_num), (new_w//block_num)
    
#     h_pad1, h_pad2 = 0,0
#     w_pad1, w_pad2= 0,0
#     if new_h != h:
#         pad= new_h-h
#         h_pad1, h_pad2 = pad//2, (pad-pad//2)
#     if new_w != w:
#         pad= new_w-w
#         w_pad1, w_pad2 = pad//2, (pad-pad//2)
    
#     info = (h,w,h_pad1,w_pad1)
#     block_info = (h_block, w_block)
#     print("info:", info)
#     print("block info:", block_info)
#     trans_pad = np.pad(trans, ((h_pad1, h_pad2),(w_pad1, w_pad2),(0,0)), mode='symmetric')
#     trans_origin_pad = np.pad(trans_origin, ((h_pad1, h_pad2),(w_pad1, w_pad2),(0,0)), mode='symmetric')
#     print("trans_pad:", trans_pad.shape)
    
    
#     new_trans = trans_pad.copy()
#     trans_mask = np.zeros((new_h,new_w,1))
#     print("trans mask:", trans_mask.shape)
            

#     trans_mask_blur = CreateMask_atmos(h_block, w_block, num_block=3, max_value=1.0)
#     new_result = trans_mask_blur*trans_pad + (1-trans_mask_blur)*trans_origin_pad
    
# #     cv2.imshow('mask blur', trans_mask_blur)
# #     cv2.imshow('new_result', new_result)
# #     cv2.waitKey(0)
# #     cv2.destroyAllWindows()
#     new_result = (255*new_result[h_pad1:h_pad1+h,w_pad1:w_pad1+w,:]).astype(np.uint8)
#     mask = (255*trans_mask_blur[h_pad1:h_pad1+h,w_pad1:w_pad1+w,:]).astype(np.uint8)
#     reverse_mask = 255-mask
#     cv2.imwrite(join(dst_dir,'atmos_blur_final',img_list[k]+'_atmos_blur.png'),new_result)
#     cv2.imwrite(join(dst_dir,'atmos_blur_final_mask',img_list[k]+'_mask.png'),mask)
#     cv2.imwrite(join(dst_dir,'atmos_blur_final_reverse_mask',img_list[k]+'_reverse_mask.png'),reverse_mask)
    
# #     break

# 3x3 + 1x1 trans blending

In [28]:
def CreateMask(h_block,w_block,num_block=3, max_value=0.9, overlap=False):

    new_h, new_w = h_block*num_block, w_block*num_block
    repeats = num_block-1
    height, width = h_block//2, w_block//2
    if overlap:
        new_h, new_w = new_h+h_block, new_w + w_block
        repeats = num_block
        
    trans_mask = np.zeros((new_h,new_w,1))
    
#     max_mask = np.linspace(max_value,max_value, num=16)
    min_mask = np.linspace(0.0,0.0, num=16)
    
   # 3x3 mask
    col_mask = np.concatenate((np.linspace(max_value, 0.,num=w_block//2),np.linspace(0., max_value,num=w_block//2)))
    col_mask = np.tile(col_mask,reps=repeats)
    col_mask = np.concatenate((np.linspace(max_value,max_value, num=width),col_mask))
    col_mask = np.concatenate((col_mask,np.linspace(max_value,max_value, num=width)))
    print("col_mask:", col_mask.shape)
    
    row_mask = np.concatenate((np.linspace(max_value, 0.,num=h_block//2),np.linspace(0., max_value,num=h_block//2)))
    row_mask = np.tile(row_mask,reps=repeats)
    row_mask = np.concatenate((np.linspace(max_value,max_value, num=height),row_mask))
    row_mask = np.concatenate((row_mask,np.linspace(max_value,max_value, num=height)))
    print("row mask:", row_mask.shape)
                              
    for i in range(trans_mask.shape[0]):
        for j in range(trans_mask.shape[1]):
            trans_mask[i,j] = min(row_mask[i], col_mask[j])
    
    col_mask = np.repeat(col_mask.reshape((-1,new_w,1)), repeats=new_h, axis=0)
    row_mask = np.repeat(row_mask.reshape((new_h,-1,1)), repeats=new_w, axis=1)
    trans_mask_blur = cv2.GaussianBlur(trans_mask,(51,51),25).reshape((new_h,new_w,1))
    
#     cv2.imshow('col mask', col_mask)
#     cv2.imshow('row mask', row_mask)
#     cv2.imshow('mask', trans_mask)
#     cv2.imshow('mask blur', trans_mask_blur)
#     cv2.waitKey(0)
#     cv2.destroyAllWindows()
    return trans_mask_blur

In [29]:
trans_dir = './result0728_exp3_real_haze/full_pass3x3/trans/'
trans_name = []
for root, dirs, files in os.walk(trans_dir):
    for file in files:
#         print(file)
        trans_name.append(join(trans_dir,file))

print(len(trans_name))
print(trans_name[:3])

# self.data_files[key].sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
# trans_name.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))
# print(trans_name[:3])

trans_origin_dir = './result0728_exp3_real_haze/full_pass1x1/trans/'
trans_origin_name = []
for root, dirs, files in os.walk(trans_origin_dir):
    for file in files:
#         print(file)
        trans_origin_name.append(join(trans_origin_dir,file))

print(len(trans_origin_name))
print(trans_origin_name[:3])

17
['./result0728_exp3_real_haze/full_pass3x3/trans/aerial_trans.png', './result0728_exp3_real_haze/full_pass3x3/trans/castle_trans.png', './result0728_exp3_real_haze/full_pass3x3/trans/cityscape_trans.png']
17
['./result0728_exp3_real_haze/full_pass1x1/trans/aerial_trans.png', './result0728_exp3_real_haze/full_pass1x1/trans/castle_trans.png', './result0728_exp3_real_haze/full_pass1x1/trans/cityscape_trans.png']


In [30]:
# img_ratio=160 # 5x5 blocks
img_ratio = 32*3 # 3x3 blocks
# num_block=5
block_num=3 # 3x3 blocks
dst_dir = './result0728_exp3_real_haze/full_pass3x3/'

for k in range(len(trans_name)):
    # (b,g,r) order with value [0,1]
    trans = cv2.imread(trans_name[k])/255.
    trans_origin = cv2.imread(trans_origin_name[k])/255.
    h,w,c = trans.shape
    new_h, new_w = img_ratio*math.ceil(h/img_ratio), img_ratio*math.ceil(w/img_ratio)
    h_block, w_block = (new_h//block_num), (new_w//block_num)
    
    h_pad1, h_pad2 = 0,0
    w_pad1, w_pad2= 0,0
    if new_h != h:
        pad= new_h-h
        h_pad1, h_pad2 = pad//2, (pad-pad//2)
    if new_w != w:
        pad= new_w-w
        w_pad1, w_pad2 = pad//2, (pad-pad//2)
    
    info = (h,w,h_pad1,w_pad1)
    block_info = (h_block, w_block)
    print("info:", info)
    print("block info:", block_info)
    trans_pad = np.pad(trans, ((h_pad1, h_pad2),(w_pad1, w_pad2),(0,0)), mode='symmetric')
    trans_origin_pad = np.pad(trans_origin, ((h_pad1, h_pad2),(w_pad1, w_pad2),(0,0)), mode='symmetric')
    print("trans_pad:", trans_pad.shape)
    
    
    new_trans = trans_pad.copy()
    trans_mask = np.zeros((new_h,new_w,1))
    print("trans mask:", trans_mask.shape)
            

    trans_mask_blur = CreateMask(h_block, w_block, num_block=3, max_value=0.7)
    new_result = trans_mask_blur*trans_pad + (1-trans_mask_blur)*trans_origin_pad
    
#     cv2.imshow('mask blur', trans_mask_blur)
#     cv2.imshow('new_result', new_result)
#     cv2.waitKey(0)
#     cv2.destroyAllWindows()
    new_result = (255*new_result[h_pad1:h_pad1+h,w_pad1:w_pad1+w,:]).astype(np.uint8)
    mask = (255*trans_mask_blur[h_pad1:h_pad1+h,w_pad1:w_pad1+w,:]).astype(np.uint8)
    reverse_mask = 255-mask
    cv2.imwrite(join(dst_dir,'trans_blur_final',img_list[k]+'_trans_blur.png'),new_result)
    cv2.imwrite(join(dst_dir,'trans_blur_final_mask',img_list[k]+'_mask.png'),mask)
    cv2.imwrite(join(dst_dir,'trans_blur_final_reverse_mask',img_list[k]+'_reverse_mask.png'),reverse_mask)
    
#     break

info: (442, 622, 19, 25)
block info: (160, 224)
trans_pad: (480, 672, 3)
trans mask: (480, 672, 1)
col_mask: (672,)
row mask: (480,)
info: (611, 619, 30, 26)
block info: (224, 224)
trans_pad: (672, 672, 3)
trans mask: (672, 672, 1)
col_mask: (672,)
row mask: (672,)
info: (600, 400, 36, 40)
block info: (224, 160)
trans_pad: (672, 480, 3)
trans mask: (672, 480, 1)
col_mask: (480,)
row mask: (672,)
info: (384, 576, 0, 0)
block info: (128, 192)
trans_pad: (384, 576, 3)
trans mask: (384, 576, 1)
col_mask: (576,)
row mask: (384,)
info: (768, 1024, 0, 16)
block info: (256, 352)
trans_pad: (768, 1056, 3)
trans mask: (768, 1056, 1)
col_mask: (1056,)
row mask: (768,)
info: (576, 768, 0, 0)
block info: (192, 256)
trans_pad: (576, 768, 3)
trans mask: (576, 768, 1)
col_mask: (768,)
row mask: (576,)
info: (768, 576, 0, 0)
block info: (256, 192)
trans_pad: (768, 576, 3)
trans mask: (768, 576, 1)
col_mask: (576,)
row mask: (768,)
info: (763, 1180, 2, 34)
block info: (256, 416)
trans_pad: (768, 1248, 3

# global trans + 3x3 local atmos blur

In [31]:
# original testset + overlap
class CreateTestDataSet_Dehaze(Dataset):
    def __init__(self, test_dir, block_num = 3,transform=None, overlap=False):
        self.dir= test_dir
        self.transform=transform
        self.data_files = {'haze':[]}
        self.img_ratio= lcm(32,block_num)
        self.block_num = block_num
        self.overlap=overlap
#         self.data_files = {'haze':[],'GT':[]}

        for key in self.data_files.keys():
            subdir = join(self.dir, key)
            self.data_files[key] += [join(subdir,x) for x in listdir(subdir) if is_image_file(x) ]
            # self.data_files[key].sort(key=lambda f: int(''.join(filter(str.isdigit, f))))

    def __getitem__(self, index):
        haze_name = self.data_files['haze'][index]
        haze = Image.open(haze_name).convert('RGB')

        if self.transform:
            # apply transform to each sample in data_files
            haze = self.transform(haze)
            
        c,h,w = haze.shape
        new_h, new_w = self.img_ratio*math.ceil(h/self.img_ratio), self.img_ratio*math.ceil(w/self.img_ratio)
        h_block,w_block = (new_h//self.block_num), (new_w//self.block_num)
        if self.overlap:
            new_h, new_w = new_h+h_block, new_w+w_block
        
        h_pad1, h_pad2 = 0,0
        w_pad1, w_pad2= 0,0
        if new_h != h:
            pad= new_h-h
            h_pad1, h_pad2 = pad//2, (pad-pad//2)
        if new_w != w:
            pad= new_w-w
            w_pad1, w_pad2 = pad//2, (pad-pad//2)
            
        haze_pad = F.pad(haze.unsqueeze(0), (w_pad1, w_pad2, h_pad1, h_pad2), mode='reflect').squeeze(0)
#         all_blocks = view_as_blocks(haze_pad.numpy(), block_shape=(c,h_block,w_block)).reshape(-1,c,h_block,w_block)
#         print("all blocks size", all_blocks.shape)
        # all_blocks: (total_block_num, channel, h, w)
#         all_blocks = torch.from_numpy(all_blocks)
        
        info = (h,w,h_pad1,w_pad1)
        block_info = (h_block, w_block)      
        sample={'haze': haze_pad, 'info': info, 'block_info': block_info}       
        return sample
    def __len__(self):
        return len(self.data_files['haze'])

In [32]:
def getTestLoader_Dehaze(test_dir, blocks=3, overlap=False):
    # val_test_transform = transforms.Resize((image_size,image_size))
    test_transform = transforms.Compose([
                                             transforms.ToTensor(),
                                             transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
                                             # Scale((image_size, image_size), use_trans_atmos=trans_atmos),
#                                              ToTensor(use_trans_atmos = trans_atmos),
#                                              Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],use_trans_atmos = trans_atmos)
                                        ])
    test_set = CreateTestDataSet_Dehaze(test_dir, blocks,test_transform, overlap)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=False, num_workers=0)
    return test_loader, test_set.__len__()

In [33]:
class Model_dehaze(nn.Module):
    def __init__(self):
        super(Model_dehaze, self).__init__()
        self.generate_dehaze = refinement_final()
        self.unnormalize_fun = UnNormalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))

    def forward(self, x, trans_map, atmos_light):
        # Unnormalize haze images
        hazes_unnormalized = self.unnormalize_fun(x)
        # Reconstruct clean images
        nonhaze_rec = (hazes_unnormalized-atmos_light*(1-trans_map))
        # nonhaze_rec = torch.clamp(nonhaze_rec, 0.0, 1.0)
        nonhaze_rec = nonhaze_rec/trans_map
        nonhaze_rec = torch.clamp(nonhaze_rec, 0.0, 1.0)
        # Refinement Module
        dehaze = self.generate_dehaze(nonhaze_rec, hazes_unnormalized, trans_map, atmos_light)
        
        return nonhaze_rec, dehaze

In [34]:
# read tans and atmos images
trans_dir = './result0728_exp3_real_haze/full_pass1x1/trans/'
trans_name = []
for root, dirs, files in os.walk(trans_dir):
    for file in files:
#         print(file)
        trans_name.append(join(trans_dir,file))

print(len(trans_name))
print(trans_name[:3])

atmos_dir = './result0728_exp3_real_haze/full_pass3x3/atmos_linear_blur_gaussian/'
atmos_name = []
for root, dirs, files in os.walk(atmos_dir):
    for file in files:
#         print(file)
        atmos_name.append(join(atmos_dir,file))

print(len(atmos_name))
print(atmos_name[:3])

17
['./result0728_exp3_real_haze/full_pass1x1/trans/aerial_trans.png', './result0728_exp3_real_haze/full_pass1x1/trans/castle_trans.png', './result0728_exp3_real_haze/full_pass1x1/trans/cityscape_trans.png']
17
['./result0728_exp3_real_haze/full_pass3x3/atmos_linear_blur_gaussian/aerial_atmos_blur.png', './result0728_exp3_real_haze/full_pass3x3/atmos_linear_blur_gaussian/castle_atmos_blur.png', './result0728_exp3_real_haze/full_pass3x3/atmos_linear_blur_gaussian/cityscape_atmos_blur.png']


In [35]:
test_loader, test_num = getTestLoader_Dehaze( './real_haze',1)
print("Testing Images Number:", test_num)

Testing Images Number: 17


In [36]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu"
print("device:", device)
checkpoint = torch.load('./exp3_epoch92_separate.pth')
model_dehaze = Model_dehaze()
model_dehaze = model_dehaze.to(device)
model_dehaze.generate_dehaze.load_state_dict(checkpoint['dehaze'])
model_dehaze.eval()

device: cuda


Model_dehaze(
  (generate_dehaze): refinement_final(
    (conv1): Sequential(
      (0): Conv2d(6, 58, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): PReLU(num_parameters=1)
    )
    (bn1): BatchNorm2d(64, eps=1.1e-05, momentum=0.01, affine=True, track_running_stats=True)
    (dense_block1): BottleneckBlock(
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(192, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (conv2): Sequential(
      (0): Conv2d(96, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): PReLU(num_parameters=1)
    )
    (residual_block11): ResidualBlock(
      (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (prelu)

In [37]:
img_dir= join('./result0728_exp3_real_haze','global_trans_local_atmos_blur')
nonhaze_rec_list = []
dehaze_list = []
trans_map_list=[]
atmos_map_list=[]

with torch.no_grad():

        for i, test_batched in enumerate(test_loader,0):
            print("Image:",img_list[i])
            # haze: (b,c,h,w) with (r,g,b) order
            haze = test_batched['haze']
            h_origin,w_origin,h_pad,w_pad = test_batched['info']
            h_origin,w_origin,h_pad,w_pad = h_origin.item(),w_origin.item(),h_pad.item(),w_pad.item()
#             print("reflected haze shape:", haze.shape)
#             print("haze info:",h_origin,w_origin,h_pad,w_pad)
            haze = haze[:,:,h_pad:h_pad+h_origin,w_pad:w_pad+w_origin]
            print("original haze shape:", haze.shape)
            # trans_map: (b,c,h,w) with (r,g,b) order
            trans_map = cv2.imread(trans_name[i])[:,:,::-1]
            trans_map = torch.from_numpy(trans_map/255.).float().unsqueeze(0)
            trans_map = trans_map.permute(0, 3, 1, 2)
            print("trans_map shape:", trans_map.shape)
#             trans_map = trans_map.to(device)
                      
            atmos_light = cv2.imread(atmos_name[i])[:,:,::-1]
            atmos_light = torch.from_numpy(atmos_light/255.).float().unsqueeze(0)
            atmos_light = atmos_light.permute(0, 3, 1, 2)
            print("atmos_light shape:", atmos_light.shape)
#             atmos_light = atmos_light.to(device)
            
#             print("haze datatype:", haze.dtype)
#             print("trans_map datatype:", trans_map.dtype)
#             print("atmos_light datatype:", atmos_light.dtype)
            
            # Predict
            nonhaze_rec, dehaze = model_dehaze(haze.to(device), trans_map.to(device), atmos_light.to(device))

            for j in range(dehaze.size(0)):
                # (C, H, W)
                # nonhaze_rec_img = nonhaze_rec[j].detach().cpu().clone().numpy()
                nonhaze_rec_img = nonhaze_rec[j].detach().cpu().clone().numpy()
                dehaze_rec_img = dehaze[j].detach().cpu().clone().numpy()
                trans_map_img = trans_map[j].detach().cpu().clone().numpy()
                atmos_map_img = atmos_light[j].detach().cpu().clone().numpy()

                # (H, W, C)
                nonhaze_rec_img = (255*np.transpose(nonhaze_rec_img, (1,2,0))).astype(np.uint8)
                dehaze_rec_img = (255*np.transpose(dehaze_rec_img, (1,2,0))).astype(np.uint8)
                trans_map_img = (255*np.transpose(trans_map_img, (1,2,0))).astype(np.uint8)
                atmos_map_img = (255*np.transpose(atmos_map_img, (1,2,0))).astype(np.uint8)
                
                # add for saving image results
                nonhaze_rec_list.append(nonhaze_rec_img)
                dehaze_list.append(dehaze_rec_img)
                trans_map_list.append(trans_map_img)
                atmos_map_list.append(atmos_map_img)
                
#                 cv2.imshow('Nonhaze Rec', nonhaze_rec_img[:,:,::-1])
#                 cv2.imshow('Dehaze Result', dehaze_rec_img[:,:,::-1])
#                 cv2.waitKey(0)
#                 cv2.destroyAllWindows()
            
            atmos_intensity_list=[]
            for idx in range(len(atmos_map_list)):
                atmos_intensity_list.append(cv2.cvtColor(atmos_map_list[idx], cv2.COLOR_RGB2GRAY))

            # because opencv is BGR order, we need to change RGB to BGR
            for i in range(len(nonhaze_rec_list)):
                cv2.imwrite(join(img_dir,'nonhaze_rec',img_list[i]+'_nonhaze_rec.png'),nonhaze_rec_list[i][:,:,::-1])
            for i in range(len(dehaze_list)):
                cv2.imwrite(join(img_dir,'dehaze',img_list[i]+'_dehaze.png'),dehaze_list[i][:,:,::-1])
            for i in range(len(trans_map_list)):
                cv2.imwrite(join(img_dir,'trans',img_list[i]+'_trans.png'),trans_map_list[i][:,:,::-1])
            for i in range(len(atmos_map_list)):
                cv2.imwrite(join(img_dir,'atmos',img_list[i]+'_atmos.png'),atmos_map_list[i][:,:,::-1])
            for i in range(len(atmos_intensity_list)):
                cv2.imwrite(join(img_dir,'atmos_intensity',img_list[i] +'_atmos_intensity.png'),atmos_intensity_list[i])
#             for i in range(len(gt_list)):
#                 cv2.imwrite(join(img_dir,'gt','gt'+str(i+1)+'.png'),gt_list[i][:,:,::-1])
            
#             break

Image: aerial
original haze shape: torch.Size([1, 3, 442, 622])
trans_map shape: torch.Size([1, 3, 442, 622])
atmos_light shape: torch.Size([1, 3, 442, 622])
Image: castle
original haze shape: torch.Size([1, 3, 611, 619])
trans_map shape: torch.Size([1, 3, 611, 619])
atmos_light shape: torch.Size([1, 3, 611, 619])
Image: cityscape
original haze shape: torch.Size([1, 3, 600, 400])
trans_map shape: torch.Size([1, 3, 600, 400])
atmos_light shape: torch.Size([1, 3, 600, 400])
Image: cliff
original haze shape: torch.Size([1, 3, 384, 576])
trans_map shape: torch.Size([1, 3, 384, 576])
atmos_light shape: torch.Size([1, 3, 384, 576])
Image: forest
original haze shape: torch.Size([1, 3, 768, 1024])
trans_map shape: torch.Size([1, 3, 768, 1024])
atmos_light shape: torch.Size([1, 3, 768, 1024])
Image: highquality13
original haze shape: torch.Size([1, 3, 576, 768])
trans_map shape: torch.Size([1, 3, 576, 768])
atmos_light shape: torch.Size([1, 3, 576, 768])
Image: img33
original haze shape: torch.

# local trans + global atmos

In [38]:
# read tans and atmos images
trans_dir = './result0728_exp3_real_haze/full_pass3x3/trans_blur_final/'
trans_name = []
for root, dirs, files in os.walk(trans_dir):
    for file in files:
#         print(file)
        trans_name.append(join(trans_dir,file))

print(len(trans_name))
print(trans_name[:3])

atmos_dir = './result0728_exp3_real_haze/full_pass1x1/atmos/'
atmos_name = []
for root, dirs, files in os.walk(atmos_dir):
    for file in files:
#         print(file)
        atmos_name.append(join(atmos_dir,file))

print(len(atmos_name))
print(atmos_name[:3])

17
['./result0728_exp3_real_haze/full_pass3x3/trans_blur_final/aerial_trans_blur.png', './result0728_exp3_real_haze/full_pass3x3/trans_blur_final/castle_trans_blur.png', './result0728_exp3_real_haze/full_pass3x3/trans_blur_final/cityscape_trans_blur.png']
17
['./result0728_exp3_real_haze/full_pass1x1/atmos/aerial_atmos.png', './result0728_exp3_real_haze/full_pass1x1/atmos/castle_atmos.png', './result0728_exp3_real_haze/full_pass1x1/atmos/cityscape_atmos.png']


In [39]:
img_dir= join('./result0728_exp3_real_haze','local_trans_global_atmos')
nonhaze_rec_list = []
dehaze_list = []
trans_map_list=[]
atmos_map_list=[]

with torch.no_grad():

        for i, test_batched in enumerate(test_loader,0):
            print("Image:",img_list[i])
            # haze: (b,c,h,w) with (r,g,b) order
            haze = test_batched['haze']
            h_origin,w_origin,h_pad,w_pad = test_batched['info']
            h_origin,w_origin,h_pad,w_pad = h_origin.item(),w_origin.item(),h_pad.item(),w_pad.item()
#             print("reflected haze shape:", haze.shape)
#             print("haze info:",h_origin,w_origin,h_pad,w_pad)
            haze = haze[:,:,h_pad:h_pad+h_origin,w_pad:w_pad+w_origin]
            print("original haze shape:", haze.shape)
            # trans_map: (b,c,h,w) with (r,g,b) order
            trans_map = cv2.imread(trans_name[i])[:,:,::-1]
            trans_map = torch.from_numpy(trans_map/255.).float().unsqueeze(0)
            trans_map = trans_map.permute(0, 3, 1, 2)
            print("trans_map shape:", trans_map.shape)
#             trans_map = trans_map.to(device)
                      
            atmos_light = cv2.imread(atmos_name[i])[:,:,::-1]
            atmos_light = torch.from_numpy(atmos_light/255.).float().unsqueeze(0)
            atmos_light = atmos_light.permute(0, 3, 1, 2)
            print("atmos_light shape:", atmos_light.shape)
#             atmos_light = atmos_light.to(device)
            
#             print("haze datatype:", haze.dtype)
#             print("trans_map datatype:", trans_map.dtype)
#             print("atmos_light datatype:", atmos_light.dtype)
            
            # Predict
            nonhaze_rec, dehaze = model_dehaze(haze.to(device), trans_map.to(device), atmos_light.to(device))

            for j in range(dehaze.size(0)):
                # (C, H, W)
                # nonhaze_rec_img = nonhaze_rec[j].detach().cpu().clone().numpy()
                nonhaze_rec_img = nonhaze_rec[j].detach().cpu().clone().numpy()
                dehaze_rec_img = dehaze[j].detach().cpu().clone().numpy()
                trans_map_img = trans_map[j].detach().cpu().clone().numpy()
                atmos_map_img = atmos_light[j].detach().cpu().clone().numpy()

                # (H, W, C)
                nonhaze_rec_img = (255*np.transpose(nonhaze_rec_img, (1,2,0))).astype(np.uint8)
                dehaze_rec_img = (255*np.transpose(dehaze_rec_img, (1,2,0))).astype(np.uint8)
                trans_map_img = (255*np.transpose(trans_map_img, (1,2,0))).astype(np.uint8)
                atmos_map_img = (255*np.transpose(atmos_map_img, (1,2,0))).astype(np.uint8)
                
                # add for saving image results
                nonhaze_rec_list.append(nonhaze_rec_img)
                dehaze_list.append(dehaze_rec_img)
                trans_map_list.append(trans_map_img)
                atmos_map_list.append(atmos_map_img)
                
#                 cv2.imshow('Nonhaze Rec', nonhaze_rec_img[:,:,::-1])
#                 cv2.imshow('Dehaze Result', dehaze_rec_img[:,:,::-1])
#                 cv2.waitKey(0)
#                 cv2.destroyAllWindows()
#                 break
            
            atmos_intensity_list=[]
            for idx in range(len(atmos_map_list)):
                atmos_intensity_list.append(cv2.cvtColor(atmos_map_list[idx], cv2.COLOR_RGB2GRAY))

            # because opencv is BGR order, we need to change RGB to BGR
            for i in range(len(nonhaze_rec_list)):
                cv2.imwrite(join(img_dir,'nonhaze_rec',img_list[i]+'_nonhaze_rec.png'),nonhaze_rec_list[i][:,:,::-1])
            for i in range(len(dehaze_list)):
                cv2.imwrite(join(img_dir,'dehaze',img_list[i]+'_dehaze.png'),dehaze_list[i][:,:,::-1])
            for i in range(len(trans_map_list)):
                cv2.imwrite(join(img_dir,'trans',img_list[i]+'_trans.png'),trans_map_list[i][:,:,::-1])
            for i in range(len(atmos_map_list)):
                cv2.imwrite(join(img_dir,'atmos',img_list[i]+'_atmos.png'),atmos_map_list[i][:,:,::-1])
            for i in range(len(atmos_intensity_list)):
                cv2.imwrite(join(img_dir,'atmos_intensity',img_list[i] +'_atmos_intensity.png'),atmos_intensity_list[i])
#             for i in range(len(gt_list)):
#                 cv2.imwrite(join(img_dir,'gt','gt'+str(i+1)+'.png'),gt_list[i][:,:,::-1])
            
#             break

Image: aerial
original haze shape: torch.Size([1, 3, 442, 622])
trans_map shape: torch.Size([1, 3, 442, 622])
atmos_light shape: torch.Size([1, 3, 442, 622])
Image: castle
original haze shape: torch.Size([1, 3, 611, 619])
trans_map shape: torch.Size([1, 3, 611, 619])
atmos_light shape: torch.Size([1, 3, 611, 619])
Image: cityscape
original haze shape: torch.Size([1, 3, 600, 400])
trans_map shape: torch.Size([1, 3, 600, 400])
atmos_light shape: torch.Size([1, 3, 600, 400])
Image: cliff
original haze shape: torch.Size([1, 3, 384, 576])
trans_map shape: torch.Size([1, 3, 384, 576])
atmos_light shape: torch.Size([1, 3, 384, 576])
Image: forest
original haze shape: torch.Size([1, 3, 768, 1024])
trans_map shape: torch.Size([1, 3, 768, 1024])
atmos_light shape: torch.Size([1, 3, 768, 1024])
Image: highquality13
original haze shape: torch.Size([1, 3, 576, 768])
trans_map shape: torch.Size([1, 3, 576, 768])
atmos_light shape: torch.Size([1, 3, 576, 768])
Image: img33
original haze shape: torch.

# local trans + local atmos

In [40]:
# read tans and atmos images
trans_dir = './result0728_exp3_real_haze/full_pass3x3/trans_blur_final/'
trans_name = []
for root, dirs, files in os.walk(trans_dir):
    for file in files:
#         print(file)
        trans_name.append(join(trans_dir,file))

print(len(trans_name))
print(trans_name[:3])

atmos_dir = './result0728_exp3_real_haze/full_pass3x3/atmos_linear_blur_gaussian/'
atmos_name = []
for root, dirs, files in os.walk(atmos_dir):
    for file in files:
#         print(file)
        atmos_name.append(join(atmos_dir,file))

print(len(atmos_name))
print(atmos_name[:3])

17
['./result0728_exp3_real_haze/full_pass3x3/trans_blur_final/aerial_trans_blur.png', './result0728_exp3_real_haze/full_pass3x3/trans_blur_final/castle_trans_blur.png', './result0728_exp3_real_haze/full_pass3x3/trans_blur_final/cityscape_trans_blur.png']
17
['./result0728_exp3_real_haze/full_pass3x3/atmos_linear_blur_gaussian/aerial_atmos_blur.png', './result0728_exp3_real_haze/full_pass3x3/atmos_linear_blur_gaussian/castle_atmos_blur.png', './result0728_exp3_real_haze/full_pass3x3/atmos_linear_blur_gaussian/cityscape_atmos_blur.png']


In [41]:
img_dir= join('./result0728_exp3_real_haze','local_trans_local_atmos')
nonhaze_rec_list = []
dehaze_list = []
trans_map_list=[]
atmos_map_list=[]

with torch.no_grad():

        for i, test_batched in enumerate(test_loader,0):
            print("Image:",img_list[i])
            # haze: (b,c,h,w) with (r,g,b) order
            haze = test_batched['haze']
            h_origin,w_origin,h_pad,w_pad = test_batched['info']
            h_origin,w_origin,h_pad,w_pad = h_origin.item(),w_origin.item(),h_pad.item(),w_pad.item()
#             print("reflected haze shape:", haze.shape)
#             print("haze info:",h_origin,w_origin,h_pad,w_pad)
            haze = haze[:,:,h_pad:h_pad+h_origin,w_pad:w_pad+w_origin]
            print("original haze shape:", haze.shape)
            # trans_map: (b,c,h,w) with (r,g,b) order
            trans_map = cv2.imread(trans_name[i])[:,:,::-1]
            trans_map = torch.from_numpy(trans_map/255.).float().unsqueeze(0)
            trans_map = trans_map.permute(0, 3, 1, 2)
            print("trans_map shape:", trans_map.shape)
#             trans_map = trans_map.to(device)
                      
            atmos_light = cv2.imread(atmos_name[i])[:,:,::-1]
            atmos_light = torch.from_numpy(atmos_light/255.).float().unsqueeze(0)
            atmos_light = atmos_light.permute(0, 3, 1, 2)
            print("atmos_light shape:", atmos_light.shape)
#             atmos_light = atmos_light.to(device)
            
#             print("haze datatype:", haze.dtype)
#             print("trans_map datatype:", trans_map.dtype)
#             print("atmos_light datatype:", atmos_light.dtype)
            
            # Predict
            nonhaze_rec, dehaze = model_dehaze(haze.to(device), trans_map.to(device), atmos_light.to(device))

            for j in range(dehaze.size(0)):
                # (C, H, W)
                # nonhaze_rec_img = nonhaze_rec[j].detach().cpu().clone().numpy()
                nonhaze_rec_img = nonhaze_rec[j].detach().cpu().clone().numpy()
                dehaze_rec_img = dehaze[j].detach().cpu().clone().numpy()
                trans_map_img = trans_map[j].detach().cpu().clone().numpy()
                atmos_map_img = atmos_light[j].detach().cpu().clone().numpy()

                # (H, W, C)
                nonhaze_rec_img = (255*np.transpose(nonhaze_rec_img, (1,2,0))).astype(np.uint8)
                dehaze_rec_img = (255*np.transpose(dehaze_rec_img, (1,2,0))).astype(np.uint8)
                trans_map_img = (255*np.transpose(trans_map_img, (1,2,0))).astype(np.uint8)
                atmos_map_img = (255*np.transpose(atmos_map_img, (1,2,0))).astype(np.uint8)
                
                # add for saving image results
                nonhaze_rec_list.append(nonhaze_rec_img)
                dehaze_list.append(dehaze_rec_img)
                trans_map_list.append(trans_map_img)
                atmos_map_list.append(atmos_map_img)
                
#                 cv2.imshow('Nonhaze Rec', nonhaze_rec_img[:,:,::-1])
#                 cv2.imshow('Dehaze Result', dehaze_rec_img[:,:,::-1])
#                 cv2.waitKey(0)
#                 cv2.destroyAllWindows()
#                 break
            
            atmos_intensity_list=[]
            for idx in range(len(atmos_map_list)):
                atmos_intensity_list.append(cv2.cvtColor(atmos_map_list[idx], cv2.COLOR_RGB2GRAY))

            # because opencv is BGR order, we need to change RGB to BGR
            for i in range(len(nonhaze_rec_list)):
                cv2.imwrite(join(img_dir,'nonhaze_rec',img_list[i]+'_nonhaze_rec.png'),nonhaze_rec_list[i][:,:,::-1])
            for i in range(len(dehaze_list)):
                cv2.imwrite(join(img_dir,'dehaze',img_list[i]+'_dehaze.png'),dehaze_list[i][:,:,::-1])
            for i in range(len(trans_map_list)):
                cv2.imwrite(join(img_dir,'trans',img_list[i]+'_trans.png'),trans_map_list[i][:,:,::-1])
            for i in range(len(atmos_map_list)):
                cv2.imwrite(join(img_dir,'atmos',img_list[i]+'_atmos.png'),atmos_map_list[i][:,:,::-1])
            for i in range(len(atmos_intensity_list)):
                cv2.imwrite(join(img_dir,'atmos_intensity',img_list[i] +'_atmos_intensity.png'),atmos_intensity_list[i])
#             for i in range(len(gt_list)):
#                 cv2.imwrite(join(img_dir,'gt','gt'+str(i+1)+'.png'),gt_list[i][:,:,::-1])
            
#             break

Image: aerial
original haze shape: torch.Size([1, 3, 442, 622])
trans_map shape: torch.Size([1, 3, 442, 622])
atmos_light shape: torch.Size([1, 3, 442, 622])
Image: castle
original haze shape: torch.Size([1, 3, 611, 619])
trans_map shape: torch.Size([1, 3, 611, 619])
atmos_light shape: torch.Size([1, 3, 611, 619])
Image: cityscape
original haze shape: torch.Size([1, 3, 600, 400])
trans_map shape: torch.Size([1, 3, 600, 400])
atmos_light shape: torch.Size([1, 3, 600, 400])
Image: cliff
original haze shape: torch.Size([1, 3, 384, 576])
trans_map shape: torch.Size([1, 3, 384, 576])
atmos_light shape: torch.Size([1, 3, 384, 576])
Image: forest
original haze shape: torch.Size([1, 3, 768, 1024])
trans_map shape: torch.Size([1, 3, 768, 1024])
atmos_light shape: torch.Size([1, 3, 768, 1024])
Image: highquality13
original haze shape: torch.Size([1, 3, 576, 768])
trans_map shape: torch.Size([1, 3, 576, 768])
atmos_light shape: torch.Size([1, 3, 576, 768])
Image: img33
original haze shape: torch.