In [None]:
from numba import jit, njit, prange
import numpy as np

from numpy.linalg import lstsq
import mne

@njit(nopython=True, parallel=True)
def create_time_lag_signals(Y, sample_freq, min_lag, max_lag):
    # Calculate number of samples for each lag
    num_samples = int((max_lag - min_lag) * sample_freq) + 1
    # Initialize array to store time lag signals
    time_lag_signals = np.zeros((num_samples, Y.shape[0], Y.shape[1]))
    # Iterate through each lag value and generate time lagged signals
    lags = np.linspace(min_lag, max_lag, num_samples)
    for l in prange(len(lags)):
        lag = lags[l]
        # Calculate index offset based on time lag
        offset = int(round(lag * sample_freq))
        # Shift signal Y by the offset
        shifted_Y = np.roll(Y, offset)
        # Clip the shifted signal to match the length of X
        if offset > 0:
            shifted_Y[:offset] = Y[0]
        elif offset < 0:
            shifted_Y[offset:] = Y[-1]
        
        # Store the time lagged signal
        time_lag_signals[l] = shifted_Y
    time_lag_signals = time_lag_signals.reshape(-1, Y.shape[1])
    return time_lag_signals


@njit(nopython=True, parallel=True)
def compute_regression(X, Y):
    ## Sk implementation
    X_offset = np.average(X, axis=0)
    X -= X_offset

    n_samples, n_channels =  Y.shape
    residuals = np.zeros_like(Y)
    for c in prange(n_channels):
        y = Y[:, c]
        y_offset = np.average(y, axis=0)
        y -= y_offset
        #y = y.reshape(-1, 1)
        coef_, resids, rank, s = lstsq(X, y, rcond=-1)
        resid = y - np.dot(X, coef_)
        residuals[:, c] = resid.reshape(-1)
    return residuals.T


class CWLCorrection():
    def __init__(self, info, cwl_ch_names, picks=['eeg', 'ecg'], min_lag=-0.1, max_lag=0.1):
        self.pick_indices = mne._fiff.pick._picks_to_idx(info, picks=picks)
        self.cwl_indices = mne._fiff.pick._picks_to_idx(info, picks=cwl_ch_names)
        self.sfreq = info['sfreq']
        self.min_lag = min_lag
        self.max_lag = max_lag

    def correct_epochs(self, epochs, verbose=False):
        data = epochs.get_data()
        for d,dat in enumerate(data):
            data[d] = self.correct_data(dat)
        epochs = mne.EpochsArray(data, epochs.info, verbose=verbose)
        return(epochs)
    
    def correct_raw(self, raw, verbose=False):
        data = raw.get_data()
        data = self.correct_data(data)
        raw = mne.io.RawArray(data, raw.info, verbose=verbose)
        return(raw)

    def correct_data(self, data):
        X = data[self.cwl_indices]
        Y = data[self.pick_indices]

        X = create_time_lag_signals(X, self.sfreq, self.min_lag, self.max_lag)
        residuals = compute_regression(X.T, Y.T)

        data[self.pick_indices] = residuals
        return(data)
    
    def correct_data_sm(self, data, verbose=False):
        import statsmodels.api as sm
        X = data[self.cwl_indices]
        Y = data[self.pick_indices]
        X = create_time_lag_signals(X, self.sfreq, self.min_lag, self.max_lag)

        for pick in self.pick_indices:
            y = data[pick, :]
            model = sm.OLS(y.T, X.T)
            results = model.fit()
            residuals = results.resid
            data[pick] = residuals
        return(data)

In [None]:
template_path = "D:\Soraya\OnlinePreproc\Analyzer_GA_P02_eyes_closed_mrion-edf.edf"
raw_template = mne.io.read_raw(template_path)

clean_raw_path = r"D:\Soraya\OnlinePreproc\Analyzer_GA_PA_CWL_P02_eyes_closed_mrion-edf.edf"
clean_raw = mne.io.read_raw(clean_raw_path)

raw = mne.io.read_raw(r"D:\Soraya\OnlinePreproc\P02_eyes_closed_mrion_RecView_PA_GA_raw.fif").crop(10, 250)

rename_dict = {raw.ch_names[i]: raw_template.ch_names[i] for i in range(len(raw_template.ch_names))}
raw.rename_channels(rename_dict)
raw.set_channel_types({'CWL1': 'misc', 'CWL2': 'misc', 'CWL3': 'misc', 'CWL4': 'misc', 'ECG': 'ecg'})

In [None]:
clean_raw.plot(picks='eeg')

In [None]:
cwl = CWLCorrection(raw.info, cwl_ch_names=['CWL1', 'CWL2', 'CWL3', 'CWL4'], min_lag=-0.2, max_lag=0.2)

epochs = mne.make_fixed_length_epochs(raw, duration=5, preload=True)
epochs_corrected = cwl.correct_epochs(epochs.copy())

In [None]:
epochs.plot_psd(picks='eeg')

In [None]:
clean_epochs = mne.make_fixed_length_epochs(clean_raw, duration=5, preload=True)
clean_epochs.plot_psd(picks='eeg')

In [None]:
epochs_corrected.plot_psd(picks='eeg')

In [None]:
import time

cwl = CWLCorrection(raw.info, cwl_ch_names=['CWL1', 'CWL2', 'CWL3', 'CWL4'], min_lag=-0.2, max_lag=0.2)

window_size = 5 # seconds
tmin = raw.times[0]
tmax =  raw.times[-1]

starts = np.arange(tmin, tmax - window_size, window_size - 1)
ends = starts + window_size
exec_times = []

corrected_data = np.zeros_like(raw.get_data())
for start, end in zip(starts, ends):
    raw_window = raw.copy().crop(tmin=start, tmax=end, include_tmax=False, verbose=False)
    t_start = time.time()
    raw_corrected = cwl.correct(raw_window)
    t_end = time.time()
    exec_times.append(t_end - t_start)
    corrected_data[:, raw_window.time_as_index(start)[0]:raw_window.time_as_index(end)[0]] = raw_corrected.get_data()

raw_corrected = mne.io.RawArray(corrected_data, raw.info)
print(np.mean(exec_times) , "s", "+-", np.std(exec_times), "s")

In [None]:
corrected_data[:, raw_window.time_as_index(starts)+1] = (corrected_data[:, raw_window.time_as_index(starts)+2] + corrected_data[:, raw_window.time_as_index(starts)])/2
raw_corrected_smooth = mne.io.RawArray(corrected_data, raw.info)