In [1]:
%matplotlib widget

from ipywidgets import HBox, VBox, Label, IntSlider, Accordion, Button, Output
from ipyevents import Event
import pathlib
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backend_bases import MouseButton
from matplotlib.widgets import Cursor
from functools import partial

import nilearn
from nilearn.plotting import plot_anat
import nibabel
import mne
from mne.transforms import (read_trans, invert_transform, combine_transforms,
                            apply_trans, Transform)

data_path = pathlib.Path('data')
subjects_dir = data_path / 'subjects'
subject = 'sample'

# MRI image.
t1_fname = str(subjects_dir / subject/ 'mri' / 'T1.mgz')
t1_img = nilearn.image.load_img(t1_fname)
del t1_fname

# Transformation matrices
#
# HEAD <> MRI (Surface RAS)
trans_fname = data_path / 'sample-trans.fif'
head_to_mri_t = read_trans(trans_fname)
mri_to_head_t = invert_transform(head_to_mri_t)

# RAS <> VOXEL
ras_to_vox_t = Transform(fro='ras', to='mri_voxel',
                         trans=t1_img.header.get_ras2vox())
vox_to_mri_t = Transform(fro='mri_voxel', to='mri',
                         trans=t1_img.header.get_vox2ras_tkr())

# RAS <> MRI
ras_to_mri_t = combine_transforms(ras_to_vox_t,
                                  vox_to_mri_t,
                                  fro='ras', to='mri')

ras_to_head_t = combine_transforms(ras_to_mri_t,
                                   mri_to_head_t,
                                   fro='ras', to='head')

del mri_to_head_t, ras_to_vox_t, vox_to_mri_t, trans_fname


# BEM solution.
bem_fname = data_path / 'sample-bem-sol.fif'
bem = mne.read_bem_solution(bem_fname, verbose=False)
del bem_fname

# Evoked & Info object.
evoked_fname = data_path / 'sample-ave.fif'
evoked = mne.read_evokeds(evoked_fname, verbose='warning')[0]
evoked.pick_types(meg=True, eeg=True)

info = evoked.info
info['projs'] = []
info['bads'] = []
del evoked_fname


plt.ioff()

In [7]:
plt.close('all')

def handle_click2(event):
    if event.button != MouseButton.LEFT:
        return

    in_ax = event.inaxes
    
    if in_ax.figure in widget['topomap_fig'].values():
        return
    
    for axis, fig in widget['fig'].items():
        if fig is in_ax.figure:
            break

    x = event.xdata
    y = event.ydata
    
    if axis == 'x':
        x_idx = 'y'
        y_idx = 'z'
    elif axis == 'y':
        x_idx = 'x'
        y_idx = 'z'
    elif axis == 'z':
        x_idx = 'x'
        y_idx = 'y'

    state['slice_coord'][x_idx]['val'] = x
    state['slice_coord'][y_idx]['val'] = y
    state['crosshair_pos'][x_idx] = x
    state['crosshair_pos'][y_idx] = y
    widget['fig'][x_idx].axes[0].clear()
    widget['fig'][y_idx].axes[0].clear()
    plot_slice(x_idx, x)
    plot_slice(y_idx, y)

    if (state['dipole_pos']['x'] is not None and 
            state['dipole_ori']['x'] is not None and
            state['dipole_pos'] != state['dipole_ori']):
        draw_dipole_arrows()

    _draw_crosshairs()

#     widget['cursor'][axis].clear(None)
#     in_ax.get_figure().canvas.draw()


def _create_format_coord(axis):
    if axis == 'topomap':

        def format_coord(x, y):
            x *= 1000
            y *= 1000
            x = int(round(x))
            y = int(round(y))
            return f'x={x} mm, y={y} mm'
        
        return format_coord

    if axis == 'x':
        x_label = 'y'
        y_label = 'z'
    elif axis == 'y':
        x_label = 'x'
        y_label = 'z'
    elif axis == 'z':
        x_label = 'x'
        y_label = 'y'
    
    # FIXME
    def format_coord(x, y, x_label, y_label):
        return f'{x_label}={x:.1f} mm, {y_label}={y:.1f} mm'

    return partial(format_coord, x_label=x_label, y_label=y_label)


def handle_leave(event):
    event_in_ax = event.inaxes
