Install depedencies

In [None]:
!pip install pytorch_msssim ripser scikit-tda pytorch-wavelets
!pip install -U --no-cache-dir gdown --pre

download data

In [None]:
!gdown '14yLiI7R8ghl0BlAGMSgMEuafA9t9Whxa' -O 'SD_Astro_combined.h5'
!gdown '1x0B4vog2sqJ4P-kJOKpYupuR0hwkbTMs' -O 'SD_Astro4.h5'

extract h5 image files to a folder, go through each sd combined(34000 images) astro4 dataset (+11000)

In [None]:
import os
import h5py
from PIL import Image
import numpy as np
from multiprocessing import Pool, cpu_count

def save_image(image_data, idx, output_dir):
    img = Image.fromarray(image_data.astype('uint8'), 'RGB')
    img.save(os.path.join(output_dir, f'image_{idx:05d}.png'))

def extract_images_chunk(h5_file, output_dir, start_idx, end_idx, initial_idx):
    with h5py.File(h5_file, 'r') as f:
        images = np.array(f['images'][start_idx:end_idx])
        for idx, image in enumerate(images):
            save_image(image, initial_idx + start_idx + idx, output_dir)
            print(f"Saved image {initial_idx + start_idx + idx}")

def extract_images_parallel(h5_file, output_dir, chunk_size=1000):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Find the highest existing image index
    existing_images = [int(f.split('_')[1].split('.')[0]) for f in os.listdir(output_dir) if f.endswith('.png')]
    if existing_images:
        initial_idx = max(existing_images) + 1
    else:
        initial_idx = 0

    total_images_before = len(existing_images)
    print(f"Total images before adding new images: {total_images_before}")

    with h5py.File(h5_file, 'r') as f:
        num_images = f['images'].shape[0]

    pool = Pool(cpu_count())
    chunks = [(start_idx, min(start_idx + chunk_size, num_images)) for start_idx in range(0, num_images, chunk_size)]

    for start_idx, end_idx in chunks:
        pool.apply_async(extract_images_chunk, (h5_file, output_dir, start_idx, end_idx, initial_idx))

    pool.close()
    pool.join()

    # Count the total images after adding the new images
    total_images_after = len([f for f in os.listdir(output_dir) if f.endswith('.png')])
    print(f"Total images after adding new images: {total_images_after}")

# Path to the HDF5 file and output directory
sd_astro_path = "SD_Astro_combined.h5"
output_dir = "Astro_images"

# Extract images using multiprocessing
extract_images_parallel(sd_astro_path, output_dir)
print(f"Images extracted to {output_dir}")


Create dataset and dataloader

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import os

class ImageFolderDataset(Dataset):
    def __init__(self, image_dir, transform=None, random_crop_size=256):
        self.image_dir = image_dir
        self.image_files = sorted(os.listdir(image_dir))
        self.transform = transform
        self.random_crop = transforms.RandomCrop(random_crop_size)

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            return self.__getitem__((idx + 1) % len(self))

        if self.random_crop:
            image = self.random_crop(image)

        if self.transform:
            image = self.transform(image)

        return image


# Define transformations
transform = transforms.Compose([
    transforms.ToTensor(),
])


# Create dataset
image_dir = "Astro_images"
image_dataset = ImageFolderDataset(image_dir, transform=transform)

# Split the dataset into training and validation sets
train_size = int(0.95 * len(image_dataset))
val_size = len(image_dataset) - train_size
train_dataset, val_dataset = random_split(image_dataset, [train_size, val_size])

# Create data loaders
batch_size = 6  # Adjusted for efficiency
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=7, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True)

initialize networks

In [None]:
import math
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import trunc_normal_
from timm.models.vision_transformer import Block
import pytorch_wavelets

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = self.relu(out)
        return out

class ResidualDenseBlock(nn.Module):
    def __init__(self, in_channels, growth_rate, num_layers=4):
        super(ResidualDenseBlock, self).__init__()
        self.num_layers = num_layers
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            in_ch = in_channels + i * growth_rate
            out_ch = growth_rate if i < num_layers - 1 else in_channels
            self.layers.append(nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1))
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        inputs = [x]
        for layer in self.layers:
            out = self.relu(layer(torch.cat(inputs, dim=1)))
            inputs.append(out)
        return out * 0.2 + x

class AttentionBlock(nn.Module):
    def __init__(self, in_channels):
        super(AttentionBlock, self).__init__()
        self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, channels, width, height = x.size()
        query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        key = self.key_conv(x).view(batch_size, -1, width * height)
        attention = torch.bmm(query, key)
        attention = F.softmax(attention, dim=-1)
        value = self.value_conv(x).view(batch_size, -1, width * height)
        out = torch.bmm(value, attention.permute(0, 2, 1))
        out = out.view(batch_size, channels, width, height)
        out = self.gamma * out + x
        return out

