<a href="https://colab.research.google.com/github/naomifelleke/Seizure-Detection/blob/main/Seizure_Detection_Algorithm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
Downloading Data

In [None]:
!git clone https://github.com/AITRICS/EEG_real_time_seizure_detection.git


In [None]:
%cd EEG_real_time_seizure_detection


In [None]:
!ls


In [None]:
from google.colab import files

# This will prompt you to upload files
uploaded = files.upload()

In [None]:
!pip install -r requirements\ \(1\).txt


In [None]:
!apt-get install -y libpython3-dev
!pip install --upgrade pip
!pip install --no-cache-dir pyedflib

In [None]:
!pip install mne


In [None]:
!unzip /content/drive/MyDrive/tuh_eeg_seizure.zip -d /content/tuh_eeg_seizure


In [None]:
Data Preprocessing

In [None]:
# -*- coding: utf-8 -*-
# Copyright (c) 2022, Kwanhyung Lee, AITRICS. All rights reserved.
#
# Licensed under the MIT License;
# you may not use this file except in compliance with the License.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from pyedflib import highlevel, EdfReader
from scipy.io.wavfile import write
from scipy import signal as sci_sig
from scipy.spatial.distance import pdist
from scipy.signal import stft, hilbert, butter, freqz, filtfilt, find_peaks
from builder.utils.process_util import run_multi_process
from builder.utils.utils import search_walk
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import math
import os
import argparse
import torch
import glob
import pickle
import random
import mne
from mne.io.edf.edf import _read_annotations_edf, _read_edf_header
from itertools import groupby

GLOBAL_DATA = {}
label_dict = {}
sample_rate_dict = {}
sev_label = {}


def label_sampling_tuh(labels, feature_samplerate):
    y_target = ""
    remained = 0
    feature_intv = 1/float(feature_samplerate)
    for i in labels:
        begin, end, label = i.split(" ")[:3]

        intv_count, remained = divmod(float(end) - float(begin) + remained, feature_intv)
        y_target += int(intv_count) * str(GLOBAL_DATA['disease_labels'][label])
    return y_target


def generate_training_data_leadwise_tuh_train(file):
    sample_rate = GLOBAL_DATA['sample_rate']    # EX) 200Hz
    file_name = ".".join(file.split(".")[:-1])  # EX) /content/tuh_eeg_seizuretrain/01_tcp_ar/072/00007235/s003_2010_11_20/00007235_s003_t000
    data_file_name = file_name.split("/")[-1]   # EX) 00007235_s003_t000
    signals, signal_headers, header = highlevel.read_edf(file)
    label_list_c = []
    for idx, signal in enumerate(signals):
        label_noref = signal_headers[idx]['label'].split("-")[0]    # EX) EEG FP1-ref or EEG FP1-LE --> EEG FP1
        label_list_c.append(label_noref)

    ############################# part 1: labeling  ###############################
    label_file = open(file_name + "." + GLOBAL_DATA['args.label_type'], 'r') # EX) 00007235_s003_t003.tse or 00007235_s003_t003.tse_bi
    y = label_file.readlines()
    y = list(y[2:])
    y_labels = list(set([i.split(" ")[2] for i in y]))
    signal_sample_rate = int(signal_headers[0]['sample_rate'])
    if sample_rate > signal_sample_rate:
        return
    if not all(elem in label_list_c for elem in GLOBAL_DATA['label_list']): # if one or more of ['EEG FP1', 'EEG FP2', ... doesn't exist
        return
    # if not any(elem in y_labels for elem in GLOBAL_DATA['disease_type']): # if non-patient exist
    #     return
    y_sampled = label_sampling_tuh(y, GLOBAL_DATA['feature_sample_rate'])

    ############################# part 2: input data filtering #############################
    signal_list = []
    signal_label_list = []
    signal_final_list_raw = []

    for idx, signal in enumerate(signals):
        label = signal_headers[idx]['label'].split("-")[0]
        if label not in GLOBAL_DATA['label_list']:
            continue

        if int(signal_headers[idx]['sample_rate']) > sample_rate:
            secs = len(signal)/float(signal_sample_rate)
            samps = int(secs*sample_rate)
            x = sci_sig.resample(signal, samps)
            signal_list.append(x)
            signal_label_list.append(label)
        else:
            signal_list.append(signal)
            signal_label_list.append(label)

    if len(signal_label_list) != len(GLOBAL_DATA['label_list']):
        print("Not enough labels: ", signal_label_list)
        return

    for lead_signal in GLOBAL_DATA['label_list']:
        signal_final_list_raw.append(signal_list[signal_label_list.index(lead_signal)])

    new_length = len(signal_final_list_raw[0]) * (float(GLOBAL_DATA['feature_sample_rate']) / GLOBAL_DATA['sample_rate'])

    if len(y_sampled) > new_length:
        y_sampled = y_sampled[:new_length]
    elif len(y_sampled) < new_length:
        diff = int(new_length - len(y_sampled))
        y_sampled += y_sampled[-1] * diff

    y_sampled_np = np.array(list(map(int,y_sampled)))
    new_labels = []
    new_labels_idxs = []

    ############################# part 3: slicing for easy training  #############################
    y_sampled = ["0" if l not in GLOBAL_DATA['selected_diseases'] else l for l in y_sampled]

    if any(l in GLOBAL_DATA['selected_diseases'] for l in y_sampled):
        y_sampled = [str(GLOBAL_DATA['target_dictionary'][int(l)]) if l in GLOBAL_DATA['selected_diseases'] else l for l in y_sampled]

    # slice and save if training data
    new_data = {}
    raw_data = torch.Tensor(signal_final_list_raw).permute(1,0)

    max_seg_len_before_seiz_label = GLOBAL_DATA['max_bckg_before_slicelength'] * GLOBAL_DATA['feature_sample_rate']
    max_seg_len_before_seiz_raw = GLOBAL_DATA['max_bckg_before_slicelength'] * GLOBAL_DATA['sample_rate']
    max_seg_len_after_seiz_label = GLOBAL_DATA['max_bckg_after_seiz_length'] * GLOBAL_DATA['feature_sample_rate']
    max_seg_len_after_seiz_raw = GLOBAL_DATA['max_bckg_after_seiz_length'] * GLOBAL_DATA['sample_rate']

    min_seg_len_label = GLOBAL_DATA['min_binary_slicelength'] * GLOBAL_DATA['feature_sample_rate']
    min_seg_len_raw = GLOBAL_DATA['min_binary_slicelength'] * GLOBAL_DATA['sample_rate']
    max_seg_len_label = GLOBAL_DATA['max_binary_slicelength'] * GLOBAL_DATA['feature_sample_rate']
    max_seg_len_raw = GLOBAL_DATA['max_binary_slicelength'] * GLOBAL_DATA['sample_rate']

    label_order = [x[0] for x in groupby(y_sampled)]
    label_change_idxs = np.where(y_sampled_np[:-1] != y_sampled_np[1:])[0]

    start_raw_idx = 0
    start_label_idx = 0
    end_raw_idx = raw_data.size(0)
    end_label_idx = len(y_sampled)
    previous_bckg_len = 0

    sliced_raws = []
    sliced_labels = []
    pre_bckg_lens_label = []
    label_list_for_filename = []

    for idx, label in enumerate(label_order):
        # if last and the label is "bckg"
        if (len(label_order) == idx+1) and (label == "0"):
            sliced_raw_data = raw_data[start_raw_idx:].permute(1,0)
            sliced_y1 = torch.Tensor(list(map(int,y_sampled[start_label_idx:]))).byte()

            if sliced_y1.size(0) < min_seg_len_label:
                continue
            sliced_raws.append(sliced_raw_data)
            sliced_labels.append(sliced_y1)
            pre_bckg_lens_label.append(0)
            label_list_for_filename.append(label)

        # if not last and the label is "bckg"
        elif (len(label_order) != idx+1) and (label == "0"):
            end_raw_idx = (label_change_idxs[idx]+1) * GLOBAL_DATA['fsr_sr_ratio']
            end_label_idx = label_change_idxs[idx]+1

            sliced_raw_data = raw_data[start_raw_idx:end_raw_idx].permute(1,0)
            sliced_y1 = torch.Tensor(list(map(int,y_sampled[start_label_idx:end_label_idx]))).byte()
            previous_bckg_len = end_label_idx - start_label_idx

            start_raw_idx = end_raw_idx
            start_label_idx = end_label_idx
            if sliced_y1.size(0) < min_seg_len_label:
                continue

            sliced_raws.append(sliced_raw_data)
            sliced_labels.append(sliced_y1)
            pre_bckg_lens_label.append(0)
            label_list_for_filename.append(label)

        # if the first and the label is "seiz" 1 ~ 8
        elif (idx == 0) and (label != "0"):
            end_raw_idx = (label_change_idxs[idx]+1) * GLOBAL_DATA['fsr_sr_ratio']
            end_label_idx = label_change_idxs[idx]+1

            if len(y_sampled)-end_label_idx > max_seg_len_after_seiz_label:
                post_len_label = max_seg_len_after_seiz_label
                post_len_raw = max_seg_len_after_seiz_raw
            else:
                post_len_label = len(y_sampled)-end_label_idx
                post_len_raw = ((len(y_sampled)-end_label_idx) * GLOBAL_DATA['fsr_sr_ratio'])
            post_ictal_end_label = end_label_idx + post_len_label
            post_ictal_end_raw = end_raw_idx + post_len_raw

            start_raw_idx = end_raw_idx
            start_label_idx = end_label_idx
            if len(y_sampled) < min_seg_len_label:
                continue

            sliced_raw_data = raw_data[:post_ictal_end_raw].permute(1,0)
            sliced_y1 = torch.Tensor(list(map(int,y_sampled[:post_ictal_end_label]))).byte()

            if sliced_y1.size(0) > max_seg_len_label:
                sliced_y2 = sliced_y1[:max_seg_len_label]
                sliced_raw_data2 = sliced_raw_data.permute(1,0)[:max_seg_len_raw].permute(1,0)
                sliced_raws.append(sliced_raw_data2)
                sliced_labels.append(sliced_y2)
                pre_bckg_lens_label.append(0)
                label_list_for_filename.append(label)
            elif sliced_y1.size(0) >= min_seg_len_label:
                sliced_raws.append(sliced_raw_data)
                sliced_labels.append(sliced_y1)
                pre_bckg_lens_label.append(0)
                label_list_for_filename.append(label)
            else:
                sliced_y2 = torch.Tensor(list(map(int,y_sampled[:min_seg_len_label]))).byte()
                sliced_raw_data2 = raw_data[:min_seg_len_raw].permute(1,0)
                sliced_raws.append(sliced_raw_data2)
                sliced_labels.append(sliced_y2)
                pre_bckg_lens_label.append(0)
                label_list_for_filename.append(label)

        # the label is "seiz" 1 ~ 8
        elif label != "0":
            end_raw_idx = (label_change_idxs[idx]+1) * GLOBAL_DATA['fsr_sr_ratio']
            end_label_idx = label_change_idxs[idx]+1

            if len(y_sampled)-end_label_idx > max_seg_len_after_seiz_label:
                post_len_label = max_seg_len_after_seiz_label
                post_len_raw = max_seg_len_after_seiz_raw
            else:
                post_len_label = len(y_sampled)-end_label_idx
                post_len_raw = ((len(y_sampled)-end_label_idx) * GLOBAL_DATA['fsr_sr_ratio'])
            post_ictal_end_label = end_label_idx + post_len_label
            post_ictal_end_raw = end_raw_idx + post_len_raw

            if previous_bckg_len > max_seg_len_before_seiz_label:
                pre_seiz_label_len = max_seg_len_before_seiz_label
            else:
                pre_seiz_label_len = previous_bckg_len
            pre_seiz_raw_len = pre_seiz_label_len * GLOBAL_DATA['fsr_sr_ratio']

            sample_len = post_ictal_end_label - (start_label_idx-pre_seiz_label_len)
            if sample_len < min_seg_len_label:
                post_ictal_end_label = start_label_idx - pre_seiz_label_len + min_seg_len_label
                post_ictal_end_raw = start_raw_idx - pre_seiz_raw_len + min_seg_len_raw
            if len(y_sampled) < post_ictal_end_label:
                start_raw_idx = end_raw_idx
                start_label_idx = end_label_idx
                continue

            sliced_raw_data = raw_data[start_raw_idx-pre_seiz_raw_len:post_ictal_end_raw].permute(1,0)
            sliced_y1 = torch.Tensor(list(map(int,y_sampled[start_label_idx-pre_seiz_label_len:post_ictal_end_label]))).byte()

            if sliced_y1.size(0) > max_seg_len_label:
                sliced_y2 = sliced_y1[:max_seg_len_label]
                sliced_raw_data2 = sliced_raw_data.permute(1,0)[:max_seg_len_raw].permute(1,0)
                sliced_raws.append(sliced_raw_data2)
                sliced_labels.append(sliced_y2)
                pre_bckg_lens_label.append(pre_seiz_label_len)
                label_list_for_filename.append(label)
            # elif sliced_y1.size(0) >= min_seg_len_label:
            else:
                sliced_raws.append(sliced_raw_data)
                sliced_labels.append(sliced_y1)
                pre_bckg_lens_label.append(pre_seiz_label_len)
                label_list_for_filename.append(label)
            start_raw_idx = end_raw_idx
            start_label_idx = end_label_idx

        else:
            print("Error! Impossible!")
            exit(1)

    for data_idx in range(len(sliced_raws)):
        sliced_raw = sliced_raws[data_idx]
        sliced_y = sliced_labels[data_idx]
        sliced_y_map = list(map(int,sliced_y))

        if GLOBAL_DATA['binary_target1'] is not None:
            sliced_y2 = torch.Tensor([GLOBAL_DATA['binary_target1'][i] for i in sliced_y_map]).byte()
        else:
            sliced_y2 = None

        if GLOBAL_DATA['binary_target2'] is not None:
            sliced_y3 = torch.Tensor([GLOBAL_DATA['binary_target2'][i] for i in sliced_y_map]).byte()
        else:
            sliced_y3 = None

        new_data['RAW_DATA'] = [sliced_raw]
        new_data['LABEL1'] = [sliced_y]
        new_data['LABEL2'] = [sliced_y2]
        new_data['LABEL3'] = [sliced_y3]

        prelabel_len = pre_bckg_lens_label[data_idx]
        label = label_list_for_filename[data_idx]

        with open(GLOBAL_DATA['data_file_directory'] + "/{}_c{}_pre{}_len{}_label_{}.pkl".format(data_file_name, str(data_idx), str(prelabel_len), str(len(sliced_y)), str(label)), 'wb') as _f:
            pickle.dump(new_data, _f)
        new_data = {}

