Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 54 additions & 36 deletions mne/viz/_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import os.path as op
import sys
import warnings
from collections import OrderedDict
from collections.abc import Iterable
from functools import partial

Expand Down Expand Up @@ -416,6 +417,50 @@ def plot_evoked_field(evoked, surf_maps, time=None, time_label='t = %0.0f ms',
return renderer.scene()


def _get_meg_surf(meg, info, picks, trans, coord_frame, warn_meg):
from ..forward import _create_meg_coils
meg_rrs, meg_tris = list(), list()
coil_transs = [_loc_to_coil_trans(info['chs'][pick]['loc'])
for pick in picks]
coils = _create_meg_coils(
[info['chs'][pick] for pick in picks], acc='normal')
offset = 0
for coil, coil_trans in zip(coils, coil_transs):
rrs, tris = _sensor_shape(coil)
rrs = apply_trans(coil_trans, rrs)
meg_rrs.append(rrs)
meg_tris.append(tris + offset)
offset += len(meg_rrs[-1])
if len(meg_rrs) == 0:
if warn_meg:
warn('MEG sensors not found. Cannot plot MEG locations.')
else:
meg_rrs = apply_trans(trans,
np.concatenate(meg_rrs, axis=0))
meg_tris = np.concatenate(meg_tris, axis=0)
return dict(rr=meg_rrs, tris=meg_tris)


def _get_coord_frame_trans(coord_frame, info, trans):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems confusing, could you just transform to head if in "meg", then transform to mri if in originally in "meg" (now "head" because it was just transformed) or "head"? That seems like it might be more straightforward

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I guess you might need them multiple times so that might not work

head_mri_t, _ = _get_trans(trans, 'head', 'mri')
dev_head_t, _ = _get_trans(info['dev_head_t'], 'meg', 'head')
trans = OrderedDict()
if coord_frame == 'meg':
trans["head"] = invert_transform(dev_head_t)
trans["meg"] = Transform('meg', 'meg')
trans["mri"] = invert_transform(combine_transforms(
dev_head_t, head_mri_t, 'meg', 'mri'))
elif coord_frame == 'mri':
trans["head"] = head_mri_t
trans["meg"] = combine_transforms(dev_head_t, head_mri_t, 'meg', 'mri')
trans["mri"] = Transform('mri', 'mri')
else: # coord_frame == 'head'
trans["head"] = Transform('head', 'head')
trans["meg"] = dev_head_t
trans["mri"] = invert_transform(head_mri_t)
return trans


@verbose
def plot_alignment(info=None, trans=None, subject=None, subjects_dir=None,
surfaces='auto', coord_frame='head',
Expand Down Expand Up @@ -552,7 +597,7 @@ def plot_alignment(info=None, trans=None, subject=None, subjects_dir=None,

.. versionadded:: 0.15
"""
from ..forward import _create_meg_coils, Forward
from ..forward import Forward
from ..coreg import get_mni_fiducials
# Update the backend
from .backends.renderer import _get_renderer
Expand Down Expand Up @@ -663,22 +708,11 @@ def plot_alignment(info=None, trans=None, subject=None, subjects_dir=None,
trans = _find_trans(subject, subjects_dir)
head_mri_t, _ = _get_trans(trans, 'head', 'mri')
dev_head_t, _ = _get_trans(info['dev_head_t'], 'meg', 'head')
del trans

# Figure out our transformations
if coord_frame == 'meg':
head_trans = invert_transform(dev_head_t)
meg_trans = Transform('meg', 'meg')
mri_trans = invert_transform(combine_transforms(
dev_head_t, head_mri_t, 'meg', 'mri'))
elif coord_frame == 'mri':
head_trans = head_mri_t
meg_trans = combine_transforms(dev_head_t, head_mri_t, 'meg', 'mri')
mri_trans = Transform('mri', 'mri')
else: # coord_frame == 'head'
head_trans = Transform('head', 'head')
meg_trans = dev_head_t
mri_trans = invert_transform(head_mri_t)
trans = _get_coord_frame_trans(coord_frame, info, trans)
head_trans, meg_trans, mri_trans = trans.values()
del trans

# both the head and helmet will be in MRI coordinates after this
surfs = dict()
Expand Down Expand Up @@ -845,7 +879,7 @@ def plot_alignment(info=None, trans=None, subject=None, subjects_dir=None,
mri_trans, head_trans)

# determine points
meg_rrs, meg_tris = list(), list()
meg_surf = None
hpi_loc = list()
ext_loc = list()
car_loc = list()
Expand All @@ -865,23 +899,8 @@ def plot_alignment(info=None, trans=None, subject=None, subjects_dir=None,
eeg_loc = list()
del eeg
if 'sensors' in meg:
coil_transs = [_loc_to_coil_trans(info['chs'][pick]['loc'])
for pick in meg_picks]
coils = _create_meg_coils([info['chs'][pick] for pick in meg_picks],
acc='normal')
offset = 0
for coil, coil_trans in zip(coils, coil_transs):
rrs, tris = _sensor_shape(coil)
rrs = apply_trans(coil_trans, rrs)
meg_rrs.append(rrs)
meg_tris.append(tris + offset)
offset += len(meg_rrs[-1])
if len(meg_rrs) == 0:
if warn_meg:
warn('MEG sensors not found. Cannot plot MEG locations.')
else:
meg_rrs = apply_trans(meg_trans, np.concatenate(meg_rrs, axis=0))
meg_tris = np.concatenate(meg_tris, axis=0)
meg_surf = _get_meg_surf(meg, info, meg_picks, meg_trans, coord_frame,
warn_meg)
del meg
if dig:
if dig == 'fiducials':
Expand Down Expand Up @@ -1039,10 +1058,9 @@ def plot_alignment(info=None, trans=None, subject=None, subjects_dir=None,
glyph_center=(0., -defaults['eegp_height'], 0),
glyph_resolution=20,
backface_culling=True)
if len(meg_rrs) > 0:
if meg_surf is not None and len(meg_surf['rr']) > 0:
color, alpha = (0., 0.25, 0.5), 0.25
surf = dict(rr=meg_rrs, tris=meg_tris)
renderer.surface(surface=surf, color=color,
renderer.surface(surface=meg_surf, color=color,
opacity=alpha, backface_culling=True)

if src is not None:
Expand Down
88 changes: 86 additions & 2 deletions mne/viz/_brain/_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,18 @@
from .callback import (ShowView, TimeCallBack, SmartCallBack,
UpdateLUT, UpdateColorbarScale)

from ...defaults import DEFAULTS
from ..utils import (_show_help_fig, _get_color_list, concatenate_images,
_generate_default_filename, _save_ndarray_img)
from .._3d import _process_clim, _handle_time, _check_views
from .._3d import (_process_clim, _handle_time, _check_views, _get_meg_surf,
_get_coord_frame_trans)

from ...externals.decorator import decorator
from ...io.pick import pick_types
from ...defaults import _handle_default
from ...surface import mesh_edges
from ...source_space import SourceSpaces, vertex_to_mni, read_talxfm
from ...transforms import apply_trans, invert_transform
from ...transforms import apply_trans, invert_transform, _find_trans
from ...utils import (_check_option, logger, verbose, fill_doc, _validate_type,
use_log_level, Bunch, _ReuseCycle, warn,
get_subjects_dir)
Expand Down Expand Up @@ -1770,6 +1773,87 @@ def _cortex_colormap(self, cortex):
)
return colormap_map[cortex]

@fill_doc
def add_sensors(self, raw, meg=True, eeg=True, trans=None):
"""Display the sensors.

Parameters
----------
raw : instance of Raw
The raw data to plot.
meg : bool
If True, display MEG sensors. Defaults to True.
eeg : bool
If True, display EEG sensors. Defaults to True.
%(trans)s
"""
info = raw.info.copy()
self._sensors_data = raw.get_data()
_validate_type(info, dict, 'info')
_validate_type(meg, bool, 'meg')
_validate_type(eeg, bool, 'eeg')
defaults = DEFAULTS['coreg']
if trans == 'auto':
trans = _find_trans(self._subject_id, self._subjects_dir)
trans = _get_coord_frame_trans('mri', info, trans)
head_trans, meg_trans, _ = trans.values()
del trans
if meg:
meg_picks = pick_types(info, meg=True, eeg=False, ref_meg=False)
meg_surf = _get_meg_surf(meg=["sensors"], info=info,
picks=meg_picks, trans=meg_trans,
coord_frame='mri', warn_meg=False)
color, alpha = (0., 0.25, 0.5), 0.25
self._renderer.surface(surface=meg_surf, color=color,
opacity=alpha, backface_culling=True)
if eeg:
eeg_picks = pick_types(info, meg=False, eeg=True, ref_meg=False)
eeg_loc = np.array([info['chs'][k]['loc'][:3] for k in eeg_picks])
eeg_loc = apply_trans(head_trans, eeg_loc)
color = defaults['eeg_color']
scale = defaults['eeg_scale']
alpha = 0.8

import vtk
from pyvista import UnstructuredGrid
from ..backends._pyvista import _sphere_glyph
points = eeg_loc
n_points = len(points)
cell_type = np.full(n_points, vtk.VTK_VERTEX)
cells = np.c_[np.full(n_points, 1), range(n_points)]
args = (cells, cell_type, points)
self._eeg_grid = UnstructuredGrid(*args)
mapper = _sphere_glyph(self._eeg_grid, factor=scale)
actor = self._renderer._actor(mapper)
self._renderer.plotter.add_actor(actor)

def _func(value):
scalars = self._sensors_data[eeg_picks.tolist(), int(value)]
rng = [np.min(scalars), np.max(scalars)]
self._renderer._set_mesh_scalars(
mesh=self._eeg_grid,
scalars=scalars,
name="data"
)
mapper.SetScalarRange(*rng)
self._renderer._update()

self.widgets = dict()
self.callbacks = dict()
self._renderer._dock_initialize()
self.widgets["sensors"] = self._renderer._dock_add_slider(
name="Sensors",
value=0,
rng=[0, self._sensors_data.shape[1] - 1],
double=True,
callback=_func,
compact=False,
)
self._renderer._dock_finalize()
# self.callbacks["sensors"].widget = self.widgets["sensors"]

self._renderer._update()

@verbose
def add_data(self, array, fmin=None, fmid=None, fmax=None,
thresh=None, center=None, transparent=False, colormap="auto",
Expand Down
4 changes: 4 additions & 0 deletions mne/viz/_brain/tests/test_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,10 @@ def __init__(self):
brain.add_foci([0], coords_as_verts=True,
hemi=hemi, color='blue')

# add sensors
evoked = read_evokeds(fname_evoked, baseline=(None, 0))[0]
brain.add_sensors(info=evoked.info)

# add text
brain.add_text(x=0, y=0, text='foo')
brain.close()
Expand Down
15 changes: 15 additions & 0 deletions mne/viz/backends/_pyvista.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,6 +1080,21 @@ def _arrow_glyph(grid, factor):
return mapper


def _sphere_glyph(grid, factor):
glyph = vtk.vtkSphereSource()
glyph.Update()

alg = _glyph(
grid,
factor=factor,
scale_mode=False,
geom=glyph.GetOutputPort(),
)
mapper = vtk.vtkDataSetMapper()
mapper.SetInputConnection(alg.GetOutputPort())
return mapper


def _glyph(dataset, scale_mode='scalar', orient=True, scalars=True, factor=1.0,
geom=None, tolerance=0.0, absolute=False, clamping=False, rng=None):
if geom is None:
Expand Down