 Imports

In [47]:
from statistics import LinearRegression
import json
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from math import ceil
from scipy import stats

import pandas as pd
import sklearn
from sklearn.svm import SVR
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn import linear_model

import os
from PIL import Image, ImageDraw, ImageFont
from ultralytics import YOLO
import random

import specparam

from neurodsp.spectral import compute_spectrum
from neurodsp.burst import detect_bursts_dual_threshold, compute_burst_stats
from neurodsp.plts.time_series import plot_time_series, plot_bursts
import warnings

from bycycle.features import compute_features
from bycycle.plts import plot_burst_detect_summary
from bycycle import BycycleGroup


from neurodsp.filt import filter_signal

import itertools

 analysis functions

In [48]:
# the simulated signals were presented in randomized order to each human labeler. The ordering was recorded as a list of indices.
# the selections the labelers made were recorded in the order they were made.
# This function finds the index at which the desired signal index appears in the ordering of signal indices.
def reverse_order_search(j, order):
    for i in range(len(order)):
        if j == order[i]:
            return i


# the linear coefficients come from a linear regression model searching for the best fit line between
# signal-noise ratio and _ratio parameter used to generate signals
def ratio_to_snr_converter(_ratio):
    snr = -19.65 * (_ratio) + 9.668
    return snr


# this decodes ONE entry in the list of params used to generate one signal.
# the function returns a dictionary with named parameters for visual inspection
# this function is only used when investigating the signal with the worst f1 score.
def decode_params(param_list) -> dict | None:
    if len(param_list) != 5:
        return None
    retDict = {}
    retDict["freq"] = 2 * param_list[0]
    retDict["n_cycles"] = param_list[1]
    retDict["rise-decay asymmetry"] = param_list[2]
    retDict["aperiodic exponent"] = param_list[3]
    retDict["signal-noise ratio"] = ratio_to_snr_converter(param_list[4])
    return retDict


# this decodes ONE entry in the list of params used to generate one signal.
# the function returns a numpy array with the same parameters as decode_params.
# this function is used when preparing data for regression analysis.
# It differs from decode_params in that it returns a numpy array instead of a dictionary.
def decode_params_np(param_list) -> dict | None:
    if len(param_list) != 5:
        return None
    retArray = np.zeros(5)
    retArray[0] = 2 * param_list[0]
    retArray[1] = param_list[1]
    retArray[2] = param_list[2]
    retArray[3] = param_list[3]
    retArray[4] = ratio_to_snr_converter(param_list[4])
    return retArray


# this function takes a list of parameters and returns a dictionary with named parameters.
# the function is used when preparing data for regression analysis.
# this function calls decode_params to ensure frequency and snr are accurate.
def param_list_to_training_data(param_list):
    num_samples = len(param_list)
    num_features = 5
    retArray = np.zeros((num_samples, num_features))
    for i in range(len(param_list)):
        row = decode_params_np(param_list[i])
        retArray[i][0:num_features] = row[:num_features]
    return retArray


def create_signal_images(signal_data, output_directory):
    """
    Save signal data as cropped images with specific requirements.

    Parameters:
        signal_data (list or array-like): A list of signals, where each signal is an array of amplitude values.
        output_directory (str): Directory where the images will be saved.
    """
    dpi = 100
    figsize_width = 1000.0 / float(dpi)
    figsize_height = 1.0

    if not os.path.exists(output_directory):
        os.makedirs(output_directory)
    else:
        print(f"Directory {output_directory} already exists. Skipping image creation.")
        return

    for i, signal in enumerate(signal_data):
        if i % 100 == 0:
            print(i)

        # Normalize signal only if its full y-axis isn't in range [-3, 3]
        signal_min, signal_max = np.min(signal), np.max(signal)
        if signal_min < -3 or signal_max > 3:
            signal = (signal - signal_min) / (
                signal_max - signal_min
            ) * 5.8 - 2.9  # Normalize to [-3, 3]

        filename = f"sig_{i}.png"
        filepath = os.path.join(output_directory, filename)

        # Create the plot
        fig = plt.figure(figsize=(figsize_width, figsize_height), dpi=dpi)
        plt.ylim(-3, 3)

        # Remove axes and internal padding
        plt.gca().set_axis_off()
        plt.subplots_adjust(left=0, right=1, top=1, bottom=0)

        # Plot the signal
        plt.plot(signal)
        plt.savefig(
            filepath,
            bbox_inches="tight",  # Crop tightly to the plot content
            pad_inches=0,  # Remove any padding
            transparent=False,  # Optional: Save with a transparent background
        )
        plt.close(fig)

        # Crop the image
        img = Image.open(filepath)
        box = (45, 0, 955, 100)  # Define the cropping box
        img = img.crop(box)
        img.save(filepath)

        print(f"Saved cropped signal image to: {filepath}")