def generate_training_data_leadwise_tuh_train_final(file):
    sample_rate = GLOBAL_DATA['sample_rate']    # EX) 200Hz
    file_name = ".".join(file.split(".")[:-1])  # EX) /content/tuh_eeg_seizuretrain/01_tcp_ar/072/00007235/s003_2010_11_20/00007235_s003_t000
    data_file_name = file_name.split("/")[-1]   # EX) 00007235_s003_t000
    signals, signal_headers, header = highlevel.read_edf(file)
    label_list_c = []
    for idx, signal in enumerate(signals):
        label_noref = signal_headers[idx]['label'].split("-")[0]    # EX) EEG FP1-ref or EEG FP1-LE --> EEG FP1
        label_list_c.append(label_noref)

    ############################# part 1: labeling  ###############################
    label_file = open(file_name + "." + GLOBAL_DATA['args.label_type'], 'r') # EX) 00007235_s003_t003.tse or 00007235_s003_t003.tse_bi
    y = label_file.readlines()
    y = list(y[2:])
    y_labels = list(set([i.split(" ")[2] for i in y]))
    signal_sample_rate = int(signal_headers[0]['sample_rate'])
    if sample_rate > signal_sample_rate:
        return
    if not all(elem in label_list_c for elem in GLOBAL_DATA['label_list']): # if one or more of ['EEG FP1', 'EEG FP2', ... doesn't exist
        return
    # if not any(elem in y_labels for elem in GLOBAL_DATA['disease_type']): # if non-patient exist
    #     return
    y_sampled = label_sampling_tuh(y, GLOBAL_DATA['feature_sample_rate'])

    # check if seizure patient or non-seizure patient
    patient_wise_dir = "/".join(file_name.split("/")[:-2])
    patient_id = file_name.split("/")[-3]
    edf_list = search_walk({'path': patient_wise_dir, 'extension': ".tse_bi"})
    patient_bool = False
    for tse_bi_file in edf_list:
        label_file = open(tse_bi_file, 'r') # EX) 00007235_s003_t003.tse or 00007235_s003_t003.tse_bi
        y = label_file.readlines()
        y = list(y[2:])
        for line in y:
            if len(line) > 5:
                if line.split(" ")[2] != 'bckg':
                    patient_bool = True
                    break
        if patient_bool:
            break

    ############################# part 2: input data filtering #############################
    signal_list = []
    signal_label_list = []
    signal_final_list_raw = []

    for idx, signal in enumerate(signals):
        label = signal_headers[idx]['label'].split("-")[0]
        if label not in GLOBAL_DATA['label_list']:
            continue

        if int(signal_headers[idx]['sample_rate']) > sample_rate:
            secs = len(signal)/float(signal_sample_rate)
            samps = int(secs*sample_rate)
            x = sci_sig.resample(signal, samps)
            signal_list.append(x)
            signal_label_list.append(label)
        else:
            signal_list.append(signal)
            signal_label_list.append(label)

    if len(signal_label_list) != len(GLOBAL_DATA['label_list']):
        print("Not enough labels: ", signal_label_list)
        return

    for lead_signal in GLOBAL_DATA['label_list']:
        signal_final_list_raw.append(signal_list[signal_label_list.index(lead_signal)])

    new_length = len(signal_final_list_raw[0]) * (float(GLOBAL_DATA['feature_sample_rate']) / GLOBAL_DATA['sample_rate'])

    if len(y_sampled) > new_length:
        y_sampled = y_sampled[:new_length]
    elif len(y_sampled) < new_length:
        diff = int(new_length - len(y_sampled))
        y_sampled += y_sampled[-1] * diff

    y_sampled_np = np.array(list(map(int,y_sampled)))
    new_labels = []
    new_labels_idxs = []

    ############################# part 3: slicing for easy training  #############################
    y_sampled = ["0" if l not in GLOBAL_DATA['selected_diseases'] else l for l in y_sampled]

    if any(l in GLOBAL_DATA['selected_diseases'] for l in y_sampled):
        y_sampled = [str(GLOBAL_DATA['target_dictionary'][int(l)]) if l in GLOBAL_DATA['selected_diseases'] else l for l in y_sampled]

    # slice and save if training data
    new_data = {}
    raw_data = torch.Tensor(signal_final_list_raw).permute(1,0)
    raw_data = raw_data.type(torch.float16)

    min_seg_len_label = GLOBAL_DATA['min_binary_slicelength'] * GLOBAL_DATA['feature_sample_rate']
    min_seg_len_raw = GLOBAL_DATA['min_binary_slicelength'] * GLOBAL_DATA['sample_rate']
    min_binary_edge_seiz_label = GLOBAL_DATA['min_binary_edge_seiz'] * GLOBAL_DATA['feature_sample_rate']
    min_binary_edge_seiz_raw = GLOBAL_DATA['min_binary_edge_seiz'] * GLOBAL_DATA['sample_rate']

    label_order = [x[0] for x in groupby(y_sampled)]
    label_change_idxs = np.where(y_sampled_np[:-1] != y_sampled_np[1:])[0]
    label_change_idxs = np.append(label_change_idxs, np.array([len(y_sampled_np)-1]))

    sliced_raws = []
    sliced_labels = []
    label_list_for_filename = []
    if len(y_sampled) < min_seg_len_label:
        return
    else:
        label_count = {}
        y_sampled_2nd = list(y_sampled)
        raw_data_2nd = raw_data
        while len(y_sampled) >= min_seg_len_label:
            is_at_middle = False
            sliced_y = y_sampled[:min_seg_len_label]
            labels = [x[0] for x in groupby(sliced_y)]

            if len(labels) == 1 and "0" in labels:
                y_sampled = y_sampled[min_seg_len_label:]
                sliced_raw_data = raw_data[:min_seg_len_raw].permute(1,0)
                raw_data = raw_data[min_seg_len_raw:]
                if patient_bool:
                    label = "0_patT"
                else:
                    label = "0_patF"
                sliced_raws.append(sliced_raw_data)
                sliced_labels.append(sliced_y)
                label_list_for_filename.append(label)

            elif len(labels) != 1 and (sliced_y[0] == '0') and (sliced_y[-1] != '0'):
                temp_sliced_y = list(sliced_y)
                temp_sliced_y.reverse()
                boundary_seizlen = temp_sliced_y.index("0") + 1
                if boundary_seizlen < min_binary_edge_seiz_label:
                    if len(y_sampled) > (min_seg_len_label + min_binary_edge_seiz_label):
                        sliced_y = y_sampled[min_binary_edge_seiz_label:min_seg_len_label+min_binary_edge_seiz_label]
                        sliced_raw_data = raw_data[min_binary_edge_seiz_raw:min_seg_len_raw+min_binary_edge_seiz_raw].permute(1,0)
                    else:
                        sliced_raw_data = raw_data[:min_seg_len_raw].permute(1,0)
                else:
                    sliced_raw_data = raw_data[:min_seg_len_raw].permute(1,0)

                y_sampled = y_sampled[min_seg_len_label:]
                raw_data = raw_data[min_seg_len_raw:]

                label = str(max(list(map(int, labels))))
                sliced_raws.append(sliced_raw_data)
                sliced_labels.append(sliced_y)

                label = label + "_beg"
                label_list_for_filename.append(label)
                is_at_middle = True

            elif (len(labels) != 1) and (sliced_y[0] != '0') and (sliced_y[-1] != '0'):
                y_sampled = y_sampled[min_seg_len_label:]
                sliced_raw_data = raw_data[:min_seg_len_raw].permute(1,0)
                raw_data = raw_data[min_seg_len_raw:]

                label = str(max(list(map(int, labels))))
                sliced_raws.append(sliced_raw_data)
                sliced_labels.append(sliced_y)

                label = label + "_whole"
                label_list_for_filename.append(label)
                is_at_middle = True

            elif (len(labels) == 1) and (sliced_y[0] != '0') and (sliced_y[-1] != '0'):
                y_sampled = y_sampled[min_seg_len_label:]
                sliced_raw_data = raw_data[:min_seg_len_raw].permute(1,0)
                raw_data = raw_data[min_seg_len_raw:]

                label = str(max(list(map(int, labels))))
                sliced_raws.append(sliced_raw_data)
                sliced_labels.append(sliced_y)

                label = label + "_middle"
                label_list_for_filename.append(label)
                is_at_middle = True

            elif len(labels) != 1 and (sliced_y[0] != '0') and (sliced_y[-1] == '0'):
                y_sampled = y_sampled[min_seg_len_label:]
                sliced_raw_data = raw_data[:min_seg_len_raw].permute(1,0)
                raw_data = raw_data[min_seg_len_raw:]

                label = str(max(list(map(int, labels))))
                sliced_raws.append(sliced_raw_data)
                sliced_labels.append(sliced_y)

                label = label + "_end"
                label_list_for_filename.append(label)

            elif len(labels) != 1 and (sliced_y[0] == '0') and (sliced_y[-1] == '0'):
                y_sampled = y_sampled[min_seg_len_label:]
                sliced_raw_data = raw_data[:min_seg_len_raw].permute(1,0)
                raw_data = raw_data[min_seg_len_raw:]

                label = str(max(list(map(int, labels))))
                sliced_raws.append(sliced_raw_data)
                sliced_labels.append(sliced_y)

                label = label + "_whole"
                label_list_for_filename.append(label)

            else:
                print("unexpected case")
                exit(1)
        if is_at_middle == True:
            sliced_y = y_sampled_2nd[-min_seg_len_label:]
            sliced_raw_data = raw_data_2nd[-min_seg_len_raw:].permute(1,0)

            if sliced_y[-1] == '0':
                label = str(max(list(map(int, labels))))
                sliced_raws.append(sliced_raw_data)
                sliced_labels.append(sliced_y)

                label = label + "_end"
                label_list_for_filename.append(label)
            else:
                pass

    for data_idx in range(len(sliced_raws)):
        sliced_raw = sliced_raws[data_idx]
        sliced_y = sliced_labels[data_idx]
        sliced_y_map = list(map(int,sliced_y))
        sliced_y = torch.Tensor(sliced_y_map).byte()

        if GLOBAL_DATA['binary_target1'] is not None:
            sliced_y2 = torch.Tensor([GLOBAL_DATA['binary_target1'][i] for i in sliced_y_map]).byte()
        else:
            sliced_y2 = None

        if GLOBAL_DATA['binary_target2'] is not None:
            sliced_y3 = torch.Tensor([GLOBAL_DATA['binary_target2'][i] for i in sliced_y_map]).byte()
        else:
            sliced_y3 = None

        new_data['RAW_DATA'] = [sliced_raw]
        new_data['LABEL1'] = [sliced_y]
        new_data['LABEL2'] = [sliced_y2]
        new_data['LABEL3'] = [sliced_y3]

        label = label_list_for_filename[data_idx]

        with open(GLOBAL_DATA['data_file_directory'] + "/{}_c{}_label_{}.pkl".format(data_file_name, str(data_idx), str(label)), 'wb') as _f:
            pickle.dump(new_data, _f)
        new_data = {}

