In [1]:
import cv2
import torch
import numpy as np
import torchvision
from eprocessing.dataload import ImageDataset
from eprocessing.etransforms import Scale, RandCrop, AddAWGN
from etrain.trainer import NNTrainer
from modelbuild.denoiser import DivergentRestorer
from emetrics.metrics import *

from pathlib import Path
import matplotlib.pyplot as plt

In [2]:
import torch.nn as nn

class PixelFrequencyLayer(nn.Module):
    def __init__(self, num_bins=256):
        """
        Initialize the layer.
        Args:
            num_bins (int): Number of bins for the pixel intensity values (default: 256 for 8-bit images).
        """
        super(PixelFrequencyLayer, self).__init__()
        self.num_bins = num_bins
        self.register_buffer("pixel_probabilities", torch.ones(num_bins) / num_bins)
    
    def compute_frequencies(self, images):
        """
        Compute pixel intensity frequencies and update probabilities.
        Args:
            images (torch.Tensor): Input images (batch_size, channels, height, width).
        """
        with torch.no_grad():
            # Flatten and compute histogram
            flat_pixels = images.flatten()
            hist = torch.histc(flat_pixels, bins=self.num_bins, min=0, max=self.num_bins - 1)
            
            # Normalize histogram to probabilities
            total_pixels = flat_pixels.numel()
            self.pixel_probabilities = hist / total_pixels

    def forward(self, images):
        """
        Transform the input image pixels into probabilities.
        Args:
            images (torch.Tensor): Input images (batch_size, channels, height, width).
        Returns:
            torch.Tensor: Transformed images with probabilities.
        """
        # Map pixel values to probabilities
        pixel_indices = images.long()  # Ensure pixel values are integers
        probabilities = self.pixel_probabilities[pixel_indices]
        return probabilities

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ChannelwiseVariance(nn.Module):
    def __init__(self, kernel_size: int, stride: int = 1, padding: int = 0):
        """
        Custom layer to compute channel-wise variance maps.
        
        Args:
            kernel_size (int): Size of the kernel (assumed square).
            stride (int): Stride for the sliding window.
            padding (int): Padding to apply to the input.
        """
        super(ChannelwiseVariance, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = nn.ReplicationPad2d()

    def forward(self, x):
        """
        Compute channel-wise variance maps.

        Args:
            x (Tensor): Input tensor of shape (B, C, H, W).

        Returns:
            Tensor: Variance map of shape (B, C, H', W').
        """
        B, C, H, W = x.shape

        # Unfold the input to extract patches of shape (B, C, kernel_size*kernel_size, L)
        patches = F.unfold(
            x, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding
        )  # Shape: (B, C * kernel_size^2, L)

        # Reshape to (B, C, kernel_size*kernel_size, L)
        patches = patches.view(B, C, self.kernel_size**2, -1)

        # Compute mean along patch dimension
        mean = patches.mean(dim=2, keepdim=True)  # Shape: (B, C, 1, L)

        # Compute variance along patch dimension
        variance = ((patches - mean) ** 2).mean(dim=2)  # Shape: (B, C, L)

        # Reshape back to spatial dimensions
        h_out = (H + 2 * self.padding - self.kernel_size) // self.stride + 1
        w_out = (W + 2 * self.padding - self.kernel_size) // self.stride + 1
        variance_map = variance.view(B, C, h_out, w_out)

        return variance_map

In [4]:
class ChannelVarianceLayer(nn.Module):
    def __init__(self, in_channels: int = 3):
        super(ChannelVarianceLayer, self).__init__()
        # Create a 2x2 kernel for each channel, initialize to 1/4 (for mean calculation)
        self.kernel = torch.ones((in_channels, 1, 2, 2)) / 4.0

    def forward(self, x):
        # Get the size of the input tensor
        batch_size, channels, height, width = x.size()
        
        # Padding to maintain the same size
        padding = (1, 0, 1, 0)
        x_padded = F.pad(x, padding, mode='replicate')
        
        # Compute the squared values for variance calculation
        squared_x = x_padded ** 2
        
        # Apply 2x2 average kernel to compute the mean squared values (i.e., second moment)
        mean_squared = F.conv2d(squared_x, self.kernel, stride=1, padding=0, groups=channels)

        # Apply 2x2 average kernel to compute the mean (i.e., first moment)
        mean = F.conv2d(x_padded, self.kernel, stride=1, padding=0, groups=channels)

        # Compute the variance: variance = E[X^2] - (E[X])^2
        variance_map = mean_squared - mean ** 2

        # Since the variance computation is applied to the padded version, remove the padding
        variance_map = variance_map
        
        return variance_map

In [5]:
xp = Path('D:/Projects/datasets/GoPro/GoPro_Large/orig_blur/awgn-0-0/test/y')
yp = Path('D:/Projects/datasets/GoPro/GoPro_Large/orig_blur/awgn-0-0/test/y')
modelp = Path('D:/Projects/torch-admm-deconv/trained_models/01-01-25/denoiser_gopro_divergent_attention_wconvs_epoch00_vloss0.0380.tar')

In [6]:
device = 'cuda'
im_shape = (256,256)
min_std, max_std = 5, 15

In [7]:
psnr = PSNRMetric(device)
ssim = SSIMMetric(device)

In [8]:
DECONV1 = {'kern_size': (),
         'max_iters': 100,
         'lmbda': 0.02,
         'iso': True}
DECONV2 = {'kern_size': (),
         'max_iters': 100,
         'rho': 0.004,
         'iso': True}

model = DivergentRestorer(3, 2, 3,
                          3, 4, 86,
                          86, 8,
                          output_activation=torch.nn.Sigmoid(), admms=[DECONV1, DECONV2])

In [9]:
model_d = torch.load(modelp, weights_only=False)

In [10]:
model.load_state_dict(model_d['model_state_dict'])

<All keys matched successfully>

In [11]:
model = model.to(device)
model = model.eval()

In [12]:
imd = ImageDataset(xp, yp, transforms=[Scale(), RandCrop(im_shape), AddAWGN(std_range=(15, 1), both=False)])

In [13]:
imdt = torch.utils.data.DataLoader(imd, shuffle=True, batch_size=4)

## Load test image and add noise

In [63]:
resize = torchvision.transforms.Resize(256)

In [12]:
ref_im = torchvision.io.read_image('D:/Projects/torch-admm-deconv/test_imgs/baboon256.png') / 255.0
# ref_im = resize(ref_im)
ref_im = ref_im[torch.newaxis, ...]

In [65]:
cv2_image = np.transpose(ref_im[0].numpy() * 255, (1, 2, 0))
cv2_image = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2RGB)
cv2.imwrite('D:/Projects/torch-admm-deconv/test_imgs/baboon256.png', cv2_image)

