In [None]:
%matplotlib widget

from ipywidgets import HBox, VBox, Label, IntSlider, Accordion, Button, Output, ToggleButtons
import pathlib
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backend_bases import MouseButton
from functools import partial
import requests
import warnings

import nilearn
from nilearn.plotting import plot_anat
import nibabel
import mne
from mne.transforms import (read_trans, invert_transform, combine_transforms,
                            apply_trans, Transform)

data_path = pathlib.Path('data')
fwd_path = data_path / 'fwd'
subjects_dir = data_path / 'subjects'
subject = 'sample'

# MRI image.
t1_fname = str(subjects_dir / subject/ 'mri' / 'T1.mgz')
t1_img = nilearn.image.load_img(t1_fname)
del t1_fname

# Transformation matrices
#
# HEAD <> MRI (Surface RAS)
trans_fname = data_path / 'sample-trans.fif'
head_to_mri_t = read_trans(trans_fname)
mri_to_head_t = invert_transform(head_to_mri_t)

# RAS <> VOXEL
ras_to_vox_t = Transform(fro='ras', to='mri_voxel',
                         trans=t1_img.header.get_ras2vox())
vox_to_mri_t = Transform(fro='mri_voxel', to='mri',
                         trans=t1_img.header.get_vox2ras_tkr())

# RAS <> MRI
ras_to_mri_t = combine_transforms(ras_to_vox_t,
                                  vox_to_mri_t,
                                  fro='ras', to='mri')

ras_to_head_t = combine_transforms(ras_to_mri_t,
                                   mri_to_head_t,
                                   fro='ras', to='head')

del mri_to_head_t, ras_to_vox_t, vox_to_mri_t, trans_fname


# BEM solution.
# bem_fname = data_path / 'sample-bem-sol.fif'
# bem = mne.read_bem_solution(bem_fname, verbose=False)
# del bem_fname

# Evoked & Info object.
evoked_fname = data_path / 'sample-ave.fif'
evoked = mne.read_evokeds(evoked_fname, verbose='warning')[0]
evoked.pick_types(meg=True, eeg=True)

info = evoked.info
info['projs'] = []
info['bads'] = []
del evoked_fname

# Zenodo repository for retrieving forward solutions.
# zenodo_deposit = '3746500'
# zenodo_url = f'https://zenodo.org/api/records/{zenodo_deposit}'
# zenodo_request = requests.get(zenodo_url)
# zenodo_files = zenodo_request.json()['files']
# del zenodo_request, zenodo_deposit, zenodo_request


plt.ioff()

# This widget will capture the MNE output.
# Create it here so we can use it as a function decorator.
output_widget = Output()

In [None]:
plt.close('all')


def retrieve_fwd_output_widgetfrom_zenodo(subject, dipole_pos, zenodo_files, overwrite=False):
    x, y, z = dipole_pos
    fname = f'{subject}-{x}-{y}-{z}-fwd.fif'
    
    if (fwd_path / fname).exists() and not overwrite:
        return

    for file in zenodo_files:
        if file['key'] == fname:
            fwd_url = file['links']['self']
            fwd_request = requests.get(fwd_url)
            with open(fwd_path / fname, 'wb') as f:
                f.write(fwd_request.content)
                
            return
    
    raise RuntimeError('Could not find the requested forward solution online!')


def retrieve_fwd_from_github(subject, dipole_pos, overwrite=False):
    x, y, z = dipole_pos
    fname = f'{subject}-{x}-{y}-{z}-fwd.fif'
    
    if (fwd_path / fname).exists() and not overwrite:
        return
    
    fwd_url = f'https://github.com/hoechenberger/dipoles_demo_data/raw/master/{fname}'
    fwd_request = requests.get(fwd_url, allow_redirects=True)
    if fwd_request.status_code != 200:
        msg = (f'Could not download the requested forward solution from GitHub!\n'
               f'Download URL was: {fwd_url}')
        raise RuntimeError(msg)
    
    with open(fwd_path / fname, 'wb') as f:
        f.write(fwd_request.content)


