# rMMN

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib qt

import numpy as np
import csv
import matplotlib.pyplot as plt
import pandas as pd
import os
import glob
from tqdm import tqdm 
from atpbar import atpbar
from datetime import datetime
import mne
from autoreject import AutoReject
from mne.preprocessing import ICA, corrmap, create_ecg_epochs, create_eog_epochs

from pyprep.find_noisy_channels import NoisyChannels
from mne_icalabel import label_components

import autoreject

In [156]:
def process_events_array(events):
    # Initialize counter
    counter = 0

    # Process array
    for i in range(events.shape[0]):
        if events[i, 2] == -1:
            events[i, 2] = 0
            counter = 0
        else:
            counter += 1
            events[i, 2] = counter

    return events


def raw_to_events(raw):
    event_id = {'-1':-1, '1':1,}
    events = mne.events_from_annotations(raw, event_id=event_id)[0]
    if len(events)==0:
        events = np.zeros((1,3))
        events = events.astype(int)

    processed_events = process_events_array(events)
    return processed_events

In [3]:
# Suppress MNE output
#mne.set_log_level('WARNING')
mne.set_log_level('ERROR')
#mne.set_log_level('CRITICAL')
#mne.set_log_level('INFO')



In [239]:
sub = '*'
picks = 'Cz'

paths = glob.glob(f"../../data/mne_raw_events/sub{sub}-LTP_*-rmmn-raw_phot-events.fif")
paths.sort()

for path in paths:
    raw = mne.io.read_raw_fif(path, preload=True)
    montage = mne.channels.make_standard_montage('standard_1020')
    raw.set_montage(montage)
    
    raw.pick_channels([picks])  
    
    filt_raw = raw.copy()
    #filt_raw.filter(0.3, 45)
    filt_raw.filter(0.1, 60)
    filt_raw.notch_filter(freqs=[60,76, 120])
    filt_raw.notch_filter(freqs=84, notch_widths=1, phase='zero')          
    
    events = raw_to_events(filt_raw)    
    epochs = mne.Epochs(filt_raw, events, tmin=-0.1, tmax=0.5, baseline=(-0.1, 0), preload=True)

    #ar = autoreject.AutoReject(n_interpolate=[1, 2, 3, 4], random_state=11, n_jobs=12, verbose=False)
    #ar.fit(epochs)
    #epochs, reject_log = ar.transform(epochs, return_log=True)

    reject_criteria = dict(eeg=80e-6)
    epochs.drop_bad(reject=reject_criteria)
    
    basename = os.path.basename(path)
    export_name = basename[:4]+'_'+basename[9:-25]
    export_path = "epochs/" + export_name
    epochs.save(export_path, overwrite=True)

    ############################
    ### PLOTTING ###
    ############################
    
    evokeds = dict(
        oddball=list(epochs["0"].iter_evoked()),
        t1=list(epochs["1"].iter_evoked()),
        t2=list(epochs["2"].iter_evoked()),
        t3=list(epochs["3"].iter_evoked()),
        t4=list(epochs["4"].iter_evoked()),
        t5=list(epochs["5"].iter_evoked()),
        t6=list(epochs["6"].iter_evoked()),
    )
    
    
    mne.viz.plot_compare_evokeds(evokeds, combine="mean", picks=picks, title=export_name)
    plt.savefig(f'plots/{export_name}')
    plt.close()


    """    
    conditions = [0,1,2,3,4,5,6]
    conditions = [0,5]
    evokeds = {str(condition): epochs[str(condition)].average() for condition in conditions}
    mne.viz.plot_compare_evokeds(evokeds, combine='mean', ci=True)
    """

# Plot power spectrum

In [None]:
frequencies = np.arange(7, 30, 3)
power = epochs['1'].compute_tfr(
    "morlet", n_cycles=2, return_itc=False, freqs=frequencies, decim=3, average=True
)
power.plot(["Cz"])