class CBAMBlock(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super(CBAMBlock, self).__init__()
        self.channel_attention = ChannelAttentionBlock(in_channels, reduction)
        self.spatial_attention = SpatialAttentionBlock()

    def forward(self, x):
        x_out = self.channel_attention(x)
        x_out = self.spatial_attention(x_out)
        return x_out

class ChannelAttentionBlock(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super(ChannelAttentionBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // reduction, 1, padding=0, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // reduction, in_channels, 1, padding=0, bias=True),
            nn.Sigmoid()
        )

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.fc(y)
        return x * y

class SpatialAttentionBlock(nn.Module):
    def __init__(self):
        super(SpatialAttentionBlock, self).__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x_out = torch.cat([avg_out, max_out], dim=1)
        x_out = self.conv(x_out)
        return x * self.sigmoid(x_out)

class ResidualInResidualDenseBlock(nn.Module):
    def __init__(self, in_channels, growth_rate, num_blocks=3, num_layers=4):
        super(ResidualInResidualDenseBlock, self).__init__()
        self.blocks = nn.ModuleList([ResidualDenseBlock(in_channels, growth_rate, num_layers) for _ in range(num_blocks)])

    def forward(self, x):
        out = x
        for block in self.blocks:
            out = block(out)
        return out * 0.2 + x

class DynamicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(DynamicConv2d, self).__init__()
        self.stride = stride
        self.padding = padding
        self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, kernel_size, kernel_size))
        self.bias = nn.Parameter(torch.Tensor(out_channels))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x):
        return F.conv2d(x, self.weight, self.bias, stride=self.stride, padding=self.padding)

class WaveletTransform(nn.Module):
    def __init__(self):
        super(WaveletTransform, self).__init__()
        self.dwt = pytorch_wavelets.DWTForward(J=1, wave='haar', mode='zero')
        self.iwt = pytorch_wavelets.DWTInverse(wave='haar', mode='zero')

    def forward(self, x):
        yl, yh = self.dwt(x)
        recon = self.iwt((yl, yh))
        return recon

class NonLocalBlock(nn.Module):
    def __init__(self, in_channels):
        super(NonLocalBlock, self).__init__()
        self.in_channels = in_channels
        self.inter_channels = in_channels // 2

        self.g = nn.Conv2d(in_channels, self.inter_channels, kernel_size=1)
        self.theta = nn.Conv2d(in_channels, self.inter_channels, kernel_size=1)
        self.phi = nn.Conv2d(in_channels, self.inter_channels, kernel_size=1)
        self.W = nn.Conv2d(self.inter_channels, in_channels, kernel_size=1)

        nn.init.constant_(self.W.weight, 0)
        nn.init.constant_(self.W.bias, 0)

    def forward(self, x):
        batch_size, c, h, w = x.size()

        g_x = self.g(x).view(batch_size, self.inter_channels, -1)
        g_x = g_x.permute(0, 2, 1)

        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
        theta_x = theta_x.permute(0, 2, 1)
        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)

        f = torch.matmul(theta_x, phi_x)
        f_div_C = F.softmax(f, dim=-1)

        y = torch.matmul(f_div_C, g_x)
        y = y.permute(0, 2, 1).contiguous()
        y = y.view(batch_size, self.inter_channels, h, w)
        W_y = self.W(y)
        z = W_y + x

        return z

class DenseBlock(nn.Module):
    def __init__(self, in_channels, growth_rate, num_layers):
        super(DenseBlock, self).__init__()
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            self.layers.append(self._make_layer(in_channels + i * growth_rate, growth_rate))

    def _make_layer(self, in_channels, growth_rate):
        layer = nn.Sequential(
            nn.Conv2d(in_channels, growth_rate, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(growth_rate),
            nn.ReLU(inplace=True)
        )
        return layer

    def forward(self, x):
        features = [x]
        for layer in self.layers:
            new_features = layer(torch.cat(features, dim=1))
            features.append(new_features)
        return torch.cat(features, dim=1)

class PixelShuffleBlock(nn.Module):
    def __init__(self, in_channels, out_channels, upscale_factor):
        super(PixelShuffleBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels * (upscale_factor ** 2), kernel_size=3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        x = self.relu(x)
        return x

class MultiScaleFeatureExtractor(nn.Module):
    def __init__(self, in_channels):
        super(MultiScaleFeatureExtractor, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=5, stride=1, padding=2)
        self.conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=7, stride=1, padding=3)
        self.relu = nn.ReLU(inplace=True)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.bn2 = nn.BatchNorm2d(in_channels)
        self.bn3 = nn.BatchNorm2d(in_channels)

        self.non_local = NonLocalBlock(in_channels)

    def forward(self, x):
        out1 = self.relu(self.bn1(self.conv1(x)))
        out2 = self.relu(self.bn2(self.conv2(x)))
        out3 = self.relu(self.bn3(self.conv3(x)))
        multi_scale_features = out1 + out2 + out3
        return self.non_local(multi_scale_features)

class SwinTransformerBlock(nn.Module):
    def __init__(self, embed_dim, depths, num_heads, window_size=7):
        super(SwinTransformerBlock, self).__init__()
        self.embed_dim = embed_dim
        self.depths = depths
        self.num_heads = num_heads
        self.window_size = window_size

        self.proj = nn.Conv2d(embed_dim, embed_dim, kernel_size=1)
        self.blocks = nn.ModuleList([
            Block(embed_dim, num_heads, mlp_ratio=4., qkv_bias=True)
            for _ in range(depths)
        ])

    def forward(self, x):
        batch_size, channels, height, width = x.shape
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)  # (batch_size, num_patches, embed_dim)
        for block in self.blocks:
            x = block(x)
        x = x.transpose(1, 2).reshape(batch_size, channels, height, width)
        return x

