# *Pre-processing and Feature Extraction*

### Import Utility Methods

In [1]:
from utility.preProcessor import *
from utility.featureExtractor import *

In [3]:
import os
import re
import requests
from bs4 import BeautifulSoup
from urllib.parse import urljoin

def find_files(url, headers):
    # Access the directory URL
    response = requests.get(url, auth=(headers['user'], headers['passwd']))
    soup = BeautifulSoup(response.text, features="html.parser")
    
    # Separate files and directories
    hrefs_files = []
    hrefs_dirs = []
    
    for link in soup.find_all('a'):
        href = link.get('href')
        if href and not href.startswith('.'):
            if href.endswith('/'):
                hrefs_dirs.append(href.strip('/'))
            else:
                hrefs_files.append(href)
    return hrefs_files, hrefs_dirs

def download_file(download_file_url, file_path, headers, output=False):
    if output:
        print('Downloading:', download_file_url)
    r = requests.get(download_file_url, auth=(headers['user'], headers['passwd']))
    with open(file_path, 'wb') as f:
        f.write(r.content)

def download_TUH(DOWNLOAD_DIR, headers, sub_dir='', output=False):
    # Base URL for the dataset
    base_url = 'https://isip.piconepress.com/projects/nedc/data/tuh_eeg/tuh_eeg_seizure/v2.0.3/edf/'
    dir_url = urljoin(base_url, sub_dir)
    
    # Clean up export_dir path for local storage
    export_dir = os.path.join(DOWNLOAD_DIR, re.sub(r'.*edf/', '', sub_dir))
    
    if not os.path.exists(export_dir):
        os.makedirs(export_dir, exist_ok=True)

    # Get lists of files and directories
    files, dirs = find_files(dir_url, headers)
    
    # Download all files in the current directory
    for file in files:
        if re.search(r'\.xlsx$|\.edf$|\.txt$|\.tse(?!_)', file):
            file_path = os.path.join(export_dir, file)
            if not os.path.exists(file_path):
                download_file(urljoin(dir_url, file), file_path, headers, output)

    # Recursively process each subdirectory
    for subfolder in dirs:
        next_sub_dir = os.path.join(sub_dir, subfolder)
        download_TUH(DOWNLOAD_DIR, headers, next_sub_dir, output)


In [4]:
from getpass import getpass
import os
import sys
import os
from bs4 import BeautifulSoup
import requests
import re
import wget
import zipfile


DOWNLOAD_DIR = os.path.expanduser('tuh_data')  # Set a local path

if not os.path.exists(DOWNLOAD_DIR):
  os.makedirs(DOWNLOAD_DIR)

user = "nedc-tuh-eeg"
key = "RLYF8ZhBMZwNnsYA8FsP"

auth_dict = {'user': user, 'passwd': key}

download_TUH(DOWNLOAD_DIR, auth_dict, '', output=True)


ConnectTimeout: HTTPSConnectionPool(host='isip.piconepress.com', port=443): Max retries exceeded with url: /projects/nedc/data/tuh_eeg/tuh_eeg_seizure/v2.0.3/edf/train%5Caaaaastr (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x0000025F7718FA40>, 'Connection to isip.piconepress.com timed out. (connect timeout=None)'))

### *Process and Export Feature Dump*

In [6]:
TUH_FILT_OVERWRITE = True  # or False, depending on your desired behavior
TUH_FEAT_OVERWRITE = True  # or False, depending on your desired behavior
TUH_UDWT_OVERWRITE = True  # or False, depending on your desired behavior
TUH_FILT_SAVE_PATH = 'filtered_data.h5'  # Replace with your desired file path
TUH_FEAT_SAVE_PATH = 'features_data.h5'  # Replace with your desired file path
TUH_UDWT_SAVE_PATH = 'udwt_data.h5'  # Replace with your desired file path
TUH_code = 'fnsz' # or another seizure type if you wish

In [9]:

import os
import re
import pyedflib
import numpy as np
import pandas as pd
from scipy import signal
import tables
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm
import requests
from bs4 import BeautifulSoup
import warnings
# Import the getpass module
from getpass import getpass

def find_files(url, headers):
    # get a soup of the directory url
    soup = BeautifulSoup(requests.get(url, auth=(headers['user'], headers['passwd'])).text, 
                         features="html.parser")
    # make a list of all the links in the url
    hrefs_list = []
    for link in soup.find_all('a'):
        hrefs_list.append(link.get('href'))

    return hrefs_list
    
    
def download_file(download_file_url, file_path, headers, output=False):
    if output:
        # print it is downloading
        print('Downloading: '+ download_file_url)
    # download the file to the directory
    r = requests.get(download_file_url, auth=(headers['user'], headers['passwd']))
    with open(file_path, 'wb') as f:
      f.write(r.content)

# needs a directory to download it to
def download_TUH(DIR, headers, sub_dir, output=False):
    
    # directory url
    dir_url = 'https://www.isip.piconepress.com/projects/tuh_eeg/downloads/tuh_eeg_seizure/v1.5.0/'+sub_dir

    hrefs_dir_list = find_files(dir_url, headers)
    
    # for each link in the directory
    for link in hrefs_dir_list:
        # download the files outside of participant folders we want
        if re.findall('.xlsx|\.edf|\.tse(?!_)', str(link)):
            # if the file doesnt already exist in the directory
            if not os.path.exists(os.path.join(DIR, link)):
                download_file(dir_url+'/'+str(link), DIR+'/'+str(link), headers, output)