def create_head_grid(info, grid_steps=25):
    """Find max. extensoion of the head in either dimension, create a
    grid corresponding to our pre-computed forward solutions.
    """
    xmin, xmax = None, None
    ymin, ymax = None, None
    zmin, zmax = None, None

    for dig in info['dig']:
        x, y, z = dig['r']

        if xmin is None:
            xmin = x
            xmax = x
        elif x < xmin:
            xmin = x
        elif x > xmax:
            xmax = x

        if ymin is None:
            ymin = y
            ymax = y
        elif y < ymin:
            ymin = y
        elif y > ymax:
            ymax = y

        if zmin is None:
            zmin = z
            zmax = z
        elif z < zmin:
            zmin = z
        elif z > zmax:
            zmax = z

    x_grid = np.linspace(start=xmin, stop=xmax, num=grid_steps).round(3)
    y_grid = np.linspace(start=ymin, stop=ymax, num=grid_steps).round(3)
    z_grid = np.linspace(start=zmin, stop=zmax, num=grid_steps).round(3)
    grid = np.meshgrid(x_grid, y_grid, z_grid, indexing='ij', sparse=True)
    return grid


def find_closest(a, x):
    """Find the element in the array a that's closest to the scalar x.
    """
    a = a.squeeze()
    idx = np.abs(a - x).argmin()
    return a[idx]


def handle_click2(event):
    if event.button != MouseButton.LEFT:
        return

    in_ax = event.inaxes
    if in_ax.figure in widget['topomap_fig'].values():
        return

    for axis, fig in widget['fig'].items():
        if fig is in_ax.figure:
            break

    x = event.xdata
    y = event.ydata

    if axis == 'x':
        x_idx = 'y'
        y_idx = 'z'
    elif axis == 'y':
        x_idx = 'x'
        y_idx = 'z'
    elif axis == 'z':
        x_idx = 'x'
        y_idx = 'y'
    
    # Which axis was NOT clicked on?
    remaining_idx = list(set(widget['fig'].keys()) - set([x_idx, y_idx]))[0]

    if state['mode'] == 'slice_browser':
        state['slice_coord'][x_idx]['val'] = x
        state['slice_coord'][y_idx]['val'] = y
        state['crosshair_pos'][x_idx] = x
        state['crosshair_pos'][y_idx] = y
        
        remove_dipole_arrows()
        remove_dipole_ori_markers()
        remove_dipole_pos_markers()
        
        widget['label']['dipole_pos'].value = 'Not set'
        widget['label']['dipole_ori'].value = 'Not set'

        state['dipole_pos']['x'] = None
        state['dipole_pos']['y'] = None
        state['dipole_pos']['z'] = None
        state['dipole_ori']['x'] = None
        state['dipole_ori']['y'] = None
        state['dipole_ori']['z'] = None

        widget['fig'][x_idx].axes[0].clear()
        widget['fig'][y_idx].axes[0].clear()
        plot_slice(x_idx, x)
        plot_slice(y_idx, y)
        reset_topomaps()
    elif state['mode'] == 'set_dipole_pos':
#         for axis in state['dipole_pos'].keys():
#             state['dipole_pos'][axis] = state['slice_coord'][axis]['val']
        
        # Construct the 3D coordinates of the clicked-on point
        dipole_pos_ras = dict()
        dipole_pos_ras[x_idx] = x
        dipole_pos_ras[y_idx] = y
        dipole_pos_ras[remaining_idx] = state['slice_coord'][remaining_idx]['val']
        
        state['dipole_pos'] = dipole_pos_ras
        dipole_pos_head = apply_trans(trans=ras_to_head_t, pts=(dipole_pos_ras['x'],
                                                                dipole_pos_ras['y'],
                                                                dipole_pos_ras['z']))
        dipole_pos_head /= 1000
        dipole_pos_head = dict(x=dipole_pos_head[0], y=dipole_pos_head[1], z=dipole_pos_head[2])

        label_text = (f"x={int(round(dipole_pos_ras['x']))}, "
                      f"y={int(round(dipole_pos_ras['y']))}, "
                      f"z={int(round(dipole_pos_ras['z']))} [mm, MRI RAS] ⟶ "
                      f"x={round(dipole_pos_head['x'], 3)}, "
                      f"y={round(dipole_pos_head['y'], 3)}, "
                      f"z={round(dipole_pos_head['z'], 3)} [m, MNE Head]")
        widget['label']['dipole_pos'].value = label_text

        leave_set_dipole_pos_mode()
    elif state['mode'] == 'set_dipole_ori':
