 Imports

In [9]:
from statistics import LinearRegression
import json
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from math import ceil
from scipy import stats

import pandas as pd
import sklearn
from sklearn.svm import SVR
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn import linear_model

import os
from PIL import Image, ImageDraw, ImageFont
from ultralytics import YOLO
import random

import specparam

from neurodsp.spectral import compute_spectrum
from neurodsp.burst import detect_bursts_dual_threshold, compute_burst_stats
from neurodsp.plts.time_series import plot_time_series, plot_bursts

import itertools

 analysis functions

In [10]:
# the simulated signals were presented in randomized order to each human labeler. The ordering was recorded as a list of indices.
# the selections the labelers made were recorded in the order they were made.
# This function finds the index at which the desired signal index appears in the ordering of signal indices.
def reverse_order_search(j, order):
    for i in range(len(order)):
        if j == order[i]:
            return i


# the linear coefficients come from a linear regression model searching for the best fit line between
# signal-noise ratio and _ratio parameter used to generate signals
def ratio_to_snr_converter(_ratio):
    snr = -19.65 * (_ratio) + 9.668
    return snr


# this decodes ONE entry in the list of params used to generate one signal.
# the function returns a dictionary with named parameters for visual inspection
# this function is only used when investigating the signal with the worst f1 score.
def decode_params(param_list) -> dict | None:
    if len(param_list) != 5:
        return None
    retDict = {}
    retDict["freq"] = 2 * param_list[0]
    retDict["n_cycles"] = param_list[1]
    retDict["rise-decay asymmetry"] = param_list[2]
    retDict["aperiodic exponent"] = param_list[3]
    retDict["signal-noise ratio"] = ratio_to_snr_converter(param_list[4])
    return retDict


# this decodes ONE entry in the list of params used to generate one signal.
# the function returns a numpy array with the same parameters as decode_params.
# this function is used when preparing data for regression analysis.
# It differs from decode_params in that it returns a numpy array instead of a dictionary.
def decode_params_np(param_list) -> dict | None:
    if len(param_list) != 5:
        return None
    retArray = np.zeros(5)
    retArray[0] = 2 * param_list[0]
    retArray[1] = param_list[1]
    retArray[2] = param_list[2]
    retArray[3] = param_list[3]
    retArray[4] = ratio_to_snr_converter(param_list[4])
    return retArray


# this function takes a list of parameters and returns a dictionary with named parameters.
# the function is used when preparing data for regression analysis.
# this function calls decode_params to ensure frequency and snr are accurate.
def param_list_to_training_data(param_list):
    num_samples = len(param_list)
    num_features = 5
    retArray = np.zeros((num_samples, num_features))
    for i in range(len(param_list)):
        row = decode_params_np(param_list[i])
        retArray[i][0:num_features] = row[:num_features]
    return retArray


def create_signal_images(signal_data, output_directory):
    """
    Save signal data as cropped images with specific requirements.

    Parameters:
        signal_data (list or array-like): A list of signals, where each signal is an array of amplitude values.
        output_directory (str): Directory where the images will be saved.
    """
    dpi = 100
    figsize_width = 1000.0 / float(dpi)
    figsize_height = 1.0

    if not os.path.exists(output_directory):
        os.makedirs(output_directory)
    else:
        print(f"Directory {output_directory} already exists. Skipping image creation.")
        return

    for i, signal in enumerate(signal_data):
        if i % 100 == 0:
            print(i)

        # Normalize signal only if its full y-axis isn't in range [-3, 3]
        signal_min, signal_max = np.min(signal), np.max(signal)
        if signal_min < -3 or signal_max > 3:
            signal = (signal - signal_min) / (
                signal_max - signal_min
            ) * 5.8 - 2.9  # Normalize to [-3, 3]

        filename = f"sig_{i}.png"
        filepath = os.path.join(output_directory, filename)

        # Create the plot
        fig = plt.figure(figsize=(figsize_width, figsize_height), dpi=dpi)
        plt.ylim(-3, 3)

        # Remove axes and internal padding
        plt.gca().set_axis_off()
        plt.subplots_adjust(left=0, right=1, top=1, bottom=0)

        # Plot the signal
        plt.plot(signal)
        plt.savefig(
            filepath,
            bbox_inches="tight",  # Crop tightly to the plot content
            pad_inches=0,  # Remove any padding
            transparent=False,  # Optional: Save with a transparent background
        )
        plt.close(fig)

        # Crop the image
        img = Image.open(filepath)
        box = (45, 0, 955, 100)  # Define the cropping box
        img = img.crop(box)
        img.save(filepath)

        print(f"Saved cropped signal image to: {filepath}")

