In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from ipywidgets import interact, SelectionSlider, Layout

import sys
sys.path.append('../code/')
device = 'cpu'

In [None]:
labels = ['L2 Inhib', 'L5 Inhib', 'L2 Prob']

dpl_data = np.load('../data/grid_sweep/dpl_sim_grid.npy')
theta_samples = np.load('../data/grid_sweep/theta_sim_grid.npy')


n_sims, n_params = theta_samples.shape
param_values = [np.unique(theta_samples[:, idx]) for idx in range(n_params)]
lookup_dict = {tuple(theta_samples[idx,:]): idx for idx in range(n_sims)}

slider_dict = {f'p{idx}': SelectionSlider(options=param_values[idx],
               description=labels[idx], style={'description_width': '150px'},
               layout=Layout(width='500px')) for 
               idx in range(len(param_values))}


In [None]:
%matplotlib widget

fig, axes = plt.subplots(1, 4, figsize=(11,3), tight_layout=True)
input_type_list = ['min_amp', 'nn_features', 'raw_waveform']

@interact(**slider_dict)
def plot_dipole(p0, p1, p2):
    for idx in range(4):
        axes[idx].clear()
    cond_idx = lookup_dict[(p0, p1, p2)]

    axes[3].plot(dpl_data[cond_idx,:], color='C3')
    axes[3].set_ylabel('Conditoning Waveform')
    axes[3].set_xlabel('Time (ms)')
    axes[3].set_ylim([-20, 20])
    axes[3].set_title('Dipole')

    for idx, input_type in enumerate(input_type_list):
        arr_img = plt.imread(f'../data/{input_type}/params3_sims10240_{input_type}_marginals_{cond_idx}.png')
        im = OffsetImage(arr_img, zoom=0.23)
        ab = AnnotationBbox(im, (1, 0), xycoords='data', box_alignment=(1.1,0), )
        axes[idx].add_artist(ab)
        axes[idx].axis('off')
        axes[idx].set_title(input_type)

