In [1]:
import boto3
import pyedflib
import io
import tempfile
import s3fs
from dotenv import dotenv_values, load_dotenv
from extraction import extract_interictal_preictal
from pipeline import Pipeline
import os

load_dotenv()  # This will load variables from .env into os.environ

api_key = os.getenv("aws_access_key_id")
secret_key = os.getenv("aws_secret_access_key")
region = os.getenv("region_name")

# Creating AWS Session

In [2]:
env_variables = dotenv_values(".env")  # Reads local .env if it exists

api_key = env_variables.get("aws_access_key_id") or os.getenv("aws_access_key_id")
secret_key = env_variables.get("aws_secret_access_key") or os.getenv("aws_secret_access_key")
region = env_variables.get("region_name") or os.getenv("region_name")

session = boto3.Session(
    aws_access_key_id=api_key,
    aws_secret_access_key=secret_key,
    region_name=region
)

fs = s3fs.S3FileSystem(
    key=api_key,
    secret=secret_key,
    client_kwargs={'region_name': region}
)
s3 = session.client("s3")

bucket_name = "maniks-chb-mit"

# Loading in Parameters

In [3]:
import yaml
def load_config(config_path):
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config

# Load configuration from config.yaml
config = load_config("config.yaml")
subject_list = config["subject_range"]

# Get Ictal Ranges

In [4]:
# Get Files
summs = []
for patient_num in subject_list:
    summs.append(f"chb{patient_num:02d}/chb{patient_num:02d}-summary.txt")

# Auxillary Function
def get_summary_file_object(bucket_name, key):
    obj = s3.get_object(Bucket=bucket_name, Key=key)
    content = obj["Body"].read().decode("utf-8")
    return io.StringIO(content)

all_ranges = {}

# Get the file-like object from S3
for s3_key in summs:
    summary_file_obj = get_summary_file_object(bucket_name, s3_key)
    subject_ranges = extract_interictal_preictal(summary_file_obj)
    all_ranges[s3_key[0:5]] = subject_ranges


In [5]:
subject_list[:4]

[1, 2, 3, 5]

In [6]:
all_time_bins = []  # Will accumulate STFT outputs of shape (22, 5, t_i)
all_labels = []   
break_case = 0
for subj in subject_list[:4]:
    print(f"Processing subject: {subj}")
    pattern = f"s3://maniks-chb-mit/chb{subj:02d}/*.edf"
    ranges = all_ranges[f"chb{subj:02d}"]
    edf_files = fs.glob(pattern)
    
    break_case = break_case + 1
    for edf_path in edf_files:
        # Create a temporary file for the current EDF file
        print(edf_path)

        file_num = edf_path.split('/')[2]

        with tempfile.NamedTemporaryFile(delete=False, suffix=".edf") as tmp_file:
            tmp_filename = tmp_file.name

        # Download the EDF file from S3 to the temporary file
        fs.get(edf_path, tmp_filename)

        # Run the file through our pipeline
        pipe = Pipeline()
        pipe.CONFIG(
            fname=tmp_filename,
            fs=config["fs"],
            window_size=config["window_size"],
            overlap=config["overlap"],
            f_low=config["f_low"],
            f_high=config["f_high"],
            ranges_dict=ranges[file_num]
        )
        
        combined_epochs, epoch_labels = pipe.run_pipeline()
        
        # ***** The change is here: Instead of looping over epochs and replicating labels, 
        # simply extend the global lists with the epochs and labels returned by the pipeline.
        all_time_bins.extend(combined_epochs)
        all_labels.extend(epoch_labels)

        # Clean up: remove the temporary file
        os.remove(tmp_filename)


Processing subject: 1
maniks-chb-mit/chb01/chb01_01.edf
maniks-chb-mit/chb01/chb01_02.edf
maniks-chb-mit/chb01/chb01_03.edf
maniks-chb-mit/chb01/chb01_04.edf
maniks-chb-mit/chb01/chb01_05.edf
maniks-chb-mit/chb01/chb01_06.edf
maniks-chb-mit/chb01/chb01_07.edf
maniks-chb-mit/chb01/chb01_08.edf
maniks-chb-mit/chb01/chb01_09.edf
maniks-chb-mit/chb01/chb01_10.edf
maniks-chb-mit/chb01/chb01_11.edf
maniks-chb-mit/chb01/chb01_12.edf
maniks-chb-mit/chb01/chb01_13.edf
maniks-chb-mit/chb01/chb01_14.edf
maniks-chb-mit/chb01/chb01_15.edf
maniks-chb-mit/chb01/chb01_16.edf
maniks-chb-mit/chb01/chb01_17.edf
maniks-chb-mit/chb01/chb01_18.edf
maniks-chb-mit/chb01/chb01_19.edf
maniks-chb-mit/chb01/chb01_20.edf
maniks-chb-mit/chb01/chb01_21.edf
maniks-chb-mit/chb01/chb01_22.edf
maniks-chb-mit/chb01/chb01_23.edf
maniks-chb-mit/chb01/chb01_24.edf
maniks-chb-mit/chb01/chb01_25.edf
maniks-chb-mit/chb01/chb01_26.edf
maniks-chb-mit/chb01/chb01_27.edf
maniks-chb-mit/chb01/chb01_29.edf
maniks-chb-mit/chb01/chb01

In [7]:
for i in range(len(all_time_bins)):
    print(all_time_bins[i].shape)

(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1471, 1)
(22, 1

In [None]:
# Concatenate along the time axis: bake
if all_time_bins:
    X = np.concatenate(all_time_bins, axis=-1)  # Final shape: (22, 5, total_time_bins)
else:
    X = None

y = np.array(all_labels)  # y has length equal to the total number of labels

In [None]:
# Mapping dictionary
mapping = {
    "Interictal": 0,
    "Preictal": 1
}

# vectorize the mapping
map_func = np.vectorize(mapping.get)
numeric_labels = map_func(copy_y)

print(numeric_labels)

In [None]:
x1 = tf.tensor(X)
y1 = tf.tensor(numeric_labels)

In [None]:
x2 = x1.permute(2,0,1)
dataset = TensorDataset(x2, y1)