def num_bursting_intervals(is_burst) -> int:
    bursting_interval_count = 0
    if is_burst[0]:
        bursting_interval_count=1
    siglen = len(is_burst)
    for i in range(0, siglen-1):
        val_at_i = is_burst[i]
        val_at_i_plus_1 = is_burst[i+1]
        if (not val_at_i) and val_at_i_plus_1:
            bursting_interval_count += 1

    return bursting_interval_count

# accommodates multiple intervals right now. can be optimized (a tiny bit) by restricting one interval
def get_bursting_intervals(is_burst, num_intervals):
    bursting_intervals = [[0,0]]*num_intervals
    burst_interval_index = 0
    siglen = len(is_burst)

    if is_burst[0]:
        bursting_intervals[0][0]=0
    for i in range(0, siglen-1):
        val_at_i = is_burst[i]
        val_at_i_plus_1 = is_burst[i+1]
        if val_at_i and (not val_at_i_plus_1):
            bursting_intervals[burst_interval_index][1]=i
            burst_interval_index += 1
        elif (not val_at_i) and val_at_i_plus_1:
            bursting_intervals[burst_interval_index][0]=i
    return bursting_intervals




 Import File Data

In [11]:

# Load data from results json exported from firebase
with open("./voyteklabstudy-default-rtdb-export.json") as f:
    results = json.load(f)


 Set constants

In [12]:


# this is the number of real recorded EEG signal we used in the study platform we hosted. The
# signals were arranged in (real signals, simulated signals) order. Thus, num_real_sigs is used
# as an array offset in this analysis.
num_real_sigs = 49

fs=1000

# List of names of human collaborators who labeled data
# length of which gives us number of labelings. Allows us to iterate through labelers
who = list(results["selections"].keys())
print(who)

# the classes are "non-bursting" and "bursting"
num_classes = 2


onsets = np.zeros((num_real_sigs, len(who)))
offsets = np.zeros((num_real_sigs, len(who)))

['Andrew Bender@1714089263343', 'Bradley Voytek Apr22 2024@1713819121072', 'Dillan@1713909205994', 'Eena Kosik@1713821677039', 'MJ@1714513556139', 'Quirine@1714514427397', 'Ryan Hammonds@1713819289745', 'Sydney Smith@1714416232441', 'rgao@1715689500559']


In [13]:
# Setup process for dual threshold burst detection
y_pred = np.zeros((num_real_sigs, 2))
num_pred = 0
for i in range(num_real_sigs):
    # Here we have code to execute the burst labeling.
    test_signal = results["sigs"]["sig_" + str(i)]
    test_signal = np.array(test_signal)

    freqs, power_spectral_density = compute_spectrum(fs=fs, sig=test_signal)
    sm = specparam.SpectralModel(peak_width_limits=[1.0, 8.0], max_n_peaks=8)
    sm.fit(freqs, power_spectrum=power_spectral_density)
    [center_frequency, log_power, bandwidth] = specparam.analysis.get_band_peak(
        sm, [10, 20], select_highest=True
    )
    print(center_frequency)

    is_burst = detect_bursts_dual_threshold(
        sig=np.array(test_signal), fs=fs, f_range=(9, 21), dual_thresh=(1, 2)
    )
    # plot_bursts(np.linspace(0, 1, 1000), test_signal, is_burst)

    num_bursts = num_bursting_intervals(is_burst)
    if num_bursts > 1:
        print(3 / 0)
    intervals: list[list[int]] = get_bursting_intervals(
        num_intervals=num_bursts, is_burst=is_burst
    )
    # should only be one interval:
    if len(intervals) == 0:
        y_pred[i] = [np.nan, np.nan]
    else:
        y_pred[i] = [intervals[0][0], intervals[0][1]]
    print(y_pred[i])

