In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from torch.nn.init import _calculate_fan_in_and_fan_out
from timm.layers import to_2tuple, trunc_normal_
import os
import torchvision.utils as utils
import torch.utils.data as data
from torch.utils.data import DataLoader
import glob
from torchvision.transforms import ToTensor, Normalize, Compose, ToPILImage
from torchvision.models import vgg16
from torch.utils.data import Dataset
from PIL import Image
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
from random import randrange
import time
from math import log10
from skimage import measure
import ipywidgets as widgets
from IPython.display import display
from torch.nn.init import trunc_normal_
import matplotlib.pyplot as plt

In [2]:
from torchvision import transforms

## RevisedLayerNorm

In [3]:
class RevisedLayerNorm(nn.Module):
    """Revised LayerNorm"""
    def __init__(self, embed_dim, epsilon=1e-5, detach_gradient=False):
        super(RevisedLayerNorm, self).__init__()
        self.epsilon = epsilon
        self.detach_gradient = detach_gradient

        self.scale = nn.Parameter(torch.ones((1, embed_dim, 1, 1)))
        self.shift = nn.Parameter(torch.zeros((1, embed_dim, 1, 1)))

        self.scale_mlp = nn.Conv2d(1, embed_dim, 1)
        self.shift_mlp = nn.Conv2d(1, embed_dim, 1)

        trunc_normal_(self.scale_mlp.weight, std=.02)
        nn.init.constant_(self.scale_mlp.bias, 1)

        trunc_normal_(self.shift_mlp.weight, std=.02)
        nn.init.constant_(self.shift_mlp.bias, 0)

    def forward(self, input_tensor):
        mean_value = torch.mean(input_tensor, dim=(1, 2, 3), keepdim=True)
        std_value = torch.sqrt((input_tensor - mean_value).pow(2).mean(dim=(1, 2, 3), keepdim=True) + self.epsilon)

        normalized_tensor = (input_tensor - mean_value) / std_value

        if self.detach_gradient:
            rescale, rebias = self.scale_mlp(std_value.detach()), self.shift_mlp(mean_value.detach())
        else:
            rescale, rebias = self.scale_mlp(std_value), self.shift_mlp(mean_value)

        output = normalized_tensor * self.scale + self.shift
        return output, rescale, rebias



In [4]:
class MultiLayerPerceptron(nn.Module):
    def __init__(self, depth, input_channels, hidden_channels=None, output_channels=None):
        super().__init__()
        output_channels = output_channels or input_channels
        hidden_channels = hidden_channels or input_channels

        self.depth = depth

        self.mlp_layers = nn.Sequential(
            nn.Conv2d(input_channels, hidden_channels, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden_channels, output_channels, kernel_size=1)
        )

        self.apply(self._initialize_weights)

    def _initialize_weights(self, layer):
        if isinstance(layer, nn.Conv2d):
            gain = (8 * self.depth) ** (-1 / 4)
            fan_in, fan_out = torch.nn.init._calculate_fan_in_and_fan_out(layer.weight)
            std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
            trunc_normal_(layer.weight, std=std)
            if layer.bias is not None:
                nn.init.constant_(layer.bias, 0)

    def forward(self, x):
        return self.mlp_layers(x)


def partition_into_windows(tensor, window_size):
    """Splits the input tensor into non-overlapping windows."""
    batch_size, height, width, channels = tensor.shape
    assert height % window_size == 0 and width % window_size == 0, "Height and width must be divisible by window_size"

    tensor = tensor.view(
        batch_size, height // window_size, window_size, width // window_size, window_size, channels
    )
    windows = tensor.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size**2, channels)
    return windows


def merge_windows(windows, window_size, height, width):
    """Reconstructs the original tensor from partitioned windows."""
    batch_size = windows.shape[0] // ((height * width) // (window_size**2))
    tensor = windows.view(
        batch_size, height // window_size, width // window_size, window_size, window_size, -1
    )
    tensor = tensor.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch_size, height, width, -1)
    return tensor


### test

In [5]:
import torch

# Initialize the MultiLayerPerceptron with sample parameters
depth = 4
input_channels = 64
hidden_channels = 128
output_channels = 64

mlp = MultiLayerPerceptron(depth, input_channels, hidden_channels, output_channels)

# Create a random tensor to test MLP (batch_size=2, channels=64, height=16, width=16)
input_tensor = torch.randn(2, 64, 16, 16)
output_tensor = mlp(input_tensor)

# Check output shape
mlp_output_shape = output_tensor.shape

# Test window partition and merging
batch_size, height, width, channels = 2, 16, 16, 64
window_size = 4

# Create a random tensor for window functions (B, H, W, C) format
input_window_tensor = torch.randn(batch_size, height, width, channels)

# Apply partitioning and merging
windows = partition_into_windows(input_window_tensor, window_size)
reconstructed_tensor = merge_windows(windows, window_size, height, width)

# Check shapes
windows_shape = windows.shape
reconstructed_shape = reconstructed_tensor.shape

# Validate if the reconstruction matches the original input shape
is_shape_correct = reconstructed_shape == input_window_tensor.shape

# Output results
mlp_output_shape, windows_shape, reconstructed_shape, is_shape_correct


(torch.Size([2, 64, 16, 16]),
 torch.Size([32, 16, 64]),
 torch.Size([2, 16, 16, 64]),
 True)

In [6]:
class LocalWindowAttention(nn.Module):
    def __init__(self, embed_dim, window_size, num_heads):
        """Self-attention mechanism within local windows."""
        super().__init__()
        self.embed_dim = embed_dim
        self.window_size = window_size  # (height, width)
        self.num_heads = num_heads
        head_dim = embed_dim // num_heads
        self.scaling_factor = head_dim ** -0.5  # Scaled dot-product attention

        # Compute and store relative positional encodings
        relative_positional_encodings = compute_log_relative_positions(self.window_size)
        self.register_buffer("relative_positional_encodings", relative_positional_encodings)

        # Learnable transformation of relative position embeddings
        self.relative_mlp = nn.Sequential(
            nn.Linear(2, 256, bias=True),
            nn.ReLU(inplace=True),
            nn.Linear(256, num_heads, bias=True)
        )

        self.attention_softmax = nn.Softmax(dim=-1)

    def forward(self, qkv):
        """Computes attention scores and applies self-attention within a window."""
        batch_size, num_tokens, _ = qkv.shape

        # Reshape qkv into separate query, key, and value tensors
        qkv = qkv.reshape(batch_size, num_tokens, 3, self.num_heads, self.embed_dim // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # Unpacking query, key, and value

        # Scale query for stable attention computation
        q = q * self.scaling_factor
        attention_scores = q @ k.transpose(-2, -1)

        # Compute relative position bias
        relative_bias = self.relative_mlp(self.relative_positional_encodings)
        relative_bias = relative_bias.permute(2, 0, 1).contiguous()  # Shape: (num_heads, window_size², window_size²)
        attention_scores = attention_scores + relative_bias.unsqueeze(0)

        # Apply softmax and compute weighted values
        attention_weights = self.attention_softmax(attention_scores)
        output = (attention_weights @ v).transpose(1, 2).reshape(batch_size, num_tokens, self.embed_dim)

        return output


In [7]:
def compute_log_relative_positions(window_size):
    """Computes log-scaled relative position embeddings for a given window size."""
    coord_range = torch.arange(window_size)

    # Create coordinate grid
    coord_grid = torch.stack(torch.meshgrid([coord_range, coord_range]))  # Shape: (2, window_size, window_size)
    
    # Flatten coordinates
    flattened_coords = torch.flatten(coord_grid, 1)  # Shape: (2, window_size * window_size)

    # Compute relative positions
    relative_positions = flattened_coords[:, :, None] - flattened_coords[:, None, :]  # Shape: (2, window_size^2, window_size^2)

    # Format and apply log transformation
    relative_positions = relative_positions.permute(1, 2, 0).contiguous()  # Shape: (window_size^2, window_size^2, 2)
    log_relative_positions = torch.sign(relative_positions) * torch.log(1. + relative_positions.abs())

    return log_relative_positions


In [8]:
class AdaptiveAttention(nn.Module):
    def __init__(self, network_depth, embed_dim, num_heads, window_size, shift_size, enable_attention=False, conv_mode=None):
        """Hybrid attention-convolution module with optional window-based attention."""
        super().__init__()
        self.embed_dim = embed_dim
        self.head_dim = embed_dim // num_heads
        self.num_heads = num_heads

        self.window_size = window_size
        self.shift_size = shift_size

        self.network_depth = network_depth
        self.enable_attention = enable_attention
        self.conv_mode = conv_mode

        # Define convolutional processing based on mode
        if self.conv_mode == 'Conv':
            self.conv_layer = nn.Sequential(
                nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1, padding_mode='reflect'),
                nn.ReLU(inplace=True),
                nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1, padding_mode='reflect')
            )

        if self.conv_mode == 'DWConv':
            self.conv_layer = nn.Conv2d(embed_dim, embed_dim, kernel_size=5, padding=2, groups=embed_dim, padding_mode='reflect')

        if self.conv_mode == 'DWConv' or self.enable_attention:
            self.value_projection = nn.Conv2d(embed_dim, embed_dim, 1)
            self.output_projection = nn.Conv2d(embed_dim, embed_dim, 1)

        if self.enable_attention:
            self.query_key_projection = nn.Conv2d(embed_dim, embed_dim * 2, 1)
            self.window_attention = LocalWindowAttention(embed_dim, window_size, num_heads)

        self.apply(self._initialize_weights)

    def _initialize_weights(self, module):
        """Custom weight initialization."""
        if isinstance(module, nn.Conv2d):
            weight_shape = module.weight.shape

            if weight_shape[0] == self.embed_dim * 2:  # Query-Key projection
                fan_in, fan_out = _calculate_fan_in_and_fan_out(module.weight)
                std = math.sqrt(2.0 / float(fan_in + fan_out))
                trunc_normal_(module.weight, std=std)
            else:
                gain = (8 * self.network_depth) ** (-1/4)
                fan_in, fan_out = _calculate_fan_in_and_fan_out(module.weight)
                std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
                trunc_normal_(module.weight, std=std)

            if module.bias is not None:
                nn.init.constant_(module.bias, 0)

    def pad_for_window_processing(self, x, shift=False):
        """Pads the input tensor to fit window processing requirements."""
        _, _, height, width = x.size()
        pad_h = (self.window_size - height % self.window_size) % self.window_size
        pad_w = (self.window_size - width % self.window_size) % self.window_size

        if shift:
            x = F.pad(x, (self.shift_size, (self.window_size - self.shift_size + pad_w) % self.window_size,
                          self.shift_size, (self.window_size - self.shift_size + pad_h) % self.window_size), mode='reflect')
        else:
            x = F.pad(x, (0, pad_w, 0, pad_h), 'reflect')
        return x

    def forward(self, x):
        """Computes the output with optional attention and convolution."""
        batch_size, channels, height, width = x.shape

        if self.conv_mode == 'DWConv' or self.enable_attention:
            v_proj = self.value_projection(x)

        if self.enable_attention:
            qk_proj = self.query_key_projection(x)
            qkv = torch.cat([qk_proj, v_proj], dim=1)

            # Apply padding for shifted window processing
            padded_qkv = self.pad_for_window_processing(qkv, self.shift_size > 0)
            padded_height, padded_width = padded_qkv.shape[2:]

            # Partition into windows
            padded_qkv = padded_qkv.permute(0, 2, 3, 1)
            qkv_windows = partition_into_windows(padded_qkv, self.window_size)  # (num_windows * batch, window_size², channels)

            # Apply window-based attention
            attn_windows = self.window_attention(qkv_windows)

            # Merge back to original spatial dimensions
            merged_output = merge_windows(attn_windows, self.window_size, padded_height, padded_width)

            # Reverse the cyclic shift
            attn_output = merged_output[:, self.shift_size:(self.shift_size + height), self.shift_size:(self.shift_size + width), :]
            attn_output = attn_output.permute(0, 3, 1, 2)

            if self.conv_mode in ['Conv', 'DWConv']:
                conv_output = self.conv_layer(v_proj)
                output = self.output_projection(conv_output + attn_output)
            else:
                output = self.output_projection(attn_output)

        else:
            if self.conv_mode == 'Conv':
                output = self.conv_layer(x)  # No attention, using convolution only
            elif self.conv_mode == 'DWConv':
                output = self.output_projection(self.conv_layer(v_proj))

        return output

In [9]:
class VisionTransformerBlock(nn.Module):
    def __init__(self, network_depth, embed_dim, num_heads, mlp_ratio=4.0,
                 norm_layer=nn.LayerNorm, enable_mlp_norm=False,
                 window_size=8, shift_size=0, enable_attention=True, conv_mode=None):
        """
        A transformer block that includes attention (optional) and MLP layers.
        """
        super().__init__()
        self.enable_attention = enable_attention
        self.enable_mlp_norm = enable_mlp_norm

        self.pre_norm = norm_layer(embed_dim) if enable_attention else nn.Identity()
        self.attention_layer = AdaptiveAttention(
            network_depth, embed_dim, num_heads=num_heads, window_size=window_size,
            shift_size=shift_size, enable_attention=enable_attention, conv_mode=conv_mode
        )

        self.post_norm = norm_layer(embed_dim) if enable_attention and enable_mlp_norm else nn.Identity()
        self.mlp_layer = MultiLayerPerceptron(network_depth, embed_dim, hidden_channels=int(embed_dim * mlp_ratio))

    def forward(self, x):
        """
        Forward pass through the transformer block.
        """
        residual = x
        if self.enable_attention:
            x, rescale, rebias = self.pre_norm(x)
        x = self.attention_layer(x)
        if self.enable_attention:
            x = x * rescale + rebias
        x = residual + x  # Residual connection

        residual = x
        if self.enable_attention and self.enable_mlp_norm:
            x, rescale, rebias = self.post_norm(x)
        x = self.mlp_layer(x)
        if self.enable_attention and self.enable_mlp_norm:
            x = x * rescale + rebias
        x = residual + x  # Residual connection

        return x