def data_load(data_file, selected_channels=[]):

    try:
        # use the reader to get an EdfReader file
        f = pyedflib.EdfReader(data_file)

        # get the names of the signals
        channel_names = f.getSignalLabels()
        # get the sampling frequencies of each signal
        channel_freq = f.getSampleFrequencies()
        
        # get a list of the EEG channels
        if len(selected_channels) == 0:
            selected_channels = channel_names

        # make an empty file of 0's
        sigbufs = np.zeros((f.getNSamples()[0],len(selected_channels)))
        # for each of the channels in the selected channels
        for i, channel in enumerate(selected_channels):
        
            try:
              # add the channel data into the array
              sigbufs[:, i] = f.readSignal(channel_names.index(channel))
            
            except:
              ValueError
              # This happens if the sampling rate of that channel is 
              # different to the others.
              # For simplicity, in this case we just make it na.
              sigbufs[:, i] = np.nan

        # turn to a pandas df and save a little space
        df = pd.DataFrame(sigbufs, columns = selected_channels).astype('float32')

        # get equally increasing numbers upto the length of the data depending
        # on the length of the data divided by the sampling frequency
        index_increase = np.linspace(0,
                                      len(df)/channel_freq[0],
                                      len(df), endpoint=False)

        # round these to the lowest nearest decimal to get the seconds
        seconds = np.floor(index_increase).astype('uint16')

        seconds = index_increase
        
        # make a column the timestamp
        df['Time'] = seconds

        # make the time stamp the index
        df = df.set_index('Time')

        # name the columns as channel
        df.columns.name = 'Channel'

        return df, channel_freq[0]

    except OSError as error:
        print('Error '+data_file+': '+str(error))
        return pd.DataFrame(), None
    except ValueError as error:
        print('Error '+data_file+': '+str(error))
        return pd.DataFrame(), None

def create_events(file_name, df, code = None):

    data_y = pd.Series(index=df.index)
    data_y.name = 'Events'

    events_tse = pd.read_csv(file_name, 
                             skiprows=1,
                             sep = ' ',
                             header=None,
                             names =['Start', 'End', 'Code', 'Certainty'])
    
    data_y = data_y.fillna('bckg')
    
    for pos, row in events_tse.iterrows():
        # if you want to manually set the code
        if code != None:
          if row['Code'] == code:
              data_y[row['Start']:row['End']] = code
        # let it be the code it is in the event file
        else:
          data_y[row['Start']:row['End']] = row['Code']

    return data_y

def window_y(events, window_size, overlap, target=None, baseline=None):
    
  # window the data so each row is another epoch
  events_windowed = window(events, w = window_size, o = overlap, copy = True)
  
  if target:
    # turn to array of bools if seizure in the
    # windowed data
    bools = events_windowed == target
    # are there any seizure seconds in the data?
    data_y = np.any(bools,axis=1)
    # turn to 0's and 1's
    data_y = data_y.astype(int)
    # expand the dimensions so running down one column
    data_y = np.expand_dims(data_y, axis=1)
  
  elif baseline:
    # replace all baseline labels to nan
    data_y = pd.DataFrame(events_windowed).replace(baseline, np.nan)
    # get the most common other than baseline
    data_y = data_y.mode(1)
    # change nan back to baseline class
    data_y = data_y.fillna(baseline).values
    # if there was nothing but baseline there will be an empty array
    if data_y.size == 0:
        data_y = np.array([baseline]*data_y.shape[0])
        data_y = np.expand_dims(data_y, -1)
  
  else:
    # get the value most frequent in the window
    data_y = pd.DataFrame(events_windowed).mode(1).values

  return data_y

def downsample(data_x, data_y, freq):
    if freq > 256:
        if freq >= 1000:
          subsample = 4
        else:
          subsample = 2

        freq = freq/subsample
        data_x = data_x[::subsample]
        data_y = data_y[::subsample]

    return data_x, data_y, freq

def sel_file_list(set_name, seiz_type):
  # load the training information
  train_info = pd.read_excel('dataset/tuh_data/DOCS/seizures_types_v02.xlsx', set_name)
  # just want the info per file here
  file_info = train_info.iloc[1:6101,1:15]
  # cleans some of the names
  file_info_cols = ['File No.', 'Patient', 'Session', 'File', 
                        'EEG Type', 'EEG SubType', 'LTM or Routine', 
                        'Normal/Abnormal', 'No. Seizures File', 
                        'No. Seizures/Session', 'Filename', 'Seizure Start', 
                        'Seizure Stop', 'Seizure Type']
  file_info.columns = file_info_cols

  # we forward fill as there are gaps in the excel file to represent the info 
  # is the same as above (apart from in the filename, seizure start, seizure stop 
  # and seizure type columns)
  for col_name in file_info.columns[:-4]:
    file_info[col_name] = file_info[col_name].ffill()

  # patient ID is an integer rather than float
  file_info['Patient'] = file_info['Patient'].astype(int)

  if seiz_type:
    # Get a list of patient event files that have a specifc type of seizure
    return list(file_info[file_info['Seizure Type']==seiz_type]['Filename'])
  else:
    return list(file_info['Filename'])