def num_bursting_intervals(is_burst) -> int:
    bursting_interval_count = 0
    if is_burst[0]:
        bursting_interval_count=1
    siglen = len(is_burst)
    for i in range(0, siglen-1):
        val_at_i = is_burst[i]
        val_at_i_plus_1 = is_burst[i+1]
        if (not val_at_i) and val_at_i_plus_1:
            bursting_interval_count += 1

    return bursting_interval_count

# accommodates multiple intervals right now. can be optimized (a tiny bit) by restricting one interval
def get_bursting_intervals(is_burst):
    burst_starts = np.where(np.diff(is_burst.astype(int)) == 1)[0] + 1
    burst_ends = np.where(np.diff(is_burst.astype(int)) == -1)[0] + 1

    burst_intervals = list(zip(burst_starts, burst_ends))

    retVal = np.zeros((len(burst_intervals), 2))
    for i in range(len(burst_intervals)):
        retVal[i][0] = burst_intervals[i][0]
        retVal[i][1] = burst_intervals[i][1]
    # print(burst_intervals)
    return retVal

def is_burst_to_window_bounds(is_burst):
    retVal=[]
    curr_pair = [-1,-1]
    pair_active=False
    for i in range(len(is_burst)):
        if not pair_active and is_burst[i]:
            pair_active=True
            curr_pair[0]=i
        
        elif pair_active and not is_burst[i]:
            curr_pair[1]=i
            retVal.append(curr_pair)
            pair_active=False
    
    if pair_active:
        curr_pair[1]=len(is_burst)
        retVal.append(curr_pair)
        
    return retVal

def merge_burst_selections(detections_for_signal):
    """
    (Mutates Parameters): Process YOLO burst detections to extract and average onsets and offsets.
    for each burst, the function calculates the average onset and offset times.

    Parameters:
        burst_detections list with length i: A list of (onset,offset) pairs
    Returns:
        Nothing
    """

    # Sort intervals by start time
    detections_for_signal.sort(key=lambda x: x[0])

    if len(detections_for_signal)==0:
        return []
    # naive solution, O(n) runtime. Good.
    combine_table = np.full(len(detections_for_signal) - 1, False)
    len_table = len(combine_table)
    for i in range(len_table):
        last_start = detections_for_signal[i][0]
        last_end = detections_for_signal[i][1]
        next_start = detections_for_signal[i + 1][0]
        next_end = detections_for_signal[i+1][1]
        if next_start <= last_end or last_start <= next_end:
            combine_table[i] = True

    for i in range(len_table):
        if combine_table[len_table - i - 1]:
            detections_for_signal[len_table - i-1][0] = min(
                detections_for_signal[len_table - i - 1][0],
                detections_for_signal[len_table - i][0],
            )
            detections_for_signal[len_table - i-1][1] = max(
                detections_for_signal[len_table - i - 1][1],
                detections_for_signal[len_table - i][1],
            )
            # we want to do longest common interval, for each interval in the set. So we want to take each interval out of the bag at some point.
            detections_for_signal.pop(len_table - i)





 Import File Data

In [49]:

# Load data from results json exported from firebase
with open("./voyteklabstudy-default-rtdb-export.json") as f:
    results = json.load(f)


 Set constants

In [50]:


# this is the number of real recorded EEG signal we used in the study platform we hosted. The
# signals were arranged in (real signals, simulated signals) order. Thus, num_real_sigs is used
# as an array offset in this analysis.
num_real_sigs = 49



