In [None]:
#Useful imports
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
%reload_ext autoreload
%autoreload 2
from matplotlib import rc
rc('text', usetex=True)

In [None]:
FOLDER = 'plots/'

def full_frame(width=None, height=None):
    ''' Nearly completely remove all borders from a plot. '''
    import matplotlib as mpl
    mpl.rcParams['savefig.pad_inches'] = 0
    figsize = None if width is None else (width, height)
    fig = plt.figure(figsize=figsize)
    ax = plt.axes([0,0,1,1], frameon=False)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    plt.autoscale(tight=True)
    return fig, ax
    
def plot_matrix(A, saveas='', cmap='RdBu', vmin=None, vmax=None):
    fig, ax = full_frame(A.shape[1], A.shape[0])
    ax.matshow(A, cmap=cmap, vmin=vmin, vmax=vmax)
    
    if saveas != '':
        plt.savefig(FOLDER + saveas)

## Setup

In [None]:
kernel = np.array([[-1, -2], [-3, -4]])
S = np.array([3, 4])
np.random.seed(1)
image = np.random.uniform(0, 1.0, size=S)
image[0, 0] = 1.0
image[-1, -1] = 0.0

plot_matrix(image, '', 'gray')
plot_matrix(kernel, '', vmax=0)

## Original 2D convolution

In [None]:
from scipy.signal import convolve2d

result_orig = convolve2d(image, kernel, mode='valid', boundary='symm')

# make sure the min and max of matrix plots are always the same. 
vmin = np.min(result_orig)
vmax = np.max(result_orig)
plot_matrix(result_orig, 'k.eps', 'gray', vmin=vmin, vmax=vmax)

## Vectorized 2D convolution

In [None]:
from scipy.linalg import circulant

## Setup
n = S[0]*S[1]

kernel_vect = np.zeros(S)
kernel_vect[:kernel.shape[0], :kernel.shape[1]] = kernel[::-1, ::-1]
kernel_vect.resize([1, n])
plot_matrix(kernel_vect, '', vmax=0)

# Because of the convention of scipy's circulant implementation we need to undo their flipping, thus the transpose. 
kernel_matrix = np.zeros((n, n))

kernel_matrix = circulant(kernel_vect).T
plot_matrix(kernel_matrix, 'k_matrix.eps', vmax=0)

image_vect = np.reshape(image, [-1, 1])
plot_matrix(image_vect, '', 'gray', vmin=vmin, vmax=vmax)
plot_matrix(image_vect.T, '', 'gray', vmin=vmin, vmax=vmax)

In [None]:
## Convolution

result_vect_vect = kernel_matrix.dot(image_vect)
plot_matrix(result_vect_vect, '', 'gray')

result_vect = result_vect_vect.reshape(S)
plot_matrix(result_vect, '', 'gray', vmin=vmin, vmax=vmax)

assert np.allclose(result_vect[:2, :3], result_orig)

## Vectorized Fourier-domain convolution

In [None]:
## first, we find out what kernel has to be convolved 
## with the vectorized 1D input image, to yield the vectorized
## 1D output image. 
# now given the kernel we should be able to move to fourier domain. 
test_kernel = np.zeros_like(image)
test_kernel[:kernel.shape[0], :kernel.shape[1]] = kernel
test_kernel.resize([1, n])
plot_matrix(test_kernel)


# this...
result_test_vect = np.convolve(image_vect.flatten(), test_kernel.flatten(), mode='full')
result_reduced = result_test_vect[int(np.floor(n/2))-1: int(np.floor(n/2))-1+n]
# ... is equivalent to:
#result_reduced = np.convolve(image_vect.flatten(), test_kernel.flatten(), mode='same')

result_test = result_reduced.reshape(S)

assert np.allclose(result_test[:2, :3], result_orig)

In [None]:
fft_kernel_vect = np.fft.fft(test_kernel).reshape((-1, 1))
fft_image_vect = np.fft.fft(image_vect).reshape((-1, 1))

## START DEBUGGING
# test that fft/ifft is working correctly 
assert np.allclose(np.fft.ifft(np.fft.fft(test_kernel)), test_kernel)

