# Translations and Convolutions

In preparation for the use of autograd to calculate the derivatives in scarlet, this is a test of using FFT's in pytorch to perform fractional translations and the errors and biases that they introduce.

In [None]:
import numpy as np
import scipy.ndimage.filters as spif

try:
    import torch
    use_torch = True
    import scarlet.torch
except ImportError:
    use_torch = False

import scarlet

%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt

matplotlib.rc('image', cmap='inferno')
matplotlib.rc('image', interpolation='none')

## Testing FFT convolution with images

We first test a point source convolved with a Gaussian PSF to test how using FFT's in numpy and pytorch compare to true convolution in scipy in terms of accuracy and runtime.

First we generate our PSF and deconvolved image.

In [None]:
def gauss2d(x, y, A=1, x0=0, y0=0, sigma_x=2, sigma_y=2):
    # Return a 2D circular Gaussian
    return A * np.exp(-(0.5*(x-x0)**2/sigma_x**2+0.5*(y-y0)**2/sigma_y**2))

def create_psf(radius):
    # Create a PSF with a given radius
    x = np.linspace(-radius, radius, 2*radius + 1)
    y = np.linspace(-radius, radius, 2*radius + 1)
    # Notice that in pytorch meshgrid is in reverse form of numpy
    X, Y = np.meshgrid(x, y)
    return gauss2d(X, Y)

def project_image(img, shape):
    # Project an image centered in a larger image
    # We'll need this later, so we might as well define it now
    result = np.zeros(shape)
    N, M = shape
    cy = N // 2
    cx = M // 2
    icy = img.shape[0] // 2
    icx = img.shape[1] // 2
    yslice = slice(cy-icy, cy+icy+1)
    xslice = slice(cx-icx, cx+icx+1)
    result[yslice, xslice] = img
    return result

psf_radius = 10
cx = cy = 10
N = 2*cx + 1

psf = create_psf(psf_radius)
plt.imshow(psf)
plt.title("PSF")
plt.show()
deconvolved = np.zeros((N, N))
deconvolved[cy, cx] = 1
plt.imshow(deconvolved)
plt.title("deconvolved image")
plt.show()

First we test the accuracy of the three different algorithms. Below is the code for numpy, while the pytorch implementation is coded into scarlet with the necessary modifications to work on `Tensors` as opposed to numpy arrays.

In [None]:
# Perform the true convolution using scipy
scipy_img = spif.convolve(deconvolved, psf)
# Perform FFT convolution using numpy and scarlet/pytorch
np_img = scarlet.fft_convolve(deconvolved, psf)

plt.imshow(scipy_img)
plt.title("truth")
plt.colorbar()
plt.show()

residual = scipy_img-np_img
max_residual = np.max(np.abs(residual))
plt.imshow(residual, vmin=-max_residual, vmax=max_residual, cmap="seismic")
plt.title("numpy FFT residual")
plt.colorbar()
plt.show()

if use_torch:
    torch_deconvolved = torch.tensor(deconvolved)
    torch_psf = torch.tensor(psf)
    torch_img = scarlet.torch.filters.fft_convolve(torch_deconvolved, torch_psf)

    residual = scipy_img-np.array(torch_img)
    max_residual = np.max(np.abs(residual))
    plt.imshow(residual, vmin=-max_residual, vmax=max_residual, cmap="seismic")
    plt.title("pytorch FFT residual")
    plt.colorbar()
    plt.show()

Next we test all three algorithms for speed

In [None]:
img_sizes = np.array([10, 20, 30, 40, 50])
times = np.zeros((3, len(img_sizes)))
time_stdev = np.zeros((3, len(img_sizes)))

# scipy
for n, img_size in enumerate(img_sizes):
    cx = cy = img_size
    N = 2*cx + 1
    
    psf = project_image(create_psf(psf_radius), (N, N))
    deconvolved = np.zeros((N, N))
    deconvolved[cy, cx] = 1
    
    result = %timeit -o -q -n 10 spif.convolve(deconvolved, psf)
    times[0, n] = result.average
    time_stdev[0, n] = result.stdev

In [None]:
# numpy
for n, img_size in enumerate(img_sizes):
    cx = cy = img_size
    N = 2*cx + 1
    
    psf = project_image(create_psf(psf_radius), (N, N))
    deconvolved = np.zeros((N, N))
    deconvolved[cy, cx] = 1
    
    # numpy
    result = %timeit -o -q -n 100 scarlet.fft_convolve(deconvolved, psf)
    times[1, n] = result.average
    time_stdev[1, n] = result.stdev