class ImprovedDenoisingNetwork(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ImprovedDenoisingNetwork, self).__init__()
        self.initial_conv = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1)

        self.encoder = nn.Sequential(
            DenseBlock(64, 32, num_layers=4),
            ResidualBlock(192, 128, stride=2),  # DenseBlock output channels + 64 initial channels
            ResidualBlock(128, 128, stride=2)
        )

        self.multi_scale = MultiScaleFeatureExtractor(128)
        self.transformer = SwinTransformerBlock(embed_dim=128, depths=2, num_heads=4)

        self.rir_block1 = ResidualInResidualDenseBlock(128, 32, num_blocks=2, num_layers=3)
        self.attention1 = CBAMBlock(128)

        self.rir_block2 = ResidualInResidualDenseBlock(128, 32, num_blocks=2, num_layers=3)
        self.attention2 = CBAMBlock(128)

        self.dynamic_conv = DynamicConv2d(128, 128, kernel_size=3, padding=1)
        self.wavelet_transform = WaveletTransform()

        self.decoder = nn.Sequential(
            PixelShuffleBlock(128, 64, upscale_factor=2),
            PixelShuffleBlock(64, 64, upscale_factor=2),
            nn.Conv2d(64, out_channels, kernel_size=3, stride=1, padding=1)
        )

    def forward(self, x):
        x = self.initial_conv(x)
        encoded = self.encoder(x)
        multi_scale_features = self.multi_scale(encoded)
        transformer_features = self.transformer(multi_scale_features)

        rir1 = self.rir_block1(transformer_features)
        att1 = self.attention1(rir1)

        rir2 = self.rir_block2(att1)
        att2 = self.attention2(rir2)

        dynamic_conv_features = self.dynamic_conv(att2)
        wavelet_features = self.wavelet_transform(dynamic_conv_features)

        decoded = self.decoder(wavelet_features)
        return decoded

noise augments, Loss functions..

In [None]:

import torchvision.transforms.functional as TF
from torchvision import models
from torch.fft import fft2, ifft2
from pytorch_msssim import ms_ssim
import numpy as np
import cv2
from skimage.feature import canny
from skimage.color import rgb2lab, deltaE_ciede2000, rgb2yuv

# Function to scale Gaussian noise to emulate different sensor resolutions
def scale_gaussian_noise(noise_tensor, scale_factor):
    batch_size, channels, height, width = noise_tensor.shape
    scaled_noise = []

    for i in range(batch_size):
        for c in range(channels):
            noise_np = noise_tensor[i, c].cpu().numpy()

            # Resize the noise to emulate higher/lower resolution
            scaled_noise_np = cv2.resize(noise_np, (int(width * scale_factor), int(height * scale_factor)), interpolation=cv2.INTER_LINEAR)

            # Resize it back to the original size
            scaled_noise_np = cv2.resize(scaled_noise_np, (width, height), interpolation=cv2.INTER_LINEAR)

            # Convert back to tensor
            scaled_noise_tensor = torch.tensor(scaled_noise_np).to(noise_tensor.device)

            # Append scaled noise to the list
            if c == 0:
                scaled_noise_batch = scaled_noise_tensor.unsqueeze(0)
            else:
                scaled_noise_batch = torch.cat((scaled_noise_batch, scaled_noise_tensor.unsqueeze(0)), 0)

        scaled_noise.append(scaled_noise_batch.unsqueeze(0))

    scaled_noise_tensor = torch.cat(scaled_noise, 0)
    return scaled_noise_tensor

# Define the function to apply rotational Gaussian noise with scaling
def apply_rotational_gaussian_noise(image_tensor, std=20, mean=0, num_frames=40):
    batch_size, channels, height, width = image_tensor.shape

    # Randomly determine the center point once at the start
    center_x = np.random.randint(0, width)
    center_y = np.random.randint(0, height)
    center = (center_x, center_y)

    # Randomly vary the standard deviation at the start
    std_variation = np.random.uniform(0.5, 2.5) * std

    # Generate noise using torch.randn
    noise_tensor = torch.randn_like(image_tensor) * std_variation + mean

    # Randomly select a scale factor between 0.5 and 1.5
    scale_factor = np.random.uniform(0.5, 4)

    # Scale the Gaussian noise
    scaled_noise_tensor = scale_gaussian_noise(noise_tensor, scale_factor)

    noisy_images = []
    rotation_factor = np.random.uniform(1, 2)
    for _ in range(num_frames):

        angle = np.random.uniform(0, rotation_factor)
        rotation_matrix = cv2.getRotationMatrix2D(center, angle, 1.0)

        # Apply rotation to each image in the batch
        rotated_noise_batch = []
        for i in range(batch_size):
            noise_np = scaled_noise_tensor[i].permute(1, 2, 0).cpu().numpy()
            rotated_noise_np = cv2.warpAffine(noise_np, rotation_matrix, (width, height))
            rotated_noise_tensor = torch.tensor(rotated_noise_np).permute(2, 0, 1).to(image_tensor.device)
            rotated_noise_batch.append(rotated_noise_tensor)

        rotated_noise_batch = torch.stack(rotated_noise_batch)

        # Add rotated noise to image
        noisy_image_tensor = image_tensor + rotated_noise_batch / 255.0
        noisy_images.append(noisy_image_tensor)

    # Stack and average the noisy images
    stacked_image_tensor = torch.stack(noisy_images).mean(dim=0)

    # Clip values to the range [0, 1]
    stacked_image_tensor = torch.clamp(stacked_image_tensor, 0, 1)

    return stacked_image_tensor

