In [None]:
import plotly
import plotly.graph_objs as go
import numpy as np
from plotly.offline import iplot
from ipywidgets import GridspecLayout, widgets, HBox


plotly.offline.init_notebook_mode()

In [None]:
n_points = 8
marker_size = 3
x_shift = 3.5
opacity = 0.75

In [None]:
from torchvision import datasets, transforms
import torch
import torch.nn.functional as F

def place_digits_randomly(mnist_imgs):
    new_img = torch.zeros([64, 64])

    x_pos, y_pos = torch.randint(0, 64-28, (2, 2))
    for i in range(2):
        new_img[y_pos[i]:y_pos[i]+28, x_pos[i]:x_pos[i]+28] += mnist_imgs[i][0]
    return new_img


SEED = 7
np.random.seed(SEED)
torch.manual_seed(SEED)

mnist = datasets.MNIST('./data', train=True, transform=transforms.ToTensor(), download=True)
i_sample = np.random.randint(len(mnist), size=2)
extended_img = place_digits_randomly([mnist[i_sample[0]][0], mnist[i_sample[1]][0]])
# flip image to get desired view 
extended_img = torch.flip(extended_img, dims=[0])

In [None]:
def create_target(n_points, marker_size):
    x, y = np.linspace(-1, 1, n_points), np.linspace(-1, 1, n_points)
    X, Y = np.meshgrid(x, y)
    X, Y = X.reshape(-1), Y.reshape(-1)
    
    target = np.concatenate((X[np.newaxis, :], Y[np.newaxis, :], 
                             np.ones([1, len(X)])), 0)
    G = go.Scatter3d(x=target[0], y=target[1], z=target[2], hoverinfo='none', 
                     mode='markers', marker=dict(size=marker_size, color='green', opacity=opacity))
    return target, G

def create_init_transformation(target, n_points):
    A_theta = np.array([[1, 0, 0], 
                        [0, 1, 0]])
    source_xy = A_theta @ target 
    source = np.concatenate((source_xy[0][np.newaxis, :], 
                             source_xy[1][np.newaxis, :], 
                             -4*np.ones([1, n_points**2])), 
                            0)
    G_transformed = go.Scatter3d(x=source[0], y=source[1], z=source[2], 
                                 hoverinfo='none', mode='markers', visible=True,
                                 marker=dict(size=marker_size, color='green', opacity=opacity))
    return source, G_transformed

def create_connecting_corner_lines(source, target, n_points):
    corners = []
    for i in [0, n_points-1, n_points*(n_points-1), n_points**2 - 1]:
        corner = go.Scatter3d(x=np.array([target[0][i], source[0][i]]),
                              y=np.array([target[1][i], source[1][i]]),
                              z=np.array([target[2][i], source[2][i]]),
                              mode='lines',  hoverinfo='none', 
                              line=dict(color='blue', width=2), opacity=opacity)
        corners.append(corner)
    return corners

def create_layout(width=850, height=450):
    layout = dict(
        title='',
        showlegend=False,
        width=width,
        height=height,
        margin=dict(t=0, r=0, l=10, b=0),
        scene=dict(
            camera=dict(up=dict(x=0, y=0.5, z=0),
                        center=dict(x=0.9, y=-0.8, z=-0.1),
                        eye=dict(x=-1.3, y=0.6, z=1.8)),
                        #center=dict(x=0.3, y=-0.3, z=0)),
            dragmode=False,
        ),
    )
    return layout


target, G_target = create_target(n_points, marker_size)
source, G_source = create_init_transformation(target, n_points)
corners = create_connecting_corner_lines(source, target, n_points)

target_2, G_target_2  = create_target(n_points, marker_size)
# update some stuff
target_2[0] = target_2[0] + x_shift
G_target_2['x'] = target_2[0]
G_target_2['marker']['color'] = 'red'

source_2, G_source_2 = create_init_transformation(target_2, n_points)
# update some stuff
source_2[2] = -(3.5/4) * source_2[2]
G_source_2['z'] = source_2[2]
G_source_2['marker']['color'] = 'red'

corners_2 = create_connecting_corner_lines(source_2, target_2, n_points)


x = np.linspace(-1, 1, 64)
y = np.linspace(-1, 1, 64)
mnist_img_original = go.Surface(x=x, y=y, z=-4*np.ones((64, 64)), 
                                surfacecolor=extended_img, cmin=0, cmax=1, 
                                colorscale='Gray', showscale=False)
