In [2]:
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import pandas as pd

import scipy
import pywt
import mne

import sys
from tqdm import tqdm


sys.path.append('../utils')
from ERP_utils import *
from update_sub_lists import *
from wavelet_utils import *
import glob
import os


#import seaborn as sns
import warnings

# Suppress all FutureWarnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [10]:
task = 'error'
erp_window = '_n05to05'
freq_high = 30

###### CHOOSE SUBJECTS
#subjects_to_process = find_existing_subjects(task = task, period = 'pre',erp_window=erp_window,freq_high=freq_high) #subs that exist in pre should exist in post
subjects_to_process = ['01','04','05','06','07','08','09','12']

##### CHOOSE CHANNELS
ch_to_process = 'all'
#ch_to_process = ch_index(['Cz'])



wavelet_params = {
    'fs' :128 , # example sampling frequency in Hz
    'centerfreq' : 1 ,
    'bandwidth': 1.5,
    'level': 10,
    'scale_values':[6, 150, 40]
}

#epochs_dir = f'/Users/cindyzhang/Documents/M2/Audiomotor_Piano/AM-EEG/analysis_{task}/{task}_epochs_data{erp_window}_{freq_high}Hz'
epochs_dir = f'/Users/cindyzhang/Documents/M2/Audiomotor_Piano/AM-EEG/analysis_error/error_epochs_data_n05to05_30Hz_corrected'

wavelet_dir = f"./wavelets_mat_{task}_{erp_window}_{freq_high}Hz_time"

print('processing subjects:', subjects_to_process)
print('processing channels:', ch_to_process)

mne.set_log_level('CRITICAL')

processing subjects: ['01', '04', '05', '06', '07', '08', '09', '12']
processing channels: all


## Batch processing wavelets

In [11]:
def wavelet_batch(subjects_to_process, channels, ep_dir, output_dir, wavelet_params, ave = False, overwrite = False, erp_begin = -0.5, erp_end = 0.5):

    """
    wavelet transforms epochs trial by trial for each subject

    
    subjects_to_process: list of subjects
    channels: 'all' or a list of channels. If not all, will only save the wavelet transformed channels and others are discarded
    ep_dir: directory where epochs are saved
    output_dir: dir to save wavelet data
    wavelet_params: dict with wavelet parameters. Example:
        wavelet_params = {
                'fs' :128 , # example sampling frequency in Hz
                'centerfreq' : 5 ,
                'bandwidth': 1.5,
                'level': 10,
                'scale_values':[6, 150, 40]
            }

    ave: whether to average spectrograms all the trials before saving (TO IMPLEMENT)

     ----
    saves wavelet transforms to .mat files
    'wavelet': wavelet data of dim n_trials x n_channels x spect_freqs x spect_times
    """
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)

    times = create_erp_times(erp_begin, erp_end, 128)

    #check file type
    for fif_name in sorted(os.listdir(ep_dir)):
        if not fif_name.endswith(".fif"):
            print('skipping file, not epochs:', fif_name)
            continue

    #identify subject
        subject_id = fif_name.split("_")[-1].split(".")[0]
        if subject_id not in subjects_to_process:
           # print(f'subject {subject_id} not in subjects to process. skipping...')
            continue

        if any(sub in fif_name for sub in ('inv', 'firsts', 'norm', 'shinv')):
            continue

        mat_name = fif_name.split(".")[0].replace("epochs", "wavelet")
        mat_path = os.path.join(output_dir, f"{mat_name}.mat")

        if not os.path.exists(mat_path) or overwrite: #skip if the file already exists
            print('processing', fif_name)
            
            #get data to loop over
            epochs = mne.read_epochs(os.path.join(ep_dir, fif_name))

            
            epochs_data = epochs.get_data()

            wavelet_data = []
            #for trial in range(10):
            for trial in tqdm(range(epochs_data.shape[0])):
                
                #initiate storage matrix
                n_freqs = wavelet_params['scale_values'][2]
                n_times = times.shape[0]
                if isinstance(channels, str) and channels == 'all':
                    trial_wavelet = np.zeros((64, n_freqs, n_times))
                    ch_towav = np.arange(64) 
                else:
                    trial_wavelet = np.zeros((len(channels), n_freqs, n_times)) 
                    ch_towav = channels

                for j, ch in enumerate(ch_towav):
                    trial_data = epochs_data[trial, ch, :]

                    cwtmatr, freqs, wavelet = morwav_trans(trial_data, 
                                                    centerfreq=wavelet_params['centerfreq'], 
                                                    bandwidth=wavelet_params['bandwidth'], 
                                                    scale_values=wavelet_params['scale_values'])

                    cwtmatr_abs = np.abs(cwtmatr)
                    trial_wavelet[j, :,:] = cwtmatr_abs

                wavelet_data.append(trial_wavelet)

            #save subject data to mat file
            wavelet_data = np.array(wavelet_data)
            wavelet_tosave = {
                'wavelet_transform':wavelet_data
            }

            savemat(mat_path, wavelet_tosave)
    
    #processing record for wavelet
    
    wavelet_record = {
        'freqs': freqs,
        'wavelet': wavelet,
        'subjects': subjects_to_process,
        'centerfreq': wavelet_params['centerfreq'],
        'bandwidth': wavelet_params['bandwidth'],
        'scale_values': wavelet_params['scale_values'],
        'times': times, 
        'channels': channels
    }
    savemat(os.path.join(output_dir, f'wavelet_record.mat'), wavelet_record)


