# **Change runtime type to T4 GPU for faster processing.**

Step 1: Run this cell first to Download Model/Install dependencies, may take a few minutes to finish...

In [None]:
# @title
!pip install timm pytorch_wavelets
!gdown --id '1MM8fZ-TFhWJYCwUnL8oI2KFDiNuV_x0K' -O 'latest_checkpoint.pth'

Step 2: Run this cell next, for model initialization..

In [None]:
# @title
import math
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import trunc_normal_
from timm.models.vision_transformer import Block
import pytorch_wavelets

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = self.relu(out)
        return out

class ResidualDenseBlock(nn.Module):
    def __init__(self, in_channels, growth_rate, num_layers=4):
        super(ResidualDenseBlock, self).__init__()
        self.num_layers = num_layers
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            in_ch = in_channels + i * growth_rate
            out_ch = growth_rate if i < num_layers - 1 else in_channels
            self.layers.append(nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1))
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        inputs = [x]
        for layer in self.layers:
            out = self.relu(layer(torch.cat(inputs, dim=1)))
            inputs.append(out)
        return out * 0.2 + x

class AttentionBlock(nn.Module):
    def __init__(self, in_channels):
        super(AttentionBlock, self).__init__()
        self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, channels, width, height = x.size()
        query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        key = self.key_conv(x).view(batch_size, -1, width * height)
        attention = torch.bmm(query, key)
        attention = F.softmax(attention, dim=-1)
        value = self.value_conv(x).view(batch_size, -1, width * height)
        out = torch.bmm(value, attention.permute(0, 2, 1))
        out = out.view(batch_size, channels, width, height)
        out = self.gamma * out + x
        return out

class CBAMBlock(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super(CBAMBlock, self).__init__()
        self.channel_attention = ChannelAttentionBlock(in_channels, reduction)
        self.spatial_attention = SpatialAttentionBlock()

    def forward(self, x):
        x_out = self.channel_attention(x)
        x_out = self.spatial_attention(x_out)
        return x_out

class ChannelAttentionBlock(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super(ChannelAttentionBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(in_channels, in_channels // reduction, 1, padding=0, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // reduction, in_channels, 1, padding=0, bias=True),
            nn.Sigmoid()
        )

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.fc(y)
        return x * y

class SpatialAttentionBlock(nn.Module):
    def __init__(self):
        super(SpatialAttentionBlock, self).__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x_out = torch.cat([avg_out, max_out], dim=1)
        x_out = self.conv(x_out)
        return x * self.sigmoid(x_out)

class ResidualInResidualDenseBlock(nn.Module):
    def __init__(self, in_channels, growth_rate, num_blocks=3, num_layers=4):
        super(ResidualInResidualDenseBlock, self).__init__()
        self.blocks = nn.ModuleList([ResidualDenseBlock(in_channels, growth_rate, num_layers) for _ in range(num_blocks)])

    def forward(self, x):
        out = x
        for block in self.blocks:
            out = block(out)
        return out * 0.2 + x

class DynamicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(DynamicConv2d, self).__init__()
        self.stride = stride
        self.padding = padding
        self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, kernel_size, kernel_size))
        self.bias = nn.Parameter(torch.Tensor(out_channels))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x):
        return F.conv2d(x, self.weight, self.bias, stride=self.stride, padding=self.padding)

class WaveletTransform(nn.Module):
    def __init__(self):
        super(WaveletTransform, self).__init__()
        self.dwt = pytorch_wavelets.DWTForward(J=1, wave='haar', mode='zero')
        self.iwt = pytorch_wavelets.DWTInverse(wave='haar', mode='zero')

    def forward(self, x):
        yl, yh = self.dwt(x)
        recon = self.iwt((yl, yh))
        return recon

class NonLocalBlock(nn.Module):
    def __init__(self, in_channels):
        super(NonLocalBlock, self).__init__()
        self.in_channels = in_channels
        self.inter_channels = in_channels // 2

        self.g = nn.Conv2d(in_channels, self.inter_channels, kernel_size=1)
        self.theta = nn.Conv2d(in_channels, self.inter_channels, kernel_size=1)
        self.phi = nn.Conv2d(in_channels, self.inter_channels, kernel_size=1)
        self.W = nn.Conv2d(self.inter_channels, in_channels, kernel_size=1)

        nn.init.constant_(self.W.weight, 0)
        nn.init.constant_(self.W.bias, 0)

    def forward(self, x):
        batch_size, c, h, w = x.size()

        g_x = self.g(x).view(batch_size, self.inter_channels, -1)
        g_x = g_x.permute(0, 2, 1)

        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
        theta_x = theta_x.permute(0, 2, 1)
        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)

        f = torch.matmul(theta_x, phi_x)
        f_div_C = F.softmax(f, dim=-1)

        y = torch.matmul(f_div_C, g_x)
        y = y.permute(0, 2, 1).contiguous()
        y = y.view(batch_size, self.inter_channels, h, w)
        W_y = self.W(y)
        z = W_y + x

        return z

