In [None]:
from ipywidgets import GridspecLayout, Output
from IPython.display import display

import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.insert(0, '../src')
import ebconv
import ebconv.nn.functional
import scipy
import scipy.misc
import time
import scipy.ndimage
import torch
import torchvision
from torchvision.transforms.functional import normalize
import matplotlib.cm as cm


In [None]:
from ebconv.splines import BSpline
from ebconv.kernel import BSplineKernel, create_random_centers
from ebconv.utils import tensordot

# BSplines examples

In [None]:
def b_splines_example():
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(21, 5))
    x = np.linspace(-8, 8, 1000)
    s = np.zeros_like(x)
    for c in np.arange(-4, 4 + 1):
        b = BSpline.create_cardinal(c)
        y = b(x)
        s += y
        ax1.plot(x, y)
    ax1.plot(x, s, label='BSplines sum', linewidth=1.5)
    ax1.grid()
    ax1.legend()
    ax1.set_title("Standard basis on interval")

    x = np.linspace(-5, 5, 1000)
    for s, c, n in zip((1, 2, 0.5), (-0.3, -1.7, 2.5), (2, 3, 4)):
        b = BSpline.create_cardinal(c, s, n)
        ax2.plot(x, b(x), label=f'c={c}, s={s}, n={n}')
    ax2.grid()
    ax2.legend()
    ax2.set_title("Different parameters")
    plt.show()

b_splines_example()

# 2D BSpline Kernel basis

In [None]:
def bsplinekernel_example(N=16, s=48, k=4):
    """Example of how to use the BSplineKernelBasis.
    
    This is a simple example on how to generate 
    a set of randomly centered basis spline functions.
    """
    # 2d interval, sample from top left to bottom right
    # in this way the value (0, 0) corresponds to (-100, 100).
    interval = np.array((100, -100, -100, 100)).reshape(2, 2)
    c = create_random_centers(interval, N)
    
    # To test the coordinates of a single center.
    #c = np.array((25, 25))[None, :]
    kb = BSplineKernel.create(c=c, s=s, k=k)
    
    # Now we generate three sets of random weights
    w = np.random.rand(3, N) * 2 - 1
    
    # Generate the meshgrid
    xx, yy = np.meshgrid(*np.linspace(*zip(*interval), 200).T)
    kb_channels = []
    fig, axes = plt.subplots(1, 3, figsize=(15,6))
    for ax, wa, cmap in zip(axes, w, ('Reds', 'Greens', 'Blues')):
        kb.w = wa
        kb_channels.append(kb.copy())
        k = kb(xx, yy)
        ax.scatter(*kb.c.T, marker='x', color='white')
        ax.imshow(k, extent=(-100, 100, -100, 100), cmap=cmap)
        ax.set_title(cmap)
    plt.show()
    return tuple(kb_channels), interval

kb_channels, interval = bsplinekernel_example()

# Example of a 2d convolution

In [None]:
def standard_convolution(kb_channels, interval, img):
    # We sample a 2d kernel
    linspaces = []
    steps = []
    for lb, ub in interval:
        x, step = np.linspace(lb, ub, 100, retstep=True)
        linspaces.append(x)
        steps.append(step)
    xx, yy = np.meshgrid(*np.linspace(*zip(*interval), 100).T)
    
    # Construct the per channel kernel weights
    weights = torch.stack([torch.Tensor(kb(xx, yy))[None, :] for kb in kb_channels])
    weights = weights
    conv = torch.nn.functional.conv2d(img, weights, groups=3)
    
    conv = conv.numpy().squeeze()
    conv = np.moveaxis(conv, 0, -1)
    conv = np.interp(conv, (conv.min(), conv.max()), (0, 1))
    return conv

def shifted_convolution(kb_channels, interval, img):
    # 1) Convolve the image once for every different center
    # with a different decimal place.
    # 2) Convolve the image for each of the different centers
    # 3) Integral shift and sum the results!


def bspline_convolution(kb_channels, interval):
    # Our test image
    img = scipy.misc.face()
    img = torch.Tensor(np.moveaxis(img, -1, 0))
    #img = normalize(img, (0, 0, 0), (100, 100, 100))
    img = img[None, :]
    
    start_time = time.time()
    conv = standard_convolution(kb_channels, interval, img)
    print('Standard convolution elasped time: ', time.time() - start_time)
    
    plt.figure()
    plt.imshow(conv)
    plt.show()
    
_ = bspline_convolution(kb_channels, interval)