In [None]:
# pytorch
if use_torch:
    for n, img_size in enumerate(img_sizes):
        cx = cy = img_size
        N = 2*cx + 1

        psf = project_image(create_psf(psf_radius), (N, N))
        deconvolved = np.zeros((N, N))
        deconvolved[cy, cx] = 1
    
        torch_psf = torch.tensor(psf)
        torch_deconvolved = torch.tensor(deconvolved)
        # pytorch
        result = %timeit -o -q -n 100 scarlet.torch.filters.fft_convolve(torch_deconvolved, torch_psf)
        times[2, n] = result.average
        time_stdev[2, n] = result.stdev

In [None]:
x = img_sizes*2 + 1
plt.errorbar(x, times[0], time_stdev[0], label="scipy")
plt.errorbar(x, times[1], time_stdev[1], label="numpy")
if use_torch:
    plt.errorbar(x, times[2], time_stdev[2], label="pytorch")
plt.xlabel("image pixels")
plt.ylabel("time (s)")
plt.yscale("log")
plt.legend()
plt.show()

if use_torch:
    result_str = "For {0} pixels: scipy = {1:.2f} ms, numpy = {2:.3f} ms, pytorch = {3:.3f} ms"
else:
    result_str = "For {0} pixels: scipy = {1:.2f} ms, numpy = {2:.3f} ms"
print(result_str.format(x[0], 1000*times[0,0], 1000*times[1,0], 1000*times[2,0]))
print(result_str.format(x[-1], 1000*times[0,-1], 1000*times[1,-1], 1000*times[2,-1]))

So we see that FFT convolution is significantly faster than true convlutions, even for images this small, and the effect grows exponentially with image size.

