In [None]:
import plotly
import plotly.graph_objs as go
import numpy as np
from plotly.offline import iplot
from ipywidgets import GridspecLayout, widgets, HBox


plotly.offline.init_notebook_mode()

In [None]:
n_points = 10 
marker_size = 3

## Base Figure

In [None]:
def create_target(n_points, marker_size):
    x, y = np.linspace(0, 1, n_points), np.linspace(0, 1, n_points)
    X, Y = np.meshgrid(x, y)
    X, Y = X.reshape(-1), Y.reshape(-1)
    
    target = np.concatenate((X[np.newaxis, :], Y[np.newaxis, :], 
                             np.ones([1, len(X)])), 0)
    G = go.Scatter3d(x=target[0], y=target[1], z=target[2], hoverinfo='none', 
                     mode='markers', marker=dict(size=marker_size, color='green'))
    return target, G


def create_init_transformation(target, n_points):
    A_theta = np.array([[1, 0, 0], 
                        [0, 1, 0]])
    source_xy = A_theta @ target 
    source = np.concatenate((source_xy[0][np.newaxis, :], 
                             source_xy[1][np.newaxis, :], 
                             -np.ones([1, n_points**2])), 
                            0)
    G_transformed = go.Scatter3d(x=source[0], y=source[1], z=source[2], 
                                 hoverinfo='none', mode='markers', visible=True,
                                 marker=dict(size=marker_size, color='green'))
    return source, G_transformed


def create_squares(min_val=-0.1, max_val=1.1):
    square_1 = go.Scatter3d(x=np.array([min_val, max_val, max_val, min_val, min_val]),
                            y=np.array([min_val, min_val, max_val, max_val, min_val]),
                            z=np.array([1, 1, 1, 1, 1]), mode='lines', 
                            line=dict(color='lightgreen'), hoverinfo='none')
    square_2 = go.Scatter3d(x=np.array([min_val, max_val, max_val, min_val, min_val]),
                            y=np.array([min_val, min_val, max_val, max_val, min_val]),
                            z=np.array([-1, -1,-1, -1, -1]), mode='lines', 
                            line=dict(color='black'), hoverinfo='none')
    return square_1, square_2


def create_connecting_corner_lines(source, target, n_points):
    corners = []
    for i in [0, n_points-1, n_points*(n_points-1), n_points**2 - 1]:
        corner = go.Scatter3d(x=np.array([source[0][i], target[0][i]]),
                              y=np.array([source[1][i], target[1][i]]),
                              z=np.array([source[2][i], target[2][i]]),
                              mode='lines',  hoverinfo='none', 
                              line=dict(color='black', width=2), opacity=0.5)
        corners.append(corner)
    return corners


def create_layout(width=750, height=450):
    layout = dict(
        title='',
        showlegend=False,
        width=width,
        height=height,
        margin=dict(t=0, r=0, l=10, b=0),
        scene=dict(
            camera=dict(up=dict(x=0, y=0.5, z=0),
                        center=dict(x=0.9, y=-0.6, z=-0.5),
                        eye=dict(x=-1., y=0.5, z=1.5)),
                        #center=dict(x=0.3, y=-0.3, z=0)),
            dragmode=False,
        ),
    )
    return layout


target, G_target = create_target(n_points, marker_size)
source, G_source = create_init_transformation(target, n_points)
corners = create_connecting_corner_lines(source, target, n_points)
square_1, square_2 = create_squares()

data = [G_target, square_1, square_2, G_source] + corners
layout = create_layout()

base_figure = go.Figure(data=data, layout=layout)
base_figure

## Attention Transformation

In [None]:
from ipywidgets import GridspecLayout, widgets, HBox



