In [1]:
import os.path as path
import numpy as np
import torch.nn.functional as functional
from IPython.display import display, clear_output
from ipywidgets import FloatSlider, interactive, VBox
import ipywidgets as widgets
import matplotlib.pyplot as plt

from constants import *
from data_loaders import *
from model import CapsNet

%matplotlib inline

DEBUG_MODE = True
USE_GPU = True
MODEL = "model37.pt" # Specifies which model to load
RECONSTRUCTION_TYPE = "FC" # FC or Conv
BATCH_SIZE = 32 # Does not perform training, all it does is to pick a batch of data to reconstruct on

Re-run this block to reset your model outputs if you messed it up.

In [24]:
# Load model
capsnet = CapsNet(reconstruction_type=RECONSTRUCTION_TYPE)
if USE_GPU:
  capsnet.cuda()
model_path = path.join(SAVE_DIR, MODEL)
capsnet.load_state_dict(torch.load(model_path))

_, test_loader = load_mnist(BATCH_SIZE)
capsnet.eval()
data, target = iter(test_loader).next()
target = torch.eye(10).index_select(dim=0, index=target) # One-hot encode target
output, reconstruction, masked = capsnet(data.cuda())

Here is where you choose which input image to play around with.

In [25]:
i = np.random.randint(BATCH_SIZE) # index of chosen image in last batch
capsules = output[i:i+1] # capsules that correspond to this specific image

# Find prediction
classes = torch.sqrt((capsules**2).sum(2))
classes = functional.softmax(classes, dim=1)
_, prediction = classes.max(dim=1)

if DEBUG_MODE:
    print("Image:{}".format(i))
    print("Target:{}".format(target[i:i+1,:].max(dim=1)[1].item()))
    print("Prediction:{}".format(prediction.item()))

    print(capsules[:,prediction,:,:].shape)
    print(capsules[:,prediction,:,:])

Image:2
Target:1
Prediction:1
torch.Size([1, 1, 1, 16, 1])
tensor([[[[[ 0.5357],
           [-0.0752],
           [-0.5303],
           [ 0.5219],
           [-0.5333],
           [ 0.1194],
           [-0.1464],
           [ 0.5383],
           [-0.5247],
           [-0.5326],
           [ 0.3336],
           [-0.4676],
           [-0.2362],
           [-0.5337],
           [-0.1159],
           [-0.5238]]]]], device='cuda:0')


In [28]:
# Dirty work here
# TODO: Fix problems with capsules and prediction as parameters
def reconstruct(prediction,c0,c1,c2,c3,c4,c5,c6,c7,c8,c9,c10,c11,c12,c13,c14,c15):
    capsules[:,prediction,0,:] = c0
    capsules[:,prediction,1,:] = c1
    capsules[:,prediction,2,:] = c2
    capsules[:,prediction,3,:] = c3
    capsules[:,prediction,4,:] = c4
    capsules[:,prediction,5,:] = c5
    capsules[:,prediction,6,:] = c6
    capsules[:,prediction,7,:] = c7
    capsules[:,prediction,8,:] = c8
    capsules[:,prediction,9,:] = c9
    capsules[:,prediction,10,:] = c10
    capsules[:,prediction,11,:] = c11
    capsules[:,prediction,12,:] = c12
    capsules[:,prediction,13,:] = c13
    capsules[:,prediction,14,:] = c14
    capsules[:,prediction,15,:] = c15
    
    reconstruction, _ = capsnet.decoder(capsules, data, target[i:i+1].cuda())
    
    im = np.squeeze(reconstruction.data.cpu().numpy())
    im += abs(im.min())
    im /= im.max()
    plt.subplot(1,2,1)
    plt.title("Reconstruction")
    plt.imshow(im, cmap="gray");
    im2 = data[i, 0].data.cpu().numpy()
    im2 += abs(im.min())
    im2 /= im.max()
    plt.subplot(1,2,2)
    plt.title("Input")
    plt.imshow(im2, cmap="gray");
    
def build_widgets():
        return interactive(reconstruct,
                prediction=prediction,
                c0=FloatSlider(description="Capsule 0",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE),
                c1=FloatSlider(description="Capsule 1",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE),
                c2=FloatSlider(description="Capsule 2",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE),
                c3=FloatSlider(description="Capsule 3",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE),
                c4=FloatSlider(description="Capsule 4",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE),
                c5=FloatSlider(description="Capsule 5",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE),
                c6=FloatSlider(description="Capsule 6",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE),
                c7=FloatSlider(description="Capsule 7",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE),
                c8=FloatSlider(description="Capsule 8",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE),
                c9=FloatSlider(description="Capsule 9",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE),
                c10=FloatSlider(description="Capsule 10",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE),
                c11=FloatSlider(description="Capsule 11",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE),
                c12=FloatSlider(description="Capsule 12",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE),
                c13=FloatSlider(description="Capsule 13",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE),
                c14=FloatSlider(description="Capsule 14",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE),
                c15=FloatSlider(description="Capsule 15",min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE))

Currently all sliders are initialized to zeros, which means the initial reconstruction is not correct at all. You can set debug mode to true, and adjust the parameters according to the model output vector.

In [29]:
MIN = -1
MAX = 1
STEP = 1e-1
CONTINUOUS_UPDATE = True

# Initial values
init = np.squeeze(capsules[:,prediction,:,:])

w = build_widgets()
display(w)

In [None]:
# Experimental improvements for interaction with visualization
# CURRENTLY NOT WORKING

# def reconstruct(change, prediction, widgets_list):
#     for i, widget in enumerate(widgets_list):
#         capsules[:,prediction,i,:] = widget.value
    
#     reconstruction, _ = capsnet.decoder(capsules, data, target[i:i+1].cuda())
    
#     if DEBUG_MODE:
#         print(capsules)
#         print(target[i:i+1])
#         print(target[i:i+1].max(dim=1)[1].reshape(-1,1))
        
#     im = np.squeeze(reconstruction.data.cpu().numpy())
#     im += abs(im.min())
#     im /= im.max()
#     plt.subplot(1,2,1)
#     plt.title("Reconstruction")
#     plt.imshow(im, cmap="gray");
#     im2 = data[i, 0].data.cpu().numpy()
#     im2 += abs(im.min())
#     im2 /= im.max()
#     plt.subplot(1,2,2)
#     plt.title("Input")
#     plt.imshow(im2, cmap="gray");

# MIN = -1
# MAX = 1
# STEP = 1e-1
# CAPS_COUNT = 16
# CONTINUOUS_UPDATE = True

# # Credits to building these widgets: https://stackoverflow.com/questions/37622023
# widgets_list = []
# for i in range(CAPS_COUNT):
#     widgets_list.append(FloatSlider(description="Capsule "+str(i),
#                                     min=MIN, max=MAX, step=STEP, continuous_update=CONTINUOUS_UPDATE))
# for widget in widgets_list:
#     widget.observe(lambda change:reconstruct(change, prediction, widgets_list))
    
# w = VBox(children=widgets_list)