#         for axis in state['dipole_ori'].keys():
#             state['dipole_ori'][axis] = state['slice_coord'][axis]['val']
            
        # Construct the 3D coordinates of the clicked-on point
        dipole_ori_ras = dict()
        dipole_ori_ras[x_idx] = x
        dipole_ori_ras[y_idx] = y
        dipole_ori_ras[remaining_idx] = state['slice_coord'][remaining_idx]['val']

        state['dipole_ori'] = dipole_ori_ras

        dipole_ori_head = apply_trans(trans=ras_to_head_t, pts=(dipole_ori_ras['x'],
                                                                dipole_ori_ras['y'],
                                                                dipole_ori_ras['z']))
        dipole_ori_head /= 1000
        dipole_ori_head = dict(x=dipole_ori_head[0], y=dipole_ori_head[1], z=dipole_ori_head[2])

        label_text = (f"x={int(round(dipole_ori_ras['x']))}, "
                      f"y={int(round(dipole_ori_ras['y']))}, "
                      f"z={int(round(dipole_ori_ras['z']))} [mm, MRI RAS] ⟶ "
                      f"x={round(dipole_ori_head['x'], 3)}, "
                      f"y={round(dipole_ori_head['y'], 3)}, "
                      f"z={round(dipole_ori_head['z'], 3)} [m, MNE Head]")
        widget['label']['dipole_ori'].value = label_text

        leave_set_dipole_ori_mode()

    if state['dipole_pos']['x'] is not None:
        plot_dipole_pos_marker()
        
    if state['dipole_ori']['x'] is not None:
        plot_dipole_ori_marker()

    if (state['dipole_pos']['x'] is not None and 
            state['dipole_ori']['x'] is not None and
            state['dipole_pos'] != state['dipole_ori']):
        draw_dipole_arrows()

    _draw_crosshairs()
    
    if (state['dipole_pos']['x'] is not None and 
            state['dipole_ori']['x'] is not None and
            state['dipole_pos'] != state['dipole_ori']):
        plot_evoked()


def plot_dipole_pos_marker():
    remove_dipole_pos_markers()
    for axis in state['dipole_pos'].keys():
        if axis == 'x':
            x_idx = 'y'
            y_idx = 'z'
        elif axis == 'y':
            x_idx = 'x'
            y_idx = 'z'
        elif axis == 'z':
            x_idx = 'x'
            y_idx = 'y'
        
        x = state['dipole_pos'][x_idx]
        y = state['dipole_pos'][y_idx]
        
        ax = widget['fig'][axis].axes[0]
        markers['dipole_pos'][axis] = ax.scatter(x, y, marker='o', s=50, facecolors='r',
                                                 edgecolors='r', label='dipole_pos_marker')
        widget['fig'][axis].canvas.draw()


def remove_dipole_pos_markers():
    for axis in state['dipole_pos'].keys():
        ax = widget['fig'][axis].axes[0]
        if markers['dipole_pos'][axis] is not None:
            markers['dipole_pos'][axis].remove()
            markers['dipole_pos'][axis] = None
            widget['fig'][axis].canvas.draw()

                                            
