In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
!pip install pyvistaqt
!pip install ipywidgets
!pip install mne
import ipywidgets
import mne
from mne.preprocessing import find_bad_channels_maxwell

### Find bad channels and run Maxfilter

In [None]:
sub = 2;
run = 11;

if sub < 10:
    sample_data_folder = "/Volumes/PortableSSD/SACSAMP/sacsamp0{}_s0{}/".format(sub,sub)
elif sub >= 10:
    sample_data_folder = "/Volumes/PortableSSD/SACSAMP/sacsamp{}_s{}/".format(sub,sub)

if run < 10:
    file_name = "run0{}".format(run)
elif run >= 10:
    file_name = "run{}".format(run)

sample_data_raw_file = os.path.join(sample_data_folder, file_name + '.fif')

# Fine calibration file (site-specific info about sensor orientation and calibration)
fine_cal_file = os.path.join(sample_data_folder, "sss_config/sss_cal_3101_160108.dat")
# Crosstalk compensation file (reduces interference between co-located magnetometers and gradiometers)
crosstalk_file = os.path.join(sample_data_folder, "sss_config/ct_sparse.fif")

mne.utils.use_log_level('error')
raw = mne.io.read_raw_fif(sample_data_raw_file, verbose='Error', allow_maxshield=True, preload=True)

# raw.crop(tmin=58, tmax=250, include_tmax=False)      # crop bad segment from sub 2 / run 11

# Find bad channels
raw.info["bads"] = []
raw_check = raw.copy()
auto_noisy_chs, auto_flat_chs, auto_scores = find_bad_channels_maxwell(
    raw_check,
    cross_talk=crosstalk_file,
    calibration=fine_cal_file,
    h_freq=40,     # low-pass filter
    min_count=5,     # n of segments in which bad channels exceed the limit of noisiness
    return_scores=True,
    verbose=False,
)
bads = raw.info["bads"] + auto_noisy_chs + auto_flat_chs
print("Automatic detection of bad channels: {}".format(bads))
raw.info["bads"] = bads

# raw.info["bads"] += ["MEG0621"]      # manually include bad channels

# Run Maxfilter
raw_sss = mne.preprocessing.maxwell_filter(raw, cross_talk=crosstalk_file, calibration=fine_cal_file, verbose=False)
raw_sss.save(os.path.join(sample_data_folder, file_name + '_sss.fif'), fmt='double', overwrite=True, verbose='Error')

In [None]:
events = mne.find_events(raw, stim_channel="STI101", output='onset', consecutive='increasing', min_duration=0.001, shortest_event=1, mask=None, uint_cast=False, mask_type='and', initial_event=False, verbose=None)
print(events)

In [None]:
%matplotlib tk
raw.plot(duration=20, n_channels=50)

In [None]:
# Plot power spectral density before/after

%matplotlib tk
# inline / tk (interactive)

fig = plt.figure(figsize=(12,3))
ax1 = fig.add_subplot(141)
ax2 = fig.add_subplot(142)
ax3 = fig.add_subplot(143)
ax4 = fig.add_subplot(144)
raw.compute_psd(fmax=60).plot(average=False, picks="data", exclude="bads", amplitude=False, axes=[ax1, ax3])
raw_sss.compute_psd(fmax=60).plot(average=False, picks="data", exclude="bads", amplitude=False, axes=[ax2, ax4])
ax1.set_ylim(0,80)
ax2.set_ylim(0,80)
ax3.set_ylim(0,80)
ax4.set_ylim(0,80)
ax1.set_title("Gradiometers/Before")
ax2.set_title("Gradiometers/After")
ax3.set_title("Magnetometers/Before")
ax4.set_title("Magnetometers/After")

plt.show()

# 50 Hz - power line
# 25-30 Hz - heartbeat

In [None]:
# Inspect noisy channels

%matplotlib inline   
# inline / tk (interactive)

ch_subset = auto_scores["ch_types"] == "grad"
ch_names = auto_scores["ch_names"][ch_subset]
scores = auto_scores["scores_noisy"][ch_subset]
limits = auto_scores["limits_noisy"][ch_subset]
bins = auto_scores["bins"]
bin_labels = [f"{start:3.3f} – {stop:3.3f}" for start, stop in bins]
data_to_plot = pd.DataFrame(data=scores,
    columns=pd.Index(bin_labels, name="Time (s)"),
    index=pd.Index(ch_names, name="Channel"))
