In [None]:
import bokeh
from bokeh.io import output_notebook
import bokeh.plotting as plt
from bokeh.models.layouts import LayoutDOM
from bokeh.layouts import row, gridplot, Spacer
from IPython.display import display, HTML

import numpy as np
import sys
import scipy
import scipy.misc
import time
import scipy.ndimage
import random
import torch
import torchvision
from torchvision.transforms.functional import normalize
import matplotlib.cm as cm
sys.path.insert(0, '../src')
output_notebook()

In [None]:
from ebconv.splines import BSplineElement
from ebconv.kernel import CardinalBSplineKernel, create_random_centers
from ebconv.nn.functional import cbsconv, translate
from ebconv.kernel import sampling_domain

# CBSConv gradient visualization

## Create a single basis

Let's start by creating a single kernel basis and a centered diract delta.

In [None]:
INPUT_SIZE = (65, 65)
KERNEL_SIZE = (10, 10)
K = 3
SHIFT = (1, 2)

In [None]:
def generate_2d_input_and_kernel(input_size, kernel_size, k):
    input_center = np.array(input_size) // 2
    kernel_center = np.array(kernel_size) // 2
    
    input_ = torch.zeros(1, 1, *input_size, dtype=torch.double)
    input_[:, :, input_center[0], input_center[1]] = 1
    
    center = [(0.0, 0.0)]
    kernel = CardinalBSplineKernel.create(center, 3, k)
    sampling = np.meshgrid(*[sampling_domain(k_s) for k_s in kernel_size], indexing='ij')
    basis = kernel(*sampling)[0]

    values = []
    fig = plt.figure(
        match_aspect=True, 
        x_range=(-input_center[0], input_center[0]), 
        y_range=(-input_center[1], input_center[1]),
        tooltips=[("x", "$x"), ("y", "$y"), ("value", "@image")]
    )
    fig.title.text = 'Input'
    fig.image(
        image=[input_.data.numpy().squeeze()], 
        x=-input_center[0], y=-input_center[1],
        dw=input_size[0], dh=input_size[1],
        palette='Viridis256'
    )
    values.append(fig)

    fig = plt.figure(
        match_aspect=True, 
        x_range=(-input_center[0], input_center[0]), 
        y_range=(-input_center[1], input_center[1]),
        tooltips=[("x", "$x"), ("y", "$y"), ("value", "@image")]
    )
    fig.title.text = 'Kernel'
    fig.image(
        image=[basis.squeeze()], 
        x=-kernel_center[0], y=-kernel_center[1],
        dw=kernel_size[0], dh=kernel_size[1],
        palette='Viridis256'
    )
    values.append(fig)

    grid = gridplot([values])
    plt.show(grid)
    return input_, kernel, kernel_size
    
input_, kernel, kernel_size = generate_2d_input_and_kernel(INPUT_SIZE, KERNEL_SIZE, K)

## Convolution and shifted convolution
Let's now compute the convolution and translate the result. We are doing so to test that during the learning phase we are actually going in the right path.

In [None]:
def conv_shift_conv(input_, kernel):
    center = torch.tensor(kernel.c, requires_grad=True)
    center = center.reshape(1, 1, 2)
    scaling = torch.tensor(kernel.s, requires_grad=False).reshape(1, 1)
    weights = torch.ones(1, 1, 1, requires_grad=False, dtype=torch.double)

    out = cbsconv(input_, KERNEL_SIZE, weights, center, scaling, K)
    shifted_out = translate(out.data.clone(), SHIFT)
    output_size = np.array(out.shape[2:])
    output_center = output_size // 2

    values = []
    fig = plt.figure(
        match_aspect=True, 
        x_range=(-output_center[0], output_center[0]), 
        y_range=(-output_center[1], output_center[1]),
        tooltips=[("x", "$x"), ("y", "$y"), ("value", "@image")]
    )
    fig.title.text = 'Convolution'
    fig.image(
        image=[out.data.numpy().squeeze()], 
        x=-output_center[0], y=-output_center[1],
        dw=output_size[0], dh=output_size[1],
        palette='Viridis256'
    )
    values.append(fig)
    
    fig = plt.figure(
        match_aspect=True, 
        x_range=(-output_center[0], output_center[0]), 
        y_range=(-output_center[1], output_center[1]),
        tooltips=[("x", "$x"), ("y", "$y"), ("value", "@image")]
    )
    fig.title.text = 'Shifted convolution'
    fig.image(
        image=[shifted_out.data.numpy().squeeze()], 
        x=-output_center[0], y=-output_center[1],
        dw=output_size[0], dh=output_size[1],
        palette='Viridis256'
    )
    values.append(fig)
    grid = gridplot([values])
    plt.show(grid)
    return shifted_out, center, weights, scaling
    
shifted_conv, center, weights, scaling = conv_shift_conv(input_, kernel)

## Plot the gradient field

Now let's try to compute for each position the gradient to check the gradiant landscape

In [None]:
def check_gradient_landscape(input_, target_conv, weights, scaling, n=100):
    # Generate a meshgrid to sample the gradient value.
    kx = int(KERNEL_SIZE[0] * 1.2)
    ky = int(KERNEL_SIZE[1] * 1.2)
    x = torch.linspace(0, kx, n, dtype=torch.double)
    x -= kx / 2
    y = torch.linspace(0, ky, n, dtype=torch.double)
    y -= ky / 2
    xx, yy = torch.meshgrid(x, y)
    xy = torch.stack([xx, yy])
    xy = xy.permute(1, 2, 0).reshape(-1, 1, 1, 2)
    loss = torch.nn.MSELoss()
    gradient = []
    loss_data = []
    for i, ixy in enumerate(xy):
        ixy.requires_grad = True
        ixy.retain_grad()
        out = cbsconv(input_, KERNEL_SIZE, weights, ixy, scaling, K)
        l_out = loss(out, target_conv)
        loss_data.append(l_out.item())
        if l_out.requires_grad:
            l_out.backward()
            data = np.concatenate(
                [ixy.data.numpy().squeeze(), ixy.grad.data.numpy().squeeze()])
            gradient.append(data)
            ixy.grad.data.zero_()
    return np.array(gradient), (xx.data.numpy(), xy.data.numpy(), loss_data), (kx, ky)

gradient, loss_data, samples_size = check_gradient_landscape(input_, shifted_conv, weights, scaling)

In [None]:
def plot_gradient(gradient, loss_data, samples_size):
    scaling = 1e3
    samples_center = np.array(samples_size) / 2
    y0 = -gradient[:, 0]
    x0 = -gradient[:, 1]
    _, _, loss = loss_data
    loss = np.flip(np.array(loss)).reshape(100, 100)
    x1 = x0 + scaling * gradient[:, 3]
    y1 = y0 + scaling * gradient[:, 2]
    fig = plt.figure(
        x_range = (-samples_center[0], samples_center[0]),
        y_range = (-samples_center[1], samples_center[1]),
    )
    fig.image(
        [loss], 
        x=-samples_center[0], y=-samples_center[1], 
        dw=samples_size[0], dh=samples_size[1]
    )
    fig.segment(x0, y0, x1, y1)
    fig.title.text = 'Gradient'
    
    plt.show(fig)
    
plot_gradient(gradient, loss_data, samples_size)