# Define the function to add combined noise
def add_combined_noise(image, mean=0, std=0.05, noise_type=None, sp_prob=0.002):
    # Randomly select a scale factor between 0.5 and 1.5
    scale_factor = np.random.uniform(0.5, 3.5)

    if noise_type == 'rotational':
        frames = int(np.random.uniform(40, 80))
        noisy_image = apply_rotational_gaussian_noise(image, std=std*255, mean=mean*255, num_frames=frames)
    elif noise_type == 'poisson':
        multi = np.random.uniform(14, 25)
        lambda_ = image * multi  # Adjust the multiplier to control intensity
        lambda_ = torch.clamp(lambda_, min=0)  # Ensure lambda is non-negative
        poisson_noise = torch.poisson(lambda_) / multi - image

        noisy_image = image + poisson_noise
    else:
        # Randomly vary the standard deviation at the start
        std_variation = np.random.uniform(0.55, 1.75) * std

        # Apply Gaussian noise
        gaussian_noise = torch.randn(image.size(), device=image.device) * std_variation + mean

        # Scale the Gaussian noise
        gaussian_noise = scale_gaussian_noise(gaussian_noise, scale_factor)

        noisy_image = image + gaussian_noise

        # Clamp values to ensure non-negativity
        noisy_image = torch.clamp(noisy_image, 0., None)

        # Optionally apply other types of noise
        if noise_type == 'salt_and_pepper':
            sp_noise = torch.rand(image.size(), device=image.device)
            salt_pepper_noise = torch.where(sp_noise < sp_prob / 2, torch.ones_like(image), torch.zeros_like(image))
            salt_pepper_noise = torch.where(sp_noise > 1 - sp_prob / 2, -torch.ones_like(image), salt_pepper_noise)

            # Scale the salt and pepper noise
            salt_pepper_noise = scale_gaussian_noise(salt_pepper_noise, scale_factor)

            noisy_image += salt_pepper_noise

    noisy_image = torch.clamp(noisy_image, 0., 1.)
    return noisy_image

from skimage.color import rgb2lab, rgb2yuv

class PerceptualLoss(nn.Module):
    def __init__(self):
        super(PerceptualLoss, self).__init__()
        vgg = models.vgg19(pretrained=True).features
        self.layers = [2, 7, 12, 21, 30]  # Conv1_2, Conv2_2, Conv3_4, Conv4_4, Conv5_4
        self.model = nn.Sequential(*[vgg[i] for i in range(max(self.layers) + 1)])
        for param in self.model.parameters():
            param.requires_grad = False

    def forward(self, denoised, target):
        loss = 0
        for i in self.layers:
            denoised_features = self.model[:i + 1](denoised)
            target_features = self.model[:i + 1](target)
            loss += F.l1_loss(denoised_features, target_features)
        return loss

class MSSSIMLoss(nn.Module):
    def __init__(self):
        super(MSSSIMLoss, self).__init__()

    def forward(self, denoised, target):
        return 1 - ms_ssim(denoised, target, data_range=1.0, size_average=True)

class FrequencyDomainLoss(nn.Module):
    def __init__(self):
        super(FrequencyDomainLoss, self).__init__()

    def forward(self, denoised, target):
        denoised_fft = torch.fft.fft2(denoised)
        target_fft = torch.fft.fft2(target)
        return F.l1_loss(torch.abs(denoised_fft), torch.abs(target_fft))

class EnhancedEdgeLoss(nn.Module):
    def __init__(self):
        super(EnhancedEdgeLoss, self).__init__()
        self.sobel_x = nn.Conv2d(3, 3, kernel_size=3, padding=1, bias=False, groups=3)
        self.sobel_y = nn.Conv2d(3, 3, kernel_size=3, padding=1, bias=False, groups=3)
        sobel_kernel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3).repeat(3, 1, 1, 1)
        sobel_kernel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3).repeat(3, 1, 1, 1)
        self.sobel_x.weight.data = sobel_kernel_x
        self.sobel_y.weight.data = sobel_kernel_y

    def forward(self, denoised, target):
        denoised_edges_x = self.sobel_x(denoised)
        denoised_edges_y = self.sobel_y(denoised)
        target_edges_x = self.sobel_x(target)
        target_edges_y = self.sobel_y(target)
        loss_x = F.l1_loss(denoised_edges_x, target_edges_x)
        loss_y = F.l1_loss(denoised_edges_y, target_edges_y)

        # Multi-scale edge loss
        denoised_edges_x2 = self.sobel_x(F.interpolate(denoised, scale_factor=0.5, mode='bilinear'))
        denoised_edges_y2 = self.sobel_y(F.interpolate(denoised, scale_factor=0.5, mode='bilinear'))
        target_edges_x2 = self.sobel_x(F.interpolate(target, scale_factor=0.5, mode='bilinear'))
        target_edges_y2 = self.sobel_y(F.interpolate(target, scale_factor=0.5, mode='bilinear'))
        loss_x2 = F.l1_loss(denoised_edges_x2, target_edges_x2)
        loss_y2 = F.l1_loss(denoised_edges_y2, target_edges_y2)

        return loss_x + loss_y + 0.5 * (loss_x2 + loss_y2)

