In [None]:
import cv2
import torch
import numpy as np
from eprocessing.dataload import ImageDataset
from eprocessing.etransforms import Scale, RandCrop, AddAWGN
from etrain.trainer import NNTrainer
from modelbuild.denoiser import DivergentRestorer
from emetrics.metrics import *

from pathlib import Path
import matplotlib.pyplot as plt

In [None]:
import torch.nn as nn

class PixelFrequencyLayer(nn.Module):
    def __init__(self, num_bins=256):
        """
        Initialize the layer.
        Args:
            num_bins (int): Number of bins for the pixel intensity values (default: 256 for 8-bit images).
        """
        super(PixelFrequencyLayer, self).__init__()
        self.num_bins = num_bins
        self.register_buffer("pixel_probabilities", torch.ones(num_bins) / num_bins)
    
    def compute_frequencies(self, images):
        """
        Compute pixel intensity frequencies and update probabilities.
        Args:
            images (torch.Tensor): Input images (batch_size, channels, height, width).
        """
        with torch.no_grad():
            # Flatten and compute histogram
            flat_pixels = images.flatten()
            hist = torch.histc(flat_pixels, bins=self.num_bins, min=0, max=self.num_bins - 1)
            
            # Normalize histogram to probabilities
            total_pixels = flat_pixels.numel()
            self.pixel_probabilities = hist / total_pixels

    def forward(self, images):
        """
        Transform the input image pixels into probabilities.
        Args:
            images (torch.Tensor): Input images (batch_size, channels, height, width).
        Returns:
            torch.Tensor: Transformed images with probabilities.
        """
        # Map pixel values to probabilities
        pixel_indices = images.long()  # Ensure pixel values are integers
        probabilities = self.pixel_probabilities[pixel_indices]
        return probabilities

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ChannelwiseVariance(nn.Module):
    def __init__(self, kernel_size: int, stride: int = 1, padding: int = 0):
        """
        Custom layer to compute channel-wise variance maps.
        
        Args:
            kernel_size (int): Size of the kernel (assumed square).
            stride (int): Stride for the sliding window.
            padding (int): Padding to apply to the input.
        """
        super(ChannelwiseVariance, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = nn.ReplicationPad2d()

    def forward(self, x):
        """
        Compute channel-wise variance maps.

        Args:
            x (Tensor): Input tensor of shape (B, C, H, W).

        Returns:
            Tensor: Variance map of shape (B, C, H', W').
        """
        B, C, H, W = x.shape

        # Unfold the input to extract patches of shape (B, C, kernel_size*kernel_size, L)
        patches = F.unfold(
            x, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding
        )  # Shape: (B, C * kernel_size^2, L)

        # Reshape to (B, C, kernel_size*kernel_size, L)
        patches = patches.view(B, C, self.kernel_size**2, -1)

        # Compute mean along patch dimension
        mean = patches.mean(dim=2, keepdim=True)  # Shape: (B, C, 1, L)

        # Compute variance along patch dimension
        variance = ((patches - mean) ** 2).mean(dim=2)  # Shape: (B, C, L)

        # Reshape back to spatial dimensions
        h_out = (H + 2 * self.padding - self.kernel_size) // self.stride + 1
        w_out = (W + 2 * self.padding - self.kernel_size) // self.stride + 1
        variance_map = variance.view(B, C, h_out, w_out)

        return variance_map

In [None]:
class ChannelVarianceLayer(nn.Module):
    def __init__(self, in_channels: int = 3):
        super(ChannelVarianceLayer, self).__init__()
        # Create a 2x2 kernel for each channel, initialize to 1/4 (for mean calculation)
        self.kernel = torch.ones((in_channels, 1, 2, 2)) / 4.0

    def forward(self, x):
        # Get the size of the input tensor
        batch_size, channels, height, width = x.size()
        
        # Padding to maintain the same size
        padding = (1, 0, 1, 0)
        x_padded = F.pad(x, padding, mode='replicate')
        
        # Compute the squared values for variance calculation
        squared_x = x_padded ** 2
        
        # Apply 2x2 average kernel to compute the mean squared values (i.e., second moment)
        mean_squared = F.conv2d(squared_x, self.kernel, stride=1, padding=0, groups=channels)

        # Apply 2x2 average kernel to compute the mean (i.e., first moment)
        mean = F.conv2d(x_padded, self.kernel, stride=1, padding=0, groups=channels)

        # Compute the variance: variance = E[X^2] - (E[X])^2
        variance_map = mean_squared - mean ** 2

        # Since the variance computation is applied to the padded version, remove the padding
        variance_map = variance_map
        
        return variance_map

In [None]:
xp = Path('D:/Projects/datasets/GoPro/GoPro_Large/orig_blur/awgn-0-0/test/y')
yp = Path('D:/Projects/datasets/GoPro/GoPro_Large/orig_blur/awgn-0-0/test/y')
modelp = Path('D:/Projects/torch-admm-deconv/trained_models/denoiser_gopro_divergent_attention_epoch72_vloss0.0397.tar')

In [None]:
device = 'cuda'
im_shape = (256,256)
min_std, max_std = 5, 15

In [None]:
psnr = PSNRMetric(device)
ssim = SSIMMetric(device)

In [None]:
DECONV1 = {'kern_size': (),
         'max_iters': 100,
         'lmbda': 0.02,
         'iso': True}
DECONV2 = {'kern_size': (),
         'max_iters': 100,
         'rho': 0.004,
         'iso': True}

model = DivergentRestorer(3, 2, 3,
                          3, 4, 128,
                          128, 8,
                          output_activation=torch.nn.Sigmoid(), admms=[DECONV1, DECONV2])

In [None]:
model_d = torch.load(modelp, weights_only=False)

In [None]:
model.load_state_dict(model_d['model_state_dict'])

In [None]:
model = model.to(device)
model = model.eval()

In [None]:
imd = ImageDataset(xp, yp, transforms=[Scale(), RandCrop(im_shape), AddAWGN(std_range=(25, 30), both=False)])

In [None]:
imdt = torch.utils.data.DataLoader(imd, shuffle=True, batch_size=4)

In [None]:
imx, imy = imd[97]

In [None]:
imxx = imx[:, :, :][torch.newaxis, ...]

In [None]:
imxx = imxx.expand(1,3,256,256)

In [None]:
out = model(ref.to(device))
out = out[0].cpu()

In [None]:
varmap = ChannelVarianceLayer()

In [None]:
varm = varmap(imxx)

In [None]:
varm.shape

In [None]:
cv2_image = np.transpose(ref[0].numpy() * 255, (1, 2, 0))
cv2_image = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2RGB)
cv2.imwrite('ffdin.png', cv2_image)

