In [None]:
import os
import logging
import numpy as np
import torch
import pandas as pd
import mne
from scipy.signal import butter, filtfilt
import shutil
import gc

logging.basicConfig(
    level=logging.INFO,
    filename='pipeline.log',
    filemode='w',
    format='%(asctime)s - %(levelname)s - %(message)s'
)


edf_dir = 'raw_data/chb04/'
summary_path = 'raw_data/chb04/chb04-summary.txt'
output_dir = 'output2/chb04/'
if os.path.exists(output_dir):
    shutil.rmtree(output_dir)
os.makedirs(output_dir, exist_ok=True)


SAMPLING_RATE = 256
SEGMENT_DURATION = 30  
WINDOW_DURATION = 3  
NUM_WINDOWS = SEGMENT_DURATION // WINDOW_DURATION  
PRE_ICTAL_DURATION = 1800
POST_ICTAL_DURATION = 1800
EXCLUDE_WINDOW = 7200
TRAIN_TEST_SPLIT_RATIO = 0.8 

def parse_time(time_str):
    h, m, s = map(int, time_str.split(':'))
    if h == 24:
        h = 0
    return h * 3600 + m * 60 + s

def parse_summary(summary_path):
    file_info = {}
    day = 0
    prev_start = -1
    with open(summary_path, 'r') as f:
        lines = f.readlines()
    i = 0
    while i < len(lines):
        if lines[i].startswith('File Name:'):
            fname = lines[i].split(':')[1].strip()
            i += 1
            start_time_line = lines[i].strip()
            start_time_str = ':'.join(start_time_line.split(':')[1:]).strip()
            start_time_parsed = parse_time(start_time_str)
            if prev_start != -1 and start_time_parsed < prev_start:
                day += 1
            global_start = day * 86400 + start_time_parsed
            file_info[fname] = {'global_start': global_start, 'seizures': []}
            prev_start = start_time_parsed
            i += 1  
            i += 1  
            num_seizures = int(lines[i].split(':')[1].strip())
            if num_seizures > 0:
                for j in range(num_seizures):
                    i += 1
                    seizure_start = int(lines[i].split(':')[1].split()[0])
                    i += 1
                    seizure_end = int(lines[i].split(':')[1].split()[0])
                    global_seizure_start = global_start + seizure_start
                    global_seizure_end = global_start + seizure_end
                    file_info[fname]['seizures'].append((global_seizure_start, global_seizure_end))
            else:
                i += 1
        else:
            i += 1
    return file_info