def rgb_to_lab_tensor(image):
    image = image.permute(0, 2, 3, 1).detach().cpu().numpy()
    lab = rgb2lab(image)
    # Normalize LAB values: L [0, 100], A and B [-128, 127]
    lab[:, :, :, 0] = lab[:, :, :, 0] / 100.0  # Normalize L to [0, 1]
    lab[:, :, :, 1:] = (lab[:, :, :, 1:] + 128) / 255.0  # Normalize A, B to [0, 1]
    return torch.from_numpy(lab).permute(0, 3, 1, 2)

def rgb_to_yuv_tensor(image):
    image = image.permute(0, 2, 3, 1).detach().cpu().numpy()
    yuv = rgb2yuv(image)
    # Normalize YUV values: Y [0, 255], U and V [0, 255]
    yuv[:, :, :, 0] = yuv[:, :, :, 0] / 255.0  # Normalize Y to [0, 1]
    yuv[:, :, :, 1:] = (yuv[:, :, :, 1:] + 0.5)  # Normalize U, V to [0, 1]
    return torch.from_numpy(yuv).permute(0, 3, 1, 2)

class EnhancedColorLuminanceLoss(nn.Module):
    def __init__(self):
        super(EnhancedColorLuminanceLoss, self).__init__()

    def colorfulness_metric(self, image):
        rg = image[:, 0, :, :] - image[:, 1, :, :]
        yb = 0.5 * (image[:, 0, :, :] + image[:, 1, :, :]) - image[:, 2, :, :]
        rg_std, rg_mean = torch.std_mean(rg)
        yb_std, yb_mean = torch.std_mean(yb)
        return torch.sqrt(rg_std ** 2 + yb_std ** 2) + 0.3 * torch.sqrt(rg_mean ** 2 + yb_mean ** 2)

    def dynamic_range_loss(self, denoised, target):
        return F.l1_loss(torch.max(denoised) - torch.min(denoised), torch.max(target) - torch.min(target))

    def forward(self, denoised, target):
        # LAB color space loss
        denoised_lab = rgb_to_lab_tensor(denoised)
        target_lab = rgb_to_lab_tensor(target)

        delta_e = torch.tensor(deltaE_ciede2000(denoised_lab.permute(0, 2, 3, 1).cpu().numpy(),
                                                target_lab.permute(0, 2, 3, 1).cpu().numpy()))
        delta_e = delta_e.to(denoised.device)
        color_loss = torch.mean(delta_e)

        # Luminance loss using Y channel from YUV color space
        denoised_yuv = rgb_to_yuv_tensor(denoised)
        target_yuv = rgb_to_yuv_tensor(target)
        luminance_loss = F.l1_loss(denoised_yuv[:, 0, :, :], target_yuv[:, 0, :, :])

        # Colorfulness metric loss
        colorfulness_l = F.l1_loss(self.colorfulness_metric(denoised), self.colorfulness_metric(target))

        # Dynamic range loss
        dynamic_range_l = self.dynamic_range_loss(denoised, target)

        return color_loss + 0.85 * luminance_loss + 0.5 * colorfulness_l + 0.5 * dynamic_range_l

class CharbonnierLoss(nn.Module):
    def __init__(self, epsilon=1e-3):
        super(CharbonnierLoss, self).__init__()
        self.epsilon = epsilon

    def forward(self, denoised, target):
        return torch.mean(torch.sqrt((denoised - target) ** 2 + self.epsilon ** 2))

class TVLoss(nn.Module):
    def __init__(self, weight=1e-6):
        super(TVLoss, self).__init__()
        self.weight = weight

    def forward(self, denoised):
        batch_size = denoised.size()[0]
        h_tv = torch.pow(denoised[:,:,1:,:] - denoised[:,:,:-1,:], 2).sum()
        w_tv = torch.pow(denoised[:,:,:,1:] - denoised[:,:,:,:-1], 2).sum()
        return self.weight * 2 * (h_tv + w_tv) / batch_size

class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss, self).__init__()
        self.charbonnier_loss = CharbonnierLoss()
        self.perceptual_loss = PerceptualLoss()
        self.msssim_loss = MSSSIMLoss()
        self.frequency_domain_loss = FrequencyDomainLoss()
        self.edge_loss = EnhancedEdgeLoss()
        self.color_luminance_loss = EnhancedColorLuminanceLoss()
        self.tv_loss = TVLoss()

    def forward(self, denoised, target, noisy_input):
        charbonnier = self.charbonnier_loss(denoised, target)
        perceptual = self.perceptual_loss(denoised, target) / 100
        msssim = self.msssim_loss(denoised, target)
        frequency = self.frequency_domain_loss(denoised, target) / 100
        edge = self.edge_loss(denoised, target)
        color_luminance = self.color_luminance_loss(denoised, target)
        tv = self.tv_loss(denoised)

        # Adjusting weights for improved balance
        return charbonnier + 0.35 * perceptual + 0.25 * msssim + 0.1 * frequency + 0.1 * edge + 0.3 * color_luminance + 0.05 * tv