for element in y_pred:
    element[0] = element[0] * (910.0 / 1000.0)
    element[1] = element[1] * (910.0 / 1000.0)



	Lower bounds below frequency-resolution have no effect (effective lower bound is the frequency resolution).
	Too low a limit may lead to overfitting noise as small bandwidth peaks.
	We recommend a lower bound of approximately 2x the frequency resolution.

nan
[        568         987]


	Lower bounds below frequency-resolution have no effect (effective lower bound is the frequency resolution).
	Too low a limit may lead to overfitting noise as small bandwidth peaks.
	We recommend a lower bound of approximately 2x the frequency resolution.

nan
[        530         891]


	Lower bounds below frequency-resolution have no effect (effective lower bound is the frequency resolution).
	Too low a limit may lead to overfitting noise as small bandwidth peaks.
	We recommend a lower bound of approximately 2x the frequency resolution.

19.518556702910356
[        524         945]


	Lower bounds below frequency-resolution have no effect (effective lower bound is the frequency resolution).
	Too lo

In [14]:
#Setup process for YOLO


# Here we want to generate an image dataset from the signals.
test_dir = "signal_images"
collection_real_sigs = [results["sigs"]["sig_" + str(i)] for i in range(num_real_sigs)]
create_signal_images(collection_real_sigs, test_dir)

# now we want to predict the onset and offset of the signals with yolo.

output_collage = "collage_dualthresh_vs_yolo.png"

# Load the model
model = YOLO(
    "/Users/kenton/HOME/coding/python/publish_the_paper/runs/detect/train50/weights/best.pt"
)

# Get all image files in the directory
all_images = [
    os.path.join(test_dir, f)
    for f in os.listdir(test_dir)
    if f.lower().endswith((".png", ".jpg", ".jpeg"))
]

# Optional: Load a font for better text rendering
try:
    font = ImageFont.truetype(
        "arial.ttf", size=16
    )  # Use a font installed on your system
except IOError:
    font = ImageFont.load_default()


Directory signal_images already exists. Skipping image creation.


In [15]:

# Annotate each selected image and record the predicted onsets and offsets
pred_onsets = np.zeros(num_real_sigs)
pred_offsets = np.zeros(num_real_sigs)
annotated_images = []
for i in range(len(all_images)):
    image_path = all_images[i]
    # Predict results for the image
    results = model.predict(source=image_path, conf=0.05)

    # Load the image using PIL
    image = Image.open(image_path).convert("RGB")
    draw = ImageDraw.Draw(image)

    # Process YOLO model predictions
    for r in results:
        found = False
        for box in r.boxes.data:
            # Extract bounding box and class information
            x1, y1, x2, y2, confidence, class_id = box.tolist()
            class_name = model.names[
                int(class_id)
            ]  # Get class name using model's class names
            if class_id == 1:
                print(f"box bounds: {x1}, {x2}")
                pred_onsets[i] = x1
                pred_offsets[i] = x2
                found = True
            else:
                continue

            # Draw the bounding box
            draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
            draw.rectangle([y_pred[i][0], 2, y_pred[i][1], 98], outline="green", width=3)

            # Create a label
            label = f"{class_name} ({confidence:.2f})"

            # Draw label inside the bounding box
            text_bbox = draw.textbbox((x1, y1), label, font=font)
            text_width = text_bbox[2] - text_bbox[0]
            text_height = text_bbox[3] - text_bbox[1]

            # Position text inside the bounding box, adjusted to fit
            label_x = max(x1, 0) + 2
            label_y = max(y1, 0) + 2

            # Draw label background and text
            draw.rectangle(
                [label_x, label_y, label_x + text_width, label_y + text_height],
                fill="red",
            )
            draw.text((label_x, label_y), label, fill="white", font=font)
        if not found:
            pred_onsets[i] = np.nan
            pred_offsets[i] = np.nan

    # Add a black border around the image
    border_size = 5
    bordered_image = Image.new(
        "RGB",
        (image.width + 2 * border_size, image.height + 2 * border_size),
        color="black",
    )
    bordered_image.paste(image, (border_size, border_size))

    # Ensure the image is resized to 910x100
    resized_image = bordered_image.resize((910, 100))  # Natural resolution
    # resized_image.show(f"Image {i}")
    annotated_images.append(resized_image)




