In [2]:
from torch.autograd import Variable
import torch
import torch.nn as nn
import numpy as np
from numpy.fft import fft2, fftshift, ifft2, ifftshift
import os
from torchvision import transforms
from PIL import Image
import torch.nn as nn
from torch.utils.data import Dataset
import torch.nn.functional as F
import scipy.io as sio

In [None]:
class MyConv2d(nn.Module):
    """
    Performs circular convolution on images with a constant filter.
    Attributes
    ----------
        kernel (torch.FloatTensor): size c*c*h*w filter
        mode                 (str): 'single' or 'batch'
        stride               (int): dilation factor
        padding                   : instance of CircularPadding or torch.nn.ReplicationPad2d
    """
    def __init__(self, kernel, mode, pad_type = 'replicate', padding=0, stride=1):
        """
        Parameters
        ----------
            gpu                  (str): gpu id
            kernel (torch.FloatTensor): convolution filter
            mode                 (str): indicates if the input is a single image of a batch of images
            pad_type             (str): padding type (default is 'circular')
            padding              (int): padding size (default is 0)
            stride               (int): dilation factor (default is 1)
        """
        super(MyConv2d, self).__init__()
        self.kernel   = nn.Parameter(kernel,requires_grad=False)   
        self.mode     = mode #'single' or 'batch'
        self.stride   = stride
        if padding==0:
            size_padding = int((kernel[0,0].size(0)-1)/2)
        else:
            size_padding = padding
        if pad_type == 'replicate':
            self.padding = nn.ReplicationPad1d(size_padding)
        if pad_type == 'reflect':
            self.padding = nn.ReflectionPad1d(size_padding)
        if pad_type == 'replicate':
            self.padding = nn.ConstantPad1d(size_padding,0)
            
    def forward(self, x): 
        """
        Performs a 2-D circular convolution.
        Parameters
        ----------
            x (torch.FloatTensor): image(s), size n*c*h*w 
        Returns
        -------
            (torch.FloatTensor): result of the convolution, size n*c*h*w if mode='single', 
                                 size c*h*w if mode='batch'
        """
        if self.mode == 'single':
            return F.conv1d(self.padding(x.unsqueeze(0)), self.kernel, stride=self.stride).data[0]
        if self.mode == 'batch':
            return F.conv1d(self.padding(x.data), self.kernel, stride=self.stride)