In [1]:
import numpy as np
import pandas as pd
import pickle

data_folder = 'output_tsne'

In [2]:
%reload_ext autoreload
%autoreload 2

In [3]:
import plotly.plotly as py
import plotly.graph_objs as go

from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
init_notebook_mode(connected=True)


def plot_surface(pivot_data, z_title='', chart_title=''):
    perps = pivot_data.columns
    lrs = pivot_data.transpose().columns
    data = [
        go.Surface(
            x = perps,
            y = lrs,
            z=pivot_data.as_matrix()
        )
    ]
    layout = go.Layout(
        title=chart_title,
        autosize=True,
        width=800,
        height=800,
        margin=dict(
            l=65,
            r=50,
            b=65,
            t=90
        ),
        scene = dict(
            xaxis = dict(
                title='Perplexity'),
            yaxis = dict(
                title='Learning rate'),
            zaxis = dict(
                title=z_title),)   
    )
    fig = go.Figure(data=data, layout=layout)
    iplot(fig)

In [4]:
def viz_for_dataset(dataset_name='MNIST-SMALL', key_to_pivot='q_link'):
    in_name = '{}/plot_{}.pickle'.format(data_folder, dataset_name)
    pkl_data = pickle.load(open(in_name, 'rb'))
    embeddeds = pkl_data['results']
    
    df = pd.DataFrame.from_records(embeddeds, exclude=["embedding"])
    df.head()
    df.loc[(df['learning_rate'] > 10.0) & df['perplexity'] < 100.0]
    df.head()
    pivot_df = df[['learning_rate', 'perplexity', key_to_pivot]].pivot(
        index='learning_rate', columns='perplexity', values=key_to_pivot)
    
    plot_surface(pivot_data=pivot_df, chart_title=key_to_pivot)

In [5]:
import ipywidgets as widgets

datasetX = widgets.Dropdown(
    options={
        'MNIST small': 'MNIST-SMALL',
        'Country Indicators 1999': 'COUNTRY1999'
    },
    value='COUNTRY1999',
    description='Dataset:',
)

pivotKeyX = widgets.Dropdown(
    options={
        'NLL in LD': 'q_link',
        'NLL in HD': 'p_link',
        'KL-loss': 'loss'
    },
    value='q_link',
    description='Surface: ',
)

ui = widgets.HBox([datasetX, pivotKeyX])

In [6]:
from ipywidgets import interact, interactive_output
out = widgets.interactive_output(viz_for_dataset, 
               {'dataset_name':datasetX,
                'key_to_pivot': pivotKeyX
               })
display(ui, out)

A Jupyter Widget

A Jupyter Widget