In [1]:
%cd ".."

import itertools

import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from ipywidgets import widgets, interact

c:\Users\kate\hyperspectral


In [2]:
M = 10
dir_prefix = 'data/'

naive_00_sqerr = np.load(dir_prefix + 'naive-n0-meta.npy')
naive_02_sqerr = np.load(dir_prefix + 'naive-n2-meta.npy')
naive_05_sqerr = np.load(dir_prefix + 'naive-n5-meta.npy')

random_00_sqerr = np.load(dir_prefix + 'random-n0-meta.npy')
random_02_sqerr = np.load(dir_prefix + 'random-n2-meta.npy')
random_05_sqerr = np.load(dir_prefix + 'random-n5-meta.npy')

opt_00_sqerr = np.load(dir_prefix + 'optimal-n0-meta.npy')
opt_02_sqerr = np.load(dir_prefix + 'optimal-n2-meta.npy')
opt_05_sqerr = np.load(dir_prefix + 'optimal-n5-meta.npy')

data = [naive_00_sqerr, naive_02_sqerr, naive_05_sqerr, random_00_sqerr, random_02_sqerr, random_05_sqerr, opt_00_sqerr, opt_02_sqerr, opt_05_sqerr]

refl_mean = []
refl_25 = []
refl_50 = []
refl_75 = []
refl_std = []

illum_mean = []
illum_25 = []
illum_50 = []
illum_75 = []
illum_std = []

# Version *with* meta trials
for d in data:
    r_mean = np.mean(d[:, :, :-1], axis = (0, 2))
    r_std = np.std(d[:, :, :-1], axis = (0, 2))
    r_perc = np.percentile(d[:, :, :-1], q = [25, 50, 75], axis = (0, 2))

    refl_mean.append(r_mean)
    refl_std.append(r_std)
    refl_25.append(r_perc[1,:] - r_perc[0,:])  # subtracting just for the sake of the format plotly expects error to be in (relative to main data point)
    refl_50.append(r_perc[1,:])
    refl_75.append(r_perc[2,:] - r_perc[1,:])

    i_mean = np.mean(d[:, :, -1], axis = 0)
    i_std = np.std(d[:, :, -1], axis = 0)
    i_perc = np.percentile(d[:, :, -1], q = [25, 50, 75], axis = 0)

    illum_mean.append(i_mean)
    illum_std.append(i_std)
    illum_25.append(i_perc[1,:] - i_perc[0,:])
    illum_50.append(i_perc[1,:])
    illum_75.append(i_perc[2,:] - i_perc[1,:])
    
# Version *without* meta trials
# for d in data:
#     r_mean = np.mean(d[:, :-1], axis = 1)
#     refl_mean.append(r_mean)

#     i_mean = d[:, -1]
#     illum_mean.append(i_mean)

In [4]:
labels = ['Naive choice (noise std = 0.00)', 'Naive choice (noise std = 0.02)', 'Naive choice (noise std = 0.05)',
          'Random choice (noise std = 0.00)', 'Random choice (noise std = 0.02)', 'Random choice (noise std = 0.05)',
          'Optimal choice (noise std = 0.00)', 'Optimal choice (noise std = 0.02)', 'Optimal choice (noise std = 0.05)']

colors = ['#636EFA','#EF553B','#00CC96','#AB63FA','#FFA15A','#19D3F3','#FF6692','#B6E880','#FF97FF','#FECB52']

data_type_label = widgets.Label(value = 'Show reflectance reconstruction data and/or illumination reconstruction data: ')
refl_cb = widgets.Checkbox(description = 'Show reflectance data', value = True, layout = widgets.Layout(width = "30%"))
illum_cb = widgets.Checkbox(description = 'Show illumination data', value = True, layout = widgets.Layout(width = "30%"))
row_data_type = widgets.HBox(children = [refl_cb, illum_cb])

naive_00_cb = widgets.Checkbox(description = labels[0], value = True, layout = widgets.Layout(width = "30%"))
naive_02_cb = widgets.Checkbox(description = labels[1], value = False, layout = widgets.Layout(width = "30%"))
naive_05_cb = widgets.Checkbox(description = labels[2], value = False, layout = widgets.Layout(width = "30%"))

random_00_cb = widgets.Checkbox(description = labels[3], value = False, layout = widgets.Layout(width = "30%"))
random_02_cb = widgets.Checkbox(description = labels[4], value = False, layout = widgets.Layout(width = "30%"))
random_05_cb = widgets.Checkbox(description = labels[5], value = False, layout = widgets.Layout(width = "30%"))