In [12]:

wavelet_batch(subjects_to_process, ch_to_process, epochs_dir, wavelet_dir, wavelet_params, overwrite=False)

processing error_epochs_all_post_01.fif


100%|██████████| 390/390 [02:06<00:00,  3.08it/s]


processing error_epochs_all_post_04.fif


100%|██████████| 362/362 [01:56<00:00,  3.12it/s]


processing error_epochs_all_post_05.fif


100%|██████████| 275/275 [02:07<00:00,  2.15it/s]


processing error_epochs_all_post_06.fif


100%|██████████| 386/386 [02:22<00:00,  2.70it/s]


processing error_epochs_all_post_07.fif


100%|██████████| 308/308 [01:31<00:00,  3.36it/s]


processing error_epochs_all_post_08.fif


100%|██████████| 325/325 [01:35<00:00,  3.40it/s]


processing error_epochs_all_post_09.fif


100%|██████████| 381/381 [01:52<00:00,  3.40it/s]


processing error_epochs_all_post_12.fif


100%|██████████| 358/358 [01:51<00:00,  3.21it/s]


processing error_epochs_all_pre_01.fif


100%|██████████| 285/285 [01:27<00:00,  3.24it/s]


processing error_epochs_all_pre_04.fif


100%|██████████| 185/185 [01:26<00:00,  2.14it/s]


processing error_epochs_all_pre_05.fif


100%|██████████| 303/303 [01:59<00:00,  2.53it/s]


processing error_epochs_all_pre_06.fif


100%|██████████| 369/369 [02:29<00:00,  2.48it/s]


processing error_epochs_all_pre_07.fif


100%|██████████| 373/373 [01:59<00:00,  3.13it/s]


processing error_epochs_all_pre_08.fif


100%|██████████| 267/267 [01:24<00:00,  3.16it/s]


processing error_epochs_all_pre_09.fif


100%|██████████| 394/394 [02:14<00:00,  2.93it/s]


processing error_epochs_all_pre_12.fif


100%|██████████| 281/281 [01:53<00:00,  2.48it/s]


processing error_epochs_others_post_01.fif


100%|██████████| 297/297 [01:38<00:00,  3.03it/s]


processing error_epochs_others_post_04.fif


100%|██████████| 307/307 [02:07<00:00,  2.41it/s]


processing error_epochs_others_post_05.fif


100%|██████████| 199/199 [01:03<00:00,  3.11it/s]


processing error_epochs_others_post_06.fif


100%|██████████| 293/293 [01:30<00:00,  3.23it/s]


processing error_epochs_others_post_07.fif


100%|██████████| 218/218 [01:02<00:00,  3.49it/s]


processing error_epochs_others_post_08.fif


100%|██████████| 244/244 [01:10<00:00,  3.48it/s]


processing error_epochs_others_post_09.fif


100%|██████████| 288/288 [01:23<00:00,  3.45it/s]


processing error_epochs_others_post_12.fif


100%|██████████| 337/337 [01:36<00:00,  3.51it/s]


processing error_epochs_others_pre_01.fif


100%|██████████| 224/224 [01:03<00:00,  3.52it/s]


processing error_epochs_others_pre_04.fif


100%|██████████| 154/154 [00:44<00:00,  3.48it/s]


processing error_epochs_others_pre_05.fif


100%|██████████| 250/250 [01:11<00:00,  3.47it/s]


processing error_epochs_others_pre_06.fif


100%|██████████| 273/273 [01:17<00:00,  3.51it/s]


processing error_epochs_others_pre_07.fif


100%|██████████| 353/353 [01:40<00:00,  3.51it/s]


processing error_epochs_others_pre_08.fif


100%|██████████| 217/217 [01:02<00:00,  3.49it/s]


processing error_epochs_others_pre_09.fif


100%|██████████| 298/298 [01:24<00:00,  3.51it/s]


processing error_epochs_others_pre_12.fif


100%|██████████| 260/260 [01:13<00:00,  3.52it/s]
