# Imports

In [None]:
!pip install -q kornia
!pip install -q wandb
!pip install -q torchmetrics
!pip install -q einops
import os
import numpy as np
import matplotlib.pyplot as plt
import glob

import torch
from torchvision import transforms
import torch.nn as nn
import torch.utils.data as data
from torch.utils.data import DataLoader

import torch.nn.functional as F
from torch.autograd import Variable
import math
from math import exp
from tqdm import tqdm

from kornia.filters.sobel import Sobel
import wandb
from torchvision.utils import make_grid
import gc
from torchmetrics import StructuralSimilarityIndexMeasure as SSIM



from torch.nn.parameter import Parameter
from einops import rearrange
from torch.nn import init 

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

# Loss, Metric and Other Functions

In [None]:
def normalize(sample):
    """
    Normalizes Digital Elevation Model (DEM) data to a range of [0,1] using min-max normalization.
    
    This function performs feature scaling on DEM elevation values by applying min-max normalization.
    The normalization is crucial for DEM super resolution tasks as it helps in:
        - Standardizing the elevation values to a common scale
        - Improving numerical stability during processing
        - Making the data suitable for machine learning models
    
    Parameters:
    -----------
    sample : numpy.ndarray
        Input DEM data array containing elevation values
        
    Returns:
    --------
    numpy.ndarray
        Normalized DEM data with values scaled between 0 and 1
    """
    MIN_H = sample.min()
    MAX_H = sample.max()
    return (sample - MIN_H)/(MAX_H-MIN_H)

In [None]:
def calculate_psnr(img1, img2, border=0 ,data_min=0.0 ,data_max=1.0 ):
    """
    Calculates Peak Signal-to-Noise Ratio (PSNR) between two images.
    
    PSNR is a quality metric that measures the ratio between the maximum possible 
    signal power and the power of corrupting noise that affects the quality of the
    representation. Higher PSNR values indicate better quality.
    
    Parameters:
    -----------
    img1 : torch.Tensor or numpy.ndarray
        First input image (reference/ground truth)
    img2 : torch.Tensor or numpy.ndarray
        Second input image (predicted/generated)
    border : int, optional
        Number of border pixels to exclude from calculation (default: 0)
    data_min : float, optional
        Minimum value of the data range (default: 0.0)
    data_max : float, optional
        Maximum value of the data range (default: 1.0)
        
    Returns:
    --------
    float
        PSNR value in decibels (dB). Returns infinity if images are identical
        
    Raises:
    -------
    ValueError
        If input images don't have the same dimensions
    """
    
    if not img1.shape == img2.shape:
        raise ValueError('Input images must have the same dimensions.')
    h, w = img1.shape[2:]

    img1 = img1[border:h-border, border:w-border]
    img2 = img2[border:h-border, border:w-border]

    mse = np.mean((img1 - img2)**2)
    if mse == 0:
        return float('inf')
    return 20 * math.log10((data_max - data_min)/ math.sqrt(mse))

In [None]:
class gradientAwareLoss(nn.Module): 
    """
    A custom loss function that computes the L1 difference between edge maps of 
    high-resolution and super-resolved images using Sobel edge detection.
    
    This loss helps in preserving edge and gradient information during super-resolution,
    ensuring better structural fidelity in the output.
    
    Attributes:
    -----------
    sobelFilter : Sobel
        Sobel edge detection filter implemented on CUDA
    l1Loss : nn.L1Loss
        L1 loss function implemented on CUDA
        
    Methods:
    --------
    forward(hr, sr):
        Computes the gradient-aware loss between high-resolution(Ground Truth) and super-resolved(Model's Output) images
    """
    def __init__(self):
        super().__init__()
        self.sobelFilter = Sobel().to('cuda')
        self.l1Loss = nn.L1Loss().to('cuda')

    def forward(self, hr, sr):
        hrEdgeMap = self.sobelFilter(hr)
        srEdgeMap = self.sobelFilter(sr)
        return self.l1Loss(hrEdgeMap, srEdgeMap)  

# Dataset and Dataloader¶

In [None]:
# # Dataset class
# class Dataset(data.Dataset):
#     def __init__(self, load_dir, normalize = True,transform=transforms.Compose([transforms.ToTensor()])):
#         self.load_dir = load_dir
#         self.tranform = transform
#         self.downsampler = torch.nn.Upsample(scale_factor=0.5, mode='bilinear')
#         self.normalize = normalize      
#     def __getitem__(self, idx):        
#         try:
#             if self.normalize:
#                 im = normalize(torch.load(self.load_dir[idx]))
#             else:
#                 im = torch.load(self.dems[idx])
#             HR = im.copy().astype(np.float32) 
#             HR = torch.from_numpy(HR).unsqueeze(0) 
            
#             # print("shape of LR", LR.shape)
#             LR = self.downsampler(HR.unsqueeze(0)).squeeze()     
#             # print("shape of LR", LR.shape)
#             return HR, LR.unsqueeze(0)
#         except:
#             print(idx)
#             print(self.load_dir[idx])
#     def __len__(self):
#         return len(self.load_dir)

