In [1]:
import gc,os,cv2
from glob import glob
import numpy as np
import pandas as pd 
from PIL import Image
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')
import pywt
import shutil,time

import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader,Dataset
import torchvision.transforms as tt 
from joblib import Parallel,delayed

In [2]:
input_transforms_rgb=tt.Compose([
    tt.transforms.Resize((256,256),antialias=True),
    tt.ToTensor(),
    tt.Normalize(mean=(0.6344,0.5955,0.5857),std=(0.1742,0.1798,0.1871))
])




## Model Initilaizations

In [3]:
class PixelAttention(nn.Module):
    def __init__(self,channel,reduct_ratio=8):
        super(PixelAttention,self).__init__()
        reduced_channel=max(1,channel//reduct_ratio)
        self.pixel_attention=nn.Sequential(
            nn.Conv2d(channel,channel//reduced_channel,kernel_size=1,padding=0,bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel//reduced_channel,1,kernel_size=1,padding=0,bias=True),
            nn.Sigmoid()
        )
    def forward(self,feature):
        x=self.pixel_attention(feature)
        return x*feature

class ChannelAttention(nn.Module):
    def __init__(self,input_channels,reduct_ratio=8):
        super(ChannelAttention,self).__init__()
        reduced_channel=max(1,input_channels//reduct_ratio)
        self.avg_pooler=nn.AdaptiveAvgPool2d(1)
        self.fcn=nn.Sequential(
            nn.Linear(input_channels,reduced_channel),
            nn.ReLU(inplace=True),
            nn.Linear(reduced_channel,input_channels)
        )

    def forward(self,input_feature):
        n,c,_,_=input_feature.size()
        x=self.avg_pooler(input_feature).view(n,c)
        x=F.sigmoid(self.fcn(x).view(n,c,1,1))
        return input_feature*x

class AttentionBlock(nn.Module):
    def __init__(self,dims,kernel_size=1):
        super(AttentionBlock,self).__init__()
        self.conv1=nn.Conv2d(dims,dims,kernel_size,padding=(kernel_size//2),bias=True)
        self.conv2=nn.Conv2d(dims,dims,kernel_size,padding=(kernel_size//2),bias=True)
        self.ca=ChannelAttention(dims)
        self.pa=PixelAttention(dims)
    def forward(self,img):
        feat=F.relu(self.conv1(img),inplace=True)
        feat=feat+img
        feat=F.relu(self.conv1(feat),inplace=True)
        feat=self.ca(feat)
        feat=self.pa(feat)
        feat+=img
        return feat
class DWT_DehazingNet(nn.Module):
    def __init__(self):
        super(DWT_DehazingNet,self).__init__()
        self.conv1=nn.Conv2d(in_channels=3,out_channels=3,kernel_size=1,stride=1,padding=0)
        self.conv2=nn.Conv2d(in_channels=3,out_channels=3,kernel_size=3,stride=1,padding=1)
        self.attn1=AttentionBlock(3)
        self.conv3=nn.Conv2d(in_channels=9,out_channels=3,kernel_size=5,stride=1,padding=2)
        self.conv4=nn.Conv2d(in_channels=6,out_channels=3,kernel_size=7,stride=1,padding=3)
        self.attn2=AttentionBlock(3)
        self.conv5=nn.Conv2d(in_channels=15,out_channels=3,kernel_size=3,stride=1,padding=1)
        self.conv_dwt=nn.Conv2d(in_channels=12,out_channels=3,kernel_size=3,stride=1,padding=1)
        self.b=1

    def forward(self,x):
        dwt_coeffs=pywt.dwt2(x.cpu(),wavelet='db4')
        LL,(LH,HL,HH)=dwt_coeffs
        dwt_out=torch.concat([torch.from_numpy(LL),torch.from_numpy(LH),torch.from_numpy(HL),torch.from_numpy(HH)],dim=1)
        x1=F.relu(self.conv1(x))
        dwt_out=tt.Resize((256,256))(dwt_out)
        dwt_in=self.conv_dwt(dwt_out)
        x2=F.relu(self.conv2(x1))
        x2=self.attn1(x2)
        cat1=torch.cat((x1,x2,dwt_in),1)
        x3=F.relu(self.conv3(cat1))
        cat2=torch.cat((x2,x3),1)
        x4=F.relu(self.conv4(cat2))
        x4=self.attn2(x4)
        cat3=torch.cat((x1,x2,x3,x4,dwt_in),1)
        k=F.relu(self.conv5(cat3))
        return F.relu(k*x-k+self.b)
class DehazingNet(nn.Module):
    def __init__(self):
        super(DehazingNet,self).__init__()
        self.conv1=nn.Conv2d(in_channels=3,out_channels=3,kernel_size=1,stride=1,padding=0)
        self.conv2=nn.Conv2d(in_channels=3,out_channels=3,kernel_size=3,stride=1,padding=1)
        self.attn1=AttentionBlock(3)
        self.conv3=nn.Conv2d(in_channels=6,out_channels=3,kernel_size=5,stride=1,padding=2)
        self.conv4=nn.Conv2d(in_channels=6,out_channels=3,kernel_size=7,stride=1,padding=3)
        self.attn2=AttentionBlock(3)
        self.conv5=nn.Conv2d(in_channels=12,out_channels=3,kernel_size=3,stride=1,padding=1)
        self.conv_dwt=nn.Conv2d(in_channels=6,out_channels=3,kernel_size=3,stride=1,padding=1)
        self.b=1

    def forward(self,x):
        x1=F.relu(self.conv1(x))
        x2=F.relu(self.conv2(x1))
        x2=self.attn1(x2)
        cat1=torch.cat((x1,x2),1)
        x3=F.relu(self.conv3(cat1))
        cat2=torch.cat((x2,x3),1)
        x4=F.relu(self.conv4(cat2))
        x4=self.attn2(x4)
        cat3=torch.cat((x1,x2,x3,x4),1)
        k=F.relu(self.conv5(cat3))
        return F.relu(k*x-k+self.b)
class FinalCNN(nn.Module):
    def __init__(self):
        super(FinalCNN,self).__init__()
        self.conv1=nn.Conv2d(in_channels=3,out_channels=16,kernel_size=3,stride=1,padding=1)
        self.conv2=nn.Conv2d(in_channels=16,out_channels=32,kernel_size=3,stride=1,padding=1)
        self.conv3=nn.Conv2d(in_channels=32,out_channels=3,kernel_size=3,stride=1,padding=1)
        self.relu1=nn.ReLU()
        self.relu2=nn.ReLU()
        self.relu3=nn.ReLU()

    def forward(self, x):
        x=self.relu1(self.conv1(x))
        x=self.relu2(self.conv2(x))
        x=self.relu3(self.conv3(x))
        return x




In [4]:
dwt_dehazenet_rgb=nn.DataParallel(DWT_DehazingNet())
dehazenet_rgb=nn.DataParallel(DehazingNet())
final_cnn=FinalCNN()


dwt_dehazenet_rgb.load_state_dict(torch.load(r'/kaggle/input/dehazing-models-ct5129/dehazing-rgb-dwt-2l.pth',map_location=torch.device('cpu')))
dehazenet_rgb.load_state_dict(torch.load(r'/kaggle/input/dehazing-models-ct5129/dehazenet-rgb-2l.pth',map_location=torch.device('cpu')))
final_cnn.load_state_dict(torch.load(r'/kaggle/input/dehazing-models-ct5129/end_cnn.pth',map_location=torch.device('cpu')))

<All keys matched successfully>

In [5]:
def save_image_rgb(img_tensor,file_path):
    if img_tensor.shape[0]!=3:
        raise ValueError("Input tensor must have 3 channels only...")
    rgb_array=tensor_denormalize_rgb(img_tensor).permute(1,2,0).cpu().detach().numpy()
    rgb_image=Image.fromarray((np.clip(rgb_array,0,1)*255).astype(np.uint8),mode='RGB')
    rgb_image.save(file_path)

def tensor_denormalize_rgb(out_tensor,mean=[0.4556,0.3837,0.3642],std=[0.2689,0.2691,0.2828]):
    if len(out_tensor.shape)==3:
        out_tensor=out_tensor.unsqueeze(0)
    mean=torch.tensor(mean).unsqueeze(0).unsqueeze(2).unsqueeze(3)
    std=torch.tensor(std).unsqueeze(0).unsqueeze(2).unsqueeze(3)    
    denorm_tensor=(out_tensor*std)+mean
    return denorm_tensor.squeeze(0)
def unsharp_mask(image,kernel_size=(5,5),sigma=0.4,amount=1.0,threshold=1):
    blurred=cv2.GaussianBlur(image,kernel_size,sigma)
    sharpened=float(amount+1)*image-float(amount)*blurred
    sharpened=np.maximum(sharpened,np.zeros(sharpened.shape))
    sharpened=np.minimum(sharpened,255*np.ones(sharpened.shape))
    sharpened=sharpened.round().astype(np.uint8)
    if threshold>0:
        low_contrast_mask=np.absolute(image-blurred)<threshold
        np.copyto(sharpened,image,where=low_contrast_mask)
    return sharpened

def clahe(image):
    clahe=cv2.createCLAHE(clipLimit=1,tileGridSize=(2,2))
    lab=cv2.cvtColor(image,cv2.COLOR_BGR2LAB)
    lab[:,:,0]=clahe.apply(lab[:,:,0])
    return cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)

def enhance_image(image_path):
    image=cv2.imread(image_path)
    img=cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image_sharpened=unsharp_mask(img)
    image_clahe=clahe(image_sharpened)
    image_tensor=tt.ToTensor()(image_clahe)   
    return image_tensor
def alpha_blending(image1, image2, alpha=0.6):
    blended = cv2.addWeighted(image1, alpha, image2, 1 - alpha, 4)
    return blended
def image_addition(coeff,img_path1,img_path2):
    img1=cv2.cvtColor(cv2.imread(img_path1),cv2.COLOR_BGR2RGB)
    img2=cv2.cvtColor(cv2.imread(img_path2),cv2.COLOR_BGR2RGB)
    img_f=tt.ToTensor()(alpha_blending(img1,img2))    
    return img_f
def save_image_final(img_tensor,file_path):
    if img_tensor.shape[0]!=3:
        raise ValueError("Input tensor must have 3 channels only...")
    rgb_array=img_tensor.permute(1,2,0).cpu().detach().numpy()
    rgb_image=Image.fromarray((np.clip(rgb_array,0,1)*255).astype(np.uint8),mode='RGB')
    rgb_image.save(file_path)
def enhance_image_merged(image_path):
    image=cv2.imread(image_path)
    img=cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    img2=cv2.fastNlMeansDenoisingColored(img,None,10,10,3,21)
    image_sharpened=unsharp_mask(img2)
    image_clahe=clahe(image_sharpened)
    image_tensor=tt.ToTensor()(image_clahe)   
    return image_tensor


In [6]:
def model1_image_pass(img_tensor,folder_name,i):
    model_output1=dwt_dehazenet_rgb(img_tensor.unsqueeze(0)).cpu()
    save_image_rgb(model_output1.squeeze(),f'/kaggle/working/{folder_name}/output_image_{i}_1.png')
    proc_out_tensor1=enhance_image(f'/kaggle/working/{folder_name}/output_image_{i}_1.png')
    proc_arr1=proc_out_tensor1.permute(1,2,0).cpu().detach().numpy()
    proc_image1=Image.fromarray((np.clip(proc_arr1,0,1)*255).astype(np.uint8),mode='RGB')
    proc_image1.save(f'/kaggle/working/{folder_name}/processed_image_{i}_1.png')
    return proc_arr1
    
    
def model2_image_pass(img_tensor,folder_name,i):
    model_output2=dehazenet_rgb(img_tensor.unsqueeze(0)).cpu()
    save_image_rgb(model_output2.squeeze(),f'/kaggle/working/{folder_name}/output_image_{i}_2.png')
    proc_out_tensor2=enhance_image(f'/kaggle/working/{folder_name}/output_image_{i}_2.png')
    proc_arr2=proc_out_tensor2.permute(1,2,0).cpu().detach().numpy()
    proc_image2=Image.fromarray((np.clip(proc_arr2,0,1)*255).astype(np.uint8),mode='RGB')
    proc_image2.save(f'/kaggle/working/{folder_name}/processed_image_{i}_2.png')
    return proc_arr2

In [7]:
def model_run_parallel(haze_img_path,i=0,folder_name='Inference_Results'):
    os.makedirs(f'/kaggle/working/{folder_name}',exist_ok=True)
    dwt_dehazenet_rgb.eval()
    dwt_dehazenet_rgb.eval()
    start_time=time.time()
    haze_img_tensor=cv2.cvtColor(cv2.imread(haze_img_path),cv2.COLOR_BGR2RGB)
    haze_img_tensor=Image.open(haze_img_path)
    img_tensor=input_transforms_rgb(haze_img_tensor).cpu()
    
    arr1,arr2=Parallel(n_jobs=2)(delayed(func)(img_tensor,folder_name,i) for func in [model1_image_pass,model2_image_pass])
    blended_image=image_addition(0.6,f'/kaggle/working/{folder_name}/processed_image_{i}_2.png',
                                 f'/kaggle/working/{folder_name}/processed_image_{i}_1.png')
    save_image_final(blended_image.squeeze(),f'/kaggle/working/{folder_name}/merged_image_{i}.png')
    proc_blended=enhance_image_merged(f'/kaggle/working/{folder_name}/merged_image_{i}.png')
    save_image_final(blended_image.squeeze(),f'/kaggle/working/{folder_name}/proc_merged_image_{i}.png')
    cnn_processed=final_cnn(tt.ToTensor()(cv2.cvtColor(cv2.imread(f'/kaggle/working/{folder_name}/proc_merged_image_{i}.png'),
                                                       cv2.COLOR_BGR2RGB)))
    save_image_final(cnn_processed.squeeze(),f'/kaggle/working/{folder_name}/cnn_processed_image_{i}.png')

model_run_parallel('/kaggle/input/dehazing-dataset-thesis/NH-HAZE/NH-HAZE/01_hazy.png')



In [8]:
test_data=pd.read_csv('/kaggle/input/dehazing-dataset-thesis/dehazing_dataset_test.csv')
time_data=[]
for idx,(path) in enumerate(test_data.Hazy.values):
    start_time=time.time()
    model_run_parallel(path,idx)
    time_data.append(time.time()-start_time)
print('Processed....')
print(f"Average workflow processing time: {sum(time_data)/len(time_data)}")




Processed....
Average workflow processing time: 0.6446901701963865
