# ***LTP Task***

## Notes

Vertical grating == White & 1

Horizontal grating == Black & -1

In [194]:
%load_ext autoreload
%autoreload 2
%matplotlib qt

import numpy as np
import csv
import matplotlib.pyplot as plt
import pandas as pd
import os
import glob
import autoreject
from tqdm import tqdm 
from atpbar import atpbar
from datetime import datetime
import mne
from autoreject import AutoReject
from mne.preprocessing import ICA, corrmap, create_ecg_epochs, create_eog_epochs

from pyprep.find_noisy_channels import NoisyChannels
from mne_icalabel import label_components


# Suppress MNE output
#mne.set_log_level('WARNING')
mne.set_log_level('ERROR')
#mne.set_log_level('CRITICAL')
#mne.set_log_level('INFO')


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Import 

In [216]:
# vertical tetanus: 0,1,8,9
# 1 = vertical 
# --> 1 = TETANUS 
# --> -1 = NORMAL

def raw_to_events(raw, timepoint, condition, sub):
    if sub in ('sub2', 'sub3', 'sub4', 'sub5', 'sub6', 'sub7'):
        inversion = -1
    if sub in ('sub0', 'sub1', 'sub8', 'sub9'):
        inversion = 1
    
    if timepoint == 'pre' and condition == 'baseline':
        event_id = {'-1':-1*inversion, '1':1*inversion,}
    elif timepoint == 'pre' and condition == 'early':
        event_id = {'-1':-2*inversion, '1':2*inversion,}
    elif timepoint == 'pre' and condition == 'late':
        event_id = {'-1':-3*inversion, '1':3*inversion,}
    elif timepoint == 'post' and condition == 'baseline':
        event_id = {'-1':-4*inversion, '1':4*inversion,}
    elif timepoint == 'post' and condition == 'early':
        event_id = {'-1':-5*inversion, '1':5*inversion,}
    elif timepoint == 'post' and condition == 'late':
        event_id = {'-1':-6*inversion, '1':6*inversion,}
    events = mne.events_from_annotations(raw, event_id=event_id)[0]
    if len(events)==0:
        events = np.zeros((1,3))
        events = events.astype(int)
    return events 

In [217]:
paths = glob.glob(f"../../data/mne_raw_events/sub?-LTP_*-*_ltp-raw_phot-events.fif")
paths.sort()
for path in paths:
    filename_components = re.split(r'[-_]', path)
    condition = filename_components[-5]
    timepoint = filename_components[-6]
    sub = filename_components[-8][-4:]

    raw = mne.io.read_raw_fif(path, preload=True)
    montage = mne.channels.make_standard_montage('standard_1020')
    raw.set_montage(montage)

    raw.pick_channels(['Oz'])
    #raw.pick_channels(['POz'])
    
    
    filt_raw = raw.copy()
    filt_raw.filter(0.3, 45)
    #filt_raw.filter(0.1, 60)
    filt_raw.notch_filter(freqs=[60,76, 120])
    filt_raw.notch_filter(freqs=84, notch_widths=1, phase='zero')          
    
    events = raw_to_events(filt_raw, timepoint, condition, sub)   
    epochs = mne.Epochs(filt_raw, events, tmin=-0.1, tmax=0.5, baseline=(-0.1, 0), preload=True)


    if len(epochs) < 10:
        print (f'not enough epochs for {path}')
        continue 


    #ar = autoreject.AutoReject(n_interpolate=[1, 2, 3, 4], random_state=11, n_jobs=12, verbose=False)
    #ar.fit(epochs)
    #epochs, reject_log = ar.transform(epochs, return_log=True)

    reject_criteria = dict(eeg=100e-6)
    epochs.drop_bad(reject=reject_criteria)
    
    basename = os.path.basename(path)
    export_path = f"epochs/{sub}_{timepoint}_{condition}_epo.fif"
    epochs.save(export_path, overwrite=True)


not enough epochs for ../../data/mne_raw_events/sub6-LTP_pre-late_ltp-raw_phot-events.fif


# Plot

In [218]:
paths

['../../data/mne_raw_events/sub0-LTP_post-baseline_ltp-raw_phot-events.fif',
 '../../data/mne_raw_events/sub0-LTP_post-early_ltp-raw_phot-events.fif',
 '../../data/mne_raw_events/sub0-LTP_post-late_ltp-raw_phot-events.fif',
 '../../data/mne_raw_events/sub0-LTP_pre-baseline_ltp-raw_phot-events.fif',
 '../../data/mne_raw_events/sub0-LTP_pre-early_ltp-raw_phot-events.fif',
 '../../data/mne_raw_events/sub0-LTP_pre-late_ltp-raw_phot-events.fif',
 '../../data/mne_raw_events/sub1-LTP_post-baseline_ltp-raw_phot-events.fif',
 '../../data/mne_raw_events/sub1-LTP_post-early_ltp-raw_phot-events.fif',
 '../../data/mne_raw_events/sub1-LTP_post-late_ltp-raw_phot-events.fif',
 '../../data/mne_raw_events/sub1-LTP_pre-baseline_ltp-raw_phot-events.fif',
 '../../data/mne_raw_events/sub1-LTP_pre-early_ltp-raw_phot-events.fif',
 '../../data/mne_raw_events/sub1-LTP_pre-late_ltp-raw_phot-events.fif',
 '../../data/mne_raw_events/sub2-LTP_post-baseline_ltp-raw_phot-events.fif',
 '../../data/mne_raw_events/sub2-

In [236]:
sub

6

In [238]:
for sub in range(10):
    paths = glob.glob(f"epochs/sub{sub}*")
    paths.sort()
    epochs_list = [mne.read_epochs(path) for path in paths]
    epochs = mne.concatenate_epochs(epochs_list)

    if sub != 6:
        evokeds_pre = dict(
            pre_baselin_normal = list(epochs["-1"].iter_evoked()),
            pre_early_normal = list(epochs["-2"].iter_evoked()),
            pre_late_normal = list(epochs["-3"].iter_evoked()),
            
            pre_baselin_tetanus = list(epochs["1"].iter_evoked()),
            pre_early_tetanus = list(epochs["2"].iter_evoked()),
            pre_late_tetanus = list(epochs["3"].iter_evoked()),
        )
        
    if sub == 6:
        evokeds_pre = dict(
            pre_baselin_normal = list(epochs["-1"].iter_evoked()),
            pre_early_normal = list(epochs["-2"].iter_evoked()),
            
            pre_baselin_tetanus = list(epochs["1"].iter_evoked()),
            pre_early_tetanus = list(epochs["2"].iter_evoked()),
        )
        
    
    evokeds_post = dict(
        post_baselin_normal = list(epochs["-4"].iter_evoked()),
        post_early_normal = list(epochs["-5"].iter_evoked()),
        post_late_normal = list(epochs["-6"].iter_evoked()),
    
        post_baselin_tetanus = list(epochs["4"].iter_evoked()),
        post_early_tetanus = list(epochs["5"].iter_evoked()),
        post_late_tetanus = list(epochs["6"].iter_evoked()),
    )
    
    
    mne.viz.plot_compare_evokeds(evokeds_pre, combine="mean", picks='Oz', title=f"sub{sub}: pre")
    plt.savefig(f'plots/sub{sub}_pre.png')
    plt.close()
    
    mne.viz.plot_compare_evokeds(evokeds_post, combine="mean", picks='Oz', title=f"sub{sub}: post")
    plt.savefig(f'plots/sub{sub}_post.png')
    plt.close()