def response(change):
    with fig.batch_update():
        A_theta = np.array([[s.value, 0, t_x.value], 
                            [0, s.value, t_y.value]])
        source_xy = A_theta @ target 
        source = np.concatenate((source_xy[0][np.newaxis, :], 
                                 source_xy[1][np.newaxis, :], 
                                 -np.ones([1, n_points**2])), 
                                 0)
        # update connecting corners
        for index, i_point in enumerate([0, n_points-1, n_points*(n_points-1), n_points**2 - 1]):
            fig.data[-4 + index]['x'] = np.array([source[0][i_point], target[0][i_point]])
            fig.data[-4 + index]['y'] = np.array([source[1][i_point], target[1][i_point]])
            fig.data[-4 + index]['z'] = np.array([source[2][i_point], target[2][i_point]])
        # update G source
        fig.data[-5]['x'] = source[0]
        fig.data[-5]['y'] = source[1]
        

def reset_values(b):
    with fig.batch_update():
        s.value, t_x.value, t_y.value = 1., 0., 0.
        
        A_theta = np.array([[s.value, 0, t_x.value], 
                            [0, s.value, t_y.value]])
        source_xy = A_theta @ target 
        source = np.concatenate((source_xy[0][np.newaxis, :], 
                                 source_xy[1][np.newaxis, :], 
                                 -np.ones([1, n_points**2])), 
                                 0)
        # update connecting corners
        for index, i_point in enumerate([0, n_points-1, n_points*(n_points-1), n_points**2 - 1]):
            fig.data[-4 + index]['x'] = np.array([source[0][i_point], target[0][i_point]])
            fig.data[-4 + index]['y'] = np.array([source[1][i_point], target[1][i_point]])
            fig.data[-4 + index]['z'] = np.array([source[2][i_point], target[2][i_point]])
        # update G source
        fig.data[-5]['x'] = source[0]
        fig.data[-5]['y'] = source[1]
        
        
fig = go.FigureWidget(base_figure)
# create sliders
s = widgets.FloatSlider(value=1, min=-1, max=1, step=0.1)
t_x = widgets.FloatSlider(value=0, min=-1, max=1, step=0.1)
t_y = widgets.FloatSlider(value=0, min=-1, max=1, step=0.1)
# add behavior to sliders
s.observe(response, names="value")
t_x.observe(response, names="value")
t_y.observe(response, names="value")
# create reset button
reset_button = widgets.Button(description = "Reset")
# add behavior
reset_button.on_click(reset_values)

# make box around figure
fig_box = widgets.Box([fig])
# create title widget
title = widgets.HTML(value="<h2 style='color:#303F5F'>(Standard) Attention</h2>")
# create widget displaying formula
formula_str = r'$  \widetilde{G}_i = \left[ \begin{array}{c} x_i^s \\ y_i^s \end{array} \right] = ' + \
    r'T_{\boldsymbol{\theta}} (G_i) = \left[ \begin{array}{c} s & 0 & t_x \\' + \
    r'0 & s & t_y \end{array} \right]' + \
    r'\left[\begin{array}{c} x_i^t \\ y_i^t \\ 1 \end{array} \right] $'
formula_label = widgets.Label(value=formula_str)

# put everything together in GridSpecLayout
n_rows, n_cols = 9, 8
grid_spec = GridspecLayout(n_rows, n_cols, height='auto', width='100%')
grid_spec[2, n_cols-3:n_cols] = HBox([widgets.Label(value=r'$s \text{ (isotropic)} \quad$'), s])
grid_spec[3, n_cols-3:n_cols] = HBox([widgets.Label(value=r'$t_x \text{ (offset x)} \quad$'), t_x])
grid_spec[4, n_cols-3:n_cols] = HBox([widgets.Label(value=r'$t_y \text{ (offset y)} \quad$'), t_y])
#grid_spec[5, n_cols-3:n_cols] = HBox([widgets.Label(value=r'$\theta_4 \text{ (shear y)} \quad$'), theta_4])
#grid_spec[6, n_cols-3:n_cols] = HBox([widgets.Label(value=r'$\theta_5 \text{ (scale y)} \quad$'), theta_5])
#grid_spec[7, n_cols-3:n_cols] = HBox([widgets.Label(value=r'$\theta_6 \text{ (offset y)} \quad$'), theta_6])
grid_spec[5, n_cols-3:n_cols] = reset_button

grid_spec[2:n_rows-1, 0:n_cols-3] = fig_box
grid_spec[0, 0:8] = title
grid_spec[1, 2::] = formula_label

grid_spec