def save_to_database(save_dir, file_title, group, data_x, data_y, 
                     feature_columns=None):

    # open the file in append mode (make it if doesnt exist)
    h5file = tables.open_file(save_dir, mode="a", title=file_title)
    
    # save space
    data_x = data_x.astype(np.float32)
    #data_y = data_y.astype(np.int16)
    
    # get filters to compress file
    filters = tables.Filters(complevel=1, complib='zlib')
    
    # if there is already a node for the particpant...
    if "/"+group in h5file:
        # ...put in the directory of where it is found
        part_x_array = h5file.get_node("/" + group + '/Data_x')
        part_y_array = h5file.get_node("/" + group + '/Data_y')
    
    else:
        # create the group directory
        part_group = h5file.create_group("/", group, 'Group Data')
        # make an atom which has the datatype found in the data we want to store
        x_atom = tables.Atom.from_dtype(data_x.dtype)
        y_atom = tables.Atom.from_dtype(data_y.dtype)

        if file_title == 'UDWT_Data':
          shape = (0,data_x.shape[1], data_x.shape[2], data_x.shape[3])
        else:
          shape = (0,data_x.shape[1], data_x.shape[2])
        
        # create an array we can append onto later
        part_x_array = h5file.create_earray("/" + group,                   # parentnode
                                            'Data_x',                        # name 
                                            x_atom,                          # atom
                                            shape, # shape
                                            'Feature Array',
                                            filters=filters
                                           )                 # title

        part_y_array = h5file.create_earray("/" + group, 
                                            'Data_y', 
                                            y_atom, 
                                            (0,),
                                            'Events Array',
                                            filters=filters
                                           )
        
        if feature_columns:
          # create the feature names array (we only need to do this once)
          h5file.create_array("/" + group,                                   # where
                              'Feature_Names',                             # name 
                              np.array(feature_columns, dtype='unicode'),  # obj
                              "Names of Each Feature")                         # title
    
    # append the data to the array directory
    part_x_array.append(data_x)
    part_y_array.append(data_y)
    
    # flush the data to disk
    h5file.flush()
    # close the file
    h5file.close()

def udwt_spectrogram(data, waveletname, level, window_size):
  data_ucwt = np.ndarray(shape=(data.shape[0], level, data.shape[1], data.shape[2]))
  
  for ii in range(data.shape[0]):
    for jj in range(data.shape[-1]):
      signal = data[ii, :, jj]
      coeffs_list = swt(signal, waveletname, level=level)
      
      coeffs_array = np.zeros((len(coeffs_list), window_size))
      
      for i, array_tuple in enumerate(coeffs_list[::-1]):

        coeffs_array[i,:] = np.array(array_tuple)[1,:]
      
      power = np.abs(coeffs_array)**2
      
      data_ucwt[ii, :, :, jj] = power
      
  return data_ucwt