class TransformerStage(nn.Module):
    def __init__(self, network_depth, embed_dim, num_layers, num_heads, mlp_ratio=4.0,
                 norm_layer=nn.LayerNorm, window_size=8,
                 attention_ratio=0.0, attention_placement='last', conv_mode=None):
        """
        A stage of transformer blocks with configurable attention placement.
        """
        super().__init__()
        self.embed_dim = embed_dim
        self.num_layers = num_layers

        attention_layers = int(attention_ratio * num_layers)

        if attention_placement == 'last':
            enable_attentions = [i >= num_layers - attention_layers for i in range(num_layers)]
        elif attention_placement == 'first':
            enable_attentions = [i < attention_layers for i in range(num_layers)]
        elif attention_placement == 'middle':
            enable_attentions = [
                (i >= (num_layers - attention_layers) // 2) and (i < (num_layers + attention_layers) // 2)
                for i in range(num_layers)
            ]

        # Build transformer blocks
        self.blocks = nn.ModuleList([
            VisionTransformerBlock(
                network_depth=network_depth,
                embed_dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                norm_layer=norm_layer,
                window_size=window_size,
                shift_size=0 if (i % 2 == 0) else window_size // 2,
                enable_attention=enable_attentions[i],
                conv_mode=conv_mode
            ) for i in range(num_layers)
        ])

    def forward(self, x):
        """
        Forward pass through the transformer stage.
        """
        for block in self.blocks:
            x = block(x)
        return x


In [10]:
class PatchEmbedding(nn.Module):
    def __init__(self, patch_size=4, input_channels=3, embedding_dim=96, kernel_size=None):
        """
        Patch embedding module that projects input images into token embeddings.
        """
        super().__init__()
        self.input_channels = input_channels
        self.embedding_dim = embedding_dim

        if kernel_size is None:
            kernel_size = patch_size

        self.projection = nn.Conv2d(
            input_channels, embedding_dim, kernel_size=kernel_size, stride=patch_size,
            padding=(kernel_size - patch_size + 1) // 2, padding_mode='reflect'
        )

    def forward(self, x):
        """
        Forward pass to generate patch embeddings.
        """
        return self.projection(x)


class PatchReconstruction(nn.Module):
    def __init__(self, patch_size=4, output_channels=3, embedding_dim=96, kernel_size=None):
        """
        Patch reconstruction module that converts token embeddings back to image patches.
        """
        super().__init__()
        self.output_channels = output_channels
        self.embedding_dim = embedding_dim

        if kernel_size is None:
            kernel_size = 1

        self.projection = nn.Sequential(
            nn.Conv2d(
                embedding_dim, output_channels * patch_size ** 2, kernel_size=kernel_size,
                padding=kernel_size // 2, padding_mode='reflect'
            ),
            nn.PixelShuffle(patch_size)
        )

    def forward(self, x):
        """
        Forward pass to reconstruct image from embeddings.
        """
        return self.projection(x)


In [11]:
class SelectiveKernelFusion(nn.Module):
    def __init__(self, channels, num_branches=2, reduction_ratio=8):
        """
        Selective Kernel Fusion (SKFusion) module for adaptive feature selection.

        Args:
            channels (int): Number of input channels.
            num_branches (int): Number of feature branches to fuse.
            reduction_ratio (int): Reduction ratio for the attention mechanism.
        """
        super().__init__()
        
        self.num_branches = num_branches
        reduced_channels = max(int(channels / reduction_ratio), 4)
        
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.channel_attention = nn.Sequential(
            nn.Conv2d(channels, reduced_channels, kernel_size=1, bias=False), 
            nn.ReLU(),
            nn.Conv2d(reduced_channels, channels * num_branches, kernel_size=1, bias=False)
        )
        
        self.softmax = nn.Softmax(dim=1)

    def forward(self, feature_maps):
        """
        Forward pass for selective kernel fusion.

        Args:
            feature_maps (list of tensors): A list of feature maps to be fused.

        Returns:
            torch.Tensor: The adaptively fused feature map.
        """
        batch_size, channels, height, width = feature_maps[0].shape
        
        # Concatenate feature maps along a new dimension (num_branches)
        stacked_features = torch.cat(feature_maps, dim=1).view(batch_size, self.num_branches, channels, height, width)
        
        # Compute attention weights
        aggregated_features = torch.sum(stacked_features, dim=1)
        attention_weights = self.channel_attention(self.global_avg_pool(aggregated_features))
        attention_weights = self.softmax(attention_weights.view(batch_size, self.num_branches, channels, 1, 1))

        # Weighted sum of input feature maps
        fused_output = torch.sum(stacked_features * attention_weights, dim=1)
        return fused_output


In [12]:
class DehazingTransformer(nn.Module):
    def __init__(self, input_channels=3, output_channels=4, window_size=8,
                 embed_dims=[24, 48, 96, 48, 24],
                 mlp_ratios=[2., 4., 4., 2., 2.],
                 layer_depths=[16, 16, 16, 8, 8],
                 num_heads=[2, 4, 6, 1, 1],
                 attention_ratios=[1/4, 1/2, 3/4, 0, 0],
                 conv_types=['DWConv', 'DWConv', 'DWConv', 'DWConv', 'DWConv'],
                 norm_layers=[RevisedLayerNorm, RevisedLayerNorm, RevisedLayerNorm, RevisedLayerNorm, RevisedLayerNorm]):
        super().__init__()

        # Patch embedding settings
        self.patch_size = 4
        self.window_size = window_size

        # Initial patch embedding
        self.patch_embed = PatchEmbedding(
            patch_size=1, input_channels=input_channels, embedding_dim=embed_dims[0], kernel_size=3)

        # Backbone layers
        self.encoder_stage1 = TransformerStage(
            network_depth=sum(layer_depths),
            embed_dim=embed_dims[0],
            num_layers=layer_depths[0],
            num_heads=num_heads[0],
            mlp_ratio=mlp_ratios[0],
            norm_layer=norm_layers[0],
            window_size=window_size,
            attention_ratio=attention_ratios[0],
            attention_placement='last',
            conv_mode=conv_types[0]
        )
        
        self.downsample1 = PatchEmbedding(
            patch_size=2, input_channels=embed_dims[0], embedding_dim=embed_dims[1]
        )
        
        self.skip_connection1 = nn.Conv2d(embed_dims[0], embed_dims[0], 1)
        
        self.encoder_stage2 = TransformerStage(
            network_depth=sum(layer_depths),
            embed_dim=embed_dims[1],
            num_layers=layer_depths[1],
            num_heads=num_heads[1],
            mlp_ratio=mlp_ratios[1],
            norm_layer=norm_layers[1],
            window_size=window_size,
            attention_ratio=attention_ratios[1],
            attention_placement='last',
            conv_mode=conv_types[1]
        )
        
        self.downsample2 = PatchEmbedding(
            patch_size=2, input_channels=embed_dims[1], embedding_dim=embed_dims[2]
        )
        
        self.skip_connection2 = nn.Conv2d(embed_dims[1], embed_dims[1], 1)
        
        self.encoder_stage3 = TransformerStage(
            network_depth=sum(layer_depths),
            embed_dim=embed_dims[2],
            num_layers=layer_depths[2],
            num_heads=num_heads[2],
            mlp_ratio=mlp_ratios[2],
            norm_layer=norm_layers[2],
            window_size=window_size,
            attention_ratio=attention_ratios[2],
            attention_placement='last',
            conv_mode=conv_types[2]
        )
        
        self.upsample1 = PatchReconstruction(
            patch_size=2, output_channels=embed_dims[3], embedding_dim=embed_dims[2]
        )
        
        assert embed_dims[1] == embed_dims[3]
        self.fusion_layer1 = SelectiveKernelFusion(embed_dims[3])
        
        self.decoder_stage1 = TransformerStage(
            network_depth=sum(layer_depths),
            embed_dim=embed_dims[3],
            num_layers=layer_depths[3],
            num_heads=num_heads[3],
            mlp_ratio=mlp_ratios[3],
            norm_layer=norm_layers[3],
            window_size=window_size,
            attention_ratio=attention_ratios[3],
            attention_placement='last',
            conv_mode=conv_types[3]
        )
        
        self.upsample2 = PatchReconstruction(
            patch_size=2, output_channels=embed_dims[4], embedding_dim=embed_dims[3]
        )
        
        assert embed_dims[0] == embed_dims[4]
        self.fusion_layer2 = SelectiveKernelFusion(embed_dims[4])
        
        self.decoder_stage2 = TransformerStage(
            network_depth=sum(layer_depths),
            embed_dim=embed_dims[4],
            num_layers=layer_depths[4],
            num_heads=num_heads[4],
            mlp_ratio=mlp_ratios[4],
            norm_layer=norm_layers[4],
            window_size=window_size,
            attention_ratio=attention_ratios[4],
            attention_placement='last',
            conv_mode=conv_types[4]
        )

        # Final patch reconstruction
        self.patch_reconstruction = PatchReconstruction(
            patch_size=1, output_channels=output_channels, embedding_dim=embed_dims[4], kernel_size=3)

    def adjust_image_size(self, x):
        # Ensures the input image size is compatible with the patch size
        _, _, height, width = x.size()
        pad_height = (self.patch_size - height % self.patch_size) % self.patch_size
        pad_width = (self.patch_size - width % self.patch_size) % self.patch_size
        x = F.pad(x, (0, pad_width, 0, pad_height), 'reflect')
        return x

    def extract_features(self, x):
        x = self.patch_embed(x)
        x = self.encoder_stage1(x)
        skip1 = x

        x = self.downsample1(x)
        x = self.encoder_stage2(x)
        skip2 = x

        x = self.downsample2(x)
        x = self.encoder_stage3(x)
        x = self.upsample1(x)

        x = self.fusion_layer1([x, self.skip_connection2(skip2)]) + x
        x = self.decoder_stage1(x)
        x = self.upsample2(x)

        x = self.fusion_layer2([x, self.skip_connection1(skip1)]) + x
        x = self.decoder_stage2(x)
        x = self.patch_reconstruction(x)
        return x

    def forward(self, x):
        original_height, original_width = x.shape[2:]
        x = self.adjust_image_size(x)

        features = self.extract_features(x)
        transmission_map, atmospheric_light = torch.split(features, (1, 3), dim=1)

        # Dehazing formula: I = J * t + A * (1 - t)
        x = transmission_map * x - atmospheric_light + x
        x = x[:, :, :original_height, :original_width]
        return x

In [13]:
def build_dehazing_transformer():
    return DehazingTransformer(
        embed_dims=[24, 48, 96, 48, 24],
        mlp_ratios=[2., 4., 4., 2., 2.],
        layer_depths=[12, 12, 12, 6, 6],
        num_heads=[2, 4, 6, 1, 1],
        attention_ratios=[1/4, 1/2, 3/4, 0, 0],
        conv_types=['Conv', 'Conv', 'Conv', 'Conv', 'Conv']
    )

In [14]:
class ConvolutionalGuidedFilter(nn.Module):
    def __init__(self, radius=1, norm_layer=nn.BatchNorm2d, conv_kernel_size: int = 1):
        super(ConvolutionalGuidedFilter, self).__init__()

        self.box_filter = nn.Conv2d(
            3, 3, kernel_size=3, padding=radius, dilation=radius, bias=False, groups=3
        )
        self.conv_a = nn.Sequential(
            nn.Conv2d(
                6,
                32,
                kernel_size=conv_kernel_size,
                padding=conv_kernel_size // 2,
                bias=False,
            ),
            norm_layer(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                32,
                32,
                kernel_size=conv_kernel_size,
                padding=conv_kernel_size // 2,
                bias=False,
            ),
            norm_layer(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                32,
                3,
                kernel_size=conv_kernel_size,
                padding=conv_kernel_size // 2,
                bias=False,
            ),
        )
        self.box_filter.weight.data[...] = 1.0

    def forward(self, x_low_res, y_low_res, x_high_res):
        _, _, h_lr, w_lr = x_low_res.size()
        _, _, h_hr, w_hr = x_high_res.size()

        N = self.box_filter(x_low_res.data.new().resize_((1, 3, h_lr, w_lr)).fill_(1.0))
        ## mean_x
        mean_x = self.box_filter(x_low_res) / N
        ## mean_y
        mean_y = self.box_filter(y_low_res) / N
        ## cov_xy
        cov_xy = self.box_filter(x_low_res * y_low_res) / N - mean_x * mean_y
        ## var_x
        var_x = self.box_filter(x_low_res * x_low_res) / N - mean_x * mean_x

        ## A
        A = self.conv_a(torch.cat([cov_xy, var_x], dim=1))
        ## b
        b = mean_y - A * mean_x

        ## mean_A; mean_b
        mean_A = F.interpolate(A, (h_hr, w_hr), mode="bilinear", align_corners=True)
        mean_b = F.interpolate(b, (h_hr, w_hr), mode="bilinear", align_corners=True)

        return mean_A * x_high_res + mean_b

In [15]:
class PixelAttentionLayer(nn.Module):
    def __init__(self, channels):
        super(PixelAttentionLayer, self).__init__()
        self.attention = nn.Sequential(
                nn.Conv2d(channels, channels // 8, kernel_size=1, padding=0, bias=True),
                nn.ReLU(inplace=True),
                nn.Conv2d(channels // 8, 1, kernel_size=1, padding=0, bias=True),
                nn.Sigmoid()
        )
    
    def forward(self, x):
        attention_map = self.attention(x)
        return x * attention_map

class ChannelAttentionLayer(nn.Module):
    def __init__(self, channels):
        super(ChannelAttentionLayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.attention = nn.Sequential(
                nn.Conv2d(channels, channels // 8, kernel_size=1, padding=0, bias=True),
                nn.ReLU(inplace=True),
                nn.Conv2d(channels // 8, channels, kernel_size=1, padding=0, bias=True),
                nn.Sigmoid()
        )

    def forward(self, x):
        pooled = self.avg_pool(x)
        attention_map = self.attention(pooled)
        return x * attention_map


In [16]:
class SuperResolutionDilationBlock(nn.Module):
    def __init__(self, in_channels, num_dense_layers, growth_rate):
        super(SuperResolutionDilationBlock, self).__init__()

        self.split_channels = in_channels // 4
        kernel_size = 3

        # Dilated convolutions with increasing dilation rates
        self.conv1 = nn.Conv2d(self.split_channels, self.split_channels, kernel_size=kernel_size, padding=1, dilation=1)
        self.conv2 = nn.Conv2d(self.split_channels * 2, self.split_channels, kernel_size=kernel_size, padding=2, dilation=2)
        self.conv3 = nn.Conv2d(self.split_channels * 3, self.split_channels, kernel_size=kernel_size, padding=4, dilation=4)
        self.conv4 = nn.Conv2d(self.split_channels * 4, self.split_channels, kernel_size=kernel_size, padding=8, dilation=8)

        # Attention mechanisms
        self.channel_attention = ChannelAttentionLayer(in_channels)
        self.pixel_attention = PixelAttentionLayer(in_channels)

        # Final 1x1 convolution for feature fusion
        self.conv_1x1 = nn.Conv2d(in_channels, in_channels, kernel_size=1, padding=0)

    def forward(self, x):
        # Split input into 4 equal parts along channel dimension
        split_features = torch.split(x, self.split_channels, dim=1)

        x0 = F.relu(self.conv1(split_features[0]))
        tmp = torch.cat((split_features[1], x0), dim=1)
        x1 = F.relu(self.conv2(tmp))

        tmp = torch.cat((split_features[2], x0, x1), dim=1)
        x2 = F.relu(self.conv3(tmp))

        tmp = torch.cat((split_features[3], x0, x1, x2), dim=1)
        x3 = F.relu(self.conv4(tmp))

        # Concatenate all outputs
        merged_features = torch.cat((x0, x1, x2, x3), dim=1)

        # Apply 1x1 convolution for feature refinement
        out = self.conv_1x1(merged_features)

        # Apply attention mechanisms
        out = self.channel_attention(out)
        out = self.pixel_attention(out)

        # Residual connection
        return out + x

In [17]:
class AdaptiveInstanceNormalization(nn.Module):
    def __init__(self, num_channels):
        super(AdaptiveInstanceNormalization, self).__init__()

        # Learnable scaling factors
        self.scale_x = nn.Parameter(torch.tensor(1.0))  # Identity scaling
        self.scale_norm = nn.Parameter(torch.tensor(0.0))  # Initially no effect

        # Instance normalization layer with affine transformation enabled
        self.instance_norm = nn.InstanceNorm2d(num_channels, momentum=0.999, eps=0.001, affine=True)

    def forward(self, x):
        normalized_x = self.instance_norm(x)
        return self.scale_x * x + self.scale_norm * normalized_x

In [None]:
class DeepGuidedNetwork(nn.Module):
    def __init__(self, radius=1):
        super().__init__()

        # Adaptive Normalization for Guided Filtering
        norm = AdaptiveInstanceNormalization
        kernel_size = 3
        depth_rate = 16
        in_channels = 3
        num_dense_layer = 4
        growth_rate = 16

        # Initial convolution layers
        self.conv_in = nn.Conv2d(in_channels, depth_rate, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)
        self.conv_out = nn.Conv2d(depth_rate, in_channels, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)

        # Residual Dense Blocks (RDBs)
        self.rdb1 = SuperResolutionDilationBlock(depth_rate, num_dense_layer, growth_rate)
        self.rdb2 = SuperResolutionDilationBlock(depth_rate, num_dense_layer, growth_rate)
        self.rdb3 = SuperResolutionDilationBlock(depth_rate, num_dense_layer, growth_rate)
        self.rdb4 = SuperResolutionDilationBlock(depth_rate, num_dense_layer, growth_rate)

        # Guided Filter & Dehazing Transformer
        self.guided_filter = ConvolutionalGuidedFilter(radius, norm_layer=norm)
        self.dehaze_network = build_dehazing_transformer()

        # Downsampling & Upsampling Layers
        self.downsample = nn.Upsample(scale_factor=0.5, mode="bilinear", align_corners=True)
        self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)

    def forward(self, x_hr):
        # Low-resolution processing
        x_lr = self.downsample(x_hr)

        # Detail extraction through Residual Dense Blocks
        y_features = self.conv_in(x_lr)
        y_features = self.rdb1(y_features)
        y_features = self.rdb2(y_features)
        y_features = self.rdb3(y_features)
        y_features = self.rdb4(y_features)
        y_detail = self.conv_out(y_features)

        y_base_hr = self.upsample(y_detail)
        y_lr = y_base_hr
        # Final guided filtering refinement
        refined_output = self.guided_filter(x_lr, y_lr, x_hr)
        
        return refined_output, y_base_hr

class DeepGuidedNetwork(nn.Module):
    def __init__(self, radius=1):
        super().__init__()

        # Adaptive Normalization for Guided Filtering
        norm = AdaptiveInstanceNormalization
        kernel_size = 3
        depth_rate = 16
        in_channels = 3
        num_dense_layer = 4
        growth_rate = 16

        # Initial convolution layers
        self.conv_in = nn.Conv2d(in_channels, depth_rate, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)
        self.conv_out = nn.Conv2d(depth_rate, in_channels, kernel_size=kernel_size, padding=(kernel_size - 1) // 2)

        # Residual Dense Blocks (RDBs)
        self.rdb1 = SuperResolutionDilationBlock(depth_rate, num_dense_layer, growth_rate)
        self.rdb2 = SuperResolutionDilationBlock(depth_rate, num_dense_layer, growth_rate)
        self.rdb3 = SuperResolutionDilationBlock(depth_rate, num_dense_layer, growth_rate)
        self.rdb4 = SuperResolutionDilationBlock(depth_rate, num_dense_layer, growth_rate)

        # Guided Filter & Dehazing Transformer
        self.guided_filter = ConvolutionalGuidedFilter(radius, norm_layer=norm)
        self.dehaze_network = build_dehazing_transformer()

        # Downsampling & Upsampling Layers
        self.downsample = nn.Upsample(scale_factor=0.5, mode="bilinear", align_corners=True)
        self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)

    def forward(self, x_hr):
        x_lr = self.downsample(x_hr)
    
        # Initial conv
        y_features = self.conv_in(x_lr)
    
        # RDBs + collect features
        feat1 = self.rdb1(y_features)
        feat2 = self.rdb2(feat1)
        feat3 = self.rdb3(feat2)
        feat4 = self.rdb4(feat3)
        y_detail = self.conv_out(feat4)
    
        # Base image
        y_base = self.dehaze_network(x_lr)
    
        # Combine
        y_lr = y_base + y_detail
        y_base_hr = self.upsample(y_base)
    
        # Guided output
        refined_output = self.guided_filter(x_lr, y_lr, x_hr)
    
        return refined_output, y_base_hr, [feat1, feat2, feat3, feat4]


In [20]:
def parse_crop_size(crop_size_str):
    try:
        return [int(x.strip()) for x in crop_size_str.split(',')]
    except ValueError:
        raise ValueError(f"Invalid crop size format: '{crop_size_str}'. Expected comma-separated integers.")

In [21]:
import os
import glob
import torch
from torch.utils.data import Dataset
from torchvision.transforms import Compose, ToTensor, Normalize
from PIL import Image, UnidentifiedImageError
from random import randrange

class TrainData(Dataset):
    def __init__(self, crop_size, hazeeffected_images_dir, hazefree_images_dir):
        super().__init__()
        
        # --- Ensure valid file extensions --- #
        valid_extensions = (".jpg", ".jpeg", ".png", ".bmp", ".tiff")
        hazy_data = [
            f for f in glob.glob(os.path.join(hazeeffected_images_dir, "*.*"))
            if f.lower().endswith(valid_extensions)
        ]

        if not hazy_data:
            raise ValueError(f"No valid images found in {hazeeffected_images_dir}")

        self.hazeeffected_images_dir = hazeeffected_images_dir
        self.hazefree_images_dir = hazefree_images_dir

        self.haze_names = []
        self.gt_names = []
        
        for h_image in hazy_data:
            filename = os.path.basename(h_image)
            haze_path = os.path.join(self.hazeeffected_images_dir, filename)
            gt_path = os.path.join(self.hazefree_images_dir, filename)

            if not os.path.exists(gt_path):
                print(f"Warning: Ground-truth missing for {filename}, skipping.")
                continue

            self.haze_names.append(haze_path)
            self.gt_names.append(gt_path)

        if not self.haze_names:
            raise ValueError("No matching ground-truth images found.")

        self.crop_size = crop_size

    def get_images(self, index):
        crop_width, crop_height = self.crop_size
        haze_name = self.haze_names[index]
        gt_name = self.gt_names[index]

        try:
            haze_img = Image.open(haze_name).convert('RGB')
            gt_img = Image.open(gt_name).convert('RGB')
        except UnidentifiedImageError:
            raise ValueError(f"Invalid image format: {haze_name} or {gt_name}")

        width, height = haze_img.size

        # --- Handle small images --- #
        if width < crop_width or height < crop_height:
            raise ValueError(f"Image too small for cropping: {haze_name}")

        # --- Random crop --- #
        x, y = randrange(0, width - crop_width + 1), randrange(0, height - crop_height + 1)
        haze_crop_img = haze_img.crop((x, y, x + crop_width, y + crop_height))
        gt_crop_img = gt_img.crop((x, y, x + crop_width, y + crop_height))

        # --- Transform to tensor --- #
        transform_haze = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        transform_gt = Compose([ToTensor()])
        haze = transform_haze(haze_crop_img)
        gt = transform_gt(gt_crop_img)

        # --- Check channels --- #
        if haze.shape[0] != 3 or gt.shape[0] != 3:
            raise ValueError(f"Invalid image channels: {haze_name}")

        return haze, gt

    def __getitem__(self, index):
        return self.get_images(index)

    def __len__(self):
        return len(self.haze_names)


In [22]:
import os
import glob
import torch
from torch.utils.data import Dataset
from torchvision.transforms import Compose, ToTensor, Normalize
from PIL import Image, UnidentifiedImageError
from random import randrange, shuffle

class HazeDataset(Dataset):
    def __init__(self, crop_size, hazeeffected_images_dir, hazefree_images_dir, split="train", split_ratio=0.8):
        """
        Dataset class for handling both training and validation dynamically.
        
        Args:
            crop_size (tuple): (width, height) of the random crop.
            hazeeffected_images_dir (str): Directory for hazy images.
            hazefree_images_dir (str): Directory for ground-truth images.
            split (str): "train" or "valid" (determines data split).
            split_ratio (float): Percentage of images to use for training (default 80% train, 20% validation).
        """
        super().__init__()
        
        # --- Ensure valid file extensions --- #
        valid_extensions = (".jpg", ".jpeg", ".png", ".bmp", ".tiff")
        hazy_data = [
            f for f in glob.glob(os.path.join(hazeeffected_images_dir, "*.*"))
            if f.lower().endswith(valid_extensions)
        ]

        if not hazy_data:
            raise ValueError(f"No valid images found in {hazeeffected_images_dir}")

        # # --- Sort and shuffle to ensure random split --- #
        hazy_data.sort()
        # shuffle(hazy_data)  

        # --- Split into train and validation --- #
        split_idx = int(len(hazy_data) * split_ratio)
        if split == "train":
            hazy_data = hazy_data[:split_idx]
        else:  # "valid"
            hazy_data = hazy_data[split_idx:]

        self.haze_names = []
        self.gt_names = []
        
        for h_image in hazy_data:
            filename = os.path.basename(h_image)
            haze_path = os.path.join(hazeeffected_images_dir, filename)
            gt_path = os.path.join(hazefree_images_dir, filename)

            if not os.path.exists(gt_path):
                print(f"Warning: Ground-truth missing for {filename}, skipping.")
                continue

            self.haze_names.append(haze_path)
            self.gt_names.append(gt_path)

        if not self.haze_names:
            raise ValueError("No matching ground-truth images found.")

        self.crop_size = crop_size

    def get_images(self, index):
        crop_width, crop_height = self.crop_size
        haze_name = self.haze_names[index]
        gt_name = self.gt_names[index]

        try:
            haze_img = Image.open(haze_name).convert('RGB')
            gt_img = Image.open(gt_name).convert('RGB')
        except UnidentifiedImageError:
            raise ValueError(f"Invalid image format: {haze_name} or {gt_name}")

        width, height = haze_img.size

        # --- Handle small images --- #
        if width < crop_width or height < crop_height:
            raise ValueError(f"Image too small for cropping: {haze_name}")

        # --- Random crop --- #
        x, y = randrange(0, width - crop_width + 1), randrange(0, height - crop_height + 1)
        haze_crop_img = haze_img.crop((x, y, x + crop_width, y + crop_height))
        gt_crop_img = gt_img.crop((x, y, x + crop_width, y + crop_height))

        # --- Transform to tensor --- #
        transform_haze = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        transform_gt = Compose([ToTensor()])
        haze = transform_haze(haze_crop_img)
        gt = transform_gt(gt_crop_img)

        # --- Check channels --- #
        if haze.shape[0] != 3 or gt.shape[0] != 3:
            raise ValueError(f"Invalid image channels: {haze_name}")

        return haze, gt

    def __getitem__(self, index):
        return self.get_images(index)

    def __len__(self):
        return len(self.haze_names)


## Validation

In [23]:
def to_psnr(dehaze, gt):
    """
    Compute PSNR (Peak Signal-to-Noise Ratio) between dehazed and ground truth images.

    Args:
        dehaze (torch.Tensor): Dehazed image tensor (B, C, H, W)
        gt (torch.Tensor): Ground truth image tensor (B, C, H, W)

    Returns:
        List[float]: PSNR values for each image in the batch.
    """
    mse = F.mse_loss(dehaze, gt, reduction='none').mean(dim=[1, 2, 3])  # Compute MSE per image
    intensity_max = 1.0

    # Compute PSNR safely, avoiding division by zero and extreme values
    psnr_list = [10.0 * log10(intensity_max / max(mse_val.item(), 1e-6)) for mse_val in mse]

    return psnr_list


In [24]:
from torchmetrics.image import StructuralSimilarityIndexMeasure

# Define SSIM metric
ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0, reduction='none')

def to_ssim(dehaze: torch.Tensor, gt: torch.Tensor):
    """
    Compute SSIM directly on the GPU using torchmetrics.

    Args:
        dehaze (torch.Tensor): Dehazed image tensor (B, C, H, W)
        gt (torch.Tensor): Ground truth image tensor (B, C, H, W)

    Returns:
        List[float]: SSIM values for each image in the batch.
    """
    ssim_values = ssim_metric(dehaze, gt)  # Shape: [B]
    # print("1",ssim_values)
    # print("2",[ssim_values])
    ssim_values = ssim_values.tolist() 
    # print(type(ssim_values))
    if isinstance(ssim_values, float):  # Correct way to check for a float
        return [ssim_values]  # Convert single float to a list
    return ssim_values  # Otherwise, return as is


In [25]:
# Test with a dummy tensor
dehaze = torch.rand(1, 3, 360, 360)  # Random batch of images
gt = torch.rand(1, 3, 360, 360)  # Random ground truth images

ssim_scores = to_ssim(dehaze, gt)
print(ssim_scores)  # Should print a list of 6 SSIM values

[0.0026210546493530273]


In [26]:
def validationB(net, val_data_loader, device, category, save_tag=False):
    """
    :param net: Your deep learning model
    :param val_data_loader: validation loader
    :param device: GPU/CPU device
    :param category: dataset type (indoor/outdoor)
    :param save_tag: whether to save images
    :return: average PSNR & SSIM values
    """
    psnr_list = []
    ssim_list = []
    
    for batch_id, val_data in enumerate(val_data_loader):
        with torch.no_grad():
            haze, gt = val_data
            haze, gt = haze.to(device), gt.to(device)
            dehaze, _ = net(haze)

        # --- Compute PSNR & SSIM --- #
        batch_psnr = to_psnr(dehaze, gt)  # This returns a list
        # print(batch_psnr)
        batch_ssim = to_ssim(dehaze, gt)  # This returns a list
        # print(batch_ssim)

        psnr_list.extend(batch_psnr)  # Flatten the list
        ssim_list.extend(batch_ssim)  # Flatten the list

    # --- Ensure lists are not empty to avoid division by zero --- #
    avr_psnr = sum(psnr_list) / len(psnr_list) if psnr_list else 0.0
    avr_ssim = sum(ssim_list) / len(ssim_list) if ssim_list else 0.0

    return avr_psnr, avr_ssim


In [None]:
def validation_sr(net, sr_val_loader, device):
    psnr_list = []
    ssim_list = []
    for lr, hr in sr_val_loader:
        with torch.no_grad():
            lr, hr = lr.to(device), hr.to(device)
            sr_out, _ = net(lr, sr = False)
        psnr_list.extend(to_psnr(sr_out, hr))
        ssim_list.extend(to_ssim(sr_out, hr))

    avr_psnr = sum(psnr_list) / len(psnr_list) if psnr_list else 0.0
    avr_ssim = sum(ssim_list) / len(ssim_list) if ssim_list else 0.0
    return avr_psnr, avr_ssim


### test

In [28]:
# old_val_psnr, old_val_ssim = validationB(net, val_data_loader, device, category)

In [29]:
# psnr, ssim = validationB(net, val_data_loader, device, category)
# # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# # psnr, ssim = validationB(model, val_loader, device, "indoor", save_tag=True)
# print(f"Validation PSNR: {psnr:.2f}, SSIM: {ssim:.4f}")


In [30]:
execution_env_widget = widgets.Dropdown(options=['local', 'kaggle'], value='local', description='Execution Env:')
display(execution_env_widget)

if os.path.exists('/kaggle'):
    execution_env_widget.value = 'kaggle' 

Dropdown(description='Execution Env:', options=('local', 'kaggle'), value='local')

In [31]:

# --- Create widgets for each hyper-parameter ---
learning_rate_widget = widgets.FloatText(value=1e-4, description='Learning Rate:')
crop_size_widget = widgets.Text(value='128,128', description='Crop Size:')
train_batch_size_widget = widgets.IntText(value=6, description='Train Batch Size:')
version_widget = widgets.IntText(value=0, description='Version:')
growth_rate_widget = widgets.IntText(value=16, description='Growth Rate:')
lambda_loss_widget = widgets.FloatText(value=0.04, description='Lambda Loss:')
val_batch_size_widget = widgets.IntText(value=2, description='Val Batch Size:')
category_widget = widgets.Dropdown(options=['indoor', 'outdoor', 'reside', 'nh'], value='reside', description='Category:')

# --- Display the widgets ---
display(
    learning_rate_widget, crop_size_widget, train_batch_size_widget, version_widget,
    growth_rate_widget, lambda_loss_widget, 
    val_batch_size_widget, category_widget
)

# --- Function to parse crop size ---
def parse_crop_size(crop_size_str):
    return [int(x) for x in crop_size_str.split(',')]

# --- Assign the widget values to variables ---
learning_rate = learning_rate_widget.value
crop_size = parse_crop_size(crop_size_widget.value)
train_batch_size = train_batch_size_widget.value
version = version_widget.value
growth_rate = growth_rate_widget.value
lambda_loss = lambda_loss_widget.value
val_batch_size = val_batch_size_widget.value
category = category_widget.value

execution_env = execution_env_widget.value  # Local or Kaggle


print('\nHyper-parameters set:')
print(f'learning_rate: {learning_rate}')
print(f'crop_size: {crop_size}')
print(f'train_batch_size: {train_batch_size}')
print(f'version: {version}')
print(f'growth_rate: {growth_rate}')
print(f'lambda_loss: {lambda_loss}')
print(f'val_batch_size: {val_batch_size}')
print(f'category: {category}')
print(f'execution_env: {execution_env}')

# --- Set category-specific hyper-parameters ---
if category == 'indoor':
    num_epochs = 1500
    train_data_dir = './data/train/indoor/'
    val_data_dir = './data/test/SOTS/indoor/'
elif category == 'outdoor':
    num_epochs = 10
    train_data_dir = './data/train/outdoor/'
    val_data_dir = './data/test/SOTS/outdoor/'
elif category == 'reside':
    num_epochs = 85
    train_data_dir = '/kaggle/input/reside6k/RESIDE-6K/train'
    val_data_dir = '/kaggle/input/reside6k/RESIDE-6K/train'
    test_data_dir = '/kaggle/input/reside6k/RESIDE-6K/test'
elif category == 'nh':
    num_epochs = 50
    train_data_dir = '/Volumes/S/dev/project/code/Aphase/Dehaze_2/data/NH-Haze_Dense-Haze_datasets/NH-HAZE-T/train/hazy'
    val_data_dir = '/Volumes/S/dev/project/code/Aphase/Dehaze_2/data/NH-Haze_Dense-Haze_datasets/NH-HAZE-T/train/GT'
else:
    raise Exception('Wrong image category. Set it to indoor or outdoor for RESIDE dataset.')

# --- Adjust paths based on execution environment ---
# if execution_env == 'kaggle':
    # train_data_dir = '/kaggle/input/reside-dataset/' + train_data_dir.strip('./')
    # val_data_dir = '/kaggle/input/reside-dataset/' + val_data_dir.strip('./')
    # train_data_dir = '/kaggle/input/nh-dense-haze/NH-HAZE-T/NH-HAZE-T'
    # val_data_dir = '/kaggle/input/nh-dense-haze/NH-HAZE-V/NH-HAZE-V' 
    # train_data_dir = '/kaggle/input/o-haze/O-HAZY/hazy'
    # val_data_dir = '/kaggle/input/o-haze/O-HAZY/GT' 
print('\nFinal dataset paths:')
print(f'Training directory: {train_data_dir}')
print(f'Validation directory: {val_data_dir}')
print(f'Number of epochs: {num_epochs}')


FloatText(value=0.0001, description='Learning Rate:')

Text(value='128,128', description='Crop Size:')

IntText(value=6, description='Train Batch Size:')

IntText(value=0, description='Version:')

IntText(value=16, description='Growth Rate:')

FloatText(value=0.04, description='Lambda Loss:')

IntText(value=2, description='Val Batch Size:')

Dropdown(description='Category:', index=2, options=('indoor', 'outdoor', 'reside', 'nh'), value='reside')


Hyper-parameters set:
learning_rate: 0.0001
crop_size: [128, 128]
train_batch_size: 6
version: 0
growth_rate: 16
lambda_loss: 0.04
val_batch_size: 2
category: reside
execution_env: kaggle

Final dataset paths:
Training directory: /kaggle/input/reside6k/RESIDE-6K/train
Validation directory: /kaggle/input/reside6k/RESIDE-6K/train
Number of epochs: 85


In [32]:
# hazeeffected_images_dir_train = f"{train_data_dir}/IN"
hazeeffected_images_dir_train = f"{train_data_dir}/hazy"
hazefree_images_dir_train = f"{train_data_dir}/GT"

# hazeeffected_images_dir_valid = f"{val_data_dir}/IN"
hazeeffected_images_dir_valid = f"{val_data_dir}/hazy"
hazefree_images_dir_valid = f"{val_data_dir}/GT"

In [33]:
# ..

In [34]:
# import os
# import glob
# import shutil

# hazeeffected_images_dir_train = f"{train_data_dir}/IN"
# hazefree_images_dir_train = f"{train_data_dir}/GT"

# hazeeffected_images_dir_valid = f"{val_data_dir}/IN"
# hazefree_images_dir_valid = f"{val_data_dir}/GT"

# # Create validation directories if they don't exist
# os.makedirs(hazeeffected_images_dir_valid, exist_ok=True)
# os.makedirs(hazefree_images_dir_valid, exist_ok=True)

# # List all hazy and clean images
# hazy_images = sorted(glob.glob(f"{hazeeffected_images_dir_train}/*"))
# clean_images = sorted(glob.glob(f"{hazefree_images_dir_train}/*"))

# # Ensure matching hazy-clean pairs
# assert len(hazy_images) == len(clean_images), "Mismatch in hazy and clean images count!"

# # Shuffle while keeping the hazy-clean correspondence
# paired_images = list(zip(hazy_images, clean_images))
# # random.shuffle(paired_images)

# # Define split ratio (e.g., 80% train, 20% validation)
# split_ratio = 0.8
# split_idx = int(len(paired_images) * split_ratio)

# # Split into train and validation
# train_pairs = paired_images[:split_idx]
# valid_pairs = paired_images[split_idx:]

# # Move validation images
# for hazy_path, clean_path in valid_pairs:
#     shutil.move(hazy_path, hazeeffected_images_dir_valid)
#     shutil.move(clean_path, hazefree_images_dir_valid)

# print(f"Moved {len(valid_pairs)} image pairs to validation set.")


In [35]:
# # hazeeffected_images_dir = '/Volumes/S/dev/project/code/Aphase/Dehaze_2/data/NH-Haze_Dense-Haze_datasets/NH-HAZE-T/train/hazy'
# # hazefree_images_dir = '/Volumes/S/dev/project/code/Aphase/Dehaze_2/data/NH-Haze_Dense-Haze_datasets/NH-HAZE-T/train/GT'
# # hazeeffected_images_dir = '/kaggle/input/o-haze/O-HAZY/hazy'
# # hazefree_images_dir = '/kaggle/input/o-haze/O-HAZY/GT'

# hazeeffected_images_dir_train = '/kaggle/input/nh-dense-haze/NH-HAZE-T/NH-HAZE-T/IN'
# hazefree_images_dir_train = '/kaggle/input/nh-dense-haze/NH-HAZE-T/NH-HAZE-T/GT'
# hazeeffected_images_dir_valid = '/kaggle/input/nh-dense-haze/NH-HAZE-V/NH-HAZE-V/IN'
# hazefree_images_dir_valid = '/kaggle/input/nh-dense-haze/NH-HAZE-V/NH-HAZE-V/GT'

In [36]:
def print_log(epoch, num_epochs, one_epoch_time, train_psnr, val_psnr, val_ssim, category):
    log_dir = "./training_log"
    os.makedirs(log_dir, exist_ok=True)  # Ensure the directory exists

    log_path = os.path.join(log_dir, f"{category}_log.txt")

    print('({0:.0f}s) Epoch [{1}/{2}], Train_PSNR:{3:.2f}, Val_PSNR:{4:.2f}, Val_SSIM:{5:.4f}'
          .format(one_epoch_time, epoch, num_epochs, train_psnr, val_psnr, val_ssim))

    # --- Write the training log --- #
    with open(log_path, 'a') as f:
        print('Date: {0}, Time_Cost: {1:.0f}s, Epoch: [{2}/{3}], Train_PSNR: {4:.2f}, Val_PSNR: {5:.2f}, Val_SSIM: {6:.4f}'
              .format(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
                      one_epoch_time, epoch, num_epochs, train_psnr, val_psnr, val_ssim), file=f)

In [37]:
def adjust_learning_rate(optimizer, epoch, category, lr_decay=0.90):
    """
    Adjusts the learning rate based on the epoch and dataset category.

    :param optimizer: The optimizer (e.g., Adam, SGD).
    :param epoch: Current epoch number.
    :param category: Dataset category ('indoor', 'outdoor', or 'NH').
    :param lr_decay: Multiplicative factor for learning rate decay.
    """
    # Define learning rate decay steps based on category
    step_dict = {'indoor': 18, 'outdoor': 3, 'NH': 20}
    step = step_dict.get(category, 3)  # Default step size if category is unknown

    # Decay learning rate at the specified step
    if epoch > 0 and epoch % step == 0:
        for param_group in optimizer.param_groups:
            param_group['lr'] *= lr_decay
            print(f"Epoch {epoch}: Learning rate adjusted to {param_group['lr']:.6f}")


## Perceptual Loss

In [None]:
# --- Perceptual Feature Loss Network --- #
class PerceptualLossNet(nn.Module):
    def __init__(self, vgg_model):
        super().__init__()
        self.feature_extractor = vgg_model
        self.feature_layers = {'3': "low_level", '8': "mid_level", '15': "high_level"}

    def get_feature_maps(self, x):
        feature_maps = []
        for layer_id, layer in self.feature_extractor.named_children():
            x = layer(x)
            if layer_id in self.feature_layers:
                feature_maps.append(x)
        return feature_maps

    def forward(self, predicted, target):
        pred_features = self.get_feature_maps(predicted)
        target_features = self.get_feature_maps(target)
        
        # Compute perceptual loss as mean squared error across feature maps
        loss = torch.stack([F.mse_loss(p, t) for p, t in zip(pred_features, target_features)]).mean()
        
        return loss

class SSFM(nn.Module):
    def __init__(self, loss_type='l1'):
        super(SSFM, self).__init__()
        assert loss_type in ['l1', 'l2'], "loss_type must be 'l1' or 'l2'"
        self.loss_type = loss_type

    def forward(self, student_feats, teacher_feats):
        """
        student_feats: List of feature maps from student RDBs [rdb1, rdb2, rdb3, rdb4]
        teacher_feats: List of corresponding feature maps from teacher
        """
        assert len(student_feats) == len(teacher_feats), "Feature lists must match"

        total_loss = 0.0
        for s_feat, t_feat in zip(student_feats, teacher_feats):
            # Match resolution
            if s_feat.shape != t_feat.shape:
                t_feat = F.interpolate(t_feat, size=s_feat.shape[2:], mode='bilinear', align_corners=False)
            
            if self.loss_type == 'l1':
                loss = F.l1_loss(s_feat, t_feat)
            else:
                loss = F.mse_loss(s_feat, t_feat)
            
            total_loss += loss

        return total_loss / len(student_feats)


In [39]:
# --- Imports --- #
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.models import vgg16

# --- Device Setup --- #
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device_ids = list(range(torch.cuda.device_count()))

# --- Initialize Model --- #
net = DeepGuidedNetwork().to(device)

# --- Enable Multi-GPU (if available) --- #
if len(device_ids) > 1:
    net = nn.DataParallel(net, device_ids=device_ids)

# --- Optimizer --- #
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

# --- Load Pretrained VGG16 for Perceptual Loss --- #
vgg_features = vgg16(pretrained=True).features[:16].to(device)
for param in vgg_features.parameters():
    param.requires_grad = False

loss_network = PerceptualLossNet(vgg_features)
loss_network.eval()

# --- Load Model Weights (if available) --- #
model_name = 'formernew'
# checkpoint_path = f"{model_name}_{category}_haze_best_{version}"
checkpoint_path = "/kaggle/input/reside-dehaze/pytorch/default/2/formernewreside_haze_iter_85.pth" 

try:
    net.load_state_dict(torch.load(checkpoint_path, weights_only=False, map_location=torch.device('cpu')))
    print(f"✅ Model weights loaded from {checkpoint_path}")
except FileNotFoundError:
    print(f"⚠️ No pretrained weights found at {checkpoint_path}")

# --- Compute Total Trainable Parameters --- #
total_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print(f"📊 Total Trainable Parameters: {total_params:,}")

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]




Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth



  0%|          | 0.00/528M [00:00<?, ?B/s]


  3%|▎         | 13.5M/528M [00:00<00:03, 141MB/s]


  7%|▋         | 34.8M/528M [00:00<00:02, 189MB/s]


 11%|█         | 55.8M/528M [00:00<00:02, 203MB/s]


 14%|█▍        | 75.1M/528M [00:00<00:02, 199MB/s]


 18%|█▊        | 94.1M/528M [00:00<00:02, 198MB/s]


 21%|██▏       | 113M/528M [00:00<00:02, 198MB/s] 


 25%|██▌       | 132M/528M [00:00<00:02, 199MB/s]


 29%|██▊       | 151M/528M [00:00<00:01, 198MB/s]


 32%|███▏      | 170M/528M [00:00<00:01, 198MB/s]


 36%|███▌      | 190M/528M [00:01<00:01, 199MB/s]


 40%|███▉      | 209M/528M [00:01<00:01, 200MB/s]


 43%|████▎     | 228M/528M [00:01<00:01, 199MB/s]


 47%|████▋     | 247M/528M [00:01<00:01, 199MB/s]


 50%|█████     | 266M/528M [00:01<00:01, 199MB/s]


 54%|█████▍    | 285M/528M [00:01<00:01, 197MB/s]


 58%|█████▊    | 304M/528M [00:01<00:01, 197MB/s]


 61%|██████▏   | 323M/528M [00:01<00:01, 198MB/s]


 65%|██████▍   | 343M/528M [00:01<00:00, 199MB/s]


 69%|██████▊   | 362M/528M [00:01<00:00, 192MB/s]


 72%|███████▏  | 380M/528M [00:02<00:00, 190MB/s]


 76%|███████▌  | 402M/528M [00:02<00:00, 200MB/s]


 80%|███████▉  | 421M/528M [00:02<00:00, 200MB/s]


 83%|████████▎ | 440M/528M [00:02<00:00, 199MB/s]


 87%|████████▋ | 459M/528M [00:02<00:00, 198MB/s]


 91%|█████████ | 478M/528M [00:02<00:00, 198MB/s]


 94%|█████████▍| 497M/528M [00:02<00:00, 198MB/s]


 98%|█████████▊| 516M/528M [00:02<00:00, 198MB/s]


100%|██████████| 528M/528M [00:02<00:00, 197MB/s]




✅ Model weights loaded from /kaggle/input/reside-dehaze/pytorch/default/2/formernewreside_haze_iter_85.pth
📊 Total Trainable Parameters: 4,645,694


In [40]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FeatureAffinityModule(nn.Module):
    def __init__(self, channels):
        super(FeatureAffinityModule, self).__init__()
        self.channels = channels

    def forward(self, student_features, teacher_features):
        # Normalize features
        student_norm = F.normalize(student_features.view(student_features.size(0), self.channels, -1), dim=2)
        teacher_norm = F.normalize(teacher_features.view(teacher_features.size(0), self.channels, -1), dim=2)

        # Compute affinity matrices
        student_affinity = torch.bmm(student_norm, student_norm.transpose(1, 2))
        teacher_affinity = torch.bmm(teacher_norm, teacher_norm.transpose(1, 2))

        # Compute KL divergence
        loss = F.kl_div(F.log_softmax(student_affinity, dim=-1),
                        F.softmax(teacher_affinity, dim=-1),
                        reduction='batchmean')
        return loss


In [41]:
# Create train and validation datasets
train_dataset = HazeDataset(crop_size=crop_size, 
                            hazeeffected_images_dir=hazeeffected_images_dir_train,
                            hazefree_images_dir=hazefree_images_dir_train,
                            split="train")

val_dataset = HazeDataset(crop_size=crop_size, 
                          hazeeffected_images_dir=hazeeffected_images_dir_train,
                          hazefree_images_dir=hazefree_images_dir_train,
                          split="valid")

print(f"Train samples: {len(train_dataset)}, Validation samples: {len(val_dataset)}")

Train samples: 4800, Validation samples: 1200


In [42]:
train_data_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
val_data_loader = DataLoader(val_dataset, batch_size=val_batch_size, shuffle=False)

In [43]:
import os
import glob
import torch
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
from PIL import Image, UnidentifiedImageError

class SRDataset(Dataset):
    def __init__(self, lr_dir, hr_dir, scale='x2', split='train', split_ratio=0.9):
        """
        Super-Resolution dataset that matches LR and HR image pairs based on naming pattern,
        with support for train/val split.

        Args:
            lr_dir (str): Directory containing low-resolution images (e.g., x2, x3, x4).
            hr_dir (str): Directory containing high-resolution images.
            scale (str): Scale suffix (e.g., 'x2', 'x3', 'x4').
            split (str): Either 'train' or 'val'.
            split_ratio (float): Ratio of training data (e.g., 0.9 means 90% train, 10% val).
        """
        super().__init__()
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        self.scale = scale
        self.split = split.lower()

        valid_ext = (".jpg", ".jpeg", ".png", ".bmp", ".tiff")
        lr_images = sorted([
            f for f in glob.glob(os.path.join(lr_dir, "*.*"))
            if f.lower().endswith(valid_ext)
        ])

        lr_hr_pairs = []
        for lr_path in lr_images:
            lr_name = os.path.basename(lr_path)
            hr_name = lr_name.replace(scale, '')
            hr_path = os.path.join(hr_dir, hr_name)

            if not os.path.exists(hr_path):
                print(f"Warning: Ground-truth missing for {lr_name}, skipping.")
                continue

            lr_hr_pairs.append((lr_path, hr_path))

        if not lr_hr_pairs:
            raise ValueError("No matching LR-HR image pairs found.")

        # Split dataset
        split_idx = int(len(lr_hr_pairs) * split_ratio)
        if self.split == 'train':
            self.lr_hr_pairs = lr_hr_pairs[:split_idx]
        elif self.split == 'val':
            self.lr_hr_pairs = lr_hr_pairs[split_idx:]
        else:
            raise ValueError("split must be either 'train' or 'val'")

    def __len__(self):
        return len(self.lr_hr_pairs)

    def __getitem__(self, idx):
        lr_path, hr_path = self.lr_hr_pairs[idx]

        try:
            lr_img = Image.open(lr_path).convert('RGB')
            hr_img = Image.open(hr_path).convert('RGB')
        except UnidentifiedImageError:
            raise ValueError(f"Unidentified image at {lr_path} or {hr_path}")

        return ToTensor()(lr_img), ToTensor()(hr_img)


In [44]:
from torchvision.transforms import functional as TF

def custom_collate_fn(batch):
    min_height = min([x[0].shape[1] for x in batch])//4
    min_width = min([x[0].shape[2] for x in batch])//4
    resized_batch = [(TF.resize(x[0], [min_height, min_width]), TF.resize(x[1], [min_height, min_width])) for x in batch]
    return torch.utils.data.dataloader.default_collate(resized_batch)


In [45]:
# # Paths
# sr_hr_dir = '/kaggle/input/flickr2k/Flickr2K/Flickr2K_HR'
# sr_lr_dir = '/kaggle/input/flickr2k/Flickr2K/Flickr2K_LR_bicubic/X2'

# # Train SR DataLoader
# sr_train_dataset = SRDataset(lr_dir, hr_dir, scale='x2', split='train')
# sr_valid_dataset = SRDataset(lr_dir, hr_dir, scale='x2', split='val')

# sr_train_loader = DataLoader(sr_train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=custom_collate_fn)

# # If you have a separate validation split:
# sr_val_loader = DataLoader(sr_valid_dataset, batch_size=batch_size, shuffle=False, num_workers=2, collate_fn=custom_collate_fn)


In [46]:
# train_data_loader = DataLoader(TrainData(crop_size, hazeeffected_images_dir_train, hazefree_images_dir_train), batch_size=train_batch_size, shuffle=True)
# val_data_loader = DataLoader(TrainData(crop_size, hazeeffected_images_dir_valid, hazefree_images_dir_valid), batch_size=val_batch_size, shuffle=False)

In [47]:
# train_size = len(TrainData(crop_size, hazeeffected_images_dir_train, hazefree_images_dir_train))
# val_size = len(TrainData(crop_size, hazeeffected_images_dir_valid, hazefree_images_dir_valid))

# print(f"Train Size: {train_size}, Val Size: {val_size}")

In [48]:
# --- SR Dataset Setup --- #
sr_enabled = True
if sr_enabled:
    sr_hr_dir = '/kaggle/input/flickr2k/Flickr2K/Flickr2K_HR'
    sr_lr_dir = '/kaggle/input/flickr2k/Flickr2K/Flickr2K_LR_bicubic/X2'
    sr_train_dataset = SRDataset(lr_dir=sr_lr_dir, hr_dir=sr_hr_dir, scale='x2', split='train')
    sr_val_dataset = SRDataset(lr_dir=sr_lr_dir, hr_dir=sr_hr_dir, scale='x2', split='val')
    sr_train_loader = DataLoader(sr_train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=custom_collate_fn)
    sr_val_loader = DataLoader(sr_val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=custom_collate_fn)
    sr_iter = iter(sr_train_loader)


In [49]:
for i,o in train_data_loader:
    print(i.shape, o.shape)
    break

torch.Size([6, 3, 128, 128]) torch.Size([6, 3, 128, 128])


In [50]:
for i,o in sr_val_loader:
    print(i.shape, o.shape)
    break

torch.Size([2, 3, 175, 255]) torch.Size([2, 3, 175, 255])


In [51]:
# ..

In [None]:
# --- Teacher Network --- #
teacher_net = DeepGuidedNetwork(radius=1).to(device)
# teacher_net.load_state_dict(torch.load('teacher_model.pth'))
# teacher_net.eval()

# --- Feature Affinity Module --- #
fam = FeatureAffinityModule(channels=64).to(device)
ssfm_loss = SSFM(loss_type='l1') 


# --- Initial Validation --- #
old_val_psnr, old_val_ssim = validationB(net, val_data_loader, device, category)
print(f"[Dehazing Init Val] PSNR: {old_val_psnr:.2f}, SSIM: {old_val_ssim:.4f}")
if sr_enabled:
    sr_val_psnr, sr_val_ssim = validation_sr(net, sr_val_loader, device)
    print(f"[SR Init Val] PSNR: {sr_val_psnr:.2f}, SSIM: {sr_val_ssim:.4f}")

# --- Training Loop --- #
best_psnr = old_val_psnr
train_psnr_prev = 0
distillation_weight = 1

for epoch in range(num_epochs):
    psnr_list = []
    start_time = time.time()

    adjust_learning_rate(optimizer, epoch, category=category)
    net.train()

    for batch_id, (haze, gt) in enumerate(train_data_loader):
        haze, gt = haze.to(device), gt.to(device)
        optimizer.zero_grad()

        # Forward Pass - Student
        dehaze, base, s1, s2, s3, s4 = net(haze)

        # Teacher Output
        with torch.no_grad():
            teacher_dehaze, _, t1, t2, t3, t4 = teacher_net(haze)

        # Losses
        base_loss = F.smooth_l1_loss(base, gt)
        smooth_loss = F.smooth_l1_loss(dehaze, gt)
        perceptual_loss = loss_network(dehaze, gt)
        
        student_feats = [s1, s2, s3, s4]  
        teacher_feats = [t1, t2, t3, t4]  
        s_loss = ssfm_loss(student_feats, teacher_feats)
        
        distillation_loss = fam(dehaze, teacher_dehaze)
        # print("distillation_loss: ", distillation_loss)
        total_loss = smooth_loss + lambda_loss * perceptual_loss + base_loss + distillation_weight * distillation_loss + s_loss

        # --- SR Training --- #
        if sr_enabled:
            try:
                sr_lr, sr_hr = next(sr_iter)
            except StopIteration:
                sr_iter = iter(sr_train_loader)
                sr_lr, sr_hr = next(sr_iter)
            sr_lr, sr_hr = sr_lr.to(device), sr_hr.to(device)
            sr_out, _ = net(sr_lr)
            sr_loss = F.l1_loss(sr_out, sr_hr)
            total_loss += sr_loss

        total_loss.backward()
        optimizer.step()
        psnr_list.extend(to_psnr(dehaze, gt))

        if batch_id % num_epochs == 0:
            print(f"Epoch [{epoch}/{num_epochs}], Iteration [{batch_id}]")

    # Save model checkpoint
    if epoch % 5 == 0:
        iter_model_path = f"{model_name}{category}_haze_iter_{epoch}.pth"
        torch.save(net.state_dict(), iter_model_path)
        print(f"Model saved in epoch {epoch}.")

    train_psnr = sum(psnr_list) / len(psnr_list)
    model_path = f"{model_name}{category}_haze_{version}.pth"

    # --- Validation --- #
    net.eval()
    val_psnr, val_ssim = validationB(net, val_data_loader, device, category)
    if sr_enabled:
        sr_val_psnr, sr_val_ssim = validation_sr(net, sr_val_loader, device)
        print(f"[SR Val] PSNR: {sr_val_psnr:.2f}, SSIM: {sr_val_ssim:.4f}")

    epoch_duration = time.time() - start_time
    print_log(epoch + 1, num_epochs, epoch_duration, train_psnr, val_psnr, val_ssim, model_path)

    if train_psnr < train_psnr_prev:
        adjust_learning_rate(optimizer, num_epochs, category=category)

    if val_psnr >= best_psnr:
        best_model_path = f"{model_name}{category}_haze_best_{version}.pth"
        torch.save(net.state_dict(), best_model_path)
        best_psnr = val_psnr

    train_psnr_prev = train_psnr

# Final save
final_path = f"{model_name}{category}_final_{epoch}.pth"
torch.save(net.state_dict(), final_path)

[Dehazing Init Val] PSNR: 28.81, SSIM: 0.9250


[SR Init Val] PSNR: 27.26, SSIM: 0.9609


Epoch [0/85], Iteration [0]


Epoch [0/85], Iteration [85]


Epoch [0/85], Iteration [170]


Epoch [0/85], Iteration [255]


Epoch [0/85], Iteration [340]


Epoch [0/85], Iteration [425]


Epoch [0/85], Iteration [510]


Epoch [0/85], Iteration [595]


Epoch [0/85], Iteration [680]


Epoch [0/85], Iteration [765]


Model saved in epoch 0.


[SR Val] PSNR: 29.76, SSIM: 0.9338
(479s) Epoch [1/85], Train_PSNR:12.36, Val_PSNR:12.61, Val_SSIM:0.4007


Epoch [1/85], Iteration [0]


Epoch [1/85], Iteration [85]


Epoch [1/85], Iteration [170]


Epoch [1/85], Iteration [255]


Epoch [1/85], Iteration [340]


Epoch [1/85], Iteration [425]


Epoch [1/85], Iteration [510]


Epoch [1/85], Iteration [595]


Epoch [1/85], Iteration [680]


Epoch [1/85], Iteration [765]


[SR Val] PSNR: 25.77, SSIM: 0.8795
(412s) Epoch [2/85], Train_PSNR:12.21, Val_PSNR:12.80, Val_SSIM:0.4087


Epoch [2/85], Iteration [0]


Epoch [2/85], Iteration [85]


Epoch [2/85], Iteration [170]


Epoch [2/85], Iteration [255]


Epoch [2/85], Iteration [340]


Epoch [2/85], Iteration [425]


Epoch [2/85], Iteration [510]


Epoch [2/85], Iteration [595]


Epoch [2/85], Iteration [680]


Epoch [2/85], Iteration [765]


[SR Val] PSNR: 17.86, SSIM: 0.7113
(416s) Epoch [3/85], Train_PSNR:12.19, Val_PSNR:12.00, Val_SSIM:0.3621
Epoch 3: Learning rate adjusted to 0.000090


Epoch [3/85], Iteration [0]


Epoch [3/85], Iteration [85]


Epoch [3/85], Iteration [170]


Epoch [3/85], Iteration [255]


Epoch [3/85], Iteration [340]


Epoch [3/85], Iteration [425]


Epoch [3/85], Iteration [510]


Epoch [3/85], Iteration [595]


Epoch [3/85], Iteration [680]


Epoch [3/85], Iteration [765]


[SR Val] PSNR: 36.18, SSIM: 0.9714
(422s) Epoch [4/85], Train_PSNR:12.18, Val_PSNR:12.67, Val_SSIM:0.3995


Epoch [4/85], Iteration [0]


Epoch [4/85], Iteration [85]


Epoch [4/85], Iteration [170]


Epoch [4/85], Iteration [255]


Epoch [4/85], Iteration [340]


Epoch [4/85], Iteration [425]


Epoch [4/85], Iteration [510]


Epoch [4/85], Iteration [595]


Epoch [4/85], Iteration [680]


Epoch [4/85], Iteration [765]


[SR Val] PSNR: 17.57, SSIM: 0.7292
(436s) Epoch [5/85], Train_PSNR:12.35, Val_PSNR:11.60, Val_SSIM:0.3725


Epoch [5/85], Iteration [0]


Epoch [5/85], Iteration [85]


Epoch [5/85], Iteration [170]


Epoch [5/85], Iteration [255]


Epoch [5/85], Iteration [340]


Epoch [5/85], Iteration [425]


Epoch [5/85], Iteration [510]


Epoch [5/85], Iteration [595]


Epoch [5/85], Iteration [680]


Epoch [5/85], Iteration [765]


Model saved in epoch 5.


[SR Val] PSNR: 35.60, SSIM: 0.9582
(445s) Epoch [6/85], Train_PSNR:12.16, Val_PSNR:12.57, Val_SSIM:0.4083
Epoch 6: Learning rate adjusted to 0.000081


Epoch [6/85], Iteration [0]


Epoch [6/85], Iteration [85]


Epoch [6/85], Iteration [170]


Epoch [6/85], Iteration [255]


Epoch [6/85], Iteration [340]


Epoch [6/85], Iteration [425]


Epoch [6/85], Iteration [510]


Epoch [6/85], Iteration [595]


Epoch [6/85], Iteration [680]


Epoch [6/85], Iteration [765]


[SR Val] PSNR: 37.72, SSIM: 0.9740
(452s) Epoch [7/85], Train_PSNR:12.22, Val_PSNR:12.91, Val_SSIM:0.4182


Epoch [7/85], Iteration [0]


Epoch [7/85], Iteration [85]


Epoch [7/85], Iteration [170]


Epoch [7/85], Iteration [255]


Epoch [7/85], Iteration [340]


Epoch [7/85], Iteration [425]


Epoch [7/85], Iteration [510]


Epoch [7/85], Iteration [595]


Epoch [7/85], Iteration [680]


Epoch [7/85], Iteration [765]


[SR Val] PSNR: 33.14, SSIM: 0.9387
(463s) Epoch [8/85], Train_PSNR:12.44, Val_PSNR:12.86, Val_SSIM:0.4072


Epoch [8/85], Iteration [0]


Epoch [8/85], Iteration [85]


Epoch [8/85], Iteration [170]


Epoch [8/85], Iteration [255]


Epoch [8/85], Iteration [340]


Epoch [8/85], Iteration [425]


Epoch [8/85], Iteration [510]


Epoch [8/85], Iteration [595]


Epoch [8/85], Iteration [680]


Epoch [8/85], Iteration [765]


[SR Val] PSNR: 37.89, SSIM: 0.9790
(468s) Epoch [9/85], Train_PSNR:12.36, Val_PSNR:12.32, Val_SSIM:0.3975
Epoch 9: Learning rate adjusted to 0.000073


Epoch [9/85], Iteration [0]


Epoch [9/85], Iteration [85]


Epoch [9/85], Iteration [170]


Epoch [9/85], Iteration [255]


Epoch [9/85], Iteration [340]


Epoch [9/85], Iteration [425]


Epoch [9/85], Iteration [510]


Epoch [9/85], Iteration [595]


Epoch [9/85], Iteration [680]


Epoch [9/85], Iteration [765]


[SR Val] PSNR: 37.85, SSIM: 0.9748
(479s) Epoch [10/85], Train_PSNR:12.37, Val_PSNR:12.91, Val_SSIM:0.3997


Epoch [10/85], Iteration [0]


Epoch [10/85], Iteration [85]


Epoch [10/85], Iteration [170]


Epoch [10/85], Iteration [255]


Epoch [10/85], Iteration [340]


Epoch [10/85], Iteration [425]


Epoch [10/85], Iteration [510]


Epoch [10/85], Iteration [595]


Epoch [10/85], Iteration [680]


Epoch [10/85], Iteration [765]


Model saved in epoch 10.


[SR Val] PSNR: 40.20, SSIM: 0.9834
(490s) Epoch [11/85], Train_PSNR:12.50, Val_PSNR:12.75, Val_SSIM:0.3874


Epoch [11/85], Iteration [0]


Epoch [11/85], Iteration [85]


Epoch [11/85], Iteration [170]


Epoch [11/85], Iteration [255]


Epoch [11/85], Iteration [340]


Epoch [11/85], Iteration [425]


Epoch [11/85], Iteration [510]


Epoch [11/85], Iteration [595]


Epoch [11/85], Iteration [680]


Epoch [11/85], Iteration [765]


[SR Val] PSNR: 40.86, SSIM: 0.9881
(499s) Epoch [12/85], Train_PSNR:12.63, Val_PSNR:12.99, Val_SSIM:0.4127
Epoch 12: Learning rate adjusted to 0.000066


Epoch [12/85], Iteration [0]


Epoch [12/85], Iteration [85]


Epoch [12/85], Iteration [170]


Epoch [12/85], Iteration [255]


Epoch [12/85], Iteration [340]


Epoch [12/85], Iteration [425]


Epoch [12/85], Iteration [510]


Epoch [12/85], Iteration [595]


Epoch [12/85], Iteration [680]


Epoch [12/85], Iteration [765]


[SR Val] PSNR: 39.92, SSIM: 0.9815
(504s) Epoch [13/85], Train_PSNR:12.64, Val_PSNR:12.89, Val_SSIM:0.4123


Epoch [13/85], Iteration [0]


Epoch [13/85], Iteration [85]


Epoch [13/85], Iteration [170]


Epoch [13/85], Iteration [255]


Epoch [13/85], Iteration [340]


Epoch [13/85], Iteration [425]


Epoch [13/85], Iteration [510]


Epoch [13/85], Iteration [595]


Epoch [13/85], Iteration [680]


Epoch [13/85], Iteration [765]


[SR Val] PSNR: 38.93, SSIM: 0.9711
(514s) Epoch [14/85], Train_PSNR:12.46, Val_PSNR:12.69, Val_SSIM:0.3876


Epoch [14/85], Iteration [0]


Epoch [14/85], Iteration [85]


Epoch [14/85], Iteration [170]


Epoch [14/85], Iteration [255]


Epoch [14/85], Iteration [340]


Epoch [14/85], Iteration [425]


Epoch [14/85], Iteration [510]


Epoch [14/85], Iteration [595]


Epoch [14/85], Iteration [680]


Epoch [14/85], Iteration [765]


[SR Val] PSNR: 38.61, SSIM: 0.9693
(517s) Epoch [15/85], Train_PSNR:12.45, Val_PSNR:12.69, Val_SSIM:0.3967
Epoch 15: Learning rate adjusted to 0.000059


Epoch [15/85], Iteration [0]


Epoch [15/85], Iteration [85]


Epoch [15/85], Iteration [170]


Epoch [15/85], Iteration [255]


Epoch [15/85], Iteration [340]


Epoch [15/85], Iteration [425]


Epoch [15/85], Iteration [510]


Epoch [15/85], Iteration [595]


Epoch [15/85], Iteration [680]


Epoch [15/85], Iteration [765]


Model saved in epoch 15.


[SR Val] PSNR: 38.45, SSIM: 0.9692
(522s) Epoch [16/85], Train_PSNR:12.64, Val_PSNR:12.67, Val_SSIM:0.4031


Epoch [16/85], Iteration [0]


Epoch [16/85], Iteration [85]


Epoch [16/85], Iteration [170]


Epoch [16/85], Iteration [255]


Epoch [16/85], Iteration [340]


Epoch [16/85], Iteration [425]


Epoch [16/85], Iteration [510]


Epoch [16/85], Iteration [595]


Epoch [16/85], Iteration [680]


Epoch [16/85], Iteration [765]


[SR Val] PSNR: 42.76, SSIM: 0.9923
(536s) Epoch [17/85], Train_PSNR:12.55, Val_PSNR:13.15, Val_SSIM:0.4250


Epoch [17/85], Iteration [0]


Epoch [17/85], Iteration [85]


Epoch [17/85], Iteration [170]


Epoch [17/85], Iteration [255]


Epoch [17/85], Iteration [340]


Epoch [17/85], Iteration [425]


Epoch [17/85], Iteration [510]


Epoch [17/85], Iteration [595]


Epoch [17/85], Iteration [680]


Epoch [17/85], Iteration [765]


[SR Val] PSNR: 41.09, SSIM: 0.9899
(548s) Epoch [18/85], Train_PSNR:12.76, Val_PSNR:13.35, Val_SSIM:0.4067
Epoch 18: Learning rate adjusted to 0.000053


Epoch [18/85], Iteration [0]


Epoch [18/85], Iteration [85]


Epoch [18/85], Iteration [170]


Epoch [18/85], Iteration [255]


Epoch [18/85], Iteration [340]


Epoch [18/85], Iteration [425]


Epoch [18/85], Iteration [510]


Epoch [18/85], Iteration [595]


Epoch [18/85], Iteration [680]


Epoch [18/85], Iteration [765]


[SR Val] PSNR: 42.80, SSIM: 0.9918
(556s) Epoch [19/85], Train_PSNR:12.55, Val_PSNR:13.01, Val_SSIM:0.4014


Epoch [19/85], Iteration [0]


Epoch [19/85], Iteration [85]


Epoch [19/85], Iteration [170]


Epoch [19/85], Iteration [255]


Epoch [19/85], Iteration [340]


Epoch [19/85], Iteration [425]


Epoch [19/85], Iteration [510]


Epoch [19/85], Iteration [595]


Epoch [19/85], Iteration [680]


Epoch [19/85], Iteration [765]


[SR Val] PSNR: 40.03, SSIM: 0.9790
(565s) Epoch [20/85], Train_PSNR:12.67, Val_PSNR:12.65, Val_SSIM:0.3900


Epoch [20/85], Iteration [0]


Epoch [20/85], Iteration [85]


Epoch [20/85], Iteration [170]


Epoch [20/85], Iteration [255]


Epoch [20/85], Iteration [340]


Epoch [20/85], Iteration [425]


Epoch [20/85], Iteration [510]


Epoch [20/85], Iteration [595]


Epoch [20/85], Iteration [680]


Epoch [20/85], Iteration [765]


Model saved in epoch 20.


[SR Val] PSNR: 40.80, SSIM: 0.9748
(570s) Epoch [21/85], Train_PSNR:12.53, Val_PSNR:12.84, Val_SSIM:0.4056
Epoch 21: Learning rate adjusted to 0.000048


Epoch [21/85], Iteration [0]


Epoch [21/85], Iteration [85]


Epoch [21/85], Iteration [170]


Epoch [21/85], Iteration [255]


Epoch [21/85], Iteration [340]


Epoch [21/85], Iteration [425]


Epoch [21/85], Iteration [510]


Epoch [21/85], Iteration [595]


Epoch [21/85], Iteration [680]


Epoch [21/85], Iteration [765]


[SR Val] PSNR: 43.27, SSIM: 0.9900
(573s) Epoch [22/85], Train_PSNR:12.73, Val_PSNR:12.95, Val_SSIM:0.4014


Epoch [22/85], Iteration [0]


Epoch [22/85], Iteration [85]


Epoch [22/85], Iteration [170]


Epoch [22/85], Iteration [255]


Epoch [22/85], Iteration [340]


Epoch [22/85], Iteration [425]


Epoch [22/85], Iteration [510]


Epoch [22/85], Iteration [595]


Epoch [22/85], Iteration [680]


Epoch [22/85], Iteration [765]


[SR Val] PSNR: 44.00, SSIM: 0.9942
(585s) Epoch [23/85], Train_PSNR:12.79, Val_PSNR:13.15, Val_SSIM:0.4158


Epoch [23/85], Iteration [0]


Epoch [23/85], Iteration [85]


Epoch [23/85], Iteration [170]


Epoch [23/85], Iteration [255]


Epoch [23/85], Iteration [340]


Epoch [23/85], Iteration [425]


Epoch [23/85], Iteration [510]


Epoch [23/85], Iteration [595]


Epoch [23/85], Iteration [680]


Epoch [23/85], Iteration [765]


[SR Val] PSNR: 42.82, SSIM: 0.9886
(601s) Epoch [24/85], Train_PSNR:12.90, Val_PSNR:13.06, Val_SSIM:0.4019
Epoch 24: Learning rate adjusted to 0.000043


Epoch [24/85], Iteration [0]


Epoch [24/85], Iteration [85]


Epoch [24/85], Iteration [170]


Epoch [24/85], Iteration [255]


Epoch [24/85], Iteration [340]


Epoch [24/85], Iteration [425]


Epoch [24/85], Iteration [510]


Epoch [24/85], Iteration [595]


Epoch [24/85], Iteration [680]


Epoch [24/85], Iteration [765]


[SR Val] PSNR: 42.28, SSIM: 0.9919
(601s) Epoch [25/85], Train_PSNR:12.89, Val_PSNR:13.23, Val_SSIM:0.4154


Epoch [25/85], Iteration [0]


Epoch [25/85], Iteration [85]


Epoch [25/85], Iteration [170]


Epoch [25/85], Iteration [255]


Epoch [25/85], Iteration [340]


Epoch [25/85], Iteration [425]


Epoch [25/85], Iteration [510]


Epoch [25/85], Iteration [595]


Epoch [25/85], Iteration [680]


Epoch [25/85], Iteration [765]


Model saved in epoch 25.


[SR Val] PSNR: 43.25, SSIM: 0.9926
(612s) Epoch [26/85], Train_PSNR:12.90, Val_PSNR:13.09, Val_SSIM:0.4114


Epoch [26/85], Iteration [0]


Epoch [26/85], Iteration [85]


Epoch [26/85], Iteration [170]


Epoch [26/85], Iteration [255]


Epoch [26/85], Iteration [340]


Epoch [26/85], Iteration [425]


Epoch [26/85], Iteration [510]


Epoch [26/85], Iteration [595]


Epoch [26/85], Iteration [680]


Epoch [26/85], Iteration [765]


[SR Val] PSNR: 43.75, SSIM: 0.9897
(624s) Epoch [27/85], Train_PSNR:12.93, Val_PSNR:13.06, Val_SSIM:0.4070
Epoch 27: Learning rate adjusted to 0.000039


Epoch [27/85], Iteration [0]


Epoch [27/85], Iteration [85]


Epoch [27/85], Iteration [170]


Epoch [27/85], Iteration [255]


Epoch [27/85], Iteration [340]


Epoch [27/85], Iteration [425]


Epoch [27/85], Iteration [510]


Epoch [27/85], Iteration [595]


Epoch [27/85], Iteration [680]


Epoch [27/85], Iteration [765]


[SR Val] PSNR: 44.57, SSIM: 0.9945
(626s) Epoch [28/85], Train_PSNR:13.01, Val_PSNR:13.01, Val_SSIM:0.4029


Epoch [28/85], Iteration [0]


Epoch [28/85], Iteration [85]


Epoch [28/85], Iteration [170]


Epoch [28/85], Iteration [255]


Epoch [28/85], Iteration [340]


Epoch [28/85], Iteration [425]


Epoch [28/85], Iteration [510]


Epoch [28/85], Iteration [595]


Epoch [28/85], Iteration [680]


Epoch [28/85], Iteration [765]


[SR Val] PSNR: 44.81, SSIM: 0.9948
(640s) Epoch [29/85], Train_PSNR:13.07, Val_PSNR:13.54, Val_SSIM:0.4169


Epoch [29/85], Iteration [0]


Epoch [29/85], Iteration [85]


Epoch [29/85], Iteration [170]


Epoch [29/85], Iteration [255]


Epoch [29/85], Iteration [340]


Epoch [29/85], Iteration [425]


Epoch [29/85], Iteration [510]


Epoch [29/85], Iteration [595]


Epoch [29/85], Iteration [680]


Epoch [29/85], Iteration [765]


[SR Val] PSNR: 41.44, SSIM: 0.9752
(655s) Epoch [30/85], Train_PSNR:13.10, Val_PSNR:13.52, Val_SSIM:0.4086
Epoch 30: Learning rate adjusted to 0.000035


Epoch [30/85], Iteration [0]


Epoch [30/85], Iteration [85]


Epoch [30/85], Iteration [170]


Epoch [30/85], Iteration [255]


Epoch [30/85], Iteration [340]


Epoch [30/85], Iteration [425]


Epoch [30/85], Iteration [510]


Epoch [30/85], Iteration [595]


Epoch [30/85], Iteration [680]


Epoch [30/85], Iteration [765]


Model saved in epoch 30.


[SR Val] PSNR: 44.67, SSIM: 0.9920
(662s) Epoch [31/85], Train_PSNR:13.18, Val_PSNR:13.39, Val_SSIM:0.4172


Epoch [31/85], Iteration [0]


Epoch [31/85], Iteration [85]


Epoch [31/85], Iteration [170]


Epoch [31/85], Iteration [255]


Epoch [31/85], Iteration [340]


Epoch [31/85], Iteration [425]


Epoch [31/85], Iteration [510]


Epoch [31/85], Iteration [595]


Epoch [31/85], Iteration [680]


Epoch [31/85], Iteration [765]


[SR Val] PSNR: 45.21, SSIM: 0.9960
(669s) Epoch [32/85], Train_PSNR:13.06, Val_PSNR:13.29, Val_SSIM:0.4067


Epoch [32/85], Iteration [0]


Epoch [32/85], Iteration [85]


Epoch [32/85], Iteration [170]


Epoch [32/85], Iteration [255]


Epoch [32/85], Iteration [340]


Epoch [32/85], Iteration [425]


Epoch [32/85], Iteration [510]


Epoch [32/85], Iteration [595]


Epoch [32/85], Iteration [680]


Epoch [32/85], Iteration [765]


[SR Val] PSNR: 45.47, SSIM: 0.9958
(673s) Epoch [33/85], Train_PSNR:13.17, Val_PSNR:13.33, Val_SSIM:0.4073
Epoch 33: Learning rate adjusted to 0.000031


Epoch [33/85], Iteration [0]


Epoch [33/85], Iteration [85]


Epoch [33/85], Iteration [170]


Epoch [33/85], Iteration [255]


Epoch [33/85], Iteration [340]


Epoch [33/85], Iteration [425]


Epoch [33/85], Iteration [510]


Epoch [33/85], Iteration [595]


Epoch [33/85], Iteration [680]


Epoch [33/85], Iteration [765]


[SR Val] PSNR: 42.45, SSIM: 0.9784
(678s) Epoch [34/85], Train_PSNR:12.96, Val_PSNR:13.37, Val_SSIM:0.4071


Epoch [34/85], Iteration [0]


Epoch [34/85], Iteration [85]


Epoch [34/85], Iteration [170]


Epoch [34/85], Iteration [255]


Epoch [34/85], Iteration [340]


Epoch [34/85], Iteration [425]


Epoch [34/85], Iteration [510]


Epoch [34/85], Iteration [595]


Epoch [34/85], Iteration [680]


Epoch [34/85], Iteration [765]


[SR Val] PSNR: 45.27, SSIM: 0.9946
(700s) Epoch [35/85], Train_PSNR:13.15, Val_PSNR:13.20, Val_SSIM:0.4099


Epoch [35/85], Iteration [0]


Epoch [35/85], Iteration [85]


Epoch [35/85], Iteration [170]


Epoch [35/85], Iteration [255]


Epoch [35/85], Iteration [340]


Epoch [35/85], Iteration [425]


Epoch [35/85], Iteration [510]


Epoch [35/85], Iteration [595]


Epoch [35/85], Iteration [680]


Epoch [35/85], Iteration [765]


Model saved in epoch 35.


[SR Val] PSNR: 45.48, SSIM: 0.9947
(709s) Epoch [36/85], Train_PSNR:13.22, Val_PSNR:13.65, Val_SSIM:0.4192
Epoch 36: Learning rate adjusted to 0.000028


Epoch [36/85], Iteration [0]


Epoch [36/85], Iteration [85]


Epoch [36/85], Iteration [170]


Epoch [36/85], Iteration [255]


Epoch [36/85], Iteration [340]


Epoch [36/85], Iteration [425]


Epoch [36/85], Iteration [510]


Epoch [36/85], Iteration [595]


Epoch [36/85], Iteration [680]


Epoch [36/85], Iteration [765]


[SR Val] PSNR: 45.06, SSIM: 0.9922
(716s) Epoch [37/85], Train_PSNR:13.23, Val_PSNR:13.81, Val_SSIM:0.4160


Epoch [37/85], Iteration [0]


Epoch [37/85], Iteration [85]


Epoch [37/85], Iteration [170]


Epoch [37/85], Iteration [255]


Epoch [37/85], Iteration [340]


Epoch [37/85], Iteration [425]


Epoch [37/85], Iteration [510]


Epoch [37/85], Iteration [595]


Epoch [37/85], Iteration [680]


Epoch [37/85], Iteration [765]


[SR Val] PSNR: 45.70, SSIM: 0.9946
(731s) Epoch [38/85], Train_PSNR:13.31, Val_PSNR:13.58, Val_SSIM:0.4140


Epoch [38/85], Iteration [0]


Epoch [38/85], Iteration [85]


Epoch [38/85], Iteration [170]


Epoch [38/85], Iteration [255]


Epoch [38/85], Iteration [340]


Epoch [38/85], Iteration [425]


Epoch [38/85], Iteration [510]


Epoch [38/85], Iteration [595]


Epoch [38/85], Iteration [680]


Epoch [38/85], Iteration [765]


[SR Val] PSNR: 45.48, SSIM: 0.9959
(738s) Epoch [39/85], Train_PSNR:13.21, Val_PSNR:13.70, Val_SSIM:0.4320
Epoch 39: Learning rate adjusted to 0.000025


Epoch [39/85], Iteration [0]


Epoch [39/85], Iteration [85]


Epoch [39/85], Iteration [170]


Epoch [39/85], Iteration [255]


Epoch [39/85], Iteration [340]


Epoch [39/85], Iteration [425]


Epoch [39/85], Iteration [510]


Epoch [39/85], Iteration [595]


Epoch [39/85], Iteration [680]


Epoch [39/85], Iteration [765]


[SR Val] PSNR: 45.78, SSIM: 0.9963
(743s) Epoch [40/85], Train_PSNR:13.19, Val_PSNR:13.46, Val_SSIM:0.4096


Epoch [40/85], Iteration [0]


Epoch [40/85], Iteration [85]


Epoch [40/85], Iteration [170]


Epoch [40/85], Iteration [255]


Epoch [40/85], Iteration [340]


Epoch [40/85], Iteration [425]


Epoch [40/85], Iteration [510]


Epoch [40/85], Iteration [595]


Epoch [40/85], Iteration [680]


Epoch [40/85], Iteration [765]


Model saved in epoch 40.


[SR Val] PSNR: 46.22, SSIM: 0.9961
(753s) Epoch [41/85], Train_PSNR:13.34, Val_PSNR:13.59, Val_SSIM:0.4115


Epoch [41/85], Iteration [0]


Epoch [41/85], Iteration [85]


Epoch [41/85], Iteration [170]


Epoch [41/85], Iteration [255]


Epoch [41/85], Iteration [340]


Epoch [41/85], Iteration [425]


Epoch [41/85], Iteration [510]


Epoch [41/85], Iteration [595]


Epoch [41/85], Iteration [680]


Epoch [41/85], Iteration [765]


[SR Val] PSNR: 46.22, SSIM: 0.9959
(762s) Epoch [42/85], Train_PSNR:13.38, Val_PSNR:13.52, Val_SSIM:0.4141
Epoch 42: Learning rate adjusted to 0.000023


Epoch [42/85], Iteration [0]


Epoch [42/85], Iteration [85]


Epoch [42/85], Iteration [170]


Epoch [42/85], Iteration [255]


Epoch [42/85], Iteration [340]


Epoch [42/85], Iteration [425]


Epoch [42/85], Iteration [510]


Epoch [42/85], Iteration [595]


Epoch [42/85], Iteration [680]


Epoch [42/85], Iteration [765]


[SR Val] PSNR: 46.20, SSIM: 0.9967
(772s) Epoch [43/85], Train_PSNR:13.34, Val_PSNR:13.51, Val_SSIM:0.4043


Epoch [43/85], Iteration [0]


Epoch [43/85], Iteration [85]


Epoch [43/85], Iteration [170]


Epoch [43/85], Iteration [255]


Epoch [43/85], Iteration [340]


Epoch [43/85], Iteration [425]


Epoch [43/85], Iteration [510]


Epoch [43/85], Iteration [595]


Epoch [43/85], Iteration [680]


Epoch [43/85], Iteration [765]


[SR Val] PSNR: 46.22, SSIM: 0.9963
(780s) Epoch [44/85], Train_PSNR:13.30, Val_PSNR:13.67, Val_SSIM:0.4280


Epoch [44/85], Iteration [0]


Epoch [44/85], Iteration [85]


Epoch [44/85], Iteration [170]


Epoch [44/85], Iteration [255]


Epoch [44/85], Iteration [340]


Epoch [44/85], Iteration [425]


Epoch [44/85], Iteration [510]


Epoch [44/85], Iteration [595]


Epoch [44/85], Iteration [680]


Epoch [44/85], Iteration [765]


[SR Val] PSNR: 46.37, SSIM: 0.9964
(788s) Epoch [45/85], Train_PSNR:13.39, Val_PSNR:13.78, Val_SSIM:0.4080
Epoch 45: Learning rate adjusted to 0.000021


Epoch [45/85], Iteration [0]


Epoch [45/85], Iteration [85]


Epoch [45/85], Iteration [170]


Epoch [45/85], Iteration [255]


Epoch [45/85], Iteration [340]


Epoch [45/85], Iteration [425]


Epoch [45/85], Iteration [510]


Epoch [45/85], Iteration [595]


Epoch [45/85], Iteration [680]


Epoch [45/85], Iteration [765]


Model saved in epoch 45.


[SR Val] PSNR: 46.61, SSIM: 0.9968
(791s) Epoch [46/85], Train_PSNR:13.43, Val_PSNR:13.56, Val_SSIM:0.4146


Epoch [46/85], Iteration [0]


Epoch [46/85], Iteration [85]


Epoch [46/85], Iteration [170]


Epoch [46/85], Iteration [255]


Epoch [46/85], Iteration [340]


Epoch [46/85], Iteration [425]


Epoch [46/85], Iteration [510]


Epoch [46/85], Iteration [595]


Epoch [46/85], Iteration [680]


Epoch [46/85], Iteration [765]


[SR Val] PSNR: 46.61, SSIM: 0.9967
(804s) Epoch [47/85], Train_PSNR:13.40, Val_PSNR:13.71, Val_SSIM:0.4208


Epoch [47/85], Iteration [0]


Epoch [47/85], Iteration [85]


Epoch [47/85], Iteration [170]


Epoch [47/85], Iteration [255]


Epoch [47/85], Iteration [340]


Epoch [47/85], Iteration [425]


Epoch [47/85], Iteration [510]


Epoch [47/85], Iteration [595]


Epoch [47/85], Iteration [680]


Epoch [47/85], Iteration [765]


[SR Val] PSNR: 46.68, SSIM: 0.9969
(817s) Epoch [48/85], Train_PSNR:13.40, Val_PSNR:13.70, Val_SSIM:0.4227
Epoch 48: Learning rate adjusted to 0.000019


Epoch [48/85], Iteration [0]


Epoch [48/85], Iteration [85]


Epoch [48/85], Iteration [170]


Epoch [48/85], Iteration [255]


Epoch [48/85], Iteration [340]


Epoch [48/85], Iteration [425]


Epoch [48/85], Iteration [510]


Epoch [48/85], Iteration [595]


Epoch [48/85], Iteration [680]


Epoch [48/85], Iteration [765]


[SR Val] PSNR: 46.71, SSIM: 0.9969
(823s) Epoch [49/85], Train_PSNR:13.47, Val_PSNR:13.78, Val_SSIM:0.4250


Epoch [49/85], Iteration [0]


Epoch [49/85], Iteration [85]


Epoch [49/85], Iteration [170]


Epoch [49/85], Iteration [255]


Epoch [49/85], Iteration [340]


Epoch [49/85], Iteration [425]


Epoch [49/85], Iteration [510]


Epoch [49/85], Iteration [595]


Epoch [49/85], Iteration [680]


Epoch [49/85], Iteration [765]


[SR Val] PSNR: 46.59, SSIM: 0.9969
(838s) Epoch [50/85], Train_PSNR:13.50, Val_PSNR:13.68, Val_SSIM:0.4011


Epoch [50/85], Iteration [0]


Epoch [50/85], Iteration [85]


Epoch [50/85], Iteration [170]


Epoch [50/85], Iteration [255]


Epoch [50/85], Iteration [340]


Epoch [50/85], Iteration [425]


Epoch [50/85], Iteration [510]


Epoch [50/85], Iteration [595]


Epoch [50/85], Iteration [680]


Epoch [50/85], Iteration [765]


Model saved in epoch 50.


[SR Val] PSNR: 47.00, SSIM: 0.9970
(851s) Epoch [51/85], Train_PSNR:13.55, Val_PSNR:13.67, Val_SSIM:0.4316
Epoch 51: Learning rate adjusted to 0.000017


Epoch [51/85], Iteration [0]


Epoch [51/85], Iteration [85]


Epoch [51/85], Iteration [170]


Epoch [51/85], Iteration [255]


Epoch [51/85], Iteration [340]


Epoch [51/85], Iteration [425]


Epoch [51/85], Iteration [510]


Epoch [51/85], Iteration [595]


Epoch [51/85], Iteration [680]


Epoch [51/85], Iteration [765]


[SR Val] PSNR: 47.05, SSIM: 0.9973
(858s) Epoch [52/85], Train_PSNR:13.54, Val_PSNR:13.99, Val_SSIM:0.4241


Epoch [52/85], Iteration [0]


Epoch [52/85], Iteration [85]


Epoch [52/85], Iteration [170]


Epoch [52/85], Iteration [255]


Epoch [52/85], Iteration [340]


Epoch [52/85], Iteration [425]


Epoch [52/85], Iteration [510]


Epoch [52/85], Iteration [595]


Epoch [52/85], Iteration [680]


Epoch [52/85], Iteration [765]


[SR Val] PSNR: 47.02, SSIM: 0.9973
(869s) Epoch [53/85], Train_PSNR:13.49, Val_PSNR:13.70, Val_SSIM:0.4141


Epoch [53/85], Iteration [0]


Epoch [53/85], Iteration [85]


Epoch [53/85], Iteration [170]


Epoch [53/85], Iteration [255]


Epoch [53/85], Iteration [340]


Epoch [53/85], Iteration [425]


Epoch [53/85], Iteration [510]


Epoch [53/85], Iteration [595]


Epoch [53/85], Iteration [680]


Epoch [53/85], Iteration [765]


[SR Val] PSNR: 47.13, SSIM: 0.9965
(869s) Epoch [54/85], Train_PSNR:13.62, Val_PSNR:13.84, Val_SSIM:0.4146
Epoch 54: Learning rate adjusted to 0.000015


Epoch [54/85], Iteration [0]


Epoch [54/85], Iteration [85]


Epoch [54/85], Iteration [170]


Epoch [54/85], Iteration [255]


Epoch [54/85], Iteration [340]


Epoch [54/85], Iteration [425]


Epoch [54/85], Iteration [510]


Epoch [54/85], Iteration [595]


Epoch [54/85], Iteration [680]


Epoch [54/85], Iteration [765]


[SR Val] PSNR: 47.22, SSIM: 0.9973
(865s) Epoch [55/85], Train_PSNR:13.63, Val_PSNR:13.84, Val_SSIM:0.4301


Epoch [55/85], Iteration [0]


Epoch [55/85], Iteration [85]


Epoch [55/85], Iteration [170]


Epoch [55/85], Iteration [255]


Epoch [55/85], Iteration [340]


Epoch [55/85], Iteration [425]


Epoch [55/85], Iteration [510]


Epoch [55/85], Iteration [595]


Epoch [55/85], Iteration [680]


Epoch [55/85], Iteration [765]


Model saved in epoch 55.


[SR Val] PSNR: 47.23, SSIM: 0.9973
(882s) Epoch [56/85], Train_PSNR:13.56, Val_PSNR:13.87, Val_SSIM:0.4123


Epoch [56/85], Iteration [0]


Epoch [56/85], Iteration [85]


Epoch [56/85], Iteration [170]


Epoch [56/85], Iteration [255]


Epoch [56/85], Iteration [340]


Epoch [56/85], Iteration [425]


Epoch [56/85], Iteration [510]


Epoch [56/85], Iteration [595]


Epoch [56/85], Iteration [680]


Epoch [56/85], Iteration [765]


[SR Val] PSNR: 47.34, SSIM: 0.9971
(895s) Epoch [57/85], Train_PSNR:13.57, Val_PSNR:13.74, Val_SSIM:0.4123
Epoch 57: Learning rate adjusted to 0.000014


Epoch [57/85], Iteration [0]


Epoch [57/85], Iteration [85]


Epoch [57/85], Iteration [170]


Epoch [57/85], Iteration [255]


Epoch [57/85], Iteration [340]


Epoch [57/85], Iteration [425]


Epoch [57/85], Iteration [510]


Epoch [57/85], Iteration [595]


Epoch [57/85], Iteration [680]


Epoch [57/85], Iteration [765]


[SR Val] PSNR: 47.42, SSIM: 0.9971
(901s) Epoch [58/85], Train_PSNR:13.55, Val_PSNR:13.98, Val_SSIM:0.4219


Epoch [58/85], Iteration [0]


Epoch [58/85], Iteration [85]


Epoch [58/85], Iteration [170]


Epoch [58/85], Iteration [255]


Epoch [58/85], Iteration [340]


Epoch [58/85], Iteration [425]


Epoch [58/85], Iteration [510]


Epoch [58/85], Iteration [595]


Epoch [58/85], Iteration [680]


Epoch [58/85], Iteration [765]


[SR Val] PSNR: 47.52, SSIM: 0.9973
(925s) Epoch [59/85], Train_PSNR:13.62, Val_PSNR:13.76, Val_SSIM:0.4189


Epoch [59/85], Iteration [0]


Epoch [59/85], Iteration [85]


Epoch [59/85], Iteration [170]


Epoch [59/85], Iteration [255]


Epoch [59/85], Iteration [340]


Epoch [59/85], Iteration [425]


Epoch [59/85], Iteration [510]


Epoch [59/85], Iteration [595]


Epoch [59/85], Iteration [680]


Epoch [59/85], Iteration [765]


[SR Val] PSNR: 47.50, SSIM: 0.9975
(918s) Epoch [60/85], Train_PSNR:13.64, Val_PSNR:13.93, Val_SSIM:0.4199
Epoch 60: Learning rate adjusted to 0.000012


Epoch [60/85], Iteration [0]


Epoch [60/85], Iteration [85]


Epoch [60/85], Iteration [170]


Epoch [60/85], Iteration [255]


Epoch [60/85], Iteration [340]


Epoch [60/85], Iteration [425]


Epoch [60/85], Iteration [510]


Epoch [60/85], Iteration [595]


Epoch [60/85], Iteration [680]


Epoch [60/85], Iteration [765]


Model saved in epoch 60.


[SR Val] PSNR: 46.96, SSIM: 0.9973
(929s) Epoch [61/85], Train_PSNR:13.74, Val_PSNR:13.50, Val_SSIM:0.4102


Epoch [61/85], Iteration [0]


Epoch [61/85], Iteration [85]


Epoch [61/85], Iteration [170]


Epoch [61/85], Iteration [255]


Epoch [61/85], Iteration [340]


Epoch [61/85], Iteration [425]


Epoch [61/85], Iteration [510]


Epoch [61/85], Iteration [595]


Epoch [61/85], Iteration [680]


Epoch [61/85], Iteration [765]


[SR Val] PSNR: 47.54, SSIM: 0.9973
(935s) Epoch [62/85], Train_PSNR:13.61, Val_PSNR:13.85, Val_SSIM:0.4248


Epoch [62/85], Iteration [0]


Epoch [62/85], Iteration [85]


Epoch [62/85], Iteration [170]


Epoch [62/85], Iteration [255]


Epoch [62/85], Iteration [340]


Epoch [62/85], Iteration [425]


Epoch [62/85], Iteration [510]


Epoch [62/85], Iteration [595]


Epoch [62/85], Iteration [680]


Epoch [62/85], Iteration [765]


[SR Val] PSNR: 47.59, SSIM: 0.9976
(936s) Epoch [63/85], Train_PSNR:13.69, Val_PSNR:14.02, Val_SSIM:0.4196
Epoch 63: Learning rate adjusted to 0.000011


Epoch [63/85], Iteration [0]


Epoch [63/85], Iteration [85]


Epoch [63/85], Iteration [170]


Epoch [63/85], Iteration [255]


Epoch [63/85], Iteration [340]


Epoch [63/85], Iteration [425]


Epoch [63/85], Iteration [510]


Epoch [63/85], Iteration [595]


Epoch [63/85], Iteration [680]


Epoch [63/85], Iteration [765]


In [None]:
# Initialize model
# model_path = "/kaggle/input/rdb-and-transformer/pytorch/default/1/formernewnh_final_49.pth"
model_path = "/kaggle/input/reside-dehaze/pytorch/default/3/formernewreside_haze_best_0.pth"
# model = DehazingNet().to(device)
# model = SR_model(upscale_factor=1).to(device)
# net.load_state_dict(torch.load(model_path, map_location=device))
net.eval()

In [None]:
# -----------------------------
# LOAD TEST DATA
# -----------------------------
test_hazy_dir = "/kaggle/input/reside6k/RESIDE-6K/test/hazy"
test_gt_dir = "/kaggle/input/reside6k/RESIDE-6K/test/GT"
# test_hazy_dir = "/kaggle/input/nh-dense-haze/NH-HAZE-T/NH-HAZE-T/IN"
# test_gt_dir = "/kaggle/input/nh-dense-haze/NH-HAZE-T/NH-HAZE-T/GT"

hazy_images = sorted(glob.glob(os.path.join(test_hazy_dir, "*.*")))
gt_images = sorted(glob.glob(os.path.join(test_gt_dir, "*.*")))

transform = Compose([
    ToTensor(),
    Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

to_pil = ToPILImage()

In [None]:
# -----------------------------
# INFERENCE & VISUALIZATION FOR SPECIFIC IMAGES
# -----------------------------
image_indices = [11, 12, 13,14]  # Indices of images to visualize

plt.figure(figsize=(10, len(image_indices) * 5))

for idx, i in enumerate(image_indices):
    hazy_img = Image.open(hazy_images[i+1])
    gt_img = Image.open(gt_images[i+1])

    # Transform for model input
    input_tensor = transform(hazy_img).unsqueeze(0).to(device)

    # Inference
    with torch.no_grad():
        res = net(input_tensor)
        print(res[0].shape)
        output_tensor = res[0].cpu().squeeze(0)

    # Convert back to image
    output_img = to_pil(output_tensor)

    # Display results
    plt.subplot(len(image_indices), 3, 3 * idx + 1)
    plt.imshow(hazy_img)
    plt.title(f"Hazy Input ")
    plt.axis("off")

    plt.subplot(len(image_indices), 3, 3 * idx + 2)
    plt.imshow(output_img)
    plt.title(f"Dehazed Output ")
    plt.axis("off")

    plt.subplot(len(image_indices), 3, 3 * idx + 3)
    plt.imshow(gt_img)
    plt.title(f"Ground Truth ")
    plt.axis("off")

plt.tight_layout()
plt.show()

In [None]:
# -----------------------------
# INFERENCE & VISUALIZATION FOR SPECIFIC IMAGES
# -----------------------------
image_indices = [70, 75, 90, 100]  # Indices of images to visualize

plt.figure(figsize=(10, len(image_indices) * 5))

for idx, i in enumerate(image_indices):
    hazy_img = Image.open(hazy_images[i+1])
    gt_img = Image.open(gt_images[i+1])

    # Transform for model input
    input_tensor = transform(hazy_img).unsqueeze(0).to(device)

    # Inference
    with torch.no_grad():
        res = net(input_tensor)
        print(res[0].shape)
        output_tensor = res[0].cpu().squeeze(0)

    # Convert back to image
    output_img = to_pil(output_tensor)

    # Display results
    plt.subplot(len(image_indices), 3, 3 * idx + 1)
    plt.imshow(hazy_img)
    plt.title(f"Hazy Input {i}")
    plt.axis("off")

    plt.subplot(len(image_indices), 3, 3 * idx + 2)
    plt.imshow(output_img)
    plt.title(f"Dehazed Output {i}")
    plt.axis("off")

    plt.subplot(len(image_indices), 3, 3 * idx + 3)
    plt.imshow(gt_img)
    plt.title(f"Ground Truth {i}")
    plt.axis("off")

plt.tight_layout()
plt.show()


In [None]:
# -----------------------------
# INFERENCE & VISUALIZATION FOR SPECIFIC IMAGES
# -----------------------------
image_indices = [1, 3, 5]  # Indices of images to visualize

plt.figure(figsize=(10, len(image_indices) * 5))

for idx, i in enumerate(image_indices):
    hazy_img = Image.open(hazy_images[i+1])
    gt_img = Image.open(gt_images[i+1])

    # Transform for model input
    input_tensor = transform(hazy_img).unsqueeze(0).to(device)

    # Inference
    with torch.no_grad():
        res = net(input_tensor)
        print(res[0].shape)
        output_tensor = res[0].cpu().squeeze(0)

    # Convert back to image
    output_img = to_pil(output_tensor)

    # Display results
    plt.subplot(len(image_indices), 3, 3 * idx + 1)
    plt.imshow(hazy_img)
    plt.title(f"Hazy Input {i}")
    plt.axis("off")

    plt.subplot(len(image_indices), 3, 3 * idx + 2)
    plt.imshow(output_img)
    plt.title(f"Dehazed Output {i}")
    plt.axis("off")

    plt.subplot(len(image_indices), 3, 3 * idx + 3)
    plt.imshow(gt_img)
    plt.title(f"Ground Truth {i}")
    plt.axis("off")

plt.tight_layout()
plt.show()
