In [1]:
# Run on Colab

!pip install pyedflib

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pyedflib
  Downloading pyEDFlib-0.1.30-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.4/2.4 MB[0m [31m67.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: pyedflib
Successfully installed pyedflib-0.1.30


In [2]:
import pyedflib
import numpy as np
import pickle
import os

In [3]:
# Run on Colab

from google.colab import drive
drive.mount('/content/drive')
summary_path = "/content/drive/MyDrive/CI_Dataset/Phase3/Summary/"
eeg_files_path = "/content/drive/MyDrive/CI_Dataset/Phase3/EEGFiles/"

Mounted at /content/drive


In [4]:
# # Run on Pycharm

# summary_path = "Summary/"
# eeg_files_path = "EEGFiles/"

In [5]:
def read_file(file_name, noise_flag):
  file_content = pyedflib.EdfReader(eeg_files_path + file_name)

  number_of_signals = file_content.signals_in_file
  signal_labels = file_content.getSignalLabels()

  signal_content = np.zeros((number_of_signals, file_content.getNSamples()[0]))
  for i in np.arange(number_of_signals):
    signal_content[i, :] = file_content.readSignal(i)

  selected_channels = []

  for index in range(len(signal_labels)):
    if signal_labels[index] == "FZ-CZ" or signal_labels[index] == "CZ-PZ":
      selected_channels.append(signal_content[index,:])
  
  if len(selected_channels) != 2:
    return -1
    
  if noise_flag:
    # Check Gap
    for i in range(len(signal_content[0])):
      if signal_content[0][i] == ' ':
        return -1

    # Check Dummy
    for i in range(len(signal_content[0])):
      if signal_content[0][i] == '-':
        return -1

  return np.array(selected_channels)

In [38]:
def get_segmented_and_augmented_data(window_size, augmentation_stride, data):
  segmented_and_augmented_data = [[], []]

  data_dimension = len(data)
  channels = len(data[0])

  for data_dim in range(data_dimension):
    for channel in range(channels):
      data_length = len(data[data_dim][0])
      for start_of_window in range(0, (data_length // 256) - window_size + augmentation_stride, augmentation_stride):
        if (start_of_window + window_size) * 256 < data[data_dim][channel].shape[0]:
          start = start_of_window * 256
          end = (start_of_window + window_size) * 256
          segmented_and_augmented_data[channel].append(data[data_dim][channel][start:end])

  return segmented_and_augmented_data

In [7]:
def get_seizure_and_not_seizure_seizure_period(data_file_name, information, number_of_seizures, noise_flag, window_size, augmentation_stride):
  signal_content = read_file(data_file_name, noise_flag)

  if len(signal_content) == 1:
    if signal_content == -1:
      return -1

  start = information[information.find("File Start Time: ")+len("File Start Time: "):]
  start = int(start[:start.find(":")])

  end = information[information.find("File End Time: ")+len("File End Time: "):]
  end = int(end[:end.find(":")])

  duration = end - start

  off_data_range = 3600 * duration * 5 // 100
  
  seizure_time = information[information.find("Seizure "):]

  seizure_data = []
  not_seizure_before_data = []
  not_seizure_after_data = []

  seizure_segmented_and_augmented_data = [[], []]
  not_seizure_before_segmented_and_augmented_data = [[], []]
  not_seizure_after_segmented_and_augmented_data = [[], []]

  for seizure in range(number_of_seizures):
    start_of_seizures = seizure_time[seizure_time.find("Start Time: ") + len("Start Time: "):]
    start_of_seizures = int(start_of_seizures[:start_of_seizures.find(" seconds")])

    end_of_seizures = seizure_time[seizure_time.find("End Time: ") + len("End Time: "):]
    end_of_seizures = int(end_of_seizures[:end_of_seizures.find(" seconds")])

    if noise_flag:
      if start_of_seizures - (window_size // 2) <= off_data_range:
        seizure_time = seizure_time[seizure_time.find("End Time: " + str(end_of_seizures) + " seconds") + len("End Time: " + str(end_of_seizures) + " seconds\n"):]
        continue
      if end_of_seizures + (window_size // 2) >= 3600 * duration - off_data_range:
        seizure_time = seizure_time[seizure_time.find("End Time: " + str(end_of_seizures) + " seconds") + len("End Time: " + str(end_of_seizures) + " seconds\n"):]
        continue
    
    seizure_data.append(signal_content[:, (start_of_seizures - (window_size // 2)) * 256:(end_of_seizures + (window_size // 2)) * 256])

    not_seizure_before_data.append(signal_content[:, (start_of_seizures - 2 * window_size - (window_size // 2)) * 256:(start_of_seizures + (window_size // 2)) * 256])
    not_seizure_after_data.append(signal_content[:, (end_of_seizures - (window_size // 2)) * 256:(end_of_seizures + 2 * window_size + (window_size // 2)) * 256])

    seizure_time = seizure_time[seizure_time.find("End Time: " + str(end_of_seizures) + " seconds") + len("End Time: " + str(end_of_seizures) + " seconds\n"):]

  result_seizure = get_segmented_and_augmented_data(window_size, augmentation_stride, seizure_data)
  result_not_seizure_before = get_segmented_and_augmented_data(window_size, augmentation_stride, not_seizure_before_data)
  result_not_seizure_after = get_segmented_and_augmented_data(window_size, augmentation_stride, not_seizure_after_data)
    
  seizure_segmented_and_augmented_data[0].append(result_seizure[0])
  seizure_segmented_and_augmented_data[1].append(result_seizure[1])

  not_seizure_before_segmented_and_augmented_data[0].append(result_not_seizure_before[0])
  not_seizure_before_segmented_and_augmented_data[1].append(result_not_seizure_before[1])

  not_seizure_after_segmented_and_augmented_data[0].append(result_not_seizure_after[0])
  not_seizure_after_segmented_and_augmented_data[1].append(result_not_seizure_after[1])

  seizure_count = len(seizure_segmented_and_augmented_data[0][0])
  not_seizure_before_count = len(not_seizure_before_segmented_and_augmented_data[0][0])
  not_seizure_after_count = len(not_seizure_after_segmented_and_augmented_data[0][0])

  seizure_segmented_and_augmented_data = np.array(seizure_segmented_and_augmented_data).reshape(2, seizure_count, 256 * window_size)
  not_seizure_before_segmented_and_augmented_data = np.array(not_seizure_before_segmented_and_augmented_data).reshape(2, not_seizure_before_count, 256 * window_size)
  not_seizure_after_segmented_and_augmented_data = np.array(not_seizure_after_segmented_and_augmented_data).reshape(2, not_seizure_after_count, 256 * window_size)

  return seizure_count, seizure_segmented_and_augmented_data, not_seizure_before_count, not_seizure_before_segmented_and_augmented_data, not_seizure_after_count, not_seizure_after_segmented_and_augmented_data

In [8]:
def get_not_seizure_not_seizure_period(data_file_name, noise_flag, window_size, augmentation_stride):
  signal_content = read_file(data_file_name, noise_flag)

  if len(signal_content) == 1:
    if signal_content == -1:
      return -1

  not_seizure_data = []

  not_seizure_segmented_and_augmented_data = [[], []]

  start = 1770
  end = 1830 + window_size - 1
  
  not_seizure_data.append(signal_content[:, start * 256:end * 256])
    
  result_not_seizure = get_segmented_and_augmented_data(window_size, augmentation_stride, not_seizure_data)
    
  not_seizure_segmented_and_augmented_data[0].append(result_not_seizure[0])
  not_seizure_segmented_and_augmented_data[1].append(result_not_seizure[1])

  not_seizure_count = len(not_seizure_segmented_and_augmented_data[0][0])

  not_seizure_segmented_and_augmented_data = np.array(not_seizure_segmented_and_augmented_data).reshape(2, not_seizure_count, 256 * window_size)

  return not_seizure_count, not_seizure_segmented_and_augmented_data

In [9]:
def read_summary(subject_number):
  if subject_number < 10:
    file_path = summary_path + "chb0" + str(subject_number) + "-summary.txt"
    file_name = "chb0" + str(subject_number)
  else:
    file_path = summary_path + "chb" + str(subject_number) + "-summary.txt"
    file_name = "chb" + str(subject_number)

  with open(file_path) as f:
    content = f.read()
  
  content = content[content.find("File Name"):]

  return file_name, content

In [10]:
def get_data_points(target_data_points_count, seizure_percent=50, before_or_after_percent=6.25, not_seizure_percent=37.5, window_size=5, augmentation_stride=1, noise_flag=False, number_of_files=10):
  seizure_count = 0
  seizure_flag = True
  target_seizure_count = int(target_data_points_count * seizure_percent / 100)
  seizure_data_points = [[], []]

  not_seizure_before_seizure_count = 0
  not_seizure_before_seizure_flag = True
  target_not_seizure_before_seizure_count = int(target_data_points_count * before_or_after_percent / 100)
  not_seizure_before_seizure_data_points = [[], []]

  not_seizure_after_seizure_count = 0
  not_seizure_after_seizure_flag = True
  target_not_seizure_after_seizure_count = int(target_data_points_count * before_or_after_percent / 100)
  not_seizure_after_seizure_data_points = [[], []]

  not_seizure_not_seizure_count = 0
  not_seizure_not_seizure_flag = True
  target_not_seizure_not_seizure_count = int(target_data_points_count * not_seizure_percent / 100)
  not_seizure_not_seizure_data_points = [[], []]

  subject_number = 1

  while (seizure_flag or not_seizure_before_seizure_flag or not_seizure_after_seizure_flag or not_seizure_not_seizure_flag) and subject_number <= 22:
    file_name, content = read_summary(subject_number)

    subject_number += 1

    data_file_number = 1
    data_files_selected = 0

    while data_files_selected < number_of_files and data_file_number < 50:
      # Remove ECG & VNS
      if (subject_number - 1 == 4 and data_file_number >= 7) or (subject_number - 1 == 9 and data_file_number >= 2):
        break
      
      if data_file_number < 10:
        index = content.find("File Name: " + file_name + "_0" + str(data_file_number))
      else:
        index = content.find("File Name: " + file_name + "_" + str(data_file_number))
      
      if index != -1:
        information = content[index:]
        number_of_seizures = information[information.find("Number of Seizures in File: ") + len("Number of Seizures in File: "):]
        number_of_seizures = int(number_of_seizures[:number_of_seizures.find("\n")])
        
        if data_file_number < 10:
          data_file_name = file_name + "_0" + str(data_file_number) + ".edf"
        else:
          data_file_name = file_name + "_" + str(data_file_number) + ".edf"

        # if data_file_name == "chb06_01.edf":
        # if data_file_name == "chb01_03.edf":
        if data_file_name == "chb01_01.edf" or data_file_name == "chb01_02.edf" or data_file_name == "chb01_03.edf" or data_file_name == "chb06_01.edf":
          if number_of_seizures == 0:
            if not_seizure_not_seizure_flag:
              result = get_not_seizure_not_seizure_period(data_file_name, noise_flag, window_size, augmentation_stride)

              if len(result) == 1:
                if result == -1:
                  break
              
              count, data_points = result
              not_seizure_not_seizure_count += count

              for i in range(count):
                  not_seizure_not_seizure_data_points[0].append(data_points[0][i])
                  not_seizure_not_seizure_data_points[1].append(data_points[1][i])

              if not_seizure_not_seizure_count >= target_not_seizure_not_seizure_count:
                not_seizure_not_seizure_flag = False
          else:
            if seizure_flag or not_seizure_before_seizure_flag or not_seizure_after_seizure_flag:
              result = get_seizure_and_not_seizure_seizure_period(data_file_name, information, number_of_seizures, noise_flag, window_size, augmentation_stride)
              
              if len(result) == 1:
                if result == -1:
                  break
              
              count, seizure_data, before_count, before_data, after_count, after_data = result

              if seizure_flag:
                seizure_count += count

                for i in range(count):
                  seizure_data_points[0].append(seizure_data[0][i])
                  seizure_data_points[1].append(seizure_data[1][i])
                
                if seizure_count >= target_seizure_count:
                  seizure_flag = False
              
              if not_seizure_before_seizure_flag:
                not_seizure_before_seizure_count += before_count

                for i in range(before_count):
                  not_seizure_before_seizure_data_points[0].append(before_data[0][i])
                  not_seizure_before_seizure_data_points[1].append(before_data[1][i])
                
                if not_seizure_before_seizure_count >= target_not_seizure_before_seizure_count:
                  not_seizure_before_seizure_flag = False
              
              if not_seizure_after_seizure_flag:
                not_seizure_after_seizure_count += after_count

                for i in range(after_count):
                  not_seizure_after_seizure_data_points[0].append(after_data[0][i])
                  not_seizure_after_seizure_data_points[1].append(after_data[1][i])
                
                if not_seizure_after_seizure_count >= target_not_seizure_after_seizure_count:
                  not_seizure_after_seizure_flag = False
        
        data_files_selected += 1
      data_file_number += 1

  return [np.array(seizure_data_points)[:,:target_seizure_count], np.array(not_seizure_before_seizure_data_points)[:,:target_not_seizure_before_seizure_count], np.array(not_seizure_after_seizure_data_points)[:,:target_not_seizure_after_seizure_count], np.array(not_seizure_not_seizure_data_points)[:,:target_not_seizure_not_seizure_count]], [target_seizure_count, target_not_seizure_before_seizure_count, target_not_seizure_after_seizure_count, target_not_seizure_not_seizure_count]

In [11]:
target_data_points = 160

In [39]:
data_points, data_counts = get_data_points(target_data_points)
class_information = [[80, 1], [80, 0]]
data_set_information = "hello"

In [40]:
print(data_points[3].shape)

(2, 60, 1280)


In [14]:
def get_subjects_report(report_type):
  if report_type == 'gender':
    subject_type_1 = [0, 2, 4, 5, 6, 8, 10, 11, 12, 13, 15, 16, 17, 18, 19, 20, 21]
    subject_type_2 = [1, 3, 7, 9, 14]
  elif report_type == 'age':
    subject_type_1 = [4, 5, 7, 8, 9, 11, 12, 13, 15, 19, 21]
    subject_type_2 = [0, 1, 2, 3, 6, 10, 14, 16, 17, 18, 20]

  return subject_type_1, subject_type_2

In [33]:
def get_data_points_with_constraint(target_data_points_count, constraint_type, window_size=5, augmentation_stride=1):
  # Constraint 1
  count_c1 = 0
  target_count_c1 = target_data_points_count

  seizure_count_c1 = 0
  seizure_flag_c1 = True
  target_seizure_count_c1 = target_count_c1 // 2
  seizure_data_points_c1 = [[], []]

  not_seizure_count_c1 = 0
  not_seizure_flag_c1 = True
  target_not_seizure_count_c1 = target_count_c1 // 2
  not_seizure_data_points_c1 = [[], []]

  # Constraint 2
  count_c2 = 0
  target_count_c2 = target_data_points_count

  seizure_count_c2 = 0
  seizure_flag_c2 = True
  target_seizure_count_c2 = target_count_c2 // 2
  seizure_data_points_c2 = [[], []]

  not_seizure_count_c2 = 0
  not_seizure_flag_c2 = True
  target_not_seizure_count_c2 = target_count_c2 // 2
  not_seizure_data_points_c2 = [[], []]

  subject_type_1, subject_type_2 = get_subjects_report(constraint_type)
  
  number_of_files = 15
  subject_index = 0

  while (seizure_flag_c1 or not_seizure_flag_c1) and subject_index < len(subject_type_1):
    subject = subject_type_1[subject_index] + 1
    subject_index += 1
    file_name, content = read_summary(subject)

    data_file_number = 1
    data_files_selected = 0

    while data_files_selected < number_of_files and data_file_number < 50:
      # Remove ECG & VNS
      if (subject == 4 and data_file_number >= 7) or (subject == 9 and data_file_number >= 2):
        break

      if data_file_number < 10:
        index = content.find("File Name: " + file_name + "_0" + str(data_file_number))
      else:
        index = content.find("File Name: " + file_name + "_" + str(data_file_number))

      if index != -1:
        information = content[index:]
        number_of_seizures = information[information.find("Number of Seizures in File: ") + len("Number of Seizures in File: "):]
        number_of_seizures = int(number_of_seizures[:number_of_seizures.find("\n")])

        if data_file_number < 10:
          data_file_name = file_name + "_0" + str(data_file_number) + ".edf"
        else:
          data_file_name = file_name + "_" + str(data_file_number) + ".edf"

        # if data_file_name == "chb06_01.edf":
        # if data_file_name == "chb01_03.edf":
        if data_file_name == "chb01_01.edf" or data_file_name == "chb01_02.edf" or data_file_name == "chb01_03.edf" or data_file_name == "chb06_01.edf":
          if number_of_seizures == 0:
            if not_seizure_flag_c1:
              result = get_not_seizure_not_seizure_period(data_file_name, False, window_size, augmentation_stride)

              if len(result) == 1:
                if result == -1:
                  break

              count, data_points = result
              not_seizure_count_c1 += count

              for i in range(count):
                  not_seizure_data_points_c1[0].append(data_points[0][i])
                  not_seizure_data_points_c1[1].append(data_points[1][i])

              if not_seizure_count_c1 >= target_not_seizure_count_c1:
                not_seizure_flag_c1 = False
          else:
            if seizure_flag_c1 or not_seizure_flag_c1:
              result = get_seizure_and_not_seizure_seizure_period(data_file_name, information, number_of_seizures, False, window_size, augmentation_stride)

              if len(result) == 1:
                if result == -1:
                  break

              count, seizure_data, before_count, before_data, after_count, after_data = result

              if seizure_flag_c1:
                seizure_count_c1 += count

                for i in range(count):
                  seizure_data_points_c1[0].append(seizure_data[0][i])
                  seizure_data_points_c1[1].append(seizure_data[1][i])

                if seizure_count_c1 >= target_seizure_count_c1:
                  seizure_flag_c1 = False

              if not_seizure_flag_c1:
                not_seizure_count_c1 += before_count + after_count

                for i in range(before_count):
                  not_seizure_data_points_c1[0].append(before_data[0][i])
                  not_seizure_data_points_c1[1].append(before_data[1][i])

                for i in range(after_count):
                  not_seizure_data_points_c1[0].append(after_data[0][i])
                  not_seizure_data_points_c1[1].append(after_data[1][i])

                if not_seizure_count_c1 >= target_not_seizure_count_c1:
                  not_seizure_flag_c1 = False

        data_files_selected += 1
      data_file_number += 1
  
  subject_index = 0
  while (seizure_flag_c2 or not_seizure_flag_c2) and subject_index < len(subject_type_2):
    subject = subject_type_2[subject_index] + 1
    subject_index += 1
    file_name, content = read_summary(subject)

    data_file_number = 1
    data_files_selected = 0

    while data_files_selected < number_of_files and data_file_number < 50:
      # Remove ECG & VNS
      if (subject == 4 and data_file_number >= 7) or (subject == 9 and data_file_number >= 2):
        break

      if data_file_number < 10:
        index = content.find("File Name: " + file_name + "_0" + str(data_file_number))
      else:
        index = content.find("File Name: " + file_name + "_" + str(data_file_number))

      if index != -1:
        information = content[index:]
        number_of_seizures = information[information.find("Number of Seizures in File: ") + len("Number of Seizures in File: "):]
        number_of_seizures = int(number_of_seizures[:number_of_seizures.find("\n")])

        if data_file_number < 10:
          data_file_name = file_name + "_0" + str(data_file_number) + ".edf"
        else:
          data_file_name = file_name + "_" + str(data_file_number) + ".edf"

        # if data_file_name == "chb06_01.edf":
        # if data_file_name == "chb01_03.edf":
        if data_file_name == "chb01_01.edf" or data_file_name == "chb01_02.edf" or data_file_name == "chb01_03.edf" or data_file_name == "chb06_01.edf":
          if number_of_seizures == 0:
            if not_seizure_flag_c2:
              result = get_not_seizure_not_seizure_period(data_file_name, False, window_size, augmentation_stride)

              if len(result) == 1:
                if result == -1:
                  break

              count, data_points = result
              not_seizure_count_c2 += count

              for i in range(count):
                  not_seizure_data_points_c2[0].append(data_points[0][i])
                  not_seizure_data_points_c2[1].append(data_points[1][i])

              if not_seizure_count_c2 >= target_not_seizure_count_c2:
                not_seizure_flag_c2 = False
          else:
            if seizure_flag_c2 or not_seizure_flag_c2:
              result = get_seizure_and_not_seizure_seizure_period(data_file_name, information, number_of_seizures, False, window_size, augmentation_stride)

              if len(result) == 1:
                if result == -1:
                  break

              count, seizure_data, before_count, before_data, after_count, after_data = result

              if seizure_flag_c2:
                seizure_count_c2 += count

                for i in range(count):
                  seizure_data_points_c2[0].append(seizure_data[0][i])
                  seizure_data_points_c2[1].append(seizure_data[1][i])

                if seizure_count_c2 >= target_seizure_count_c2:
                  seizure_flag_c2 = False

              if not_seizure_flag_c2:
                not_seizure_count_c2 += before_count + after_count

                for i in range(before_count):
                  not_seizure_data_points_c2[0].append(before_data[0][i])
                  not_seizure_data_points_c2[1].append(before_data[1][i])

                for i in range(after_count):
                  not_seizure_data_points_c2[0].append(after_data[0][i])
                  not_seizure_data_points_c2[1].append(after_data[1][i])

                if not_seizure_count_c2 >= target_not_seizure_count_c2:
                  not_seizure_flag_c2 = False

        data_files_selected += 1
      data_file_number += 1

  return [np.array(seizure_data_points_c1)[:,:target_seizure_count_c1], np.array(not_seizure_data_points_c1)[:,:target_not_seizure_count_c1]], [np.array(seizure_data_points_c2)[:,:target_seizure_count_c2], np.array(not_seizure_data_points_c2)[:,:target_not_seizure_count_c2]], [target_seizure_count_c1, target_not_seizure_count_c1], [target_seizure_count_c2, target_not_seizure_count_c2]

In [57]:
data_points_c1, data_points_c2, data_counts_c1, data_counts_c2 = get_data_points_with_constraint(80, "gender")
class_information = [[80, 1, [1, 0]], [80, 0, [0, 1]]]
data_set_information = "hello"

In [32]:
print(data_points_c1[0].shape, data_points_c1[1].shape, data_points_c2[0].shape, data_points_c2[1].shape)

(2, 40, 1280) (2, 40, 1280) (2, 40, 1280) (2, 40, 1280)


In [64]:
def to_pickle(input_data_points, input_class_information, input_data_set_information, data_set_number, data_type=None):
  if data_type == "constraint":
    data_points_channel_1 = np.concatenate((input_data_points[0][0], input_data_points[1][0]), axis=0)
    data_points_channel_2 = np.concatenate((input_data_points[0][1], input_data_points[1][1]), axis=0)
  elif data_type == "test":
    data_points_channel_1 = input_data_points[0]
    data_points_channel_2 = input_data_points[1]
  else:
    data_points_channel_1 = np.concatenate((input_data_points[0][0], input_data_points[1][0], input_data_points[2][0], data_points[3][0]), axis=0)
    data_points_channel_2 = np.concatenate((input_data_points[0][1], input_data_points[1][1], input_data_points[2][1], data_points[3][1]), axis=0)
  
  labels = np.concatenate([np.full(shape=input_class_information[class_index][0], fill_value=input_class_information[class_index][1], dtype=int) for class_index in range(len(input_class_information))])
  labels_hot_ones = np.concatenate([np.full(shape=(input_class_information[class_index][0], 2), fill_value=input_class_information[class_index][2], dtype=int) for class_index in range(len(input_class_information))])

  # folder = 'C:/Users/Poorya Sadr/Desktop/EEG_phase3/Datasets/' + str(data_set_number) + '/'
  # if not os.path.isdir(folder):
  #   os.mkdir(folder)

  # file_name = 'data_points_channel_1.pkl'
  # file_path = os.path.join(folder, file_name)
  # pickle.dump(data_points_channel_1, open(file_path, 'wb'))

  # file_name = 'data_points_channel_2.pkl'
  # file_path = os.path.join(folder, file_name)
  # pickle.dump(data_points_channel_2, open(file_path, 'wb'))

  # file_name = 'labels.pkl'
  # file_path = os.path.join(folder, file_name)
  # pickle.dump(labels, open(file_path, 'wb'))

  # file_name = 'labels_hot_ones.pkl'
  # file_path = os.path.join(folder, file_name)
  # pickle.dump(labels_hot_ones, open(file_path, 'wb'))

  # file_name = 'information.txt'
  # file_path = os.path.join(folder, file_name)
  # f = open(file_path, "w")
  # f.write(input_data_set_information)
  # f.close()

In [63]:
to_pickle(data_points, class_information, data_set_information, 1)

[[1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [1 0]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]
 [0 1]

In [54]:
def prepare_test_data():
  data_points_seizure = []
  data_points_not_seizure = []

  seizure_segmented_and_augmented_data = [[], []]
  not_seizure_segmented_and_augmented_data = [[], []]

  window_size = 5
  stride = 1

  count = 1
  subject = 10
  # seizure_file_number = [12, 20, 27, 30, 31, 38, 89]
  # start = [6330, 6888, 2382, 3021, 3801, 4618, 1383]
  # end = [6348, 6958, 2447, 3079, 3877, 4707, 1437]
  # not_seizure_file_number = [21, 22, 28]

  seizure_file_number = [3]
  start = [2996]
  end = [3036]
  not_seizure_file_number = [1, 2]

  for seizure in range(count):
    # if seizure_file_number[seizure] < 10:
    #   data_file_name = "chb10_0" + str(seizure_file_number[seizure]) + ".edf"
    # else:
    #   data_file_name = "chb10_" + str(seizure_file_number[seizure]) + ".edf"
    
    if seizure_file_number[seizure] < 10:
      data_file_name = "chb01_0" + str(seizure_file_number[seizure]) + ".edf"
    else:
      data_file_name = "chb01_" + str(seizure_file_number[seizure]) + ".edf"

    signal_content = read_file(data_file_name, False)
    
    start_of_seizures = start[seizure]
    end_of_seizures = end[seizure]
    
    data_points_seizure.append(signal_content[:, start_of_seizures * 256:end_of_seizures * 256])

  result_seizure = get_segmented_and_augmented_data(window_size, stride, data_points_seizure)
    
  seizure_segmented_and_augmented_data[0].append(result_seizure[0])
  seizure_segmented_and_augmented_data[1].append(result_seizure[1])

  seizure_count = len(seizure_segmented_and_augmented_data[0][0])

  seizure_segmented_and_augmented_data = np.array(seizure_segmented_and_augmented_data).reshape(2, seizure_count, 256 * window_size)

  for not_seizure in range(count):
    # if not_seizure_file_number[not_seizure] < 10:
    #   data_file_name = "chb10_0" + str(not_seizure_file_number[not_seizure]) + ".edf"
    # else:
    #   data_file_name = "chb10_" + str(not_seizure_file_number[not_seizure]) + ".edf"
    
    if not_seizure_file_number[not_seizure] < 10:
      data_file_name = "chb01_0" + str(not_seizure_file_number[not_seizure]) + ".edf"
    else:
      data_file_name = "chb01_" + str(not_seizure_file_number[not_seizure]) + ".edf"

    signal_content = read_file(data_file_name, False)
    
    start = 1770
    end = 1830
    
    data_points_not_seizure.append(signal_content[:, start * 256:end * 256])

  result_not_seizure = get_segmented_and_augmented_data(window_size, stride, data_points_not_seizure)
    
  not_seizure_segmented_and_augmented_data[0].append(result_not_seizure[0])
  not_seizure_segmented_and_augmented_data[1].append(result_not_seizure[1])

  not_seizure_count = len(not_seizure_segmented_and_augmented_data[0][0])

  not_seizure_segmented_and_augmented_data = np.array(not_seizure_segmented_and_augmented_data).reshape(2, not_seizure_count, 256 * window_size)

  return np.array(seizure_segmented_and_augmented_data), seizure_count, np.array(not_seizure_segmented_and_augmented_data), not_seizure_count

  # pickle.dump(new_seizure, open('seizure_test_data_points.pkl', 'wb'))
  # pickle.dump(new_not_seizure, open('not_seizure_test_data_points.pkl', 'wb'))

In [53]:
test_seizure_data_points, test_seizure_count, test_not_seizure_data_points, test_not_seizure_count = prepare_test_data()

(array([[[  42.78388278,   35.75091575,   25.59218559, ...,
          -183.44322344, -195.16483516, -206.10500611],
         [ -20.9035409 ,  -22.85714286,  -25.59218559, ...,
           -17.77777778,  -17.38705739,  -11.91697192],
         [  61.53846154,   53.33333333,   40.43956044, ...,
          -281.51404151, -277.21611722, -279.16971917],
         ...,
         [ -69.35286935,  -65.44566545,  -60.36630037, ...,
           176.01953602,  172.5030525 ,  165.07936508],
         [-213.52869353, -227.2039072 , -238.92551893, ...,
            27.15506716,   22.85714286,   18.94993895],
         [   2.14896215,    1.36752137,    0.97680098, ...,
            23.63858364,   23.24786325,   24.02930403]],
 
        [[ 113.89499389,  114.67643468,  113.5042735 , ...,
          -176.41025641, -170.94017094, -162.34432234],
         [  16.60561661,   21.29426129,   20.51282051, ...,
            25.2014652 ,   32.23443223,   40.83028083],
         [  46.3003663 ,   42.002442  ,   33.01587302, 