class DenseBlock(nn.Module):
    def __init__(self, in_channels, growth_rate, num_layers):
        super(DenseBlock, self).__init__()
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            self.layers.append(self._make_layer(in_channels + i * growth_rate, growth_rate))

    def _make_layer(self, in_channels, growth_rate):
        layer = nn.Sequential(
            nn.Conv2d(in_channels, growth_rate, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(growth_rate),
            nn.ReLU(inplace=True)
        )
        return layer

    def forward(self, x):
        features = [x]
        for layer in self.layers:
            new_features = layer(torch.cat(features, dim=1))
            features.append(new_features)
        return torch.cat(features, dim=1)

class PixelShuffleBlock(nn.Module):
    def __init__(self, in_channels, out_channels, upscale_factor):
        super(PixelShuffleBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels * (upscale_factor ** 2), kernel_size=3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        x = self.relu(x)
        return x

class MultiScaleFeatureExtractor(nn.Module):
    def __init__(self, in_channels):
        super(MultiScaleFeatureExtractor, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=5, stride=1, padding=2)
        self.conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=7, stride=1, padding=3)
        self.relu = nn.ReLU(inplace=True)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.bn2 = nn.BatchNorm2d(in_channels)
        self.bn3 = nn.BatchNorm2d(in_channels)

        self.non_local = NonLocalBlock(in_channels)

    def forward(self, x):
        out1 = self.relu(self.bn1(self.conv1(x)))
        out2 = self.relu(self.bn2(self.conv2(x)))
        out3 = self.relu(self.bn3(self.conv3(x)))
        multi_scale_features = out1 + out2 + out3
        return self.non_local(multi_scale_features)

class SwinTransformerBlock(nn.Module):
    def __init__(self, embed_dim, depths, num_heads, window_size=7):
        super(SwinTransformerBlock, self).__init__()
        self.embed_dim = embed_dim
        self.depths = depths
        self.num_heads = num_heads
        self.window_size = window_size

        self.proj = nn.Conv2d(embed_dim, embed_dim, kernel_size=1)
        self.blocks = nn.ModuleList([
            Block(embed_dim, num_heads, mlp_ratio=4., qkv_bias=True)
            for _ in range(depths)
        ])

    def forward(self, x):
        batch_size, channels, height, width = x.shape
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)  # (batch_size, num_patches, embed_dim)
        for block in self.blocks:
            x = block(x)
        x = x.transpose(1, 2).reshape(batch_size, channels, height, width)
        return x

class ImprovedDenoisingNetwork(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ImprovedDenoisingNetwork, self).__init__()
        self.initial_conv = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1)

        self.encoder = nn.Sequential(
            DenseBlock(64, 32, num_layers=4),
            ResidualBlock(192, 128, stride=2),  # DenseBlock output channels + 64 initial channels
            ResidualBlock(128, 128, stride=2)
        )

        self.multi_scale = MultiScaleFeatureExtractor(128)
        self.transformer = SwinTransformerBlock(embed_dim=128, depths=2, num_heads=4)

        self.rir_block1 = ResidualInResidualDenseBlock(128, 32, num_blocks=2, num_layers=3)
        self.attention1 = CBAMBlock(128)

        self.rir_block2 = ResidualInResidualDenseBlock(128, 32, num_blocks=2, num_layers=3)
        self.attention2 = CBAMBlock(128)

        self.dynamic_conv = DynamicConv2d(128, 128, kernel_size=3, padding=1)
        self.wavelet_transform = WaveletTransform()

        self.decoder = nn.Sequential(
            PixelShuffleBlock(128, 64, upscale_factor=2),
            PixelShuffleBlock(64, 64, upscale_factor=2),
            nn.Conv2d(64, out_channels, kernel_size=3, stride=1, padding=1)
        )

    def forward(self, x):
        x = self.initial_conv(x)
        encoded = self.encoder(x)
        multi_scale_features = self.multi_scale(encoded)
        transformer_features = self.transformer(multi_scale_features)

        rir1 = self.rir_block1(transformer_features)
        att1 = self.attention1(rir1)

        rir2 = self.rir_block2(att1)
        att2 = self.attention2(rir2)

        dynamic_conv_features = self.dynamic_conv(att2)
        wavelet_features = self.wavelet_transform(dynamic_conv_features)

        decoded = self.decoder(wavelet_features)
        return decoded