True

In [13]:
awgn_gen = AddAWGN(std_range=(15, 16), both=False)
noisy_ref = awgn_gen(ref_im, ref_im)
noisy_ref = noisy_ref[0]

In [14]:
cv2_image = np.transpose(noisy_ref[0].numpy() * 255, (1, 2, 0))
cv2_image = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2RGB)
cv2.imwrite('D:/Projects/torch-admm-deconv/test_imgs/ref_noisy.png', cv2_image)

True

In [15]:
model_out = model(noisy_ref.to(device))

In [16]:
model_out = model_out.to('cpu')

In [26]:
cv2_image = np.transpose(model_out[0].to('cpu').detach().numpy() * 255, (1, 2, 0))
cv2_image = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2RGB)
cv2.imwrite('D:/Projects/torch-admm-deconv/test_imgs/model_dednoised.png', cv2_image)

True

In [54]:
model_out_b = torch.clamp(model_out + 1.72/255, 0, 1)

### Load ffdnet image

In [18]:
ffndet = torchvision.io.read_image('D:/Projects/torch-admm-deconv/test_imgs/ffdnet.png') / 255.0
ffndet = ffndet[torch.newaxis, ...]

In [56]:
psnr(model_out, ref_im)

tensor(28.2484, device='cuda:0', grad_fn=<CloneBackward0>)

In [20]:
ssim(model_out, ref_im)

tensor(0.8669, device='cuda:0', grad_fn=<CloneBackward0>)

In [21]:
psnr(ffndet, ref_im)

tensor(26.5737, device='cuda:0')

In [22]:
ssim(ffndet, ref_im)

tensor(0.7555, device='cuda:0')