In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from torch.nn.init import _calculate_fan_in_and_fan_out
from timm.models.layers import to_2tuple, trunc_normal_
import os
from torch.utils.data import DataLoader



class RLN(nn.Module):
	r"""Revised LayerNorm"""
	def __init__(self, dim, eps=1e-5, detach_grad=False):
		super(RLN, self).__init__()
		self.eps = eps
		self.detach_grad = detach_grad

		self.weight = nn.Parameter(torch.ones((1, dim, 1, 1)))
		self.bias = nn.Parameter(torch.zeros((1, dim, 1, 1)))

		self.meta1 = nn.Conv2d(1, dim, 1)
		self.meta2 = nn.Conv2d(1, dim, 1)

		trunc_normal_(self.meta1.weight, std=.02)
		nn.init.constant_(self.meta1.bias, 1)

		trunc_normal_(self.meta2.weight, std=.02)
		nn.init.constant_(self.meta2.bias, 0)

	def forward(self, input):
		mean = torch.mean(input, dim=(1, 2, 3), keepdim=True)
		std = torch.sqrt((input - mean).pow(2).mean(dim=(1, 2, 3), keepdim=True) + self.eps)

		normalized_input = (input - mean) / std

		if self.detach_grad:
			rescale, rebias = self.meta1(std.detach()), self.meta2(mean.detach())
		else:
			rescale, rebias = self.meta1(std), self.meta2(mean)

		out = normalized_input * self.weight + self.bias
		return out, rescale, rebias


class Mlp(nn.Module):
	def __init__(self, network_depth, in_features, hidden_features=None, out_features=None):
		super().__init__()
		out_features = out_features or in_features
		hidden_features = hidden_features or in_features

		self.network_depth = network_depth

		self.mlp = nn.Sequential(
			nn.Conv2d(in_features, hidden_features, 1),
			nn.ReLU(True),
			nn.Conv2d(hidden_features, out_features, 1)
		)

		self.apply(self._init_weights)

	def _init_weights(self, m):
		if isinstance(m, nn.Conv2d):
			gain = (8 * self.network_depth) ** (-1/4)
			fan_in, fan_out = _calculate_fan_in_and_fan_out(m.weight)
			std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
			trunc_normal_(m.weight, std=std)
			if m.bias !=   None:
				nn.init.constant_(m.bias, 0)

	def forward(self, x):
		return self.mlp(x)


def window_partition(x, window_size):
	B, H, W, C = x.shape
	x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
	windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size**2, C)
	return windows


def window_reverse(windows, window_size, H, W):
	B = int(windows.shape[0] / (H * W / window_size / window_size))
	x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
	x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
	return x


def get_relative_positions(window_size):
	coords_h = torch.arange(window_size)
	coords_w = torch.arange(window_size)

	coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
	coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
	relative_positions = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww

	relative_positions = relative_positions.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
	relative_positions_log  = torch.sign(relative_positions) * torch.log(1. + relative_positions.abs())

	return relative_positions_log


