Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MRG: Introduce ICA.get_explained_variance_ratio() to easily retrieve relative explained variances after a fit #11141

Merged
merged 17 commits into from
Sep 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ Enhancements
- Add ``==`` and ``!=`` comparison between `mne.Projection` objects (:gh:`11147` by `Mathieu Scheltienne`_)
- Parse automatically temperature channel with :func:`mne.io.read_raw_edf` (:gh:`11150` by `Eric Larson`_ and `Alex Gramfort`_)
- Add ``encoding`` parameter to :func:`mne.io.read_raw_edf` and :func:`mne.io.read_raw_bdf` to support custom (non-UTF8) annotation channel encodings (:gh:`11154` by `Clemens Brunner`_)
- :class:`mne.preprocessing.ICA` gained a new method, :meth:`~mne.preprocessing.ICA.get_explained_variance_ratio`, that allows the retrieval of the proportion of variance explained by ICA components (:gh:`11141` by `Richard Höchenberger`_)

Bugs
~~~~
Expand All @@ -58,6 +59,7 @@ Bugs
- Fix bug in :func:`mne.viz.plot_filter` when plotting filters created using ``output='ba'`` mode with ``compensation`` turned on. (:gh:`11040` by `Marian Dovgialo`_)
- Fix bug in :func:`mne.io.read_raw_bti` where EEG, EMG, and H/VEOG channels were not detected properly, and many non-ECG channels were called ECG. The logic has been improved, and any channels of unknown type are now labeled as ``misc`` (:gh:`11102` by `Eric Larson`_)
- Fix bug in :func:`mne.viz.plot_topomap` when providing ``sphere="eeglab"`` (:gh:`11081` by `Mathieu Scheltienne`_)
- The string and HTML representation of :class:`mne.preprocessing.ICA` reported incorrect values for the explained variance. This information has been removed from the representations, and should instead be retrieved via the new :meth:`~mne.preprocessing.ICA.get_explained_variance_ratio` method (:gh:`11141` by `Richard Höchenberger`_)

API changes
~~~~~~~~~~~
Expand Down
4 changes: 0 additions & 4 deletions mne/html_templates/repr/ica.html.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@
<th>ICA components</th>
<td>{{ n_components }}</td>
</tr>
<tr>
<th>Explained variance</th>
<td>{{ (explained_variance * 100) | round(1) }}&nbsp;%</td>
</tr>
<tr>
<th>Available PCA components</th>
<td>{{ n_pca_components }}</td>
Expand Down
149 changes: 139 additions & 10 deletions mne/preprocessing/ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@

from inspect import isfunction
from collections import namedtuple
from collections.abc import Sequence
from copy import deepcopy
from numbers import Integral
from time import time
from dataclasses import dataclass
from typing import Optional, List
import warnings

import math
import json
Expand Down Expand Up @@ -452,7 +454,6 @@ class _InfosForRepr:
fit_n_samples: Optional[int]
fit_n_components: Optional[int]
fit_n_pca_components: Optional[int]
fit_explained_variance: Optional[float]
ch_types: List[str]
excludes: List[str]

Expand All @@ -470,11 +471,6 @@ class _InfosForRepr:
fit_n_pca_components = getattr(self, 'pca_components_', None)
if fit_n_pca_components is not None:
fit_n_pca_components = len(self.pca_components_)
fit_explained_variance = getattr(self, 'pca_explained_variance_', None)
if fit_explained_variance is not None:
abs_vars = self.pca_explained_variance_
rel_vars = abs_vars / abs_vars.sum()
fit_explained_variance = rel_vars[:fit_n_components].sum()

