# Clustering

In [None]:
import collections
from collections import Counter
from fcmeans import FCM
import math
import matplotlib.pyplot as plt
import numpy as np
import os
import pickle
import pyedflib
import pynwb
from pynwb import NWBFile
from pynwb import TimeSeries
import pywt
import random
import re
import scipy.stats
from scipy.stats import entropy, tstd, tmean
from scipy import signal
import sklearn.cluster
from sklearn.cluster import Birch
from sklearn.cluster import KMeans
from sklearn import metrics
from tabulate import tabulate
import warnings

## Convert Data

Download the dataset

In [None]:
dataset_root = r"" # Path to patient EDF files
summary_path = r"" # Path to patient summary.txt file

def get_dataset_stats():
  dataset_stats = []

  summary_text = open(summary_path, "r").read()
  parsed_summary = re.split("\n\n", summary_text)[2:]
  seizure_numbers = re.findall(r'Number of Seizures in File: (.)', summary_text)
  max_seizures_file = max([int(x) for x in seizure_numbers])

  for summary in parsed_summary:
    file_name = re.search(r'File Name: (.*?)\n', summary)
    if file_name:
      file_name = file_name.group(1)
      seizure_durations = []

      num_seizures = re.search(r'Number of Seizures in File: (.)', summary).group(1)
      num_seizures =  int(num_seizures)
      
      if(num_seizures > 0):
        if max_seizures_file < 2:
          seizure_start_time = re.search(r'Seizure Start Time: (.*?) seconds\n', summary)
          seizure_end_time = re.search('Seizure End Time: (.*?) seconds', summary)
          if seizure_start_time:
            seizure_start_time = int(seizure_start_time.group(1)) * 256
            seizure_end_time = int(seizure_end_time.group(1)) * 256
            seizure_durations.append((seizure_start_time, seizure_end_time))
          else:
            seizure_start_time = int(re.search(r'Seizure 1 Start Time: (.*?) seconds\n', summary).group(1)) * 256
            seizure_end_time = int(re.search('Seizure 1 End Time: (.*?) seconds', summary).group(1)) * 256
            seizure_durations.append((seizure_start_time, seizure_end_time))
        else:
          for i in range(num_seizures):
            seizure_num = i + 1
            seizure_start_time = re.search(r'Seizure %i Start Time: (.*?) seconds\n'%(seizure_num), summary)
            seizure_end_time = re.search('Seizure %i End Time: (.*?) seconds'%(seizure_num), summary)
            if seizure_start_time:
              seizure_start_time = int(seizure_start_time.group(1)) * 256
              seizure_end_time = int(seizure_end_time.group(1)) * 256
              seizure_durations.append((seizure_start_time, seizure_end_time))
            else:
              seizure_start_time = int(re.search(r'Seizure Start Time: (.*?) seconds\n', summary).group(1)) * 256
              seizure_end_time = int(re.search('Seizure End Time: (.*?) seconds', summary).group(1)) * 256
              seizure_durations.append((seizure_start_time, seizure_end_time))
      dataset_stats.append([file_name, num_seizures, seizure_durations])
  return np.asarray(dataset_stats)

dataset_stats = get_dataset_stats()

  return np.asarray(dataset_stats)


In [None]:
def is_abnormal(index, seg_size, durations):
  for duration in durations:
    seizure_start = duration[0]
    seizure_end = duration[1]
    seg_start = index
    seg_end = index + seg_size

    if seizure_start <= seg_start <= seizure_end or seg_start <= seizure_start <= seg_end:
      return True

  return False

def segment(signal, durations):
    seg_size = 1280 # 5 Seconds
    overlap = 0.1 # 90% overlap
    index = 0
    segments = []
    labels = []

    need_last_seg = True

    while index <= len(signal) - 1280:
      if index + seg_size == len(signal) - 1:
        need_last_seg = False

      segment = signal[index:index + seg_size]
      features = get_dwt_features(segment)
      segments.append(features)
      
      # Add labels for each segment
        # 0 - Normal
        # 1 - Abnormal
      if is_abnormal(index, seg_size, durations):
        labels.append(1)
      else:
        labels.append(0)

      index += math.ceil(overlap * seg_size)

    if need_last_seg:
      segment = signal[-seg_size:]
      segments.append(get_dwt_features(segment))
      
      if is_abnormal(len(signal) - 1 - seg_size, seg_size, durations):
        labels.append(1)
      else:
        labels.append(0)

    return segments, labels

def load_data(stats, func):
  file_names = stats[:,0]
  dataset = []
  seg_size = 1280

  for root, dir, files in os.walk(dataset_root, topdown=False):
    for name in files:
      if name in file_names:
        X_data = []
        Y_data = []
        index = np.where(file_names == name)[0][0]
        f = pyedflib.EdfReader(os.path.join(root, name))
        n = f.signals_in_file
        for i in range(n):
          signal = f.readSignal(i, digital=True)
          segments, labels = segment(signal, stats[index][2])
          X_data += segments
          Y_data += labels
        f.close()
        dataset.append((name, X_data, Y_data, stats[index][2]))
  return dataset

dataset = load_data(dataset_stats, get_dwt_features)

In [None]:
def under_sample(data):
  normal_index = data[:,1] == 0
  abnormal_index = data[:,1] == 1

  normal = data[normal_index, :]
  abnormal = data[abnormal_index, :]

  rand_index = np.random.choice(normal.shape[0], abnormal.shape[0], replace=False)
  normal_sample = normal[rand_index, :]

  dataset = np.concatenate((abnormal, normal_sample))
  np.random.shuffle(dataset)

  return np.asarray([x for x in dataset[:,0]]), np.asarray(dataset)[:,1]