mnist_img_resample = go.Surface(x=x, y=y, z=1*np.ones((64, 64)), 
                                surfacecolor=extended_img, cmin=0, cmax=1, 
                                colorscale='Gray', showscale=False)
mnist_img_resample_2 = go.Surface(x=x+x_shift, y=y, z=1*np.ones((64, 64)), 
                                surfacecolor=extended_img, cmin=0, cmax=1, 
                                colorscale='Gray', showscale=False)
mnist_img_reconstructed = go.Surface(x=x+x_shift, y=y, z=3.5*np.ones((64, 64)), 
                                    surfacecolor=extended_img, cmin=0, cmax=1, 
                                    colorscale='Gray', showscale=False)



data = [mnist_img_resample, mnist_img_resample_2, mnist_img_reconstructed, mnist_img_original,
        G_target, G_target_2, G_source, G_source_2] + corners + corners_2
layout = create_layout()

base_figure = go.Figure(data=data, layout=layout)
base_figure


In [None]:
from ipywidgets import GridspecLayout, widgets, HBox



def response(change):
    with fig.batch_update():
        A_theta = np.array([[s.value, 0, t_x.value], 
                            [0, s.value, t_y.value],
                            [0, 0, 1]])
        A_theta_inv = np.linalg.pinv(A_theta)
        
        source = A_theta @ target 
        # update G source
        fig.data[6]['x'] = source[0]
        fig.data[6]['y'] = source[1]
        # update mnist image in middle 
        A_theta = torch.from_numpy(A_theta[0:2, :]).unsqueeze(0).type(torch.float32)
        grid = F.affine_grid(A_theta, size=(1, 1, 64, 64), align_corners=False)
        out = F.grid_sample(extended_img.unsqueeze(0).unsqueeze(0), grid, align_corners=False)
        fig.data[0]['surfacecolor'] = out[0, 0]  
        fig.data[1]['surfacecolor'] = out[0, 0] 
        # update G source with using inverse transformation
        x_rec_grid = A_theta_inv @ target
        fig.data[5]['x'] = x_rec_grid[0] + x_shift
        fig.data[5]['y'] = x_rec_grid[1]
        # update reconstructed mnist 
        A_theta_inv = torch.from_numpy(A_theta_inv[0:2, :]).unsqueeze(0).type(torch.float32)
        #print(A_theta_inv.shape, A_theta.shape)
        grid = F.affine_grid(A_theta_inv, size=(1, 1, 64, 64), align_corners=False)
        out = F.grid_sample(out, grid, align_corners=False)
        fig.data[2]['surfacecolor'] = out[0, 0]  
        
        # update connecting corners for standard attention
        for index, i_point in enumerate([0, n_points-1, n_points*(n_points-1), n_points**2 - 1]):
            fig.data[-8 + index]['x'] = np.array([target[0][i_point], source[0][i_point]])
            fig.data[-8 + index]['y'] = np.array([target[1][i_point], source[1][i_point]])
        # update connecting corners for inv attention
        for index, i_point in enumerate([0, n_points-1, n_points*(n_points-1), n_points**2 - 1]):
            fig.data[-4 + index]['x'] = np.array([x_rec_grid[0][i_point] + x_shift, 
                                                  target[0][i_point] + x_shift])
            fig.data[-4 + index]['y'] = np.array([x_rec_grid[1][i_point], 
                                                  target[1][i_point]])