fs=1000

# List of names of human collaborators who labeled data
# length of which gives us number of labelings. Allows us to iterate through labelers
who = list(results["selections"].keys())
print(who)

# the classes are "non-bursting" and "bursting"
num_classes = 2


onsets = np.zeros((num_real_sigs, len(who)))
offsets = np.zeros((num_real_sigs, len(who)))


['Andrew Bender@1714089263343', 'Bradley Voytek Apr22 2024@1713819121072', 'Dillan@1713909205994', 'Eena Kosik@1713821677039', 'MJ@1714513556139', 'Quirine@1714514427397', 'Ryan Hammonds@1713819289745', 'Sydney Smith@1714416232441', 'rgao@1715689500559']


This is the analysis on the real signals.

In [51]:
# Setup process for dual threshold burst detection
y_pred = [[]]*num_real_sigs
num_pred = 0
for i in range(num_real_sigs):
    # Here we have code to execute the burst labeling.
    test_signal = results["sigs"]["sig_" + str(i)]
    test_signal = np.array(test_signal)

    freqs, power_spectral_density = compute_spectrum(fs=fs, sig=test_signal)
    sm = specparam.SpectralModel(peak_width_limits=[1.0, 8.0], max_n_peaks=8)
    sm.fit(freqs, power_spectrum=power_spectral_density)
    [center_frequency, log_power, bandwidth] = specparam.analysis.get_band_peak(
        sm, [10, 20], select_highest=True
    )
    print(center_frequency)

    is_burst = detect_bursts_dual_threshold(
        sig=np.array(test_signal), fs=fs, f_range=(10, 20), dual_thresh=(1, 2)
    )
    # plot_bursts(np.linspace(0, 1, 1000), test_signal, is_burst)

    intervals = is_burst_to_window_bounds(is_burst)

    y_pred[i]=intervals
    
for i in range(len(y_pred)):
    for j in range(len(y_pred[i])):
        for k in range(len(y_pred[i][j])):
            y_pred[i][j][k]*=.91



	Lower bounds below frequency-resolution have no effect (effective lower bound is the frequency resolution).
	Too low a limit may lead to overfitting noise as small bandwidth peaks.
	We recommend a lower bound of approximately 2x the frequency resolution.

nan


	Lower bounds below frequency-resolution have no effect (effective lower bound is the frequency resolution).
	Too low a limit may lead to overfitting noise as small bandwidth peaks.
	We recommend a lower bound of approximately 2x the frequency resolution.

nan


	Lower bounds below frequency-resolution have no effect (effective lower bound is the frequency resolution).
	Too low a limit may lead to overfitting noise as small bandwidth peaks.
	We recommend a lower bound of approximately 2x the frequency resolution.

19.518556702910356


	Lower bounds below frequency-resolution have no effect (effective lower bound is the frequency resolution).
	Too low a limit may lead to overfitting noise as small bandwidth peaks.
	We recommen

In [52]:
#Setup process for YOLO


# Here we want to generate an image dataset from the signals.
test_dir = "signal_images"
collection_real_sigs = [results["sigs"]["sig_" + str(i)] for i in range(num_real_sigs)]
for i in range(len(collection_real_sigs)):
    collection_real_sigs[i] = filter_signal(np.array(collection_real_sigs[i]), fs, 'lowpass', 30, n_seconds=.2, remove_edges=False)

create_signal_images(collection_real_sigs, test_dir)

# now we want to predict the onset and offset of the signals with yolo.

output_collage = "collage_dualthresh_vs_yolo.png"

# Load the model
model = YOLO(
    "/Users/kenton/HOME/coding/python/publish_the_paper/runs/detect/train50/weights/best.pt"
    # "./best.pt"
)

# Get all image files in the directory
all_images = [
    os.path.join(test_dir, f)
    for f in os.listdir(test_dir)
    if f.lower().endswith((".png", ".jpg", ".jpeg"))
]

# Optional: Load a font for better text rendering
try:
    font = ImageFont.truetype(
        "arial.ttf", size=16
    )  # Use a font installed on your system
except IOError:
    font = ImageFont.load_default()