fig, ax = plt.subplots(1, 2, figsize=(12, 8), layout="constrained")
fig.suptitle("Automated noisy channel detection: Gradiometers", fontsize=16, fontweight="bold")
sns.heatmap(data=data_to_plot, cmap="Reds", cbar_kws=dict(label="Score"), ax=ax[0])
[
    ax[0].axvline(x, ls="dashed", lw=0.25, dashes=(25, 15), color="gray")
    for x in range(1, len(bins))
]
ax[0].set_title("All Scores", fontweight="bold")
sns.heatmap(data=data_to_plot,
    vmin=np.nanmin(limits),
    cmap="Reds",
    cbar_kws=dict(label="Score"),
    ax=ax[1])
[
    ax[1].axvline(x, ls="dashed", lw=0.25, dashes=(25, 15), color="gray")
    for x in range(1, len(bins))
]
ax[1].set_title("Scores > Limit", fontweight="bold")

ch_subset = auto_scores["ch_types"] == "mag"
ch_names = auto_scores["ch_names"][ch_subset]
scores = auto_scores["scores_noisy"][ch_subset]
limits = auto_scores["limits_noisy"][ch_subset]
bins = auto_scores["bins"]
bin_labels = [f"{start:3.3f} – {stop:3.3f}" for start, stop in bins]
data_to_plot = pd.DataFrame(data=scores,
    columns=pd.Index(bin_labels, name="Time (s)"),
    index=pd.Index(ch_names, name="Channel"))
fig, ax = plt.subplots(1, 2, figsize=(12, 8), layout="constrained")
fig.suptitle("Automated noisy channel detection: Magnetometers", fontsize=16, fontweight="bold")
sns.heatmap(data=data_to_plot, cmap="Reds", cbar_kws=dict(label="Score"), ax=ax[0])
[
    ax[0].axvline(x, ls="dashed", lw=0.25, dashes=(25, 15), color="gray")
    for x in range(1, len(bins))
]
ax[0].set_title("All Scores", fontweight="bold")
sns.heatmap(data=data_to_plot,
    vmin=np.nanmin(limits),
    cmap="Reds",
    cbar_kws=dict(label="Score"),
    ax=ax[1])
[
    ax[1].axvline(x, ls="dashed", lw=0.25, dashes=(25, 15), color="gray")
    for x in range(1, len(bins))
]
ax[1].set_title("Scores > Limit", fontweight="bold")

plt.show()

In [None]:
# Inspect flat channels

%matplotlib inline   

ch_subset = auto_scores["ch_types"] == "grad"
ch_names = auto_scores["ch_names"][ch_subset]
scores = auto_scores["scores_flat"][ch_subset]
limits = auto_scores["limits_flat"][ch_subset]
bins = auto_scores["bins"]
bin_labels = [f"{start:3.3f} – {stop:3.3f}" for start, stop in bins]
data_to_plot = pd.DataFrame(data=scores,
    columns=pd.Index(bin_labels, name="Time (s)"),
    index=pd.Index(ch_names, name="Channel"))
fig, ax = plt.subplots(1, 2, figsize=(12, 8), layout="constrained")
fig.suptitle("Automated flat channel detection: Gradiometers", fontsize=16, fontweight="bold")
sns.heatmap(data=data_to_plot, cmap="Reds", cbar_kws=dict(label="Score"), ax=ax[0])
[
    ax[0].axvline(x, ls="dashed", lw=0.25, dashes=(25, 15), color="gray")
    for x in range(1, len(bins))
]
ax[0].set_title("All Scores", fontweight="bold")
sns.heatmap(data=data_to_plot,
    vmax=np.nanmax(limits),
    cmap="Reds",
    cbar_kws=dict(label="Score"),
    ax=ax[1])
[
    ax[1].axvline(x, ls="dashed", lw=0.25, dashes=(25, 15), color="gray")
    for x in range(1, len(bins))
]
ax[1].set_title("Scores > Limit", fontweight="bold")

ch_subset = auto_scores["ch_types"] == "mag"
ch_names = auto_scores["ch_names"][ch_subset]
scores = auto_scores["scores_flat"][ch_subset]
limits = auto_scores["limits_flat"][ch_subset]
bins = auto_scores["bins"]
bin_labels = [f"{start:3.3f} – {stop:3.3f}" for start, stop in bins]
data_to_plot = pd.DataFrame(data=scores,
    columns=pd.Index(bin_labels, name="Time (s)"),
    index=pd.Index(ch_names, name="Channel"))