In [None]:
# Importing required libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from torch.fft import fft2, ifft2
from pytorch_msssim import ssim
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# Function to scale Gaussian noise to emulate different sensor resolutions
def scale_gaussian_noise(noise_tensor, scale_factor):
    batch_size, channels, height, width = noise_tensor.shape
    scaled_noise = []

    for i in range(batch_size):
        for c in range(channels):
            noise_np = noise_tensor[i, c].cpu().numpy()

            # Resize the noise to emulate higher/lower resolution
            scaled_noise_np = cv2.resize(noise_np, (int(width * scale_factor), int(height * scale_factor)), interpolation=cv2.INTER_LINEAR)

            # Resize it back to the original size
            scaled_noise_np = cv2.resize(scaled_noise_np, (width, height), interpolation=cv2.INTER_LINEAR)

            # Convert back to tensor
            scaled_noise_tensor = torch.tensor(scaled_noise_np).to(noise_tensor.device)

            # Append scaled noise to the list
            if c == 0:
                scaled_noise_batch = scaled_noise_tensor.unsqueeze(0)
            else:
                scaled_noise_batch = torch.cat((scaled_noise_batch, scaled_noise_tensor.unsqueeze(0)), 0)

        scaled_noise.append(scaled_noise_batch.unsqueeze(0))

    scaled_noise_tensor = torch.cat(scaled_noise, 0)
    return scaled_noise_tensor

# Define the function to apply rotational Gaussian noise with scaling
def apply_rotational_gaussian_noise(image_tensor, std=40, mean=0, num_frames=50):
    batch_size, channels, height, width = image_tensor.shape

    # Randomly determine the center point once at the start
    center_x = np.random.randint(0, width)
    center_y = np.random.randint(0, height)
    center = (center_x, center_y)

    # Randomly vary the standard deviation at the start
    std_variation = np.random.uniform(0.8, 2.0) * std

    # Generate noise using torch.randn
    noise_tensor = torch.randn_like(image_tensor) * std_variation + mean

    # Randomly select a scale factor between 0.5 and 1.5
    scale_factor = np.random.uniform(0.75, 2.0)

    # Scale the Gaussian noise
    scaled_noise_tensor = scale_gaussian_noise(noise_tensor, scale_factor)

    noisy_images = []

    rotation_factor = np.random.uniform(1, 3)
    for _ in range(num_frames):

        angle = np.random.uniform(0, rotation_factor)
        rotation_matrix = cv2.getRotationMatrix2D(center, angle, 1.0)

        # Apply rotation to each image in the batch
        rotated_noise_batch = []
        for i in range(batch_size):
            noise_np = scaled_noise_tensor[i].permute(1, 2, 0).cpu().numpy()
            rotated_noise_np = cv2.warpAffine(noise_np, rotation_matrix, (width, height))
            rotated_noise_tensor = torch.tensor(rotated_noise_np).permute(2, 0, 1).to(image_tensor.device)
            rotated_noise_batch.append(rotated_noise_tensor)

        rotated_noise_batch = torch.stack(rotated_noise_batch)

        # Add rotated noise to image
        noisy_image_tensor = image_tensor + rotated_noise_batch / 255.0
        noisy_images.append(noisy_image_tensor)

    # Stack and average the noisy images
    stacked_image_tensor = torch.stack(noisy_images).mean(dim=0)

    # Clip values to the range [0, 1]
    stacked_image_tensor = torch.clamp(stacked_image_tensor, 0, 1)

    return stacked_image_tensor

# Define the function to add combined noise
def add_combined_noise(image, mean=0, std=0.075, noise_type=None, sp_prob=0.002):
    # Randomly select a scale factor between 0.5 and 1.5
    scale_factor = np.random.uniform(0.75, 2.5)

    if noise_type == 'rotational':
        noisy_image = apply_rotational_gaussian_noise(image, std=std*255, mean=mean*255)
    elif noise_type == 'poisson':
        multi = np.random.uniform(15, 25)
        lambda_ = image * multi  # Adjust the multiplier to control intensity
        poisson_noise = torch.poisson(lambda_) / multi - image

        noisy_image = image + poisson_noise
    else:
        # Randomly vary the standard deviation at the start
        std_variation = np.random.uniform(0.75, 1.75) * std

        # Apply Gaussian noise
        gaussian_noise = torch.randn(image.size(), device=image.device) * std_variation + mean

        # Scale the Gaussian noise
        gaussian_noise = scale_gaussian_noise(gaussian_noise, scale_factor)

        noisy_image = image + gaussian_noise

        # Clamp values to ensure non-negativity
        noisy_image = torch.clamp(noisy_image, 0., None)

        # Optionally apply other types of noise
        if noise_type == 'salt_and_pepper':
            sp_noise = torch.rand(image.size(), device=image.device)
            salt_pepper_noise = torch.where(sp_noise < sp_prob / 2, torch.ones_like(image), torch.zeros_like(image))
            salt_pepper_noise = torch.where(sp_noise > 1 - sp_prob / 2, -torch.ones_like(image), salt_pepper_noise)

            # Scale the salt and pepper noise
            salt_pepper_noise = scale_gaussian_noise(salt_pepper_noise, scale_factor)

            noisy_image += salt_pepper_noise

    noisy_image = torch.clamp(noisy_image, 0., 1.)
    return noisy_image