0
Saved cropped signal image to: signal_images/sig_0.png
Saved cropped signal image to: signal_images/sig_1.png
Saved cropped signal image to: signal_images/sig_2.png
Saved cropped signal image to: signal_images/sig_3.png
Saved cropped signal image to: signal_images/sig_4.png
Saved cropped signal image to: signal_images/sig_5.png
Saved cropped signal image to: signal_images/sig_6.png
Saved cropped signal image to: signal_images/sig_7.png
Saved cropped signal image to: signal_images/sig_8.png
Saved cropped signal image to: signal_images/sig_9.png
Saved cropped signal image to: signal_images/sig_10.png
Saved cropped signal image to: signal_images/sig_11.png
Saved cropped signal image to: signal_images/sig_12.png
Saved cropped signal image to: signal_images/sig_13.png
Saved cropped signal image to: signal_images/sig_14.png
Saved cropped signal image to: signal_images/sig_15.png
Saved cropped signal image to: signal_images/sig_16.png
Saved cropped signal image to: signal_images/sig_17.png


In [53]:
#Bycycle
warnings.filterwarnings("ignore", category=FutureWarning)
# Frequency band of interest
f_alpha = (5, 80)

# Tuned burst detection parameters
thresholds = {
    'amp_fraction': .2,
    'amp_consistency': .5,
    'period_consistency': .5,
    'monotonicity': .9,
    'min_n_cycles': 2
}
# convert results["sigs"][string_key] to numpy arrays
np_sigs = np.zeros((num_real_sigs, 1000))
for i in range(num_real_sigs):
    np_sigs[i] = np.array(results["sigs"]["sig_"+str(i)])

# # Apply lowpass filter to each signal
# for idx in range(len(np_sigs)):
#     np_sigs[idx] = filter_signal(np_sigs[idx], fs, 'lowpass', 30, n_seconds=.2, remove_edges=False)

# Compute features for each signal
bg = BycycleGroup(thresholds=thresholds)
bg.fit(np_sigs, 500, f_alpha)

# Recompute cycles on edges of bursts with reduced thresholds
bg.recompute_edges(.01)

# Add group and subject ids to dataframes
groups = ['patient' if idx >= int(num_real_sigs/2) else 'control' for idx in range(num_real_sigs)]
subject_ids = [idx for idx in range(num_real_sigs)]

for idx, group in enumerate(groups):
    bg.df_features[idx]['group'] = group
    bg.df_features[idx]['subject_id'] = subject_ids[idx]

# Concatenate the list of dataframes
df_features = pd.concat(bg.df_features)

#get burst bool array

df_features.head()
df_features.to_html("df_features.html")

# bg[0].plot(xlim=(0, 10), figsize=(16, 3))

# for i in range(num_real_sigs):
#     if i%4 == 0:
#         is_burst = bg.df_features[i]['is_burst']
#         burst_intervals = get_bursting_intervals(is_burst)  
#         # make image with bounding boxes for each burst
#         # Load the image using PIL
#         image = Image.open(all_images[i]).convert("RGB")
#         print(image.size)
#         draw = ImageDraw.Draw(image)
#         for j in range(len(is_burst)):
#             if is_burst[j]:  # If the point is part of a burst
#                 draw.line([i*20, 10, i*20, 90], fill="green", width=3)  # Draw a green line for bursts
#         # for i in range(len(pred_onsets)):
#         draw.rectangle([pred_onsets[i], 2, pred_offsets[i], 98], outline="red", width=3)
        
#         image.show()

#         print(burst_intervals)

In [54]:

# Annotate each selected image and record the predicted onsets and offsets
pred_onsets = np.zeros(num_real_sigs)
pred_offsets = np.zeros(num_real_sigs)
annotated_images = []
for i in range(len(all_images)):
    image_path = all_images[i]
    # Predict results for the image
    prediction_results = model.predict(source=image_path, conf=0.2)

    # Load the image using PIL
    image = Image.open(image_path).convert("RGB")
    draw = ImageDraw.Draw(image)

    # Process YOLO model predictions
    yolo_intervals = []
    for r in prediction_results:
        found = False

        for box in r.boxes.data:
            # Extract bounding box and class information
            x1, y1, x2, y2, confidence, class_id = box.tolist()
            class_name = model.names[
                int(class_id)
            ]  # Get class name using model's class names
            if class_id == 1:
                print(f"box bounds: {x1}, {x2}")
                if x1<0:
                    print(f"here {x1} {i}")
                if x2 < 0:
                    print(f"here {x2} {i}")
                pred_onsets[i] = x1
                pred_offsets[i] = x2
                yolo_intervals.append([x1,x2])
                found = True
            else:
                continue

    merge_burst_selections(yolo_intervals)

    # display yolo detections
    print(f"signal {i}")
    for j in range(len(yolo_intervals)):
        print(yolo_intervals[j])
        draw.rectangle(
            [yolo_intervals[j][0], y1, yolo_intervals[j][1], y2],
            outline="red",
            width=3,
        )

        # Create a label
        label = f"{class_name} ({confidence:.2f})"

        # Draw label inside the bounding box
        text_bbox = draw.textbbox((yolo_intervals[j][0], y1), label, font=font)
        text_width = text_bbox[2] - text_bbox[0]
        text_height = text_bbox[3] - text_bbox[1]

        # Position text inside the bounding box, adjusted to fit
        label_x = yolo_intervals[j][0]
        label_y = max(y1, 0) + 2

        # Draw label background and text
        draw.rectangle(
            [label_x, label_y, label_x + text_width, label_y + text_height],
            fill="red",
        )
        draw.text((label_x, label_y), label, fill="white", font=font)

    for j in range(len(y_pred[i])):
        draw.rectangle(
            [y_pred[i][j][0], y1, y_pred[i][j][1], y2],
            outline="orange",
            width=3,
        )

        # Create a label
        label = f"dualthresh selection"

        # Draw label inside the bounding box
        text_bbox = draw.textbbox((y_pred[i][j][0], y1), label, font=font)
        text_width = text_bbox[2] - text_bbox[0]
        text_height = text_bbox[3] - text_bbox[1]

        # Position text inside the bounding box, adjusted to fit
        label_x = y_pred[i][j][0]
        label_y = max(y1, 0) + 2

        # Draw label background and text
        draw.rectangle(
            [label_x, label_y, label_x + text_width, label_y + text_height],
            fill="orange",
        )
        draw.text((label_x, label_y), label, fill="white", font=font)


    #get rows from bg.df_features so I can get sample_last_trough and sample_next_trough
    # for each burst
    # Select a specific row (e.g., first row)
    max_index = bg.df_features[i].shape[0]
    for j in range(max_index):
        row = bg.df_features[i].iloc[j]  # Change index as needed

        # Extract the last and next trough sample indices
        last_trough = row["sample_last_trough"] * 0.91
        next_trough = row["sample_next_trough"] * 0.91

        # print(f"Last trough: {last_trough}, Next trough: {next_trough}")

        # if is_burst, draw the box
        # make the lines of the rectangle but only draw one rectangle for adjacent bursts
        is_burst = row["is_burst"]
        if is_burst:

            # draw.rectangle([last_trough, 2, next_trough, 98], outline="blue", width=3)
            draw.line([last_trough, 10, next_trough, 10], fill="blue", width=3)  # Draw a green line for bursts
            draw.line([last_trough, 90, next_trough, 90], fill="blue", width=3)  # Draw a green line for bursts  
            # if the last trough is not a burst, draw the left side of the rectangle
            if  not bg.df_features[i].iloc[j-1]["is_burst"]:
                # label the burst as a bycycle burst
                label = "bycycle burst"
                text_bbox = draw.textbbox((last_trough, 10), label, font=font)
                text_width = text_bbox[2] - text_bbox[0]
                text_height = text_bbox[3] - text_bbox[1]
                label_x = max(last_trough, 0) + 2
                label_y = max(10, 0) + 2
                draw.rectangle(
                    [label_x, label_y, label_x + text_width, label_y + text_height],
                    fill="blue",
                )
                draw.text((label_x, label_y), label, fill="white", font=font)
                draw.line([last_trough, 10, last_trough, 90], fill="blue", width=3)

            # if this is not the last row and next trough is not a burst, draw the right side of the rectangle
            if j < max_index - 1 and not bg.df_features[i].iloc[j+1]["is_burst"]:
                draw.line([next_trough, 10, next_trough, 90], fill="blue", width=3)
            # if this is the last row and this is a burst, draw the right side of the rectangle
            if j == max_index-1 and is_burst:
                draw.line([next_trough, 10, next_trough, 90], fill="blue", width=3)
                

            # draw.rectangle([last_trough, 10, next_trough, 90], outline="blue", width=3)

            # draw.rectangle([last_trough, 2, next_trough, 98], outline="blue", width=3)


        
    # Add a black border around the image
    border_size = 5
    bordered_image = Image.new(
        "RGB",
        (image.width + 2 * border_size, image.height + 2 * border_size),
        color="black",
    )
    bordered_image.paste(image, (border_size, border_size))

    # Ensure the image is resized to 910x100
    resized_image = bordered_image.resize((910, 100))  # Natural resolution
    # resized_image.show(f"Image {i}")
    annotated_images.append(resized_image)




