# Imports

In [1]:
import glob, os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import yaml
import mne
from mne.io import concatenate_raws, read_raw_edf
from mne import Epochs
from mne.decoding import Scaler
from scipy import signal
from scipy.signal import butter, lfilter
from scipy.signal import ShortTimeFFT
from extraction import extract_interictal_preictal
from pipeline import Pipeline
import torch as tf
from torch.utils.data import TensorDataset, DataLoader, random_split, Dataset

# Creating Spectograms

In [2]:

def load_config(config_path):
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config

# Load configuration from config.yaml
config = load_config("config.yaml")
base_subject_dir = config["base_subject_dir"]
subject_list = config["subject_range"]

# For testing, process only the first 2 subjects from subject_list
selected_subjects = subject_list[:2]

all_time_bins = []  # Will accumulate STFT outputs of shape (22, 5, t_i)
all_labels = []     # Will have one label per time bin (already multiplied in the pipeline)
all_ranges = {}

for subj in subject_list:
    subject_folder = f"chb{subj:02d}"
    summary_file = os.path.join(base_subject_dir, subject_folder, f"chb{subj:02d}-summary.txt")
    print(f"\nProcessing summary file for subject {subject_folder}: {summary_file}")
    
    subject_ranges = extract_interictal_preictal(summary_file)
    all_ranges[subject_folder] = subject_ranges
    
    # Process each EDF file for this subject
    for edf_fname, ranges in subject_ranges.items():
        # print("Processing EDF file:", edf_fname)
        # print("Ranges:", ranges)
        edf_file = os.path.join(base_subject_dir, subject_folder, edf_fname)
        # print(f"\nProcessing EDF file: {edf_file}")
        # print("Ranges:", ranges)
        
        pipe = Pipeline()
        pipe.CONFIG(
            fname=edf_file,
            fs=config["fs"],
            window_size=config["window_size"],
            overlap=config["overlap"],
            f_low=config["f_low"],
            f_high=config["f_high"],
            ranges_dict=ranges
        )
        
        combined_epochs, epoch_labels = pipe.run_pipeline()
        # print(f"Completed processing {edf_fname}. Number of epoch segments: {len(combined_epochs)}")
        
        # ***** The change is here: Instead of looping over epochs and replicating labels, 
        # simply extend the global lists with the epochs and labels returned by the pipeline.
        all_time_bins.extend(combined_epochs)
        all_labels.extend(epoch_labels)
        # ***** End of change.


KeyError: 'base_subject_dir'

In [5]:
# Concatenate along the time axis:
if all_time_bins:
    X = np.concatenate(all_time_bins, axis=-1)  # Final shape: (22, 5, total_time_bins)
else:
    X = None

y = np.array(all_labels)  # y has length equal to the total number of labels

In [18]:
# Mapping dictionary
mapping = {
    "Interictal": 0,
    "Preictal": 1
}

# vectorize the mapping
map_func = np.vectorize(mapping.get)
numeric_labels = map_func(copy_y)

print(numeric_labels)

[0 0 0 ... 0 0 0]


# Creating tensors for ML

In [22]:
import numpy as np

# Assuming X.shape = (22, 5, 71893) and y.shape = (71893,)
X_preictal = []
X_interictal = []

# Iterate over the timestamps
for i in range(X.shape[2]):  
    if numeric_labels[i] == 1:  # Preictal
        X_preictal.append(X[:, :, i])
    else:  # Interictal
        X_interictal.append(X[:, :, i])

# Convert lists to numpy arrays
X_preictal = np.array(X_preictal).transpose(1, 2, 0) if X_preictal else np.empty((22, 5, 0))
X_interictal = np.array(X_interictal).transpose(1, 2, 0) if X_interictal else np.empty((22, 5, 0))

# Print the shapes for verification
print("X_preictal shape:", X_preictal.shape)
print("X_interictal shape:", X_interictal.shape)


X_preictal shape: (22, 5, 6048)
X_interictal shape: (22, 5, 65845)


In [32]:
x1 = tf.tensor(X)
y1 = tf.tensor(numeric_labels)

In [36]:
x2 = x1.permute(2,0,1)
dataset = TensorDataset(x2, y1)

In [44]:
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_data, test_data = random_split(dataset, [train_size, test_size])

In [51]:
print(x1.shape)

torch.Size([22, 5, 71893])


# Pickling the data

In [48]:
import pickle

# Define file names
dataset_file = "dataset.pkl"


# Save the preictal and interictal data
with open(dataset_file, "wb") as f:
    pickle.dump(dataset, f)



In [56]:

class RandomDataset(Dataset):
    def __init__(self, num_samples=100, channels=3, height=32, width=32, num_classes=10):
        self.num_samples = num_samples
        # Generate random image data (e.g., resembling 32x32 RGB images)
        self.data = tf.randn(num_samples, channels, height, width)
        # Generate random labels between 0 and num_classes-1
        self.labels = tf.randint(0, num_classes, (num_samples,))
    
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

# Create the dataset and optionally inspect one batch
dataset = RandomDataset(num_samples=100)
loader = DataLoader(dataset, batch_size=10, shuffle=True)

# Print the shape of one batch as a quick test
for images, labels in loader:
    print("Batch images shape:", images.shape, "Batch labels shape:", labels.shape)
    break

# Save the dataset to disk using torch.save
dataset_filename = "random_dataset.pt"
tf.save(dataset, dataset_filename)
print(f"Dataset saved as {dataset_filename}")

Batch images shape: torch.Size([10, 3, 32, 32]) Batch labels shape: torch.Size([10])
Dataset saved as random_dataset.pt