if TUH_FILT_OVERWRITE or TUH_FEAT_OVERWRITE or TUH_UDWT_OVERWRITE:

  if TUH_FILT_OVERWRITE and os.path.exists(TUH_FILT_SAVE_PATH):
    os.remove(TUH_FILT_SAVE_PATH)

  if TUH_FEAT_OVERWRITE and os.path.exists(TUH_FEAT_SAVE_PATH):
    os.remove(TUH_FEAT_SAVE_PATH)

  if TUH_UDWT_OVERWRITE and os.path.exists(TUH_UDWT_SAVE_PATH):
    os.remove(TUH_UDWT_SAVE_PATH)

  # ---------
  # TUH SETUP
  # ---------
  DOWNLOAD_DIR = "TUH Database"
  if not os.path.exists(DOWNLOAD_DIR):
    os.makedirs(DOWNLOAD_DIR)

  user = getpass('TUH Username: ')
  key = getpass('TUH Password: ')
  auth_dict = {'user': user, 'passwd': key}

  # --------------
  # GET FILE PATHS
  # --------------
  # download info files
  download_TUH(DOWNLOAD_DIR, auth_dict, '_DOCS')

  seiz_types_path = 'dataset/tuh_data/DOCS/seizures_types_v02.xlsx'
  seiz_types = pd.read_excel(seiz_types_path)

  seiz_types = seiz_types.set_index('Class Code')

  int_code = seiz_types.to_dict()['Class No.']
  # change to lower case
  int_code = { k.lower() : v for k,v in int_code.items() if not isinstance(k, float)}

  # get a list of files
  tuh_file_list = sel_file_list('train', TUH_code)+sel_file_list('dev_test', TUH_code)

  # get a list of the montages
  montage = []
  for file in tuh_file_list:
    montage.append(file.split('/')[3])

  # count how many times the montages appear in the data
  montage_counts = pd.Series(montage).value_counts()

  # remove all files apart from those in the most common montage
  regex = re.compile(montage_counts.index[0])
  tuh_file_list = [i for i in tuh_file_list if regex.search(i)]
  # remove duplicates
  tuh_file_list = list(set(tuh_file_list))

  # They changed "dev_test" to just "dev" in their file structure so 
  # I need to account for this now...
  for i, string in enumerate(tuh_file_list):
    tuh_file_list[i] = re.sub('_test', '', string)


  # --------------------
  # GET SIMILAR CHANNELS
  # --------------------
  # this is to make sure all the data have the same channels
  all_channels = []
  for events_path in tqdm(tuh_file_list, desc = 'Finding Channels'):
    file_ID = events_path.split('/')[-1][:-4]
    # we use the above to get the file directory this file is in
    pat_file_dir = '/'.join(events_path.split('/')[1:-1])
    # this will download all edf and event files for the selected patient
    download_TUH(DOWNLOAD_DIR, auth_dict, pat_file_dir, output=True)
    
    with pyedflib.EdfReader(DOWNLOAD_DIR+'/'+file_ID+'.edf') as f:
        # get the names of the signals
        all_channels.extend(f.getSignalLabels())

  # turn the list into a pandas series
  all_channels = pd.Series(all_channels)

  # count how many times the channels appear in each participant
  channel_counts = all_channels.value_counts()
  
  # threshold the channels to only those found in all raw data
  channel_keeps = list(channel_counts[channel_counts >= channel_counts[0]].index)
  regex = re.compile('30|PHOTIC|EKG|PG')
  channel_keeps = [i for i in channel_keeps if not regex.search(i)]

  # ---------------
  # CREATE FEATURES
  # ---------------
  for events_path in tqdm(tuh_file_list, desc='Creating Features'):
    file_ID = events_path.split('/')[-1][:-4]
    # we use the above to get the file directory this file is in
    pat_file_dir = '/'.join(events_path.split('/')[1:-1])
    # patient ID
    pat_ID = events_path.split('/')[-3]

    # this will download all edf and event files for the selected patient
    #download_TUH(DOWNLOAD_DIR, auth_dict, pat_file_dir, output=True)

    # load data
    raw_data, freq = data_load(DOWNLOAD_DIR+'/'+file_ID+'.edf', channel_keeps)

    if raw_data.empty:
      print('Skipped: '+file_ID)
    else:
      raw_events = create_events(DOWNLOAD_DIR+'/'+file_ID+'.tse', raw_data)
      # change to integer representation
      raw_events = raw_events.replace(int_code)

      if TUH_FILT_OVERWRITE or TUH_UDWT_OVERWRITE:

        # downsample
        down_x, down_y, down_freq = downsample(raw_data.values, raw_events.values, freq)

        # window y
        window_size = 256*2
        overlap = 256
        
        if TUH_FILT_OVERWRITE:
          # filter the data
          b, a = signal.butter(4, [1/(down_freq/2), 30/(down_freq/2)], 'bandpass', analog=False)
          filt_data = signal.filtfilt(b, a, down_x.T).T
          
          # scale the data over each channel
          SS = StandardScaler()
          scaled_data = SS.fit_transform(filt_data)
          scaled_data = pd.DataFrame(scaled_data, columns = raw_data.columns, index = down_y)
          
          # drop na
          scaled_data = scaled_data.dropna()
          
          # window x
          data_x = window_x(scaled_data, window_size, overlap)
          data_y = window_y(scaled_data.index.values, window_size, overlap, target=None, 
                            baseline=6)
          
          # to stop it printing warnings like: object name is not a valid 
          # Python identifier: '00001113'
          with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=tables.NaturalNameWarning)
            # save the data
            save_to_database(TUH_FILT_SAVE_PATH, 'Filt_Data', pat_ID, data_x, 
                            data_y[:,0], list(scaled_data.columns))
          
        # TODO: add in a dropna bit
        if TUH_UDWT_OVERWRITE:
          # window x
          data_x = window_x(pd.DataFrame(down_x, columns = raw_data.columns), 
                            window_size, overlap)
          data_y = window_y(down_y, window_size, overlap, target=None, 
                          baseline=6)

          udwt_data = udwt_spectrogram(data_x, 'db4', 6, window_size)

          # get the shape of this data
          orig_shape = udwt_data.shape
          # reshape to merge the levels and data
          udwt_data_reshape = np.reshape(udwt_data, (-1, orig_shape[-1]))
          # scale across channels
          SS = StandardScaler()
          udwt_data_scaled = SS.fit_transform(udwt_data_reshape)
          # shape the data back
          udwt_data_scaled = np.reshape(udwt_data_scaled, orig_shape)

          #display(udwt_data_scaled.shape)

          # to stop it printing warnings like: object name is not a valid 
          # Python identifier: '00001113'
          with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=tables.NaturalNameWarning)
            # save the data
            save_to_database(TUH_UDWT_SAVE_PATH, 'UDWT_Data', 
                            pat_ID,
                            #file_ID, # had to use file ID as appending to an existing array was not working 
                            udwt_data_scaled, data_y[:,0])

        
      if TUH_FEAT_OVERWRITE:
        # TODO: in 00001795 relative_power = bandpass_2/bandpass_1 creates an inf
        # due to dividing by 0. Fix this with a checks in the function, but for now,
        # for a quick patch I'll just not use power_ratio.
        feat = Seizure_Features(sf = freq,
                                window_size=2,
                                overlap=1,
                                levels=6,
                                bandpasses = [[1,4],[4,8],[8,12],
                                              [12,30],[30,70]],
                                feature_list=['power', #'power_ratio', 
                                              'mean', 'mean_abs', 
                                              'std', 'ratio', 'LSWT', 'fft_corr', 
                                              'fft_eigen', 'time_corr', 'time_eigen'],
                                scale = True,
                                baseline=6)
        

        
        # just to ignore the runtime warnings about na's
        with warnings.catch_warnings():
          warnings.filterwarnings("ignore", category=RuntimeWarning)
          
          x_feat, y_feat = feat.transform(raw_data.values, 
                                          raw_events.values,
                                          channel_names_list = list(raw_data.columns))

        
        feat_df = pd.DataFrame(x_feat,
                                index=y_feat[:,0], 
                                columns = feat.feature_names)

        feat_df = feat_df.dropna()
        
        
        # to stop it printing warnings like: object name is not a valid 
        # Python identifier: '00001113'
        with warnings.catch_warnings():
          warnings.filterwarnings("ignore", category=tables.NaturalNameWarning)
          feat_df.to_hdf(TUH_FEAT_SAVE_PATH, 
                        pat_ID,
                        format='table', append=True)

  if re.findall('.xlsx|\.edf|\.tse(?!_)', str(link)):


