In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from ipywidgets import interact, SelectionSlider, Layout

import sys
sys.path.append('../code/')
device = 'cpu'

In [None]:
labels = ['L2 Inhib', 'L5 Inhib', 'L2 Prob']

dpl_data = np.load('../data/grid_sweep/dpl_sim_grid.npy')
theta_samples = np.load('../data/grid_sweep/theta_sim_grid.npy')


n_sims, n_params = theta_samples.shape
param_values = [np.unique(theta_samples[:, idx]) for idx in range(n_params)]
lookup_dict = {tuple(theta_samples[idx,:]): idx for idx in range(n_sims)}

slider_dict = {f'p{idx}': SelectionSlider(options=param_values[idx],
               description=labels[idx], style={'description_width': '150px'},
               layout=Layout(width='500px')) for 
               idx in range(len(param_values))}


In [None]:
%matplotlib widget

fig, axes = plt.subplots(1, 2, figsize=(7,3), tight_layout=True, gridspec_kw={'width_ratios': [1, 0.7]})

@interact(**slider_dict)
def plot_dipole(p0, p1, p2):
    idx = lookup_dict[(p0, p1, p2)]
    for line in axes[0].get_lines(): # ax.lines:
        line.remove()

    axes[1].clear()

    axes[0].plot(dpl_data[idx,:], color='C3')
    axes[0].set_ylabel('Dipole (nAm)')
    axes[0].set_xlabel('Time (ms)')
    axes[0].set_ylim([-20, 20])
    axes[0].set_title(idx)

    arr_img = plt.imread(f'../data/marginal_figures/params3_sims10240_nn_features_marginals_{idx}.png')
    im = OffsetImage(arr_img, zoom=0.25)
    ab = AnnotationBbox(im, (1, 0), xycoords='data', box_alignment=(1.1,-0.1))
    axes[1].add_artist(ab)
    axes[1].axis('off')