def reset_values(b):
    with fig.batch_update():
        s.value, t_x.value, t_y.value = 1., 0., 0.
        
        A_theta = np.array([[s.value, 0, t_x.value], 
                            [0, s.value, t_y.value],
                            [0, 0, 1]])
        A_theta_inv = np.linalg.pinv(A_theta)
        
        source = A_theta @ target 
        # update G source
        fig.data[6]['x'] = source[0]
        fig.data[6]['y'] = source[1]
        # update mnist image in middle 
        A_theta = torch.from_numpy(A_theta[0:2, :]).unsqueeze(0).type(torch.float32)
        grid = F.affine_grid(A_theta, size=(1, 1, 64, 64), align_corners=False)
        out = F.grid_sample(extended_img.unsqueeze(0).unsqueeze(0), grid, align_corners=False)
        fig.data[0]['surfacecolor'] = out[0, 0]  
        fig.data[1]['surfacecolor'] = out[0, 0] 
        # update G source with using inverse transformation
        x_rec_grid = A_theta_inv @ target
        fig.data[5]['x'] = x_rec_grid[0] + x_shift
        fig.data[5]['y'] = x_rec_grid[1]
        # update reconstructed mnist 
        A_theta_inv = torch.from_numpy(A_theta_inv[0:2, :]).unsqueeze(0).type(torch.float32)
        #print(A_theta_inv.shape, A_theta.shape)
        grid = F.affine_grid(A_theta_inv, size=(1, 1, 64, 64), align_corners=False)
        out = F.grid_sample(out, grid, align_corners=False)
        fig.data[2]['surfacecolor'] = out[0, 0]  
        
        # update connecting corners for standard attention
        for index, i_point in enumerate([0, n_points-1, n_points*(n_points-1), n_points**2 - 1]):
            fig.data[-8 + index]['x'] = np.array([target[0][i_point], source[0][i_point]])
            fig.data[-8 + index]['y'] = np.array([target[1][i_point], source[1][i_point]])
        # update connecting corners for inv attention
        for index, i_point in enumerate([0, n_points-1, n_points*(n_points-1), n_points**2 - 1]):
            fig.data[-4 + index]['x'] = np.array([x_rec_grid[0][i_point] + x_shift, 
                                                  target[0][i_point] + x_shift])
            fig.data[-4 + index]['y'] = np.array([x_rec_grid[1][i_point], 
                                                  target[1][i_point]])
        
        
fig = go.FigureWidget(base_figure)
# create sliders
s = widgets.FloatSlider(value=1, min=-1, max=1, step=0.1)
t_x = widgets.FloatSlider(value=0, min=-1, max=1, step=0.1)
t_y = widgets.FloatSlider(value=0, min=-1, max=1, step=0.1)
# add behavior to sliders
s.observe(response, names="value")
t_x.observe(response, names="value")
t_y.observe(response, names="value")
# create reset button
reset_button = widgets.Button(description = "Reset")
# add behavior
reset_button.on_click(reset_values)

# make box around figure
fig_box = widgets.Box([fig])
# create title widget
title = widgets.HTML(value="<h2 style='color:#303F5F'>Creation of Attention Crops and Inverse Transformation</h2>")
# create widget displaying formula
formula_str_1 = r'$\color{green}{\left[ \begin{array}{c} x_k^s \\ y_k^s \\ 1 \end{array} \right] = ' + \
    r'\left[ \begin{array}{c} s & 0 & t_x \\' + \
    r'0 & s & t_y  \\ 0 & 0 & 1\end{array} \right]' + \
    r'\left[\begin{array}{c} x_k^t \\ y_k^t \\ 1 \end{array} \right]}$'
formula_label_1 = widgets.Label(value=formula_str_1)
formula_str_2 = r'$\color{red}{\left[ \begin{array}{c} x_k^s \\ y_k^s \\ 1 \end{array} \right] = ' + \
    r' \left( \left[\begin{array}{c} s & 0 & t_x \\' + \
    r'0 & s & t_y  \\ 0 & 0 & 1\end{array} \right] \right)^{+} ' + \
    r'\left[\begin{array}{c} x_k^t \\ y_k^t \\ 1 \end{array} \right]}$'
formula_label_2 = widgets.Label(value=formula_str_2)

# put everything together in GridSpecLayout
n_rows, n_cols = 12, 8
grid_spec = GridspecLayout(n_rows, n_cols, height='auto', width='100%')
grid_spec[2, n_cols-2:n_cols] = HBox([widgets.Label(value=r'$s$'), s])
grid_spec[3, n_cols-2:n_cols] = HBox([widgets.Label(value=r'$t_x$'), t_x])
grid_spec[4, n_cols-2:n_cols] = HBox([widgets.Label(value=r'$t_y$'), t_y])

grid_spec[5, n_cols-2:n_cols] = reset_button

grid_spec[2:n_rows-1, 0:n_cols-2] = fig_box
grid_spec[0, 0:8] = title
grid_spec[1, 1:3] = formula_label_1
grid_spec[1, 3:6] = formula_label_2

grid_spec