ValueError: Worksheet named 'train' not found

In [None]:
import os
import re
import pyedflib
import numpy as np
import pandas as pd
from scipy import signal
import tables
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm
import requests
from bs4 import BeautifulSoup
import warnings
# Import the getpass module
from getpass import getpass

def find_files(url, headers):
    # get a soup of the directory url
    soup = BeautifulSoup(requests.get(url, auth=(headers['user'], headers['passwd'])).text, 
                         features="html.parser")
    # make a list of all the links in the url
    hrefs_list = []
    for link in soup.find_all('a'):
        hrefs_list.append(link.get('href'))

    return hrefs_list
    
    
def download_file(download_file_url, file_path, headers, output=False):
    if output:
        # print it is downloading
        print('Downloading: '+ download_file_url)
    # download the file to the directory
    r = requests.get(download_file_url, auth=(headers['user'], headers['passwd']))
    with open(file_path, 'wb') as f:
      f.write(r.content)

# needs a directory to download it to
def download_TUH(DIR, headers, sub_dir, output=False):
    
    # directory url
    dir_url = 'https://www.isip.piconepress.com/projects/tuh_eeg/downloads/tuh_eeg_seizure/v1.5.0/'+sub_dir

    hrefs_dir_list = find_files(dir_url, headers)
    
    # for each link in the directory
    for link in hrefs_dir_list:
        # download the files outside of participant folders we want
        if re.findall('.xlsx|\.edf|\.tse(?!_)', str(link)):
            # if the file doesnt already exist in the directory
            if not os.path.exists(os.path.join(DIR, link)):
                download_file(dir_url+'/'+str(link), DIR+'/'+str(link), headers, output)

def data_load(data_file, selected_channels=[]):

    try:
        # use the reader to get an EdfReader file
        f = pyedflib.EdfReader(data_file)

        # get the names of the signals
        channel_names = f.getSignalLabels()
        # get the sampling frequencies of each signal
        channel_freq = f.getSampleFrequencies()
        
        # get a list of the EEG channels
        if len(selected_channels) == 0:
            selected_channels = channel_names

        # make an empty file of 0's
        sigbufs = np.zeros((f.getNSamples()[0],len(selected_channels)))
        # for each of the channels in the selected channels
        for i, channel in enumerate(selected_channels):
        
            try:
              # add the channel data into the array
              sigbufs[:, i] = f.readSignal(channel_names.index(channel))
            
            except:
              ValueError
              # This happens if the sampling rate of that channel is 
              # different to the others.
              # For simplicity, in this case we just make it na.
              sigbufs[:, i] = np.nan

        # turn to a pandas df and save a little space
        df = pd.DataFrame(sigbufs, columns = selected_channels).astype('float32')

        # get equally increasing numbers upto the length of the data depending
        # on the length of the data divided by the sampling frequency
        index_increase = np.linspace(0,
                                      len(df)/channel_freq[0],
                                      len(df), endpoint=False)

        # round these to the lowest nearest decimal to get the seconds
        seconds = np.floor(index_increase).astype('uint16')

        seconds = index_increase
        
        # make a column the timestamp
        df['Time'] = seconds

        # make the time stamp the index
        df = df.set_index('Time')

        # name the columns as channel
        df.columns.name = 'Channel'

        return df, channel_freq[0]

    except OSError as error:
        print('Error '+data_file+': '+str(error))
        return pd.DataFrame(), None
    except ValueError as error:
        print('Error '+data_file+': '+str(error))
        return pd.DataFrame(), None

