# Question 3

In [8]:
import torch
from torch.fft import fft2, ifft2
import numpy as np
from utils.create_noisy_images_utils import BSDS300Dataset
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Lambda
from utils.models import Unet
import statistics
import skimage.io

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [14]:
class BlurredBSDS300Dataset(BSDS300Dataset):
    def __init__(self, root='./data/BSDS300', patch_size=32, split='train', use_patches=True,
                 kernel_size=7, sigma=2, return_kernel=True):
        super(BlurredBSDS300Dataset, self).__init__(root, patch_size, split)

        # trim images to even size
        self.images = self.images[..., :-1, :-1]
        self.kernel_size = kernel_size
        self.return_kernel = return_kernel

        # extract blur kernel (use an MNIST digit)
        self.kernel_dataset = MNIST('./', train=True, download=True,
                                    transform=Compose([Lambda(lambda x: np.array(x)),
                                                       ToTensor(),
                                                       Lambda(lambda x: x / torch.sum(x))]))

        kernels = torch.cat([x[0] for (x, _) in zip(self.kernel_dataset, np.arange(self.images.shape[0]))])
        kernels = torch.nn.functional.interpolate(kernels[:, None, ...], size=2*(kernel_size,))
        kernels = kernels / torch.sum(kernels, dim=(-1, -2), keepdim=True)
        self.kernel = kernels[[0]].repeat(kernels.shape[0], 1, 1, 1)

        # blur the images
        H = psf2otf(self.kernel, self.images.shape)
        self.blurred_images = ifft2(fft2(self.images) * H).real
        self.blurred_patches = self.patchify(self.blurred_images, patch_size)

        # save which blur kernel is used for each image
        self.patch_kernel = self.kernel.repeat(1, len(self.blurred_patches) // len(self.images), 1, 1)
        self.patch_kernel = self.patch_kernel.view(-1, *self.kernel.shape[-2:])

        # reshape kernel
        self.kernel = self.kernel.squeeze()

    def get_kernel(self, kernel_size, sigma):
        kernel = self.gaussian(kernel_size, sigma)
        kernel_2d = torch.matmul(kernel.unsqueeze(-1), kernel.unsqueeze(-1).t())
        return kernel_2d

    def __getitem__(self, idx):
        out = [self.blurred_images[idx][None, ...].to(device),
               self.images[idx][None, ...].to(device)]
        if self.return_kernel:
            out.append(self.kernel[[idx]].to(device))

        return out

In [18]:
def img_to_numpy(x):
    return np.clip(x.detach().cpu().numpy().squeeze().transpose(1, 2, 0), 0, 1)


def psf2otf(psf, shape):
    inshape = psf.shape
    psf = torch.nn.functional.pad(psf, (0, shape[-1] - inshape[-1], 0, shape[-2] - inshape[-2], 0, 0))

    # Circularly shift OTF so that the 'center' of the PSF is [0,0] element of the array
    psf = torch.roll(psf, shifts=(-int(inshape[-1] / 2), -int(inshape[-2] / 2)), dims=(-1, -2))

    # Compute the OTF
    otf = fft2(psf)

    return otf


def calc_psnr(x, gt):
    out = 10 * np.log10(1 / ((x - gt)**2).mean().item())
    return out


def wiener_deconv(x, kernel):
    snr = 100  # use this SNR parameter for your results
    H = psf2otf(kernel, x.shape).to(device)
    G = torch.conj(H) * 1/(1/snr + H*torch.conj(H)).to(device)
    return ifft2(fft2(x) * G).real


def load_models():
    model_deblur_denoise = Unet().to(device)
    model_deblur_denoise.load_state_dict(torch.load('utils/models/pretrained/deblur_denoise.pth', map_location=device))

    model_denoise = Unet().to(device)
    model_denoise.load_state_dict(torch.load('utils/models/pretrained/denoise.pth', map_location=device))

    return model_deblur_denoise, model_denoise

In [23]:
def evaluate_model():

    # create the dataset
    dataset = BlurredBSDS300Dataset(split='test')

    # load the models
    model_deblur_denoise, model_denoise = load_models()

    # put into evaluation mode
    model_deblur_denoise.eval()
    model_denoise.eval()
    
    PSNRs = []

    for sigma in [0.02]:
        
        psnr_m1 = []
        psnr_m2 = []
        psnr_m3 = []
        
        index = 0

        for image, gt, kernel in dataset:

            ################################################################################
            # TODO: Your code goes here!
            ################################################################################
            pass

            # add noise to the image
            
            image = image + sigma * torch.tensor(np.random.randn(*image.shape)).to(device)

            # apply each method (wiener deconvolution and the two networks)
            
            # Method 1: Wiener deconvolution
            
            image_deconv = wiener_deconv(image, kernel)
            
            psnr_m1.append(calc_psnr(image_deconv, gt))
            
            # Method 2: Neural Network for deconvolution + denoising
            
            image_neural = model_deblur_denoise(image.to(dtype=torch.float))
            
            psnr_m2.append(calc_psnr(image_neural, gt))
            
            # Method 3: Wiener deconvolution + Neural Network denoising
            
            image_deconv_neural = model_denoise(image_deconv.to(dtype=torch.float))
            
            psnr_m3.append(calc_psnr(image_deconv_neural, gt))

            filename = 'data/noisy_images/'
            skimage.io.imsave(filename + str(index) + '_gt.png', (img_to_numpy(gt)*255).astype(np.uint8))
            skimage.io.imsave(filename + str(index) + '_noisy.png', (img_to_numpy(image)*255).astype(np.uint8))
            skimage.io.imsave(filename + str(index) + '_deconv.png', (img_to_numpy(image_deconv)*255).astype(np.uint8))
            skimage.io.imsave(filename + str(index) + '_neural.png', (img_to_numpy(image_neural)*255).astype(np.uint8))
            skimage.io.imsave(filename + str(index) + '_deconv_neural.png', (img_to_numpy(image_deconv_neural)*255).astype(np.uint8))
            
            index = index + 1
    
        PSNRs.append(statistics.mean(psnr_m1))
        PSNRs.append(statistics.mean(psnr_m2))
        PSNRs.append(statistics.mean(psnr_m3))
    
    print(PSNRs)

            # save the psnrs

            # save out sample images to include in your writeup

            # HINT: use the calc_psnr function to calculate the PSNR, and use the
            # wiener_deconv function to perform wiener deconvolution

In [24]:
if __name__ == '__main__':
    evaluate_model()

[23.046599738102696, 26.93645977897099, 29.30034218455255]
