In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import glob
import os
from torchvision.transforms import Compose, ToTensor, Normalize, ToPILImage
from PIL import Image
from random import randrange
from torch.optim.lr_scheduler import CosineAnnealingLR
import matplotlib.pyplot as plt
import math
from math import log10
import numpy as np
from torch.nn.init import _calculate_fan_in_and_fan_out
from timm.layers import to_2tuple, trunc_normal_
import torchvision.utils as utils
import torch.utils.data as data
from torchvision.models import vgg16
from torch.utils.data import Dataset
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import time
from skimage import measure
import ipywidgets as widgets
from IPython.display import display


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Basenet

In [3]:
class RLN(nn.Module):
	r"""Revised LayerNorm"""
	def __init__(self, dim, eps=1e-5, detach_grad=False):
		super(RLN, self).__init__()
		self.eps = eps
		self.detach_grad = detach_grad

		self.weight = nn.Parameter(torch.ones((1, dim, 1, 1)))
		self.bias = nn.Parameter(torch.zeros((1, dim, 1, 1)))

		self.meta1 = nn.Conv2d(1, dim, 1)
		self.meta2 = nn.Conv2d(1, dim, 1)

		trunc_normal_(self.meta1.weight, std=.02)
		nn.init.constant_(self.meta1.bias, 1)

		trunc_normal_(self.meta2.weight, std=.02)
		nn.init.constant_(self.meta2.bias, 0)

	def forward(self, input):
		mean = torch.mean(input, dim=(1, 2, 3), keepdim=True)
		std = torch.sqrt((input - mean).pow(2).mean(dim=(1, 2, 3), keepdim=True) + self.eps)

		normalized_input = (input - mean) / std

		if self.detach_grad:
			rescale, rebias = self.meta1(std.detach()), self.meta2(mean.detach())
		else:
			rescale, rebias = self.meta1(std), self.meta2(mean)

		out = normalized_input * self.weight + self.bias
		return out, rescale, rebias


class Mlp(nn.Module):
	def __init__(self, network_depth, in_features, hidden_features=None, out_features=None):
		super().__init__()
		out_features = out_features or in_features
		hidden_features = hidden_features or in_features

		self.network_depth = network_depth

		self.mlp = nn.Sequential(
			nn.Conv2d(in_features, hidden_features, 1),
			nn.ReLU(True),
			nn.Conv2d(hidden_features, out_features, 1)
		)

		self.apply(self._init_weights)

	def _init_weights(self, m):
		if isinstance(m, nn.Conv2d):
			gain = (8 * self.network_depth) ** (-1/4)
			fan_in, fan_out = _calculate_fan_in_and_fan_out(m.weight)
			std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
			trunc_normal_(m.weight, std=std)
			if m.bias !=   None:
				nn.init.constant_(m.bias, 0)

	def forward(self, x):
		return self.mlp(x)


def window_partition(x, window_size):
	B, H, W, C = x.shape
	x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
	windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size**2, C)
	return windows


def window_reverse(windows, window_size, H, W):
	B = int(windows.shape[0] / (H * W / window_size / window_size))
	x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
	x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
	return x


def get_relative_positions(window_size):
	coords_h = torch.arange(window_size)
	coords_w = torch.arange(window_size)

	coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
	coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
	relative_positions = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww

	relative_positions = relative_positions.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
	relative_positions_log  = torch.sign(relative_positions) * torch.log(1. + relative_positions.abs())

	return relative_positions_log


