In [8]:
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
from scipy import ndimage
from scipy import signal
import scipy.sparse.linalg
import torch
import re
from torch.nn.functional import conv2d
import functools
from matplotlib.widgets import Slider, Button, RadioButtons, TextBox
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from crf.utils import read_image, read_pfm, read_pgm
from crf.features import Vgg16features
from crf.crf import *
from crf.depth import *
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [10]:
img1 = read_image('imL.png')
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
VGG = Vgg16features()
_ = VGG.to(device)

In [45]:
%matplotlib widget


f, axarr = plt.subplots(1,2)
shape = (3,5)
im1 = np.random.rand(*shape)
ax0 = axarr[0].imshow(img1[::4,::4])  # 5 points tolerance
ax1 = axarr[1].imshow(img1[::4,::4])
rotate_left = np.array([[0,-1],[1,0]])
rotate_right = np.linalg.inv(rotate_left)

h,w = img1[::4,::4].shape[:2]
position = np.mgrid[:h,:w].transpose((1,2,0))/np.sqrt(h**2+w**2)
all_features = VGG.get_all_features(img1)
    
## Make sliders, and buttons
axcolor = 'black'
sigma1ax = plt.axes([0.25, 0.1, 0.65, 0.03], facecolor=axcolor)
sigma2ax = plt.axes([0.25, 0.15, 0.65, 0.03], facecolor=axcolor)
sigma3ax = plt.axes([0.25, 0.2, 0.65, 0.03], facecolor=axcolor)

sigma1 = Slider(sigma1ax, r'$\sigma_p$', 0.01, 0.3, valinit=.1)
sigma2 = Slider(sigma2ax, r'$\sigma_c$', 0.01, 1.0, valinit=.1)
sigma3 = Slider(sigma3ax, r'$\sigma_f$', 0.05, 10.0, valinit=3)

rax = plt.axes([0.05, 0.10, 0.075, 0.15], facecolor='white')
layer_selector = RadioButtons(rax, (0, 1, 2), active=0)
axbox = plt.axes([0.05, 0.30, 0.05, 0.05])
dim_selector = TextBox(axbox, 'd', initial='10')


def callback(i,j):
    if axarr[1].lines:
        axarr[1].lines[-1].remove()
    n,m = im1.shape[:2]
    #ax1.set_data(img_W[i,j])
    ax1.set_data(20*f.get_W(i,j))
    coords = [j,i]
    axarr[1].plot(*coords,".r",markersize=4)
    
def on_move(fig, axes,callback, event):
    if fig.frozen: return
    # Verify click is within the axes of interest
    if axes[0].in_axes(event):
        imshape = axes[0].get_images()[0]._A.shape[:2]
        # Transform the event from display to axes coordinates
        ax_pos = axes[0].transAxes.inverted().transform((event.x, event.y))
        
        i,j = (rotate_left@(ax_pos)*np.array(imshape)//1).astype(int)
        i,j = i%imshape[0],j%imshape[1]
        if [i,j]!=fig.last_ij:
            callback(i,j)
            fig.last_ij = [i,j]
        #print(dir(axes[0].get_images()[0]))
        #print(dir(axes[0]))
def on_click(event):
    if axarr[0].in_axes(event):
        f.frozen=not f.frozen
        if not f.frozen:
            handler_wrapper(event)
            
def update_sigmas(*args):
    s1,s2,s3 = sigma1.val,sigma2.val,sigma3.val
    d = int(dim_selector.text)
    ref = np.zeros((h,w,5+d))
    ref[...,:3] = img1[::4,::4]/s2
    ref[...,3:5] = position/s1
    q = int(layer_selector.value_selected)
    ref[...,5:] = f.feats/s3
    f.get_W = lazy_W(ref)
    callback(*f.last_ij)
    
def update_features(*args):
    q = int(layer_selector.value_selected)
    d = int(dim_selector.text)
    projection_matrix = np.random.rand(all_features[q].shape[-1],d)
    feats = (all_features[q]@projection_matrix)[::2**(2-q),::2**(2-q)]
    f.feats = (feats - feats.mean((0,1)))/feats.std((0,1))
    update_sigmas(*args)

f.last_ij = [0,0]
f.frozen = True
update_features()

## Attach sliders and buttons
on_move_wrapper = functools.partial(on_move, f, axarr,callback)
f.canvas.mpl_connect('motion_notify_event', on_move_wrapper)
f.canvas.mpl_connect('button_press_event', on_click)

[s.on_changed(update_sigmas) for s in [sigma1,sigma2,sigma3]]
dim_selector.on_submit(update_features)
layer_selector.on_clicked(update_features)

plt.show()

FigureCanvasNbAgg()