In [None]:
#Useful imports
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
%reload_ext autoreload
%autoreload 2
from matplotlib import rc
rc('text', usetex=True)

In [None]:
def matshow(A, title=''):
    plt.matshow(A)
    plt.axis('off')
    plt.title(title)

## Conventional Approach with circulant matrix

One can write the 2d convolution as a matrix product of a circulant matrix, formed by the convolution kernel, and the vectorized input matrix (e.g. image).

In [None]:
k = np.array([-1, 1]) # convolution kernel
S = np.array([3, 3]) # image size

# create Image with random entries between 0 and 1
I_i = np.random.random(size = S) 

# create zero-padded and mirrored version of k.
k_prime = np.zeros(np.product(S))
k_prime[:len(k)] = k[::-1]
print(I_i)
print(k_prime)

matshow(I_i, '$I_i$')

In [None]:
from scipy.linalg import circulant
I_i_vect = np.reshape(I_i, [-1, 1])
circ_k_prime = circulant(k_prime).T
matshow(circ_k_prime, "$circ(k')$")

I_nv_vert = circ_k_prime.dot(I_i_vect)
I_nv = I_nv_vert.reshape(S)
matshow(I_nv, '$I_{nv}$')

In [None]:
from scipy.signal import convolve2d
I_conv = convolve2d(I_i, k.reshape(1,2), mode='valid')
matshow(I_conv, '$I_{nv}$')

assert np.allclose(I_conv, I_nv[:, :2])

## Fourier Approach 

Since the convolution matrix is circulant, it is diagonal in Fourier domain. We can use this to solve the system more efficiently 

In [None]:
# the following two functions are taken from 
# https://github.com/aboucaud/pypher
def zero_pad(image, shape, position='corner'):
    """
    Extends image to a certain size with zeros
    Parameters
    ----------
    image: real 2d `numpy.ndarray`
        Input image
    shape: tuple of int
        Desired output shape of the image
    position : str, optional
        The position of the input image in the output one:
            * 'corner'
                top-left corner (default)
            * 'center'
                centered
    Returns
    -------
    padded_img: real `numpy.ndarray`
        The zero-padded image
    """
    shape = np.asarray(shape, dtype=int)
    imshape = np.asarray(image.shape, dtype=int)

    if np.alltrue(imshape == shape):
        return image

    if np.any(shape <= 0):
        raise ValueError("ZERO_PAD: null or negative shape given")

    dshape = shape - imshape
    if np.any(dshape < 0):
        raise ValueError("ZERO_PAD: target size smaller than source one")

    pad_img = np.zeros(shape, dtype=image.dtype)

    idx, idy = np.indices(imshape)

    if position == 'center':
        if np.any(dshape % 2 != 0):
            raise ValueError("ZERO_PAD: source and target shapes "
                             "have different parity.")
        offx, offy = dshape // 2
    else:
        offx, offy = (0, 0)

    pad_img[idx + offx, idy + offy] = image
    return pad_img

def psf2otf(psf, shape):
    """
    Convert point-spread function to optical transfer function.
    Compute the Fast Fourier Transform (FFT) of the point-spread
    function (PSF) array and creates the optical transfer function (OTF)
    array that is not influenced by the PSF off-centering.
    By default, the OTF array is the same size as the PSF array.
    To ensure that the OTF is not altered due to PSF off-centering, PSF2OTF
    post-pads the PSF array (down or to the right) with zeros to match
    dimensions specified in OUTSIZE, then circularly shifts the values of
    the PSF array up (or to the left) until the central pixel reaches (1,1)
    position.
    Parameters
    ----------
    psf : `numpy.ndarray`
        PSF array
    shape : int
        Output shape of the OTF array
    Returns
    -------
    otf : `numpy.ndarray`
        OTF array
    Notes
    -----
    Adapted from MATLAB psf2otf function
    """
    if np.all(psf == 0):
        return np.zeros_like(psf)

    inshape = psf.shape
    # Pad the PSF to outsize
    psf = zero_pad(psf, shape, position='corner')

    # Circularly shift OTF so that the 'center' of the PSF is
    # [0,0] element of the array
    for axis, axis_size in enumerate(inshape):
        psf = np.roll(psf, -int(axis_size / 2), axis=axis)

    # Compute the OTF
    otf = np.fft.fft2(psf)

    # Estimate the rough number of operations involved in the FFT
    # and discard the PSF imaginary part if within roundoff error
    # roundoff error  = machine epsilon = sys.float_info.epsilon
    # or np.finfo().eps
    n_ops = np.sum(psf.size * np.log2(psf.shape))
    otf = np.real_if_close(otf, tol=n_ops)
    return otf

In [None]:
# TODO: for this implementation we don't need to mirror k! 
# we should understand why MATLAB's psf2otf does need mirroring...
otf = psf2otf(k[::1].reshape(1,2), S)
fft_i = np.fft.fft2(I_i)
fft_nv = np.fft.fft2(I_nv)

# TODO: is this close enough? 
print(fft_nv - np.multiply(otf, fft_i))
#assert np.allclose(fft_nv, )


In [None]:
print(otf)
print(fft_i)
I_i_recovered = np.fft.ifft2(np.divide(fft_nv, otf, where=otf>0))
print(I_i_recovered)
matshow(np.real(I_i_recovered))
matshow(I_i)