Skip to content

Commit

Permalink
add HPI SNR calculation and plotting (mne-tools#9570)
Browse files Browse the repository at this point in the history
* add SNR option to chpi amplitude calculation

* docstring

* add to python reference

* minor cleanup

Co-authored-by: Eric Larson <larson.eric.d@gmail.com>

* add test

* add viz function

* nicer units

* add to tutorial

* fix rebase snafus

* speed up tutorial

* allow user-passed axes

* simpler title

* fix docstrings

* change plot order

* validate axes arg

* fix errant backtick

* fix test

* add plotting test

* modernize

* add changelog [skip actions][skip azp]

* fix plot legend

* improve viz test

Co-authored-by: Eric Larson <larson.eric.d@gmail.com>
  • Loading branch information
2 people authored and marsipu committed Aug 2, 2021
1 parent 114442c commit 24123b0
Show file tree
Hide file tree
Showing 9 changed files with 292 additions and 46 deletions.
2 changes: 2 additions & 0 deletions doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ Enhancements

- The :meth:`mne.Evoked.get_data` method now has a ``units`` parameter (:gh:`9578` by `Stefan Appelhoff`_)

- Add `mne.chpi.compute_chpi_snr` and `mne.viz.plot_chpi_snr` for computing and plotting the time-varying SNR of continuously-active HPI coils (:gh:`9570` by `Daniel McCloy`_ and `Jussi Nurminen`_)

- Add :func:`mne.get_montage_volume_labels` to find the regions of interest in a Freesurfer atlas anatomical segmentation for an intracranial electrode montage and :func:`mne.viz.plot_channel_labels_circle` to plot them (:gh:`9545` by `Alex Rockhill`_)

- Add :func:`mne.viz.Brain.add_volume_labels` to plot subcortical surfaces and other regions of interest (:gh:`9540` by `Alex Rockhill`_ and `Eric Larson`_)
Expand Down
1 change: 1 addition & 0 deletions doc/preprocessing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ EEG referencing:
:toctree: generated/

compute_chpi_amplitudes
compute_chpi_snr
compute_chpi_locs
compute_head_pos
extract_chpi_locs_ctf
Expand Down
1 change: 1 addition & 0 deletions doc/visualization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Visualization
mne_analyze_colormap
plot_bem
plot_brain_colorbar
plot_chpi_snr
plot_connectivity_circle
plot_cov
plot_channel_labels_circle
Expand Down
181 changes: 149 additions & 32 deletions mne/chpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from .io.kit.kit import RawKIT as _RawKIT
from .io.meas_info import _simplify_info, Info
from .io.pick import (pick_types, pick_channels, pick_channels_regexp,
pick_info)
pick_info, _picks_to_idx)
from .io.proj import Projection, setup_proj
from .io.constants import FIFF
from .io.ctf.trans import _make_ctf_coord_trans_set
Expand Down Expand Up @@ -589,30 +589,35 @@ def _setup_hpi_amplitude_fitting(info, t_window, remove_aliased=False,
if t_window <= 0:
raise ValueError('t_window (%s) must be > 0' % (t_window,))
logger.info('Using time window: %0.1f ms' % (1000 * t_window,))
model_n_window = int(round(float(t_window) * info['sfreq']))
# build model to extract sinusoidal amplitudes.
slope = np.linspace(-0.5, 0.5, model_n_window)[:, np.newaxis]
rps = np.arange(model_n_window)[:, np.newaxis].astype(float)
rps *= 2 * np.pi / info['sfreq'] # radians/sec
f_t = hpi_freqs[np.newaxis, :] * rps
l_t = line_freqs[np.newaxis, :] * rps
model = [np.sin(f_t), np.cos(f_t)] # hpi freqs
model += [np.sin(l_t), np.cos(l_t)] # line freqs
model += [slope, np.ones(slope.shape)]
model = np.concatenate(model, axis=1)
window_nsamp = np.rint(t_window * info['sfreq']).astype(int)
model = _setup_hpi_glm(hpi_freqs, line_freqs, info['sfreq'], window_nsamp)
inv_model = np.linalg.pinv(model)
inv_model_reord = _reorder_inv_model(inv_model, len(hpi_freqs))
proj, proj_op, meg_picks = _setup_ext_proj(info, ext_order)

# include mag and grad picks separately, for SNR computations
mag_picks = _picks_to_idx(info, 'mag', allow_empty=True)
grad_picks = _picks_to_idx(info, 'grad', allow_empty=True)
# Set up magnetic dipole fits
hpi = dict(meg_picks=meg_picks, hpi_pick=hpi_pick,
model=model, inv_model=inv_model, t_window=t_window,
inv_model_reord=inv_model_reord,
on=hpi_ons, n_window=model_n_window, proj=proj, proj_op=proj_op,
freqs=hpi_freqs, line_freqs=line_freqs)
hpi = dict(
meg_picks=meg_picks, mag_picks=mag_picks, grad_picks=grad_picks,
hpi_pick=hpi_pick, model=model, inv_model=inv_model, t_window=t_window,
inv_model_reord=inv_model_reord, on=hpi_ons, n_window=window_nsamp,
proj=proj, proj_op=proj_op, freqs=hpi_freqs, line_freqs=line_freqs)
return hpi


def _setup_hpi_glm(hpi_freqs, line_freqs, sfreq, window_nsamp):
"""Initialize a general linear model for HPI amplitude estimation."""
slope = np.linspace(-0.5, 0.5, window_nsamp)[:, np.newaxis]
radians_per_sec = 2 * np.pi * np.arange(window_nsamp, dtype=float) / sfreq
f_t = hpi_freqs[np.newaxis, :] * radians_per_sec[:, np.newaxis]
l_t = line_freqs[np.newaxis, :] * radians_per_sec[:, np.newaxis]
model = [np.sin(f_t), np.cos(f_t), # hpi freqs
np.sin(l_t), np.cos(l_t), # line freqs
slope, np.ones_like(slope)] # drift, DC
return np.hstack(model)


@jit()
def _reorder_inv_model(inv_model, n_freqs):
# Reorder for faster computation
Expand Down Expand Up @@ -651,14 +656,16 @@ def _time_prefix(fit_time):
return (' t=%0.3f:' % fit_time).ljust(17)


def _fit_chpi_amplitudes(raw, time_sl, hpi):
def _fit_chpi_amplitudes(raw, time_sl, hpi, snr=False):
"""Fit amplitudes for each channel from each of the N cHPI sinusoids.
Returns
-------
sin_fit : ndarray, shape (n_freqs, n_channels))
sin_fit : ndarray, shape (n_freqs, n_channels)
The sin amplitudes matching each cHPI frequency.
Will be all nan if this time window should be skipped.
snr : ndarray, shape (n_freqs, 2)
Estimated SNR for this window, separately for mag and grad channels.
"""
# No need to detrend the data because our model has a DC term
with use_log_level(False):
Expand All @@ -676,6 +683,10 @@ def _fit_chpi_amplitudes(raw, time_sl, hpi):
n_on = ons.all(axis=-1).sum(axis=0)
if not (n_on >= 3).all():
return None
if snr:
return _fast_fit_snr(
this_data, len(hpi['freqs']), hpi['model'], hpi['inv_model'],
hpi['mag_picks'], hpi['grad_picks'])
return _fast_fit(this_data, hpi['proj_op'], len(hpi['freqs']),
hpi['model'], hpi['inv_model_reord'])

Expand All @@ -686,8 +697,8 @@ def _fast_fit(this_data, proj, n_freqs, model, inv_model_reord):
if this_data.shape[1] != model.shape[0]:
model = model[:this_data.shape[1]]
inv_model_reord = _reorder_inv_model(np.linalg.pinv(model), n_freqs)
proj_data = np.dot(proj, this_data)
X = np.dot(inv_model_reord, proj_data.T)
proj_data = proj @ this_data
X = inv_model_reord @ proj_data.T

sin_fit = np.zeros((n_freqs, X.shape[1]))
for fi in range(n_freqs):
Expand All @@ -699,6 +710,34 @@ def _fast_fit(this_data, proj, n_freqs, model, inv_model_reord):
return sin_fit


@jit()
def _fast_fit_snr(this_data, n_freqs, model, inv_model, mag_picks, grad_picks):
# first or last window
if this_data.shape[1] != model.shape[0]:
model = model[:this_data.shape[1]]
inv_model = np.linalg.pinv(model)
coefs = np.ascontiguousarray(inv_model) @ np.ascontiguousarray(this_data.T)
# average sin & cos terms (special property of sinusoids: power=A²/2)
hpi_power = (coefs[:n_freqs] ** 2 + coefs[n_freqs:(2 * n_freqs)] ** 2) / 2
resid = this_data - np.ascontiguousarray((model @ coefs).T)
# can't use np.var(..., axis=1) with Numba, so do it manually:
resid_mean = np.atleast_2d(resid.sum(axis=1) / resid.shape[1]).T
squared_devs = np.abs(resid - resid_mean) ** 2
resid_var = squared_devs.sum(axis=1) / squared_devs.shape[1]
# output array will be (n_freqs, 3 * n_ch_types). The 3 columns for each
# channel type are the SNR, the mean cHPI power and the residual variance
# (which gets tiled to shape (n_freqs,) because it's a scalar).
snrs = np.empty((n_freqs, 0))
# average power & compute residual variance separately for each ch type
for _picks in (mag_picks, grad_picks):
if len(_picks):
avg_power = hpi_power[:, _picks].sum(axis=1) / len(_picks)
avg_resid = resid_var[_picks].mean() * np.ones(n_freqs)
snr = 10 * np.log10(avg_power / avg_resid)
snrs = np.hstack((snrs, np.stack((snr, avg_power, avg_resid), 1)))
return snrs


def _check_chpi_param(chpi_, name):
if name == 'chpi_locs':
want_ndims = dict(times=1, rrs=3, moments=3, gofs=2)
Expand Down Expand Up @@ -913,6 +952,43 @@ def _unit_quat_constraint(x):
return 1 - (x * x).sum()


@verbose
def compute_chpi_snr(raw, t_step_min=0.01, t_window='auto', ext_order=1,
tmin=0, tmax=None, verbose=None):
"""Compute time-varying estimates of cHPI SNR.
Parameters
----------
raw : instance of Raw
Raw data with cHPI information.
t_step_min : float
Minimum time step to use.
%(chpi_t_window)s
%(chpi_ext_order)s
%(raw_tmin)s
%(raw_tmax)s
%(verbose)s
Returns
-------
chpi_snrs : dict
The time-varying cHPI SNR estimates, with entries "times", "freqs",
"snr_mag", "power_mag", and "resid_mag" (and/or "snr_grad",
"power_grad", and "resid_grad", depending on which channel types are
present in ``raw``).
See Also
--------
mne.chpi.compute_chpi_locs, mne.chpi.compute_chpi_amplitudes
Notes
-----
.. versionadded:: 0.24
"""
return _compute_chpi_amp_or_snr(raw, t_step_min, t_window, ext_order,
tmin, tmax, verbose, snr=True)


@verbose
def compute_chpi_amplitudes(raw, t_step_min=0.01, t_window='auto',
ext_order=1, tmin=0, tmax=None, verbose=None):
Expand All @@ -923,8 +999,7 @@ def compute_chpi_amplitudes(raw, t_step_min=0.01, t_window='auto',
raw : instance of Raw
Raw data with cHPI information.
t_step_min : float
Minimum time step to use. If correlations are sufficiently high,
t_step_max will be used.
Minimum time step to use.
%(chpi_t_window)s
%(chpi_ext_order)s
%(raw_tmin)s
Expand All @@ -937,7 +1012,7 @@ def compute_chpi_amplitudes(raw, t_step_min=0.01, t_window='auto',
See Also
--------
mne.chpi.compute_chpi_locs
mne.chpi.compute_chpi_locs, mne.chpi.compute_chpi_snr
Notes
-----
Expand All @@ -962,6 +1037,19 @@ def compute_chpi_amplitudes(raw, t_step_min=0.01, t_window='auto',
.. versionadded:: 0.20
"""
return _compute_chpi_amp_or_snr(raw, t_step_min, t_window, ext_order,
tmin, tmax, verbose)


def _compute_chpi_amp_or_snr(raw, t_step_min=0.01, t_window='auto',
ext_order=1, tmin=0, tmax=None, verbose=None,
snr=False):
"""Compute cHPI amplitude or SNR.
See compute_chpi_amplitudes for parameter descriptions. One additional
boolean parameter ``snr`` signals whether to return SNR instead of
amplitude.
"""
hpi = _setup_hpi_amplitude_fitting(raw.info, t_window, ext_order=ext_order)
tmin, tmax = raw._tmin_tmax_to_start_stop(tmin, tmax)
tmin = tmin / raw.info['sfreq']
Expand All @@ -974,14 +1062,26 @@ def compute_chpi_amplitudes(raw, t_step_min=0.01, t_window='auto',
% (len(hpi['freqs']), len(fit_idxs), tmax - tmin))
del tmin, tmax
sin_fits = dict()
sin_fits['proj'] = hpi['proj']
sin_fits['times'] = np.round(fit_idxs + raw.first_samp -
hpi['n_window'] / 2.) / raw.info['sfreq']
sin_fits['proj'] = hpi['proj']
sin_fits['slopes'] = np.empty(
(len(sin_fits['times']),
len(hpi['freqs']),
len(sin_fits['proj']['data']['col_names'])))
for mi, midpt in enumerate(ProgressBar(fit_idxs, mesg='cHPI amplitudes')):
n_times = len(sin_fits['times'])
n_freqs = len(hpi['freqs'])
n_chans = len(sin_fits['proj']['data']['col_names'])
if snr:
del sin_fits['proj']
sin_fits['freqs'] = hpi['freqs']
ch_types = raw.get_channel_types()
grad_offset = 3 if 'mag' in ch_types else 0
for ch_type in ('mag', 'grad'):
if ch_type in ch_types:
for key in ('snr', 'power', 'resid'):
cols = 1 if key == 'resid' else n_freqs
sin_fits[f'{ch_type}_{key}'] = np.empty((n_times, cols))
else:
sin_fits['slopes'] = np.empty((n_times, n_freqs, n_chans))
message = f"cHPI {'SNRs' if snr else 'amplitudes'}"
for mi, midpt in enumerate(ProgressBar(fit_idxs, mesg=message)):
#
# 0. determine samples to fit.
#
Expand All @@ -992,7 +1092,24 @@ def compute_chpi_amplitudes(raw, t_step_min=0.01, t_window='auto',
#
# 1. Fit amplitudes for each channel from each of the N sinusoids
#
sin_fits['slopes'][mi] = _fit_chpi_amplitudes(raw, time_sl, hpi)
amps_or_snrs = _fit_chpi_amplitudes(raw, time_sl, hpi, snr)
if snr:
# unpack the SNR estimates. mag & grad are returned in one array
# (because of Numba) so take care with which column is which.
# note that mean residual is a scalar (same for all HPI freqs) but
# is returned as a (tiled) vector (again, because Numba) so that's
# why below we take amps_or_snrs[0, 2] instead of [:, 2]
ch_types = raw.get_channel_types()
if 'mag' in ch_types:
sin_fits['mag_snr'][mi] = amps_or_snrs[:, 0] # SNR
sin_fits['mag_power'][mi] = amps_or_snrs[:, 1] # mean power
sin_fits['mag_resid'][mi] = amps_or_snrs[0, 2] # mean resid
if 'grad' in ch_types:
sin_fits['grad_snr'][mi] = amps_or_snrs[:, grad_offset]
sin_fits['grad_power'][mi] = amps_or_snrs[:, grad_offset + 1]
sin_fits['grad_resid'][mi] = amps_or_snrs[0, grad_offset + 2]
else:
sin_fits['slopes'][mi] = amps_or_snrs
return sin_fits


Expand Down
18 changes: 17 additions & 1 deletion mne/tests/test_chpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
RawArray, read_raw_kit)
from mne.io.constants import FIFF
from mne.chpi import (compute_chpi_amplitudes, compute_chpi_locs,
compute_head_pos, _setup_ext_proj,
compute_chpi_snr, compute_head_pos, _setup_ext_proj,
_chpi_locs_to_times_dig, _compute_good_distances,
extract_chpi_locs_ctf, head_pos_to_trans_rot_t,
read_head_pos, write_head_pos, filter_chpi,
Expand Down Expand Up @@ -308,6 +308,22 @@ def test_calculate_chpi_positions_vv():
_calculate_chpi_positions(raw)


@pytest.mark.slowtest
def test_calculate_chpi_snr():
"""Test cHPI SNR calculation."""
raw = read_raw_fif(chpi_fif_fname, allow_maxshield='yes')
result = compute_chpi_snr(raw)
# make sure all the entries are there
keys = {f'{ch_type}_{key}' for ch_type in ('mag', 'grad') for key in
('snr', 'power', 'resid')}
assert set(result) == keys.union({'times', 'freqs'})
# make sure the values are plausible, given the sample data file
assert result['mag_snr'].min() > 1
assert result['mag_snr'].max() < 40
assert result['grad_snr'].min() > 1
assert result['grad_snr'].max() < 40


@testing.requires_testing_data
@pytest.mark.slowtest
def test_calculate_chpi_positions_artemis():
Expand Down
2 changes: 1 addition & 1 deletion mne/viz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .misc import (plot_cov, plot_csd, plot_bem, plot_events,
plot_source_spectrogram, _get_presser,
plot_dipole_amplitudes, plot_ideal_filter, plot_filter,
adjust_axes)
adjust_axes, plot_chpi_snr)
from .evoked import (plot_evoked, plot_evoked_image, plot_evoked_white,
plot_snr_estimate, plot_evoked_topo,
plot_evoked_joint, plot_compare_evokeds)
Expand Down
Loading

0 comments on commit 24123b0

Please sign in to comment.