In [1]:
import numpy as np
from sklearn.decomposition import PCA
from scipy.interpolate import interp1d as interp1d

import plotly.graph_objects as go
from plotly.subplots import make_subplots
from ipywidgets import widgets, interact

import color_utils

In [2]:
def spec2spec(wavelengths1, values1, wavelengths2):
    """ Interpolate a spectrum sampled with intervals wavelengths1 to intervals wavelengths2. """
    interp_spec = interp1d(wavelengths1, values1, kind='cubic', bounds_error=False, fill_value='extrapolate')
    return interp_spec(wavelengths2)

In [3]:
WAVELENGTHS = np.linspace(400., 700., 33)
WP1_DATASET = np.load('data/wp-1.npy')
WP2_DATASET = np.load('data/wp-2.npy')
WP3_DATASET = np.load('data/wp-3.npy')

SENSOR_DATABASE_WAVELENGTHS = np.linspace(400., 720., 33)
NIKON_RESPONSE = color_utils.spec2spec(SENSOR_DATABASE_WAVELENGTHS, np.load('data/nikon5100.npy'), WAVELENGTHS)

In [4]:
wp_label = widgets.Label(value = 'Number of distinct waveplate layers in filter (DOF - 1, distinct meaning they have different properties):', layout = widgets.Layout(width = "99%"))
wp_rb = widgets.RadioButtons(options = ['1', '2', '3'], value = '2', disabled = False, layout = widgets.Layout(width = "99%"))

sensor_label = widgets.Label(value = 'Multiply filter spectra by Nikon sensor responses before running PCA: ', layout = widgets.Layout(width = "45%"))
sensor_cb = widgets.Checkbox(description = '', value = False, disabled = False)

basis_label = widgets.Label(value = 'Show basis function: ', layout = widgets.Layout(width = "40%"))
basis_cb1 = widgets.Checkbox(description = '1', value = True, disabled = False)
basis_cb2 = widgets.Checkbox(description = '2', value = True, disabled = False)
basis_cb3 = widgets.Checkbox(description = '3', value = True, disabled = False)
basis_cb4 = widgets.Checkbox(description = '4', value = True, disabled = False)
basis_cb5 = widgets.Checkbox(description = '5', value = False, disabled = False)
basis_cb6 = widgets.Checkbox(description = '6', value = False, disabled = False)
basis_cb7 = widgets.Checkbox(description = '7', value = False, disabled = False)
basis_cb8 = widgets.Checkbox(description = '8', value = False, disabled = False)
basis_row = widgets.HBox([basis_cb1, basis_cb2, basis_cb3, basis_cb4, basis_cb5, basis_cb6, basis_cb7, basis_cb8])

ev_label = widgets.Label(value = 'Explained variance % for basis functions 1-8: ')
ev_value = widgets.Label(value = '')

pc_label = widgets.Label(value = 'Total explained variance % for *selected* basis functions: ')
pc_value = widgets.Label(value = '')

ui = widgets.VBox([wp_label, wp_rb, sensor_label, sensor_cb, basis_label, basis_row, ev_label, ev_value, pc_label, pc_value])

plots = []

for i in range(8):
    scatter = go.Scatter(x = WAVELENGTHS, y = np.array([]), 
                         mode = 'lines',
                         name = 'basis ' + str(i+1))
    
    plots.append(scatter)

layout = go.Layout(
             margin = dict(l=0, r=0, b=0, t=0),
             xaxis_title = 'wavelength (nm)',
             yaxis_title = 'basis function value'
         )

fig = go.FigureWidget(data = plots, layout = layout)

def update(sensor_cb, waveplate_rb, basis_cb1, basis_cb2, basis_cb3, basis_cb4, 
           basis_cb5, basis_cb6, basis_cb7, basis_cb8):
    basis_cb = [basis_cb1, basis_cb2, basis_cb3, basis_cb4, basis_cb5, basis_cb6, basis_cb7, basis_cb8]
    pca = PCA(n_components = 8)
    
    if waveplate_rb == '1':
        spectra = WP1_DATASET
    elif waveplate_rb == '2':
        spectra = WP2_DATASET
    else:
        spectra = WP3_DATASET
        
    if sensor_cb:
        spectra = spectra.repeat(3, axis = 0) * np.tile(NIKON_RESPONSE, reps = (spectra.shape[0], 1))
    
    pca.fit(spectra)
    
    ev_sum = 0.
    ev_new = '['
    for i in range(8):
        if basis_cb[i]:
            ev_sum += pca.explained_variance_ratio_[i]
            
        ev_new += '{:.2f}%'.format(pca.explained_variance_ratio_[i] * 100.)
            
        if i < 7:
            ev_new += ', '
            
    ev_new += ']'
    ev_value.value = ev_new
    pc_value.value = '{:.2f}%'.format(ev_sum * 100.)

    with fig.batch_update():
        for i in range(8):
            if basis_cb[i]:
                fig.data[i].y = pca.components_[i,:]
            else:
                fig.data[i].y = np.array([])

out = widgets.interactive_output(update, {'sensor_cb': sensor_cb, 'waveplate_rb': wp_rb,
                                          'basis_cb1': basis_cb1, 'basis_cb2': basis_cb2, 'basis_cb3': basis_cb3, 'basis_cb4': basis_cb4,
                                          'basis_cb5': basis_cb5, 'basis_cb6': basis_cb6, 'basis_cb7': basis_cb7, 'basis_cb8': basis_cb8})

display(ui)
fig

VBox(children=(Label(value='Number of distinct waveplate layers in filter (DOF - 1, distinct meaning they have…

FigureWidget({
    'data': [{'mode': 'lines',
              'name': 'basis 1',
              'type': 'scatter'…