image 1/1 /Users/kenton/HOME/coding/python/publish_the_paper/algo_dualthresh/signal_images/sig_16.png: 64x416 3 non-bursts, 1 burst, 25.9ms
Speed: 4.1ms preprocess, 25.9ms inference, 7.3ms postprocess per image at shape (1, 3, 64, 416)
box bounds: 85.19955444335938, 360.5264587402344
signal 0
[85.19955444335938, 360.5264587402344]

image 1/1 /Users/kenton/HOME/coding/python/publish_the_paper/algo_dualthresh/signal_images/sig_17.png: 64x416 2 non-bursts, 1 burst, 7.0ms
Speed: 0.3ms preprocess, 7.0ms inference, 0.5ms postprocess per image at shape (1, 3, 64, 416)
box bounds: 456.92901611328125, 784.8284912109375
signal 1
[456.92901611328125, 784.8284912109375]

image 1/1 /Users/kenton/HOME/coding/python/publish_the_paper/algo_dualthresh/signal_images/sig_15.png: 64x416 2 non-bursts, 1 burst, 8.0ms
Speed: 0.3ms preprocess, 8.0ms inference, 0.4ms postprocess per image at shape (1, 3, 64, 416)
box bounds: 263.7947692871094, 510.1513977050781
signal 2
[263.7947692871094, 510.1513977050781]


In [55]:
# Determine collage dimensions
collage_width = 910  # Each image's width
collage_images_per_row = 3  # Number of images per row
collage_rows = ceil(len(annotated_images) / collage_images_per_row)
collage_height = collage_rows * 100  # 100 pixels per image height

# Create the blank collage canvas
collage = Image.new(
    "RGB", (collage_width * collage_images_per_row, collage_height), color="white"
)

# Paste each image into the collage
for i, annotated_image in enumerate(annotated_images):
    row = i // collage_images_per_row
    col = i % collage_images_per_row
    x_offset = col * 910
    y_offset = row * 100
    collage.paste(annotated_image, (x_offset, y_offset))

# Save the collage
collage.save(output_collage)
print(f"Collage saved to {output_collage}")
# collage.show()


Collage saved to collage_dualthresh_vs_yolo.png


In [56]:

print("onsets", onsets)
print("predicted onsets", onsets)
print("offsets", offsets)
print("predicted offsets", offsets)

valid_indices = np.where(~np.isnan(pred_onsets))
valid_indices2 = np.where(~np.isnan(y_pred[:][0][0]))

valid_indices = np.intersect1d(valid_indices, ar2=valid_indices2)

diff_onsets = y_pred[valid_indices][0][0] - pred_onsets[valid_indices]
diff_offsets = y_pred[valid_indices][0][1] - pred_offsets[valid_indices]
# print("diff onsets", diff_onsets)

avg_diff_onsets = np.mean(diff_onsets)
avg_diff_offsets = np.mean(diff_offsets)

print("average onset error, not counting missed onsets", avg_diff_onsets)
print("average offset error, not counting missed onsets", avg_diff_offsets)

