# Viewing 2D projection and comparing with raw high dimensional values

There are many methods for visualizing high dimensional data in 2D. Some popular methods are PCA, t-SNE, autoencoding variants, amongst many others. These methods often do a great job at giving us a feel for what high dimensional data looks like; something our mere mortal brains can't do.

The main aim of this notebook is to setup a plotly dashboard that lets us interact with the 2D projection. This dashboard will consist of three panels:

- The left panel shows a scatter plot of the 2D projection
- The middle panel shows a line plot of the *raw high dimensional data* from a selected point in the 2D projection from the left panel. The idea here is we want to compare nearby points in the 2D embedding to see if they make sense in ther high dimensional counter-part.
- The right panel shows not a single line plot, but a collection of raw high dimensional data *points* from a group of selected points in the 2D projection from the left panel. The aim is similar to the middle panel but instead of focusing on a single data point we want to a look at a collection that we can selected using the box-select or lasso-select tool. 

In [1]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
# plotly
from plotly.graph_objs import FigureWidget
from plotly.callbacks import Points, InputDeviceState


# ipywidgets
from ipywidgets import HTML
from ipywidgets import HBox, VBox, Button

## Data and projection setup

We load the data and prep and project to 2D. Here we use PCA.

In [3]:
from sklearn.datasets import load_iris

In [4]:
X,y = load_iris(return_X_y=True)


In [5]:
X.shape

(150, 4)

In [6]:
X = np.log10(X)

In [7]:
from sklearn.decomposition import PCA

In [8]:
z = PCA(n_components=2).fit_transform(X)

## Dashboard setup

Here is the magic of using plotly to make our dashbaord!

In [9]:
## left panel figure
fig = FigureWidget(
    data=[
        dict(
            type='scattergl',
            x=z[:,0],
            y=z[:,1],
            mode='markers',
        )
    ],
)
fig.data[0].marker.cmax = 1.5
fig.data[0].marker.cmin = -.5100
fig.data[0].marker.colorscale = [[0, 'lightgray'], [0.5, 'lightgray'],[0.5, 'red'], [1, 'red']]
fig.data[0].marker.color = np.zeros(X.shape[0])
fig.layout.width = 500
fig.layout.hovermode = "closest"
fig.layout.title = "Embedding"
fig.layout.xaxis.title = "z1"
fig.layout.yaxis.title = "z2"

scatt1 = fig.data[0]

In [16]:
#

In [10]:
# middle panel setup
fig2 = FigureWidget(
    data=[
        dict(
            type='scattergl',
            x= np.arange(X.shape[1]),
            y=X[0,:],        
        )
    ],
)

fig2.layout.width = 500
fig2.layout.hovermode = "closest"
fig2.layout.xaxis.title = "Features"
fig2.layout.title = "Selected sample"

In [11]:
# right panel setup

fig3 = FigureWidget(
    data=[
        dict(
            type='scattergl',
            x= np.arange(X.shape[1]),
            y=np.mean(X,0),        
        )
    ],
)

fig3.layout.width = 500
fig3.layout.hovermode = "closest"
fig3.layout.title = "Average values"
fig3.layout.xaxis.title = "Average values"

In [12]:
# interactivity setup

selection_inds = []
def brush(trace, points, state):
    selection_inds.clear()
    selection_inds.extend(points.point_inds)    
    inds = np.array(points.point_inds)
    fig3.data[0].y = np.mean(X[points.point_inds,:],0)        
    for i in range(len(selection_inds)):
        fig3.add_scattergl(x = np.arange(X.shape[1]), y = X[selection_inds[i],:], 
                           marker = {"color": "lightblue"}, opacity = .3, showlegend=False)
    if inds.size:
        selected = scatt1.marker.color.copy()
        selected[inds] = 1
        scatt1.marker.color = selected 
        

    
scatt1.on_selection(brush)


def hover_fn(trace, points, state):
    ind = points.point_inds[0]
    fig2.data[0].y = X[ind,:]

scatt1.on_hover(hover_fn)


# Reset brush
def reset_brush(btn):
    selection_inds.clear()
    fig3.data = ()
    fig3.add_scattergl(x = np.arange(X.shape[1]), y = np.mean(X,0), 
                           showlegend=False)   
    scatt1.marker.color = np.zeros(X.shape[0])

# Create reset button
button = Button(description="clear")
button.on_click(reset_brush)

In [13]:
dashboard = VBox([HBox([fig, fig2, fig3]), button])
dashboard

VBox(children=(HBox(children=(FigureWidget({
    'data': [{'marker': {'cmax': 1.5,
                         'câ€¦