In [None]:
%matplotlib widget

import pathlib
import numpy as np
import matplotlib.pyplot as plt
import mne
from ipywidgets import IntSlider, HBox, VBox, Label


REGENERATE_FWDS = False

data_path = pathlib.Path('data')
subjects_dir = data_path / 'subjects'
subject = 'sample'

# Read info.
info_fname = data_path / 'sample-ave.fif'
info = mne.io.read_info(info_fname, verbose='warning')
info['projs'] = []
picks = mne.pick_types(info, meg=True, eeg=True)
info = mne.pick_info(info, picks)
del info_fname, picks

# Read trans.
trans_fname = data_path / 'sample-trans.fif'
trans = mne.read_trans(trans_fname)
del trans_fname

# t1_fname = str(subjects_dir / subject/ 'mri' / 'T1.mgz')
# t1_img= nilearn.image.load_img(t1_fname)
# del t1_fname

if REGENERATE_FWDS:
    # Read BEM.
    bem_fname = data_path / 'sample-bem-sol.fif'
    bem = mne.read_bem_solution(bem_fname, verbose=False)
    del bem_fname

In [None]:
fwds = dict()
z_min, z_max = 60, 120
z_dips = np.arange(z_min, z_max + 1, 10)
x_dip, y_dip = 20, 20

if REGENERATE_FWDS:
    for z_dip in z_dips:
        print(f'==> Generating dipole fwd solution for depth z={z_dip} mm.\n')
        ori = np.eye(3)
        pos = (x_dip, y_dip, z_dip)
        pos_meters = np.array([pos]) / 1e3
        pos_meters = np.tile(pos_meters, (3, 1))
        dip = mne.Dipole(times=np.arange(3), pos=pos_meters, amplitude=3 * [10e-9],
                         ori=ori, gof=3 * [100])
        fwd, _ = mne.make_forward_dipole(dip, bem=bem, info=info, trans=trans,
                                        verbose='error')
        fwd_fname = data_path / f'sample-{pos[0]}-{pos[1]}-{pos[2]}-fwd.fif'
        mne.write_forward_solution(fwd_fname, fwd, overwrite=True)

    del dip, fwd, pos, pos_meters, ori, z_dip

for z_dip in z_dips:
    pos = (x_dip, y_dip, z_dip)
    fwd_fname = data_path / f'sample-{pos[0]}-{pos[1]}-{pos[2]}-fwd.fif'
    fwd = mne.read_forward_solution(fwd_fname, verbose='warning')
    fwd = mne.convert_forward_solution(fwd=fwd, force_fixed=True, verbose='warning')
    fwds[pos] = fwd

del z_dip, pos, fwd_fname, fwd

In [None]:
def plot(phi, theta, z_pos, info, fwds, axes, fig_axes):
    phi_rad = np.deg2rad(phi)
    theta_rad = np.deg2rad(theta)

    ox = np.sin(theta_rad) * np.cos(phi_rad)
    oy = np.sin(theta_rad) * np.sin(phi_rad)
    oz = np.cos(theta_rad)

    ori = np.array([[ox, oy, oz]])
    dip_pos = (x_dip, y_dip, z_pos)
    leadfield = fwds[dip_pos]['sol']['data']
    meeg_data = np.dot(leadfield, ori.T)  # compute forward
    evoked = mne.EvokedArray(meeg_data, info)

    # Save azimuth and elevation from 3D projection – will be restored later!
    azim = axes[0].azim
    elev = axes[0].elev
    
    [ax.clear() for ax in axes]

    for ax_num, ch_type in enumerate(['mag', 'grad', 'eeg'], start=1):
        # FIXME why isn't the scaling working?
#         if ch_type == 'mag':
#             vmax = 10 * 1e15
#         elif ch_type == 'grad':
#             vmax = 10 * 1e13
#         else:
#             vmax = 10 * 1e6
        vmax = None

        evoked.plot_topomap(ch_type=ch_type, colorbar=False,
                            outlines='skirt',
                            nrows=1, ncols=1,
                            times=evoked.times[-1],
                            vmax=vmax,
                            res=256,
                            show=False,
                            axes=axes[ax_num])
        axes[ax_num].set_title(ch_type, fontweight='bold')

    ori = np.array([[ox, oy, oz]])
    pos = np.array([dip_pos]) / 1e3
    dip = mne.Dipole(times=[0], pos=pos, amplitude=[10e-9], ori=ori, gof=[100])

    dip.plot_locations(trans=trans, subject=subject, subjects_dir=subjects_dir,
                       ax=axes[0], show=False, coord_frame='head')
    axes[0].axis('off')
    axes[0].set_title(None)
    axes[0].get_figure().suptitle('Dipole Position & Orientation vs Sensor Signal',
                                  weight ='bold')

    # Restore camera position.
    axes[0].view_init(elev=elev, azim=azim)

    [x.axis('off') for x in fig_axes.ravel()]