if self.info is not None:
ch_types = [c for c in _DATA_CH_TYPES_SPLIT if c in self]
Expand All @@ -493,7 +489,6 @@ class _InfosForRepr:
fit_n_samples=fit_n_samples,
fit_n_components=fit_n_components,
fit_n_pca_components=fit_n_pca_components,
fit_explained_variance=fit_explained_variance,
ch_types=ch_types,
excludes=excludes
)
Expand All @@ -511,8 +506,6 @@ def __repr__(self):
f' (fit in {infos.fit_n_iter} iterations on '
f'{infos.fit_n_samples} samples), '
f'{infos.fit_n_components} ICA components '
f'explaining {round(infos.fit_explained_variance * 100, 1)} % '
f'of variance '
f'({infos.fit_n_pca_components} PCA components available), '
f'channel types: {", ".join(infos.ch_types)}, '
f'{len(infos.excludes) or "no"} sources marked for exclusion'
Expand All @@ -531,7 +524,6 @@ def _repr_html_(self):
n_samples=infos.fit_n_samples,
n_components=infos.fit_n_components,
n_pca_components=infos.fit_n_pca_components,
explained_variance=infos.fit_explained_variance,
ch_types=infos.ch_types,
excludes=infos.excludes
)
Expand Down Expand Up @@ -962,6 +954,141 @@ def get_components(self):
return np.dot(self.mixing_matrix_[:, :self.n_components_].T,
self.pca_components_[:self.n_components_]).T

def get_explained_variance_ratio(
self, inst, *, components=None, ch_type=None
):
"""Get the proportion of data variance explained by ICA components.

Parameters
----------
inst : mne.io.BaseRaw | mne.BaseEpochs | mne.Evoked
The uncleaned data.
components : array-like of int | int | None
The component(s) for which to do the calculation. If more than one
component is specified, explained variance will be calculated
jointly across all supplied components. If ``None`` (default), uses
all available components.
ch_type : 'mag' | 'grad' | 'planar1' | 'planar2' | 'eeg' | array-like of str | None
The channel type(s) to include in the calculation. If ``None``, all
available channel types will be used.

Returns
-------
dict (str, float)
The fraction of variance in ``inst`` that can be explained by the
ICA components, calculated separately for each channel type.
Dictionary keys are the channel types, and corresponding explained
variance ratios are the values.

Notes
-----
A value similar to EEGLAB's ``pvaf`` (percent variance accounted for)
will be calculated for the specified component(s).

Since ICA components cannot be assumed to be aligned orthogonally, the
sum of the proportion of variance explained by all components may not
be equal to 1. In certain situations, the proportion of variance
explained by a component may even be negative.

.. versionadded:: 1.2
""" # noqa: E501
if self.current_fit == 'unfitted':
raise ValueError('ICA must be fitted first.')

_validate_type(
item=inst, types=(BaseRaw, BaseEpochs, Evoked),
item_name='inst'
)
_validate_type(
item=components, types=(None, 'int-like', Sequence, np.ndarray),
item_name='components', type_name='int, array-like of int, or None'
)
if isinstance(components, (Sequence, np.ndarray)):
for item in components:
_validate_type(
item=item, types='int-like',
item_name='Elements of "components"'
)

_validate_type(
item=ch_type, types=(Sequence, np.ndarray, str, None),
item_name='ch_type', type_name='str, array-like of str, or None'
)
if isinstance(ch_type, str):
ch_types = [ch_type]
elif ch_type is None:
ch_types = inst.get_channel_types(unique=True, only_data_chs=True)
else:
assert isinstance(ch_type, (Sequence, np.ndarray))
ch_types = ch_type

assert len(ch_types) >= 1
allowed_ch_types = ('mag', 'grad', 'planar1', 'planar2', 'eeg')
for ch_type in ch_types:
if ch_type not in allowed_ch_types:
raise ValueError(
f'You requested operation on the channel type '
f'"{ch_type}", but only the following channel types are '
f'supported: {", ".join(allowed_ch_types)}'
)
del ch_type

# Input data validation ends here
if components is None:
components = range(self.n_components_)

explained_var_ratios = [
self._get_explained_variance_ratio_one_ch_type(
inst=inst, components=components, ch_type=ch_type
) for ch_type in ch_types
]
result = dict(zip(ch_types, explained_var_ratios))
return result

