# Optimization in Fourier domain

In section *A* we illustrate the 2D convolution with both the matrix
multiplication setup and in the Fourier domain

In section *B* we illustrate convex optimization solutions
The example is a guided non-blind deblurring

In [None]:
#Useful imports
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
%reload_ext autoreload
%autoreload 2
from matplotlib import rc

rc('text', usetex=True)
font = {'family' : 'DejaVu Sans',
        'size'   : 20.0}
rc('font', **font)  # pass in the font dict as kwargs

## A - 2D convolution

In [None]:
s = 4

# Matlab output of "magic(4)" 
N = np.array(
    [[16,    2,     3,   13],
    [5 ,   11 ,   10 ,    8],
    [9 ,    7 ,    6 ,   12],
    [4 ,   14 ,   15 ,    1]])

k = np.array([-1, 1]) # convolution kernel
S = np.array([s, s])

### 1. Standard convolution

In [None]:
from scipy.signal import convolve2d
k_2D = k[::-1].reshape(1, -1)
Rconv = convolve2d(N, k_2D, mode='valid')
print('Rconv = \n', Rconv)

### 2. Matrix multiplication

One can write the 2d convolution as a matrix product of a circulant matrix, formed by the convolution kernel, and the vectorized input matrix (e.g. image).

In [None]:
from scipy.linalg import circulant
# create zero-padded and mirrored version of k.
k_prime = np.zeros(np.product(S))
k_prime[:len(k)] = k

n = np.reshape(N, [-1, 1])
K = circulant(k_prime).T
#matshow(K, "$circ(k')$")

Rmult_vert = K.dot(n)
Rmult = Rmult_vert.reshape(S)
print('Rmult = \n', Rmult)

### 3. Fourier Approach 

Since the convolution matrix is circulant, it is diagonal in Fourier domain. We can use this to solve the system more efficiently 

In [None]:
from psf2otf import psf2otf

Kf = psf2otf(k[::-1].reshape(1,2), S)
Nf = np.fft.fft2(N)

Rf = np.multiply(Kf, Nf)
Rfourier = np.real(np.fft.ifft2(Rf))

print('Rfourier = \n', Rfourier)

## B - Convex optimization

In [None]:
# Read the NIR image to be deblurred, and the RGB guide
N_b = plt.imread('../input/nir_blurry.tiff');
RGB = plt.imread('../input/rgb.tiff');

# For simplicity we only use the color luminance
Y = np.mean(RGB,axis=2);
# Rescale
Y = Y.astype(np.float)/np.max(Y);
N_b = N_b.astype(np.float)/np.max(N_b);

f = plt.figure()
f.set_size_inches(20, 10)
plt.subplot(1, 2, 1)
plt.imshow(N_b)
plt.title('Out-of-focus NIR image');
plt.subplot(1, 2, 2)
plt.imshow(RGB)
plt.title('Color guide');
plt.show()

In [None]:
from scipy.signal import convolve2d
from scipy.signal import gaussian
from math import ceil
from time import time


# We run a blur estimation algorithm, with the strong assumption of a
# constant blur across the image, and that the blur is Gaussian
# We obtain an estimate of sigma = 0.394
#### blur kernel ####

sigma = 0.394;
bsize = ceil(2*3*sigma+1)
b_one = gaussian(bsize, sigma, sym=True)
b = np.outer(b_one, b_one)

# make sure all elements sum up to 1 (like in matlabs fspecial('Gaussian') output)
b = b/np.sum(b)

lamda = 1.0;
eps = 2.2204e-16 # like in MATLAB

#### gradient kernels ####
f1 = np.array([-1, 1]).reshape([1, -1]);
f2 = f1.T;

#### color guides ####

y1 = convolve2d(Y, f1, 'same');
y2 = convolve2d(Y, f2, 'same');

# Fourier domain optimization solution:

start_time = time()


f1F = psf2otf(f1, Y.shape);
f2F = psf2otf(f2, Y.shape);

y1F = psf2otf(y1, Y.shape);
y2F = psf2otf(y2, Y.shape);

bF   = psf2otf(b, Y.shape);
N_bF = psf2otf(N_b, Y.shape);

# EQ (15)
I_x = np.multiply(np.conj(f1F),y1F) +  np.multiply(np.conj(f2F), y2F) + np.multiply(np.conj(bF), N_bF)
C  = lamda * (np.abs(f1F)**2 + np.abs(f2F)**2 + np.abs(bF)**2) + eps

NF = np.divide(I_x, C);
N = np.abs(np.fft.fftshift(np.fft.ifft2(NF)));

end_time = time()
timeTotal = end_time - start_time

print('Done with optimization after {:2.2f}s'.format(timeTotal))

In [None]:
# Comparing results

import scipy

def sh_computation(image):

    # it is assumed that the image is gray-scale.
    
    # blurring the image in both directions 
    #Hv = fspecial('Gaussian',[1, 2*3*1+1],1)
    Bver = scipy.ndimage.gaussian_filter(image, [1, 0])
    Bhor = scipy.ndimage.gaussian_filter(image, [0, 1])
    #Bver = imfilter(image, Hv, 'symmetric');                        ]

    # computing the edges of the original image and its blurred version in both
    # directions
    D_Fver = np.abs(np.gradient(image, axis=0)); 
    D_Fhor = np.abs(np.gradient(image, axis=1));
    D_Bver = np.abs(np.gradient(Bver, axis=0)); 
    D_Bhor = np.abs(np.gradient(Bhor, axis=1));

    Vver = D_Fver - D_Bver; 
    Vver[Vver < 0]= 0;
    Vhor = D_Fhor - D_Bhor; 
    Vhor[Vhor < 0]= 0;

    s_Fver = np.sum(D_Fver);
    s_Fhor = np.sum(D_Fhor);
    s_Vver = np.sum(Vver);
    s_Vhor = np.sum(Vhor);

    b_Fver = (s_Fver - s_Vver) / s_Fver;
    b_Fhor = (s_Fhor - s_Vhor) / s_Fhor;

    blur = max(b_Fver,b_Fhor);

    sharpness = 1 - blur;
    return sharpness

print(sh_computation(N_b), sh_computation(N))

fig = plt.figure()
fig.set_size_inches(20, 10)
plt.subplot(1,2,1)
plt.imshow(N_b, 'gray')
plt.title('Out-of-focus NIR image')
plt.subplot(1,2,2)
plt.imshow(N, 'gray')
plt.title('Deblurred NIR image')

fig = plt.figure()
fig.set_size_inches(20, 10)
plt.subplot(1,2,1)
plt.imshow(N_b[500:900, 650:1200], 'gray')
plt.title('Out-of-focus CROP');

plt.subplot(1,2,2)
plt.imshow(N[500:900, 650:1200], 'gray')
plt.title('Deblurred CROP');