In [1]:
import json

import numpy as np
import pandas as pd

import dash
import dash_table
import dash_core_components as dcc
import dash_html_components as html
from dash.dependencies import Input, Output
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from jupyter_dash import JupyterDash

In [2]:
def graphimage(data, point=0):
    try:
        number = data['points'][point]['customdata'][0]
        img = 255*images[number]
        img = img.astype(np.uint8)
    except:
        img = np.zeros((28,28))

    z = np.zeros((28,28,3), dtype=np.uint8)
    z[:,:,1] =  img
    f = go.Figure(go.Image(z=z))
    f.update_layout(
        height=48,
        width=48,
        margin=dict(l=4,r=4,t=4,b=4)
    )
    f.update_xaxes(showticklabels=False) 
    f.update_yaxes(showticklabels=False) 
    return f

def imagediv(fg, id):
    res = html.Div([
        dcc.Graph(id=id, figure=fg)],
        className='',
        #style={'width':'200px'}
    )
    return res

colors_list = px.colors.qualitative.Plotly
colors = {i:c for i,c in enumerate(colors_list)}
print(colors)

{0: '#636EFA', 1: '#EF553B', 2: '#00CC96', 3: '#AB63FA', 4: '#FFA15A', 5: '#19D3F3', 6: '#FF6692', 7: '#B6E880', 8: '#FF97FF', 9: '#FECB52'}


In [4]:
images = np.load('/Users/cjw/Code/MNIST/Data/mnist_norm.npy')
images = images.reshape((-1, 28, 28))

df = pd.read_pickle('/Users/cjw/Code/MNIST/Data/umap.pkl')#.sort_values('label')
df = df.drop('strlabel', axis=1)
df = df.reset_index()
df = df.sort_values('label')
df.head()

Unnamed: 0,index,x,y,label
30207,30207,14.818077,8.978702,0
5662,5662,13.744578,7.798445,0
55366,55366,14.832038,6.889963,0
14160,14160,14.265167,7.510245,0
14161,14161,14.693933,7.937889,0


In [5]:
fig = px.scatter(df, x='x', y='y', color='label',
                 custom_data=[df.index, df.label])

fig.update_traces(marker_size=4)
fig.update_layout(clickmode='event+select',
                  height=500,
                  width=500,
                  margin=dict(l=4,r=4,t=4,b=4),
                  legend={'itemsizing':'constant','itemwidth':60})

z = np.zeros((28,28,3), dtype=np.uint8)
z[:,:,1] =255*(images[14]).astype('uint8') 
mnistfig = go.Figure(go.Image(z=z))
selfig = go.Figure() # make_subplots(1,4, horizontal_spacing=.05)
mnistfig.update_layout(height=200,
                       width=200)
mnistfig.update_xaxes(showticklabels=False) 
mnistfig.update_yaxes(showticklabels=False) 

num_options = [{'label':s, 'value':s} for s in range(10)]
num_options.insert(0, {'label':'All', 'value':-1})

In [6]:
external_stylesheets = ['https://codepen.io/chriddyp/pen/bWLwgP.css']
app = JupyterDash(__name__, external_stylesheets=external_stylesheets)

app.layout = html.Div([
        html.Div([
           html.Label(['Pick a label', dcc.Dropdown(options=num_options,
                                                    id='num-dropdown')], 
            style={'width':'20%'}),
            html.Div(
                dcc.Graph(
                    id='basic-interactions',
                    figure=fig,
                ),
                className='hoverimage', style={'width':'90%', 'float':'left'})],
            style={'width':'40%', 'float':'left'}
        ),
        html.Div([
            html.Div(dcc.Graph(id='mnist-image',
                            figure=mnistfig), 
                className='',
                style={}),
            html.Div(
                children=[imagediv(selfig, "s0")],
                id='selected-images', className='', 
                style={'width':'95%', 'height':'160px', 'display':'flex',
                       'flex-wrap':'wrap'}),
            html.Div([
                dash_table.DataTable(
                    id='table',
                    columns=[{'name':i, 'id': i} for i in df.columns],
                    data=df.sample(10).to_dict('records'),
                    style_cell={'font-size':'small', 'height':'10px'}
                ),     
            ], style={'position':'absolute', 'width':'80%','bottom':'0px', 'left':'0px',
                      }),
        ], style={'width':'40%', 'position':'absolute', 'height':'750px',
                  'left':'50%', 'top':'50px','border':'3px solid #73AD21'}),
])

In [7]:
@app.callback(
    Output('basic-interactions', 'figure'),
    Input('num-dropdown', 'value')
)
def pick_label(val_label):

    if val_label is None:
        _df = df
        xcolor = 'label'
    elif val_label == -1:
        _df = df
        xcolor = 'label'
    else:
        _df = df[df.label == val_label].copy()
        xcolor = 'xcolor'
        _df['xcolor'] = colors[val_label]
        print(xcolor, val_label, colors[val_label])
        print(_df.head())
     
    f1 = px.scatter(_df, x='x', y='y', color='label',
                 custom_data=[_df.index, _df.label])

    f1.update_traces(marker_size=4)
    if val_label is not None and val_label >= 0:
        f1.update_traces(marker_color=colors[val_label])

    f1.update_layout(clickmode='event+select',
                  height=500,
                  width=500,
                  margin=dict(l=4,r=4,t=4,b=4),  
                  legend={'itemsizing':'constant','itemwidth':60})
    return f1
    
@app.callback(
    Output('mnist-image', 'figure'),
    Input('basic-interactions', 'hoverData'))
def display_hover_data(hoverData):
    f = graphimage(hoverData)
    return f #, json.dumps(hoverData, indent=2)

@app.callback(
    [Output('table', 'data'),
     Output('selected-images', 'children')],
    Input('basic-interactions', 'selectedData'))
def display_hover_data(selectedData):
    if selectedData is None:
        #return json.dumps(selectedData, indent=2), [] 
        return [{'a':0}], [] 

    n = len(selectedData['points'])
    px = list()
    g = list() 
    nums = np.random.randint(0,n, 12)
    for i, ix in enumerate(nums):
        fg = graphimage(selectedData, point=ix)
        fgdiv = imagediv(fg, f"si{i}")
        g.append(fgdiv)
        px.append(selectedData['points'][ix]['customdata'][0])
   
    res_table = df.loc[df.index.isin(px)].to_dict('records')
    return res_table, g

In [8]:
app.run_server(mode='inline')

In [9]:
28*28

784

xcolor 4 #FFA15A
       index         x          y  label   xcolor
52589  52589 -1.793161  13.782656      4  #FFA15A
58980  58980 -2.458477  13.312469      4  #FFA15A
27612  27612  0.343376  13.500982      4  #FFA15A
36640  36640 -0.747245  13.325110      4  #FFA15A
53873  53873 -1.107094  12.985000      4  #FFA15A