def _on_phi_slider_change(change):
    _toggle_sliders_disabled()
    _widget_vals['phi_slider'] = change['new']
    _update_plot()
    _toggle_sliders_disabled()


def _on_theta_slider_change(change):
    _toggle_sliders_disabled()
    _widget_vals['theta_slider'] = change['new']
    _update_plot()
    _toggle_sliders_disabled()


def _on_z_slider_change(change):
    _toggle_sliders_disabled()
    _widget_vals['z_slider'] = change['new']
    _update_plot()
    _toggle_sliders_disabled()


def _toggle_sliders_disabled():
    widgets['phi_slider'].disabled = not widgets['phi_slider'].disabled
    widgets['theta_slider'].disabled = not widgets['theta_slider'].disabled
    widgets['z_slider'].disabled = not widgets['z_slider'].disabled


def _update_plot():
    widgets['fig_title'].value = 'Updating…'
    plot(phi=_widget_vals['phi_slider'],
         theta=_widget_vals['theta_slider'],
         z_pos=_widget_vals['z_slider'],
         fwds=fwds, info=info,
         axes=axes, fig_axes=fig_axes)
    widgets['fig_title'].value = 'Ready.'


def _get_widgets():
    style = {'description_width': 'initial'}

    phi_slider = IntSlider(description='azimuth φ (deg)', min=0, max=360,
                           step=1, value=0, continuous_update=False, style=style)
    phi_slider.observe(_on_phi_slider_change, names='value')

    theta_slider = IntSlider(description='polar angle θ (deg)', min=-90, max=90,
                             step=1, value=45, continuous_update=False, style=style)
    theta_slider.observe(_on_theta_slider_change, names='value')
    
    z_slider = IntSlider(description='z pos (mm)', min=z_min, max=z_max, step=10,
                         value=80, continuous_update=False, orientation='vertical',
                         style=style)
    z_slider.observe(_on_z_slider_change, names='value')
    
    fig_title = Label('Ready.')

    widgets = dict(phi_slider=phi_slider,
                   theta_slider=theta_slider,
                   z_slider=z_slider,
                   fig_title=fig_title)
    return widgets


def _get_axes():
    widths = [8, 3, 3, 3]
    heights = [5]
    
    plt.ioff()
    gs_kw = dict(width_ratios=widths, height_ratios=heights)
    fig, fig_axes = plt.subplots(ncols=4, nrows=1,
                          gridspec_kw=gs_kw, figsize=(10, 3))
    fig.canvas.toolbar_visible = False
    fig.canvas.header_visible = False

    grid_axes = fig_axes[1].get_gridspec()

    axes = []
    axes.append(fig.add_subplot(grid_axes[0], projection='3d'))
    axes[-1].view_init(azim=-135, elev=40)
    axes.append(fig.add_subplot(grid_axes[1]))
    axes.append(fig.add_subplot(grid_axes[2]))
    axes.append(fig.add_subplot(grid_axes[3]))
    return fig, fig_axes, axes

In [None]:
widgets = _get_widgets()
_widget_vals = dict(zip(widgets.keys(), [v.value for v in widgets.values()]))


fig, fig_axes, axes  = _get_axes()

plot(phi=widgets['phi_slider'].value,
     theta=widgets['theta_slider'].value,
     z_pos=widgets['z_slider'].value,
     fwds=fwds,
     info=info,
     axes=axes,
     fig_axes=fig_axes)


# output = interactive_plot.children[-1]
# output.layout.height = '600px'
# interactive_plot

first_row = HBox([widgets['z_slider'], fig.canvas])
second_row = VBox([widgets['phi_slider'], widgets['theta_slider']])
app_layout = VBox([widgets['fig_title'], first_row, second_row])
app_layout