def generate_training_data_leadwise_tuh_dev(file):
    sample_rate = GLOBAL_DATA['sample_rate']    # EX) 200Hz
    file_name = ".".join(file.split(".")[:-1])  # EX) /content/tuh_eeg_seizuretrain/01_tcp_ar/072/00007235/s003_2010_11_20/00007235_s003_t000
    data_file_name = file_name.split("/")[-1]   # EX) 00007235_s003_t000
    signals, signal_headers, header = highlevel.read_edf(file)
    label_list_c = []
    for idx, signal in enumerate(signals):
        label_noref = signal_headers[idx]['label'].split("-")[0]    # EX) EEG FP1-ref or EEG FP1-LE --> EEG FP1
        label_list_c.append(label_noref)

    ############################# part 1: labeling  ###############################
    label_file = open(file_name + "." + GLOBAL_DATA['args.label_type'], 'r') # EX) 00007235_s003_t003.tse or 00007235_s003_t003.tse_bi
    y = label_file.readlines()
    y = list(y[2:])
    y_labels = list(set([i.split(" ")[2] for i in y]))
    signal_sample_rate = int(signal_headers[0]['sample_rate'])
    if sample_rate > signal_sample_rate:
        return
    if not all(elem in label_list_c for elem in GLOBAL_DATA['label_list']): # if one or more of ['EEG FP1', 'EEG FP2', ... doesn't exist
        return
    # if not any(elem in y_labels for elem in GLOBAL_DATA['disease_type']): # if non-patient exist
    #     return
    y_sampled = label_sampling_tuh(y, GLOBAL_DATA['feature_sample_rate'])

    # check if seizure patient or non-seizure patient
    patient_wise_dir = "/".join(file_name.split("/")[:-2])
    import os