#     print(event_in_ax)

    for axis, fig in widget['fig'].items():
        if fig is event_in_ax.figure:
            break

    widget['cursor'][axis].clear(None)
#     widget['cursor'][axis].ax.format_coord = lambda x, y: 'xxxx'
#     widget['cursor'][axis].set_active(False)
    widget['cursor'][axis].ax.figure.canvas.draw()


def handle_enter(event):
    pass
#     event_in_ax = event.inaxes

#     for axis, fig in widget['fig'].items():
#         if fig is event_in_ax.figure:
#             break
    
#     widget['cursor'][axis].set_active(True)
#     event_in_ax.figure.canvas.draw()
#     widget['cursor'][axis].ax.format_coord = _create_format_coord(axis)


def _on_slider_change(change):
    changed_slider = change['owner']
    changed_slider.disabled = True
    new_value = change['new']

    for axis, slider in widget['slider'].items():
        if slider is changed_slider:
            break

    state['slice_coord'][axis] = new_value
    widget['fig'][axis].axes[-1].clear()
    plot_slice(axis=axis, pos=new_value)
    changed_slider.disabled = False


def _update_axis_label(axis):
    label = widget['label']['axis'][axis]
    label_text = state['label_text'][axis]
    label.value = label_text
    

def _update_topomap_label(ch_type):
    label = widget['label']['topomap_' + ch_type]
    label_text = state['label_text']['topomap_' + ch_type]
    label.value = label_text


def _draw_crosshairs():
    # Remove potentially existing crosshairs.
    for axis in widget['fig'].keys():
        ax = widget['fig'][axis].axes[0]
        lines_to_keep = [line for line in ax.lines 
                         if line.get_label() != 'crosshair']
        ax.lines = lines_to_keep

    widget['fig']['x'].axes[0].axvline(state['crosshair_pos']['y'], color='white', label='crosshair', lw=0.5)
    widget['fig']['x'].axes[0].axhline(state['crosshair_pos']['z'], color='white', label='crosshair', lw=0.5)
    widget['fig']['y'].axes[0].axvline(state['crosshair_pos']['x'], color='white', label='crosshair', lw=0.5)
    widget['fig']['y'].axes[0].axhline(state['crosshair_pos']['z'], color='white', label='crosshair', lw=0.5)
    widget['fig']['z'].axes[0].axvline(state['crosshair_pos']['x'], color='white', label='crosshair', lw=0.5)
    widget['fig']['z'].axes[0].axhline(state['crosshair_pos']['y'], color='white', label='crosshair', lw=0.5)

    widget['fig']['x'].canvas.draw()
    widget['fig']['y'].canvas.draw()
    widget['fig']['z'].canvas.draw()

    
def plot_slice(axis, pos):
    old_label_text = state['label_text'][axis]
    new_label_text = old_label_text + ' [updating]'
    state['label_text'][axis] = new_label_text
    _update_axis_label(axis)

    fig = widget['fig'][axis]
    img = nilearn.plotting.plot_anat(t1_img, display_mode=axis,
                                     cut_coords=(pos,),
                                     figure=fig, dim=-0.5)
    img.axes[pos].ax.format_coord = _create_format_coord(axis)
    widget['cursor'][axis] = Cursor(fig.axes[-1],
                                    useblit=True,
                                    color='red', linewidth=2, label='cursor')
    fig.canvas.draw()
    state['label_text'][axis] = old_label_text
    _update_axis_label(axis)


def create_fig():
    fig = plt.figure(figsize=(2,2))
    fig.canvas.mpl_connect('button_press_event', handle_click2)
    fig.canvas.mpl_connect('figure_enter_event', handle_enter)
    fig.canvas.mpl_connect('figure_leave_event', handle_leave)
    fig.canvas.toolbar_visible = False
    fig.canvas.header_visible = False
    fig.canvas.resizable = False
    return fig


def create_topomap_fig():
    fig, ax = plt.subplots(1, figsize=(2,2))
    fig.canvas.toolbar_visible = False
    fig.canvas.header_visible = False
    fig.canvas.resizable = False
    fig.canvas.callbacks.callbacks.clear()
    ax.set_axis_off()
    return fig



