Skip to content

Commit

Permalink
[MRG] Add viz of auto-noisy segment detection (#145)
Browse files Browse the repository at this point in the history
* Add viz of auto-noisy segment detection

This is currently non-interactive (purely static viz) until we'vedecided
how to best add interactivity.

* Add diagnostic plots to report

* Require json_tricks

* Formatting

* Require seaborn
  • Loading branch information
hoechenberger committed Jul 6, 2020
1 parent 0950c94 commit 96c3329
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Expand Up @@ -35,7 +35,7 @@ jobs:
name: Install dependencies in conda base environment
command: |
conda update -n base -c defaults conda
pip install numpy scipy pandas matplotlib nibabel coloredlogs python-picard
pip install numpy scipy pandas json_tricks matplotlib seaborn nibabel coloredlogs python-picard
pip install -U scikit-learn
pip install --upgrade https://api.github.com/repos/mne-tools/mne-python/zipball/master
git clone https://github.com/mne-tools/mne-bids.git --depth 1
Expand Down
48 changes: 32 additions & 16 deletions 01-import_and_maxfilter.py
Expand Up @@ -37,6 +37,7 @@

import numpy as np
import pandas as pd
import json_tricks

import mne
from mne.preprocessing import find_bad_channels_maxwell
Expand Down Expand Up @@ -110,17 +111,34 @@ def find_bad_channels(raw, subject, session, task, run):

logger.info(gen_log_message(message=msg, step=1, subject=subject,
session=session))

deriv_path = config.get_subject_deriv_path(subject=subject,
session=session,
kind=config.get_kind())

bids_basename = make_bids_basename(subject=subject,
session=session,
task=config.get_task(),
acquisition=config.acq,
run=run,
processing=config.proc,
recording=config.rec,
space=config.space,
prefix=deriv_path)

raw_lp_filtered_for_maxwell = (raw.copy()
.filter(l_freq=None,
h_freq=40))
auto_noisy_chs, auto_flat_chs = find_bad_channels_maxwell(
auto_noisy_chs, auto_flat_chs, auto_scores = find_bad_channels_maxwell(
raw=raw_lp_filtered_for_maxwell,
calibration=config.mf_cal_fname,
cross_talk=config.mf_ctc_fname)
cross_talk=config.mf_ctc_fname,
return_scores=True)
del raw_lp_filtered_for_maxwell

preexisting_bads = raw.info['bads'].copy()
bads = preexisting_bads.copy()

if config.find_flat_channels_meg:
msg = f'Found {len(auto_flat_chs)} flat channels.'
logger.info(gen_log_message(message=msg, step=1,
Expand All @@ -138,21 +156,19 @@ def find_bad_channels(raw, subject, session, task, run):
logger.info(gen_log_message(message=msg, step=1,
subject=subject, session=session))

# Write the bad channels to disk.
deriv_path = config.get_subject_deriv_path(subject=subject,
session=session,
kind=config.get_kind())
if config.find_noisy_channels_meg:
auto_scores_fname = bids_basename.copy().update(suffix='scores.json')
with open(auto_scores_fname, 'w') as f:
json_tricks.dump(auto_scores, fp=f, allow_nan=True,
sort_keys=False)

bads_tsv_fname = make_bids_basename(subject=subject,
session=session,
task=config.get_task(),
acquisition=config.acq,
run=run,
processing=config.proc,
recording=config.rec,
space=config.space,
prefix=deriv_path,
suffix='bad_channels.tsv')
if config.interactive:
import matplotlib.pyplot as plt
config.plot_auto_scores(auto_scores)
plt.show()

# Write the bad channels to disk.
bads_tsv_fname = bids_basename.copy().update(suffix='bad_channels.tsv')
bads_for_tsv = []
reasons = []

Expand Down
43 changes: 43 additions & 0 deletions 99-make_reports.py
Expand Up @@ -76,6 +76,42 @@ def plot_er_psd(subject, session):
return fig


def plot_auto_scores(subject, session):
"""Plot automated bad channel detection scores.
"""
import json_tricks

deriv_path = config.get_subject_deriv_path(subject=subject,
session=session,
kind=config.get_kind())

fname_scores = make_bids_basename(subject=subject,
session=session,
task=config.get_task(),
acquisition=config.acq,
run=None,
processing=config.proc,
recording=config.rec,
space=config.space,
prefix=deriv_path,
suffix='scores.json')

all_figs = []
all_captions = []
for run in config.get_runs():
with open(fname_scores.update(run=run), 'r') as f:
auto_scores = json_tricks.load(f)

figs = config.plot_auto_scores(auto_scores)
all_figs.extend(figs)

# Could be more than 1 fig, e.g. "grad" and "mag"
captions = [f'Run {run}'] * len(figs)
all_captions.extend(captions)

return all_figs, all_captions


def run_report(subject, session=None):
deriv_path = config.get_subject_deriv_path(subject=subject,
session=session,
Expand All @@ -102,6 +138,13 @@ def run_report(subject, session=None):

rep.parse_folder(deriv_path, verbose=True)

# Visualize automated noisy channel detection.
if config.find_noisy_channels_meg:
figs, captions = plot_auto_scores(subject=subject, session=session)
rep.add_figs_to_section(figs=figs,
captions=captions,
section='Data Quality')

# Visualize events.
events_fig = plot_events(subject=subject, session=session,
deriv_path=deriv_path)
Expand Down
58 changes: 58 additions & 0 deletions config.py
Expand Up @@ -1049,3 +1049,61 @@ def wrapper(*args, **kwargs):
logger.critical(message)
return wrapper
return failsafe_run_decorator


def plot_auto_scores(auto_scores):
"""Plot scores of automated bad channel detection.
"""
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

if ch_types == ['meg']:
ch_types_ = ['grad', 'mag']
else:
ch_types_ = ch_types

figs = []
for ch_type in ch_types_:
# Only select the data for mag or grad channels.
ch_subset = auto_scores['ch_types'] == ch_type
ch_names = auto_scores['ch_names'][ch_subset]
scores = auto_scores['scores_noisy'][ch_subset]
limits = auto_scores['limits_noisy'][ch_subset]
bins = auto_scores['bins'] # The the windows that were evaluated.

# We will label each segment by its start and stop time, with up to 3
# digits before and 3 digits after the decimal place (1 ms precision).
bin_labels = [f'{start:3.3f}{stop:3.3f}'
for start, stop in bins]

# We store the data in a Pandas DataFrame. The seaborn heatmap function
# we will call below will then be able to automatically assign the correct
# labels to all axes.
data_to_plot = pd.DataFrame(data=scores,
columns=pd.Index(bin_labels, name='Time (s)'),
index=pd.Index(ch_names, name='Channel'))

# First, plot the "raw" scores.
fig, ax = plt.subplots(1, 2, figsize=(12, 8))
fig.suptitle(f'Automated noisy channel detection: {ch_type}',
fontsize=16, fontweight='bold')
sns.heatmap(data=data_to_plot, cmap='Reds', cbar_kws=dict(label='Score'),
ax=ax[0])
[ax[0].axvline(x, ls='dashed', lw=0.25, dashes=(25, 15), color='gray')
for x in range(1, len(bins))]
ax[0].set_title('All Scores', fontweight='bold')

# Now, adjust the color range to highlight segments that exceeded the limit.
sns.heatmap(data=data_to_plot,
vmin=np.nanmin(limits), # bads in input data have NaN limits
cmap='Reds', cbar_kws=dict(label='Score'), ax=ax[1])
[ax[1].axvline(x, ls='dashed', lw=0.25, dashes=(25, 15), color='gray')
for x in range(1, len(bins))]
ax[1].set_title('Scores > Limit', fontweight='bold')

# The figure title should not overlap with the subplots.
fig.tight_layout(rect=[0, 0.03, 1, 0.95])
figs.append(fig)

return figs

0 comments on commit 96c3329

Please sign in to comment.