class WindowAttention(nn.Module):
	def __init__(self, dim, window_size, num_heads):

		super().__init__()
		self.dim = dim
		self.window_size = window_size  # Wh, Ww
		self.num_heads = num_heads
		head_dim = dim // num_heads
		self.scale = head_dim ** -0.5

		relative_positions = get_relative_positions(self.window_size)
		self.register_buffer("relative_positions", relative_positions)
		self.meta = nn.Sequential(
			nn.Linear(2, 256, bias=True),
			nn.ReLU(True),
			nn.Linear(256, num_heads, bias=True)
		)

		self.softmax = nn.Softmax(dim=-1)

	def forward(self, qkv):
		B_, N, _ = qkv.shape

		qkv = qkv.reshape(B_, N, 3, self.num_heads, self.dim // self.num_heads).permute(2, 0, 3, 1, 4)

		q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

		q = q * self.scale
		attn = (q @ k.transpose(-2, -1))

		relative_position_bias = self.meta(self.relative_positions)
		relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
		attn = attn + relative_position_bias.unsqueeze(0)

		attn = self.softmax(attn)

		x = (attn @ v).transpose(1, 2).reshape(B_, N, self.dim)
		return x


class Attention(nn.Module):
	def __init__(self, network_depth, dim, num_heads, window_size, shift_size, use_attn=False, conv_type=None):
		super().__init__()
		self.dim = dim
		self.head_dim = int(dim // num_heads)
		self.num_heads = num_heads

		self.window_size = window_size
		self.shift_size = shift_size

		self.network_depth = network_depth
		self.use_attn = use_attn
		self.conv_type = conv_type

		if self.conv_type == 'Conv':
			self.conv = nn.Sequential(
				nn.Conv2d(dim, dim, kernel_size=3, padding=1, padding_mode='reflect'),
				nn.ReLU(True),
				nn.Conv2d(dim, dim, kernel_size=3, padding=1, padding_mode='reflect')
			)

		if self.conv_type == 'DWConv':
			self.conv = nn.Conv2d(dim, dim, kernel_size=5, padding=2, groups=dim, padding_mode='reflect')

		if self.conv_type == 'DWConv' or self.use_attn:
			self.V = nn.Conv2d(dim, dim, 1)
			self.proj = nn.Conv2d(dim, dim, 1)

		if self.use_attn:
			self.QK = nn.Conv2d(dim, dim * 2, 1)
			self.attn = WindowAttention(dim, window_size, num_heads)

		self.apply(self._init_weights)

	def _init_weights(self, m):
		if isinstance(m, nn.Conv2d):
			w_shape = m.weight.shape
			
			if w_shape[0] == self.dim * 2:	# QK
				fan_in, fan_out = _calculate_fan_in_and_fan_out(m.weight)
				std = math.sqrt(2.0 / float(fan_in + fan_out))
				trunc_normal_(m.weight, std=std)		
			else:
				gain = (8 * self.network_depth) ** (-1/4)
				fan_in, fan_out = _calculate_fan_in_and_fan_out(m.weight)
				std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
				trunc_normal_(m.weight, std=std)

			if m.bias !=  None:
				nn.init.constant_(m.bias, 0)

	def check_size(self, x, shift=False):
		_, _, h, w = x.size()
		mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
		mod_pad_w = (self.window_size - w % self.window_size) % self.window_size

		if shift:
			x = F.pad(x, (self.shift_size, (self.window_size-self.shift_size+mod_pad_w) % self.window_size,
						  self.shift_size, (self.window_size-self.shift_size+mod_pad_h) % self.window_size), mode='reflect')
		else:
			x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
		return x

	def forward(self, X):
		B, C, H, W = X.shape

		if self.conv_type == 'DWConv' or self.use_attn:
			V = self.V(X)
		#print(self.use_attn)
		if self.use_attn:
			#print('attention')      
			QK = self.QK(X)
			QKV = torch.cat([QK, V], dim=1)

			# shift
			shifted_QKV = self.check_size(QKV, self.shift_size > 0)
			Ht, Wt = shifted_QKV.shape[2:]

			# partition windows
			shifted_QKV = shifted_QKV.permute(0, 2, 3, 1)
			qkv = window_partition(shifted_QKV, self.window_size)  # nW*B, window_size**2, C

			attn_windows = self.attn(qkv)

			# merge windows
			shifted_out = window_reverse(attn_windows, self.window_size, Ht, Wt)  # B H' W' C

			# reverse cyclic shift
			out = shifted_out[:, self.shift_size:(self.shift_size+H), self.shift_size:(self.shift_size+W), :]
			attn_out = out.permute(0, 3, 1, 2)

			if self.conv_type in ['Conv', 'DWConv']:
				conv_out = self.conv(V)
				out = self.proj(conv_out + attn_out)
			else:
				out = self.proj(attn_out)

		else:
			if self.conv_type == 'Conv':
				out = self.conv(X)				# no attention and use conv, no projection
			elif self.conv_type == 'DWConv':
				out = self.proj(self.conv(V))

		return out


class TransformerBlock(nn.Module):
	def __init__(self, network_depth, dim, num_heads, mlp_ratio=4.,
				 norm_layer=nn.LayerNorm, mlp_norm=False,
				 window_size=8, shift_size=0, use_attn=True, conv_type=None):
		super().__init__()
		self.use_attn = use_attn
		self.mlp_norm = mlp_norm

		self.norm1 = norm_layer(dim) if use_attn else nn.Identity()
		self.attn = Attention(network_depth, dim, num_heads=num_heads, window_size=window_size,
							  shift_size=shift_size, use_attn=use_attn, conv_type=conv_type)

		self.norm2 = norm_layer(dim) if use_attn and mlp_norm else nn.Identity()
		self.mlp = Mlp(network_depth, dim, hidden_features=int(dim * mlp_ratio))

	def forward(self, x):
		identity = x
		if self.use_attn: x, rescale, rebias = self.norm1(x)
		x = self.attn(x)
		if self.use_attn: x = x * rescale + rebias
		x = identity + x

		identity = x
		if self.use_attn and self.mlp_norm: x, rescale, rebias = self.norm2(x)
		x = self.mlp(x)
		if self.use_attn and self.mlp_norm: x = x * rescale + rebias
		x = identity + x
		return x


class BasicLayer(nn.Module):
	def __init__(self, network_depth, dim, depth, num_heads, mlp_ratio=4.,
				 norm_layer=nn.LayerNorm, window_size=8,
				 attn_ratio=0., attn_loc='last', conv_type=None):

		super().__init__()
		self.dim = dim
		self.depth = depth

		attn_depth = attn_ratio * depth

		if attn_loc == 'last':
			use_attns = [i >= depth-attn_depth for i in range(depth)]
		elif attn_loc == 'first':
			use_attns = [i < attn_depth for i in range(depth)]
		elif attn_loc == 'middle':
			use_attns = [i >= (depth-attn_depth)//2 and i < (depth+attn_depth)//2 for i in range(depth)]

		# build blocks
		self.blocks = nn.ModuleList([
			TransformerBlock(network_depth=network_depth,
							 dim=dim, 
							 num_heads=num_heads,
							 mlp_ratio=mlp_ratio,
							 norm_layer=norm_layer,
							 window_size=window_size,
							 shift_size=0 if (i % 2 == 0) else window_size // 2,
							 use_attn=use_attns[i], conv_type=conv_type)
			for i in range(depth)])

	def forward(self, x):
		for blk in self.blocks:
			x = blk(x)
		return x


class PatchEmbed(nn.Module):
	def __init__(self, patch_size=4, in_chans=3, embed_dim=96, kernel_size=None):
		super().__init__()
		self.in_chans = in_chans
		self.embed_dim = embed_dim

		if kernel_size is None:
			kernel_size = patch_size

		self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=patch_size,
							  padding=(kernel_size-patch_size+1)//2, padding_mode='reflect')

	def forward(self, x):
		x = self.proj(x)
		return x


class PatchUnEmbed(nn.Module):
	def __init__(self, patch_size=4, out_chans=3, embed_dim=96, kernel_size=None):
		super().__init__()
		self.out_chans = out_chans
		self.embed_dim = embed_dim

		if kernel_size is None:
			kernel_size = 1

		self.proj = nn.Sequential(
			nn.Conv2d(embed_dim, out_chans*patch_size**2, kernel_size=kernel_size,
					  padding=kernel_size//2, padding_mode='reflect'),
			nn.PixelShuffle(patch_size)
		)

	def forward(self, x):
		x = self.proj(x)
		return x


class SKFusion(nn.Module):
	def __init__(self, dim, height=2, reduction=8):
		super(SKFusion, self).__init__()
		
		self.height = height
		d = max(int(dim/reduction), 4)
		
		self.avg_pool = nn.AdaptiveAvgPool2d(1)
		self.mlp = nn.Sequential(
			nn.Conv2d(dim, d, 1, bias=False), 
			nn.ReLU(),
			nn.Conv2d(d, dim*height, 1, bias=False)
		)
		
		self.softmax = nn.Softmax(dim=1)

	def forward(self, in_feats):
		B, C, H, W = in_feats[0].shape
		
		in_feats = torch.cat(in_feats, dim=1)
		in_feats = in_feats.view(B, self.height, C, H, W)
		
		feats_sum = torch.sum(in_feats, dim=1)
		attn = self.mlp(self.avg_pool(feats_sum))
		attn = self.softmax(attn.view(B, self.height, C, 1, 1))

		out = torch.sum(in_feats*attn, dim=1)
		return out      


class DehazeFormer(nn.Module):
	def __init__(self, in_chans=3, out_chans=4, window_size=8,
				 embed_dims=[24, 48, 96, 48, 24],
				 mlp_ratios=[2., 4., 4., 2., 2.],
				 depths=[16, 16, 16, 8, 8],
				 num_heads=[2, 4, 6, 1, 1],
				 attn_ratio=[1/4, 1/2, 3/4, 0, 0],
				 conv_type=['DWConv', 'DWConv', 'DWConv', 'DWConv', 'DWConv'],
				 norm_layer=[RLN, RLN, RLN, RLN, RLN]):
		super(DehazeFormer, self).__init__()

		# setting
		self.patch_size = 4
		self.window_size = window_size
		self.mlp_ratios = mlp_ratios

		# split image into non-overlapping patches
		self.patch_embed = PatchEmbed(
			patch_size=1, in_chans=in_chans, embed_dim=embed_dims[0], kernel_size=3)

		# backbone
		self.layer1 = BasicLayer(network_depth=sum(depths), dim=embed_dims[0], depth=depths[0],
					   			 num_heads=num_heads[0], mlp_ratio=mlp_ratios[0],
					   			 norm_layer=norm_layer[0], window_size=window_size,
					   			 attn_ratio=attn_ratio[0], attn_loc='last', conv_type=conv_type[0])

		self.patch_merge1 = PatchEmbed(
			patch_size=2, in_chans=embed_dims[0], embed_dim=embed_dims[1])

		self.skip1 = nn.Conv2d(embed_dims[0], embed_dims[0], 1)

		self.layer2 = BasicLayer(network_depth=sum(depths), dim=embed_dims[1], depth=depths[1],
								 num_heads=num_heads[1], mlp_ratio=mlp_ratios[1],
								 norm_layer=norm_layer[1], window_size=window_size,
								 attn_ratio=attn_ratio[1], attn_loc='last', conv_type=conv_type[1])

		self.patch_merge2 = PatchEmbed(
			patch_size=2, in_chans=embed_dims[1], embed_dim=embed_dims[2])

		self.skip2 = nn.Conv2d(embed_dims[1], embed_dims[1], 1)

		self.layer3 = BasicLayer(network_depth=sum(depths), dim=embed_dims[2], depth=depths[2],
								 num_heads=num_heads[2], mlp_ratio=mlp_ratios[2],
								 norm_layer=norm_layer[2], window_size=window_size,
								 attn_ratio=attn_ratio[2], attn_loc='last', conv_type=conv_type[2])

		self.patch_split1 = PatchUnEmbed(
			patch_size=2, out_chans=embed_dims[3], embed_dim=embed_dims[2])

		assert embed_dims[1] == embed_dims[3]
		self.fusion1 = SKFusion(embed_dims[3])

		self.layer4 = BasicLayer(network_depth=sum(depths), dim=embed_dims[3], depth=depths[3],
								 num_heads=num_heads[3], mlp_ratio=mlp_ratios[3],
								 norm_layer=norm_layer[3], window_size=window_size,
								 attn_ratio=attn_ratio[3], attn_loc='last', conv_type=conv_type[3])

		self.patch_split2 = PatchUnEmbed(
			patch_size=2, out_chans=embed_dims[4], embed_dim=embed_dims[3])

		assert embed_dims[0] == embed_dims[4]
		self.fusion2 = SKFusion(embed_dims[4])			

		self.layer5 = BasicLayer(network_depth=sum(depths), dim=embed_dims[4], depth=depths[4],
					   			 num_heads=num_heads[4], mlp_ratio=mlp_ratios[4],
					   			 norm_layer=norm_layer[4], window_size=window_size,
					   			 attn_ratio=attn_ratio[4], attn_loc='last', conv_type=conv_type[4])

		# merge non-overlapping patches into image
		self.patch_unembed = PatchUnEmbed(
			patch_size=1, out_chans=out_chans, embed_dim=embed_dims[4], kernel_size=3)


	def check_image_size(self, x):
		# NOTE: for I2I test
		_, _, h, w = x.size()
		mod_pad_h = (self.patch_size - h % self.patch_size) % self.patch_size
		mod_pad_w = (self.patch_size - w % self.patch_size) % self.patch_size
		x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
		return x

	def forward_features(self, x):
		x = self.patch_embed(x)
		x = self.layer1(x)
		skip1 = x

		x = self.patch_merge1(x)
		x = self.layer2(x)
		skip2 = x

		x = self.patch_merge2(x)
		x = self.layer3(x)
		x = self.patch_split1(x)

		x = self.fusion1([x, self.skip2(skip2)]) + x
		x = self.layer4(x)
		x = self.patch_split2(x)

		x = self.fusion2([x, self.skip1(skip1)]) + x
		x = self.layer5(x)
		x = self.patch_unembed(x)
		return x

	def forward(self, x):
		H, W = x.shape[2:]
		x = self.check_image_size(x)

		feat = self.forward_features(x)
		K, B = torch.split(feat, (1, 3), dim=1)

		x = K * x - B + x
		x = x[:, :, :H, :W]
		return x


def dehazeformer_t():
    return DehazeFormer(
		embed_dims=[24, 48, 96, 48, 24],
		mlp_ratios=[2., 4., 4., 2., 2.],
		depths=[4, 4, 4, 2, 2],
		num_heads=[2, 4, 6, 1, 1],
		attn_ratio=[0, 1/2, 1, 0, 0],
		conv_type=['DWConv', 'DWConv', 'DWConv', 'DWConv', 'DWConv'])


def dehazeformer_s():
    return DehazeFormer(
		embed_dims=[24, 48, 96, 48, 24],
		mlp_ratios=[2., 4., 4., 2., 2.],
		depths=[8, 8, 8, 4, 4],
		num_heads=[2, 4, 6, 1, 1],
		attn_ratio=[1/4, 1/2, 3/4, 0, 0],
		conv_type=['DWConv', 'DWConv', 'DWConv', 'DWConv', 'DWConv'])


def dehazeformer_b():
    return DehazeFormer(
        embed_dims=[24, 48, 96, 48, 24],
		mlp_ratios=[2., 4., 4., 2., 2.],
		depths=[16, 16, 16, 8, 8],
		num_heads=[2, 4, 6, 1, 1],
		attn_ratio=[1/4, 1/2, 3/4, 0, 0],
		conv_type=['DWConv', 'DWConv', 'DWConv', 'DWConv', 'DWConv'])


def dehazeformer_d():
    return DehazeFormer(
        embed_dims=[24, 48, 96, 48, 24],
		mlp_ratios=[2., 4., 4., 2., 2.],
		depths=[32, 32, 32, 16, 16],
		num_heads=[2, 4, 6, 1, 1],
		attn_ratio=[1/4, 1/2, 3/4, 0, 0],
		conv_type=['DWConv', 'DWConv', 'DWConv', 'DWConv', 'DWConv'])


def dehazeformer_w():
    return DehazeFormer(
        embed_dims=[48, 96, 192, 96, 48],
		mlp_ratios=[2., 4., 4., 2., 2.],
		depths=[16, 16, 16, 8, 8],
		num_heads=[2, 4, 6, 1, 1],
		attn_ratio=[1/4, 1/2, 3/4, 0, 0],
		conv_type=['DWConv', 'DWConv', 'DWConv', 'DWConv', 'DWConv'])


def dehazeformer_m():
    return DehazeFormer(
		embed_dims=[24, 48, 96, 48, 24],
		mlp_ratios=[2., 4., 4., 2., 2.],
		depths=[12, 12, 12, 6, 6],
		num_heads=[2, 4, 6, 1, 1],
		attn_ratio=[1/4, 1/2, 3/4, 0, 0],
		conv_type=['Conv', 'Conv', 'Conv', 'Conv', 'Conv'])


def dehazeformer_l():
    return DehazeFormer(
		embed_dims=[48, 96, 192, 96, 48],
		mlp_ratios=[2., 4., 4., 2., 2.],
		depths=[16, 16, 16, 12, 12],
		num_heads=[2, 4, 6, 1, 1],
		attn_ratio=[1/4, 1/2, 3/4, 0, 0],
		conv_type=['Conv', 'Conv', 'Conv', 'Conv', 'Conv'])



In [2]:

from torch.nn import functional as F

import torch.nn as nn

class BasicBlock(nn.Module):
	def __init__(self, network_depth, dim, depth, num_heads, mlp_ratio=4.,
				 norm_layer=nn.LayerNorm, window_size=8,
				 attn_ratio=0., attn_loc='last', conv_type=None):

		super().__init__()
		self.dim = dim
		self.depth = depth
		self.gf=FastGuidedFilter(r=1)
		self.downsample = nn.Upsample(
            scale_factor=0.5, mode="bilinear", align_corners=True
        )
    
		depth_rate=24
		kernel_size=3
		in_channels=3
		self.conv_out = nn.Conv2d(depth_rate*2, depth_rate, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)
		self.relu1=nn.ReLU(inplace=True)
		self.relu2=nn.ReLU(inplace=True)
		self.norm1=AdaptiveInstanceNorm(depth_rate)
		self.norm2=AdaptiveInstanceNorm(depth_rate) 
		attn_depth = attn_ratio * depth
		#print(attn_depth,attn_ratio,depth)
		if attn_loc == 'last':
			use_attns = [i >= depth-attn_depth for i in range(depth)]
		elif attn_loc == 'first':
			use_attns = [i < attn_depth for i in range(depth)]
		elif attn_loc == 'middle':
			use_attns = [i >= (depth-attn_depth)//2 and i < (depth+attn_depth)//2 for i in range(depth)]

		# build blocks
		self.blocks = nn.ModuleList([
			TransformerBlock(network_depth=network_depth,
							 dim=dim, 
							 num_heads=num_heads,
							 mlp_ratio=mlp_ratio,
							 norm_layer=norm_layer,
							 window_size=window_size,
							 shift_size=0 if (i % 2 == 0) else window_size // 2,
							 use_attn=use_attns[i], conv_type=conv_type)
			for i in range(depth)])

	def forward(self, x_hr):
		x_lr = self.downsample(x_hr)
   
		x_lr_new=self.norm1(x_lr)
		x_lr_new=self.relu1( x_lr_new)
		for blc in self.blocks:
   
		    x_lr_new = blc(x_lr_new)
    
		g_hr= self.gf(x_lr, x_lr_new, x_hr)
		gx_cat=torch.cat([g_hr,x_hr],1)
		g_hr=self.conv_out(gx_cat)
		g_hr=self.norm2(g_hr)
		g_hr=self.relu2( g_hr)
		x=g_hr+x_hr
   
		return g_hr
   


class DeepGuidedFilterFormer(nn.Module):
    def __init__(self,  radius=1):
        super().__init__()

        
        norm = AdaptiveInstanceNorm
        depth_rate=24
        kernel_size=3
        in_channels=3
        self.conv_in = nn.Conv2d(in_channels, depth_rate, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)
        self.conv_out = nn.Conv2d(depth_rate, in_channels, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)
        self.relu1=nn.ReLU(inplace=True)
        self.block_num=3
        network_depth=50
        dim=depth_rate
        mlp_ratio=2.0
        norm_layer=RLN
        window_size=16
        conv_type='Conv'
        depth=4
        num_heads=4
        attn_ratio=1/4
        
        
        self.blocks = nn.ModuleList([
                   BasicBlock(network_depth=network_depth, dim=dim, depth=depth,
					   			 num_heads=num_heads, mlp_ratio=mlp_ratio,
					   			 norm_layer=norm_layer, window_size=window_size,
					   			 attn_ratio=attn_ratio, attn_loc='last', conv_type=conv_type)
			             for i in range(self.block_num)])

        

    def forward(self, x_hr):
        x_hr=self.conv_in(x_hr)
        #x_hr=self.relu1(x_hr)
        #pixelshuffle_ratio=2
        # Unpixelshuffle
        #x_lr_unpixelshuffled = unpixel_shuffle(x_lr, pixelshuffle_ratio)
        
        for blc in self.blocks:
            x_hr=blc(x_hr)
            
        x_hr=self.conv_out(x_hr)
        # Pixelshuffle
        #y_lr = F.pixel_shuffle(
           # self.lr(x_lr_unpixelshuffled), pixelshuffle_ratio
        #)

        return x_hr
           
   
   
class ConvGuidedFilter(nn.Module):
    """
    Adapted from https://github.com/wuhuikai/DeepGuidedFilter
    """
    def __init__(self, radius=1, norm=nn.BatchNorm2d, conv_a_kernel_size: int = 1):
        super(ConvGuidedFilter, self).__init__()

        self.box_filter = nn.Conv2d(
            3, 3, kernel_size=3, padding=radius, dilation=radius, bias=False, groups=3
        )
        self.conv_a = nn.Sequential(
            nn.Conv2d(
                6,
                32,
                kernel_size=conv_a_kernel_size,
                padding=conv_a_kernel_size // 2,
                bias=False,
            ),
            norm(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                32,
                32,
                kernel_size=conv_a_kernel_size,
                padding=conv_a_kernel_size // 2,
                bias=False,
            ),
            norm(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                32,
                3,
                kernel_size=conv_a_kernel_size,
                padding=conv_a_kernel_size // 2,
                bias=False,
            ),
        )
        self.box_filter.weight.data[...] = 1.0

    def forward(self, x_lr, y_lr, x_hr):
        _, _, h_lrx, w_lrx = x_lr.size()
        _, _, h_hrx, w_hrx = x_hr.size()

        N = self.box_filter(x_lr.data.new().resize_((1, 3, h_lrx, w_lrx)).fill_(1.0))
        ## mean_x
        mean_x = self.box_filter(x_lr) / N
        ## mean_y
        mean_y = self.box_filter(y_lr) / N
        ## cov_xy
        cov_xy = self.box_filter(x_lr * y_lr) / N - mean_x * mean_y
        ## var_x
        var_x = self.box_filter(x_lr * x_lr) / N - mean_x * mean_x

        ## A
        A = self.conv_a(torch.cat([cov_xy, var_x], dim=1))
        ## b
        b = mean_y - A * mean_x

        ## mean_A; mean_b
        mean_A = F.interpolate(A, (h_hrx, w_hrx), mode="bilinear", align_corners=True)
        mean_b = F.interpolate(b, (h_hrx, w_hrx), mode="bilinear", align_corners=True)

        return mean_A * x_hr + mean_b
        


        

class DeepGuideddetail(nn.Module):
    def __init__(self,  radius=1):
        super().__init__()

        
        norm = AdaptiveInstanceNorm

        
        

        #self.lr = dehazeformer_m()
        kernel_size=3
        depth_rate=16
        in_channels=3
        num_dense_layer=4
        growth_rate=16
        growth_rate=16
        
        self.conv_in = nn.Conv2d(in_channels, depth_rate, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)
        self.conv_out = nn.Conv2d(depth_rate, in_channels, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)
        self.rdb1 = SRDB(depth_rate, num_dense_layer, growth_rate)
        self.rdb2 = SRDB(depth_rate, num_dense_layer, growth_rate)
        self.rdb3 = SRDB(depth_rate, num_dense_layer, growth_rate)
        self.rdb4 = SRDB(depth_rate, num_dense_layer, growth_rate)

        self.gf = ConvGuidedFilter(radius, norm=norm)

        self.downsample = nn.Upsample(
            scale_factor=0.5, mode="bilinear", align_corners=True
        )

    def forward(self, x_hr):
        x_lr = self.downsample(x_hr)
        y_lr=self.conv_in(x_lr)
        y_lr=self.rdb1(y_lr)
        y_lr=self.rdb2(y_lr)
        y_lr=self.rdb3(y_lr)
        y_lr=self.rdb4(y_lr)
        y_lr=self.conv_out(y_lr)
        
        #pixelshuffle_ratio=2
        # Unpixelshuffle
        #x_lr_unpixelshuffled = unpixel_shuffle(x_lr, pixelshuffle_ratio)
        
        #y_lr=self.lr(x_lr)
        # Pixelshuffle
        #y_lr = F.pixel_shuffle(
           # self.lr(x_lr_unpixelshuffled), pixelshuffle_ratio
        #)

        return F.tanh( self.gf(x_lr, y_lr, x_hr))
                
        

class DeepGuidedall(nn.Module):
    def __init__(self,  radius=1):
        super().__init__()

        
        norm = AdaptiveInstanceNorm

        
        

        #self.lr = dehazeformer_m()
        kernel_size=3
        depth_rate=16
        in_channels=3
        num_dense_layer=4
        growth_rate=16
        growth_rate=16
        
        self.conv_in = nn.Conv2d(in_channels, depth_rate, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)
        self.conv_out = nn.Conv2d(depth_rate, in_channels, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)
        self.rdb1 = SRDB(depth_rate, num_dense_layer, growth_rate)
        self.rdb2 = SRDB(depth_rate, num_dense_layer, growth_rate)
        self.rdb3 = SRDB(depth_rate, num_dense_layer, growth_rate)
        self.rdb4 = SRDB(depth_rate, num_dense_layer, growth_rate)

        self.gf = ConvGuidedFilter(radius, norm=norm)
        self.lr = dehazeformer_m()

        self.downsample = nn.Upsample(
            scale_factor=0.5, mode="bilinear", align_corners=True
        )
        self.upsample = nn.Upsample(
            scale_factor=2, mode="bilinear", align_corners=True
        )

    def forward(self, x_hr):
        x_lr = self.downsample(x_hr)
        y_lr=self.conv_in(x_lr)
        y_lr=self.rdb1(y_lr)
        y_lr=self.rdb2(y_lr)
        y_lr=self.rdb3(y_lr)
        y_lr=self.rdb4(y_lr)
        y_detail=self.conv_out(y_lr)
        y_base=self.lr(x_lr)
        y_lr=y_base+y_detail
        y_base=self.upsample(y_base)
        
        #pixelshuffle_ratio=2
        # Unpixelshuffle
        #x_lr_unpixelshuffled = unpixel_shuffle(x_lr, pixelshuffle_ratio)
        
        #y_lr=self.lr(x_lr)
        # Pixelshuffle
        #y_lr = F.pixel_shuffle(
           # self.lr(x_lr_unpixelshuffled), pixelshuffle_ratio
        #)

        return F.tanh( self.gf(x_lr, y_lr, x_hr)), y_base   




class DeepGuidednew(nn.Module):
    def __init__(self,  radius=1):
        super().__init__()

        
        norm = AdaptiveInstanceNorm

        
        

        #self.lr = dehazeformer_m()
        kernel_size=3
        depth_rate=16
        in_channels=3
        num_dense_layer=4
        growth_rate=16
        growth_rate=16
        
        self.conv_in = nn.Conv2d(in_channels, depth_rate, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)
        self.conv_out = nn.Conv2d(depth_rate, in_channels, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)
        self.rdb1 = SRDB(depth_rate, num_dense_layer, growth_rate)
        self.rdb2 = SRDB(depth_rate, num_dense_layer, growth_rate)
        self.rdb3 = SRDB(depth_rate, num_dense_layer, growth_rate)
        self.rdb4 = SRDB(depth_rate, num_dense_layer, growth_rate)

        self.gf = ConvGuidedFilter(radius, norm=norm)
        self.lr = dehazeformer_m()

        self.downsample = nn.Upsample(
            scale_factor=0.5, mode="bilinear", align_corners=True
        )
        self.upsample = nn.Upsample(
            scale_factor=2, mode="bilinear", align_corners=True
        )

    def forward(self, x_hr):
        x_lr = self.downsample(x_hr)
        y_lr=self.conv_in(x_lr)
        y_lr=self.rdb1(y_lr)
        y_lr=self.rdb2(y_lr)
        y_lr=self.rdb3(y_lr)
        y_lr=self.rdb4(y_lr)
        y_detail=self.conv_out(y_lr)
        y_base=self.lr(x_lr)
        y_lr=y_base+y_detail
        y_base=self.upsample(y_base)
        
        #pixelshuffle_ratio=2
        # Unpixelshuffle
        #x_lr_unpixelshuffled = unpixel_shuffle(x_lr, pixelshuffle_ratio)
        
        #y_lr=self.lr(x_lr)
        # Pixelshuffle
        #y_lr = F.pixel_shuffle(
           # self.lr(x_lr_unpixelshuffled), pixelshuffle_ratio
        #)

        return  self.gf(x_lr, y_lr, x_hr), y_base               
        
        

class DeepAtrousGuidedFilter(nn.Module):
    def __init__(self,  radius=1):
        super().__init__()

        
        norm = AdaptiveInstanceNorm

        
        

        self.lr = dehazeformer_m()

        self.gf = ConvGuidedFilter(radius, norm=norm)

        self.downsample = nn.Upsample(
            scale_factor=0.5, mode="bilinear", align_corners=True
        )

    def forward(self, x_hr):
        x_lr = self.downsample(x_hr)
        #pixelshuffle_ratio=2
        # Unpixelshuffle
        #x_lr_unpixelshuffled = unpixel_shuffle(x_lr, pixelshuffle_ratio)
        y_lr=self.lr(x_lr)
        # Pixelshuffle
        #y_lr = F.pixel_shuffle(
           # self.lr(x_lr_unpixelshuffled), pixelshuffle_ratio
        #)

        return F.tanh( self.gf(x_lr, y_lr, x_hr))

In [3]:
import torch
import torch.nn as nn


class AdaptiveInstanceNorm(nn.Module):
    def __init__(self, n):
        super(AdaptiveInstanceNorm, self).__init__()

        self.w_0 = nn.Parameter(torch.Tensor([1.0]))
        self.w_1 = nn.Parameter(torch.Tensor([0.0]))

        self.ins_norm = nn.InstanceNorm2d(n, momentum=0.999, eps=0.001, affine=True)

    def forward(self, x):
        return self.w_0 * x + self.w_1 * self.ins_norm(x)


class PALayer(nn.Module):
    def __init__(self, channel: int):
        super(PALayer, self).__init__()
        self.pa = nn.Sequential(
            nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(channel // 8, 1, 1, padding=0, bias=True),
            nn.Sigmoid(),
        )

    def forward(self, x):
        y = self.pa(x)
        return x * y


class CALayer(nn.Module):
    def __init__(self, channel: int):
        super(CALayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.ca = nn.Sequential(
            nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(channel // 8, channel, 1, padding=0, bias=True),
            nn.Sigmoid(),
        )

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.ca(y)
        return x * y


In [4]:
import torch
from torch.nn import functional as F


def unpixel_shuffle(feature, r: int = 1):
    b, c, h, w = feature.shape
    out_channel = c * (r ** 2)
    out_h = h // r
    out_w = w // r
    feature_view = feature.contiguous().view(b, c, out_h, r, out_w, r)
    feature_prime = (
        feature_view.permute(0, 1, 3, 5, 2, 4)
        .contiguous()
        .view(b, out_channel, out_h, out_w)
    )
    return feature_prime


def sample_patches(
    inputs: torch.Tensor, patch_size: int = 3, stride: int = 2
) -> torch.Tensor:
    """

    :param inputs: the input feature maps, shape: (n, c, h, w).
    :param patch_size: the spatial size of sampled patches
    :param stride: the stride of sampling.
    :return: extracted patches, shape: (n, c, patch_size, patch_size, n_patches).
    """
    """
    Patch sampler for feature maps.
    Parameters
    ---
    inputs : torch.Tensor
        
    patch_size : int, optional
       
    stride : int, optional
        
    Returns
    ---
    patches : torch.Tensor
        
    """

    n, c, h, w = inputs.shape
    patches = (
        inputs.unfold(2, patch_size, stride)
        .unfold(3, patch_size, stride)
        .reshape(n, c, -1, patch_size, patch_size)
        .permute(0, 1, 3, 4, 2)
    )
    return patches


def chop_patches(
    img: torch.Tensor, patch_size_h: int = 256, patch_size_w: int = 512
) -> torch.Tensor:
    """

    :param inputs: the input feature maps, shape: (n, c, h, w).
    :param patch_size: the spatial size of sampled patches
    :param stride: the stride of sampling.
    :return: extracted patches, shape: (n, c, patch_size, patch_size, n_patches).
    """
    """
    Patch sampler for feature maps.
    Parameters
    ---
    inputs : torch.Tensor

    patch_size : int, optional

    stride : int, optional

    Returns
    ---
    patches : torch.Tensor

    """
    patches = (
        img.unfold(2, patch_size_h, patch_size_h)
        .unfold(3, patch_size_w, patch_size_w)
        .contiguous()
        .permute(2, 3, 0, 1, 4, 5)
        .flatten(start_dim=0, end_dim=2)
        # .reshape(-1, c, patch_size_h, patch_size_w)
    )
    return patches


def unchop_patches(
    patches: torch.Tensor, img_h: int = 1024, img_w: int = 2048, n: int = 1
) -> torch.Tensor:
    """
    Assumes non-overlapping patches

    See: https://discuss.pytorch.org/t/reshaping-windows-into-image/19805
    """
    _, c, patch_size_h, patch_size_w = patches.shape
    num_h = img_h // patch_size_h
    num_w = img_w // patch_size_w

    img = patches.reshape(n, num_h * num_w, patch_size_h * patch_size_w * c).permute(
        0, 2, 1
    )
    img = F.fold(
        img,
        (img_h, img_w),
        (patch_size_h, patch_size_w),
        1,
        0,
        (patch_size_h, patch_size_w),
    )
    return img.reshape(n, c, img_h, img_w)

def roll_n(X, axis, n):
    f_idx = tuple(
        slice(None, None, None) if i != axis else slice(0, n, None)
        for i in range(X.dim())
    )
    b_idx = tuple(
        slice(None, None, None) if i != axis else slice(n, None, None)
        for i in range(X.dim())
    )
    front = X[f_idx]
    back = X[b_idx]
    return torch.cat([back, front], axis)

In [5]:


# --- Imports --- #
import torch
import torch.nn.functional as F


# --- Perceptual loss network  --- #
class LossNetwork(torch.nn.Module):
    def __init__(self, vgg_model):
        super(LossNetwork, self).__init__()
        self.vgg_layers = vgg_model
        self.layer_name_mapping = {
            '3': "relu1_2",
            '8': "relu2_2",
            '15': "relu3_3"
        }

    def output_features(self, x):
        output = {}
        for name, module in self.vgg_layers._modules.items():
            x = module(x)
            if name in self.layer_name_mapping:
                output[self.layer_name_mapping[name]] = x
        return list(output.values())

    def forward(self, dehaze, gt):
        loss = []
        dehaze_features = self.output_features(dehaze)
        gt_features = self.output_features(gt)
        for dehaze_feature, gt_feature in zip(dehaze_features, gt_features):
            loss.append(F.mse_loss(dehaze_feature, gt_feature))

        return sum(loss)/len(loss)



In [6]:


# --- Imports --- #
import torch
import torch.nn as nn
import torch.nn.functional as F



    
class ConvBlock(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size=3):
        super().__init__()

        self.conv1 =nn.Conv2d(in_channel, in_channel, kernel_size=kernel_size, padding=(kernel_size - 1) // 2) #ConvLayer(in_channel, in_channel, 3)
        self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)

    def forward(self, input):
        out = self.conv1(input)
        out = self.conv2(out)

        return out    
    




        




        
class CAB(nn.Module):
    def __init__(self, features):
        super(CAB, self).__init__()
        #new_features=features//2
        features=features//2
        self.reduce_fature=nn.Conv2d(features*2, features, kernel_size=1, bias=False)
        self.delta_gen1 = nn.Sequential(
                        nn.Conv2d(features*2, features, kernel_size=1, bias=False),
                        nn.Conv2d(features, 2, kernel_size=3, padding=1, bias=False)
                        )

        self.delta_gen2 = nn.Sequential(
                        nn.Conv2d(features*2, features, kernel_size=1, bias=False),
                        nn.Conv2d(features, 2, kernel_size=3, padding=1, bias=False)
                        )


        #self.delta_gen1.weight.data.zero_()
        #self.delta_gen2.weight.data.zero_()

    # https://github.com/speedinghzl/AlignSeg/issues/7
    # the normlization item is set to [w/s, h/s] rather than [h/s, w/s]
    # the function bilinear_interpolate_torch_gridsample2 is standard implementation, please use bilinear_interpolate_torch_gridsample2 for training.
    def bilinear_interpolate_torch_gridsample(self, input, size, delta=0):
        out_h, out_w = size
        n, c, h, w = input.shape
        s = 1.0
        norm = torch.tensor([[[[w/s, h/s]]]]).type_as(input).to(input.device)
        w_list = torch.linspace(-1.0, 1.0, out_h).view(-1, 1).repeat(1, out_w)
        h_list = torch.linspace(-1.0, 1.0, out_w).repeat(out_h, 1)
        grid = torch.cat((h_list.unsqueeze(2), w_list.unsqueeze(2)), 2)
        grid = grid.repeat(n, 1, 1, 1).type_as(input).to(input.device)
        grid = grid + delta.permute(0, 2, 3, 1) / norm

        output = F.grid_sample(input, grid)
        return output

    def bilinear_interpolate_torch_gridsample2(self, input, size, delta=0):
        out_h, out_w = size
        n, c, h, w = input.shape
        s = 2.0
        norm = torch.tensor([[[[(out_w-1)/s, (out_h-1)/s]]]]).type_as(input).to(input.device) # not [h/s, w/s]
        w_list = torch.linspace(-1.0, 1.0, out_h).view(-1, 1).repeat(1, out_w)
        h_list = torch.linspace(-1.0, 1.0, out_w).repeat(out_h, 1)
        grid = torch.cat((h_list.unsqueeze(2), w_list.unsqueeze(2)), 2)
        grid = grid.repeat(n, 1, 1, 1).type_as(input).to(input.device)
        grid = grid + delta.permute(0, 2, 3, 1) / norm

        output = F.grid_sample(input, grid, align_corners=True)
        return output

    def forward(self, low_stage, high_stage):
        h, w = low_stage.size(2), low_stage.size(3)
        high_stage=self.reduce_fature(high_stage)
        high_stage = F.interpolate(input=high_stage, size=(h, w), mode='bilinear', align_corners=True)
        
        concat = torch.cat((low_stage, high_stage), 1)
        delta1 = self.delta_gen1(concat)
        delta2 = self.delta_gen2(concat)
        high_stage = self.bilinear_interpolate_torch_gridsample2(high_stage, (h, w), delta1)
        low_stage = self.bilinear_interpolate_torch_gridsample2(low_stage, (h, w), delta2)

        high_stage += low_stage
        return high_stage

class MakeDense(nn.Module):
    def __init__(self, in_channels, growth_rate, kernel_size=3):
        super(MakeDense, self).__init__()
        self.conv = nn.Conv2d(in_channels, growth_rate, kernel_size=kernel_size, padding=(kernel_size-1)//2)

    def forward(self, x):
        out = F.relu(self.conv(x))
        out = torch.cat((x, out), 1)
        return out

class PALayer(nn.Module):
    def __init__(self, channel):
        super(PALayer, self).__init__()
        self.pa = nn.Sequential(
                nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel // 8, 1, 1, padding=0, bias=True),
                nn.Sigmoid()
        )
    def forward(self, x):
        y = self.pa(x)
        return x * y

class CALayer(nn.Module):
    def __init__(self, channel):
        super(CALayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.ca = nn.Sequential(
                nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel // 8, channel, 1, padding=0, bias=True),
                nn.Sigmoid()
        )

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.ca(y)
        return x * y

class SRDBDK(nn.Module):
    def __init__(self, in_channels, num_dense_layer, growth_rate):
        super(SRDBDK, self).__init__()
        
        modules = []
        self.split_channel=in_channels//8
        kernel_size=3
        dilation=1
        self.conv1 = nn.Conv2d(self.split_channel*1, self.split_channel, kernel_size=9, padding=4, dilation=1)
        dilation=2
        self.conv2 = nn.Conv2d(self.split_channel*2, self.split_channel*1, kernel_size=7, padding=3, dilation=1)
        dilation=4
        self.conv3 = nn.Conv2d(self.split_channel*4, self.split_channel*2, kernel_size=5,  padding=2, dilation=1)
        dilation=8
        self.conv4 = nn.Conv2d(self.split_channel*8, self.split_channel*4, kernel_size=3, padding=1, dilation=1)

            
        #self.residual_dense_layers = nn.Sequential(*modules)
        _in_channels=in_channels
        self.calayer=CALayer(in_channels)
        self.palayer=PALayer(in_channels)
        self.conv_1x1 = nn.Conv2d(_in_channels, in_channels, kernel_size=1, padding=0)

    def forward(self, x):
        splited = torch.split(x, [self.split_channel,self.split_channel*1,self.split_channel*2,self.split_channel*4], dim=1)
        x0=F.relu(self.conv1(splited[0]))
        tmp= torch.cat((splited[1], x0), 1)
        x1=F.relu(self.conv2(tmp))
        tmp= torch.cat((splited[2], x0, x1), 1)
        x2=F.relu(self.conv3(tmp))
        tmp= torch.cat((splited[3], x0, x1, x2), 1)
        x3=F.relu(self.conv4(tmp))
        tmp= torch.cat(( x0, x1, x2, x3), 1)
        
        out = self.conv_1x1(tmp)
        out=self.calayer(out)
        out=self.palayer(out)
        out=out+x
        return out
        
        
class SRDB(nn.Module):
    def __init__(self, in_channels, num_dense_layer, growth_rate):
        super(SRDB, self).__init__()
        
        modules = []
        self.split_channel=in_channels//4
        kernel_size=3
        dilation=1
        self.conv1 = nn.Conv2d(self.split_channel*1, self.split_channel, kernel_size=kernel_size, padding=dilation, dilation=dilation)
        dilation=2
        self.conv2 = nn.Conv2d(self.split_channel*2, self.split_channel, kernel_size=kernel_size, padding=dilation, dilation=dilation)
        dilation=4
        self.conv3 = nn.Conv2d(self.split_channel*3, self.split_channel, kernel_size=kernel_size,  padding=dilation, dilation=dilation)
        dilation=8
        self.conv4 = nn.Conv2d(self.split_channel*4, self.split_channel, kernel_size=kernel_size, padding=dilation, dilation=dilation)

            
        #self.residual_dense_layers = nn.Sequential(*modules)
        _in_channels=in_channels
        self.calayer=CALayer(in_channels)
        self.palayer=PALayer(in_channels)
        self.conv_1x1 = nn.Conv2d(_in_channels, in_channels, kernel_size=1, padding=0)

    def forward(self, x):
        splited = torch.split(x, self.split_channel, dim=1)
        x0=F.relu(self.conv1(splited[0]))
        tmp= torch.cat((splited[1], x0), 1)
        x1=F.relu(self.conv2(tmp))
        tmp= torch.cat((splited[2], x0, x1), 1)
        x2=F.relu(self.conv3(tmp))
        tmp= torch.cat((splited[3], x0, x1, x2), 1)
        x3=F.relu(self.conv4(tmp))
        tmp= torch.cat(( x0, x1, x2, x3), 1)
        
        out = self.conv_1x1(tmp)
        out=self.calayer(out)
        out=self.palayer(out)
        #print(out.shape, x.shape)
        out=out+x
        return out


class SRDBN(nn.Module):
    def __init__(self, in_channels, num_dense_layer, growth_rate):
        super(SRDBN, self).__init__()
        modules = []
        self.split_channel=in_channels//8
        kernel_size=3
        dilation=1
        self.conv1 = nn.Conv2d(self.split_channel*1, self.split_channel, kernel_size=kernel_size, padding=dilation, dilation=dilation)
        dilation=2
        self.conv2 = nn.Conv2d(self.split_channel*2, self.split_channel, kernel_size=kernel_size, padding=dilation, dilation=dilation)
        dilation=4
        self.conv3 = nn.Conv2d(self.split_channel*3, self.split_channel, kernel_size=kernel_size,  padding=dilation, dilation=dilation)
        dilation=8
        self.conv4 = nn.Conv2d(self.split_channel*4, self.split_channel, kernel_size=kernel_size, padding=dilation, dilation=dilation)
        
        dilation=8
        self.conv5 = nn.Conv2d(self.split_channel*5, self.split_channel, kernel_size=kernel_size, padding=dilation, dilation=dilation)
        
        dilation=4
        self.conv6 = nn.Conv2d(self.split_channel*6, self.split_channel, kernel_size=kernel_size, padding=dilation, dilation=dilation)
        
        dilation=2
        self.conv7 = nn.Conv2d(self.split_channel*7, self.split_channel, kernel_size=kernel_size, padding=dilation, dilation=dilation)
        
        dilation=1
        self.conv8 = nn.Conv2d(self.split_channel*8, self.split_channel, kernel_size=kernel_size, padding=dilation, dilation=dilation)

            
        #self.residual_dense_layers = nn.Sequential(*modules)
        _in_channels=in_channels
        self.calayer=CALayer(in_channels)
        self.palayer=PALayer(in_channels)
        self.conv_1x1 = nn.Conv2d(_in_channels, in_channels, kernel_size=1, padding=0)

    def forward(self, x):
        splited = torch.split(x, self.split_channel, dim=1)
        x0=F.relu(self.conv1(splited[0]))
        tmp= torch.cat((splited[1], x0), 1)
        x1=F.relu(self.conv2(tmp))
        tmp= torch.cat((splited[2], x0, x1), 1)
        x2=F.relu(self.conv3(tmp))
        tmp= torch.cat((splited[3], x0, x1, x2), 1)
        x3=F.relu(self.conv4(tmp))
        tmp= torch.cat(( splited[4],x0, x1, x2, x3), 1)
        x4=F.relu(self.conv5(tmp))
        
        tmp= torch.cat(( splited[5],x0, x1, x2, x3,x4), 1)
        x5=F.relu(self.conv6(tmp))
        
        tmp= torch.cat(( splited[6],x0, x1, x2, x3,x4,x5), 1)
        x6=F.relu(self.conv7(tmp))
        
        tmp= torch.cat(( splited[7],x0, x1, x2, x3,x4,x5,x6), 1)
        x7=F.relu(self.conv8(tmp))
        
       
        tmp= torch.cat(( x0, x1, x2, x3,x4,x5,x6,x7), 1)
        out = self.conv_1x1(tmp)
        out=self.calayer(out)
        out=self.palayer(out)
        out=out+x
        return out


        
                

class RDB(nn.Module):
    def __init__(self, in_channels, num_dense_layer, growth_rate):
        super(RDB, self).__init__()
        _in_channels = in_channels
        modules = []
        for i in range(num_dense_layer):
            modules.append(MakeDense(_in_channels, growth_rate))
            _in_channels += growth_rate
        self.residual_dense_layers = nn.Sequential(*modules)
        self.conv_1x1 = nn.Conv2d(_in_channels, in_channels, kernel_size=1, padding=0)

    def forward(self, x):
        out = self.residual_dense_layers(x)
        out = self.conv_1x1(out)
        out = out + x
        return out


In [7]:

# --- Imports --- #
import torch.utils.data as data
from PIL import Image
from random import randrange
from torchvision.transforms import Compose, ToTensor, Normalize

import glob
# --- Training dataset --- #

class TrainData512(data.Dataset):
    def __init__(self, crop_size, train_data_dir):
        super().__init__()
        hazeeffected_images_dir='/home/zsd/data/dehazing/ITS_v2/indoor/reside/hazy/'
        
        hazy_data = glob.glob(hazeeffected_images_dir + "*.png")
        hazefree_images_dir='/home/zsd/data/dehazing/ITS_v2/indoor/reside/clear/'
        haze_names=[]
        gt_names=[]
        for h_image in hazy_data:
		        h_image = h_image.split("/")[-1]
		        id_ = h_image.split("_")[0]  + ".png"
		        haze_names.append(hazeeffected_images_dir+h_image)
		        gt_names.append(hazefree_images_dir+id_)
        self.haze_names = haze_names
        self.gt_names = gt_names
        self.crop_size = crop_size
        self.train_data_dir = train_data_dir

    def get_images(self, index):
        crop_width, crop_height = self.crop_size
        haze_name = self.haze_names[index]
        gt_name = self.gt_names[index]

        haze_img = Image.open(haze_name)

        try:
            gt_img = Image.open(gt_name)
        except:
            gt_img = Image.open(gt_name).convert('RGB')

        

        

        # --- x,y coordinate of left-top corner --- #
        
        haze_crop_img = haze_img.resize((512, 512),Image.ANTIALIAS)
        gt_crop_img = gt_img.resize((512, 512),Image.ANTIALIAS)

        # --- Transform to tensor --- #
        transform_all = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
       
        haze = transform_all(haze_crop_img)
        gt = transform_all(gt_crop_img)


        return haze, gt

    def __getitem__(self, index):
        res = self.get_images(index)
        return res

    def __len__(self):
        return len(self.haze_names)
        
        

class TrainDataNew(data.Dataset):
    def __init__(self, crop_size, train_data_dir):
        super().__init__()
        # hazeeffected_images_dir='/home/zsd/data/dehazing/ITS_v2/indoor/light_dehazenet/data/'
        hazeeffected_images_dir = '/kaggle/input/o-haze/O-HAZY/hazy'
        
        
        hazy_data = glob.glob(hazeeffected_images_dir + "*.jpg")
        # hazefree_images_dir='/home/zsd/data/dehazing/ITS_v2/indoor/light_dehazenet/image/'
        hazefree_images_dir = '/kaggle/input/o-haze/O-HAZY/hazy'
        haze_names=[]
        gt_names=[]
        for h_image in hazy_data:
		        h_image = h_image.split("/")[-1]
		        id_ = h_image.split("_")[0] + "_" + h_image.split("_")[1] + ".jpg"
		        haze_names.append(hazeeffected_images_dir+h_image)
		        gt_names.append(hazefree_images_dir+id_)
        self.haze_names = haze_names
        self.gt_names = gt_names
        self.crop_size = crop_size
        self.train_data_dir = train_data_dir

    def get_images(self, index):
        crop_width, crop_height = self.crop_size
        haze_name = self.haze_names[index]
        gt_name = self.gt_names[index]

        haze_img = Image.open(haze_name)

        try:
            gt_img = Image.open(gt_name)
        except:
            gt_img = Image.open(gt_name).convert('RGB')

        width, height = haze_img.size

        if width < crop_width or height < crop_height:
            raise Exception('Bad image size: {}'.format(gt_name))

        # --- x,y coordinate of left-top corner --- #
        x, y = randrange(0, width - crop_width + 1), randrange(0, height - crop_height + 1)
        haze_crop_img = haze_img.crop((x, y, x + crop_width, y + crop_height))
        gt_crop_img = gt_img.crop((x, y, x + crop_width, y + crop_height))

        # --- Transform to tensor --- #
         #transform_haze = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        transform_gt = Compose([ToTensor()])
        haze = transform_gt(haze_crop_img)
        gt = transform_gt(gt_crop_img)

        # --- Check the channel is 3 or not --- #
        if list(haze.shape)[0] !=  3 or list(gt.shape)[0] !=  3:
            raise Exception('Bad image channel: {}'.format(gt_name))

        return haze, gt

    def __getitem__(self, index):
        res = self.get_images(index)
        return res

    def __len__(self):
        return len(self.haze_names)
        
        
        
class TrainData(data.Dataset):
    def __init__(self, crop_size, train_data_dir,istrain = True):
        super().__init__()
        if istrain:
            # hazeeffected_images_dir = '/Volumes/S/dev/project/code/Aphase/Dehaze_2/data/NH-Haze_Dense-Haze_datasets/NH-HAZE-T/train/hazy'
            # hazefree_images_dir = '/Volumes/S/dev/project/code/Aphase/Dehaze_2/data/NH-Haze_Dense-Haze_datasets/NH-HAZE-T/train/GT'
            # hazeeffected_images_dir = '/kaggle/input/o-haze/O-HAZY/hazy'
            # hazefree_images_dir = '/kaggle/input/o-haze/O-HAZY/GT'
            hazeeffected_images_dir = '/kaggle/input/nh-dense-haze/NH-HAZE-T/NH-HAZE-T/IN'
            hazefree_images_dir = '/kaggle/input/nh-dense-haze/NH-HAZE-T/NH-HAZE-T/GT'
        else:
            hazeeffected_images_dir = '/kaggle/input/nh-dense-haze/NH-HAZE-V/NH-HAZE-V/IN'
            hazefree_images_dir = '/kaggle/input/nh-dense-haze/NH-HAZE-V/NH-HAZE-V/GT'

        hazy_data = glob.glob(os.path.join(hazeeffected_images_dir, "*.*"))
        
        haze_names=[]
        gt_names=[]
        print(hazy_data)
        print(len(hazy_data))
        for h_image in hazy_data:
            # h_image = h_image.split("/")[-1]
            # id_ = h_image.split("_")[0] + "_" + h_image.split("_")[1] + ".jpg"
            # print("id",id_)
            # haze_names.append(hazeeffected_images_dir+h_image)
            # gt_names.append(hazefree_images_dir+id_)
            h_image = h_image.split("/")[-1]  # Extract filename
            haze_names.append(os.path.join(hazeeffected_images_dir, h_image))
            gt_names.append(os.path.join(hazefree_images_dir, h_image)) 

        self.haze_names = haze_names
        self.gt_names = gt_names
        self.crop_size = crop_size
        self.train_data_dir = train_data_dir

    def get_images(self, index):
        crop_width, crop_height = self.crop_size
        haze_name = self.haze_names[index]
        gt_name = self.gt_names[index]

        haze_img = Image.open(haze_name)

        try:
            gt_img = Image.open(gt_name)
        except:
            gt_img = Image.open(gt_name).convert('RGB')

        width, height = haze_img.size

        if width < crop_width or height < crop_height:
            raise Exception('Bad image size: {}'.format(gt_name))

        # --- x,y coordinate of left-top corner --- #
        x, y = randrange(0, width - crop_width + 1), randrange(0, height - crop_height + 1)
        haze_crop_img = haze_img.crop((x, y, x + crop_width, y + crop_height))
        gt_crop_img = gt_img.crop((x, y, x + crop_width, y + crop_height))

        # --- Transform to tensor --- #
        transform_haze = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        transform_gt = Compose([ToTensor()])
        haze = transform_haze(haze_crop_img)
        gt = transform_gt(gt_crop_img)

        # --- Check the channel is 3 or not --- #
        if list(haze.shape)[0] !=  3 or list(gt.shape)[0] !=  3:
            raise Exception('Bad image channel: {}'.format(gt_name))

        return haze, gt

    def __getitem__(self, index):
        res = self.get_images(index)
        return res

    def __len__(self):
        return len(self.haze_names)



In [8]:
"""
paper: GridDehazeNet: Attention-Based Multi-Scale Network for Image Dehazing
file: utils.py
about: all utilities
author: Xiaohong Liu
date: 01/08/19
"""

# --- Imports --- #
import time
import torch
import torch.nn.functional as F
import torchvision.utils as utils
from math import log10
from skimage import measure
from skimage.metrics import structural_similarity as ssim


def to_psnr(dehaze, gt):
    mse = F.mse_loss(dehaze, gt, reduction='none')
    #print (mse)
    mse_split = torch.split(mse, 1, dim=0)
    mse_list = [torch.mean(torch.squeeze(mse_split[ind])).item() for ind in range(len(mse_split))]

    intensity_max = 1.0
    psnr_list = [10.0 * log10(intensity_max / min(max(mse,0.000001),1000)) for mse in mse_list]
    return psnr_list


def to_ssim_skimage(dehaze, gt):
    dehaze_list = torch.split(dehaze, 1, dim=0)
    gt_list = torch.split(gt, 1, dim=0)

    dehaze_list_np = [dehaze_list[ind].permute(0, 2, 3, 1).data.cpu().numpy().squeeze() for ind in range(len(dehaze_list))]
    gt_list_np = [gt_list[ind].permute(0, 2, 3, 1).data.cpu().numpy().squeeze() for ind in range(len(dehaze_list))]
    # ssim_list = [measure.compare_ssim(dehaze_list_np[ind],  gt_list_np[ind], data_range=1, multichannel=True) for ind in range(len(dehaze_list))]
    ssim_list = [ssim(dehaze_list_np[ind], gt_list_np[ind], data_range=1, channel_axis=-1) for ind in range(len(dehaze_list))]

    return ssim_list




def validationStlyle(net, val_data_loader, device, category, save_tag=False):
    """
    :param net: GateDehazeNet
    :param val_data_loader: validation loader
    :param device: The GPU that loads the network
    :param category: indoor or outdoor test dataset
    :param save_tag: tag of saving image or not
    :return: average PSNR value
    """
    psnr_list = []
    ssim_list = []

    for batch_id, val_data in enumerate(val_data_loader):

        with torch.no_grad():
            haze, gt, image_name = val_data
            haze = haze.to(device)
            gt = gt.to(device)
            hazing = net(gt,haze)

        # --- Calculate the average PSNR --- #
        psnr_list.extend(to_psnr(hazing, haze))

        # --- Calculate the average SSIM --- #
        ssim_list.extend(to_ssim_skimage(hazing, haze))

        # --- Save image --- #
        if save_tag:
            save_image(dehaze, image_name, category)
  
    avr_psnr = sum(psnr_list) / len(psnr_list)
    
    avr_ssim = sum(ssim_list) / len(ssim_list)
    return avr_psnr, avr_ssim
    
    
def validationB(net, val_data_loader, device, category, save_tag=False):
    """
    :param net: GateDehazeNet
    :param val_data_loader: validation loader
    :param device: The GPU that loads the network
    :param category: indoor or outdoor test dataset
    :param save_tag: tag of saving image or not
    :return: average PSNR value
    """
    psnr_list = []
    ssim_list = []
    i=0
    for batch_id, val_data in enumerate(val_data_loader):

        with torch.no_grad():
            # haze, gt, image_name = val_data
            haze, gt = val_data
            haze = haze.to(device)
            gt = gt.to(device)
            dehaze, _ = net(haze)

        # --- Calculate the average PSNR --- #
        psnr_list.extend(to_psnr(dehaze, gt))

        # --- Calculate the average SSIM --- #
        ssim_list.extend(to_ssim_skimage(dehaze, gt))

        # --- Save image --- #
        if save_tag:
            i+=1
            save_image(dehaze, i, category)
  
    avr_psnr = sum(psnr_list) / len(psnr_list)
    
    avr_ssim = sum(ssim_list) / len(ssim_list)
    return avr_psnr, avr_ssim
        
    

def validation(net, val_data_loader, device, category, save_tag=False):
    """
    :param net: GateDehazeNet
    :param val_data_loader: validation loader
    :param device: The GPU that loads the network
    :param category: indoor or outdoor test dataset
    :param save_tag: tag of saving image or not
    :return: average PSNR value
    """
    psnr_list = []
    ssim_list = []

    for batch_id, val_data in enumerate(val_data_loader):

        with torch.no_grad():
            haze, gt, image_name = val_data
            haze = haze.to(device)
            gt = gt.to(device)
            dehaze = net(haze)

        # --- Calculate the average PSNR --- #
        psnr_list.extend(to_psnr(dehaze, gt))

        # --- Calculate the average SSIM --- #
        ssim_list.extend(to_ssim_skimage(dehaze, gt))

        # --- Save image --- #
        if save_tag:
            save_image(dehaze, image_name, category)
  
    avr_psnr = sum(psnr_list) / len(psnr_list)
    
    avr_ssim = sum(ssim_list) / len(ssim_list)
    return avr_psnr, avr_ssim
def validationN(net, val_data_loader, device, category, save_tag=False):
    """
    :param net: GateDehazeNet
    :param val_data_loader: validation loader
    :param device: The GPU that loads the network
    :param category: indoor or outdoor test dataset
    :param save_tag: tag of saving image or not
    :return: average PSNR value
    """
    psnr_list = []
    ssim_list = []

    for batch_id, val_data in enumerate(val_data_loader):

        with torch.no_grad():
            haze, gt, image_name = val_data
            haze = haze.to(device)
            gt = gt.to(device)
            dehaze,_,_ = net(haze)

        # --- Calculate the average PSNR --- #
        psnr_list.extend(to_psnr(dehaze, gt))

        # --- Calculate the average SSIM --- #
        ssim_list.extend(to_ssim_skimage(dehaze, gt))

        # --- Save image --- #
        if save_tag:
            save_image(dehaze, image_name, category)

    avr_psnr = sum(psnr_list) / len(psnr_list)
    avr_ssim = sum(ssim_list) / len(ssim_list)
    
    return avr_psnr, avr_ssim


def save_image(dehaze, image_name, category):
    dehaze_images = torch.split(dehaze, 1, dim=0)
    batch_num = len(dehaze_images)

    for ind in range(batch_num):
        utils.save_image(dehaze_images[ind], './{}_results/{}'.format(category, image_name[ind][:-3] + 'png'))


import os
import time

def print_log(epoch, num_epochs, one_epoch_time, train_psnr, val_psnr, val_ssim, category):
    log_dir = "./training_log"
    os.makedirs(log_dir, exist_ok=True)  # Ensure the directory exists

    log_path = os.path.join(log_dir, f"{category}_log.txt")

    print('({0:.0f}s) Epoch [{1}/{2}], Train_PSNR:{3:.2f}, Val_PSNR:{4:.2f}, Val_SSIM:{5:.4f}'
          .format(one_epoch_time, epoch, num_epochs, train_psnr, val_psnr, val_ssim))

    # --- Write the training log --- #
    with open(log_path, 'a') as f:
        print('Date: {0}, Time_Cost: {1:.0f}s, Epoch: [{2}/{3}], Train_PSNR: {4:.2f}, Val_PSNR: {5:.2f}, Val_SSIM: {6:.4f}'
              .format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
                      one_epoch_time, epoch, num_epochs, train_psnr, val_psnr, val_ssim), file=f)



def adjust_learning_rate_step(optimizer, category, lr_decay=0.95):

    # --- Decay learning rate --- #

    for param_group in optimizer.param_groups:
       param_group['lr'] *= lr_decay
       print('Learning rate sets to {}.'.format(param_group['lr']))

            
            
            
def adjust_learning_rate(optimizer, epoch, category, lr_decay=0.90):

    # --- Decay learning rate --- #
    step = 18 if category == 'indoor' else 3
    if category == 'NH':
       step = 20
    #if not category == 'indoor':
       #for param_group in optimizer.param_groups:
            #param_group['lr'] *= 0.99
            #print('Learning rate sets to {}.'.format(param_group['lr']))
    if not epoch % step and epoch > 0:
        for param_group in optimizer.param_groups:
            param_group['lr'] *= lr_decay
            print('Learning rate sets to {}.'.format(param_group['lr']))
    else:
        for param_group in optimizer.param_groups:
            print('Learning rate sets to {}.'.format(param_group['lr']))


In [9]:
# import ipywidgets as widgets
# from IPython.display import display

# # Create widgets for each hyper-parameter
# learning_rate_widget = widgets.FloatText(value=1e-4, description='Learning Rate:')
# crop_size_widget = widgets.Text(value='360,360', description='Crop Size:')
# train_batch_size_widget = widgets.IntText(value=6, description='Train Batch Size:')
# network_height_widget = widgets.IntText(value=3, description='Network Height:')
# network_width_widget = widgets.IntText(value=6, description='Network Width:')
# num_dense_layer_widget = widgets.IntText(value=4, description='Num Dense Layer:')
# growth_rate_widget = widgets.IntText(value=16, description='Growth Rate:')
# lambda_loss_widget = widgets.FloatText(value=0.04, description='Lambda Loss:')
# val_batch_size_widget = widgets.IntText(value=1, description='Val Batch Size:')
# category_widget = widgets.Dropdown(options=['indoor', 'outdoor'], value='indoor', description='Category:')

# # Display the widgets
# display(learning_rate_widget, crop_size_widget, train_batch_size_widget, network_height_widget, network_width_widget, num_dense_layer_widget, growth_rate_widget, lambda_loss_widget, val_batch_size_widget, category_widget)

# # Function to parse the crop size
# def parse_crop_size(crop_size_str):
#     return [int(x) for x in crop_size_str.split(',')]

# # Assign the widget values to variables
# learning_rate = learning_rate_widget.value
# crop_size = parse_crop_size(crop_size_widget.value)
# train_batch_size = train_batch_size_widget.value
# network_height = network_height_widget.value
# network_width = network_width_widget.value
# num_dense_layer = num_dense_layer_widget.value
# growth_rate = growth_rate_widget.value
# lambda_loss = lambda_loss_widget.value
# val_batch_size = val_batch_size_widget.value
# category = category_widget.value

# print('Hyper-parameters set:')
# print(f'learning_rate: {learning_rate}')
# print(f'crop_size: {crop_size}')
# print(f'train_batch_size: {train_batch_size}')
# print(f'network_height: {network_height}')
# print(f'network_width: {network_width}')
# print(f'num_dense_layer: {num_dense_layer}')
# print(f'growth_rate: {growth_rate}')
# print(f'lambda_loss: {lambda_loss}')
# print(f'val_batch_size: {val_batch_size}')
# print(f'category: {category}')

In [10]:
import ipywidgets as widgets
from IPython.display import display

# --- Create widgets for each hyper-parameter ---
learning_rate_widget = widgets.FloatText(value=1e-4, description='Learning Rate:')
crop_size_widget = widgets.Text(value='360,360', description='Crop Size:')
train_batch_size_widget = widgets.IntText(value=6, description='Train Batch Size:')
network_height_widget = widgets.IntText(value=3, description='Network Height:')
network_width_widget = widgets.IntText(value=6, description='Network Width:')
num_dense_layer_widget = widgets.IntText(value=4, description='Num Dense Layer:')
growth_rate_widget = widgets.IntText(value=16, description='Growth Rate:')
lambda_loss_widget = widgets.FloatText(value=0.04, description='Lambda Loss:')
val_batch_size_widget = widgets.IntText(value=1, description='Val Batch Size:')
category_widget = widgets.Dropdown(options=['indoor', 'outdoor', 'nh'], value='nh', description='Category:')
execution_env_widget = widgets.Dropdown(options=['local', 'kaggle'], value='local', description='Execution Env:')

# --- Display the widgets ---
display(
    learning_rate_widget, crop_size_widget, train_batch_size_widget, network_height_widget, 
    network_width_widget, num_dense_layer_widget, growth_rate_widget, lambda_loss_widget, 
    val_batch_size_widget, category_widget, execution_env_widget
)

# --- Function to parse crop size ---
def parse_crop_size(crop_size_str):
    return [int(x) for x in crop_size_str.split(',')]

# --- Assign the widget values to variables ---
learning_rate = learning_rate_widget.value
crop_size = parse_crop_size(crop_size_widget.value)
train_batch_size = train_batch_size_widget.value
network_height = network_height_widget.value
network_width = network_width_widget.value
num_dense_layer = num_dense_layer_widget.value
growth_rate = growth_rate_widget.value
lambda_loss = lambda_loss_widget.value
val_batch_size = val_batch_size_widget.value
category = category_widget.value
execution_env = execution_env_widget.value  # Local or Kaggle

print('\nHyper-parameters set:')
print(f'learning_rate: {learning_rate}')
print(f'crop_size: {crop_size}')
print(f'train_batch_size: {train_batch_size}')
print(f'network_height: {network_height}')
print(f'network_width: {network_width}')
print(f'num_dense_layer: {num_dense_layer}')
print(f'growth_rate: {growth_rate}')
print(f'lambda_loss: {lambda_loss}')
print(f'val_batch_size: {val_batch_size}')
print(f'category: {category}')
print(f'execution_env: {execution_env}')

# --- Set category-specific hyper-parameters ---
if category == 'indoor':
    num_epochs = 1500
    train_data_dir = './data/train/indoor/'
    val_data_dir = './data/test/SOTS/indoor/'
elif category == 'outdoor':
    num_epochs = 10
    train_data_dir = './data/train/outdoor/'
    val_data_dir = './data/test/SOTS/outdoor/'
elif category == 'nh':
    num_epochs = 1000
    train_data_dir = '/Volumes/S/dev/project/code/Aphase/Dehaze_2/data/NH-Haze_Dense-Haze_datasets/NH-HAZE-T/train/hazy'
    val_data_dir = '/Volumes/S/dev/project/code/Aphase/Dehaze_2/data/NH-Haze_Dense-Haze_datasets/NH-HAZE-T/train/GT'
else:
    raise Exception('Wrong image category. Set it to indoor or outdoor for RESIDE dataset.')

# --- Adjust paths based on execution environment ---
if execution_env == 'kaggle':
    # train_data_dir = '/kaggle/input/reside-dataset/' + train_data_dir.strip('./')
    # val_data_dir = '/kaggle/input/reside-dataset/' + val_data_dir.strip('./')
    train_data_dir = '/kaggle/input/o-haze/O-HAZY/hazy'
    val_data_dir = '/kaggle/input/o-haze/O-HAZY/GT' 
print('\nFinal dataset paths:')
print(f'Training directory: {train_data_dir}')
print(f'Validation directory: {val_data_dir}')
print(f'Number of epochs: {num_epochs}')


FloatText(value=0.0001, description='Learning Rate:')

Text(value='360,360', description='Crop Size:')

IntText(value=6, description='Train Batch Size:')

IntText(value=3, description='Network Height:')

IntText(value=6, description='Network Width:')

IntText(value=4, description='Num Dense Layer:')

IntText(value=16, description='Growth Rate:')

FloatText(value=0.04, description='Lambda Loss:')

IntText(value=1, description='Val Batch Size:')

Dropdown(description='Category:', index=2, options=('indoor', 'outdoor', 'nh'), value='nh')

Dropdown(description='Execution Env:', options=('local', 'kaggle'), value='local')


Hyper-parameters set:
learning_rate: 0.0001
crop_size: [360, 360]
train_batch_size: 6
network_height: 3
network_width: 6
num_dense_layer: 4
growth_rate: 16
lambda_loss: 0.04
val_batch_size: 1
category: nh
execution_env: local

Final dataset paths:
Training directory: /Volumes/S/dev/project/code/Aphase/Dehaze_2/data/NH-Haze_Dense-Haze_datasets/NH-HAZE-T/train/hazy
Validation directory: /Volumes/S/dev/project/code/Aphase/Dehaze_2/data/NH-Haze_Dense-Haze_datasets/NH-HAZE-T/train/GT
Number of epochs: 1000


In [11]:
# --- Imports --- #
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
#import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
#from dehazeformer import *
from torchvision.models import vgg16
#plt.switch_backend('agg')






# --- Gpu device --- #
device_ids = [Id for Id in range(torch.cuda.device_count())]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# --- Define the network --- #
net = DeepGuidednew() #GridDehazeNet(height=network_height, width=network_width, num_dense_layer=num_dense_layer, growth_rate=growth_rate)

#net =FFA(3,19)
# --- Build optimizer --- #
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)


# --- Multi-GPU --- #
net = net.to(device)
net = nn.DataParallel(net, device_ids=device_ids)


# --- Define the perceptual loss network --- #
vgg_model = vgg16(pretrained=True).features[:16]
vgg_model = vgg_model.to(device)
for param in vgg_model.parameters():
    param.requires_grad = False

loss_network = LossNetwork(vgg_model)
loss_network.eval()
models='formernew'

# --- Load the network weight --- #
try:
    net.load_state_dict(torch.load(models+'{}_haze_best_{}_{}'.format(category, network_height, network_width)))
    print('--- weight loaded ---')
except:
    print('--- no weight loaded ---')


# --- Calculate all trainable parameters in network --- #
pytorch_total_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print("Total_params: {}".format(pytorch_total_params))


# --- Load training data and validation/test data --- #
train_data_loader = DataLoader(TrainData(crop_size, train_data_dir), batch_size=train_batch_size, shuffle=True)
val_data_loader = DataLoader(TrainData(crop_size, train_data_dir, istrain=False), batch_size=val_batch_size, shuffle=False)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:02<00:00, 197MB/s]


--- no weight loaded ---
Total_params: 4645694
['/kaggle/input/nh-dense-haze/NH-HAZE-T/NH-HAZE-T/IN/37.png', '/kaggle/input/nh-dense-haze/NH-HAZE-T/NH-HAZE-T/IN/35.png', '/kaggle/input/nh-dense-haze/NH-HAZE-T/NH-HAZE-T/IN/11.png', '/kaggle/input/nh-dense-haze/NH-HAZE-T/NH-HAZE-T/IN/31.png', '/kaggle/input/nh-dense-haze/NH-HAZE-T/NH-HAZE-T/IN/03.png', '/kaggle/input/nh-dense-haze/NH-HAZE-T/NH-HAZE-T/IN/40.png', '/kaggle/input/nh-dense-haze/NH-HAZE-T/NH-HAZE-T/IN/33.png', '/kaggle/input/nh-dense-haze/NH-HAZE-T/NH-HAZE-T/IN/09.png', '/kaggle/input/nh-dense-haze/NH-HAZE-T/NH-HAZE-T/IN/02.png', '/kaggle/input/nh-dense-haze/NH-HAZE-T/NH-HAZE-T/IN/14.png', '/kaggle/input/nh-dense-haze/NH-HAZE-T/NH-HAZE-T/IN/08.png', '/kaggle/input/nh-dense-haze/NH-HAZE-T/NH-HAZE-T/IN/39.png', '/kaggle/input/nh-dense-haze/NH-HAZE-T/NH-HAZE-T/IN/20.png', '/kaggle/input/nh-dense-haze/NH-HAZE-T/NH-HAZE-T/IN/38.png', '/kaggle/input/nh-dense-haze/NH-HAZE-T/NH-HAZE-T/IN/10.png', '/kaggle/input/nh-dense-haze/NH-HAZE-

  net.load_state_dict(torch.load(models+'{}_haze_best_{}_{}'.format(category, network_height, network_width)))


In [12]:



# --- Previous PSNR and SSIM in testing --- #
old_val_psnr, old_val_ssim = validationB(net, val_data_loader, device, category)
print('old_val_psnr: {0:.2f}, old_val_ssim: {1:.4f}'.format(old_val_psnr, old_val_ssim))
train_psnrold=0

for epoch in range(num_epochs):
    psnr_list = []
    start_time = time.time()
    adjust_learning_rate(optimizer, epoch, category=category)

    for batch_id, train_data in enumerate(train_data_loader):

        haze, gt = train_data
        haze = haze.to(device)
        gt = gt.to(device)

        # --- Zero the parameter gradients --- #
        optimizer.zero_grad()

        # --- Forward + Backward + Optimize --- #
        net.train()
        dehaze,base = net(haze)
        base_loss = F.smooth_l1_loss(base, gt)

        smooth_loss = F.smooth_l1_loss(dehaze, gt)
        perceptual_loss = loss_network(dehaze, gt)
        loss = smooth_loss + lambda_loss*perceptual_loss+base_loss

        loss.backward()
        optimizer.step()

        # --- To calculate average PSNR --- #
        psnr_list.extend(to_psnr(dehaze, gt))

        if not (batch_id % 100):
            print('Epoch: {0}, Iteration: {1}'.format(epoch, batch_id))

    # --- Calculate the average training PSNR in one epoch --- #
    train_psnr = sum(psnr_list) / len(psnr_list)

    # --- Save the network parameters --- #
    torch.save(net.state_dict(), models+'{}_haze_{}_{}'.format(category, network_height, network_width))

    # --- Use the evaluation model in testing --- #
    net.eval()

    val_psnr, val_ssim = validationB(net, val_data_loader, device, category)
    one_epoch_time = time.time() - start_time
    print_log(epoch+1, num_epochs, one_epoch_time, train_psnr, val_psnr, val_ssim, models+category)
    
    
    if train_psnr< train_psnrold:
        adjust_learning_rate_step(optimizer, category=category)            

    # --- update the network weight --- #
    if val_psnr >= old_val_psnr:
        torch.save(net.state_dict(), models+'{}_haze_best_{}_{}'.format(category, network_height, network_width))
        old_val_psnr = val_psnr


old_val_psnr: 10.68, old_val_ssim: 0.2251
Learning rate sets to 0.0001.
Epoch: 0, Iteration: 0
(12s) Epoch [1/1000], Train_PSNR:11.43, Val_PSNR:13.44, Val_SSIM:0.2976
Learning rate sets to 0.0001.
Epoch: 1, Iteration: 0
(10s) Epoch [2/1000], Train_PSNR:12.62, Val_PSNR:13.98, Val_SSIM:0.2611
Learning rate sets to 0.0001.
Epoch: 2, Iteration: 0
(10s) Epoch [3/1000], Train_PSNR:14.05, Val_PSNR:13.53, Val_SSIM:0.3102
Learning rate sets to 9e-05.
Epoch: 3, Iteration: 0
(10s) Epoch [4/1000], Train_PSNR:14.16, Val_PSNR:14.49, Val_SSIM:0.3417
Learning rate sets to 9e-05.
Epoch: 4, Iteration: 0
(10s) Epoch [5/1000], Train_PSNR:14.41, Val_PSNR:15.97, Val_SSIM:0.4199
Learning rate sets to 9e-05.
Epoch: 5, Iteration: 0
(10s) Epoch [6/1000], Train_PSNR:13.40, Val_PSNR:15.00, Val_SSIM:0.2982
Learning rate sets to 8.1e-05.
Epoch: 6, Iteration: 0
(10s) Epoch [7/1000], Train_PSNR:14.17, Val_PSNR:14.53, Val_SSIM:0.3435
Learning rate sets to 8.1e-05.
Epoch: 7, Iteration: 0
(10s) Epoch [8/1000], Train_PSN