# def create_slider(description, min, max, step, initial_value,
#                   callback_func):
#     style = {'description_width': 'initial'}
#     slider = IntSlider(description=description,
#                        min=min, max=max, step=step,
#                        value=initial_value,
#                        continuous_update=True,
#                        style=style,
#                        readout=False)
#     slider.observe(callback_func, names='value')
#     return slider


def draw_dipole_arrows():
    for axis, fig in widget['fig'].items():
        ax = widget['fig'][axis].axes[0]
        # Remove potentially existing arrow.
        artists_to_keep = [artist for artist in ax.artists 
                           if artist.get_label() != 'dipole']
        ax.artists = artists_to_keep
        
        if axis == 'x':
            x_idx = 'y'
            y_idx = 'z'
        elif axis == 'y':
            x_idx = 'x'
            y_idx = 'z'
        elif axis == 'z':
            x_idx = 'x'
            y_idx = 'y'
        
        x = state['dipole_pos'][x_idx]
        y = state['dipole_pos'][y_idx]
        dx = state['dipole_ori'][x_idx] - state['dipole_pos'][x_idx]
        dy = state['dipole_ori'][y_idx] - state['dipole_pos'][y_idx]
        
        ax.arrow(x=x, y=y, dx=dx, dy=dy, color='white',
                 width=3, head_width=15, length_includes_head=True,
                 label='dipole')
        
        fig.canvas.draw()


def plot_sensors(ch_type, evoked):
    evoked.plot_sensors(ch_type=ch_type,
                        title='',
                        show=False,
                        axes=widget['topomap_fig'][ch_type].axes[0])
        
        
def handle_set_dipole_pos(button):
    dipole_pos_ras = state['crosshair_pos'].copy()
    state['dipole_pos'] = dipole_pos_ras
    
    dipole_pos_head = apply_trans(trans=ras_to_head_t, pts=(dipole_pos_ras['x'],
                                                            dipole_pos_ras['y'],
                                                            dipole_pos_ras['z']))
    dipole_pos_head /= 1000
    dipole_pos_head = dict(x=dipole_pos_head[0], y=dipole_pos_head[1], z=dipole_pos_head[2])

    label_text = (f"x={int(round(dipole_pos_ras['x']))}, "
                  f"y={int(round(dipole_pos_ras['y']))}, "
                  f"z={int(round(dipole_pos_ras['z']))} [mm, MRI RAS] — "
                  f"x={round(dipole_pos_head['x'], 3)}, "
                  f"y={round(dipole_pos_head['y'], 3)}, "
                  f"z={round(dipole_pos_head['z'], 3)} [m, MNE Head]")
    widget['label']['dipole_pos'].value = label_text
    
    if (state['dipole_pos']['x'] is not None and 
            state['dipole_ori']['x'] is not None and
            state['dipole_pos'] != state['dipole_ori']):
        draw_dipole_arrows()
        widget['button']['plot_evoked'].disabled = False
    else:
        widget['button']['plot_evoked'].disabled = True


def handle_set_dipole_ori(button):
    dipole_ori_ras = state['crosshair_pos'].copy()
    state['dipole_ori'] = dipole_ori_ras

    dipole_ori_head = apply_trans(trans=ras_to_head_t, pts=(dipole_ori_ras['x'],
                                                            dipole_ori_ras['y'],
                                                            dipole_ori_ras['z']))
    dipole_ori_head /= 1000
    dipole_ori_head = dict(x=dipole_ori_head[0], y=dipole_ori_head[1], z=dipole_ori_head[2])


    label_text = (f"x={int(round(dipole_ori_ras['x']))}, "
                  f"y={int(round(dipole_ori_ras['y']))}, "
                  f"z={int(round(dipole_ori_ras['z']))} [mm, MRI RAS] — "
                  f"x={round(dipole_ori_head['x'], 3)}, "
                  f"y={round(dipole_ori_head['y'], 3)}, "
                  f"z={round(dipole_ori_head['z'], 3)} [m, MNE Head]")

    widget['label']['dipole_ori'].value = label_text

    if (state['dipole_pos']['x'] is not None and 
            state['dipole_ori']['x'] is not None and
            state['dipole_pos'] != state['dipole_ori']):
        draw_dipole_arrows()
        widget['button']['plot_evoked'].disabled = False
    else:
        widget['button']['plot_evoked'].disabled = True

        