def search_walk(args):
    path = args['path']
    extension = args['extension']
    file_list = []
    for root, _, files in os.walk(path):
        for file in files:
            if file.endswith(extension):
                file_list.append(os.path.join(root, file))
    return file_list

    edf_list = search_walk({'path': patient_wise_dir, 'extension': ".tse_bi"})
    patient_bool = False
    for tse_bi_file in edf_list:
        label_file = open(tse_bi_file, 'r') # EX) 00007235_s003_t003.tse or 00007235_s003_t003.tse_bi
        y = label_file.readlines()
        y = list(y[2:])
        for line in y:
            if len(line) > 5:
                if line.split(" ")[2] != 'bckg':
                    patient_bool = True
                    break
        if patient_bool:
            break

    ############################# part 2: input data filtering #############################
    signal_list = []
    signal_label_list = []
    signal_final_list_raw = []

    for idx, signal in enumerate(signals):
        label = signal_headers[idx]['label'].split("-")[0]
        if label not in GLOBAL_DATA['label_list']:
            continue

        if int(signal_headers[idx]['sample_rate']) > sample_rate:
            secs = len(signal)/float(signal_sample_rate)
            samps = int(secs*sample_rate)
            x = sci_sig.resample(signal, samps)
            signal_list.append(x)
            signal_label_list.append(label)
        else:
            signal_list.append(signal)
            signal_label_list.append(label)

    if len(signal_label_list) != len(GLOBAL_DATA['label_list']):
        print("Not enough labels: ", signal_label_list)
        return

    for lead_signal in GLOBAL_DATA['label_list']:
        signal_final_list_raw.append(signal_list[signal_label_list.index(lead_signal)])

    new_length = len(signal_final_list_raw[0]) * (float(GLOBAL_DATA['feature_sample_rate']) / GLOBAL_DATA['sample_rate'])

    if len(y_sampled) > new_length:
        y_sampled = y_sampled[:new_length]
    elif len(y_sampled) < new_length:
        diff = int(new_length - len(y_sampled))
        y_sampled += y_sampled[-1] * diff

    y_sampled_np = np.array(list(map(int,y_sampled)))
    new_labels = []
    new_labels_idxs = []

    ############################# part 3: slicing for easy training  #############################
    y_sampled = ["0" if l not in GLOBAL_DATA['selected_diseases'] else l for l in y_sampled]

    if any(l in GLOBAL_DATA['selected_diseases'] for l in y_sampled):
        y_sampled = [str(GLOBAL_DATA['target_dictionary'][int(l)]) if l in GLOBAL_DATA['selected_diseases'] else l for l in y_sampled]

    # slice and save if training data
    new_data = {}
    raw_data = torch.Tensor(signal_final_list_raw).permute(1,0)
    raw_data = raw_data.type(torch.float16)

    # max_seg_len_before_seiz_label = GLOBAL_DATA['max_bckg_before_slicelength'] * GLOBAL_DATA['feature_sample_rate']
    # max_seg_len_before_seiz_raw = GLOBAL_DATA['max_bckg_before_slicelength'] * GLOBAL_DATA['sample_rate']
    # min_end_margin_label = args.slice_end_margin_length * GLOBAL_DATA['feature_sample_rate']
    # min_end_margin_raw = args.slice_end_margin_length * GLOBAL_DATA['sample_rate']

    min_seg_len_label = GLOBAL_DATA['min_binary_slicelength'] * GLOBAL_DATA['feature_sample_rate']
    min_seg_len_raw = GLOBAL_DATA['min_binary_slicelength'] * GLOBAL_DATA['sample_rate']
    # max_seg_len_label = GLOBAL_DATA['max_binary_slicelength'] * GLOBAL_DATA['feature_sample_rate']
    # max_seg_len_raw = GLOBAL_DATA['max_binary_slicelength'] * GLOBAL_DATA['sample_rate']

    sliced_raws = []
    sliced_labels = []
    label_list_for_filename = []

    if len(y_sampled) < min_seg_len_label:
        return
    else:
        label_count = {}
        while len(y_sampled) >= min_seg_len_label:
            one_left_slice = False
            sliced_y = y_sampled[:min_seg_len_label]

            if (sliced_y[-1] == '0'):
                sliced_raw_data = raw_data[:min_seg_len_raw].permute(1,0)
                raw_data = raw_data[min_seg_len_raw:]
                y_sampled = y_sampled[min_seg_len_label:]

                labels = [x[0] for x in groupby(sliced_y)]
                if (len(labels) == 1) and (labels[0] == '0'):
                    label = "0"
                else:
                    label = ("".join(labels)).replace("0", "")[0]
                sliced_raws.append(sliced_raw_data)
                sliced_labels.append(sliced_y)
                label_list_for_filename.append(label)

            else:
                if '0' in y_sampled[min_seg_len_label:]:
                    end_1 = y_sampled[min_seg_len_label:].index('0')
                    temp_y_sampled = list(y_sampled[min_seg_len_label+end_1:])
                    temp_y_sampled_order = [x[0] for x in groupby(temp_y_sampled)]

                    if len(list(set(temp_y_sampled))) == 1:
                        end_2 = len(temp_y_sampled)
                        one_left_slice = True
                    else:
                        end_2 = temp_y_sampled.index(temp_y_sampled_order[1])

                    if end_2 >= min_end_margin_label:
                        temp_sec = random.randint(1,args.slice_end_margin_length)
                        temp_seg_len_label = int(min_seg_len_label + (temp_sec * args.feature_sample_rate) + end_1)
                        temp_seg_len_raw = int(min_seg_len_raw + (temp_sec * args.samplerate) + (end_1 * GLOBAL_DATA['fsr_sr_ratio']))
                    else:
                        if one_left_slice:
                            temp_label = end_2
                        else:
                            temp_label = end_2 // 2

                        temp_seg_len_label = int(min_seg_len_label + temp_label + end_1)
                        temp_seg_len_raw = int(min_seg_len_raw + (temp_label * GLOBAL_DATA['fsr_sr_ratio']) + (end_1 * GLOBAL_DATA['fsr_sr_ratio']))

                    sliced_y = y_sampled[:temp_seg_len_label]
                    sliced_raw_data = raw_data[:temp_seg_len_raw].permute(1,0)
                    raw_data = raw_data[temp_seg_len_raw:]
                    y_sampled = y_sampled[temp_seg_len_label:]

                    labels = [x[0] for x in groupby(sliced_y)]
                    if (len(labels) == 1) and (labels[0] == '0'):
                        label = "0"
                    else:
                        label = ("".join(labels)).replace("0", "")[0]
                    sliced_raws.append(sliced_raw_data)
                    sliced_labels.append(sliced_y)
                    label_list_for_filename.append(label)
                else:
                    sliced_y = y_sampled[:]
                    sliced_raw_data = raw_data[:].permute(1,0)
                    raw_data = []
                    y_sampled = []

                    labels = [x[0] for x in groupby(sliced_y)]
                    if (len(labels) == 1) and (labels[0] == '0'):
                        label = "0"
                    else:
                        label = ("".join(labels)).replace("0", "")[0]
                    sliced_raws.append(sliced_raw_data)
                    sliced_labels.append(sliced_y)
                    label_list_for_filename.append(label)

    for data_idx in range(len(sliced_raws)):
        sliced_raw = sliced_raws[data_idx]
        sliced_y = sliced_labels[data_idx]
        sliced_y_map = list(map(int,sliced_y))

        if GLOBAL_DATA['binary_target1'] is not None:
            sliced_y2 = torch.Tensor([GLOBAL_DATA['binary_target1'][i] for i in sliced_y_map]).byte()
        else:
            sliced_y2 = None

        if GLOBAL_DATA['binary_target2'] is not None:
            sliced_y3 = torch.Tensor([GLOBAL_DATA['binary_target2'][i] for i in sliced_y_map]).byte()
        else:
            sliced_y3 = None

        new_data['RAW_DATA'] = [sliced_raw]
        new_data['LABEL1'] = [sliced_y]
        new_data['LABEL2'] = [sliced_y2]
        new_data['LABEL3'] = [sliced_y3]

        label = label_list_for_filename[data_idx]

        with open(GLOBAL_DATA['data_file_directory'] + "/{}_c{}_len{}_label_{}.pkl".format(data_file_name, str(data_idx), str(len(sliced_y)), str(label)), 'wb') as _f:
            pickle.dump(new_data, _f)
        new_data = {}


    from argparse import Namespace

def main(args):
    save_directory = args.save_directory
    data_type = args.data_type
    dataset = args.dataset
    label_type = args.label_type
    sample_rate = args.sample_rate
    cpu_num = args.cpu_num
    feature_type = args.feature_type
    feature_sample_rate = args.feature_sample_rate
    task_type = args.task_type



    # Add your real code here

# 👇 Create a manual args object
from argparse import Namespace
args = Namespace(
    save_directory="/content/tuh_eeg_seizure/edf",
    dataset="tuh",
    task_type="binary",
    data_type="train",
    label_type="tse",
    sample_rate=250,
    cpu_num=2,
    feature_type="default",
    feature_sample_rate=1

)
data_file_directory = f"{args.save_directory}/dataset-{args.dataset}_task-{args.task_type}_datatype-{args.data_type}_v6"


labels = ['EEG FP1', 'EEG FP2', 'EEG F3', 'EEG F4', 'EEG F7', 'EEG F8',
                    'EEG C3', 'EEG C4', 'EEG CZ', 'EEG T3', 'EEG T4',
                    'EEG P3', 'EEG P4', 'EEG O1', 'EEG O2', 'EEG T5', 'EEG T6', 'EEG PZ', 'EEG FZ']
eeg_data_directory = "/content/tuh_eeg_seizure{}".format(args.data_type)
    # eeg_data_directory = "/mnt/aitrics_ext/ext01/shared/edf/tuh_final/{}".format(data_type)
if args.label_type == "tse":
    disease_labels = {
        'bckg': 0, 'cpsz': 1, 'mysz': 2, 'gnsz': 3,
        'fnsz': 4, 'tnsz': 5, 'tcsz': 6, 'spsz': 7, 'absz': 8
    }
elif args.label_type == "tse_bi":
    disease_labels = {
        'bckg': 0, 'seiz': 1
    }
else:
    raise ValueError(f"Unsupported args.label_type: {args.label_type}")

disease_labels_inv = {v: k for k, v in disease_labels.items()}


edf_list1 = search_walk({'path': eeg_data_directory, 'extension': ".edf"})
edf_list2 = search_walk({'path': eeg_data_directory, 'extension': ".EDF"})
if edf_list2:
        edf_list = edf_list1 + edf_list2
else:
        edf_list = edf_list1
if os.path.isdir(data_file_directory):
        os.system("rm -rf {}".format(data_file_directory))
os.system("mkdir {}".format(data_file_directory))

GLOBAL_DATA['label_list'] = labels # 'EEG FP1', 'EEG FP2', 'EEG F3', ...
GLOBAL_DATA['disease_labels'] = disease_labels #  {'bckg': 0, 'cpsz': 1, 'mysz': 2, ...
GLOBAL_DATA['disease_labels_inv'] = disease_labels_inv #  {0:'bckg', 1:'cpsz', 2:'mysz', ...
GLOBAL_DATA['data_file_directory'] = data_file_directory
GLOBAL_DATA['args.label_type'] = args.label_type # "tse_bi" ...
GLOBAL_DATA['args.feature_type'] = args.feature_type
GLOBAL_DATA['args.feature_sample_rate'] = args.feature_sample_rate
GLOBAL_DATA['args.sample_rate'] = args.sample_rate
GLOBAL_DATA['fsr_sr_ratio'] = (args.sample_rate // args.feature_sample_rate)
GLOBAL_DATA['min_binary_slicelength'] =  getattr(args, 'min_binary_slicelength',30)
GLOBAL_DATA['min_binary_edge_seiz'] = getattr(args, 'min_binary_edge_seiz',3)

target_dictionary = {0:0}
selected_diseases = []
for idx, i in enumerate(getattr(args, 'disease_type', [])):
        selected_diseases.append(str(disease_labels[i]))
        target_dictionary[disease_labels[i]] = idx + 1

GLOBAL_DATA['disease_type'] =getattr(args, 'disease_type',['gnsz', 'fnsz', 'spsz', 'cpsz', 'absz', 'tnsz', 'tcsz', 'mysz'])
GLOBAL_DATA['target_dictionary'] = target_dictionary # {0: 0, 4: 1, 5: 2, 8: 3, 2: 4, 9: 5, 6: 6, 7: 7, 3: 8}
GLOBAL_DATA['selected_diseases'] = selected_diseases # ['4', '5', '8', '2', '9', '6', '7', '3']
GLOBAL_DATA['binary_target1'] = getattr(args, 'binary_target1', {0:0, 1:1, 2:1, 3:1, 4:1, 5:1, 6:1, 7:1, 8:1})
GLOBAL_DATA['binary_target2'] = getattr(args, 'binary_target2', {0:0, 1:1, 2:2, 3:2, 4:2, 5:1, 6:3, 7:4, 8:5})

print("########## Preprocessor Setting Information ##########")
print("Number of EDF files: ", len(edf_list))
for i in GLOBAL_DATA:
        print("{}: {}".format(i, GLOBAL_DATA[i]))
with open(data_file_directory + '/preprocess_info.infopkl', 'wb') as pkl:
        pickle.dump(GLOBAL_DATA, pkl, protocol=pickle.HIGHEST_PROTOCOL)
print("################ Preprocess begins... ################\n")

if (args.task_type == "binary") and (args.data_type == "train"):
  if not hasattr(args, 'cpu_num') or args.cpu_num < 1:
    args.cpu_num = os.cpu_count() or 1

    run_multi_process(generate_training_data_leadwise_tuh_train_final, edf_list, n_processes=args.cpu_num)

elif (args.task_type == "binary") and (args.data_type == "dev"):
    cpu_num = max(1, getattr(args, 'cpu_num', os.cpu_count() or 1))
    run_multi_process(generate_training_data_leadwise_tuh_train_final, edf_list, n_processes=cpu_num)

if __name__ == '__main__':
    # make sure all edf file name different!!! if not, additional coding is necessary
   # make sure all edf file names are different!!! if not, additional coding is necessary
    import sys
# Remove Jupyter/Colab's auto-added -f argument
    sys.argv = [sys.argv[0]] + [arg for arg in sys.argv[1:] if not arg.startswith("-f") and not arg.endswith(".json")]

    parser = argparse.ArgumentParser()

    parser.add_argument('--seed', '-sd', type=int, default=1004,
                        help='Random seed number')
    parser.add_argument('--samplerate', '-sr', type=int, default=200,
                        help='Sample Rate')
    parser.add_argument('--save_directory', '-sp', type=str,
                        help='Path to save data')
    parser.add_argument('--label_type', '-lt', type=str,
                        default='tse',
                        help='tse_bi = global with binary label, tse = global with various labels, cae = severance CAE seizure label.')
    parser.add_argument('--cpu_num', '-cn', type=int,
                        default=32,
                        help='select number of available cpus')
    parser.add_argument('--feature_type', '-ft', type=str,
                        default=['rawsignal'])
    parser.add_argument('--sample_rate', type=int,
                        default=200,
                        help='Sampling rate of the data')

    parser.add_argument('--feature_sample_rate', '-fsr', type=int,
                        default=50,
                        help='select features sample rate')
    parser.add_argument('--dataset', '-st', type=str,
                        default='tuh',
                        choices=['tuh'])
    parser.add_argument('--data_type', '-dt', type=str,
                        default='train',
                        choices=['train', 'dev'])
    parser.add_argument('--task_type', '-tt', type=str,
                        default='binary',
                        choices=['anomaly', 'multiclassification', 'binary'])



    ##### Target Grouping #####
    parser.add_argument('--disease_type', type=list, default=['gnsz', 'fnsz', 'spsz', 'cpsz', 'absz', 'tnsz', 'tcsz', 'mysz'], choices=['gnsz', 'fnsz', 'spsz', 'cpsz', 'absz', 'tnsz', 'tcsz', 'mysz'])

    ### for binary detector ###
    # key numbers represent index of --disease_type + 1  ### -1 is "not being used"
    parser.add_argument('--binary_target1', type=dict, default={0:0, 1:1, 2:1, 3:1, 4:1, 5:1, 6:1, 7:1, 8:1})
    parser.add_argument('--binary_target2', type=dict, default={0:0, 1:1, 2:2, 3:2, 4:2, 5:1, 6:3, 7:4, 8:5})
    parser.add_argument('--min_binary_slicelength', type=int, default=30)
    parser.add_argument('--min_binary_edge_seiz', type=int, default=3)
    args = parser.parse_args()
    main(args)




In [None]:
!pip install torchinfo

In [None]:
    import torch
    import torch.nn as nn

    class CNN2D_LSTM_V1(nn.Module):
      def __init__(self, args, device):
        super(CNN2D_LSTM_V1, self).__init__()
        self.conv = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2)
        self.lstm = nn.LSTM(input_size=16 * 32 * 32, hidden_size=64, num_layers=1, batch_first=True)
        self.fc = nn.Linear(64, 2)

      def forward(self, x):
        batch_size, seq_len, c, h, w = x.shape
        x = x.view(batch_size * seq_len, c, h, w)
        x = self.pool(torch.relu(self.conv(x)))
        x = x.view(batch_size, seq_len, -1)
        x, _ = self.lstm(x)
        x = self.fc(x[:, -1])
        return x
    model = CNN2D_LSTM_V1(args, device).to(device)