def create_events(file_name, df, code = None):

    data_y = pd.Series(index=df.index)
    data_y.name = 'Events'

    events_tse = pd.read_csv(file_name, 
                             skiprows=1,
                             sep = ' ',
                             header=None,
                             names =['Start', 'End', 'Code', 'Certainty'])
    
    data_y = data_y.fillna('bckg')
    
    for pos, row in events_tse.iterrows():
        # if you want to manually set the code
        if code != None:
          if row['Code'] == code:
              data_y[row['Start']:row['End']] = code
        # let it be the code it is in the event file
        else:
          data_y[row['Start']:row['End']] = row['Code']

    return data_y

def window_y(events, window_size, overlap, target=None, baseline=None):
    
  # window the data so each row is another epoch
  events_windowed = window(events, w = window_size, o = overlap, copy = True)
  
  if target:
    # turn to array of bools if seizure in the
    # windowed data
    bools = events_windowed == target
    # are there any seizure seconds in the data?
    data_y = np.any(bools,axis=1)
    # turn to 0's and 1's
    data_y = data_y.astype(int)
    # expand the dimensions so running down one column
    data_y = np.expand_dims(data_y, axis=1)
  
  elif baseline:
    # replace all baseline labels to nan
    data_y = pd.DataFrame(events_windowed).replace(baseline, np.nan)
    # get the most common other than baseline
    data_y = data_y.mode(1)
    # change nan back to baseline class
    data_y = data_y.fillna(baseline).values
    # if there was nothing but baseline there will be an empty array
    if data_y.size == 0:
        data_y = np.array([baseline]*data_y.shape[0])
        data_y = np.expand_dims(data_y, -1)
  
  else:
    # get the value most frequent in the window
    data_y = pd.DataFrame(events_windowed).mode(1).values

  return data_y

def downsample(data_x, data_y, freq):
    if freq > 256:
        if freq >= 1000:
          subsample = 4
        else:
          subsample = 2

        freq = freq/subsample
        data_x = data_x[::subsample]
        data_y = data_y[::subsample]

    return data_x, data_y, freq

def sel_file_list(set_name, seiz_type):
  # load the training information
  train_info = pd.read_excel('dataset/tuh_data/DOCS/seizures_types_v02.xlsx', set_name)
  # just want the info per file here
  file_info = train_info.iloc[1:6101,1:15]
  # cleans some of the names
  file_info_cols = ['File No.', 'Patient', 'Session', 'File', 
                        'EEG Type', 'EEG SubType', 'LTM or Routine', 
                        'Normal/Abnormal', 'No. Seizures File', 
                        'No. Seizures/Session', 'Filename', 'Seizure Start', 
                        'Seizure Stop', 'Seizure Type']
  file_info.columns = file_info_cols

  # we forward fill as there are gaps in the excel file to represent the info 
  # is the same as above (apart from in the filename, seizure start, seizure stop 
  # and seizure type columns)
  for col_name in file_info.columns[:-4]:
    file_info[col_name] = file_info[col_name].ffill()

  # patient ID is an integer rather than float
  file_info['Patient'] = file_info['Patient'].astype(int)

  if seiz_type:
    # Get a list of patient event files that have a specifc type of seizure
    return list(file_info[file_info['Seizure Type']==seiz_type]['Filename'])
  else:
    return list(file_info['Filename'])

def save_to_database(save_dir, file_title, group, data_x, data_y, 
                     feature_columns=None):

    # open the file in append mode (make it if doesnt exist)
    h5file = tables.open_file(save_dir, mode="a", title=file_title)
    
    # save space
    data_x = data_x.astype(np.float32)
    #data_y = data_y.astype(np.int16)
    
    # get filters to compress file
    filters = tables.Filters(complevel=1, complib='zlib')
    
    # if there is already a node for the particpant...
    if "/"+group in h5file:
        # ...put in the directory of where it is found
        part_x_array = h5file.get_node("/" + group + '/Data_x')
        part_y_array = h5file.get_node("/" + group + '/Data_y')
    
    else:
        # create the group directory
        part_group = h5file.create_group("/", group, 'Group Data')
        # make an atom which has the datatype found in the data we want to store
        x_atom = tables.Atom.from_dtype(data_x.dtype)
        y_atom = tables.Atom.from_dtype(data_y.dtype)

        if file_title == 'UDWT_Data':
          shape = (0,data_x.shape[1], data_x.shape[2], data_x.shape[3])
        else:
          shape = (0,data_x.shape[1], data_x.shape[2])
        
        # create an array we can append onto later
        part_x_array = h5file.create_earray("/" + group,                   # parentnode
                                            'Data_x',                        # name 
                                            x_atom,                          # atom
                                            shape, # shape
                                            'Feature Array',
                                            filters=filters
                                           )                 # title

        part_y_array = h5file.create_earray("/" + group, 
                                            'Data_y', 
                                            y_atom, 
                                            (0,),
                                            'Events Array',
                                            filters=filters
                                           )
        
        if feature_columns:
          # create the feature names array (we only need to do this once)
          h5file.create_array("/" + group,                                   # where
                              'Feature_Names',                             # name 
                              np.array(feature_columns, dtype='unicode'),  # obj
                              "Names of Each Feature")                         # title
    
    # append the data to the array directory
    part_x_array.append(data_x)
    part_y_array.append(data_y)
    
    # flush the data to disk
    h5file.flush()
    # close the file
    h5file.close()