In [None]:
import torch
import torch.nn as nn
import torch.utils.data as data
import numpy as np
import torchvision.transforms as transforms
import random
import cv2
import os

def custom_collate(batch):
    """
    Custom collate function to filter out None values
    
    Parameters:
    -----------
    batch : list
        List of data samples, potentially containing None values
        
    Returns:
    --------
    torch.Tensor or None
        Collated batch after filtering None values, returns None if batch is empty
    """
    batch = list(filter(lambda x: x is not None, batch))
    if len(batch) == 0:
        return None
    return torch.utils.data.dataloader.default_collate(batch)

class Dataset(data.Dataset):
    def __init__(self, load_dir, normalize=True, transform=transforms.Compose([transforms.ToTensor()]),
                 downsample_methods=['bilinear', 'area', 'gaussian','motion_blur','median_blur']):
        """
        Custom Dataset class for DEM super-resolution training that implements multiple 
        downsampling methods and data preprocessing.

        The dataset handles loading high-resolution DEM images and generates corresponding 
        low-resolution pairs using various downsampling techniques.
        
        Parameters:
            load_dir (list): List of paths to high-resolution images
            normalize (bool): Whether to normalize the images
            transform (torchvision.transforms): Additional transformations
            downsample_methods (list): List of downsampling methods to use
        """
        self.load_dir = load_dir
        self.transform = transform
        self.normalize = normalize
        self.downsample_methods = downsample_methods
        
        # Predefined downsampling methods
        self.downsamplers = {
            'bilinear': self._bilinear_downsample,
            'area': self._area_downsample,
            'gaussian': self._gaussian_downsample,
            'motion_blur': self._motion_blur_downsample,
            'median_blur': self._median_blur_downsample
        }
    
    def _bilinear_downsample(self, HR):
        """Bilinear downsampling"""
        return torch.nn.functional.interpolate(HR.unsqueeze(0), scale_factor=0.5, mode='bilinear').squeeze()
    
    def _area_downsample(self, HR):
        """Area interpolation downsampling"""
        return torch.nn.functional.interpolate(HR.unsqueeze(0), scale_factor=0.5, mode='area').squeeze()
    
    def _gaussian_downsample(self, HR):
        """Gaussian blur + downsampling"""
        hr_np = HR.squeeze(0).numpy() 
        blurred = cv2.GaussianBlur(hr_np, (5, 5), 0)
        downsampled = cv2.resize(blurred, (HR.shape[2]//2, HR.shape[1]//2), interpolation=cv2.INTER_AREA)
        return torch.from_numpy(downsampled).unsqueeze(0) 
    
    def _motion_blur_downsample(self, HR):
        """Motion blur + downsampling"""
        hr_np = HR.squeeze(0).numpy() 
        kernel_size = 5
        kernel_v = np.zeros((kernel_size, kernel_size))
        kernel_v[:, kernel_size // 2] = np.ones(kernel_size)
        kernel_v /= kernel_size
        blurred = cv2.filter2D(hr_np, -1, kernel_v)
        downsampled = cv2.resize(blurred, (HR.shape[2]//2, HR.shape[1]//2), interpolation=cv2.INTER_AREA)
        return torch.from_numpy(downsampled).unsqueeze(0)  # Add channel back

    
    def _median_blur_downsample(self, HR):
        """Median blur + downsampling"""
        hr_np = HR.squeeze(0).numpy() 
        blurred = cv2.medianBlur(hr_np.astype(np.uint8), 5)
        downsampled = cv2.resize(blurred, (HR.shape[2]//2, HR.shape[1]//2), interpolation=cv2.INTER_AREA)
        return torch.from_numpy(downsampled).unsqueeze(0)  # Add channel back

    
    def __getitem__(self, idx):
        if idx < 0 or idx >= len(self):
            raise IndexError(f"Index {idx} is out of range")
        
        try:
            image_path = self.load_dir[idx]
            
            if not os.path.exists(image_path):
                raise FileNotFoundError(f"Image file not found: {image_path}")
            
            im = torch.load(image_path)
            
            if self.normalize:
                im = normalize(im)
            
            if not isinstance(im, torch.Tensor):
                im = torch.tensor(im, dtype=torch.float32)
            
            if im.dim() == 2:
                im = im.unsqueeze(0)
            
            HR = im.float()
            
            downsample_method = random.choice(self.downsample_methods)
            
            LR = self.downsamplers[downsample_method](HR)
            
            HR = HR.squeeze()
            LR = LR.squeeze().float()
            
            return HR.unsqueeze(0), LR.unsqueeze(0)
        
        except Exception as e:
            print(f"Error processing image at index {idx}")
            print(f"Image path: {self.load_dir[idx]}")
            print(f"Error details: {type(e).__name__}: {str(e)}")
            
            return None
    
    def __len__(self):
        return len(self.load_dir)

In [None]:
## Data Directory Setup
"""
    Sets up training and testing data directories for DEM super-resolution by loading
    PyTorch tensor files (.pt) and performing basic data validation.

    The script:
    1. Loads training samples from specified training directory
    2. Loads testing samples from specified testing directory
    3. Removes specific problematic samples if needed
    4. Verifies data loading by printing dataset sizes

    Variables:
    ----------
    train_dir : list
        List of paths to training DEM tensor files
    test_dir : list
        List of paths to testing DEM tensor files
"""

train_dir = glob.glob('../data_DEMs_no_nan/train/*.pt')
test_dir = glob.glob('../data_DEMs_no_nan/test/*.pt')

# this line just removed one of the currupted DEM file.
# test_dir.remove('../data_DEMs_no_nan/test/NAC_DTM_MESSIER3_block_211.pt')


# prints the numnber of samples in trainig and testing dataset.
print(len(train_dir)) 
print(len(test_dir))

In [None]:
# Dataset and DataLoader Initialization
"""
    Initializes training and testing datasets and their respective dataloaders for DEM super-resolution.

    The setup includes:
    1. Creation of custom Dataset instances for both training and testing
    2. Configuration of DataLoaders with specific batch sizes and shuffling parameters
    3. Organization of data for efficient model training and evaluation

    Variables:
    ----------
    trainset : Dataset
        Custom dataset instance for training data
    testset : Dataset
        Custom dataset instance for testing data
    trainloader : DataLoader
        DataLoader for training with batch processing and shuffling
    testloader : DataLoader
        DataLoader for testing with batch processing
"""

trainset = Dataset(load_dir = train_dir)
testset  = Dataset(load_dir = test_dir)

trainloader = DataLoader(trainset,batch_size=4,shuffle=True)
testloader = DataLoader(testset,batch_size=4,shuffle=False)

In [None]:
import matplotlib.pyplot as plt
import torchvision

def visualize_dataloader(dataloader, n_batches=5, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
    """
    Visualize images from a DataLoader.

    Args:
        dataloader: PyTorch DataLoader object.
        n_batches: Number of batches to visualize.
        mean: Mean used for normalization.
        std: Standard deviation used for normalization.
    """
    batch_count = 0

    for _, images in dataloader:
        if batch_count >= n_batches:
            break
        
        # Denormalize images if necessary
        images = torchvision.utils.make_grid(images, nrow=images.shape[0])
        images = images.permute(1, 2, 0).numpy()  # Convert to HWC format
        # images = (images * std + mean).clip(0, 1)  # Denormalize and clip
        
        # Plot the images
        plt.figure(figsize=(12, 6))
        plt.imshow(images)
        plt.axis("off")
        plt.title(f"Batch {batch_count + 1}")
        plt.show()

        batch_count += 1
visualize_dataloader(trainloader, n_batches=4)

In [None]:
# visualization 
for b, (hr, lr) in enumerate(trainloader):
    print(hr.shape)
    print(lr.shape)
    
    print(hr.min())
    print(hr.max())
    
    print(lr.min())
    print(lr.max())
    
    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12, 12))

    axes[0].set_yticklabels([])

    axes[0].imshow(hr[2].squeeze(), cmap='gray')
    axes[1].imshow(lr[2].squeeze(), cmap='gray')
    
    axes[0].set_title("HR")
    axes[1].set_title("LR")
    

    break

# Model

In [None]:
class depthwise_separable_conv(nn.Module):
    """
        Implements depthwise separable convolution, which factorizes standard convolution 
        into depthwise and pointwise operations for efficient computation.

        The operation consists of two steps:
        1. Depthwise convolution: applies a single filter per input channel
        2. Pointwise convolution: applies 1x1 convolution to combine the outputs

        Parameters:
        -----------
        nin : int
            Number of input channels
        nout : int
            Number of output channels
        kernel_size : int, optional
            Size of the convolving kernel (default: 3)
        padding : int, optional
            Padding size (default: 1)
        bias : bool, optional
            If True, adds a learnable bias to the output (default: False)
    """
    
    def __init__(self, nin, nout, kernel_size = 3, padding = 1, bias=False):
        super(depthwise_separable_conv, self).__init__()
        self.depthwise = nn.Conv2d(nin, nin, kernel_size=kernel_size, padding=padding, groups=nin, bias=bias)
        self.pointwise = nn.Conv2d(nin, nout, kernel_size=1, bias=bias)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out
    
    
    
    

class ERAM(nn.Module):
    
    """
    Enhanced Residual Attention Module (ERAM) for feature refinement in super-resolution.
    
    Implementation based on:
    Li, Y., Zhou, L., Xu, F., & Chen, S. (2022). OGSRN: Optical-guided super-resolution 
    network for SAR image. Chinese Journal of Aeronautics, 35(5), 204-219.
    https://doi.org/10.1016/j.cja.2021.08.036
    
    Parameters:
    -----------
    channel_begin : int
        Number of input channels
    dimension : int
        Spatial dimension for average pooling
        
    Architecture components:
    - Channel attention branch with statistical and convolution operations
    - Spatial attention branch with depthwise separable convolutions
    - Residual connection for feature preservation
    """
    
    def __init__(self, channel_begin, dimension):
        super().__init__()
        self.conv = nn.Conv2d(channel_begin, channel_begin, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.avgpool = nn.AvgPool2d(dimension)
        
        self.conv1 = nn.Conv2d(channel_begin, channel_begin//2, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(channel_begin//2, channel_begin, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(channel_begin, channel_begin, kernel_size=3, stride=1, padding=1)

        self.dconv = depthwise_separable_conv(channel_begin, channel_begin, kernel_size = 3, padding = 1, bias=False)
        self.sigmoid = nn.Sigmoid()


    def forward(self, x):
        """
            Forward pass combining channel and spatial attention mechanisms.

            Parameters:
            -----------
            x : torch.Tensor
                Input feature map

            Returns:
            --------
            torch.Tensor
                Refined feature map with attention applied
        """
        # Channel attention
        si_ca = self.avgpool(x) + torch.var_mean(x, dim=(2,3))[0].unsqueeze(2).unsqueeze(2)
        mi_ca = self.conv2(self.relu(self.conv1(si_ca)))
        
        # Spatial attention
        mi_sa = self.conv3(self.relu(self.dconv(x)))
        
        # Combine attentions and apply to input
        return self.sigmoid(mi_ca+mi_sa) * x

    

class SelfAttn(nn.Module):
    """
        Self-Attention module implementing multi-head attention mechanism for feature refinement.

        This module computes self-attention using queries, keys, and values with multiple attention
        heads, allowing the model to attend to different feature aspects in parallel.

        Parameters:
        -----------
        dim : int
            Input feature dimension
        num_heads : int, optional
            Number of attention heads (default: 8)
        bias : bool, optional
            Whether to include bias in linear transformations (default: False)
        
        Attributes:
        -----------
        scale : float
            Scaling factor for attention scores
        qkv : nn.Linear
            Linear projection for computing query, key, and value
        proj_out : nn.Linear
            Output projection
    """
    
    def __init__(self, dim, num_heads=8, bias=False):
        super(SelfAttn, self).__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=bias)
        self.proj_out = nn.Linear(dim, dim)

    def forward(self, x):
        """
            Compute self-attention on input features.

            Parameters:
            -----------
            x : torch.Tensor
                Input tensor of shape [batch_size, num_tokens, channels]

            Returns:
            --------
            torch.Tensor
                Attention-refined features of shape [batch_size, num_tokens, channels]
        """
        b, N, c = x.shape

        qkv = self.qkv(x).chunk(3, dim=-1)
        # [b, N, c] -> [b, N, head, c//head] -> [b, head, N, c//head]
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.num_heads), qkv)

        # [b, head, N, c//head] * [b, head, N, c//head] -> [b, head, N, N]
        attn = torch.einsum('bijc, bikc -> bijk', q, k) * self.scale
        attn = attn.softmax(dim=-1)
        # [b, head, N, N] * [b, head, N, c//head] -> [b, head, N, c//head] -> [b, N, head, c//head]
        x = torch.einsum('bijk, bikc -> bijc', attn, v)
        x = rearrange(x, 'b i j c -> b j (i c)')
        x = self.proj_out(x)
        return x


class Mlp(nn.Module):
    """
        Multi-Layer Perceptron (MLP) module with GELU activation.

        This module implements a two-layer feed-forward network with expansion ratio,
        commonly used in transformer architectures as the feed-forward network (FFN) component.

        Parameters:
        -----------
        in_features : int
            Number of input features
        mlp_ratio : int, optional
            Expansion ratio for hidden dimension (default: 4)

        Attributes:
        -----------
        fc : nn.Sequential
            Sequential container of linear layers and GELU activation
            - First linear layer expands dimensions
            - GELU activation function
            - Second linear layer projects back to input dimension
    """
    def __init__(self, in_features, mlp_ratio=4):
        super(Mlp, self).__init__()
        hidden_features = in_features * mlp_ratio

        self.fc = nn.Sequential(
            nn.Linear(in_features, hidden_features),
            nn.GELU(),
            nn.Linear(hidden_features, in_features)
        )

    def forward(self, x):
        return self.fc(x)


def window_partition(x, window_size):
    """
    Args:
        x: (b, h, w, c)
        window_size (int): window size
    Returns:
        windows: (num_windows*b, window_size, window_size, c) [non-overlap]
    """
    return rearrange(x, 'b (h s1) (w s2) c -> (b h w) s1 s2 c', s1=window_size, s2=window_size)


def window_reverse(windows, window_size, h, w):
    """
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        h (int): Height of image
        w (int): Width of image
    Returns:
        x: (b, h, w, c)
    """
    b = int(windows.shape[0] / (h * w / window_size / window_size))
    return rearrange(windows, '(b h w) s1 s2 c -> b (h s1) (w s2) c', b=b, h=h // window_size, w=w // window_size)


class Transformer(nn.Module):
    """
        Window-based Transformer module for processing DEM features.

        This module combines positional embedding, self-attention within windows,
        and MLP layers with skip connections for feature transformation.

        Parameters:
        -----------
        dim : int
            Feature dimension
        num_heads : int, optional
            Number of attention heads (default: 4)
        window_size : int, optional
            Size of attention windows (default: 8)
        mlp_ratio : int, optional
            MLP expansion ratio (default: 4)
        qkv_bias : bool, optional
            Whether to use bias in QKV projection (default: False)

        Attributes:
        -----------
        pos_embed : nn.Conv2d
            Depthwise convolution for positional embedding
        norm1, norm2 : nn.LayerNorm
            Layer normalization modules
        attn : SelfAttn
            Self-attention module
        mlp : Mlp
            Multi-layer perceptron module
    """
    def __init__(self, dim, num_heads=4, window_size=8, mlp_ratio=4, qkv_bias=False):
        super(Transformer, self).__init__()
        self.window_size = window_size
        self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)

        self.norm1 = nn.LayerNorm(dim)
        self.attn = SelfAttn(dim, num_heads, qkv_bias)
        self.norm2 = nn.LayerNorm(dim)

        self.mlp = Mlp(dim, mlp_ratio)


    def forward(self, x):
        """
            Forward pass of the Transformer.

            Process steps:
            1. Add positional embedding
            2. Partition input into windows
            3. Apply self-attention within windows
            4. Merge windows and apply MLP
            5. Add skip connections

            Parameters:
            -----------
            x : torch.Tensor
                Input features [B, C, H, W]

            Returns:
            --------
            torch.Tensor
                Transformed features [B, C, H, W]
        """
        x = x + self.pos_embed(x)
        x = rearrange(x, 'b c h w -> b h w c')
        b, h, w, c = x.shape

        shortcut = x
        x = self.norm1(x)

        pad_l = pad_t = 0
        pad_r = (self.window_size - w % self.window_size) % self.window_size
        pad_b = (self.window_size - h % self.window_size) % self.window_size
        x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
        _, Hp, Wp, _ = x.shape

        x_windows = window_partition(x, self.window_size)  # nW*B, window_size, window_size, c
        x_windows = rearrange(x_windows, 'B s1 s2 c -> B (s1 s2) c', s1=self.window_size,
                              s2=self.window_size)  # nW*b, window_size*window_size, c

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows)  # nW*b, window_size*window_size, c
        attn_windows = attn_windows.reshape(b*256,64,-1)

        # merge windows
        attn_windows = rearrange(attn_windows, 'B (s1 s2) c -> B s1 s2 c', s1=self.window_size, s2=self.window_size)
        x = window_reverse(attn_windows, self.window_size, Hp, Wp)  # b H' W' c

        # reverse cyclic shift
        if pad_r > 0 or pad_b > 0:
            x = x[:, :h, :w, :].contiguous()

        x = x + shortcut
        x = x + self.mlp(self.norm2(x))
        return rearrange(x, 'b h w c -> b c h w')


class ResBlock(nn.Module):
    """
        Residual Block with grouped convolutions and channel expansion.

        This block implements a modified residual structure using:
        1. Channel expansion through 1x1 convolution
        2. Grouped 3x3 convolution for spatial processing
        3. Channel reduction through 1x1 convolution
        4. Residual connection

        Parameters:
        -----------
        in_features : int
            Number of input channels
        ratio : int, optional
            Channel expansion ratio (default: 4)

        Attributes:
        -----------
        net : nn.Sequential
            Sequential container of layers:
            - First 1x1 conv for channel expansion
            - LeakyReLU activation
            - Grouped 3x3 conv (groups = in_features * ratio)
            - LeakyReLU activation
            - Second 1x1 conv for channel reduction
    """
    def __init__(self, in_features, ratio=4):
        super(ResBlock, self).__init__()

        self.net = nn.Sequential(
            nn.Conv2d(in_features, in_features * ratio, 1, 1, 0),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(in_features * ratio, in_features * ratio, 3, 1, 1, groups=in_features * ratio),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(in_features * ratio, in_features, 1, 1, 0),
        )

    def forward(self, x):
        return self.net(x) + x



class BaseBlock(nn.Module):
    """
        Base building block combining Transformer, ResBlock, and ERAM modules in sequence.

        This block creates a series of processing stages, each consisting of:
        1. Window-based Transformer for global feature interaction
        2. Residual block for local feature processing
        3. Enhanced Residual Attention Module (ERAM) for feature refinement

        Parameters:
        -----------
        dim : int
            Feature dimension
        num_heads : int, optional
            Number of attention heads in Transformer (default: 8)
        window_size : int, optional
            Size of attention windows (default: 8)
        ratios : list, optional
            List of expansion ratios for each stage (default: [1, 2, 2, 4, 4])
        qkv_bias : bool, optional
            Whether to use bias in Transformer QKV projection (default: False)

        Attributes:
        -----------
        layers : nn.ModuleList
            List of processing stages, each containing Transformer, ResBlock, and ERAM
    """

    def __init__(self, dim, num_heads=8, window_size=8, ratios=[1, 2, 2, 4, 4], qkv_bias=False):
        super(BaseBlock, self).__init__()
        self.layers = nn.ModuleList([])
        for ratio in ratios:
            self.layers.append(nn.ModuleList([
                Transformer(dim, num_heads, window_size, ratio, qkv_bias),
                ResBlock(dim, ratio),
                ERAM(dim,128) # this 128 represent the input square (128x128)dimension, can be changed. 
            ]))

    def forward(self, x):
        for tblock, rblock ,eram in self.layers:
            x = tblock(x)
            x = rblock(x)
            x = eram(x)
        return x


class SRModel(nn.Module):
    """
        Super-Resolution Model combining transformer-based processing with pixel shuffle upsampling.

        Architecture components:
        1. Feature extraction head
        2. Transformer-based main body
        3. Feature fusion
        4. Pixel shuffle upsampling
        5. Reconstruction tail

        Parameters:
        -----------
        n_feats : int, optional
            Number of feature channels (default: 40)
        n_heads : int, optional
            Number of attention heads (default: 8)
        ratios : list, optional
            Expansion ratios for BaseBlock stages (default: [4, 2, 2, 2, 4])
        upscaling_factor : int, optional
            Super-resolution scale factor (default: 2)

        Attributes:
        -----------
        head : nn.Conv2d
            Initial feature extraction
        body : BaseBlock
            Main feature processing block
        fuse : nn.Conv2d
            Feature fusion layer
        upsapling : nn.Sequential
            Pixel shuffle-based upsampling network
        tail : nn.Conv2d
            Final reconstruction layer
        act : nn.LeakyReLU
            Activation function
    """
    def __init__(self, n_feats=40, n_heads=8, ratios=[4, 2, 2, 2, 4], upscaling_factor=2):
        super(SRModel, self).__init__()
        self.scale = upscaling_factor
        self.head = nn.Conv2d(1, n_feats, 3, 1, 1)

        self.body = BaseBlock(n_feats, num_heads=n_heads, ratios=ratios)

        self.fuse = nn.Conv2d(n_feats * 2, n_feats, 3, 1, 1)

        if self.scale == 4:
            self.upsapling = nn.Sequential(
                nn.Conv2d(n_feats, n_feats * 4, 1, 1, 0),
                nn.PixelShuffle(2),
                nn.Conv2d(n_feats, n_feats * 4, 1, 1, 0),
                nn.PixelShuffle(2)
            )
        else:
            self.upsapling = nn.Sequential(
                nn.Conv2d(n_feats, n_feats * self.scale * self.scale, 1, 1, 0),
                nn.PixelShuffle(self.scale)
            )

        self.tail = nn.Conv2d(n_feats, 1, 3, 1, 1)
        self.act = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        x0 = self.head(x)
        x0 = self.fuse(torch.cat([x0, self.body(x0)], dim=1))
        x0 = self.upsapling(x0)
        x0 = self.tail(self.act(x0))
        x = F.interpolate(x, scale_factor=self.scale, mode='bicubic', align_corners=False)
        return (torch.tanh(x0 + x) +1.0)/2.0

network = SRModel()
inp = torch.rand([1,1,128,128])
out = network(inp)
print(out.shape)

In [None]:
network = network.to(device) #moving model on GPU or CPU, as per availability

In [None]:
class Blocks(nn.Module):
    """
        Basic convolutional block used in the discriminator network.

        Consists of:
        - Convolution layer with stride
        - Batch normalization
        - LeakyReLU activation

        Parameters:
        -----------
        in_channels : int
            Number of input channels
        out_channels : int
            Number of output channels
        stride : int
            Convolution stride
    """
    def __init__(self, in_channels, out_channels, stride):
        super(Blocks, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.conv(x)

class Discriminator(nn.Module):
    """
        Discriminator network for adversarial training in DEM super-resolution.

        Architecture:
        - Initial convolutional layer
        - Series of convolutional blocks with increasing feature channels
        - Final classification layer

        The network progressively reduces spatial dimensions while increasing feature channels,
        followed by a binary classification output.

        Parameters:
        -----------
        in_channels : int
            Number of input channels
        features : int
            Base number of feature channels

        Attributes:
        -----------
        first_layer : nn.Sequential
            Initial feature extraction
        Block1-Block9 : Blocks/nn.Sequential
            Main processing blocks with different channel configurations
        final_layer : nn.Sequential
            Classification head with sigmoid activation
    """
    def __init__(self, in_channels, features):
        super(Discriminator, self).__init__()
        self.first_layer= nn.Sequential(
            nn.Conv2d(in_channels, features, 3, 2 ,1),
            nn.LeakyReLU(0.2),
        )
        self.Block1 = Blocks(features, features*2, stride=2)
        self.Block2 = Blocks(features*2, features*2, stride=1)
        self.Block3 = Blocks(features*2, features*4, stride=2)
        self.Block4 = Blocks(features*4, features*4, stride=1)
        self.Block5 = Blocks(features*4, features*8, stride=2)
        self.Block6 = Blocks(features*8, features*8, stride=1)
        self.Block7 = Blocks(features*8, features*8, stride=2)
        self.Block8 = Blocks(features*8, features*8, stride=2)
        self.Block9 = nn.Sequential(
            nn.Conv2d(features*8, features*4, 3, 2, 1),
        
            nn.LeakyReLU(0.2),
        )
        self.final_layer = nn.Sequential(
            nn.Linear(features*4, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x =  self.first_layer(x)
        x =  self.Block1(x)
        x =  self.Block2(x)
        x =  self.Block3(x)
        x =  self.Block4(x)
        x =  self.Block5(x)
        x =  self.Block6(x)
        x =  self.Block7(x)
        x =  self.Block8(x)
        x = self.Block9(x)
        x = x.view(x.size(0), -1)
        return self.final_layer(x)

In [None]:
"""
    Initializes and configures the super-resolution generator and discriminator models.

    Setup includes:
    1. Creating generator (SRModel) instance
    2. Loading pre-trained generator weights
    3. Creating discriminator instance
    4. Moving both models to specified device (CPU/GPU)

    Variables:
    ----------
    generator : SRModel
        Super-resolution generator model
    discriminator : Discriminator
        Discriminator model for adversarial training
    device : torch.device
        Computation device (CPU/GPU)
"""
pre_trained_weight_path = 'model.pt'
generator = SRModel()
generator = generator.to(device)
generator.load_state_dict(torch.load(pre_trained_weight_path))
discriminator = Discriminator(1, 128) # 128 is the shape of input image, change it accordingly
discriminator = discriminator.to(device)

# Training 

In [None]:
wandb.init(project="data_Dem_pretrained_wt_mutlpleDownsampling", name="EXP-1")

In [None]:
"""
    Sets up multiple loss functions and optimizers for training.

    Loss Components:
    - L1 Loss: For pixel-wise accuracy
    - Edge Loss: For gradient preservation
    - SSIM: For structural similarity
    - Additional L1: For supplementary feature matching

    Optimizers:
    - Generator: Adam optimizer with learning rate 1e-5
    - Discriminator: Adam optimizer with learning rate 1e-5
"""

l1Loss = nn.L1Loss().to(device) 
edgeLoss = gradientAwareLoss().to(device) 
ssim = SSIM(data_range=1.0).to(device) 
anotherl1Loss = nn.L1Loss().to(device) 

optim_G = torch.optim.Adam(generator.parameters(), lr=0.00001)
optim_D = torch.optim.Adam(discriminator.parameters(), lr=0.00001)

In [None]:
sobelFilter = Sobel().to('cuda')

In [None]:
from torchmetrics import StructuralSimilarityIndexMeasure as SSIM
ssim = SSIM(data_range=1.0).to('cuda')

In [None]:

num_train_batches = float(len(trainloader))
num_val_batches = float(len(testloader))

def train_one_epoch(epoch):
    
    """
        Performs one epoch of adversarial training for the super-resolution model.

        The training process includes:
        1. Generator training with combined losses:
           - Adversarial loss
           - L1 reconstruction loss
           - SSIM loss
           - Edge preservation loss
        2. Discriminator training with real/fake image classification

        Parameters:
        -----------
        epoch : int
            Current epoch number

        Returns:
        --------
        float
            Average PSNR value for the epoch

        Logging:
        --------
        - L1 Loss
        - Edge Loss
        - SSIM Loss
        - Total Loss
        - SSIM
        - PSNR
    """
    
    print(f"Epoch {epoch}: ", end ="")
    
    l1_loss_per_epoch = 0.0
    edge_loss_per_epoch = 0.0
    ssim_loss_per_epoch = 0.0
    ssim_per_epoch = 0.0
    psnr_per_epoch = 0.0
    total_loss_per_epoch = 0.0
    D_adv_loss = 0
    
    generator.train()
    for batch, (hr, lr) in enumerate(tqdm(trainloader)):

        for p in discriminator.parameters():
            p.requires_grad = False
        #training generator
        optim_G.zero_grad()
 
        lr_images = lr.to(device)
        hr_images = hr.to(device)
        lr_images = lr_images.float()
        predicted_hr_images = generator(lr_images)
        predicted_hr_labels = discriminator(predicted_hr_images)
        gf_loss = F.binary_cross_entropy_with_logits(predicted_hr_labels, torch.ones_like(predicted_hr_labels)) #adverserial loss
      
      
        # reconstruction loss    
      
        l1_loss_per_sample = l1Loss(hr_images*1000, predicted_hr_images*1000)
        ssim_per_sample = ssim(hr_images, predicted_hr_images)
        ssim_loss_per_sample = 1 - ssim_per_sample
        edge_loss = edgeLoss(hr_images*1000, predicted_hr_images*1000)  
        reconstruction_loss = l1_loss_per_sample + 100*(ssim_loss_per_sample) + 50*edge_loss
        t_loss = reconstruction_loss + 50*gf_loss
        
      
        t_loss.backward()
        optim_G.step()
      
        psnr_per_sample = calculate_psnr(hr_images.detach().cpu().numpy(), predicted_hr_images.detach().cpu().numpy())
    
        l1_loss_per_epoch += l1_loss_per_sample.item()
        edge_loss_per_epoch += edge_loss.item() 
        ssim_loss_per_epoch += ssim_loss_per_sample.item() 
        ssim_per_epoch += ssim_per_sample.item()
        psnr_per_epoch += psnr_per_sample 
        total_loss_per_epoch += t_loss.item()
      
        # training discriminator
        for p in discriminator.parameters():
            p.requires_grad = True
        optim_D.zero_grad()
        predicted_hr_images = generator(lr_images).detach() # avoid back propogation to generator
        hr_images = hr_images.float()
        adv_hr_real = discriminator(hr_images)
        adv_hr_fake = discriminator(predicted_hr_images)
        df_loss = F.binary_cross_entropy_with_logits(adv_hr_real, torch.ones_like(adv_hr_real)) + F.binary_cross_entropy_with_logits(adv_hr_fake, torch.zeros_like(adv_hr_fake))
        D_adv_loss += df_loss.item()
        df_loss.backward()
        optim_D.step()
    
    l1_loss_per_epoch /= float(len(trainloader))
    edge_loss_per_epoch /= float(len(trainloader))
    ssim_loss_per_epoch /= float(len(trainloader))
    ssim_per_epoch /= float(len(trainloader))
    psnr_per_epoch /= float(len(trainloader))
    total_loss_per_epoch /= float(len(trainloader))
    
    wandb.log({"Train L1 Loss": l1_loss_per_epoch})
    wandb.log({"Train Edge Loss": edge_loss_per_epoch})
    wandb.log({"Train SSIM Loss": ssim_loss_per_epoch})
    wandb.log({"Train Total Loss": total_loss_per_epoch})
    wandb.log({"Train SSIM": ssim_per_epoch})
    wandb.log({"Train PSNR": psnr_per_epoch})
    
    print(f"(Train) L1 Loss: {l1_loss_per_epoch:.3f} | SSIM Loss: {ssim_loss_per_epoch:.3f} | Edge Loss: {edge_loss_per_epoch:.3f} | Total Loss: {total_loss_per_epoch:.3f}")
    print(f"SSIM: {ssim_per_epoch:.3f} | PSNR: {psnr_per_epoch}")
    
    torch.cuda.empty_cache()
    gc.collect()
    
    
    return psnr_per_epoch




In [None]:

def valid_one_epoch(epoch):
    """
        Performs one epoch of validation for the super-resolution model.

        The validation process includes:
        1. Forward pass through generator
        2. Calculation of SSIM and PSNR metrics
        3. Visualization of sample results (LR, HR, and SR images)
        4. Logging of metrics and visualizations

        Parameters:
        -----------
        epoch : int
            Current epoch number

        Returns:
        --------
        float
            Average PSNR value for the validation set

        Logging:
        --------
        - SSIM
        - PSNR
        - Sample image visualizations (LR, HR, SR)
    """
    ssim_per_epoch = 0.0
    psnr_per_epoch = 0.0
    b_ssim_per_epoch = 0.0
    b_psnr_per_epoch = 0.0
    
    generator.eval()
    with torch.no_grad():
        for hr, lr in tqdm(testloader):
            batched_hr, batched_lr = hr.to(device), lr.to(device)
            predicted_sr = generator(batched_lr)

            ssim_per_epoch += ssim(batched_hr, predicted_sr)
            psnr_per_epoch += calculate_psnr(batched_hr.cpu().numpy(), predicted_sr.cpu().numpy())

            grid1 = make_grid(batched_lr[:4])
            grid2 = make_grid(batched_hr[:4])
            grid3 = make_grid(predicted_sr[:4])


            grid1 = wandb.Image(grid1, caption="Low Resolution DEM")
            grid2 = wandb.Image(grid2, caption="High Resolution DEM")
            grid3 = wandb.Image(grid3, caption="Reconstructed High Resolution DEM")

            wandb.log({"Original LR": grid1})
            wandb.log({"Original HR": grid2})
            wandb.log({"Reconstruced": grid3})


        ssim_per_epoch /= float(len(testloader))
        psnr_per_epoch /= float(len(testloader))


        wandb.log({"Test Predicted SSIM": ssim_per_epoch})
        wandb.log({"Test Predicted PSNR": psnr_per_epoch})
       

        print(f"(Val) SSIM: {ssim_per_epoch:.3f} | PSNR: {psnr_per_epoch:.3f}")
        
        torch.cuda.empty_cache()
        gc.collect()
        
        return psnr_per_epoch

In [None]:

# Training Loop Configuration
"""
    Main training loop for the super-resolution model with model saving logic.

    Components:
    - Training for specified number of epochs
    - Validation after each epoch
    - Model checkpoint saving based on best PSNR
    - Memory management with CUDA cache clearing
"""


best_psnr = 0
prev_psnr =0
expNumber = 5
num_epochs = 1
os.makedirs(f"savedModels/Exp{expNumber}",exist_ok=True)
for i in range(num_epochs):
    torch.cuda.empty_cache()
    gc.collect()
    train_psnr = train_one_epoch(i)
    valid_psnr = valid_one_epoch(i)

    
    if valid_psnr > best_psnr:
        best_psnr = valid_psnr
        torch.save(generator.state_dict(), f"savedModels/Exp{expNumber}/SRmodel_{best_psnr}.pt")
        print("Model saved!")