def plot_dipole_ori_marker():
    remove_dipole_ori_markers()
    for axis in state['dipole_ori'].keys():
        if axis == 'x':
            x_idx = 'y'
            y_idx = 'z'
        elif axis == 'y':
            x_idx = 'x'
            y_idx = 'z'
        elif axis == 'z':
            x_idx = 'x'
            y_idx = 'y'
        
        x = state['dipole_ori'][x_idx]
        y = state['dipole_ori'][y_idx]
        
        ax = widget['fig'][axis].axes[0]
        markers['dipole_ori'][axis] = ax.scatter(x, y, marker='x', s=50, facecolors='r',
                                                 edgecolors='r', label='dipole_ori_marker')
        widget['fig'][axis].canvas.draw()

def remove_dipole_ori_markers():
    for axis in state['dipole_ori'].keys():
        ax = widget['fig'][axis].axes[0]
        if markers['dipole_ori'][axis] is not None:
            markers['dipole_ori'][axis].remove()
            markers['dipole_ori'][axis] = None
            widget['fig'][axis].canvas.draw()


def _create_format_coord(axis):
    if axis == 'topomap':

        def format_coord(x, y):
            x *= 1000
            y *= 1000
            x = int(round(x))
            y = int(round(y))
            return f'x={x} mm, y={y} mm'
        
        return format_coord

    if axis == 'x':
        x_label = 'y'
        y_label = 'z'
    elif axis == 'y':
        x_label = 'x'
        y_label = 'z'
    elif axis == 'z':
        x_label = 'x'
        y_label = 'y'
    
    # FIXME
    def format_coord(x, y, x_label, y_label):
        return f'{x_label}={x:.1f} mm, {y_label}={y:.1f} mm'

    return partial(format_coord, x_label=x_label, y_label=y_label)


def handle_leave(event):
    pass


def handle_enter(event):
    pass


def handle_mode_change(change):
    new_mode = change['new']
    if new_mode == 'Slice Browser':
        leave_set_dipole_ori_mode()
        leave_set_dipole_pos_mode()
        state['mode'] = 'slice_browser'
    elif new_mode == 'Set Dipole Origin':
        enter_set_dipole_pos_mode()
        state['mode'] = 'set_dipole_pos'
    elif new_mode == 'Set Dipole Orientation':
        enter_set_dipole_ori_mode()
        state['mode'] = 'set_dipole_ori'


def _update_axis_label(axis):
    label = widget['label']['axis'][axis]
    label_text = state['label_text'][axis]
    label.value = label_text
    

def _update_topomap_label(ch_type):
    label = widget['label']['topomap_' + ch_type]
    label_text = state['label_text']['topomap_' + ch_type]
    label.value = label_text


def _draw_crosshairs():
    # Remove potentially existing crosshairs.
    for axis in widget['fig'].keys():
        ax = widget['fig'][axis].axes[0]
        lines_to_keep = [line for line in ax.lines 
                         if line.get_label() != 'crosshair']
        ax.lines = lines_to_keep

    widget['fig']['x'].axes[0].axvline(state['crosshair_pos']['y'], color='white', label='crosshair', lw=0.5)
    widget['fig']['x'].axes[0].axhline(state['crosshair_pos']['z'], color='white', label='crosshair', lw=0.5)
    widget['fig']['y'].axes[0].axvline(state['crosshair_pos']['x'], color='white', label='crosshair', lw=0.5)
    widget['fig']['y'].axes[0].axhline(state['crosshair_pos']['z'], color='white', label='crosshair', lw=0.5)
    widget['fig']['z'].axes[0].axvline(state['crosshair_pos']['x'], color='white', label='crosshair', lw=0.5)
    widget['fig']['z'].axes[0].axhline(state['crosshair_pos']['y'], color='white', label='crosshair', lw=0.5)

    widget['fig']['x'].canvas.draw()
    widget['fig']['y'].canvas.draw()
    widget['fig']['z'].canvas.draw()

    