def udwt_spectrogram(data, waveletname, level, window_size):
  data_ucwt = np.ndarray(shape=(data.shape[0], level, data.shape[1], data.shape[2]))
  
  for ii in range(data.shape[0]):
    for jj in range(data.shape[-1]):
      signal = data[ii, :, jj]
      coeffs_list = swt(signal, waveletname, level=level)
      
      coeffs_array = np.zeros((len(coeffs_list), window_size))
      
      for i, array_tuple in enumerate(coeffs_list[::-1]):

        coeffs_array[i,:] = np.array(array_tuple)[1,:]
      
      power = np.abs(coeffs_array)**2
      
      data_ucwt[ii, :, :, jj] = power
      
  return data_ucwt


if TUH_FILT_OVERWRITE or TUH_FEAT_OVERWRITE or TUH_UDWT_OVERWRITE:

  if TUH_FILT_OVERWRITE and os.path.exists(TUH_FILT_SAVE_PATH):
    os.remove(TUH_FILT_SAVE_PATH)

  if TUH_FEAT_OVERWRITE and os.path.exists(TUH_FEAT_SAVE_PATH):
    os.remove(TUH_FEAT_SAVE_PATH)

  if TUH_UDWT_OVERWRITE and os.path.exists(TUH_UDWT_SAVE_PATH):
    os.remove(TUH_UDWT_SAVE_PATH)

  # ---------
  # TUH SETUP
  # ---------
  DOWNLOAD_DIR = "TUH Database"
  if not os.path.exists(DOWNLOAD_DIR):
    os.makedirs(DOWNLOAD_DIR)

  user = getpass('TUH Username: ')
  key = getpass('TUH Password: ')
  auth_dict = {'user': user, 'passwd': key}

  # --------------
  # GET FILE PATHS
  # --------------
  # download info files
  download_TUH(DOWNLOAD_DIR, auth_dict, '_DOCS')

  seiz_types_path = 'dataset/tuh_data/DOCS/seizures_types_v02.xlsx'
  seiz_types = pd.read_excel(seiz_types_path)

  seiz_types = seiz_types.set_index('Class Code')

  int_code = seiz_types.to_dict()['Class No.']
  # change to lower case
  int_code = { k.lower() : v for k,v in int_code.items() if not isinstance(k, float)}

  # get a list of files
  tuh_file_list = sel_file_list('train', TUH_code)+sel_file_list('dev_test', TUH_code)

  # get a list of the montages
  montage = []
  for file in tuh_file_list:
    montage.append(file.split('/')[3])

  # count how many times the montages appear in the data
  montage_counts = pd.Series(montage).value_counts()

  # remove all files apart from those in the most common montage
  regex = re.compile(montage_counts.index[0])
  tuh_file_list = [i for i in tuh_file_list if regex.search(i)]
  # remove duplicates
  tuh_file_list = list(set(tuh_file_list))

  # They changed "dev_test" to just "dev" in their file structure so 
  # I need to account for this now...
  for i, string in enumerate(tuh_file_list):
    tuh_file_list[i] = re.sub('_test', '', string)


  # --------------------
  # GET SIMILAR CHANNELS
  # --------------------
  # this is to make sure all the data have the same channels
  all_channels = []
  for events_path in tqdm(tuh_file_list, desc = 'Finding Channels'):
    file_ID = events_path.split('/')[-1][:-4]
    # we use the above to get the file directory this file is in
    pat_file_dir = '/'.join(events_path.split('/')[1:-1])
    # this will download all edf and event files for the selected patient
    download_TUH(DOWNLOAD_DIR, auth_dict, pat_file_dir, output=True)
    
    with pyedflib.EdfReader(DOWNLOAD_DIR+'/'+file_ID+'.edf') as f:
        # get the names of the signals
        all_channels.extend(f.getSignalLabels())

  # turn the list into a pandas series
  all_channels = pd.Series(all_channels)

  # count how many times the channels appear in each participant
  channel_counts = all_channels.value_counts()
  
  # threshold the channels to only those found in all raw data
  channel_keeps = list(channel_counts[channel_counts >= channel_counts[0]].index)
  regex = re.compile('30|PHOTIC|EKG|PG')
  channel_keeps = [i for i in channel_keeps if not regex.search(i)]

  # ---------------
  # CREATE FEATURES
  # ---------------
  for events_path in tqdm(tuh_file_list, desc='Creating Features'):
    file_ID = events_path.split('/')[-1][:-4]
    # we use the above to get the file directory this file is in
    pat_file_dir = '/'.join(events_path.split('/')[1:-1])
    # patient ID
    pat_ID = events_path.split('/')[-3]

    # this will download all edf and event files for the selected patient
    #download_TUH(DOWNLOAD_DIR, auth_dict, pat_file_dir, output=True)

    # load data
    raw_data, freq = data_load(DOWNLOAD_DIR+'/'+file_ID+'.edf', channel_keeps)

    if raw_data.empty:
      print('Skipped: '+file_ID)
    else:
      raw_events = create_events(DOWNLOAD_DIR+'/'+file_ID+'.tse', raw_data)
      # change to integer representation
      raw_events = raw_events.replace(int_code)

      if TUH_FILT_OVERWRITE or TUH_UDWT_OVERWRITE:

        # downsample
        down_x, down_y, down_freq = downsample(raw_data.values, raw_events.values, freq)

        # window y
        window_size = 256*2
        overlap = 256
        
        if TUH_FILT_OVERWRITE:
          # filter the data
          b, a = signal.butter(4, [1/(down_freq/2), 30/(down_freq/2)], 'bandpass', analog=False)
          filt_data = signal.filtfilt(b, a, down_x.T).T
          
          # scale the data over each channel
          SS = StandardScaler()
          scaled_data = SS.fit_transform(filt_data)
          scaled_data = pd.DataFrame(scaled_data, columns = raw_data.columns, index = down_y)
          
          # drop na
          scaled_data = scaled_data.dropna()
          
          # window x
          data_x = window_x(scaled_data, window_size, overlap)
          data_y = window_y(scaled_data.index.values, window_size, overlap, target=None, 
                            baseline=6)
          
          # to stop it printing warnings like: object name is not a valid 
          # Python identifier: '00001113'
          with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=tables.NaturalNameWarning)
            # save the data
            save_to_database(TUH_FILT_SAVE_PATH, 'Filt_Data', pat_ID, data_x, 
                            data_y[:,0], list(scaled_data.columns))
          
        # TODO: add in a dropna bit
        if TUH_UDWT_OVERWRITE:
          # window x
          data_x = window_x(pd.DataFrame(down_x, columns = raw_data.columns), 
                            window_size, overlap)
          data_y = window_y(down_y, window_size, overlap, target=None, 
                          baseline=6)

          udwt_data = udwt_spectrogram(data_x, 'db4', 6, window_size)

          # get the shape of this data
          orig_shape = udwt_data.shape
          # reshape to merge the levels and data
          udwt_data_reshape = np.reshape(udwt_data, (-1, orig_shape[-1]))
          # scale across channels
          SS = StandardScaler()
          udwt_data_scaled = SS.fit_transform(udwt_data_reshape)
          # shape the data back
          udwt_data_scaled = np.reshape(udwt_data_scaled, orig_shape)

          #display(udwt_data_scaled.shape)

          # to stop it printing warnings like: object name is not a valid 
          # Python identifier: '00001113'
          with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=tables.NaturalNameWarning)
            # save the data
            save_to_database(TUH_UDWT_SAVE_PATH, 'UDWT_Data', 
                            pat_ID,
                            #file_ID, # had to use file ID as appending to an existing array was not working 
                            udwt_data_scaled, data_y[:,0])

        
      if TUH_FEAT_OVERWRITE:
        # TODO: in 00001795 relative_power = bandpass_2/bandpass_1 creates an inf
        # due to dividing by 0. Fix this with a checks in the function, but for now,
        # for a quick patch I'll just not use power_ratio.
        feat = Seizure_Features(sf = freq,
                                window_size=2,
                                overlap=1,
                                levels=6,
                                bandpasses = [[1,4],[4,8],[8,12],
                                              [12,30],[30,70]],
                                feature_list=['power', #'power_ratio', 
                                              'mean', 'mean_abs', 
                                              'std', 'ratio', 'LSWT', 'fft_corr', 
                                              'fft_eigen', 'time_corr', 'time_eigen'],
                                scale = True,
                                baseline=6)
        

        
        # just to ignore the runtime warnings about na's
        with warnings.catch_warnings():
          warnings.filterwarnings("ignore", category=RuntimeWarning)
          
          x_feat, y_feat = feat.transform(raw_data.values, 
                                          raw_events.values,
                                          channel_names_list = list(raw_data.columns))

        
        feat_df = pd.DataFrame(x_feat,
                                index=y_feat[:,0], 
                                columns = feat.feature_names)

        feat_df = feat_df.dropna()
        
        
        # to stop it printing warnings like: object name is not a valid 
        # Python identifier: '00001113'
        with warnings.catch_warnings():
          warnings.filterwarnings("ignore", category=tables.NaturalNameWarning)
          feat_df.to_hdf(TUH_FEAT_SAVE_PATH, 
                        pat_ID,
                        format='table', append=True) from getpass import getpass
          
import os
import sys
import os
from bs4 import BeautifulSoup
import requests
import re
import wget
import zipfile


DOWNLOAD_DIR = os.path.expanduser('tuh_data')  # Set a local path

if not os.path.exists(DOWNLOAD_DIR):
  os.makedirs(DOWNLOAD_DIR)

user = "nedc-tuh-eeg"
key = "RLYF8ZhBMZwNnsYA8FsP"

auth_dict = {'user': user, 'passwd': key}

download_TUH(DOWNLOAD_DIR, auth_dict, '', output=True)

  if re.findall('.xlsx|\.edf|\.tse(?!_)', str(link)):


SyntaxError: invalid syntax (483134918.py, line 519)