# for i in range(len(onsets)):
#     print(f"real vs predicted onset for signal {center_onsets[i]} vs {pred_onsets[i]}")
#     print(f"real vs predicted offset for signal {center_offsets[i]} vs {pred_offsets[i]}")
#     print(f"diff between onsets for signal {i}: {center_onsets[i] - pred_onsets[i]}")
#     print(f"diff between offsets for signal {i}: {center_offsets[i] - pred_offsets[i]}")
#     print() # newline

# print("average onset error, not counting missed onsets", np.mean(onsets[onsets > 0] - onsets[onsets > 0]))


onsets [[          0           0           0           0           0           0           0           0           0]
 [          0           0           0           0           0           0           0           0           0]
 [          0           0           0           0           0           0           0           0           0]
 [          0           0           0           0           0           0           0           0           0]
 [          0           0           0           0           0           0           0           0           0]
 [          0           0           0           0           0           0           0           0           0]
 [          0           0           0           0           0           0           0           0           0]
 [          0           0           0           0           0           0           0           0           0]
 [          0           0           0           0           0           0           0           0        

TypeError: only integer scalar arrays can be converted to a scalar index

In [None]:

print("onsets", onsets)
print("predicted onsets", onsets)
print("offsets", offsets)
print("predicted offsets", offsets)

valid_indices = np.where(~np.isnan(pred_onsets))
valid_indices2 = np.where(~np.isnan(y_pred[:,0]))

valid_indices = np.intersect1d(valid_indices, ar2=valid_indices2)

diff_onsets = y_pred[valid_indices,0] - pred_onsets[valid_indices]
diff_offsets = y_pred[valid_indices,1] - pred_offsets[valid_indices]
# print("diff onsets", diff_onsets)

avg_diff_onsets = np.mean(diff_onsets)
avg_diff_offsets = np.mean(diff_offsets)

print("average onset error, not counting missed onsets", avg_diff_onsets)
print("average offset error, not counting missed onsets", avg_diff_offsets)

find_rate = np.count_nonzero(~np.isnan(pred_onsets)) / float(num_real_sigs)
print("find rate yolo", find_rate)

find_rate_dualthresh = np.count_nonzero(~np.isnan(y_pred[:,0])) / float(num_real_sigs)
print("find rate dualthresh", find_rate_dualthresh)

find_rate_relative_to_detected = np.count_nonzero(~np.isnan(pred_onsets)) / float(np.count_nonzero(~np.isnan(y_pred[:,0])))
print("ratio yolo vs dualthresh", find_rate_relative_to_detected)

# count how many signals contain bursts by bycycle algorithm is_burst per bg.df_features[i]
count = 0
for i in range(num_real_sigs):
    is_burst = bg.df_features[i]['is_burst']
    if np.any(is_burst):
        count += 1

# print number of yolo detections/number of bycycle detections
print("number of yolo detections/number of bycycle detections", find_rate/find_rate_dualthresh)


# for i in range(len(onsets)):
#     print(f"real vs predicted onset for signal {center_onsets[i]} vs {pred_onsets[i]}")
#     print(f"real vs predicted offset for signal {center_offsets[i]} vs {pred_offsets[i]}")
#     print(f"diff between onsets for signal {i}: {center_onsets[i] - pred_onsets[i]}")
#     print(f"diff between offsets for signal {i}: {center_offsets[i] - pred_offsets[i]}")
#     print() # newline

# print("average onset error, not counting missed onsets", np.mean(onsets[onsets > 0] - onsets[onsets > 0]))


onsets [[          0           0           0           0           0           0           0           0           0]
 [          0           0           0           0           0           0           0           0           0]
 [          0           0           0           0           0           0           0           0           0]
 [          0           0           0           0           0           0           0           0           0]
 [          0           0           0           0           0           0           0           0           0]
 [          0           0           0           0           0           0           0           0           0]
 [          0           0           0           0           0           0           0           0           0]
 [          0           0           0           0           0           0           0           0           0]
 [          0           0           0           0           0           0           0           0        

TypeError: list indices must be integers or slices, not tuple

: 

: 

: 

: 

: 

: 

: 

: 