In [1]:
import numpy as np

In [2]:
from bokeh import events
from bokeh.io import push_notebook, output_notebook, show
from bokeh.layouts import row
from bokeh.models import CustomJS, Div
from bokeh.plotting import figure, ColumnDataSource
import bokeh.palettes

In [3]:
from sklearn import datasets
from sklearn.decomposition import PCA

In [4]:
from PIL import Image
import base64
from io import BytesIO

In [5]:
def gnp2im(image_np, bit_depth_scale_factor):
    """Converts an image stored as a 2-D grayscale Numpy array into a PIL image."""
    return Image.fromarray((image_np * bit_depth_scale_factor).astype(np.uint8), mode='L')

def to_base64(png):
    return "data:image/png;base64," + base64.b64encode(png).decode("utf-8")

def get_thumbnails(data, bit_depth_scale_factor):
    thumbnails = []
    for gnp in data:
        im = gnp2im(gnp, bit_depth_scale_factor)
        memout = BytesIO()
        im.save(memout, format='png')
        thumbnails.append(to_base64(memout.getvalue()))
    return thumbnails

In [6]:
def map_label_to_color(label):
    return viridis_palette[label]

def display_event(div, x, y, thumbnails, figure_width, figure_height, attributes=[], style = 'float:left;clear:left;font_size=13px'):
    "Build a suitable CustomJS to display the current event in the div model."
    return CustomJS(args=dict(div=div, x=x, y=y, thumbnails=thumbnails, figure_width=figure_width, figure_height=figure_height), code="""
        var attrs = %s; var args = []; var n = x.length;
        
        var test_x;
        var test_y;
        for (var i = 0; i < attrs.length; i++) {
            if (attrs[i] == 'x') {
                test_x = Number(cb_obj[attrs[i]]);
            }
            
            if (attrs[i] == 'y') {
                test_y = Number(cb_obj[attrs[i]]);
            }
        }
    
        var minDiffIndex = -1;
        var minDiff = 99999;
        var squareDiff;
        for (var i = 0; i < n; i++) {
            squareDiff = (test_x - x[i]) ** 2 + (test_y - y[i]) ** 2;
            if (squareDiff < minDiff) {
                minDiff = squareDiff;
                minDiffIndex = i;
            }
        }
        
        var img_tag_attrs = "height='" + (figure_height * 0.5) + "' width='" + (figure_width * 0.5) + "' style='float: left; margin: 0px 15px 15px 0px;' border='2'";
        var img_tag = "<div><img src='" + thumbnails[minDiffIndex] + "' " + img_tag_attrs + "></img></div>";
        var line = "<span style=%r>Index: " + minDiffIndex + "</span>" + img_tag + "\\n";
        div.text = "";
        var text = div.text.concat(line);
        var lines = text.split("\\n")
        if (lines.length > 35)
            lines.shift();
        div.text = lines.join("\\n");
    """ % (attributes, style))

In [7]:
output_notebook()

In [8]:
dataset = datasets.load_digits()

In [9]:
data = dataset['images']
labels = dataset['target']

In [10]:
n, h, w = data.shape

In [11]:
pca = PCA(n_components=2)

In [12]:
z = pca.fit_transform(data.reshape((n, h * w)))

In [13]:
z.shape

(1797, 2)

In [14]:
bit_depth_scale_factor = 255
thumbnails = get_thumbnails(data, bit_depth_scale_factor)

In [15]:
viridis_palette = bokeh.palettes.viridis(10)
colors = list(map(map_label_to_color, labels))

In [16]:
x = z[:, 0]
y = z[:, 1]

In [17]:
figure_width = 400
figure_height = 400

p = figure(plot_width=figure_width, plot_height=figure_height)
p.scatter(x, y, fill_color=colors, fill_alpha=0.6, line_color=None)

div = Div()

layout = row(p, div)

point_attributes = ['x', 'y']
p.js_on_event(events.MouseMove, display_event(div, x, y, thumbnails, figure_width, figure_height, attributes=point_attributes))
#p.js_on_event(events.Tap, display_event(div, x, y, thumbnails, figure_width, figure_height, attributes=point_attributes))

show(layout)