In [None]:
import numpy as np
from matplotlib import pyplot as plt

from mushroom_hyperscanning.data import load_eeg

In [None]:
SUBJECT = "03"
CEREMONY = "ceremony1"
BIDS_ROOT = "../data/003_sanitization"

In [None]:
raw = load_eeg(SUBJECT, CEREMONY, root=BIDS_ROOT, preload=True)
raw.crop(tmin=60 * 35, tmax=60 * 40)  # use 20 minutes of data to test
raw.info

In [None]:
plt.figure(figsize=(20, 5))
plt.plot(raw.get_data().T)
plt.show()

In [None]:
raw.compute_psd().plot()

In [None]:
# Filter raw data
raw_filtered = raw.copy().filter(1, 90)
raw_filtered = raw_filtered.copy().notch_filter(np.arange(60, raw.info["sfreq"] / 2, 60))
raw_filtered.compute_psd().plot()

In [None]:
# Segment signals in 1s epochs
from mne import make_fixed_length_epochs

epochs = make_fixed_length_epochs(raw_filtered, duration=1.0, preload=True)

# Run a first autoreject before ICA
from autoreject import AutoReject

ar = AutoReject(
    n_jobs=-1
)
ar.fit(epochs)
arlog = ar.get_reject_log(epochs)

# plot autoreject results
arlog.plot("horizontal")

In [None]:
# Segment signals in 1s epochs
from mne import make_fixed_length_epochs

epochs = make_fixed_length_epochs(raw_filtered, duration=1.0, preload=True)

# Run a first autoreject before ICA
from autoreject import AutoReject

ar = AutoReject(n_jobs=-1, n_interpolate=[4], consensus=[0.8])
ar.fit(epochs)
arlog = ar.get_reject_log(epochs)

# plot autoreject results
arlog.plot("horizontal")

In [None]:
cleaned = ar.transform(epochs)

In [None]:
for ep in cleaned:
    plt.figure(figsize=(20, 5))
    plt.plot(ep.T)
    plt.show()

In [None]:
# Fit ICA without bad epochs
from mne.preprocessing import ICA

ica = ICA(n_components=15, random_state=69, max_iter="auto")
ica.fit(epochs[~arlog.bad_epochs])

In [None]:
from mne.preprocessing import create_ecg_epochs, create_eog_epochs

## Find ECG components
ecg_threshold = 0.50
ecg_epochs = create_ecg_epochs(raw_filtered, ch_name="ECG")
ecg_inds, ecg_scores = ica.find_bads_ecg(ecg_epochs, ch_name="ECG", method="ctps", threshold=ecg_threshold)
if ecg_inds == []:
    ecg_inds = [list(abs(ecg_scores)).index(max(abs(ecg_scores)))]

# Plot
ica.plot_scores(ecg_scores, exclude=ecg_inds)
ica.plot_properties(ecg_epochs, picks=ecg_inds)

In [None]:
## Find EOG components
eog_threshold = 2
eog_epochs = create_eog_epochs(raw_filtered, ch_name=["Fp1", "Fp2"])
eog_inds, eog_scores = ica.find_bads_eog(eog_epochs, ch_name=["Fp1", "Fp2"], threshold=eog_threshold)

# if eog_inds == []:
#     eog_inds = [list(abs(eog_scores)).index(max(abs(eog_scores)))]

ica.plot_scores(eog_scores, exclude=eog_inds)
ica.plot_properties(eog_epochs, picks=eog_inds)

In [None]:
# Reconstruct raw without artifact components
ica.exclude = []  # ecg_inds + eog_inds
raw_clean = raw_filtered.copy()
ica.apply(raw_clean)

In [None]:
# Resegment and run autoreject on cleaned data
epochs_clean = make_fixed_length_epochs(raw_clean, duration=1.0, preload=True)
ar_clean = AutoReject(n_jobs=-1)
ar_clean.fit(epochs_clean)
arlog_clean = ar_clean.get_reject_log(epochs_clean)
arlog_clean.plot("horizontal")

In [None]:
epochs_clean