# Function to visualize images
def visualize_images(original, noisy_images, titles, figsize=(20, 10)):
    num_images = len(noisy_images) + 1
    fig, axes = plt.subplots(1, num_images, figsize=figsize)

    # Original image
    axes[0].imshow(original.permute(1, 2, 0).cpu().numpy())
    axes[0].set_title("Original Image")
    axes[0].axis('off')

    # Noisy images
    for i, noisy_image in enumerate(noisy_images):
        axes[i+1].imshow(noisy_image.permute(1, 2, 0).cpu().numpy())
        axes[i+1].set_title(titles[i])
        axes[i+1].axis('off')

    plt.show()

# Load an example image and convert it to a tensor
image_path = '/notebooks/astro9.jpg'  # Update this path to your image file
image = Image.open(image_path).convert('RGB')
transform = transforms.Compose([
    transforms.ToTensor()
])
image_tensor = transform(image).unsqueeze(0).to('cuda' if torch.cuda.is_available() else 'cpu')

# Generate noisy images
noisy_image_gaussian = add_combined_noise(image_tensor, noise_type=None)
noisy_image_rotational = add_combined_noise(image_tensor, noise_type='rotational')
noisy_image_poisson = add_combined_noise(image_tensor, noise_type='poisson')
noisy_image_SP = add_combined_noise(image_tensor, noise_type='salt_and_pepper')

# Visualize images
visualize_images(
    image_tensor.squeeze(0),
    [
        noisy_image_gaussian.squeeze(0),
        noisy_image_rotational.squeeze(0),
        noisy_image_poisson.squeeze(0),
        noisy_image_SP.squeeze(0)
    ],
    ["Gaussian Noise",  "Rotational Gaussian Noise", "poisson", "SP"]
)


training loop/training

In [None]:
import torch.optim as optim
import matplotlib.pyplot as plt
import os
import random
from accelerate import Accelerator

# Initialize the accelerator
accelerator = Accelerator()

# Disable LaTeX rendering in matplotlib
plt.rcParams['text.usetex'] = False


augmentation_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ColorJitter(brightness=0.05, contrast=0.15, saturation=0.15, hue=0.05),
])

# Function to load checkpoints
def load_checkpoint(model, optimizer, checkpoint_path):
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        start_step = checkpoint['step']
        print(f"Loaded checkpoint from {checkpoint_path} (Epoch: {start_epoch}, Step: {start_step})")
    else:
        start_epoch = 0
        start_step = 0
        print("No checkpoint found. Starting training from scratch.")
    return model, optimizer, start_epoch, start_step

# Initialize the model, optimizer, and loss function
model = ImprovedDenoisingNetwork(in_channels=3, out_channels=3)
optimizer = optim.AdamW(model.parameters(), lr=6e-5, weight_decay=0.1)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)
hybrid_loss = CustomLoss()

# Prepare everything with the accelerator
model, optimizer, hybrid_loss, train_loader, val_loader = accelerator.prepare(model, optimizer, hybrid_loss, train_loader, val_loader)


# Directory for saving model checkpoints
model_save_dir = "saved_models"
checkpoint_path = os.path.join(model_save_dir, "/notebooks/saved_models/latest_checkpoint_12_step_2500.pth")

# Create directory if it doesn't exist
if not os.path.exists(model_save_dir):
    os.makedirs(model_save_dir)

# Load model and optimizer state if checkpoint exists
model, optimizer, start_epoch, start_step = load_checkpoint(model, optimizer, checkpoint_path)



num_epochs = 150
print_every = 250  # Print progress every 250 steps
save_every = 1500   # Save progress every 1500 steps

def save_model(epoch, step, model, optimizer, model_save_dir):
    model_path = os.path.join(model_save_dir, f"latest_checkpoint_{epoch+1}_step_{step+1}.pth")
    accelerator.save({
        'epoch': epoch,
        'step': step,
        'model_state_dict': accelerator.get_state_dict(model),
        'optimizer_state_dict': optimizer.state_dict()
    }, model_path)
    print(f"Model saved at {model_path}")

def save_progress_images(model, images, noisy_images, epoch, step, save_dir="denoising_progress"):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    model.eval()
    with torch.no_grad():
        original = images[0].unsqueeze(0)
        noisy = noisy_images[0].unsqueeze(0)
        denoised = model(noisy).squeeze(0).cpu()
        original = original.squeeze(0).cpu()
        noisy = noisy.squeeze(0).cpu()

        fig, axes = plt.subplots(1, 3, figsize=(18, 6))
        axes[0].imshow(TF.to_pil_image(original))
        axes[0].set_title("Original")
        axes[1].imshow(TF.to_pil_image(noisy))
        axes[1].set_title("Noisy")
        axes[2].imshow(TF.to_pil_image(denoised))
        axes[2].set_title(f"Denoised - Epoch {epoch+1}, Step {step+1}")
        plt.savefig(os.path.join(save_dir, f"epoch_{epoch+1}_step_{step+1}.png"))
        plt.close()
    model.train()