# Validation Methods

In [None]:
def get_overlap(start_1, end_1, start_2, end_2):
  return range(max(start_1, start_2), min(end_1, end_2))

def is_overlaping(start_1, end_1, start_2, end_2):
  if len(get_overlap(start_1, end_1, start_2, end_2)):
    return True
  return False

def calculate_seziure_times(labels, tolerance=5):
  start = None
  count = 0
  seizures = []
  for (index, label) in enumerate(labels):
    if label:
      if not start:
        start = index * 128
    else:
      if start:
        count += 1
        if count == tolerance:
          seizures.append((start, index * 128))
          start = None
          count = 0
  return seizures

def calculate_latency(seizure_durations, seizure_times):
  count = 0
  S = len(seizure_durations)
  latency = 0

  if S == 0:
    return -1

  for (seizure_start, seizure_end) in seizure_durations:
    for (start, end) in seizure_times:
      overlap = get_overlap(seizure_start, seizure_end, start, end)
      if len(overlap):
        latency += (overlap[0] - seizure_start) / 256
        count += 1
        break

  if count == 0:
    return -1

  return latency / count

def calculate_sensitivity(seizure_durations, seizure_times):
  count = 0
  S = len(seizure_durations)

  if S == 0:
    return -1
  
  for (seizure_start, seizure_end) in seizure_durations:
    detected = False
    for (start, end) in seizure_times:
      if is_overlaping(seizure_start, seizure_end, start, end):
        count += 1
        break
  return count / S

def calculate_false_pos_rate(Y_test, Y_pred):
  count = 0
  N = len([label for label in Y_test if label == 0])

  for i in range(len(Y_pred)):
    if Y_pred[i] and not Y_test[i]:
      count += 1

  return 3600 * 256 * (count / N)

def calculate_false_neg_rate(Y_test, Y_pred):
  count = 0
  N = len

In [None]:
def evaluate(model, X_train, Y_train, X_test, Y_test, seizure_durations, cluster_assignments, ictal):
  # Training
  contingency_matrix = metrics.cluster.contingency_matrix(Y_train, cluster_assignments)
  label_mapping = np.argmax(contingency_matrix, axis=0)
  
  #Testing
  predictions = model.predict(X_test)
  labels = [label_mapping[i] for i in predictions]
  seizure_times = calculate_seziure_times(labels)

  latency = -1
  sensitivity = -1

  # Evaluation Metrics
  if ictal:
    latency = calculate_latency(seizure_durations, seizure_times)
    sensitivity = calculate_sensitivity(seizure_durations, seizure_times)
  false_pos_rate = calculate_false_pos_rate(Y_test, labels)

  return latency, sensitivity, false_pos_rate

In [None]:
def get_train_test_set(index, test_set, dataset):
  X_test = np.asarray(test_set[1])
  Y_test = test_set[2]
  seizure_durations = test_set[3]
  
  train_set = dataset[:index] + dataset[index + 1:]

  X_train = [i[1] for i in train_set]
  X_train = [j for i in X_train for j in i]

  Y_train = [i[2] for i in train_set]
  Y_train = [j for i in Y_train for j in i]

  X_train, Y_train = under_sample(np.array(list(zip(X_train, Y_train))))

  return X_train, Y_train, X_test, Y_test, seizure_durations

def evaluate_model_test(model, cmeans=False):
  fpr_iterictal = 0
  fpr_ictal = 0
  avg_latency = 0
  avg_sensitivity = 0

  count_ictal = 0
  count_iterictal = 0

  for index, test_set in enumerate(dataset):
    patient_file = (test_set[0][:-4])
    X_train, Y_train, X_test, Y_test, seizure_durations = get_train_test_set(index, test_set, dataset)

    ictal = len(seizure_durations) > 0

    # Train and test Model 1
    model.fit(X_train)
    cluster_assignments = 0
    
    if cmeans:
      cluster_assignments = model.predict(X_train)
    else:
      cluster_assignments = model.labels_
      
    latency, sensitivity, false_pos_rate = evaluate(model, X_train, Y_train, X_test, Y_test, seizure_durations, cluster_assignments, ictal)

    if ictal:
      count_ictal += 1
      avg_latency += max(0, latency)
      avg_sensitivity += max(0, sensitivity)
      fpr_ictal += false_pos_rate
    else:
      count_iterictal += 1
      fpr_iterictal += false_pos_rate

  if not count_ictal == 0:
    avg_latency /= count_ictal
    avg_sensitivity /= count_ictal
    fpr_ictal /= count_ictal

  if not count_iterictal == 0:
    fpr_iterictal /= count_iterictal

  print(tabulate([[avg_latency, avg_sensitivity, fpr_ictal, fpr_iterictal]], headers=["Latency", "Sensitivity", "False Positive Rate (Ictal)", "False Positive Rate (Interictal)"], tablefmt="pretty"))

# Testing BIRCH

In [None]:
warnings.filterwarnings('ignore')

ideal_k_birch = 2
birch = Birch(threshold=10, branching_factor=80, n_clusters=ideal_k_birch, compute_labels=True, copy=True)
evaluate_model_test(birch)

# Testing Kmeans

In [None]:
ideal_k_kmeans = 4
kmeans = KMeans(n_clusters=ideal_k_kmeans)
evaluate_model_test(kmeans)

# Testing Fuzzy C-Means

In [None]:
ideal_k_cmeans = 4
cmeans = FCM(n_clusters=ideal_k_cmeans)
evaluate_model_test(cmeans, True)