Step 3: Finally run this cell and select a file to denoise, the denoised image will be Saved to the colab file directories "denoised" folder.

In [None]:
# @title
import torch
from torchvision import transforms
from PIL import Image, ImageOps
import matplotlib.pyplot as plt
import numpy as np
from scipy.ndimage import gaussian_filter
from google.colab import files
import os

# Upload the image
uploaded = files.upload()

# Get the image path
image_path = list(uploaded.keys())[0]

torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the trained model (make sure to replace 'your_model_file.pth' with the path to your model file)
checkpoint = torch.load('latest_checkpoint.pth', map_location=device)
model = ImprovedDenoisingNetwork(in_channels=3, out_channels=3)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()

# Function to load an image
def load_image(image_path):
    image = Image.open(image_path).convert('RGB')
    return image

# Function to split image into overlapping tiles
def image_to_tiles(image, tile_size, overlap):
    w, h = image.size
    step = tile_size - overlap
    tiles = []
    positions = []
    for i in range(0, h, step):
        for j in range(0, w, step):
            right = min(j + tile_size, w)
            bottom = min(i + tile_size, h)
            tile = image.crop((j, i, right, bottom))
            tiles.append(tile)
            positions.append((i, j, right - j, bottom - i))
    return tiles, positions

# Create an alpha mask for blending
def create_alpha_mask(tile_size, overlap):
    mask = np.ones((tile_size, tile_size), dtype=np.float32)
    ramp = np.linspace(0, 1, overlap)
    mask[:overlap, :] *= ramp[:, None]
    mask[-overlap:, :] *= ramp[::-1, None]
    mask[:, :overlap] *= ramp[None, :]
    mask[:, -overlap:] *= ramp[None, ::-1]
    return mask

# Function to merge tiles back to image with alpha blending
def tiles_to_image(tiles, positions, image_size, tile_size, overlap):
    full_image = np.zeros((image_size[1], image_size[0], 3), dtype=np.float32)
    alpha_map = np.zeros((image_size[1], image_size[0], 3), dtype=np.float32)
    alpha_mask = create_alpha_mask(tile_size, overlap)

    for idx, (i, j, width, height) in enumerate(positions):
        tile = np.array(tiles[idx])[:height, :width]  # Crop tile to original size before padding
        h, w, _ = tile.shape

        # Ensure the alpha mask matches the tile size
        mask = alpha_mask[:h, :w, np.newaxis]

        full_image[i:i+height, j:j+width] += tile * mask
        alpha_map[i:i+height, j:j+width] += mask

    final_image = full_image / np.maximum(alpha_map, 1e-8)  # Normalize by the alpha map, avoiding division by zero
    final_image = np.clip(final_image, 0, 255).astype(np.uint8)
    return Image.fromarray(final_image)

# Transformations
transform = transforms.Compose([
    transforms.ToTensor()
])

# Load and process the image in overlapping tiles
original_image = load_image(image_path)
tile_size = 256  # Define tile size to match your model's expected input size
overlap = 32     # Increased overlap size for better blending
tiles, positions = image_to_tiles(original_image, tile_size, overlap)

# Process each tile
processed_tiles = []
for tile in tiles:
    input_tensor = transform(tile).unsqueeze(0).to(device)

    with torch.no_grad():
        processed_tile = model(input_tensor).squeeze(0)

    processed_tile = torch.clamp(processed_tile, 0, 1)  # Ensure output is in correct range

    processed_tiles.append(transforms.ToPILImage()(processed_tile.cpu()))

# Reconstruct the image from tiles
reconstructed_image = tiles_to_image(processed_tiles, positions, original_image.size, tile_size, overlap)

# Function to save the denoised image
def save_image(image, path):
  #create path
  if not os.path.exists(os.path.dirname(path)):
    os.makedirs(os.path.dirname(path))
  image.save(path)
  print(f"Image saved at {path}")

# Function to display images
def show_images(original, reconstructed):
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    axes[0].imshow(np.asarray(original))
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    axes[1].imshow(np.asarray(reconstructed))
    axes[1].set_title('Reconstructed Image')
    axes[1].axis('off')
    plt.show()

# Display the images
show_images(original_image, reconstructed_image)

# Save the denoised image
output_path = os.path.join('/content/denoised', 'denoised_' + image_path)  # Replace with your desired output path
save_image(reconstructed_image, output_path)