def plot_slice(axis, pos):
    old_label_text = state['label_text'][axis]
    new_label_text = old_label_text + ' [updating]'
    state['label_text'][axis] = new_label_text
    _update_axis_label(axis)

    fig = widget['fig'][axis]
    
    with warnings.catch_warnings():  # Suppress DeprecationWarning
        warnings.simplefilter("ignore")
        img = nilearn.plotting.plot_anat(t1_img, display_mode=axis,
                                         cut_coords=(pos,),
                                         figure=fig, dim=-0.5)

    img.axes[pos].ax.format_coord = _create_format_coord(axis)
    fig.canvas.draw()
    state['label_text'][axis] = old_label_text
    _update_axis_label(axis)


def create_fig():
    fig = plt.figure(figsize=(2,2))
    fig.canvas.mpl_connect('button_press_event', handle_click2)
    fig.canvas.mpl_connect('figure_enter_event', handle_enter)
    fig.canvas.mpl_connect('figure_leave_event', handle_leave)
    fig.canvas.toolbar_visible = False
    fig.canvas.header_visible = False
    fig.canvas.resizable = False
    return fig


def create_topomap_fig():
    fig, ax = plt.subplots(1, figsize=(2,2))
    fig.canvas.toolbar_visible = False
    fig.canvas.header_visible = False
    fig.canvas.resizable = False
    fig.canvas.callbacks.callbacks.clear()
    ax.set_axis_off()
    fig.set_tight_layout(True)
    return fig


def remove_dipole_arrows():
    for axis, fig in widget['fig'].items():
        ax = widget['fig'][axis].axes[0]
        artists_to_keep = [artist for artist in ax.artists 
                           if artist.get_label() != 'dipole']
        ax.artists = artists_to_keep
        fig.canvas.draw()


def draw_dipole_arrows():
    remove_dipole_arrows()
    for axis, fig in widget['fig'].items():
        ax = widget['fig'][axis].axes[0]
        
        if axis == 'x':
            x_idx = 'y'
            y_idx = 'z'
        elif axis == 'y':
            x_idx = 'x'
            y_idx = 'z'
        elif axis == 'z':
            x_idx = 'x'
            y_idx = 'y'
        
        x = state['dipole_pos'][x_idx]
        y = state['dipole_pos'][y_idx]
        dx = state['dipole_ori'][x_idx] - state['dipole_pos'][x_idx]
        dy = state['dipole_ori'][y_idx] - state['dipole_pos'][y_idx]
        
        ax.arrow(x=x, y=y, dx=dx, dy=dy, color='white',
                 width=3, head_width=15, length_includes_head=True,
                 label='dipole')
        
        fig.canvas.draw()


def reset_topomaps():
    for ch_type in ['mag', 'grad', 'eeg']:
        widget['topomap_fig'][ch_type].axes[0].clear()
        plot_sensors(ch_type=ch_type, evoked=evoked)
        widget['topomap_fig'][ch_type].canvas.draw()


def plot_sensors(ch_type, evoked):
    evoked.plot_sensors(ch_type=ch_type,
                        title='',
                        show=False,
                        axes=widget['topomap_fig'][ch_type].axes[0])

    
def enter_set_dipole_pos_mode():
    pass


def leave_set_dipole_pos_mode():
    pass


def enter_set_dipole_ori_mode():
    pass


def leave_set_dipole_ori_mode():
    pass

        
def gen_forward_solution(pos, bem, info, trans):
    ori = np.eye(3)
    pos = np.tile(pos, (3, 1))
    dip = mne.Dipole(times=np.arange(3), pos=pos, amplitude=3 * [10e-9],
                     ori=ori, gof=3 * [100])
    with widget['output']:
        fwd, _ = mne.make_forward_dipole(dip, bem=bem, info=info, trans=trans)
    
    return fwd


def gen_evoked(pos, ori, info, fwd):
    leadfield = fwd['sol']['data']
    meeg_data = np.dot(leadfield, ori.T)  # compute forward
    evoked = mne.EvokedArray(meeg_data, info)
    return evoked