In [None]:
from torch.optim.lr_scheduler import _LRScheduler
import math

class CosineAnnealingWarmUpSingle(_LRScheduler):
    def __init__(self, optimizer, max_lr, epochs, steps_per_epoch, pct_start=0.3, div_factor=25.0, final_div_factor=1e4, last_epoch=-1, verbose=False):
        self.max_lr = max_lr
        self.total_steps = epochs * steps_per_epoch
        self.warmup_steps = int(self.total_steps * pct_start)
        self.div_factor = div_factor
        self.final_div_factor = final_div_factor
        super().__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        step = self.last_epoch
        if step < self.warmup_steps:
            # linear warmup
            warmup_factor = step / max(1, self.warmup_steps)
            lrs = [base_lr + warmup_factor * (self.max_lr - base_lr) for base_lr in self.base_lrs]
        else:
            # cosine annealing
            progress = (step - self.warmup_steps) / max(1, self.total_steps - self.warmup_steps)
            cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
            lrs = [self.max_lr * cosine_decay for _ in self.base_lrs]
        return lrs


In [None]:
import mne

def load_eeg_data(file_path):
    raw = mne.io.read_raw_edf(file_path, preload=True, verbose=False)
    data = raw.get_data()  # shape: (channels, time)
    return torch.tensor(data, dtype=torch.float32)


In [None]:
def __getitem__(self, index):
    ...
    return train_x, train_y, seq_len, target_len  # ← maybe just 4 items


In [None]:
class Detector_Dataset(torch.utils.data.Dataset):
    def __init__(self, file_list, args):
        self.file_list = file_list
        self.args = args
        # load/process your data here


In [None]:
Training Model

In [None]:
# Copyright (c) 2022, Kwanhyung Lee, Hyewon Jeong, Seyun Kim AITRICS. All rights reserved.
#
# Licensed under the MIT License;
# you may not use this file except in compliance with the License.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import os
import argparse
import random
from datetime import datetime
from tqdm import tqdm
import matplotlib.pyplot as plt
from scipy.io.wavfile import write
from itertools import groupby
import math
import time
import glob

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.utils.rnn as rnn_utils
from torch.autograd import Variable
from torchsummary import summary
from torchinfo import summary

from builder.utils.lars import LARC
from control.config import args

from builder.data.data_preprocess import get_data_preprocessed
# from builder.data.data_preprocess_temp1 import get_data_preprocessed
from builder.models import get_detector_model, grad_cam
from builder.utils.logger import Logger
from builder.utils.utils import set_seeds, set_devices
from builder.utils.cosine_annealing_with_warmup import CosineAnnealingWarmUpRestarts
from builder.utils.cosine_annealing_with_warmupSingle import CosineAnnealingWarmUpSingle
from builder.trainer import get_trainer
from builder.trainer import *
from torch.utils.data import DataLoader
class CustomLogger(Logger):
    def __init__(self, args):
        super().__init__(args)
        self.y_true_multi = []
        self.y_pred_multi = []
logger = CustomLogger(args)
logger.y_true_multi = []
logger.y_pred_multi = []


os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
list_of_test_results_per_seed = []
from torch.optim.lr_scheduler import _LRScheduler
import math

class CosineAnnealingWarmUpSingle(_LRScheduler):
    def __init__(self, optimizer, max_lr, epochs, steps_per_epoch,
                 pct_start=0.3, div_factor=25.0, final_div_factor=1e4, last_epoch=-1, verbose=False):
        self.max_lr = max_lr
        self.total_steps = epochs * steps_per_epoch
        self.warmup_steps = int(self.total_steps * pct_start)
        self.div_factor = div_factor
        self.final_div_factor = final_div_factor
        self.min_lr = max_lr / final_div_factor
        self.base_lrs = [max_lr / div_factor for _ in optimizer.param_groups]
        super().__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        step = self.last_epoch + 1  # because PyTorch calls this before stepping
        if step < self.warmup_steps:
            warmup_factor = step / self.warmup_steps
            return [
                base_lr + warmup_factor * (self.max_lr - base_lr)
                for base_lr in self.base_lrs
            ]
        else:
            decay_step = step - self.warmup_steps
            total_decay_steps = self.total_steps - self.warmup_steps
            cosine_decay = 0.5 * (1 + math.cos(math.pi * decay_step / total_decay_steps))
            return [
                self.min_lr + (self.max_lr - self.min_lr) * cosine_decay
                for _ in self.optimizer.param_groups
            ]


# define result class
class ResultsSaver:
    def __init__(self, args):
        self.args = args
        self.results = []

    def results_all_seeds(self, test_results):
        # Implement your logic to save/process results from all seeds
        self.results.append(test_results)
        print(f"Saved results from seed {self.args.seed}: {test_results}")

# Implement your functions to return instances of this class
def experiment_results_validation(args):
    return ResultsSaver(args)

def experiment_results(args):
    return ResultsSaver(args)

# Initialize your results savers
save_valid_results = experiment_results_validation(args)
save_test_results = experiment_results(args)

# Your existing loop
for seed_num in args.seed_list:
    args.seed = seed_num
    set_seeds(args)
    device = set_devices(args)
    print(device)
    logger = Logger(args)
    logger.evaluator.best_auc = 0



    save_valid_results.results_all_seeds(logger.test_results)




import numpy as np
import torch
from torch.utils.data import Dataset
import mne  # For actual EDF file loading

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import mne
import glob
import random
from torch import nn

class Detector_Dataset(Dataset):
    def __init__(self, args, edf_files, transform=None):
        self.edf_files = edf_files
        self.args = args
        self.transform = transform
        self.required_channels = 12

    def __len__(self):
        return len(self.edf_files)

    def __getitem__(self, idx):
        file_path = self.edf_files[idx]
        try:
            # 1. Load raw EEG data
            raw_data = self._load_edf(file_path)  # [12, 256]

            # 2. Create all 5 required components as torch tensors
            v = torch.tensor(self._create_time_features(raw_data), dtype=torch.float32)
            w = torch.tensor(self._create_freq_features(raw_data), dtype=torch.float32)
            x = torch.tensor(self._create_raw_features(raw_data), dtype=torch.float32)
            y = torch.tensor(self._get_label(file_path), dtype=torch.long)
            z = torch.tensor(self._get_metadata(raw_data), dtype=torch.float32)

            if self.transform:
                x = self.transform(x)

            return v, w, x, y, z  # This must be properly indented inside __getitem__

        except Exception as e:
            print(f"Error processing {file_path}: {str(e)}")
            dummy = torch.zeros((12, 1, 3), dtype=torch.float32)
            return dummy, dummy, dummy, torch.tensor(0, dtype=torch.long), dummy

    # Rest of your methods (_load_edf, _create_time_features, etc.) go here

    def _load_edf(self, file_path):
        """Load EDF and return raw data [12, 256]"""
        try:
            raw = mne.io.read_raw_edf(file_path, preload=True, verbose=False)
            raw.pick(range(self.required_channels))
            return raw.get_data()  # [channels, timepoints]
        except Exception as e:
            raise ValueError(f"EDF loading failed: {str(e)}")

    def _create_time_features(self, data):
        """Create time-domain features [12, 1, 3]"""
        features = np.zeros((12, 1, 3))
        features[:, 0, 0] = data.mean(axis=1)
        features[:, 0, 1] = data.std(axis=1)
        features[:, 0, 2] = np.median(data, axis=1)
        return features

    def _create_freq_features(self, data):
        """Create frequency-domain features [12, 1, 3]"""
        psd = np.abs(np.fft.fft(data)[:, :3])
        return psd.reshape(12, 1, 3)

    def _create_raw_features(self, data):
        """Create raw signal features [12, 1, 3]"""
        window = data.shape[1] // 3
        features = np.zeros((12, 1, 3))
        for i in range(3):
            features[:, 0, i] = data[:, i*window:(i+1)*window].mean(axis=1)
        return features

    def _get_label(self, file_path):
        return 1 if "seizure" in str(file_path).lower() else 0

    def _get_metadata(self, data):
        """Create additional metadata [12, 1, 3]"""
        metadata = np.zeros((12, 1, 3))
        metadata[:, 0, 0] = data.max(axis=1) - data.min(axis=1)
        metadata[:, 0, 1] = np.percentile(data, 75, axis=1)
        metadata[:, 0, 2] = np.percentile(data, 25, axis=1)
        return metadata