The main speed difference between numpy and pytorch is that pytorch has much slower indexing. This means that `sinc` is *much* slower in pytorch, so we use numpy to calcualte the Lanczos kernel and then convert the result to a pytorch `Tensor`. There is an open [ticket](https://github.com/pytorch/pytorch/issues/5388) out in pytorch to fix this, but until that happens we will probably see slower than expected runtimes with pytorch.

## Resampling the pixel grid

Recall that in scarlet all of the sources are modeled such that they are in the center of a bounding box that is reprojected into the blended scence, usually at some fractional pixel location. The Whittaker-Shannon Sampling Theorem tells us that we can perfectly reconstruct a continuous signal from a set of well-sampled ($\Omega$-bandlimited) discrete measurements with the formula

$$f(t) = \sum_{k=-\infty}^{k=\infty} f\left(\frac{k \pi}{\Omega}\right) \textrm{sinc}\left(\frac{\Omega t}{\pi} -k \right)$$

where $F(\omega)$, the fourier transform of $f$, is piecewise continuous on $[-\Omega, \Omega]$, $t\in\mathcal{R}$, and the samples are obtained at the points $t_k=k/\pi$. Unfortunately this function is not practically useful, as sinc falls off slowly and technically requires an infinite number of samples to prefectly reconstruct the signal.

Practically we must use some windowed function that approximates sinc with just a few dozen samples.

### Cubic Splines

The most common technique in computer graphics for resampling is the cubic spline, due to its speed an accuracy at approximating the sinc. Splines are windows on a sinc function of the form

$$ w(x, a, b) = \frac{1}{6}\cdot
\begin{cases}
(-6a-9b+12)\cdot |x|^3 + (6a+12b-18)\cdot |x|^2 -2b + 6 & 0\leq |x| \leq1 \\
(-6a-b)\cdot |x|^2 + (30a + 6b)\cdot |x|^2 -(48a + 12b)|x|+24a + 8b & 1 \leq |x| \leq 2 \\
0 & \textrm{otherwise}
\end{cases}
$$

where $a$ and $b$ are parameters that determine the sharpness and shape of the spines respectively. Two of the most common splines used in computer graphics are the Catmull-Rom spline ($a=0.5$, $b=0$) and the Mitchel-Netravali spline ($a=b=1/3$).

### Lanczos 

It is generally accepted that a better approximation to the sinc is a Lanczos kernel

$$L(x) =
\begin{cases}
\textrm{sinc}(x)\cdot \textrm{sinc}\left(x/a \right) & -a \leq x \leq a \\
0 & \textrm{otherwise}
\end{cases}
$$

is a better approximation to the sinc but often computationally more expensive due to its use of trig functions. 

### Testing resampling algorithms

Understanding the limitations of these interpolations and any biases that they introduce is important for our evaluation of scarlet.

We first test the accuracy of the three methods as well as linear interpolation of the nearest neighbors.

In [None]:
def test_resampling_algorithm(x0, y0, radius=10, moment=1, **kwargs):
    """For a given x and y offset, calculate the true gaussian
    and resampled gaussian at a fractional position and return
    the given moment.
    """ 
    x = np.linspace(-radius, radius, 2*radius + 1)
    y = np.linspace(-radius, radius, 2*radius + 1)

    X,Y = np.meshgrid(x, y)
    centered = gauss2d(X, Y, x0=0, y0=0)
    truth = gauss2d(X, Y, x0=x0, y0=y0)
    if use_torch:
        interpolated = scarlet.filters.fft_resample(torch.tensor(centered), x0, y0, **kwargs)
    else:
        interpolated = scarlet.filters.fft_resample(centered, x0, y0, **kwargs)
    return (truth-interpolated)**moment

def ordinal(x):
    """Return the correct ordinal for a number
    """
    _x = x % 10
    if _x == 1:
        return "st"
    if _x == 2:
        return "nd"
    if _x == 3:
        return "rd"
    return "th"

def test_resampling(moment, algorithm, radius=10, **kwargs):
    """Test the accuracy of an interpolation algorithm
    
    Iterate through a range of dx and dy values to calculate the moments of the interpolated residuals
    """
    kernel = kernels[algorithm]
    xs = np.linspace(0,.5, 6)
    N = len(xs)
    moments = np.zeros((N, N, 2*radius+1, 2*radius+1))

    # First calculate the 1st and 2nd moment
    for i in range(N):
        dy = xs[i]
        for j in range(N):
            dx = xs[j]
            moments[i, j] = test_resampling_algorithm(dx, dy, moment=moment, kernel=kernel, **kwargs)

    # Set the same color mapping for all of the plots of the same moment
    vmax = np.max(np.abs(moments))
    if moment % 2:
        vmin = -vmax
        cmap = "seismic"
    else:
        vmin = 0
        cmap = "inferno"

    # Generate a plot for each dx,dy combination, for the 1st and 2nd moments
    fig, axes = plt.subplots(nrows=N, ncols=N, figsize=(15, 13), gridspec_kw={"hspace":0, "wspace": .1})
    for i in range(N):
        dy = xs[i]
        for j in range(N):
            dx = xs[j]
            ax = axes[i][j]
            _moment = moments[i, j]
            im = ax.imshow(_moment, vmin=vmin, vmax=vmax, cmap=cmap)
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_ticks([])
            if i == 0:
                ax.set_title("dx={0:.1f}".format(dx))
            if j == 0:
                ax.set_ylabel("dy={0:.1f}".format(dy))
    cbar = fig.colorbar(im, ax=axes.ravel().tolist(), pad=.02)
    cbar.set_label("{0}{1} moment".format(moment, ordinal(moment)))
    fig.suptitle("Algorithm: {0}".format(algorithm), y=.92)
    plt.show()

if use_torch:
    kernels = {
        "lanczos": scarlet.torch.filters.lanczos,
        "catmull_rom": scarlet.torch.filters.catmull_rom,
        "mitchel_netravali": scarlet.torch.filters.mitchel_netravali,
        "bilinear": scarlet.torch.filters.bilinear_interpolation,
    }
else:
    kernels = {
        "lanczos": scarlet.filters.lanczos,
        "catmull_rom": scarlet.filters.catmull_rom,
        "mitchel_netravali": scarlet.filters.mitchel_netravali,
        "bilinear": scarlet.filters.bilinear_interpolation,
    }

for kernel in kernels:
    test_resampling(1, kernel)


for kernel in kernels:
    test_resampling(2, kernel)


We see that the Catmull-Rom and Lanczos kernels have 1st moments that are qualitatively different but similar in magnitude, and more accurate than the Mitchel-Netravali spline, however the Lanczos slightly outperforms both of them in the second moment. Unsurprisingly the linear interpolation is the least accurate of the four in both moments.

When we calcualte the runtime of each algorithm we get a surprise:

In [None]:
img = create_psf(10)

for kernel in kernels:
    print(kernel)
    if use_torch:
        %timeit scarlet.torch.filters.fft_resample(torch.tensor(img), .1, .4, kernels[kernel])
    else:
        %timeit scarlet.filters.fft_resample(img, .1, .4, kernels[kernel])

Apparently the implementation of piecewise functions in python is slower than using trig functions, so in this implementation the Lanczos kernel is not only the most accurate, but the fastest as well.