In [1]:
import torch
from torchvision.datasets import CIFAR10
import torchvision.transforms as tvt

In [2]:
#class to create Gaussian Noise (referred from https://discuss.pytorch.org/t/how-to-add-noise-to-mnist-dataset-when-using-pytorch/59745)
class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean

    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean

    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

In [None]:
#class for dataset to create a pair of noisy image and clean image
def ImageDataset(CIFAR10):
  def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
    super(ImageDataset, self).__init__(root, train, transform, target_transform, download)
    transform_noise=tvt.Compose([tvt.ToTensor(), tvt.Normalize( (0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)), AddGaussianNoise(0., 1.)])
    transform_clean=tvt.Compose([tvt.ToTensor(), tvt.Normalize( (0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)))
  def __getitem__(self, idx):
        features, target = super(ImageDataset, self).__getitem__(idx)
        noisy_image = transform_noise(features)
        clean_iamge = transform_clean(features)
        return noisy_image, clean_iamge


In [None]:
# dCNN model class used by the paper
class dCNN(torch.nn.Module):
  def __init__(self, channels, layers=17):
    super(dCNN, self).__init__()
    self.conv1 = torch.nn.Conv2d(in_channels=channels, out_channels=64, kernel_size=3, padding=1, bias=False)
    self.convList = torch.nn.ModuleList()
    for i in range(layers-2):
      self.convList.append(torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1, bias=False))
      self.convList.append(torch.nn.BatchNorm2d(64))
      self.convList.append(torch.nn.ReLU(inplace=True))
    self.conv2 = torch.nn.Conv2d(in_channels=64, out_channels=channels, kernel_size=3, padding=1, bias=False)
  def forward(self, x):
    x = self.conv1(x)
    x = self.convList(torch.nn.functional.relu(x, inplace=True))
    x = self.conv2(x)
    return x


In [None]:
#training function


In [None]:
#function to check PSNR(dB)/SSIM values of the resulted images

In [None]:
#main pipeline

#create dataset object (MNIST, Fashion-MNIST, CIFAR-10)
#create dataloader
#implement training function
#implement function to check PSNR/SSIM values