opt_00_cb = widgets.Checkbox(description = labels[6], value = False, layout = widgets.Layout(width = "30%"))
opt_02_cb = widgets.Checkbox(description = labels[7], value = False, layout = widgets.Layout(width = "30%"))
opt_05_cb = widgets.Checkbox(description = labels[8], value = False, layout = widgets.Layout(width = "30%"))

filter_strategy_label = widgets.Label(value = 'Filter choosing strategy (standard deviation of measurement noise): ')

row_00 = widgets.HBox(children = [naive_00_cb, random_00_cb, opt_00_cb], layout = widgets.Layout(width = "100%"))
row_02 = widgets.HBox(children = [naive_02_cb, random_02_cb, opt_02_cb], layout = widgets.Layout(width = "100%"))
row_05 = widgets.HBox(children = [naive_05_cb, random_05_cb, opt_05_cb], layout = widgets.Layout(width = "100%"))

misc_label = widgets.Label(value = 'Results of 20 trials, data scatter point is 50th percentile of reconstruction error, error bars represent 25th and 75th percentiles of reconstruction error')

ui = widgets.VBox(children = [data_type_label, row_data_type, filter_strategy_label, row_00, row_02, row_05, misc_label])

layout = go.Layout(
    xaxis_title_text = 'number of measurements',
    yaxis_title_text = 'reconstruction error (vs. ground truth)',
    yaxis2 = go.layout.YAxis(side = 'right', overlaying = 'y1'),
    showlegend = True
)

dummies = [[go.Scatter(x = np.array([]), y = np.array([]), 
                       mode = 'lines+markers', 
                       line = dict(
                           color = colors[i]
                        ), 
                        error_y = dict(
                            type = 'data',
                            array = np.array([]),
                            arrayminus = np.array([]),
                            visible = True
                        ),
                        name = labels[i] + ', reflectance', 
                        yaxis = 'y1')] 
         + [go.Scatter(x = np.array([]), y = np.array([]), 
                       mode = 'lines+markers', 
                       line = dict(
                           color = colors[i], 
                           dash = 'dash'
                        ), 
                        error_y = dict(
                            type = 'data',
                            array = np.array([]),
                            arrayminus = np.array([]),
                            visible = True
                        ),
                        name = labels[i] + ', illuminant', 
                        yaxis = 'y2')] 
          for i in range(len(labels))]

fig = go.FigureWidget(data=list(itertools.chain(*dummies)), layout=layout)

def update(naive_00_cb, naive_02_cb, naive_05_cb,
           random_00_cb, random_02_cb, random_05_cb,
           opt_00_cb, opt_02_cb, opt_05_cb,
           refl_cb, illum_cb):

    cbs = [naive_00_cb, naive_02_cb, naive_05_cb, random_00_cb, random_02_cb, random_05_cb, opt_00_cb, opt_02_cb, opt_05_cb]

    with fig.batch_update():
        for i in range(len(cbs)):
            if refl_cb and cbs[i]:
                fig.data[i*2].x = np.array(range(1, M))
                fig.data[i*2].y = refl_50[i]
                fig.data[i*2].error_y.array = refl_75[i]
                fig.data[i*2].error_y.arrayminus = refl_25[i]

            else:
                fig.data[i*2].x = np.array([])
                fig.data[i*2].y = np.array([])
                fig.data[i*2].error_y.array = np.array([])
                fig.data[i*2].error_y.arrayminus = np.array([])
        
            if illum_cb and cbs[i]:
                fig.data[i*2+1].x = np.array(range(1, M))
                fig.data[i*2+1].y = illum_50[i]
                fig.data[i*2+1].error_y.array = illum_75[i]
                fig.data[i*2+1].error_y.arrayminus = illum_25[i]

            else:
                fig.data[i*2+1].x = np.array([])
                fig.data[i*2+1].y = np.array([])
                fig.data[i*2+1].error_y.array = np.array([])
                fig.data[i*2+1].error_y.arrayminus = np.array([])

out = widgets.interactive_output(update, {'naive_00_cb': naive_00_cb, 'naive_02_cb': naive_02_cb, 'naive_05_cb': naive_05_cb,
                                          'random_00_cb': random_00_cb, 'random_02_cb': random_02_cb, 'random_05_cb': random_05_cb,
                                          'opt_00_cb': opt_00_cb, 'opt_02_cb': opt_02_cb, 'opt_05_cb': opt_05_cb,
                                          'refl_cb': refl_cb, 'illum_cb': illum_cb})

display(ui)
fig

VBox(children=(Label(value='Show reflectance reconstruction data and/or illumination reconstruction data: '), …

FigureWidget({
    'data': [{'error_y': {'array': array([5.31342135, 2.18519498, 1.69715104, 1.00512257, 0.818…