@output_widget.capture(clear_output=True)
def plot_evoked():
    old_topomap_mag_label_text = state['label_text']['topomap_mag']
    new_topomap_mag_label_text = old_topomap_mag_label_text + ' [updating]'
    state['label_text']['topomap_mag'] = new_topomap_mag_label_text

    old_topomap_grad_label_text = state['label_text']['topomap_grad']
    new_topomap_grad_label_text = old_topomap_grad_label_text + ' [updating]'
    state['label_text']['topomap_grad'] = new_topomap_grad_label_text

    old_topomap_eeg_label_text = state['label_text']['topomap_eeg']
    new_topomap_eeg_label_text = old_topomap_eeg_label_text + ' [updating]'
    state['label_text']['topomap_eeg'] = new_topomap_eeg_label_text
    
    for ch_type in ['mag', 'grad', 'eeg']:
        _update_topomap_label(ch_type)

    
    dipole_pos = (state['dipole_pos']['x'],
                  state['dipole_pos']['y'],
                  state['dipole_pos']['z'])
    dipole_ori = (state['dipole_ori']['x'],
                  state['dipole_ori']['y'],
                  state['dipole_ori']['z'])

    dipole_pos = apply_trans(trans=ras_to_head_t, pts=dipole_pos)
    dipole_pos /= 1000

    dipole_ori = apply_trans(trans=ras_to_head_t, pts=dipole_ori)
    dipole_ori /= 1000
    dipole_ori /= np.linalg.norm(dipole_ori)

    dipole_pos = np.array(dipole_pos).reshape(1, 3).round(3)
    dipole_ori = np.array(dipole_ori).reshape(1, 3).round(3)
    
#     fwd = gen_forward_solution(dipole_pos, bem=bem, info=info, trans=head_to_mri_t)

    # Retrieve the dipole pos closest to the one we have a pre-calculated fwd for.
    pos_head_grid = create_head_grid(info=info)
    dipole_pos_for_fwd = (find_closest(pos_head_grid[0], dipole_pos[0, 0]),
                          find_closest(pos_head_grid[1], dipole_pos[0, 1]),
                          find_closest(pos_head_grid[2], dipole_pos[0, 2]))
    
    print(f'Requested calculations for dipole located at:\n'
          f'    x={dipole_pos[0, 0]}, y={dipole_pos[0, 1]}, z={dipole_pos[0, 2]} [m, MNE Head]\n'
          f'Using a forward solution for the following location:\n'
          f'    x={dipole_pos_for_fwd[0]}, y={dipole_pos_for_fwd[1]}, z={dipole_pos_for_fwd[2]} [m, MNE Head]\n')

    fwd_fname = f'{subject}-{dipole_pos_for_fwd[0]}-{dipole_pos_for_fwd[1]}-{dipole_pos_for_fwd[2]}-fwd.fif'
    if (fwd_path / fwd_fname).exists():
        print('\nUsing existing forward solution: {fwd_fname}')
    else:
#         print('Retrieving forward solution from Zenodo.')
#         retrieve_fwd_from_zenodo(subject=subject, dipole_pos=dipole_pos_for_fwd,
#                                  zenodo_files=zenodo_files)
        print('Retrieving forward solution from GitHub.')
        retrieve_fwd_from_github(subject=subject, dipole_pos=dipole_pos_for_fwd)

    fwd = mne.read_forward_solution(fwd_path / fwd_fname)
    del fwd_fname, pos_head_grid, dipole_pos_for_fwd

    fwd = mne.forward.convert_forward_solution(fwd=fwd, force_fixed=True)
    evoked = gen_evoked(pos=dipole_pos, ori=dipole_ori, info=info, fwd=fwd)

    for ch_type, fig in widget['topomap_fig'].items():
        ax = fig.axes[0]
        ax.clear()
        evoked.plot_topomap(ch_type=ch_type,
                            colorbar=False,
                            outlines='head',
                            contours=0,
                            nrows=1, ncols=1,
                            times=evoked.times[-1],
                            res=256,
                            show=False,
                            axes=ax)
        ax.set_title(None)
        ax.format_coord = _create_format_coord('topomap')
        fig.canvas.draw()

    state['label_text']['topomap_mag'] = old_topomap_mag_label_text
    state['label_text']['topomap_grad'] = old_topomap_grad_label_text
    state['label_text']['topomap_eeg'] = old_topomap_eeg_label_text
    
    for ch_type in ['mag', 'grad', 'eeg']:
        _update_topomap_label(ch_type)