In [None]:
plt.imshow(ref.permute((1,2,0)))
plt.axis('off')

In [None]:
plt.imshow(out.permute((1,2,0)).detach().numpy())
plt.axis('off')

In [None]:
plt.imshow(imy.permute((1,2,0)).detach().numpy())
plt.axis('off')

In [None]:
plt.imshow(ffdout.permute((1,2,0)).detach().numpy())
plt.axis('off')

In [None]:
psnr(ffdnet, ref)

In [None]:
psnr(out[torch.newaxis,...], ref)

In [None]:
ssim(out[torch.newaxis,...], ref)

In [None]:
psnr(ffdout[torch.newaxis,...] / 255, imy)

In [None]:
ssim(ffdout[torch.newaxis,...] / 255, imy[torch.newaxis,...])

In [None]:
import torchvision

In [None]:
ffdnet = torchvision.io.read_image('D:/Projects/torch-admm-deconv/ffdnetout.png') / 255.0
ref = torchvision.io.read_image('D:/Projects/torch-admm-deconv/ref.png') / 255.0

In [None]:
ffdnet = ffdnet[torch.newaxis,...]
ref = ref[torch.newaxis,...]

In [None]:
ffdnet.shape

In [None]:
psnr(out[torch.newaxis,...], imy)

In [None]:
ssim(out[torch.newaxis,...], imy[torch.newaxis,...])

In [None]:
psnr(ffdout[torch.newaxis,...] / 255, imy)

In [None]:
ssim(ffdout[torch.newaxis,...] / 255, imy[torch.newaxis,...])