class WindowAttention(nn.Module):
	def __init__(self, dim, window_size, num_heads):

		super().__init__()
		self.dim = dim
		self.window_size = window_size  # Wh, Ww
		self.num_heads = num_heads
		head_dim = dim // num_heads
		self.scale = head_dim ** -0.5

		relative_positions = get_relative_positions(self.window_size)
		self.register_buffer("relative_positions", relative_positions)
		self.meta = nn.Sequential(
			nn.Linear(2, 256, bias=True),
			nn.ReLU(True),
			nn.Linear(256, num_heads, bias=True)
		)

		self.softmax = nn.Softmax(dim=-1)

	def forward(self, qkv):
		B_, N, _ = qkv.shape

		qkv = qkv.reshape(B_, N, 3, self.num_heads, self.dim // self.num_heads).permute(2, 0, 3, 1, 4)

		q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

		q = q * self.scale
		attn = (q @ k.transpose(-2, -1))

		relative_position_bias = self.meta(self.relative_positions)
		relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
		attn = attn + relative_position_bias.unsqueeze(0)

		attn = self.softmax(attn)

		x = (attn @ v).transpose(1, 2).reshape(B_, N, self.dim)
		return x


class Attention(nn.Module):
	def __init__(self, network_depth, dim, num_heads, window_size, shift_size, use_attn=False, conv_type=None):
		super().__init__()
		self.dim = dim
		self.head_dim = int(dim // num_heads)
		self.num_heads = num_heads

		self.window_size = window_size
		self.shift_size = shift_size

		self.network_depth = network_depth
		self.use_attn = use_attn
		self.conv_type = conv_type

		if self.conv_type == 'Conv':
			self.conv = nn.Sequential(
				nn.Conv2d(dim, dim, kernel_size=3, padding=1, padding_mode='reflect'),
				nn.ReLU(True),
				nn.Conv2d(dim, dim, kernel_size=3, padding=1, padding_mode='reflect')
			)

		if self.conv_type == 'DWConv':
			self.conv = nn.Conv2d(dim, dim, kernel_size=5, padding=2, groups=dim, padding_mode='reflect')

		if self.conv_type == 'DWConv' or self.use_attn:
			self.V = nn.Conv2d(dim, dim, 1)
			self.proj = nn.Conv2d(dim, dim, 1)

		if self.use_attn:
			self.QK = nn.Conv2d(dim, dim * 2, 1)
			self.attn = WindowAttention(dim, window_size, num_heads)

		self.apply(self._init_weights)

	def _init_weights(self, m):
		if isinstance(m, nn.Conv2d):
			w_shape = m.weight.shape
			
			if w_shape[0] == self.dim * 2:	# QK
				fan_in, fan_out = _calculate_fan_in_and_fan_out(m.weight)
				std = math.sqrt(2.0 / float(fan_in + fan_out))
				trunc_normal_(m.weight, std=std)		
			else:
				gain = (8 * self.network_depth) ** (-1/4)
				fan_in, fan_out = _calculate_fan_in_and_fan_out(m.weight)
				std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
				trunc_normal_(m.weight, std=std)

			if m.bias !=  None:
				nn.init.constant_(m.bias, 0)

	def check_size(self, x, shift=False):
		_, _, h, w = x.size()
		mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
		mod_pad_w = (self.window_size - w % self.window_size) % self.window_size

		if shift:
			x = F.pad(x, (self.shift_size, (self.window_size-self.shift_size+mod_pad_w) % self.window_size,
						  self.shift_size, (self.window_size-self.shift_size+mod_pad_h) % self.window_size), mode='reflect')
		else:
			x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
		return x

	def forward(self, X):
		B, C, H, W = X.shape

		if self.conv_type == 'DWConv' or self.use_attn:
			V = self.V(X)
		#print(self.use_attn)
		if self.use_attn:
			#print('attention')      
			QK = self.QK(X)
			QKV = torch.cat([QK, V], dim=1)

			# shift
			shifted_QKV = self.check_size(QKV, self.shift_size > 0)
			Ht, Wt = shifted_QKV.shape[2:]

			# partition windows
			shifted_QKV = shifted_QKV.permute(0, 2, 3, 1)
			qkv = window_partition(shifted_QKV, self.window_size)  # nW*B, window_size**2, C

			attn_windows = self.attn(qkv)

			# merge windows
			shifted_out = window_reverse(attn_windows, self.window_size, Ht, Wt)  # B H' W' C

			# reverse cyclic shift
			out = shifted_out[:, self.shift_size:(self.shift_size+H), self.shift_size:(self.shift_size+W), :]
			attn_out = out.permute(0, 3, 1, 2)

			if self.conv_type in ['Conv', 'DWConv']:
				conv_out = self.conv(V)
				out = self.proj(conv_out + attn_out)
			else:
				out = self.proj(attn_out)

		else:
			if self.conv_type == 'Conv':
				out = self.conv(X)				# no attention and use conv, no projection
			elif self.conv_type == 'DWConv':
				out = self.proj(self.conv(V))

		return out


class TransformerBlock(nn.Module):
	def __init__(self, network_depth, dim, num_heads, mlp_ratio=4.,
				 norm_layer=nn.LayerNorm, mlp_norm=False,
				 window_size=8, shift_size=0, use_attn=True, conv_type=None):
		super().__init__()
		self.use_attn = use_attn
		self.mlp_norm = mlp_norm

		self.norm1 = norm_layer(dim) if use_attn else nn.Identity()
		self.attn = Attention(network_depth, dim, num_heads=num_heads, window_size=window_size,
							  shift_size=shift_size, use_attn=use_attn, conv_type=conv_type)

		self.norm2 = norm_layer(dim) if use_attn and mlp_norm else nn.Identity()
		self.mlp = Mlp(network_depth, dim, hidden_features=int(dim * mlp_ratio))

	def forward(self, x):
		identity = x
		if self.use_attn: x, rescale, rebias = self.norm1(x)
		x = self.attn(x)
		if self.use_attn: x = x * rescale + rebias
		x = identity + x

		identity = x
		if self.use_attn and self.mlp_norm: x, rescale, rebias = self.norm2(x)
		x = self.mlp(x)
		if self.use_attn and self.mlp_norm: x = x * rescale + rebias
		x = identity + x
		return x


class BasicLayer(nn.Module):
	def __init__(self, network_depth, dim, depth, num_heads, mlp_ratio=4.,
				 norm_layer=nn.LayerNorm, window_size=8,
				 attn_ratio=0., attn_loc='last', conv_type=None):

		super().__init__()
		self.dim = dim
		self.depth = depth

		attn_depth = attn_ratio * depth

		if attn_loc == 'last':
			use_attns = [i >= depth-attn_depth for i in range(depth)]
		elif attn_loc == 'first':
			use_attns = [i < attn_depth for i in range(depth)]
		elif attn_loc == 'middle':
			use_attns = [i >= (depth-attn_depth)//2 and i < (depth+attn_depth)//2 for i in range(depth)]

		# build blocks
		self.blocks = nn.ModuleList([
			TransformerBlock(network_depth=network_depth,
							 dim=dim, 
							 num_heads=num_heads,
							 mlp_ratio=mlp_ratio,
							 norm_layer=norm_layer,
							 window_size=window_size,
							 shift_size=0 if (i % 2 == 0) else window_size // 2,
							 use_attn=use_attns[i], conv_type=conv_type)
			for i in range(depth)])

	def forward(self, x):
		for blk in self.blocks:
			x = blk(x)
		return x


class PatchEmbed(nn.Module):
	def __init__(self, patch_size=4, in_chans=3, embed_dim=96, kernel_size=None):
		super().__init__()
		self.in_chans = in_chans
		self.embed_dim = embed_dim

		if kernel_size is None:
			kernel_size = patch_size

		self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=patch_size,
							  padding=(kernel_size-patch_size+1)//2, padding_mode='reflect')

	def forward(self, x):
		x = self.proj(x)
		return x


class PatchUnEmbed(nn.Module):
	def __init__(self, patch_size=4, out_chans=3, embed_dim=96, kernel_size=None):
		super().__init__()
		self.out_chans = out_chans
		self.embed_dim = embed_dim

		if kernel_size is None:
			kernel_size = 1

		self.proj = nn.Sequential(
			nn.Conv2d(embed_dim, out_chans*patch_size**2, kernel_size=kernel_size,
					  padding=kernel_size//2, padding_mode='reflect'),
			nn.PixelShuffle(patch_size)
		)

	def forward(self, x):
		x = self.proj(x)
		return x


class SKFusion(nn.Module):
	def __init__(self, dim, height=2, reduction=8):
		super(SKFusion, self).__init__()
		
		self.height = height
		d = max(int(dim/reduction), 4)
		
		self.avg_pool = nn.AdaptiveAvgPool2d(1)
		self.mlp = nn.Sequential(
			nn.Conv2d(dim, d, 1, bias=False), 
			nn.ReLU(),
			nn.Conv2d(d, dim*height, 1, bias=False)
		)
		
		self.softmax = nn.Softmax(dim=1)

	def forward(self, in_feats):
		B, C, H, W = in_feats[0].shape
		
		in_feats = torch.cat(in_feats, dim=1)
		in_feats = in_feats.view(B, self.height, C, H, W)
		
		feats_sum = torch.sum(in_feats, dim=1)
		attn = self.mlp(self.avg_pool(feats_sum))
		attn = self.softmax(attn.view(B, self.height, C, 1, 1))

		out = torch.sum(in_feats*attn, dim=1)
		return out      

In [4]:
class DehazeFormer(nn.Module):
	def __init__(self, in_chans=3, out_chans=4, window_size=8,
				 embed_dims=[24, 48, 96, 48, 24],
				 mlp_ratios=[2., 4., 4., 2., 2.],
				 depths=[16, 16, 16, 8, 8],
				 num_heads=[2, 4, 6, 1, 1],
				 attn_ratio=[1/4, 1/2, 3/4, 0, 0],
				 conv_type=['DWConv', 'DWConv', 'DWConv', 'DWConv', 'DWConv'],
				 norm_layer=[RLN, RLN, RLN, RLN, RLN]):
		super(DehazeFormer, self).__init__()

		# setting
		self.patch_size = 4
		self.window_size = window_size
		self.mlp_ratios = mlp_ratios

		# split image into non-overlapping patches
		self.patch_embed = PatchEmbed(
			patch_size=1, in_chans=in_chans, embed_dim=embed_dims[0], kernel_size=3)

		# backbone
		self.layer1 = BasicLayer(network_depth=sum(depths), dim=embed_dims[0], depth=depths[0],
					   			 num_heads=num_heads[0], mlp_ratio=mlp_ratios[0],
					   			 norm_layer=norm_layer[0], window_size=window_size,
					   			 attn_ratio=attn_ratio[0], attn_loc='last', conv_type=conv_type[0])

		self.patch_merge1 = PatchEmbed(
			patch_size=2, in_chans=embed_dims[0], embed_dim=embed_dims[1])

		self.skip1 = nn.Conv2d(embed_dims[0], embed_dims[0], 1)

		self.layer2 = BasicLayer(network_depth=sum(depths), dim=embed_dims[1], depth=depths[1],
								 num_heads=num_heads[1], mlp_ratio=mlp_ratios[1],
								 norm_layer=norm_layer[1], window_size=window_size,
								 attn_ratio=attn_ratio[1], attn_loc='last', conv_type=conv_type[1])

		self.patch_merge2 = PatchEmbed(
			patch_size=2, in_chans=embed_dims[1], embed_dim=embed_dims[2])

		self.skip2 = nn.Conv2d(embed_dims[1], embed_dims[1], 1)

		self.layer3 = BasicLayer(network_depth=sum(depths), dim=embed_dims[2], depth=depths[2],
								 num_heads=num_heads[2], mlp_ratio=mlp_ratios[2],
								 norm_layer=norm_layer[2], window_size=window_size,
								 attn_ratio=attn_ratio[2], attn_loc='last', conv_type=conv_type[2])

		self.patch_split1 = PatchUnEmbed(
			patch_size=2, out_chans=embed_dims[3], embed_dim=embed_dims[2])

		assert embed_dims[1] == embed_dims[3]
		self.fusion1 = SKFusion(embed_dims[3])

		self.layer4 = BasicLayer(network_depth=sum(depths), dim=embed_dims[3], depth=depths[3],
								 num_heads=num_heads[3], mlp_ratio=mlp_ratios[3],
								 norm_layer=norm_layer[3], window_size=window_size,
								 attn_ratio=attn_ratio[3], attn_loc='last', conv_type=conv_type[3])

		self.patch_split2 = PatchUnEmbed(
			patch_size=2, out_chans=embed_dims[4], embed_dim=embed_dims[3])

		assert embed_dims[0] == embed_dims[4]
		self.fusion2 = SKFusion(embed_dims[4])			

		self.layer5 = BasicLayer(network_depth=sum(depths), dim=embed_dims[4], depth=depths[4],
					   			 num_heads=num_heads[4], mlp_ratio=mlp_ratios[4],
					   			 norm_layer=norm_layer[4], window_size=window_size,
					   			 attn_ratio=attn_ratio[4], attn_loc='last', conv_type=conv_type[4])

		# merge non-overlapping patches into image
		self.patch_unembed = PatchUnEmbed(
			patch_size=1, out_chans=out_chans, embed_dim=embed_dims[4], kernel_size=3)


	def check_image_size(self, x):
		# NOTE: for I2I test
		_, _, h, w = x.size()
		mod_pad_h = (self.patch_size - h % self.patch_size) % self.patch_size
		mod_pad_w = (self.patch_size - w % self.patch_size) % self.patch_size
		x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
		return x

	def forward_features(self, x):
		x = self.patch_embed(x)
		x = self.layer1(x)
		skip1 = x

		x = self.patch_merge1(x)
		x = self.layer2(x)
		skip2 = x

		x = self.patch_merge2(x)
		x = self.layer3(x)
		x = self.patch_split1(x)

		x = self.fusion1([x, self.skip2(skip2)]) + x
		x = self.layer4(x)
		x = self.patch_split2(x)

		x = self.fusion2([x, self.skip1(skip1)]) + x
		x = self.layer5(x)
		x = self.patch_unembed(x)
		return x

	def forward(self, x):
		H, W = x.shape[2:]
		x = self.check_image_size(x)

		feat = self.forward_features(x)
		K, B = torch.split(feat, (1, 3), dim=1)

		x = K * x - B + x
		x = x[:, :, :H, :W]
		return x

In [5]:
def dehazeformer_m():
    return DehazeFormer(
		embed_dims=[24, 48, 96, 48, 24],
		mlp_ratios=[2., 4., 4., 2., 2.],
		depths=[12, 12, 12, 6, 6],
		num_heads=[2, 4, 6, 1, 1],
		attn_ratio=[1/4, 1/2, 3/4, 0, 0],
		conv_type=['Conv', 'Conv', 'Conv', 'Conv', 'Conv'])

# DetailNet

In [6]:
# -----------------------------
# TEACHER MODEL
# -----------------------------
class SR_model(nn.Module):
    def __init__(self, upscale_factor=1):
        super(SR_model, self).__init__()
        self.downsample = nn.Upsample(
            scale_factor=0.5, mode="bilinear", align_corners=True
        )
        self.feature_extraction = nn.Sequential(
            nn.Conv2d(3, 56, kernel_size=5, padding=2),
            nn.PReLU()
        )
        self.shrinking = nn.Sequential(
            nn.Conv2d(56, 24, kernel_size=3, padding=1),
            nn.PReLU()
        )
        self.mapping = nn.Sequential(
            nn.Conv2d(24, 24, kernel_size=3, padding=1),
            nn.PReLU(),
            nn.Conv2d(24, 24, kernel_size=3, padding=1),
            nn.PReLU(),
            nn.Conv2d(24, 24, kernel_size=3, padding=1),
            nn.PReLU(),
            nn.Conv2d(24, 24, kernel_size=3, padding=1),
            nn.PReLU()
        )
        self.expanding = nn.Sequential(
            nn.Conv2d(24, 56, kernel_size=3, padding=1),
            nn.PReLU()
        )
        self.deconvolution = nn.Sequential(
            nn.Conv2d(56, 3, kernel_size=3, padding=1),
            nn.Tanh()  # Ensure output is normalized
        )

    def forward(self, x):
        x = self.downsample(x)
        residual = x  # Store input for residual connection
        x = self.feature_extraction(x)
        x = self.shrinking(x)
        x = self.mapping(x)
        x = self.expanding(x)
        x = self.deconvolution(x)
        return x + residual  # Add residual for stability


In [7]:
# -----------------------------
# STUDENT MODEL
# -----------------------------
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        )
    
    def forward(self, x):
        return x + self.block(x)  # Residual connection

class SpectralAttention(nn.Module):
    def __init__(self, channels):
        super(SpectralAttention, self).__init__()
        self.fc = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels // 16, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(channels // 16, channels, kernel_size=1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return x * self.fc(x)

class SpatialAttention(nn.Module):
    def __init__(self, channels):
        super(SpatialAttention, self).__init__()
        self.conv = nn.Conv2d(channels, 1, kernel_size=1)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        return x * self.sigmoid(self.conv(x))

class MultiScaleFeatureFusion(nn.Module):
    def __init__(self, channels):
        super(MultiScaleFeatureFusion, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(channels, channels, kernel_size=5, padding=2)
        self.conv5 = nn.Conv2d(channels, channels, kernel_size=7, padding=3)
    
    def forward(self, x):
        return self.conv1(x) + self.conv3(x) + self.conv5(x)

class DehazingNet(nn.Module):
    def __init__(self):
        super(DehazingNet, self).__init__()
        self.initial_conv = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.residual_blocks = nn.Sequential(
            ResidualBlock(64),
            ResidualBlock(64),
            ResidualBlock(64)
        )
        self.downsample = nn.Upsample(
            scale_factor=0.5, mode="bilinear", align_corners=True
        )
        self.spectral_attention = SpectralAttention(64)
        self.spatial_attention = SpatialAttention(64)
        self.multi_scale_fusion = MultiScaleFeatureFusion(64)
        self.final_conv = nn.Conv2d(64, 3, kernel_size=3, padding=1)
    
    def forward(self, x):
        x = self.downsample(x)
        x = F.relu(self.initial_conv(x))
        x = self.residual_blocks(x)
        x = self.spectral_attention(x)
        x = self.spatial_attention(x)
        x = self.multi_scale_fusion(x)
        x = self.final_conv(x)
        return x 


In [8]:
# -----------------------------
# FEATURE AFFINITY MODULE (FAM) USING KL DIVERGENCE
# -----------------------------
class FeatureAffinityModule(nn.Module):
    def __init__(self):
        super(FeatureAffinityModule, self).__init__()
        self.pool = nn.AdaptiveAvgPool2d((16, 16))

    def forward(self, features_a, features_b):
        # Pool and flatten features
        feat_a = self.pool(features_a).view(features_a.size(0), -1)
        feat_b = self.pool(features_b).view(features_b.size(0), -1)

        # Normalize features (important for stable KL divergence)
        feat_a = F.normalize(feat_a, p=2, dim=-1)
        feat_b = F.normalize(feat_b, p=2, dim=-1)

        # Compute normalized affinity matrices
        affinity_a = torch.mm(feat_a, feat_a.T) / feat_a.size(1)
        affinity_b = torch.mm(feat_b, feat_b.T) / feat_b.size(1)

        # Compute symmetric KL divergence loss
        loss = 0.5 * (F.kl_div(F.log_softmax(affinity_a, dim=-1), F.softmax(affinity_b, dim=-1), reduction='batchmean') +
                      F.kl_div(F.log_softmax(affinity_b, dim=-1), F.softmax(affinity_a, dim=-1), reduction='batchmean'))

        return loss


# Guided Filter

In [9]:
class ConvGuidedFilter(nn.Module):
    """
    Adapted from https://github.com/wuhuikai/DeepGuidedFilter
    """
    def __init__(self, radius=1, norm=nn.BatchNorm2d, conv_a_kernel_size: int = 1):
        super(ConvGuidedFilter, self).__init__()

        self.box_filter = nn.Conv2d(
            3, 3, kernel_size=3, padding=radius, dilation=radius, bias=False, groups=3
        )
        self.conv_a = nn.Sequential(
            nn.Conv2d(
                6,
                32,
                kernel_size=conv_a_kernel_size,
                padding=conv_a_kernel_size // 2,
                bias=False,
            ),
            norm(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                32,
                32,
                kernel_size=conv_a_kernel_size,
                padding=conv_a_kernel_size // 2,
                bias=False,
            ),
            norm(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                32,
                3,
                kernel_size=conv_a_kernel_size,
                padding=conv_a_kernel_size // 2,
                bias=False,
            ),
        )
        self.box_filter.weight.data[...] = 1.0

    def forward(self, x_lr, y_lr, x_hr):
        _, _, h_lrx, w_lrx = x_lr.size()
        _, _, h_hrx, w_hrx = x_hr.size()

        N = self.box_filter(x_lr.data.new().resize_((1, 3, h_lrx, w_lrx)).fill_(1.0))
        ## mean_x
        mean_x = self.box_filter(x_lr) / N
        ## mean_y
        mean_y = self.box_filter(y_lr) / N
        ## cov_xy
        cov_xy = self.box_filter(x_lr * y_lr) / N - mean_x * mean_y
        ## var_x
        var_x = self.box_filter(x_lr * x_lr) / N - mean_x * mean_x

        ## A
        A = self.conv_a(torch.cat([cov_xy, var_x], dim=1))
        ## b
        b = mean_y - A * mean_x

        ## mean_A; mean_b
        mean_A = F.interpolate(A, (h_hrx, w_hrx), mode="bilinear", align_corners=True)
        mean_b = F.interpolate(b, (h_hrx, w_hrx), mode="bilinear", align_corners=True)

        return mean_A * x_hr + mean_b

In [10]:
class AdaptiveInstanceNorm(nn.Module):
    def __init__(self, n):
        super(AdaptiveInstanceNorm, self).__init__()

        self.w_0 = nn.Parameter(torch.Tensor([1.0]))
        self.w_1 = nn.Parameter(torch.Tensor([0.0]))

        self.ins_norm = nn.InstanceNorm2d(n, momentum=0.999, eps=0.001, affine=True)

    def forward(self, x):
        return self.w_0 * x + self.w_1 * self.ins_norm(x)

In [11]:
class DeepGuidednew(nn.Module):
    def __init__(self, radius=1):
        super().__init__()
        norm = AdaptiveInstanceNorm
        kernel_size=3
        depth_rate=16
        in_channels=3
        num_dense_layer=4
        growth_rate=16
        growth_rate=16

        # self.local = local
        
        # self.conv_in = nn.Conv2d(in_channels, depth_rate, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)
        # self.conv_out = nn.Conv2d(depth_rate, in_channels, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)

        self.gf = ConvGuidedFilter(radius, norm=norm)
        self.lr = dehazeformer_m()

        self.downsample = nn.Upsample(
            scale_factor=0.5, mode="bilinear", align_corners=True
        )
        self.upsample = nn.Upsample(
            scale_factor=2, mode="bilinear", align_corners=True
        )

    def forward(self, x_hr, y_detail):
        x_lr = self.downsample(x_hr)
        # y_lr=self.conv_in(x_lr)
        # y_lr= self.local(y_lr)
        # y_detail=self.conv_out(y_lr)
        y_base=self.lr(x_lr)
        # print(y_base.shape, y_detail.shape)
        y_lr=y_base+ y_detail
        y_base=self.upsample(y_base)
        return  self.gf(x_lr, y_lr, x_hr), y_base      

# Data Loading

In [12]:
# -----------------------------
# CUSTOM DATASET LOADER
# -----------------------------
class TrainData(Dataset):
    def __init__(self, crop_size, hazeeffected_images_dir, hazefree_images_dir):
        super().__init__()
        hazy_data = glob.glob(os.path.join(hazeeffected_images_dir, "*.*"))
        self.haze_names = [os.path.join(hazeeffected_images_dir, os.path.basename(h)) for h in hazy_data]
        self.gt_names = [os.path.join(hazefree_images_dir, os.path.basename(h)) for h in hazy_data]
        self.crop_size = crop_size
    
    def get_images(self, index):
        crop_width, crop_height = self.crop_size
        haze_img = Image.open(self.haze_names[index]).convert('RGB')
        gt_img = Image.open(self.gt_names[index]).convert('RGB')
        
        width, height = haze_img.size
        x, y = randrange(0, width - crop_width + 1), randrange(0, height - crop_height + 1)
        haze_crop_img = haze_img.crop((x, y, x + crop_width, y + crop_height))
        gt_crop_img = gt_img.crop((x, y, x + crop_width, y + crop_height))
        
        transform = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        return transform(haze_crop_img), transform(gt_crop_img)

    def __getitem__(self, index):
        return self.get_images(index)

    def __len__(self):
        return len(self.haze_names)

In [13]:
crop_size = (360, 360)
# train_data = TrainData(crop_size, '/kaggle/input/nh-dense-haze/Dense-Haze-T/Dense-Haze-T/IN', '/kaggle/input/ nh-dense-haze/Dense-Haze-T/Dense-Haze-T/GT')
train_data = TrainData(crop_size, '/kaggle/input/reside6k/RESIDE-6K/train/hazy', '/kaggle/input/reside6k/RESIDE-6K/train/GT')
dataloader = DataLoader(train_data, batch_size=4, shuffle=True)

In [14]:
# Get a single batch from the dataloader
for hazy_images, clear_images in dataloader:
    print(f"Hazy Images Shape: {hazy_images.shape}")
    print(f"Clear Images Shape: {clear_images.shape}")
    break  # Only check one batch

Hazy Images Shape: torch.Size([4, 3, 360, 360])
Clear Images Shape: torch.Size([4, 3, 360, 360])


# Perceptual Loss

In [15]:
from torchvision.models import vgg16
loss_model = vgg16(pretrained=True)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:03<00:00, 182MB/s]  


In [16]:
loss_model

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [17]:
loss_model = loss_model.features
loss_model = loss_model.to(device)
for param in loss_model.parameters():
    param.requires_grad = False

In [18]:
class FeatureLossNetwork(torch.nn.Module):
    def __init__(self, feature_extractor):
        super(FeatureLossNetwork, self).__init__()
        self.feature_layers = feature_extractor
        self.layer_name_mapping = {
            '1': "relu1_1",
            # '3': "relu1_2",
            # '6': "relu2_1",
            # '8': "relu2_2",
            # '11': "relu3_1",
            # '13': "relu3_2",
            # '15': "relu3_3",
            '18': "relu4_1",
            # '20': "relu4_2",
            # '22': "relu4_3",
            # '25': "relu5_1",
            # '27': "relu5_2",
            '29': "relu5_3"
        }

    def extract_features(self, x):
        output = {}
        for name, module in self.feature_layers._modules.items():
            x = module(x)
            if name in self.layer_name_mapping:
                output[self.layer_name_mapping[name]] = x
        return list(output.values())

    def forward(self, predicted, ground_truth):
        loss = []
        scale_factor = 1000  
        predicted_features = self.extract_features(predicted)
        ground_truth_features = self.extract_features(ground_truth)
        for pred_feature, gt_feature in zip(predicted_features, ground_truth_features):
            loss.append(F.mse_loss(pred_feature, gt_feature))

        return sum(loss) / (len(loss) * scale_factor)

In [19]:
loss_network = FeatureLossNetwork(loss_model)
loss_network.eval()

FeatureLossNetwork(
  (feature_layers): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, 

In [20]:
# # Test loss network with generated image and original image
# if isinstance(dehazed_image, np.ndarray):
#     dehazed_image = torch.from_numpy(dehazed_image).float().permute(2, 0, 1).unsqueeze(0).to(device)

# if isinstance(image, np.ndarray):
#     image = torch.from_numpy(image).float().permute(2, 0, 1).unsqueeze(0).to(device)

# # Compute loss
# dehazed_image = dehazed_image.to(device)
# image.to(device)
# loss = loss_network(dehazed_image, image)
# print("Feature Loss:", loss.item())

In [21]:
learning_rate = 1e-4

In [22]:
# --- GPU device --- #
device_ids = list(range(torch.cuda.device_count()))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# --- Define the network --- #
net = DeepGuidednew()

# --- Multi-GPU (correct order) --- #
net = nn.DataParallel(net, device_ids=device_ids).to(device)

# --- Build optimizer --- #
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

# # --- Define the perceptual loss network --- #
# vgg_model = vgg16(pretrained=True).features[:16].to(device)
# for param in vgg_model.parameters():
#     param.requires_grad = False

# loss_network = LossNetwork(vgg_model)
# loss_network.eval()

# models = 'formernew'

# --- Load the network weight --- #
# weight_path = "{}_{}_haze_best_{}".format(models, category, version)
# try:
#     net.load_state_dict(torch.load(weight_path))
#     print('--- weight loaded ---')
# except FileNotFoundError:
#     print('--- no weight loaded ---')

# --- Calculate all trainable parameters in network --- #
pytorch_total_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print("Total_params: {}".format(pytorch_total_params))

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Total_params: 4637423


In [23]:
lambda_loss = 0.84
print(f'lambda_loss: {lambda_loss}')

lambda_loss: 0.84


In [24]:
# -----------------------------
# CO-DISTILLATION TRAINING
# -----------------------------
import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
import csv

def calculate_psnr(output, target, max_pixel_value=1.0):
    mse = F.mse_loss(output, target)
    # print(f"MSE: {mse.item()}")
    if mse == 0:
        return 100  # Avoid log(0) case, return max PSNR
    psnr = 20 * math.log10(max_pixel_value) - 10 * math.log10(mse.item())
    # print(f"PSNR: {psnr}")
    return psnr

# -----------------------------
# CO-DISTILLATION TRAINING
# -----------------------------
def train(net, teacher, student, fam, dataloader, num_epochs=10, lambda_fam=0.25, log_file="training_log.csv"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    # net.to(device).train()
    teacher.to(device).train()
    student.to(device).train()
    fam.to(device)

    optimizer_t = torch.optim.Adam(teacher.parameters(), lr=1e-2)
    optimizer_s = torch.optim.Adam(student.parameters(), lr=1e-2)
    optimizer_d = torch.optim.Adam(net.parameters(), lr=1e-2)
    
    scheduler_t = CosineAnnealingLR(optimizer_t, T_max=num_epochs, eta_min=1e-3)
    scheduler_s = CosineAnnealingLR(optimizer_s, T_max=num_epochs, eta_min=1e-3)
    scheduler_d = CosineAnnealingLR(optimizer_d, T_max=num_epochs, eta_min=1e-3)
    
    best_loss = float('inf')
    best_psnr = 0

    with open(log_file, mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(["Epoch", "Loss", "Teacher PSNR", "Student PSNR"])
        
        for epoch in range(num_epochs):
            print(f"Epoch {epoch + 1}/{num_epochs}")
            total_loss = 0
            total_psnr_t = 0  # Teacher PSNR
            total_psnr_s = 0  # Student PSNR
            num_batches = len(dataloader)
            teacher_output = 0
            student_output = 0
            
            for hazy_images, clear_images in dataloader:
                hazy_images, clear_images = hazy_images.to(device), clear_images.to(device)
                print(f"Hazy images shape: {hazy_images.shape}, Clear images shape: {clear_images.shape}")
                
                teacher_output = teacher(clear_images)
                student_output = student(hazy_images)
                print(f"Teacher output shape: {teacher_output.shape}, Student output shape: {student_output.shape}")

                dehaze,base = net(hazy_images, student_output)
                
                base_loss = F.smooth_l1_loss(base, clear_images)
                smooth_loss = F.smooth_l1_loss(dehaze, clear_images)
                perceptual_loss = loss_network(dehaze, clear_images)

                # print("Type: Teacher:", type(teacher_output))
                # print("Type: student_output:", type(student_output))
                # print("Type: detail_output:", type(detail_output))
                downsample = nn.Upsample(
                    scale_factor=0.5, mode="bilinear", align_corners=True
                )
                clear_images = downsample(clear_images)

                # mse_loss_d = F.mse_loss(detail_output, clear_images)
                mse_loss_t = F.mse_loss(teacher_output, clear_images)
                mse_loss_s = F.mse_loss(student_output, clear_images)
                # print(f"MSE Loss - Teacher: {mse_loss_t.item()}, Student: {mse_loss_s.item()}")
                
                fam_loss = fam(teacher_output, student_output) 
                # print(f"FAM Loss: {fam_loss.item()}")
                print("Base Loss:", base_loss)
                print("Smooth Loss:", smooth_loss)
                print("Lambda Loss * Perceptual Loss:", lambda_loss * perceptual_loss)
                print("FAM Loss:", fam_loss)

                
                loss = base_loss + smooth_loss + lambda_loss * perceptual_loss + fam_loss 
                
                # loss = base_loss + smooth_loss + lambda_loss * perceptual_loss+ fam_loss + mse_loss_t + mse_loss_s
                print(f"Total Loss: {loss.item()}")
                
                optimizer_t.zero_grad()
                optimizer_s.zero_grad()
                optimizer_d.zero_grad()
                loss.backward()
                optimizer_t.step()
                optimizer_s.step()
                optimizer_d.step()
                
                # print(f"PSNR - Teacher: {psnr_t}, Student: {psnr_s}")
                
                total_loss += loss.item()
            # Compute PSNR for teacher and student
            psnr_t = calculate_psnr(teacher_output, clear_images)
            psnr_s = calculate_psnr(student_output, clear_images)
            total_psnr_t += psnr_t
            total_psnr_s += psnr_s

            avg_loss = total_loss / num_batches
            avg_psnr_t = total_psnr_t / num_batches
            avg_psnr_s = total_psnr_s / num_batches
            print(f"Epoch {epoch + 1} - Avg Loss: {avg_loss}, Avg PSNR (Teacher): {avg_psnr_t}, Avg PSNR (Student): {avg_psnr_s}")


            log_entry = f"Epoch {epoch+1}/{num_epochs}, Avg Loss: {avg_loss:.6f}, Avg Teacher PSNR: {avg_psnr_t:.2f}, Avg Student PSNR: {avg_psnr_s:.2f}"
            print(log_entry)

            # Write log to file
            writer.writerow([epoch + 1, avg_loss, avg_psnr_t, avg_psnr_s])

            # Update schedulers
            scheduler_t.step()
            scheduler_s.step()
            scheduler_d.step()

            # Save only if the model improves
            # if avg_loss < best_loss or epoch %50 ==0:
            best_loss = avg_loss
            best_psnr = avg_psnr_s
            torch.save(student.state_dict(), str(epoch)+"best_dehazing_student.pth")
            torch.save(teacher.state_dict(), str(epoch)+"best_sr_teacher.pth")
            print(f"Saved Best Model (Loss: {best_loss:.6f}, Student PSNR: {best_psnr:.2f})")

    print("Training complete. Logs saved in", log_file)


In [25]:
# -----------------------------
# TRAINING SETUP 
# -----------------------------
teacher_model = SR_model(upscale_factor=1)
student_model = DehazingNet()
fam_module = FeatureAffinityModule()

In [26]:
train(net,teacher_model, student_model, fam_module, dataloader, num_epochs=10)

Using device: cpu
Epoch 1/10
Hazy images shape: torch.Size([4, 3, 360, 360]), Clear images shape: torch.Size([4, 3, 360, 360])
Teacher output shape: torch.Size([4, 3, 180, 180]), Student output shape: torch.Size([4, 3, 180, 180])
Base Loss: tensor(0.1466, grad_fn=<SmoothL1LossBackward0>)
Smooth Loss: tensor(0.1338, grad_fn=<SmoothL1LossBackward0>)
Lambda Loss * Perceptual Loss: tensor(0.0011, grad_fn=<MulBackward0>)
FAM Loss: tensor(2.0489e-07, grad_fn=<MulBackward0>)
Total Loss: 0.281588077545166
Hazy images shape: torch.Size([4, 3, 360, 360]), Clear images shape: torch.Size([4, 3, 360, 360])
Teacher output shape: torch.Size([4, 3, 180, 180]), Student output shape: torch.Size([4, 3, 180, 180])
Base Loss: tensor(3166731.2500, grad_fn=<SmoothL1LossBackward0>)
Smooth Loss: tensor(3141624., grad_fn=<SmoothL1LossBackward0>)
Lambda Loss * Perceptual Loss: tensor(1.0751e+10, grad_fn=<MulBackward0>)
FAM Loss: tensor(7.4506e-08, grad_fn=<MulBackward0>)
Total Loss: 10757286912.0
Hazy images sha

KeyboardInterrupt: 

In [None]:
# -----------------------------
# LOAD MODEL
# -----------------------------
# model_path = "/kaggle/input/dehazing_sr/pytorch/default/1/best_dehazing_student.pth"
model_path = "/kaggle/input/dehazing_sr/pytorch/default/3/9best_sr_teacher.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Initialize model
# model = DehazingNet().to(device)
model = SR_model(upscale_factor=1).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

In [None]:
# -----------------------------
# LOAD TEST DATA
# -----------------------------
test_hazy_dir = "/kaggle/input/reside6k/RESIDE-6K/test/hazy"
test_gt_dir = "/kaggle/input/reside6k/RESIDE-6K/test/GT"
# test_hazy_dir = "/kaggle/input/nh-dense-haze/NH-HAZE-T/NH-HAZE-T/IN"
# test_gt_dir = "/kaggle/input/nh-dense-haze/NH-HAZE-T/NH-HAZE-T/GT"

hazy_images = sorted(glob.glob(os.path.join(test_hazy_dir, "*.*")))
gt_images = sorted(glob.glob(os.path.join(test_gt_dir, "*.*")))

transform = Compose([
    ToTensor(),
    Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

to_pil = ToPILImage()

In [None]:
.. block execution

In [None]:
# -----------------------------
# INFERENCE & VISUALIZATION
# -----------------------------
num_samples = 20  # Change as needed
plt.figure(figsize=(10, num_samples * 5))

for i in range(num_samples):
    hazy_img = Image.open(hazy_images[i])

    # Transform for model input
    input_tensor = transform(hazy_img).unsqueeze(0).to(device)

    # Inference
    with torch.no_grad():
        output_tensor = model(input_tensor).cpu().squeeze(0)

    # Convert back to image
    output_img = to_pil(output_tensor)

    # Display results
    plt.subplot(num_samples, 2, 2 * i + 1)
    plt.imshow(hazy_img)
    plt.title("Hazy Input")
    plt.axis("off")

    plt.subplot(num_samples, 2, 2 * i + 2)
    plt.imshow(output_img)
    plt.title("Output")
    plt.axis("off")

plt.tight_layout()
plt.show()


In [None]:
# -----------------------------
# INFERENCE & VISUALIZATION FOR SPECIFIC IMAGES
# -----------------------------
image_indices = [70, 75, 89, 100]  # Indices of images to visualize

plt.figure(figsize=(10, len(image_indices) * 5))

for idx, i in enumerate(image_indices):
    hazy_img = Image.open(hazy_images[i+1])
    gt_img = Image.open(gt_images[i+1])

    # Transform for model input
    input_tensor = transform(hazy_img).unsqueeze(0).to(device)

    # Inference
    with torch.no_grad():
        output_tensor = model(input_tensor).cpu().squeeze(0)

    # Convert back to image
    output_img = to_pil(output_tensor)

    # Display results
    plt.subplot(len(image_indices), 3, 3 * idx + 1)
    plt.imshow(hazy_img)
    plt.title(f"Hazy Input {i}")
    plt.axis("off")

    plt.subplot(len(image_indices), 3, 3 * idx + 2)
    plt.imshow(output_img)
    plt.title(f"Dehazed Output {i}")
    plt.axis("off")

    plt.subplot(len(image_indices), 3, 3 * idx + 3)
    plt.imshow(gt_img)
    plt.title(f"Ground Truth {i}")
    plt.axis("off")

plt.tight_layout()
plt.show()


In [None]:
# -----------------------------
# INFERENCE & VISUALIZATION
# -----------------------------
num_samples = 5  # Change as needed
plt.figure(figsize=(10, num_samples * 5))

for i in range(num_samples):
    hazy_img = Image.open(hazy_images[i]).convert('RGB')
    gt_img = Image.open(gt_images[i]).convert('RGB')

    # Transform for model input
    input_tensor = transform(hazy_img).unsqueeze(0).to(device)

    # Inference
    with torch.no_grad():
        output_tensor = teacher_model(input_tensor).cpu().squeeze(0)
    
    # Convert back to image
    output_img = to_pil(output_tensor)

    # Display results
    plt.subplot(num_samples, 3, 3 * i + 1)
    plt.imshow(hazy_img)
    plt.title("Hazy Input")
    plt.axis("off")

    plt.subplot(num_samples, 3, 3 * i + 2)
    plt.imshow(output_img)
    plt.title("Dehazed Output")
    plt.axis("off")

    plt.subplot(num_samples, 3, 3 * i + 3)
    plt.imshow(gt_img)
    plt.title("Ground Truth")
    plt.axis("off")

plt.tight_layout()
plt.show()

In [None]:
model_path = "/kaggle/input/dehazing_sr/pytorch/default/3/9best_dehazing_student.pth"
model = DehazingNet().to(device)
# model = SR_model(upscale_factor=1).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()


In [None]:
# -----------------------------
# INFERENCE & VISUALIZATION
# -----------------------------
num_samples = 20  # Change as needed
plt.figure(figsize=(10, num_samples * 5))

for i in range(num_samples):
    hazy_img = Image.open(hazy_images[i])

    # Transform for model input
    input_tensor = transform(hazy_img).unsqueeze(0).to(device)

    # Inference
    with torch.no_grad():
        output_tensor = model(input_tensor).cpu().squeeze(0)

    # Convert back to image
    output_img = to_pil(output_tensor)

    # Display results
    plt.subplot(num_samples, 2, 2 * i + 1)
    plt.imshow(hazy_img)
    plt.title("Hazy Input")
    plt.axis("off")

    plt.subplot(num_samples, 2, 2 * i + 2)
    plt.imshow(output_img)
    plt.title("Output")
    plt.axis("off")

plt.tight_layout()
plt.show()


In [None]:
# -----------------------------
# INFERENCE & VISUALIZATION
# -----------------------------
num_samples = 5  # Change as needed
plt.figure(figsize=(10, num_samples * 5))

for i in range(num_samples):
    hazy_img = Image.open(hazy_images[i]).convert('RGB')
    gt_img = Image.open(gt_images[i]).convert('RGB')

    # Transform for model input
    input_tensor = transform(hazy_img).unsqueeze(0).to(device)

    # Inference
    with torch.no_grad():
        output_tensor = teacher_model(input_tensor).cpu().squeeze(0)
    
    # Convert back to image
    output_img = to_pil(output_tensor)

    # Display results
    plt.subplot(num_samples, 3, 3 * i + 1)
    plt.imshow(hazy_img)
    plt.title("Hazy Input")
    plt.axis("off")

    plt.subplot(num_samples, 3, 3 * i + 2)
    plt.imshow(output_img)
    plt.title("Dehazed Output")
    plt.axis("off")

    plt.subplot(num_samples, 3, 3 * i + 3)
    plt.imshow(gt_img)
    plt.title("Ground Truth")
    plt.axis("off")

plt.tight_layout()
plt.show()

In [None]:
# -----------------------------
# INFERENCE & VISUALIZATION FOR SPECIFIC IMAGES
# -----------------------------
image_indices = [70, 75, 89, 100]  # Indices of images to visualize

plt.figure(figsize=(10, len(image_indices) * 5))

for idx, i in enumerate(image_indices):
    hazy_img = Image.open(hazy_images[i+1])
    gt_img = Image.open(gt_images[i+1])

    # Transform for model input
    input_tensor = transform(hazy_img).unsqueeze(0).to(device)

    # Inference
    with torch.no_grad():
        output_tensor = model(input_tensor).cpu().squeeze(0)

    # Convert back to image
    output_img = to_pil(output_tensor)

    # Display results
    plt.subplot(len(image_indices), 3, 3 * idx + 1)
    plt.imshow(hazy_img)
    plt.title(f"Hazy Input {i}")
    plt.axis("off")

    plt.subplot(len(image_indices), 3, 3 * idx + 2)
    plt.imshow(output_img)
    plt.title(f"Dehazed Output {i}")
    plt.axis("off")

    plt.subplot(len(image_indices), 3, 3 * idx + 3)
    plt.imshow(gt_img)
    plt.title(f"Ground Truth {i}")
    plt.axis("off")

plt.tight_layout()
plt.show()