state = dict()
state['slice_coord'] = dict(x=dict(val=0, min=-60, max=60),
                            y=dict(val=0, min=-70, max=70),
                            z=dict(val=0, min=-20, max=60))
state['crosshair_pos'] = dict(x=0, y=0, z=0)
state['dipole_pos'] = dict(x=None, y=None, z=None)
state['dipole_ori'] = dict(x=None, y=None, z=None)
state['label_text'] = dict(x='sagittal',
                           y='coronal',
                           z='axial',
                           topomap_mag='Evoked magnetometer field',
                           topomap_grad='Evoked gradiometer field',
                           topomap_eeg='Evoked EEG field')
state['dipole_arrows'] = []
state['mode'] = 'slice_browser'

widget = dict()
widget['fig'] = dict(x=create_fig(),
                     y=create_fig(),
                     z=create_fig())
widget['topomap_fig'] = dict(mag=create_topomap_fig(),
                             grad=create_topomap_fig(),
                             eeg=create_topomap_fig())
widget['label'] = dict()
widget['label']['axis'] = dict(x=Label(state['label_text']['x']),
                               y=Label(state['label_text']['y']),
                               z=Label(state['label_text']['z']))      
widget['label']['topomap_mag'] = Label(state['label_text']['topomap_mag'])
widget['label']['topomap_grad'] = Label(state['label_text']['topomap_grad'])
widget['label']['topomap_eeg'] = Label(state['label_text']['topomap_eeg'])
widget['label']['dipole_pos'] = Label('Not set')
widget['label']['dipole_ori'] = Label('Not set')
widget['label']['dipole_pos_'] = Label('Dipole origin:')
widget['label']['dipole_ori_'] = Label('Dipole orientation:')
widget['toggle_buttons'] = dict(mode_selector=ToggleButtons(
    options=['Slice Browser', 'Set Dipole Origin', 'Set Dipole Orientation']))
widget['toggle_buttons']['mode_selector'].observe(handle_mode_change, 'value')
widget['output'] = output_widget
widget['accordion'] = Accordion(
    children=[widget['output'],
              Label("now this isn't too helpful now, is it")],
    titles=('MNE Output', 'Help'))
widget['accordion'].set_title(0, 'MNE Output')
widget['accordion'].set_title(1, 'Help')

markers = dict()
markers['dipole_pos'] = dict(x=None, y=None, z=None)
markers['dipole_ori'] = dict(x=None, y=None, z=None)

plot_slice('x', state['slice_coord']['x']['val'])
plot_slice('y', state['slice_coord']['y']['val'])
plot_slice('z', state['slice_coord']['z']['val'])
_draw_crosshairs()

for ch_type in ['mag', 'grad', 'eeg']:
    plot_sensors(ch_type=ch_type, evoked=evoked)

dipole_props_col = VBox([HBox([widget['label']['dipole_pos_'], widget['label']['dipole_pos']]),
                         HBox([widget['label']['dipole_ori_'], widget['label']['dipole_ori']])])
app = VBox([widget['toggle_buttons']['mode_selector'],
            dipole_props_col,
            HBox([VBox([widget['label']['axis']['x'], widget['fig']['x'].canvas]),
                  VBox([widget['label']['axis']['y'], widget['fig']['y'].canvas]),
                          VBox([widget['label']['axis']['z'], widget['fig']['z'].canvas])]),
                    HBox([VBox([widget['label']['topomap_mag'], widget['topomap_fig']['mag'].canvas]),
                          VBox([widget['label']['topomap_grad'], widget['topomap_fig']['grad'].canvas]),
                          VBox([widget['label']['topomap_eeg'], widget['topomap_fig']['eeg'].canvas])]),
                    widget['accordion']])

display(app)