def collate_fn(batch):
    """Convert numpy arrays to tensors and stack"""
    try:
        v = torch.stack([item[0] for item in batch])
        w = torch.stack([item[1] for item in batch])
        x = torch.stack([item[2] for item in batch])
        y = torch.stack([item[3] for item in batch])
        z = torch.stack([item[4] for item in batch])
        return v, w, x, y, z
    except Exception as e:
        print(f"Collate error: {e}")
        raise

def get_data_preprocessed(args):
    file_list = glob.glob("/content/tuh_eeg_seizure/edf/**/*.edf", recursive=True)
    if not file_list:
        raise ValueError("No EDF files found!")

    random.shuffle(file_list)
    split_idx = [int(0.7*len(file_list)), int(0.85*len(file_list))]
    train_files, dev_files, eval_files = np.split(file_list, split_idx)

    # Create datasets
    train_dataset = Detector_Dataset(args, train_files)
    dev_dataset = Detector_Dataset(args, dev_files)
    eval_dataset = Detector_Dataset(args, eval_files)

    # Calculate safe number of workers
    num_workers = min(4, (os.cpu_count() or 1)//2)

    # Create loaders

    train_loader = DataLoader(
      train_dataset,
      batch_size=args.batch_size,
      shuffle=True,
      collate_fn=collate_fn,
      num_workers=0 if os.name == 'nt' else min(4, os.cpu_count()//2),  # Windows needs num_workers=0
      pin_memory=torch.cuda.is_available(),
      persistent_workers=False  # Add this line
)

    dev_loader = DataLoader(
        dev_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=num_workers,
        pin_memory=True
    )
    eval_loader = DataLoader(
        eval_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=num_workers,
        pin_memory=True
    )
    # Add this at the end of your training script:
    import gc
    gc.collect()  # Helps clean up multiprocessing resources


    # Verify first batch
    print("\nDataLoader Sanity Check:")
    sample = next(iter(train_loader))
    for i, tensor in enumerate(sample):
        print(f"Item {i}: shape={tensor.shape}, dtype={tensor.dtype}")

    return train_loader, dev_loader, eval_loader

# Main execution
for seed_num in args.seed_list:
    args.seed = seed_num
    set_seeds(args)
    device = set_devices(args)
    logger = Logger(args)
    logger.evaluator.best_auc = 0

    # Load Data
    try:
        train_loader, dev_loader, eval_loader = get_data_preprocessed(args)
        print(f"Train batches: {len(train_loader)}")
    except Exception as e:
        print(f"Error loading data: {e}")
        continue  # Skip to next seed if data loading fails

    # Rest of your training loop...
    one_epoch_iter_num = len(train_loader)  # Now this will work
    print("Iterations per epoch: ", one_epoch_iter_num)

class CNN2D_LSTM_V1(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.bidirectional = True

        # Adjusted CNN architecture to handle input shape [batch, 12, 1, 3]
        self.conv_block = nn.Sequential(
            # Input: [batch, channels=12, height=1, width=3]
            nn.Conv2d(in_channels=12, out_channels=32, kernel_size=(1,1)),
            nn.BatchNorm2d(32),
            nn.ReLU(),

            nn.Conv2d(32, 64, kernel_size=(1,1)),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )

        # LSTM layer
        self.lstm = nn.LSTM(
            input_size=64,  # Must match CNN output channels
            hidden_size=128,
            batch_first=True,
            bidirectional=self.bidirectional
        )

        # Classifier
        lstm_output_size = 128 * 2 if self.bidirectional else 128
        self.fc = nn.Sequential(
            nn.Linear(lstm_output_size, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 2)
        )

    def forward(self, v, w, x, y, z):
        # 1. Combine features along width dimension
        # Input shapes: [batch, 12, 1, 3] for v,w,x,z
        x_combined = torch.cat([v, w, x, z], dim=-1)  # [batch, 12, 1, 12]

        # 2. Permute for CNN: [batch, channels, height, width]
        x = x_combined.permute(0, 3, 1, 2)  # [batch, 12, 12, 1] -> WRONG
        # Correct approach:
        x = x_combined.permute(0, 1, 2, 3)  # Maintain [batch, 12, 1, 12]

        # 3. CNN processing
        x = self.conv_block(x)  # [batch, 64, 1, 12]

        # 4. Prepare for LSTM: [batch, seq_len, features]
        x = x.squeeze(2).permute(0, 2, 1)  # [batch, 12, 64]

        # 5. LSTM processing
        x, _ = self.lstm(x)  # [batch, 12, 256] (if bidirectional)

        # 6. Get final state
        if self.bidirectional:
            x = torch.cat([x[:, -1, :128], x[:, 0, 128:]], dim=-1)
        else:
            x = x[:, -1, :]

        return self.fc(x)

# 2. Initialize device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 3. Initialize model
model = CNN2D_LSTM_V1(args).to(device)
print("\nModel Architecture:")
print(model)





dev_per_epochs = 10
criterion = nn.CrossEntropyLoss(reduction='none')

if args.checkpoint:
    if args.last:
        ckpt_path = args.dir_result + '/' + args.project_name + '/ckpts/last_{}.pth'.format(str(seed_num))
    elif args.best:
        ckpt_path = args.dir_result + '/' + args.project_name + '/ckpts/best_{}.pth'.format(str(seed_num))

    checkpoint = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(checkpoint['model'])
    logger.best_auc = checkpoint['score']
    start_epoch = checkpoint['epoch']
    del checkpoint
else:
    logger.best_auc = 0
    start_epoch = 1


if args.optim == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=args.lr_init, weight_decay=args.weight_decay)
elif args.optim == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=args.lr_init, momentum=args.momentum, weight_decay=args.weight_decay)
elif args.optim == 'adamw':
        optimizer = optim.AdamW(model.parameters(), lr = args.lr_init, weight_decay=args.weight_decay)
elif args.optim == 'adam_lars':
        optimizer = optim.Adam(model.parameters(), lr = args.lr_init, weight_decay=args.weight_decay)
        optimizer = LARC(optimizer=optimizer, eps=1e-8, trust_coefficient=0.001)
elif args.optim == 'sgd_lars':
        optimizer = optim.SGD(model.parameters(), lr=args.lr_init, momentum=args.momentum, weight_decay=args.weight_decay)
        optimizer = LARC(optimizer=optimizer, eps=1e-8, trust_coefficient=0.001)
elif args.optim == 'adamw_lars':
        optimizer = optim.AdamW(model.parameters(), lr = args.lr_init, weight_decay=args.weight_decay)
        optimizer = LARC(optimizer=optimizer, eps=1e-8, trust_coefficient=0.001)

one_epoch_iter_num = len(train_loader)
print("Iterations per epoch: ", one_epoch_iter_num)
iteration_num = args.epochs * one_epoch_iter_num

# Learning Rate Scheduler Setup
if args.lr_scheduler == "CosineAnnealing":
    # Enhanced CosineAnnealingWarmUpRestarts with better defaults
    scheduler = CosineAnnealingWarmUpRestarts(
        optimizer,
        T_0=10 * one_epoch_iter_num,      # 10 epoch cycles (prevents LR from getting too small)
        T_mult=1,                         # Keep cycle length constant
        eta_max=args.lr_init * 5,         # Peak LR = 5x initial LR
        T_up=2 * one_epoch_iter_num,      # 2 epoch warmup
        gamma=0.7                         # Gentle decay factor
    )
    print(f"Using CosineAnnealingWarmUpRestarts: "
          f"{args.epochs} epochs, {one_epoch_iter_num} iters/epoch, "
          f"peak LR={args.lr_init * 5:.1e}")

elif args.lr_scheduler == "OneCycle":
    # Turbo-charged OneCycleLR with automatic batch scaling
    base_lr = args.lr_init * math.sqrt(args.batch_size / 32)  # Scale LR by sqrt(batch_size/32)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=base_lr * 10,              # 10x base LR (automatically scaled)
        steps_per_epoch=len(train_loader),
        epochs=args.epochs,
        pct_start=0.3,                    # 30% of iterations spent increasing LR
        div_factor=25,                     # Starts from max_lr/25
        final_div_factor=1e4,              # Ends at max_lr/1e4
        anneal_strategy='cos',             # Smoother annealing
        cycle_momentum=True if isinstance(optimizer, optim.SGD) else False
    )
    print(f"Using OneCycleLR: max_lr={base_lr * 10:.1e} "
          f"(batch_size scaled from base {base_lr:.1e})")

# Add this to your training loop to monitor LR (after optimizer.step()):
current_lr = scheduler.get_last_lr()[0]
if batch_idx % 50 == 0:
    print(f"Epoch {epoch} Batch {batch_idx}: LR = {current_lr:.2e}")
    if args.lr_scheduler == "OneCycle":
        # For OneCycle only - track momentum if using SGD
        if isinstance(optimizer, optim.SGD):
            momentum_group = [pg['momentum'] for pg in optimizer.param_groups]
            print(f"Current momentum: {momentum_group[0]:.3f}")

model.train()
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)  # Prevents overconfidence
  # Ensure reduction='mean' (default)