def gen_forward_solution(pos, bem, info, trans):
    ori = np.eye(3)
    pos = np.tile(pos, (3, 1))
    dip = mne.Dipole(times=np.arange(3), pos=pos, amplitude=3 * [10e-9],
                     ori=ori, gof=3 * [100])
    with widget['output']:
        fwd, _ = mne.make_forward_dipole(dip, bem=bem, info=info, trans=trans)
    
    return fwd


def gen_evoked(pos, ori, info, fwd):
    leadfield = fwd['sol']['data']
    meeg_data = np.dot(leadfield, ori.T)  # compute forward
    evoked = mne.EvokedArray(meeg_data, info)
    return evoked


def handle_plot_evoked(button):  
    old_topomap_mag_label_text = state['label_text']['topomap_mag']
    new_topomap_mag_label_text = old_topomap_mag_label_text + ' [updating]'
    state['label_text']['topomap_mag'] = new_topomap_mag_label_text

    old_topomap_grad_label_text = state['label_text']['topomap_grad']
    new_topomap_grad_label_text = old_topomap_grad_label_text + ' [updating]'
    state['label_text']['topomap_grad'] = new_topomap_grad_label_text

    old_topomap_eeg_label_text = state['label_text']['topomap_eeg']
    new_topomap_eeg_label_text = old_topomap_eeg_label_text + ' [updating]'
    state['label_text']['topomap_eeg'] = new_topomap_eeg_label_text
    
    for ch_type in ['mag', 'grad', 'eeg']:
        _update_topomap_label(ch_type)

    
    dipole_pos = (state['dipole_pos']['x'],
                  state['dipole_pos']['y'],
                  state['dipole_pos']['z'])
    dipole_ori = (state['dipole_ori']['x'],
                  state['dipole_ori']['y'],
                  state['dipole_ori']['z'])

    dipole_pos = apply_trans(trans=ras_to_head_t, pts=dipole_pos)
    dipole_pos /= 1000

    dipole_ori = apply_trans(trans=ras_to_head_t, pts=dipole_ori)
    dipole_ori /= 1000
    dipole_ori /= np.linalg.norm(dipole_ori)

    dipole_pos = np.array(dipole_pos).reshape(1, 3)
    dipole_ori = np.array(dipole_ori).reshape(1, 3)
    
#     dipole = mne.Dipole(times=[0], pos=dipole_pos, amplitude=[10e-9],
#                         ori=dipole_ori, gof=[100])
    fwd = gen_forward_solution(dipole_pos, bem=bem, info=info, trans=head_to_mri_t)
    evoked = gen_evoked(pos=dipole_pos, ori=dipole_ori, info=info, fwd=fwd)

    for ch_type, fig in widget['topomap_fig'].items():
        ax = fig.axes[0]
        ax.clear()
        evoked.plot_topomap(ch_type=ch_type,
                            colorbar=False,
                            outlines='head',
                            contours=0,
                            nrows=1, ncols=1,
                            times=evoked.times[-1],
                            res=256,
                            show=False,
                            axes=ax)
        ax.set_title(None)
        ax.format_coord = _create_format_coord('topomap')
        fig.canvas.draw()

    state['label_text']['topomap_mag'] = old_topomap_mag_label_text
    state['label_text']['topomap_grad'] = old_topomap_grad_label_text
    state['label_text']['topomap_eeg'] = old_topomap_eeg_label_text
    
    for ch_type in ['mag', 'grad', 'eeg']:
        _update_topomap_label(ch_type)


state = dict()
state['slice_coord'] = dict(x=dict(val=0, min=-60, max=60),
                            y=dict(val=0, min=-70, max=70),
                            z=dict(val=0, min=-20, max=60))
state['crosshair_pos'] = dict(x=0, y=0, z=0)
state['dipole_pos'] = dict(x=None, y=None, z=None)
state['dipole_ori'] = dict(x=None, y=None, z=None)
state['label_text'] = dict(x='sagittal',
                           y='coronal',
                           z='axial',
                           topomap_mag='Evoked magnetometer field',
                           topomap_grad='Evoked gradiometer field',
                           topomap_eeg='Evoked EEG field')
