In [19]:
from aws_utils import s3_client, bucket_name, fs, s3_resource
import pyedflib
import io
import tempfile
from extraction import extract_interictal_preictal
from pipeline import Pipeline
import numpy as np
import torch
import pickle
import os
from pathlib import Path


# Creating AWS Session

# Loading in Parameters

In [11]:
import yaml
def load_config(config_path):
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config

# Load configuration from config.yaml
config = load_config("config.yaml")
subject_list = config["subject_range"]

# Get Ictal Ranges

In [12]:
# Get Files
summs = []
for patient_num in subject_list:
    summs.append(f"chb{patient_num:02d}/chb{patient_num:02d}-summary.txt")

# Auxillary Function
def get_summary_file_object(bucket_name, key):
    obj = s3_client.get_object(Bucket=bucket_name, Key=key)
    content = obj["Body"].read().decode("utf-8")
    return io.StringIO(content)

all_ranges = {}

# Get the file-like object from S3
for s3_key in summs:
    summary_file_obj = get_summary_file_object(bucket_name, s3_key)
    subject_ranges = extract_interictal_preictal(summary_file_obj)
    all_ranges[s3_key[0:5]] = subject_ranges


In [13]:
all_time_bins = []  # Will accumulate STFT outputs of shape (22, 5, t_i)
all_labels = []   
break_case = 0
for subj in subject_list:
    print(f"Processing subject: {subj}")
    pattern = f"s3://maniks-chb-mit/chb{subj:02d}/*.edf"
    ranges = all_ranges[f"chb{subj:02d}"]
    edf_files = fs.glob(pattern)
    
    break_case = break_case + 1
    for edf_path in edf_files:
        # Create a temporary file for the current EDF file
        print(edf_path)

        file_num = edf_path.split('/')[2]

        with tempfile.NamedTemporaryFile(delete=False, suffix=".edf") as tmp_file:
            tmp_filename = tmp_file.name

        # Download the EDF file from S3 to the temporary file
        fs.get(edf_path, tmp_filename)

        # Run the file through our pipeline
        pipe = Pipeline()
        pipe.CONFIG(
            fname=tmp_filename,
            fs=config["fs"],
            window_size=config["window_size"],
            overlap=config["overlap"],
            f_low=config["f_low"],
            f_high=config["f_high"],
            ranges_dict=ranges[file_num]
        )
        
        combined_epochs, epoch_labels = pipe.run_pipeline()
        
        # ***** The change is here: Instead of looping over epochs and replicating labels, 
        # simply extend the global lists with the epochs and labels returned by the pipeline.
        all_time_bins.extend(combined_epochs)
        all_labels.extend(epoch_labels)

        # Clean up: remove the temporary file
        os.remove(tmp_filename)


Processing subject: 1
maniks-chb-mit/chb01/chb01_01.edf
maniks-chb-mit/chb01/chb01_02.edf
maniks-chb-mit/chb01/chb01_03.edf
maniks-chb-mit/chb01/chb01_04.edf
maniks-chb-mit/chb01/chb01_05.edf
maniks-chb-mit/chb01/chb01_06.edf
maniks-chb-mit/chb01/chb01_07.edf
maniks-chb-mit/chb01/chb01_08.edf
maniks-chb-mit/chb01/chb01_09.edf
maniks-chb-mit/chb01/chb01_10.edf
maniks-chb-mit/chb01/chb01_11.edf
maniks-chb-mit/chb01/chb01_12.edf
maniks-chb-mit/chb01/chb01_13.edf
maniks-chb-mit/chb01/chb01_14.edf
maniks-chb-mit/chb01/chb01_15.edf
maniks-chb-mit/chb01/chb01_16.edf
maniks-chb-mit/chb01/chb01_17.edf
maniks-chb-mit/chb01/chb01_18.edf
maniks-chb-mit/chb01/chb01_19.edf
maniks-chb-mit/chb01/chb01_20.edf
maniks-chb-mit/chb01/chb01_21.edf
maniks-chb-mit/chb01/chb01_22.edf
maniks-chb-mit/chb01/chb01_23.edf
maniks-chb-mit/chb01/chb01_24.edf
maniks-chb-mit/chb01/chb01_25.edf
maniks-chb-mit/chb01/chb01_26.edf
maniks-chb-mit/chb01/chb01_27.edf
maniks-chb-mit/chb01/chb01_29.edf
maniks-chb-mit/chb01/chb01

# Reshape Data

In [14]:
expanded = []
for obj in all_time_bins:
    obj_exp = np.expand_dims(obj, axis=0)
    obj_exp = np.squeeze(obj_exp, axis=3)
    expanded.append(obj_exp)

In [15]:
# Concatenate along the time axis: bake
if all_time_bins:
    X = np.concatenate(expanded, axis=0)  # Final shape: (22, 5, total_time_bins)
else:
    X = None

y = np.array(all_labels)  # y has length equal to the total number of labels
X_tf = torch.tensor(X, dtype=torch.float32)
y_tf = torch.tensor(y, dtype=torch.long)


In [20]:
local_path = Path("final_data.pkl")
with local_path.open("wb") as f:
    pickle.dump({"X": X_tf, "y": y_tf}, f, protocol=4)

# ------------------------------------------------------------------
# upload to  s3://maniks-chb-mit/final_data.pkl  --------------------
bucket_name = "maniks-chb-mit"
object_key  = "final_data.pkl"          # ⇠ root of the bucket


# s3_resource = session.resource("s3")

s3_resource.Bucket(bucket_name).upload_file(
    Filename=str(local_path),
    Key=object_key
)

print(f"✅  uploaded → s3://{bucket_name}/{object_key}")


✅  uploaded → s3://maniks-chb-mit/final_data.pkl


## Check for Data Imbalances

In [26]:
print("total samples :", len(y))
print("  pre‑ictal 1 :", np.sum(y == 1))
print("inter‑ictal 0 :", np.sum(y == 0))

total samples : 14464
  pre‑ictal 1 : 3231
inter‑ictal 0 : 11233