epoch_losses = []
for epoch in range(start_epoch, args.epochs + 1):
    logger.loss = 0.0  # Reset epoch logger
    epoch_losses = []   # Reset per-epoch losses
    for batch_idx, (v, w, x, y, z) in enumerate(train_loader):
        try:
            # Verify shapes
            print(f"Batch shapes - v:{v.shape}, w:{w.shape}, x:{x.shape}, y:{y.shape}, z:{z.shape}")

            # Move to device
            v = v.to(device, non_blocking=True)
            w = w.to(device, non_blocking=True)
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)
            z = z.to(device, non_blocking=True)

            # Verify target dtype
            assert y.dtype == torch.long, f"Target y must be long dtype, got {y.dtype}"

            # Forward pass
            optimizer.zero_grad(set_to_none=True)
            outputs = model(v, w, x, y, z)

            # Verify output shape
            assert outputs.shape[0] == y.shape[0], "Batch size mismatch"
            assert outputs.shape[1] == 2, "Output should have shape [batch, 2]"

            # Calculate loss
            loss = criterion(outputs, y)
            print(f"Batch {batch_idx} loss: {loss.item():.4f}")

            # Backward pass
            loss.backward()
            if epoch < 5:  # First 5 epochs
              for g in optimizer.param_groups:
                g['lr'] = args.lr_init * (epoch / 5)
            # Gradient clipping

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()
            scheduler.step()

            # --- Second part of your processing ---
            # Using v,w,x,y,z instead of train_x, train_y to maintain consistency
            train_x, train_y = x, y  # Assuming x is your input and y is target

            # Data validation before processing
            if train_x is None:
                raise ValueError("Batch features are None")
            if not isinstance(train_x, torch.Tensor):
                raise ValueError(f"Expected tensor, got {type(train_x)}")
            if torch.isnan(train_x).any():
                raise ValueError("NaN values detected in input features")

            # Move to device (already done above)
            # Forward pass (already done above)
            # Loss calculation (already done above)

            # Gradient validation
            if batch_idx % 50 == 0:
                grad_norms = [p.grad.norm().item() for p in model.parameters() if p.grad is not None]
                if not grad_norms:
                    print("Warning: No gradients detected!")
                else:
                    avg_grad = sum(grad_norms)/len(grad_norms)
                    if avg_grad < 1e-6:
                        print("Warning: Vanishing gradients detected!")

            # Logging
            logger.loss += loss.item()
            epoch_losses.append(loss.item())

            if batch_idx % 10 == 0:
                print(f'Epoch: {epoch} | Batch: {batch_idx}/{len(train_loader)} | Loss: {loss.item():.4f}')

            # Periodic output validation
            if batch_idx % 20 == 0:
                with torch.no_grad():
                    sample_output = model(v[:1], w[:1], x[:1], y[:1], z[:1])
                    print(f"Output range: [{sample_output.min():.4f}, {sample_output.max():.4f}]")

        except Exception as e:
            print(f"\nError in batch {batch_idx}:")
            print(f"Exception: {str(e)}")
            print("Batch contents:")
            # Using the actual variables we have (v,w,x,y,z)
            for i, item in enumerate([v, w, x, y, z]):
                print(f"Item {i}: Type={type(item)}")
                if torch.is_tensor(item):
                    print(f"Shape: {item.shape}")
            continue

    avg_loss = sum(epoch_losses) / len(epoch_losses) if epoch_losses else float('nan')
    print(f"Epoch {epoch} complete | Average Loss: {avg_loss}")




# End of epoch processing
avg_epoch_loss = np.mean(epoch_losses)
print(f'\nEpoch {epoch} complete | Average Loss: {avg_epoch_loss:.4f}')
pbar.update(1)
ps_per_epoch = one_epoch_iter_num
div_factor = math.sqrt(args.batch_size)

save_valid_results.results_all_seeds(logger.test_results)
del model



print("#################################################")
print("################# Test Begins ###################")
print("#################################################")
logger = Logger(args)

# Set up evaluation
class Evaluator:
    def __init__(self, args):
        # initialization code
        pass
evaluator = Evaluator(args)
names = [args.project_name]
average_speed_over = 10
time_taken = 0
num_windows = 30 - args.window_size




for name in names:


  if args.last:
       ckpt_path = os.path.join(args.dir_result, name, "ckpts", f"last_{args.seed}.pth")
  elif args.best:
       ckpt_path = os.path.join(args.dir_result, name, "ckpts", f"best_{args.seed}.pth")
  else:
       ckpt_path = os.path.join(args.dir_result, name, "ckpts", "best.pth")  # fallback default


# Load checkpoint if exists
if os.path.exists(ckpt_path):
    print(f"Loading checkpoint from {ckpt_path}")
    checkpoint = torch.load(ckpt_path, map_location=device)

    # Load model state
    model.load_state_dict(checkpoint['model'])

    # Resume training from next epoch
    start_epoch = checkpoint['epoch'] + 1

    # Optional: Load other training states if available
    if 'optimizer' in checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer'])
    if 'scheduler' in checkpoint:
        scheduler.load_state_dict(checkpoint['scheduler'])
    if 'best_auc' in checkpoint:
        logger.best_auc = checkpoint['best_auc']

    print(f"Resuming training from epoch {start_epoch}")
else:
    print("No checkpoint found - training from scratch")


    # initialize test step
def evaluate(model, eval_loader, device, logger):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for v, w, x, y, z in eval_loader:
            v, w, x, y, z = v.to(device), w.to(device), x.to(device), y.to(device), z.to(device)

            outputs = model(v, w, x, y, z)
            preds = torch.softmax(outputs, dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(y.cpu().numpy())

    # Calculate metrics
    from sklearn.metrics import roc_auc_score, average_precision_score, f1_score
    auc = roc_auc_score(all_labels, [p[1] for p in all_preds])
    apr = average_precision_score(all_labels, [p[1] for p in all_preds])
    f1 = f1_score(all_labels, np.argmax(all_preds, axis=1))

    return auc, apr, f1

logger.writer.close()

auc_list = []
apr_list = []
f1_list = []
tpr_list = []
tnr_list = []
os.system("echo  \'#######################################\'")
os.system("echo  \'##### Final test results per seed #####\'")
os.system("echo  \'#######################################\'")
for result, tpr, tnr in list_of_test_results_per_seed:
    os.system("echo  \'seed_case:{} -- auc: {}, apr: {}, f1 _score: {}, tpr: {}, tnr: {}\'".format(str(result[0]), str(result[1]), str(result[2]), str(result[3]), str(tpr), str(tnr)))
    auc_list.append(result[1])
    apr_list.append(result[2])
    f1_list.append(result[3])
    tpr_list.append(tpr)
    tnr_list.append(tnr)
os.system("echo  \'Total average -- auc: {}, apr: {}, f1_score: {}, tnr: {}, tpr: {}\'".format(str(np.mean(auc_list)), str(np.mean(apr_list)), str(np.mean(f1_list)), str(np.mean(tpr_list)), str(np.mean(tnr_list))))
os.system("echo  \'Total std -- auc: {}, apr: {}, f1_score: {}, tnr: {}, tpr: {}\'".format(str(np.std(auc_list)), str(np.std(apr_list)), str(np.std(f1_list)), str(np.std(tpr_list)), str(np.std(tnr_list))))

In [None]:
Optimization

In [None]:
# Add these before optimize_for_tinyml()
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import prune

class Tiny_CNN2D_LSTM(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(12, 16, kernel_size=(1,1)),
            nn.BatchNorm2d(16),
            nn.ReLU6(),
            nn.Dropout(0.1),
            nn.Conv2d(16, 32, kernel_size=(1,1)),
            nn.BatchNorm2d(32),
            nn.ReLU6(),
        )
        self.lstm = nn.LSTM(input_size=32, hidden_size=64, num_layers=1, batch_first=True)
        self.fc = nn.Sequential(
            nn.Linear(64, 32),
            nn.ReLU6(),
            nn.Linear(32, 2)
        )

    def forward(self, v, w, x, y, z):
        x_combined = torch.cat([v, w, x, z], dim=-1)
        x = x_combined.permute(0, 3, 1, 2)
        x = self.conv_block(x)
        x = x.squeeze(2).permute(0, 2, 1)
        x, _ = self.lstm(x)
        x = x[:, -1, :]
        return self.fc(x)

class DistillationLoss(nn.Module):
    def __init__(self, temperature=2.0):
        super().__init__()
        self.temperature = temperature
        self.kl_loss = nn.KLDivLoss(reduction='batchmean')

    def forward(self, student_output, teacher_output, labels):
        soft_loss = self.kl_loss(
            F.log_softmax(student_output/self.temperature, dim=1),
            F.softmax(teacher_output/self.temperature, dim=1)
        ) * (self.temperature ** 2)
        hard_loss = F.cross_entropy(student_output, labels)
        return 0.7 * soft_loss + 0.3 * hard_loss

def prepare_for_quantization(model):
    model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
    return torch.quantization.prepare_qat(model.train())

def quantize_model(model):
    model.eval()
    return torch.quantization.convert(model)

def apply_pruning(model, amount=0.2):
    parameters_to_prune = [
        (module, 'weight')
        for module in model.modules()
        if isinstance(module, (nn.Linear, nn.Conv2d))
    ]
    prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.L1Unstructured,
        amount=amount
    )
    for module, _ in parameters_to_prune:
        prune.remove(module, 'weight')
    return model