def _get_explained_variance_ratio_one_ch_type(
self, *, inst, components, ch_type
):
# The algorithm implemented below should be equivalent to
# https://sccn.ucsd.edu/pipermail/eeglablist/2014/009134.html
#
# Reconstruct ("back-project") the data using only the specified ICA
# components. Don't make use of potential "spare" PCA components in
# this process – we're only interested in the contribution of the ICA
# components!
kwargs = dict(
inst=inst.copy(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the copy really needed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, yes, because we use ica.apply() below and this works in place. We could do without the copy if instead of using ica.apply(), we'd "manually" do the matrix multiplication. But ica.apply() does quite a few additional things and I'm worried I'd forget something important 😅

I trust that if we discover the copy() here causes issues, we can optimize things later on.

include=[components],
exclude=[],
n_pca_components=0,
verbose=False,
)
if (
isinstance(inst, (BaseEpochs, Evoked)) and
inst.baseline is not None
):
# Don't warn if data was baseline-corrected.
with warnings.catch_warnings():
warnings.filterwarnings(
action='ignore',
message='The data.*was baseline-corrected',
category=RuntimeWarning
)
inst_recon = self.apply(**kwargs)
else:
inst_recon = self.apply(**kwargs)

data_recon = inst_recon.get_data(picks=ch_type)
data_orig = inst.get_data(picks=ch_type)
data_diff = data_orig - data_recon

# To estimate the data variance, we first compute the variance across
# channels at each time point, and then we average these variances.
mean_var_diff = data_diff.var(axis=0).mean()
mean_var_orig = data_orig.var(axis=0).mean()
Comment on lines +1084 to +1087
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I worry a bit about this approach because it is going to be very sensitive to channel scaling. If you use a prewhitener in MNE MEG+EEG can be processed jointly (I think?), but this var calculation will make EEG the only thing that matters.

I think you can fix this by applying the pre-whitener to the data_diff and the data_orig before this calculation, then everything might "just work". But @agramfort would know better than me...

Copy link
Member Author

@hoechenberger hoechenberger Sep 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent point.

How about we make our lives easy and simply apply this algorithm here to each channel type separately, and return a dict with the results? One like we use for reject and flat when creating epochs?

ICA property plots also only show one channel type at a time, so this part would be consistent.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think we have to do this. As it makes no sense compute variance with mixed channel types.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is now implemented.


var_explained_ratio = 1 - mean_var_diff / mean_var_orig
return var_explained_ratio

def get_sources(self, inst, add_channels=None, start=None, stop=None):
"""Estimate sources given the unmixing matrix.

Expand Down Expand Up @@ -2247,6 +2374,8 @@ def _find_sources(sources, target, score_func):
def _ica_explained_variance(ica, inst, normalize=False):
"""Check variance accounted for by each component in supplied data.

This function is only used for sorting the components.

Parameters
----------
ica : ICA
Expand Down
91 changes: 89 additions & 2 deletions mne/preprocessing/tests/test_ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,8 +490,6 @@ def test_ica_core(method, n_components, noise_cov, n_pca_components,
assert 'raw data decomposition' in repr_
assert f'{ica.n_components_} ICA components' in repr_
assert 'Available PCA components' in repr_html_
assert 'Explained variance' in repr_html_

assert ('mag' in ica) # should now work without error

# test re-fit
Expand Down Expand Up @@ -949,6 +947,95 @@ def f(x, y):
ica.fit(raw_, picks=picks, reject_by_annotation=True)


@requires_sklearn
def test_get_explained_variance_ratio(tmp_path, short_raw_epochs):
"""Test ICA.get_explained_variance_ratio()."""
raw, epochs, _ = short_raw_epochs
ica = ICA(max_iter=1)

# Unfitted ICA should raise an exception
with pytest.raises(ValueError, match='ICA must be fitted first'):
ica.get_explained_variance_ratio(epochs)

with pytest.warns(RuntimeWarning, match='were baseline-corrected'):
ica.fit(epochs)

# components = int, ch_type = None
explained_var_comp_0 = ica.get_explained_variance_ratio(
epochs, components=0
)
# components = int, ch_type = str
explained_var_comp_0_eeg = ica.get_explained_variance_ratio(
epochs, components=0, ch_type='eeg'
)
# components = int, ch_type = list of str
explained_var_comp_0_eeg_mag = ica.get_explained_variance_ratio(
epochs, components=0, ch_type=['eeg', 'mag']
)
# components = list of int, single element, ch_type = None
explained_var_comp_1 = ica.get_explained_variance_ratio(
epochs, components=[1]
)
# components = list of int, multiple elements, ch_type = None
explained_var_comps_01 = ica.get_explained_variance_ratio(
epochs, components=[0, 1]
)
# components = None, i.e., all components, ch_type = None
explained_var_comps_all = ica.get_explained_variance_ratio(
epochs, components=None
)

assert 'grad' in explained_var_comp_0
assert 'mag' in explained_var_comp_0
assert 'eeg' in explained_var_comp_0

assert len(explained_var_comp_0_eeg) == 1
assert 'eeg' in explained_var_comp_0_eeg

assert 'mag' in explained_var_comp_0_eeg_mag
assert 'eeg' in explained_var_comp_0_eeg_mag
assert 'grad' not in explained_var_comp_0_eeg_mag

assert round(explained_var_comp_0['grad'], 4) == 0.1784
assert round(explained_var_comp_0['mag'], 4) == 0.0259
assert round(explained_var_comp_0['eeg'], 4) == 0.0229

assert np.isclose(
explained_var_comp_0['eeg'],
explained_var_comp_0_eeg['eeg']
)
assert np.isclose(
explained_var_comp_0['mag'],
explained_var_comp_0_eeg_mag['mag']
)
assert np.isclose(
explained_var_comp_0['eeg'],
explained_var_comp_0_eeg_mag['eeg']
)

assert round(explained_var_comp_1['eeg'], 4) == 0.0231
assert round(explained_var_comps_01['eeg'], 4) == 0.0459
assert (
explained_var_comps_all['grad'] ==
explained_var_comps_all['mag'] ==
explained_var_comps_all['eeg'] ==
1
)

# Test Raw
ica.get_explained_variance_ratio(raw)
# Test Evoked
evoked = epochs.average()
ica.get_explained_variance_ratio(evoked)
# Test Evoked without baseline correction
evoked.baseline = None
ica.get_explained_variance_ratio(evoked)

# Test invalid ch_type
with pytest.raises(ValueError, match='only the following channel types'):
ica.get_explained_variance_ratio(raw, ch_type='foobar')


@requires_sklearn
@pytest.mark.slowtest
@pytest.mark.parametrize('method, cov', [
Expand Down
40 changes: 37 additions & 3 deletions tutorials/preprocessing/40_artifact_correction_ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@
filt_raw = raw.copy().filter(l_freq=1., h_freq=None)

# %%
# Fitting and plotting the ICA solution
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Fitting ICA
# ~~~~~~~~~~~
#
# .. admonition:: Ignoring the time domain
# :class: sidebar hint
Expand Down Expand Up @@ -262,8 +262,42 @@
# speed-up) and ``reject`` (for providing a rejection dictionary for maximum
# acceptable peak-to-peak amplitudes for each channel type, just like we used
# when creating epoched data in the :ref:`tut-overview` tutorial).
#

# %%
# Looking at the ICA solution
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Now we can examine the ICs to see what they captured.
#
# Using :meth:`~mne.preprocessing.ICA.get_explained_variance_ratio`, we can
# retrieve the fraction of variance in the original data that is explained by
# our ICA components in the form of a dictionary:

explained_var_ratio = ica.get_explained_variance_ratio(filt_raw)
for channel_type, ratio in explained_var_ratio.items():
print(
f'Fraction of {channel_type} variance explained by all components: '
f'{ratio}'
)

# %%
# The values were calculated for all ICA components jointly, but separately for
# each channel type (here: magnetometers and EEG).
#
# We can also explicitly request for which component(s) and channel type(s) to
# perform the computation:
explained_var_ratio = ica.get_explained_variance_ratio(
filt_raw,
components=[0],
ch_type='eeg'
)
# This time, print as percentage.
ratio_percent = round(100 * explained_var_ratio['eeg'])
print(
f'Fraction of variance in EEG signal explained by first component: '
f'{ratio_percent}%'
)

# %%
# `~mne.preprocessing.ICA.plot_sources` will show the time series of the
# ICs. Note that in our call to `~mne.preprocessing.ICA.plot_sources` we
# can use the original, unfiltered `~mne.io.Raw` object. A helpful tip is that
Expand Down