image 1/1 /Users/kenton/HOME/coding/python/publish_the_paper/signal_images/sig_16.png: 64x416 4 non-bursts, 1 burst, 31.6ms
Speed: 3.6ms preprocess, 31.6ms inference, 11.0ms postprocess per image at shape (1, 3, 64, 416)
box bounds: 69.64598846435547, 351.0574951171875

image 1/1 /Users/kenton/HOME/coding/python/publish_the_paper/signal_images/sig_17.png: 64x416 4 non-bursts, 1 burst, 10.9ms
Speed: 0.4ms preprocess, 10.9ms inference, 0.3ms postprocess per image at shape (1, 3, 64, 416)
box bounds: 412.0537414550781, 709.7371826171875

image 1/1 /Users/kenton/HOME/coding/python/publish_the_paper/signal_images/sig_15.png: 64x416 2 non-bursts, 1 burst, 9.7ms
Speed: 0.2ms preprocess, 9.7ms inference, 0.4ms postprocess per image at shape (1, 3, 64, 416)
box bounds: 277.7815246582031, 554.9505615234375

image 1/1 /Users/kenton/HOME/coding/python/publish_the_paper/signal_images/sig_29.png: 64x416 4 non-bursts, 3 bursts, 8.8ms
Speed: 0.2ms preprocess, 8.8ms inference, 0.3ms postprocess per im

In [16]:
# Determine collage dimensions
collage_width = 910  # Each image's width
collage_images_per_row = 3  # Number of images per row
collage_rows = ceil(len(annotated_images) / collage_images_per_row)
collage_height = collage_rows * 100  # 100 pixels per image height

# Create the blank collage canvas
collage = Image.new(
    "RGB", (collage_width * collage_images_per_row, collage_height), color="white"
)

# Paste each image into the collage
for i, annotated_image in enumerate(annotated_images):
    row = i // collage_images_per_row
    col = i % collage_images_per_row
    x_offset = col * 910
    y_offset = row * 100
    collage.paste(annotated_image, (x_offset, y_offset))

# Save the collage
collage.save(output_collage)
print(f"Collage saved to {output_collage}")
# collage.show()

print("onsets", onsets)
print("predicted onsets", onsets)
print("offsets", offsets)
print("predicted offsets", offsets)

valid_indices = np.where(~np.isnan(pred_onsets))
valid_indices2 = np.where(~np.isnan(y_pred[:,0]))

valid_indices = np.intersect1d(valid_indices, valid_indices2)

diff_onsets = y_pred[valid_indices,0] - pred_onsets[valid_indices]
diff_offsets = y_pred[valid_indices,1] - pred_offsets[valid_indices]
# print("diff onsets", diff_onsets)

avg_diff_onsets = np.mean(diff_onsets)
avg_diff_offsets = np.mean(diff_offsets)

print("average onset error, not counting missed onsets", avg_diff_onsets)
print("average offset error, not counting missed onsets", avg_diff_offsets)

find_rate = np.count_nonzero(~np.isnan(pred_onsets)) / num_real_sigs
print("find rate", find_rate)

# for i in range(len(onsets)):
#     print(f"real vs predicted onset for signal {center_onsets[i]} vs {pred_onsets[i]}")
#     print(f"real vs predicted offset for signal {center_offsets[i]} vs {pred_offsets[i]}")
#     print(f"diff between onsets for signal {i}: {center_onsets[i] - pred_onsets[i]}")
#     print(f"diff between offsets for signal {i}: {center_offsets[i] - pred_offsets[i]}")
#     print() # newline

# print("average onset error, not counting missed onsets", np.mean(onsets[onsets > 0] - onsets[onsets > 0]))


Collage saved to collage_dualthresh_vs_yolo.png
onsets [[          0           0           0           0           0           0           0           0           0]
 [          0           0           0           0           0           0           0           0           0]
 [          0           0           0           0           0           0           0           0           0]
 [          0           0           0           0           0           0           0           0           0]
 [          0           0           0           0           0           0           0           0           0]
 [          0           0           0           0           0           0           0           0           0]
 [          0           0           0           0           0           0           0           0           0]
 [          0           0           0           0           0           0           0           0           0]
 [          0           0           0           0        