state['dipole_arrows'] = []

widget = dict()
widget['fig'] = dict(x=create_fig(),
                     y=create_fig(),
                     z=create_fig())
widget['topomap_fig'] = dict(mag=create_topomap_fig(),
                             grad=create_topomap_fig(),
                             eeg=create_topomap_fig())
widget['cursor'] = dict(x=None, y=None, z=None)  # Will be created in plot_slice()
# widget['slider'] = {name: create_slider(description='',
#                                         min=state['slice_coord'][name]['min'],
#                                         max=state['slice_coord'][name]['max'],
#                                         step=10,
#                                         initial_value=state['slice_coord'][name]['val'],
#                                         callback_func=_on_slider_change)
#                     for name in state['slice_coord'].keys()}
widget['label'] = dict()
widget['label']['axis'] = dict(x=Label(state['label_text']['x']),
                               y=Label(state['label_text']['y']),
                               z=Label(state['label_text']['z']))      
widget['label']['topomap_mag'] = Label(state['label_text']['topomap_mag'])
widget['label']['topomap_grad'] = Label(state['label_text']['topomap_grad'])
widget['label']['topomap_eeg'] = Label(state['label_text']['topomap_eeg'])
widget['label']['dipole_pos'] = Label('Not set')
widget['label']['dipole_ori'] = Label('Not set')
widget['label']['dipole_pos_'] = Label('Dipole pos.:')
widget['label']['dipole_ori_'] = Label('Dipole ori.:')
widget['button'] = dict(set_dipole_pos=Button(description='Set dipole origin',
                                              icon='fa-map-marker-alt'),
                        set_dipole_ori=Button(description='Set dipole orientation',
                                              icon='fa-map-marker-alt'),
                        plot_evoked=Button(description='Plot evoked field',
                                           icon='fa-paint-brush',
                                           button_style='success',
                                           disabled=True))

widget['button']['set_dipole_pos'].on_click(handle_set_dipole_pos)
widget['button']['set_dipole_ori'].on_click(handle_set_dipole_ori)
widget['button']['plot_evoked'].on_click(handle_plot_evoked)
widget['output'] = Output()
widget['accordion'] = Accordion(
    children=[widget['output'],
              Label("now this isn't too helpful now, is it")],
    titles=('MNE Output', 'Help'))
widget['accordion'].set_title(0, 'MNE Output')
widget['accordion'].set_title(1, 'Help')



plot_slice('x', state['slice_coord']['x']['val'])
plot_slice('y', state['slice_coord']['y']['val'])
plot_slice('z', state['slice_coord']['z']['val'])
_draw_crosshairs()

for ch_type in ['mag', 'grad', 'eeg']:
    plot_sensors(ch_type=ch_type, evoked=evoked)

button_row = HBox([widget['button']['set_dipole_pos'], widget['button']['set_dipole_ori'], widget['button']['plot_evoked']])
dipole_props_col = VBox([HBox([widget['label']['dipole_pos_'], widget['label']['dipole_pos']]),
                         HBox([widget['label']['dipole_ori_'], widget['label']['dipole_ori']])])
app = VBox([button_row,
            dipole_props_col,
            HBox([VBox([widget['label']['axis']['x'], widget['fig']['x'].canvas]),
                  VBox([widget['label']['axis']['y'], widget['fig']['y'].canvas]),
                          VBox([widget['label']['axis']['z'], widget['fig']['z'].canvas])]),
        #                   VBox([HBox([widget['button']['set_dipole_pos'], widget['label']['dipole_pos']]),
        #                         HBox([widget['button']['set_dipole_ori'], widget['label']['dipole_ori']]),
        #                         widget['button']['plot_evoked']])]),
                    HBox([VBox([widget['label']['topomap_mag'], widget['topomap_fig']['mag'].canvas]),
                          VBox([widget['label']['topomap_grad'], widget['topomap_fig']['grad'].canvas]),
                          VBox([widget['label']['topomap_eeg'], widget['topomap_fig']['eeg'].canvas])]),
                    widget['accordion']])

display(app)

  ax = fh.add_axes([fraction * index * (x1 - x0) + x0, y0,


VBox(children=(HBox(children=(Button(description='Set dipole origin', icon='map-marker-alt', style=ButtonStyle…