# test that f(a*a) = f(a).f(a)
test = np.convolve(test_kernel.flatten(), test_kernel.flatten(), 'full')
test1 = np.fft.fft(test[:n])
test2 = np.multiply(np.fft.fft(test_kernel), np.fft.fft(test_kernel))
assert np.allclose(test1, test2)

# test that f(a*b) = f(a).f(b)
# TODO: why is this stuff not working? Do we need zero padding? 
test = np.convolve(test_kernel.flatten(), image_vect.flatten(), 'full')
test1 = np.fft.fft(test[:n].flatten())
test2 = np.multiply(np.fft.fftshift(np.fft.fft(test_kernel.flatten())), np.fft.fftshift(np.fft.fft(image_vect.flatten())))
if not np.allclose(test1, test2):
    print('1, 2 not equal: \n{}, \n{}'.format(test1, test2))

# above should be equal to fourier of result. 
test3 = np.fft.fftshift(np.fft.fft(result_vect_vect.flatten()))
if not np.allclose(test1, test3):
    print('1, 3 not equal: \n{}, \n{}'.format(test1, test3))
if not np.allclose(test2, test3):
    print('2, 3 not equal: \n{}, \n{}'.format(test1, test3))
    
##  DONE DEBUGGING

fft_result_fft_vect = np.multiply(fft_kernel_vect, fft_image_vect)

result_fft_vect = np.fft.ifft(fft_result_fft_vect)
result_fft = np.reshape(np.real(result_fft_vect), S)

plot_matrix(result_fft, '', 'gray', vmin=vmin, vmax=vmax)

In [None]:
# Create circulant matrix
kernel_fft = np.fft.fftshift(np.fft.fft2(kernel_matrix))
kernel_diagonal = np.real(kernel_fft)
plot_matrix(np.real(kernel_fft), '')
plot_matrix(np.imag(kernel_fft), '')

## Practical implementation

In [None]:
## Setup

from psf2otf import psf2otf
plot_matrix(kernel[::-1, ::-1])
print(kernel)
test_psf2otf = psf2otf(kernel[::-1], S)
print(test_psf2otf)
abs_psf2otf = np.real(test_psf2otf)
print(abs_psf2otf)
plot_matrix(abs_psf2otf, '')

In [None]:
otf_kernel = psf2otf(kernel, S)

fft_image = np.fft.fft2(image)

fft_result_psf2otf = np.multiply(otf_kernel, fft_image)

result_psf2otf = np.fft.ifft2(fft_result_psf2otf)

plot_matrix(np.real(result_psf2otf), '', 'gray', vmin=vmin, vmax=vmax)

assert np.allclose(result_psf2otf[:2, :3], result_orig)

In [None]:
from psf2otf import zero_pad

#psf = np.array([[-1, 1]]) 
print(kernel)
psf = np.array([[-1, -2], [-3, -4]])
print('psf input', psf)
shape = np.array([3, 4])

inshape = psf.shape
# Pad the PSF to outsize
psf = zero_pad(psf, shape, position='corner')
plot_matrix(psf)
print(psf)

# Circularly shift OTF so that the 'center' of the PSF is
# [0,0] element of the array
for axis, axis_size in enumerate(inshape):
    psf = np.roll(psf, -int(axis_size / 2), axis=axis)

plot_matrix(psf)

# Compute the OTF
otf = np.fft.fft2(psf)
plot_matrix(np.real(otf))
plot_matrix(np.imag(otf))
plot_matrix(np.abs(otf))

# Estimate the rough number of operations involved in the FFT
# and discard the PSF imaginary part if within roundoff error
# roundoff error  = machine epsilon = sys.float_info.epsilon
# or np.finfo().eps
n_ops = np.sum(psf.size * np.log2(psf.shape))
otf = np.real_if_close(otf, tol=n_ops)

In [None]:
## Assumption: otf is simply the fft2(psf). Note otf gets multiplied pointwise with fft2(image), 
## and then result = ifft2(otf .* fft2(image)).
## It would make sense that psf convolved with image gives exactly result.

fft_image = np.fft.fftshift(image)
output = convolve2d(psf, image, mode='same', boundary='symm')

plot_matrix(output, '', 'gray', vmin=vmin, vmax=vmax)

print(np.divide(output[1:, :3], result_orig))
assert np.allclose(output[1:, :3], result_orig)