In [6]:
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import pandas as pd

import scipy
import pywt
import mne

import sys
from tqdm import tqdm


sys.path.append('../utils')
from ERP_utils import *
from update_sub_lists import *
from wavelet_utils import *
import glob
import os


#import seaborn as sns
import warnings

# Suppress all FutureWarnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [7]:
task = 'motor'
erp_window = '_n05to05'
freq_high = 30

###### CHOOSE SUBJECTS
#subjects_to_process = find_existing_subjects(task = task, period = 'pre',erp_window=erp_window,freq_high=freq_high) #subs that exist in pre should exist in post
subjects_to_process = ['10', '11', '12']

##### CHOOSE CHANNELS
ch_to_process = 'all'
#ch_to_process = ch_index(['Cz'])



wavelet_params = {
    'fs' :128 , # example sampling frequency in Hz
    'centerfreq' : 1 ,
    'bandwidth': 1.5,
    'level': 10,
    'scale_values':[6, 150, 40]
}

epochs_dir = f'/Users/cindyzhang/Documents/M2/Audiomotor_Piano/AM-EEG/analysis_{task}/{task}_epochs_data{erp_window}_{freq_high}Hz'
wavelet_dir = f"./wavelets_mat_{task}_{erp_window}_{freq_high}Hz_time"

print('processing subjects:', subjects_to_process)
print('processing channels:', ch_to_process)

mne.set_log_level('CRITICAL')

processing subjects: ['10', '11', '12']
processing channels: all


## Batch processing wavelets

In [8]:
def wavelet_batch(subjects_to_process, channels, ep_dir, output_dir, wavelet_params, ave = False, overwrite = False, erp_begin = -0.5, erp_end = 0.5):

    """
    wavelet transforms epochs trial by trial for each subject

    
    subjects_to_process: list of subjects
    channels: 'all' or a list of channels. If not all, will only save the wavelet transformed channels and others are discarded
    ep_dir: directory where epochs are saved
    output_dir: dir to save wavelet data
    wavelet_params: dict with wavelet parameters. Example:
        wavelet_params = {
                'fs' :128 , # example sampling frequency in Hz
                'centerfreq' : 5 ,
                'bandwidth': 1.5,
                'level': 10,
                'scale_values':[6, 150, 40]
            }

    ave: whether to average spectrograms all the trials before saving (TO IMPLEMENT)

     ----
    saves wavelet transforms to .mat files
    'wavelet': wavelet data of dim n_trials x n_channels x spect_freqs x spect_times
    """
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)

    times = create_erp_times(erp_begin, erp_end, 128)

    #check file type
    for fif_name in sorted(os.listdir(ep_dir)):
        if not fif_name.endswith(".fif"):
            print('skipping file, not epochs:', fif_name)
            continue

    #identify subject
        subject_id = fif_name.split("_")[-1].split(".")[0]
        if subject_id not in subjects_to_process:
           # print(f'subject {subject_id} not in subjects to process. skipping...')
            continue
        
        mat_name = fif_name.split(".")[0].replace("epochs", "wavelet")
        mat_path = os.path.join(output_dir, f"{mat_name}.mat")

        if not os.path.exists(mat_path) or overwrite: #skip if the file already exists
            print('processing', fif_name)
            
            #get data to loop over
            epochs = mne.read_epochs(os.path.join(ep_dir, fif_name))

            
            epochs_data = epochs.get_data()

            wavelet_data = []
            #for trial in range(10):
            for trial in tqdm(range(epochs_data.shape[0])):
                
                #initiate storage matrix
                n_freqs = wavelet_params['scale_values'][2]
                n_times = times.shape[0]
                if isinstance(channels, str) and channels == 'all':
                    trial_wavelet = np.zeros((64, n_freqs, n_times))
                    ch_towav = np.arange(64) 
                else:
                    trial_wavelet = np.zeros((len(channels), n_freqs, n_times)) 
                    ch_towav = channels

                for j, ch in enumerate(ch_towav):
                    trial_data = epochs_data[trial, ch, :]

                    cwtmatr, freqs, wavelet = morwav_trans(trial_data, 
                                                    centerfreq=wavelet_params['centerfreq'], 
                                                    bandwidth=wavelet_params['bandwidth'], 
                                                    scale_values=wavelet_params['scale_values'])

                    cwtmatr_abs = np.abs(cwtmatr)
                    trial_wavelet[j, :,:] = cwtmatr_abs

                wavelet_data.append(trial_wavelet)

            #save subject data to mat file
            wavelet_data = np.array(wavelet_data)
            wavelet_tosave = {
                'wavelet_transform':wavelet_data
            }

            savemat(mat_path, wavelet_tosave)
    
    #processing record for wavelet
    
    wavelet_record = {
        'freqs': freqs,
        'wavelet': wavelet,
        'subjects': subjects_to_process,
        'centerfreq': wavelet_params['centerfreq'],
        'bandwidth': wavelet_params['bandwidth'],
        'scale_values': wavelet_params['scale_values'],
        'times': times, 
        'channels': channels
    }
    savemat(os.path.join(output_dir, f'wavelet_record.mat'), wavelet_record)


In [9]:

wavelet_batch(subjects_to_process, ch_to_process, epochs_dir, wavelet_dir, wavelet_params, overwrite=True)

processing motor_epochs_post_10.fif


  0%|          | 0/172 [00:00<?, ?it/s]

100%|██████████| 172/172 [00:49<00:00,  3.44it/s]


processing motor_epochs_post_11.fif


100%|██████████| 351/351 [01:41<00:00,  3.46it/s]


processing motor_epochs_post_12.fif


100%|██████████| 258/258 [01:15<00:00,  3.44it/s]


processing motor_epochs_pre_10.fif


100%|██████████| 253/253 [01:14<00:00,  3.39it/s]


processing motor_epochs_pre_11.fif


100%|██████████| 323/323 [01:34<00:00,  3.42it/s]


processing motor_epochs_pre_12.fif


100%|██████████| 287/287 [01:23<00:00,  3.45it/s]