# Training loop
for epoch in range(start_epoch, num_epochs):
    model.train()
    running_loss = 0.0
    noise_types = [None, 'rotational', 'salt_and_pepper', 'poisson']

    for step, inputs in enumerate(train_loader):
        if epoch == start_epoch and step < start_step:
            continue  # Skip the steps already covered in the previous checkpoint

        with accelerator.accumulate(model):
            inputs = inputs.to(accelerator.device)
            augmented_inputs = augmentation_transform(inputs)
            selected_noise = random.choice(noise_types)
            noisy_inputs = add_combined_noise(augmented_inputs, noise_type=selected_noise)
            with accelerator.autocast():
                outputs = model(noisy_inputs)
                loss = hybrid_loss(outputs, augmented_inputs, noisy_inputs)

            if torch.isnan(loss):
                print(f"NaN loss detected at Epoch {epoch+1}, Step {step+1}")
                break

            accelerator.backward(loss)
            optimizer.step()
            optimizer.zero_grad()
            running_loss += loss.item() * inputs.size(0)

            if (step + 1) % print_every == 0:
                print(f"Epoch {epoch+1}, Step {step+1}, Loss: {loss.item():.4f}")

            if (step + 1) % save_every == 0:
                save_progress_images(model, augmented_inputs, noisy_inputs, epoch, step)

            if (step + 1) % 2500 == 0:
                save_model(epoch, step, model, optimizer, model_save_dir)
                print(f"Model saved at epoch {epoch+1}, step {step+1}")

    epoch_loss = running_loss / len(train_loader.dataset)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
    scheduler.step(epoch_loss)

Test model denoising

In [None]:
import torch
from torchvision import transforms
from PIL import Image, ImageOps
import matplotlib.pyplot as plt
import numpy as np
from scipy.ndimage import gaussian_filter

torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the trained model (make sure to replace 'your_model_file.pth' with the path to your model file)
checkpoint = torch.load('/notebooks/saved_models/latest_checkpoint_8_step_2500.pth', map_location=device)
model = ImprovedDenoisingNetwork(in_channels=3, out_channels=3)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()


# Function to load an image
def load_image(image_path):
    image = Image.open(image_path).convert('RGB')
    return image

# Function to split image into overlapping tiles
def image_to_tiles(image, tile_size, overlap):
    w, h = image.size
    step = tile_size - overlap
    tiles = []
    positions = []
    for i in range(0, h, step):
        for j in range(0, w, step):
            right = min(j + tile_size, w)
            bottom = min(i + tile_size, h)
            tile = image.crop((j, i, right, bottom))
            tiles.append(tile)
            positions.append((i, j, right - j, bottom - i))
    return tiles, positions

# Create an alpha mask for blending
def create_alpha_mask(tile_size, overlap):
    mask = np.ones((tile_size, tile_size), dtype=np.float32)
    ramp = np.linspace(0, 1, overlap)
    mask[:overlap, :] *= ramp[:, None]
    mask[-overlap:, :] *= ramp[::-1, None]
    mask[:, :overlap] *= ramp[None, :]
    mask[:, -overlap:] *= ramp[None, ::-1]
    return mask

# Function to merge tiles back to image with alpha blending
def tiles_to_image(tiles, positions, image_size, tile_size, overlap):
    full_image = np.zeros((image_size[1], image_size[0], 3), dtype=np.float32)
    alpha_map = np.zeros((image_size[1], image_size[0], 3), dtype=np.float32)
    alpha_mask = create_alpha_mask(tile_size, overlap)

    for idx, (i, j, width, height) in enumerate(positions):
        tile = np.array(tiles[idx])[:height, :width]  # Crop tile to original size before padding
        h, w, _ = tile.shape

        # Ensure the alpha mask matches the tile size
        mask = alpha_mask[:h, :w, np.newaxis]

        full_image[i:i+height, j:j+width] += tile * mask
        alpha_map[i:i+height, j:j+width] += mask

    final_image = full_image / np.maximum(alpha_map, 1e-8)  # Normalize by the alpha map, avoiding division by zero
    final_image = np.clip(final_image, 0, 255).astype(np.uint8)
    return Image.fromarray(final_image)

# Transformations
transform = transforms.Compose([
    transforms.ToTensor()
])

# Load and process the image in overlapping tiles
image_path = 'astro9.jpg'  # Replace with your image path
original_image = load_image(image_path)
tile_size = 256  # Define tile size to match your model's expected input size
overlap = 32     # Define overlap size
tiles, positions = image_to_tiles(original_image, tile_size, overlap)

# Process each tile
processed_tiles = []
for tile in tiles:
    input_tensor = transform(tile).unsqueeze(0).to(device)

    with torch.no_grad():
        processed_tile = model(input_tensor).squeeze(0)

    processed_tile = torch.clamp(processed_tile, 0, 1)  # Ensure output is in correct range

    processed_tiles.append(transforms.ToPILImage()(processed_tile.cpu()))

# Reconstruct the image from tiles
reconstructed_image = tiles_to_image(processed_tiles, positions, original_image.size, tile_size, overlap)

# Function to save the denoised image
def save_image(image, path):
    image.save(path)
    print(f"Image saved at {path}")

# Function to display images
def show_images(original, reconstructed):
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    axes[0].imshow(np.asarray(original))
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    axes[1].imshow(np.asarray(reconstructed))
    axes[1].set_title('Reconstructed Image')
    axes[1].axis('off')
    plt.show()


# Display the images
show_images(original_image, reconstructed_image)

# Save the denoised image
output_path = 'denoised_astro9.png'  # Replace with your desired output path
save_image(reconstructed_image, output_path)