def train_with_distillation(student, teacher, train_loader, epochs=10):
    student.train()
    teacher.eval()
    criterion = DistillationLoss()
    optimizer = optim.AdamW(student.parameters(), lr=1e-4)

    for epoch in range(epochs):
        for v, w, x, y, z in train_loader:
            v, w, x, y, z = v.to(device), w.to(device), x.to(device), y.to(device), z.to(device)
            with torch.no_grad():
                teacher_logits = teacher(v, w, x, y, z)
            student_logits = student(v, w, x, y, z)
            loss = criterion(student_logits, teacher_logits, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

In [None]:
def optimize_for_tinyml(args, train_loader, dev_loader):
    # 1. Create original and tiny models
    original_model = CNN2D_LSTM_V1(args).to(device)
    tiny_model = Tiny_CNN2D_LSTM(args).to(device)

    # 2. Train original model (if not already trained)
    if not args.checkpoint:
        train_model(original_model, train_loader, dev_loader)

    # 3. Apply knowledge distillation
    print("Applying knowledge distillation...")
    train_with_distillation(tiny_model, original_model, train_loader)

    # 4. Prepare for quantization aware training
    print("Preparing for quantization...")
    tiny_model = prepare_for_quantization(tiny_model)

    # 5. Fine-tune with QAT
    print("Quantization aware training...")
    train_model(tiny_model, train_loader, dev_loader, epochs=5)

    # 6. Apply pruning
    print("Applying pruning...")
    tiny_model = apply_pruning(tiny_model)

    # 7. Final quantization
    print("Final quantization...")
    tiny_model = quantize_model(tiny_model)

    # 8. Evaluate
    print("Evaluating optimized model...")
    eval_results = evaluate(tiny_model, dev_loader, device)

    # 9. Save for deployment
    torch.save(tiny_model.state_dict(), 'tinyml_eeg_model.pth')

    # 10. Convert to ONNX/TFLite for edge deployment
    dummy_input = (torch.randn(1, 12, 1, 3), torch.randn(1, 12, 1, 3),
                   torch.randn(1, 12, 1, 3), torch.zeros(1, dtype=torch.long),
                   torch.randn(1, 12, 1, 3))
    torch.onnx.export(tiny_model, dummy_input, 'tinyml_eeg_model.onnx')

    return tiny_model, eval_results

In [None]:

if args.run_tinyml_optimization:
    optimized_model, results = optimize_for_tinyml(args, train_loader, dev_loader)
    print(f"Optimized model results: {results}")

In [None]:
Test Data

In [None]:
import torch
import torch.nn as nn
import os
from tqdm import tqdm
import numpy as np

# 1. Define your 12-channel model
class CNN2D_LSTM_V1(nn.Module):
    def __init__(self, args, device):
        super().__init__()
        self.conv1 = nn.Conv2d(12, 32, kernel_size=(1,3), padding=(0,1))
        self.conv2 = nn.Conv2d(32, 32, kernel_size=(1,3), padding=(0,1))
        self.relu = nn.ReLU()
        self.lstm = nn.LSTM(input_size=32, hidden_size=128, batch_first=True)
        self.fc = nn.Linear(128, args.output_dim)

    def forward(self, x):
        # Verify input shape
        if x.size(1) != 12:
            raise ValueError(f"Expected 12 input channels, got {x.size(1)}. Input shape: {x.shape}")

        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = x.squeeze(2).permute(0, 2, 1)  # [batch, seq_len, features]
        x, _ = self.lstm(x)
        return self.fc(x[:, -1, :])  # Return last timestep

# Initialize device and model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNN2D_LSTM_V1(args, device).to(device)

# Load data
train_loader, val_loader, test_loader, len_train_dir, len_val_dir, len_test_dir = get_data_preprocessed(args)

# Load checkpoint
if args.last:
    ckpt_path = os.path.join(args.dir_result, args.project_name, 'ckpts', 'last.pth')
elif args.best:
    ckpt_path = os.path.join(args.dir_result, args.project_name, 'ckpts', 'best_0.pth')

if not os.path.exists(ckpt_path):
    print(f"Checkpoint not found at {ckpt_path}")
    exit(1)

ckpt = torch.load(ckpt_path, map_location=device)
model.load_state_dict(ckpt['model'])
print(f"Loaded model from {ckpt_path}")

# Evaluation function
def evaluate(model, loader, device):
    model.eval()
    y_true = []
    y_pred = []

    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            # Handle different batch formats
            if isinstance(batch, dict):
                x = batch['data']
                y = batch['label']
            else:
                x, y = batch[0], batch[1]

            # Prepare input
            x = x.to(device).float()
            y = y.to(device)

            # Ensure correct input shape [batch, 12, 1, time_steps]
            if x.dim() == 3:
                x = x.unsqueeze(2)  # Add height dimension if missing

            # Handle single-channel input by repeating
            if x.size(1) == 1:
                x = x.repeat(1, 12, 1, 1)
            elif x.size(1) != 12:
                raise ValueError(f"Expected 1 or 12 input channels, got {x.size(1)}")

            # Forward pass
            outputs = model(x)
            preds = torch.sigmoid(outputs) if args.task_type == "binary" else torch.softmax(outputs, dim=1)

            y_true.append(y.cpu().numpy())
            y_pred.append(preds.cpu().numpy())

    return np.concatenate(y_true), np.concatenate(y_pred)

# Run evaluation
print("\nStarting evaluation...")
y_true, y_pred = evaluate(model, test_loader, device)

# Calculate metrics
if args.task_type == "binary":
    from sklearn.metrics import roc_auc_score, accuracy_score, f1_score

    auc = roc_auc_score(y_true, y_pred)
    y_pred_class = (y_pred > 0.5).astype(int)
    acc = accuracy_score(y_true, y_pred_class)
    f1 = f1_score(y_true, y_pred_class)

    print("\nTest Results:")
    print(f"AUC: {auc:.4f}")
    print(f"Accuracy: {acc:.4f}")
    print(f"F1 Score: {f1:.4f}")
else:
    from sklearn.metrics import accuracy_score

    y_pred_class = np.argmax(y_pred, axis=1)
    acc = accuracy_score(y_true, y_pred_class)
    print(f"Accuracy: {acc:.4f}")

print("\nEvaluation complete!")

In [None]:
Seizure Test

In [None]:
import os
import argparse
import numpy as np
import random
import math
import torch
from torch import nn
from tqdm import tqdm
from scipy.signal import find_peaks
from sklearn.metrics import precision_recall_curve
from builder.data.data_preprocess import get_data_preprocessed
from builder.models import get_detector_model
from builder.utils.metrics import Evaluator
from builder.utils.logger import Logger
from builder.utils.utils import set_seeds, set_devices

# Configure environment
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

class SeizureEvaluator:
    def __init__(self, args):
        self.args = args
        self.device = set_devices(args)
        self.criterion = nn.CrossEntropyLoss(reduction='none')
        self.logger = Logger(args)

    def calc_hf(self, ref, hyp):
        """Calculate hit and false alarm fractions between two events"""
        start_r, stop_r = ref
        start_h, stop_h = hyp

        ref_dur = stop_r - start_r
        hyp_dur = stop_h - start_h
        hit = fa = 0.0

        if start_h <= start_r and stop_h <= stop_r:  # Pre-prediction
            hit = (stop_h - start_r) / ref_dur
            fa = min((start_r - start_h) / ref_dur, 1.0)
        elif start_h >= start_r and stop_h >= stop_r:  # Post-prediction
            hit = (stop_r - start_h) / ref_dur
            fa = min((stop_h - stop_r) / ref_dur, 1.0)
        elif start_h < start_r and stop_h > stop_r:  # Over-prediction
            hit = 1.0
            fa = min(((stop_h - stop_r) + (start_r - start_h)) / ref_dur, 1.0)
        else:  # Under-prediction
            hit = hyp_dur / ref_dur

        return hit, fa

    def detect_events(self, signal, threshold=0.5):
        """Detect seizure events in signal using threshold"""
        signal = np.array(signal)
        padded = np.concatenate([[0], signal, [0]])
        diff = np.diff(padded)
        starts = np.where(diff > threshold)[0]
        ends = np.where(diff < -threshold)[0]
        return list(zip(starts, ends))

    def evaluate_segments(self, model, test_loader):
        """Main evaluation function"""
        model.eval()
        all_refs = []
        all_hyps = []

        with torch.no_grad():
            for batch in tqdm(test_loader, desc="Evaluating"):
                x, y, seq_lens, target_lens, _, _ = batch
                x = x.to(self.device)

                # Process in sliding windows
                for i in range(0, x.size(1), self.args.window_size):
                    x_window = x[:, i:i+self.args.window_size]
                    if x_window.size(1) < self.args.min_segment_length:
                        continue

                    outputs = model(x_window)
                    probs = torch.sigmoid(outputs).cpu().numpy()

                    # Store results
                    all_hyps.extend(probs[:, 1].tolist())
                    all_refs.extend(y[:, i:i+self.args.window_size].cpu().numpy().tolist())

        return self.calculate_metrics(all_refs, all_hyps)

    def calculate_metrics(self, refs, hyps):
        """Calculate all evaluation metrics"""
        # Convert to numpy arrays
        refs = np.array(refs)
        hyps = np.array(hyps)

        # 1. Calculate traditional metrics
        precision, recall, thresholds = precision_recall_curve(refs, hyps)
        f1_scores = 2 * (precision * recall) / (precision + recall + 1e-9)
        best_idx = np.argmax(f1_scores)

        # 2. Calculate event-based metrics
        ref_events = self.detect_events(refs, threshold=0.5)
        hyp_events = self.detect_events(hyps, threshold=thresholds[best_idx])

        # Calculate OVLP and TAES metrics
        ovlp_metrics = self.calculate_ovlp(ref_events, hyp_events)
        taes_metrics = self.calculate_taes(ref_events, hyp_events)

        # Calculate latencies
        latency_metrics = self.calculate_latencies(ref_events, hyp_events)

        return {
            'precision': precision[best_idx],
            'recall': recall[best_idx],
            'f1': f1_scores[best_idx],
            'threshold': thresholds[best_idx],
            'ovlp': ovlp_metrics,
            'taes': taes_metrics,
            'latency': latency_metrics
        }

    def run_evaluation(self):
        """Complete evaluation pipeline"""
        set_seeds(self.args)

        # Load data and model
        train_loader, val_loader, test_loader, _, _, _ = get_data_preprocessed(self.args)
        model = get_detector_model(self.args)(self.args, self.device).to(self.device)

        # Load checkpoint
        ckpt_path = os.path.join(
            self.args.dir_result,
            self.args.project_name,
            'ckpts',
            f"best_{self.args.seed}.pth" if self.args.best else "last.pth"
        )

        if os.path.exists(ckpt_path):
            ckpt = torch.load(ckpt_path, map_location=self.device)
            model.load_state_dict(ckpt['model'])
            print(f"Loaded model from {ckpt_path}")
        else:
            raise FileNotFoundError(f"Checkpoint not found at {ckpt_path}")

        # Run evaluation
        results = self.evaluate_segments(model, test_loader)

        # Print and return results
        print("\nEvaluation Results:")
        print(f"Precision: {results['precision']:.4f}")
        print(f"Recall: {results['recall']:.4f}")
        print(f"F1 Score: {results['f1']:.4f}")
        print(f"Optimal Threshold: {results['threshold']:.4f}")

        print("\nOVLP Metrics:")
        print(f"Sensitivity: {results['ovlp']['tpr']:.4f}")
        print(f"Specificity: {results['ovlp']['tnr']:.4f}")

        print("\nTAES Metrics:")
        print(f"Sensitivity: {results['taes']['tpr']:.4f}")
        print(f"Specificity: {results['taes']['tnr']:.4f}")

        print("\nLatency Metrics:")
        print(f"Average Latency: {results['latency']['avg']:.2f}s")
        print(f"Detection Rate: {results['latency']['detection_rate']:.2f}")

        return results

if __name__ == "__main__":
    # Initialize from config
    from control.config import args

    # Run evaluation
    evaluator = SeizureEvaluator(args)
    results = evaluator.run_evaluation()