fig, ax = plt.subplots(1, 2, figsize=(12, 8), layout="constrained")
fig.suptitle("Automated flat channel detection: Magnetometers", fontsize=16, fontweight="bold")
sns.heatmap(data=data_to_plot, cmap="Reds", cbar_kws=dict(label="Score"), ax=ax[0])
[
    ax[0].axvline(x, ls="dashed", lw=0.25, dashes=(25, 15), color="gray")
    for x in range(1, len(bins))
]
ax[0].set_title("All Scores", fontweight="bold")
sns.heatmap(data=data_to_plot,
    vmax=np.nanmax(limits),
    cmap="Reds",
    cbar_kws=dict(label="Score"),
    ax=ax[1])
[
    ax[1].axvline(x, ls="dashed", lw=0.25, dashes=(25, 15), color="gray")
    for x in range(1, len(bins))
]
ax[1].set_title("Scores > Limit", fontweight="bold")

plt.show()

In [None]:
# Plot channels with events 

map_dict_act = {12: "BL_ACT", 20: "FP_ACT", 28: "RE_ON", 32: "RE_OFF"}
map_dict_pas = {13: "BL_PAS", 21: "FP_PAS", 29: "RE_ON", 33: "RE_OFF"}
map_dict_fix = {14: "BL_FIX", 22: "FP_FIX", 30: "RE_ON", 34: "RE_OFF"}

mapping = map_dict_pas
events = mne.find_events(raw, stim_channel="STI101", output='onset', consecutive='increasing', min_duration=0.001, shortest_event=1, mask=None, uint_cast=False, mask_type='and', initial_event=False, verbose=None)
annot_from_events = mne.annotations_from_events(
    events=events,
    event_desc=mapping,
    sfreq=raw.info["sfreq"],
    orig_time=raw.info["meas_date"],
)
raw.set_annotations(annot_from_events);

%matplotlib tk
raw.plot(duration=100, n_channels=50)

### Plot BIO/MISC channels

In [None]:
# meg = MEG channels
# stim = Stimulus channels
# eog =  EOG channels
# ecg = ECG channels
# emg =  EMG channels
# misc = Miscellaneous analog channels
# chpi = Continuous HPI coil channels
# ias = Internal Active Shielding data
# syst = System status channel information
# bio = Bio channels

# print(raw.info)
# print(raw.info["ch_names"])

In [None]:
plt.close()
start = 1000
timestamps = 1000
fig, axes = plt.subplots(3, 1, figsize=(16, 12))

plt.sca(axes[0])
plt.plot(np.transpose(raw.get_data(picks="bio")[2,start:start+timestamps]), label='BIO003')
plt.legend(loc=(1.,0.5))
# plt.ylabel("EMG (respiratory)"
plt.sca(axes[1])
plt.plot(np.transpose(raw.get_data(picks="misc")[0,start:start+timestamps]), label='MISC001')
plt.legend(loc=(1.,0.5))
# plt.ylabel("ECG")

plt.sca(axes[2])
plt.plot(np.transpose(raw.get_data(picks="misc")[1,start:start+timestamps]), label='MISC002')
plt.legend(loc=(1.,0.5))
# plt.ylabel("ECG")

In [None]:
plt.close()
start = 0
timestamps = 100000
fig, axes = plt.subplots(3, 1, figsize=(16, 12))

plt.sca(axes[0])
plt.plot(np.transpose(raw.get_data(picks="bio")[0,start:start+timestamps]), label='BIO001')  # HEOG
plt.plot(np.transpose(raw.get_data(picks="bio")[1,start:start+timestamps]), label='BIO002')  # VEOG
plt.legend(loc=(1.,0.5))
plt.ylabel("EOG")

plt.sca(axes[1])
plt.plot(np.transpose(raw.get_data(picks="misc")[2,start:start+timestamps]), label='MISC007')  # V eye pos
plt.plot(np.transpose(raw.get_data(picks="misc")[3,start:start+timestamps]), label='MISC008')  # H eye pos
plt.legend(loc=(1.,0.5))
plt.ylabel("Eye position")

plt.sca(axes[2])
plt.plot(np.transpose(raw.get_data(picks="misc")[4,start:start+timestamps]), label='MISC009')  # Pupil
plt.legend(loc=(1.,0.5))
plt.ylabel("Pupil")