def butter_bandpass_filter(data, lowcut=0.1, highcut=127, fs=256, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return filtfilt(b, a, data, axis=-1)

def extract_spectral_features(segment, fs=256, bands={'delta': (0.1, 4), 'theta': (4, 8), 'alpha': (8, 12), 'beta': (12, 30), 'gamma': (30, 127)}):
    n_samples = segment.shape[-1]
    freqs = np.fft.rfftfreq(n_samples, 1/fs)
    fft_vals = np.fft.rfft(segment, axis=-1)
    features = []
    for band in bands:
        low, high = bands[band]
        idx = np.where((freqs >= low) & (freqs <= high))[0]
        band_fft = fft_vals[:, idx]
        power = np.mean(np.abs(band_fft)**2, axis=-1)
        mean_amp = np.mean(np.abs(band_fft), axis=-1)
        avg_feature = (power + mean_amp) / 2
        features.append(avg_feature)
    features = np.stack(features, axis=1)  
    return features

def get_common_channels(file_names, edf_dir):
    channel_sets = []
    for file_name in file_names:
        raw = mne.io.read_raw_edf(os.path.join(edf_dir, file_name), preload=False, verbose=False)
        channel_sets.append(set(raw.ch_names))
    common_channels = set.intersection(*channel_sets)
    return sorted(list(common_channels)) 

def extract_segments(file_name, file_info, common_channels):
    raw = mne.io.read_raw_edf(os.path.join(edf_dir, file_name), preload=False, verbose=False)
    raw.pick_channels(common_channels)  
    logging.info(f"After selection, file {file_name} has {raw.info['nchan']} channels")
    duration = raw.times[-1]
    fs = raw.info['sfreq']
    assert fs == SAMPLING_RATE, f"Sampling rate is {fs}, expected {SAMPLING_RATE}"
    global_start = file_info[file_name]['global_start']
    sequence_id = hash(file_name) % 10000 

    local_seizures = file_info[file_name]['seizures']
    ictal_periods = [(s, e) for s, e in local_seizures]
    post_ictal_periods = [(e, e + POST_ICTAL_DURATION) for s, e in local_seizures]

    all_seizures = [seizure for fname in file_info for seizure in file_info[fname]['seizures']]
    exclude_windows = [(s - EXCLUDE_WINDOW, e + EXCLUDE_WINDOW) for s, e in all_seizures]
    pre_ictal_periods = [(s - PRE_ICTAL_DURATION, s) for s, e in all_seizures]

    step_size = SEGMENT_DURATION
    for t in np.arange(0, duration - SEGMENT_DURATION + 1, step_size):
        t_end = t + SEGMENT_DURATION
        absolute_t = global_start + t
        absolute_t_end = absolute_t + SEGMENT_DURATION

        if any(s <= absolute_t_end and e >= absolute_t for s, e in ictal_periods) or \
           any(s <= absolute_t_end and e >= absolute_t for s, e in post_ictal_periods):
            continue

        in_pre_ictal = any(p <= absolute_t < s for p, s in pre_ictal_periods)
        in_exclude_window = any(s <= absolute_t_end and e >= absolute_t for s, e in exclude_windows)

        if in_pre_ictal:
            label = 1
        elif in_exclude_window:
            continue
        else:
            label = 0

        window_index = int(t / step_size)
        start_sample = int(t * fs)
        stop_sample = int(t_end * fs)
        segment_data = raw.get_data(start=start_sample, stop=stop_sample)  


        window_samples = int(WINDOW_DURATION * fs) 
        window_features = []
        for w in range(NUM_WINDOWS):
            w_start = w * window_samples
            w_end = (w + 1) * window_samples
            window_data = segment_data[:, w_start:w_end]  
            filtered_data = butter_bandpass_filter(window_data)
            features = extract_spectral_features(filtered_data)  
            features = features.transpose(1, 0) 
            features = np.mean(features, axis=0) 
            window_features.append(features)
        window_features = np.stack(window_features, axis=0) 

        yield (file_name, t, t_end, window_features, label, sequence_id, window_index)

def process_and_save_all_files(file_names, file_info, patient_id, common_channels):
    all_segments = []
    for file_name in file_names:
        logging.info(f"Collecting segments from file: {file_name}")
        segment_generator = extract_segments(file_name, file_info, common_channels)
        for fname, start, end, features, label, seq_id, win_idx in segment_generator:
            global_start = file_info[fname]['global_start'] + start
            all_segments.append({
                'file_name': fname,
                'start': start,
                'end': end,
                'features': features, 
                'label': label,
                'sequence_id': seq_id,
                'window_index': win_idx,
                'global_start': global_start
            })
        gc.collect()


    np.random.seed(42)
    np.random.shuffle(all_segments)


    total_segments = len(all_segments)
    train_size = int(total_segments * TRAIN_TEST_SPLIT_RATIO)
    train_segments = all_segments[:train_size]
    test_segments = all_segments[train_size:]

    test_labels = [seg['label'] for seg in test_segments]
    n_pre_ictal_test = sum(1 for label in test_labels if label == 1)
    n_inter_ictal_test = len(test_labels) - n_pre_ictal_test
    logging.info(f"Test segments: {len(test_segments)} (Pre-ictal: {n_pre_ictal_test}, Inter-ictal: {n_inter_ictal_test})")
    logging.info(f"Total segments: {total_segments}, Train: {len(train_segments)}, Test: {len(test_segments)}")

    train_features_list = []
    train_labels_list = []
    train_sequence_ids_list = []
    train_window_indices_list = []
    train_csv_data = []

    for segment in train_segments:
        try:
            feature_tensor = torch.from_numpy(segment['features']).float().unsqueeze(0)  
            train_features_list.append(feature_tensor)
            train_labels_list.append(segment['label'])
            train_sequence_ids_list.append(segment['sequence_id'])
            train_window_indices_list.append(segment['window_index'])
            train_csv_data.append({
                'patient_id': patient_id,
                'file_name': segment['file_name'],
                'start_time': segment['start'],
                'end_time': segment['end'],
                'label': segment['label']
            })
        except Exception as e:
            logging.error(f"Error processing train segment in {segment['file_name']}: {str(e)}")
        gc.collect()

    test_features_list = []
    test_labels_list = []
    test_sequence_ids_list = []
    test_window_indices_list = []
    test_csv_data = []

    for segment in test_segments:
        try:
            feature_tensor = torch.from_numpy(segment['features']).float().unsqueeze(0) 
            test_features_list.append(feature_tensor)
            test_labels_list.append(segment['label'])
            test_sequence_ids_list.append(segment['sequence_id'])
            test_window_indices_list.append(segment['window_index'])
            test_csv_data.append({
                'patient_id': patient_id,
                'file_name': segment['file_name'],
                'start_time': segment['start'],
                'end_time': segment['end'],
                'label': segment['label']
            })
        except Exception as e:
            logging.error(f"Error processing test segment in {segment['file_name']}: {str(e)}")
        gc.collect()

    if train_features_list:
        features_tensor = torch.cat(train_features_list, dim=0) 
        mean = features_tensor.mean(dim=(0, 1), keepdim=True)
        std = features_tensor.std(dim=(0, 1), keepdim=True) + 1e-6
        features_tensor = (features_tensor - mean) / std
        labels_tensor = torch.tensor(train_labels_list, dtype=torch.long)
        sequence_ids_tensor = torch.tensor(train_sequence_ids_list, dtype=torch.long)
        window_indices_tensor = torch.tensor(train_window_indices_list, dtype=torch.long)
        torch.save({
            'features': features_tensor,
            'labels': labels_tensor,
            'sequence_ids': sequence_ids_tensor,
            'window_indices': window_indices_tensor
        }, os.path.join(output_dir, f'train_fold1_{patient_id}.pt'))
        pd.DataFrame(train_csv_data).to_csv(os.path.join(output_dir, f'train_fold1_{patient_id}.csv'), index=False)
        logging.info(f"Saved train_fold1_{patient_id} with {len(train_features_list)} segments.")
    else:
        logging.warning(f"No train segments to save for {patient_id}.")

    if test_features_list:
        features_tensor = torch.cat(test_features_list, dim=0)  
        features_tensor = (features_tensor - mean) / std
        labels_tensor = torch.tensor(test_labels_list, dtype=torch.long)
        sequence_ids_tensor = torch.tensor(test_sequence_ids_list, dtype=torch.long)
        window_indices_tensor = torch.tensor(test_window_indices_list, dtype=torch.long)
        torch.save({
            'features': features_tensor,
            'labels': labels_tensor,
            'sequence_ids': sequence_ids_tensor,
            'window_indices': window_indices_tensor
        }, os.path.join(output_dir, f'test_fold1_{patient_id}.pt'))
        pd.DataFrame(test_csv_data).to_csv(os.path.join(output_dir, f'test_fold1_{patient_id}.csv'), index=False)
        logging.info(f"Saved test_fold1_{patient_id} with {len(test_features_list)} segments.")
    else:
        logging.warning(f"No test segments to save for {patient_id}.")

    gc.collect()

def main():
    total, used, free = shutil.disk_usage(output_dir)
    free_gb = free / (2**30)
    if free_gb < 10:
        logging.warning(f"Low disk space: {free_gb:.2f} GB free. May cause write failures.")

    logging.info("Parsing summary...")
    file_info = parse_summary(summary_path)
    files = list(file_info.keys())
    patient_id = 'chb04'

    common_channels = get_common_channels(files, edf_dir)
    logging.info(f"Common channels for {patient_id}: {common_channels} (Count: {len(common_channels)})")

    logging.info(f"Processing all files for patient {patient_id}")
    process_and_save_all_files(files, file_info, patient_id, common_channels)

if __name__ == "__main__":
    main()