In [1]:
import numpy as np
import scipy.signal as signal
from scipy.signal import convolve, find_peaks
from scipy.signal import butter, sosfreqz, sosfiltfilt, filtfilt
import scipy.io
import h5py
import os, sys
import pandas as pd
import csv

import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
from matplotlib.widgets import Slider, Button, Cursor, CheckButtons
from matplotlib.cm import ScalarMappable
import ipywidgets as widgets
from ipywidgets import interactive, Output, IntSlider, FloatRangeSlider, IntRangeSlider, FloatSlider, interact, fixed, interactive_output, VBox, HBox, Dropdown, SelectMultiple, Button, widgets, Label
import tkinter as tk
from tkinter import filedialog, OptionMenu, StringVar
from IPython.display import display, clear_output

%matplotlib qt

In [2]:
dir_path = []
loaded_data = None
cleaned_path = None  # Declare globally

def load_mat_file(file_path):
    try:
        # Try loading with scipy first (for MATLAB < v7.3)
        mat_data = scipy.io.loadmat(file_path)
        is_hdf5 = False
    except NotImplementedError:
        # File is MATLAB v7.3 format, use h5py
        mat_data = h5py.File(file_path, 'r')
        is_hdf5 = True
    
    return mat_data, is_hdf5

def extract_data_from_mat(mat_data, is_hdf5=False):
    if is_hdf5:
        # HDF5 format (MATLAB v7.3)
        data_keys = [key for key in mat_data.keys() if not key.startswith('#')]
    else:
        # Older MATLAB format
        data_keys = [key for key in mat_data.keys() if not key.startswith('__')]
    
    if len(data_keys) == 0:
        raise ValueError("No data arrays found in .mat file.")
    
    elif len(data_keys) == 1:
        # Single array found, use it directly
        selected_key = data_keys[0]
        if is_hdf5:
            eeg_data = np.array(mat_data[selected_key])
        else:
            eeg_data = mat_data[selected_key]
        print(f"Loaded data from key: '{selected_key}'")
    
    else:
        # Multiple arrays found, prompt user to select
        print(f"\nMultiple data arrays found in .mat file:")
        for i, key in enumerate(data_keys, 1):
            if is_hdf5:
                shape = mat_data[key].shape
                dtype = mat_data[key].dtype
            else:
                shape = mat_data[key].shape
                dtype = mat_data[key].dtype
            print(f"  {i}. '{key}' - Shape: {shape}, Type: {dtype}")
        
        selection = int(input("\nSelect the data array by number: ").strip())
        if selection < 1 or selection > len(data_keys):
            raise ValueError("Invalid selection.")
        
        selected_key = data_keys[selection - 1]
        if is_hdf5:
            eeg_data = np.array(mat_data[selected_key])
        else:
            eeg_data = mat_data[selected_key]
        print(f"Loaded data from key: '{selected_key}'")
    
    return eeg_data

def load_eeg_data(file_path, num_channels=None):
    try:
        if not os.path.isfile(file_path):
            raise FileNotFoundError(f"The file '{file_path}' does not exist.")
        
        if file_path.endswith('.bin'):
            if num_channels is None:
                raise ValueError("Number of channels must be specified for .bin files.")
            with open(file_path, 'rb') as f:
                eeg_data = np.fromfile(f, dtype=np.int16)
                eeg_data = eeg_data.reshape((-1, num_channels)).T
        
        elif file_path.endswith('.npy'):
            eeg_data = np.load(file_path).T
            num_channels = eeg_data.shape[0]
        
        elif file_path.endswith('.mat'):
            # Load .mat file (handles both v7.3 and older formats)
            mat_data, is_hdf5 = load_mat_file(file_path)
            
            try:
                eeg_data = extract_data_from_mat(mat_data, is_hdf5)
            finally:
                # Close HDF5 file if opened
                if is_hdf5:
                    mat_data.close()
            
            print(f"Data type: {eeg_data.dtype}")
            print(f"Original shape: {eeg_data.shape}")
            
            # Handle data shape - ensure it's 2D (channels x samples)
            if eeg_data.ndim == 1:
                eeg_data = eeg_data.reshape(1, -1)
                print("Reshaped 1D data to (1, samples)")
            elif eeg_data.ndim == 2:
                # HDF5 files often need transposing
                if is_hdf5:
                    # HDF5 stores data in transposed format by default
                    print(f"HDF5 detected. Transposing from {eeg_data.shape} to {eeg_data.T.shape}")
                    eeg_data = eeg_data.T
                
                # Check if further transpose is needed (samples x channels -> channels x samples)
                if eeg_data.shape[0] > eeg_data.shape[1]:
                    print(f"Data shape is {eeg_data.shape}, assuming (samples x channels)")
                    response = input("Transpose to (channels x samples)? (y/n): ").strip().lower()
                    if response == 'y':
                        eeg_data = eeg_data.T
            else:
                raise ValueError(f"Unsupported data dimensions: {eeg_data.ndim}D. Expected 1D or 2D array.")
            
            num_channels = eeg_data.shape[0]
        
        else:
            raise ValueError("Unsupported file format. Only .bin, .npy, and .mat are supported.")

        print(f"File '{file_path}' loaded successfully!")
        print(f"Final data shape: {eeg_data.shape} (channels x samples)")
        return eeg_data
    
    except Exception as e:
        print(f"Error loading file '{file_path}': {e}")
        return None

def main():
    global dir_path, loaded_data, cleaned_path

    # Prompt for the full file path
    file_path = input("Enter the full path to the EEG file (.bin, .npy, or .mat): ").strip().strip("'\"")

    # Validate file type and prompt for channels if needed
    if file_path.endswith('.bin'):
        try:
            num_channels = int(input("Enter the number of channels (required for .bin files): ").strip())
        except ValueError:
            print("Invalid input. Number of channels must be an integer.")
            return
    else:
        num_channels = None

    # Load the data
    loaded_data = load_eeg_data(file_path, num_channels)
    
    if loaded_data is not None:
        # Convert to float32
        loaded_data = loaded_data.astype(np.float32)
        
        # Save the file path for later reference
        cleaned_path = file_path.replace("\\", "/")
        dir_path.append(cleaned_path)

        # Print data and path details
        print(f"\nLoaded data shape: {loaded_data.shape}")
        print(f"Data type: {loaded_data.dtype}")
        print(f"Cleaned file path: {cleaned_path}")
    else:
        print("No data loaded.")

if __name__ == "__main__":
    main()

Enter the full path to the EEG file (.bin, .npy, or .mat):  '/Volumes/kovi ssd/egyetem/7.felev-szakdoga/THUMB/patients/epi52_sleep/epi52_01_sleep_2_thumb_1.mat'


Loaded data from key: 'lfpdata'
Data type: float64
Original shape: (30000200, 23)
HDF5 detected. Transposing from (30000200, 23) to (23, 30000200)
File '/Volumes/kovi ssd/egyetem/7.felev-szakdoga/THUMB/patients/epi52_sleep/epi52_01_sleep_2_thumb_1.mat' loaded successfully!
Final data shape: (23, 30000200) (channels x samples)

Loaded data shape: (23, 30000200)
Data type: float32
Cleaned file path: /Volumes/kovi ssd/egyetem/7.felev-szakdoga/THUMB/patients/epi52_sleep/epi52_01_sleep_2_thumb_1.mat


In [3]:
def remove_outliers(samples, curation_needed=True):
    if not curation_needed:
        return samples, np.array([])  # No curation needed, return the original data and empty outlier indices

    # Calculate the first and third quartiles
    Q1 = np.percentile(samples, 1, axis=0)
    Q3 = np.percentile(samples, 99, axis=0)

    # Calculate the interquartile range (IQR)
    IQR = Q3 - Q1

    # Set a threshold for outlier detection (e.g., 1.5 times the IQR)
    threshold_factor = 1.5
    lower_thresholds = Q1 - threshold_factor * IQR
    upper_thresholds = Q3 + threshold_factor * IQR

    outlier_indices = []
    curated_data = np.copy(samples)

    for channel in range(samples.shape[1]):
        channel_indices_below = np.where(samples[:, channel] < lower_thresholds[channel])[0]
        channel_indices_above = np.where(samples[:, channel] > upper_thresholds[channel])[0]
        channel_indices = list(set(channel_indices_below) | set(channel_indices_above))

        for index in channel_indices:
            # Set the 400 timepoints before and after outlier to zero
            start_index = max(0, index - 400)
            end_index = min(samples.shape[0] - 1, index + 400)
            curated_data[start_index:end_index + 1, channel] = 0

        outlier_indices.append(sorted(channel_indices))

    outlier_indices = np.array(sorted(list(set(item for sublist in outlier_indices for item in sublist))))

    return curated_data, outlier_indices

# Ask the user if they want to perform curation
curation_decision = input("Do you want to perform outlier curation? (y/n): ").lower()

if curation_decision == 'y':
    loaded_data, outlier_indices = remove_outliers(loaded_data)
else:
    loaded_data = loaded_data
    outlier_indices = np.array([])

# Define filtering functions
def apply_hamming_filter(input_data, sampling_rate, channels_to_exclude=None, double=False):
    center_coefficient = 0.54
    side_coefficient = 0.23
    filter_kernel = np.array([side_coefficient, center_coefficient, side_coefficient])

    # Ensure the input_data is longer than the kernel
    if input_data.shape[0] < len(filter_kernel):
        raise ValueError("Input data must be longer than the kernel size")

    hamming_data = np.apply_along_axis(
        lambda x: np.convolve(x, filter_kernel, mode='same'),
        axis=0,
        arr=input_data
    )

    if double:
        hamming_data = np.apply_along_axis(
            lambda x: np.convolve(x, filter_kernel, mode='same'),
            axis=0,
            arr=hamming_data
        )
    
    if channels_to_exclude:
        hamming_data[:, channels_to_exclude] = input_data[:, channels_to_exclude]

    return hamming_data

def butter_bandpass(lowcut, highcut, fs, order=4):
    nyquist = 0.5 * fs
    low = lowcut / nyquist if lowcut is not None else None
    high = highcut / nyquist if highcut is not None else None

    if lowcut is not None and highcut is not None:
        sos = butter(order, [low, high], btype='band', output='sos')
    elif lowcut is None and highcut is not None:
        sos = butter(order, high, btype='low', output='sos')
    elif lowcut is not None and highcut is None:
        sos = butter(order, low, btype='high', output='sos')
    else:
        raise ValueError("At least one of lowcut or highcut must be provided.")

    return sos

def butter_bandpass_filter(data, lowcut, highcut, fs, order=4):
    num_channels = data.shape[0]
    filtered_data = np.zeros_like(data)

    for i in range(num_channels):
        sos = butter_bandpass(lowcut, highcut, fs, order=order)
        filtered_data[i, :] = sosfiltfilt(sos, data[i, :])

    return filtered_data

def preprocess_data(raw_data, fs):
    # Apply bandpass filter
    bandpassed_data = butter_bandpass_filter(raw_data, 1, 30, fs)

    # Apply Hamming filter
    preprocessed_data = apply_hamming_filter(bandpassed_data, sampling_rate=fs_resampled, double=False)
    
    return preprocessed_data

def compute_csd(data, spacing=1.0, data_is_first_derivative=False):
    d = np.asarray(data, dtype=np.float64)
    csd = np.zeros_like(d)

    if d.size == 0:
        return csd

    if data_is_first_derivative:
        if d.shape[0] == 1:
            csd[:] = 0.0
            return csd

        csd[1:, :] = (d[1:, :] - d[:-1, :]) / spacing
        csd[0, :] = csd[1, :]
        csd = csd / spacing
    else:
        nchan = d.shape[0]
        if nchan < 3:
            csd[:] = 0.0
            return csd

        csd[1:-1, :] = (d[:-2, :] - 2.0 * d[1:-1, :] + d[2:, :]) / (spacing ** 2)
        csd[0, :] = csd[1, :]
        csd[-1, :] = csd[-2, :]

    return csd

window_duration = 10

t = loaded_data.shape[1]
num_channels = loaded_data.shape[0]
fs = 20000 #recording freq
fs_resampled = 2000
fs_resampled_high = 5000
resampled_data_high = loaded_data[:, ::4]
resampled_data = loaded_data[:, ::10]
resamp_data = preprocess_data(resampled_data, fs_resampled)
plot_height_inches = min(6 * num_channels, 9)  
plot_width_inches = 14  
figsize = (plot_width_inches, plot_height_inches)
INT16_TO_UV = 1 #0.30518 if in int16 change to this value
max_data = np.max(loaded_data)
min_data = np.min(loaded_data)
span = max_data - min_data

Do you want to perform outlier curation? (y/n):  n


In [4]:
start_time = 0
filter_settings = {}
snapshot_counter = 1
scale_amplitude = 1.0
use_monochrome = True  # Toggle for Monochrome/Gradient
window_duration_str = '10'
noisy_intervals = []
is_noisy_mode = False
current_interval = []
noisy_mode_checkbox = widgets.Checkbox(value=False, description='Noisy Data Mode')
noisy_intervals_label = widgets.Label(value="Noisy Intervals: []")

def calculate_starting_amplitude(data, multiplier=0.5):
    return (1.0 / np.max(np.abs(data))) * multiplier

def toggle_monochrome(event):
    global use_monochrome
    use_monochrome = not use_monochrome
    plt.clf()  # Clear the current plot
    visualize_eeg_data(window_duration_text.value, hidden_channels_text.value,
                       apply_hamming=apply_hamming_checkbox.value,
                       apply_double_hamming=apply_double_hamming_checkbox.value,
                       apply_bandpass=apply_bandpass_checkbox.value,
                       lowcut=lowcut_text.value,
                       highcut=highcut_text.value,
                       scale_amplitude=scale_amplitude)
    plt.draw()

def toggle_noisy_mode(change):
    global is_noisy_mode
    is_noisy_mode = change.new
    if not is_noisy_mode:
        finalize_noisy_interval()
noisy_mode_checkbox.observe(toggle_noisy_mode, names='value')

# Finalize the current noisy interval if any
def finalize_noisy_interval():
    global current_interval, noisy_intervals
    if len(current_interval) == 2:
        # Convert relative interval to absolute time
        noisy_intervals.append((round(current_interval[0], 2), round(current_interval[1], 2)))
        current_interval = []
        update_noisy_intervals_label()

# Update the noisy intervals label
def update_noisy_intervals_label():
    noisy_intervals_label.value = f"Noisy Intervals: {noisy_intervals}"

def visualize_eeg_data(window_duration_str, hidden_channels_str, apply_hamming=False, apply_double_hamming=False, apply_bandpass=False, lowcut=None, highcut=None, scale_amplitude=1.0):
    plt.clf()

    global resampled_data, resampled_data_high, fs_resampled, fs_resampled_high, num_channels, window_duration, start_time, filter_settings, use_monochrome
    window_duration = float(window_duration_str)
    window_start = int(start_time * fs_resampled)
    window_end = window_start + int(window_duration * fs_resampled)

    hidden_channels = []
    if hidden_channels_str.strip() and hidden_channels_str != '':
        hidden_channels = parse_hidden_channels(hidden_channels_str)

    selected_channels = np.array([channel for channel in range(1, num_channels + 1) if channel not in hidden_channels])
    num_visible_channels = len(selected_channels)

    if use_monochrome:
        plt.gca().set_facecolor('white')
        text_color = 'white'
        grid_color = 'gray'
        line_color = '#e36414'
    else:
        plt.gca().set_facecolor('black')
        text_color = 'black'
        grid_color = 'white'
        line_color = None

    if num_visible_channels > 0:
        # Convert lowcut and highcut to float only if they are valid, otherwise set them to None
        try:
            lowcut = float(lowcut) if lowcut not in [None, ''] else None
            highcut = float(highcut) if highcut not in [None, ''] else None
        except ValueError:
            print("Invalid input for lowcut or highcut. Please enter valid numerical values.")
            return

        # Choose appropriate resampled data if highcut > nyquist for the lower-resampled stream
        nyquist_freq = fs_resampled / 2
        if highcut is not None and highcut > nyquist_freq:
            print(f"Highcut frequency {highcut} exceeds Nyquist frequency {nyquist_freq}.")
            data_to_visualize = resampled_data_high[selected_channels - 1, window_start:window_end] * scale_amplitude
        else:
            data_to_visualize = resampled_data[selected_channels - 1, window_start:window_end] * scale_amplitude

        if apply_hamming or apply_double_hamming:
            data_to_visualize = apply_hamming_filter(data_to_visualize, fs_resampled, double=apply_double_hamming)

        if apply_bandpass:
            try:
                if lowcut is not None or highcut is not None:
                    data_to_visualize = butter_bandpass_filter(data_to_visualize, lowcut, highcut, fs_resampled)
            except ValueError:
                print("Switching from 2 kHz downsampled data to 5 kHz.")

        # If user requested a CSD heatmap, compute CSD and plot heatmap instead of traces
        if 'csd_heatmap_checkbox' in globals() and csd_heatmap_checkbox.value:
            spacing = float(channel_spacing_text.value) if ('channel_spacing_text' in globals() and channel_spacing_text.value is not None) else 1.0
            is_first_deriv = ('csd_is_gradient_checkbox' in globals() and csd_is_gradient_checkbox.value)

            # compute CSD (channels x time)
            csd_data = compute_csd(data_to_visualize, spacing=spacing, data_is_first_derivative=is_first_deriv)

            # normalize symmetrically around zero for diverging colormap
            if csd_data.size > 0:
                vmax = np.nanmax(np.abs(csd_data))
                if vmax == 0:
                    vmax = 1.0
            else:
                vmax = 1.0
            norm = Normalize(vmin=-vmax, vmax=vmax)

            # Show CSD heatmap: channels on y, time on x
            im = plt.imshow(csd_data, aspect='auto', cmap='RdBur', norm=norm, origin='lower',
                            extent=[0, window_duration, selected_channels[0], selected_channels[-1]])
            plt.colorbar(im, label='CSD (arb. units)')
            plt.ylabel('Channel')
            # simplify y ticks: show channel numbers
            try:
                plt.yticks(np.linspace(selected_channels[0], selected_channels[-1], min(10, num_visible_channels)).astype(int))
            except Exception:
                pass
        else:
            # Optionally apply CSD to the traces (if csd_apply_checkbox exists and is checked).
            # This makes the stacked traces show CSD values instead of the original data rows.
            if 'csd_apply_checkbox' in globals() and csd_apply_checkbox.value:
                spacing = float(channel_spacing_text.value) if ('channel_spacing_text' in globals() and channel_spacing_text.value is not None) else 1.0
                is_first_deriv = ('csd_is_gradient_checkbox' in globals() and csd_is_gradient_checkbox.value)
                # compute CSD and replace data_to_visualize with it (keeps same channels x time shape)
                try:
                    data_to_visualize = compute_csd(data_to_visualize, spacing=spacing, data_is_first_derivative=is_first_deriv)
                except Exception as e:
                    print("Error computing CSD for traces:", e)
                    # fallback to original data_to_visualize

            # Draw stacked traces as before (either potentials/gradient or CSD if csd_apply_checkbox was true)
            span = 5000
            colors = plt.cm.tab20(np.linspace(0, 1, num_visible_channels)) if not use_monochrome else ['#0f4c5c'] * num_visible_channels
            for i, (channel_data, color) in enumerate(zip(data_to_visualize, colors)):
                time = np.linspace(window_start / fs_resampled, window_end / fs_resampled, len(channel_data))
                plt.plot(time - start_time, channel_data + (num_visible_channels - 1 - i) * span, color=color, alpha=0.95, linewidth=0.5)

            plt.ylim(-span, num_visible_channels * span)
            plt.yticks(np.arange(0, num_visible_channels) * span, reversed(selected_channels))
    else:
        plt.ylim(0, 1)

    for start, end in noisy_intervals:
        highlight_start = max(start, start_time)
        highlight_end = min(end, start_time + window_duration)
        if highlight_start < highlight_end:
            plt.axvspan(highlight_start - start_time, highlight_end - start_time, color='red', alpha=0.3)

    plt.xticks(np.linspace(0, window_duration, 11), np.round(np.linspace(start_time, start_time + window_duration, 11), 2))
    plt.xlabel('Time (s)')
    plt.grid(True, axis='x', linestyle='--', color='#e36414', alpha=0.5)
    plt.tight_layout()
    plt.gca().yaxis.set_ticks_position('both')
    plt.tick_params(axis='y', labelleft=True, labelright=True)
    plt.subplots_adjust(left=0.05, right=0.95)
    plt.show()

    fig = plt.gcf().number
    if fig in filter_settings:
        settings = filter_settings[fig]
        apply_hamming_checkbox.value = settings['apply_hamming']
        apply_double_hamming_checkbox.value = settings['apply_double_hamming']
        apply_bandpass_checkbox.value = settings['apply_bandpass']
        lowcut_text.value = settings['lowcut'] if settings['lowcut'] is not None else ''
        highcut_text.value = settings['highcut'] if settings['highcut'] is not None else ''

def save_filter_settings(_):
    global filter_settings
    fig = plt.gcf().number
    filter_settings[fig] = {
        'apply_hamming': apply_hamming_checkbox.value,
        'apply_double_hamming': apply_double_hamming_checkbox.value,
        'apply_bandpass': apply_bandpass_checkbox.value,
        'lowcut': lowcut_text.value,
        'highcut': highcut_text.value,
        'scale_amplitude': scale_amplitude
    }
    print(f"Filter settings saved for figure {fig}.")

def parse_hidden_channels(hidden_channels_str):
    hidden_channels = set()
    ranges = hidden_channels_str.split(',')
    for r in ranges:
        try:
            if '-' in r:
                start, end = r.split('-')
                hidden_channels.update(range(int(start), int(end) + 1))
            else:
                hidden_channels.add(int(r))
        except ValueError:
            pass
    return list(hidden_channels)

def save_snapshot(_):
    global snapshot_counter, filter_settings
    snap_path = os.path.join(cleaned_path[:-4] + f'_overall_window{snapshot_counter}.svg')
    filter_settings[snap_path] = {
        'apply_hamming': apply_hamming_checkbox.value,
        'apply_double_hamming': apply_double_hamming_checkbox.value,
        'apply_bandpass': apply_bandpass_checkbox.value,
        'lowcut': lowcut_text.value,
        'highcut': highcut_text.value,
    }
    snapshot_counter += 1
    plt.savefig(snap_path, bbox_inches='tight')

def on_key(event):
    global start_time

    if event.key == 'left':
        start_time -= round(window_duration * 0.2, 2) if not scroll_fast_checkbox.value else window_duration
        start_time = max(0, start_time)
    elif event.key == 'right':
        start_time += round(window_duration * 0.2, 2) if not scroll_fast_checkbox.value else window_duration
        start_time = min(t - window_duration, start_time)

    for fig in plt.get_fignums():
        plt.figure(fig)
        visualize_eeg_data(window_duration_text.value,
                           hidden_channels_text.value,
                           apply_hamming=apply_hamming_checkbox.value,
                           apply_double_hamming=apply_double_hamming_checkbox.value,
                           apply_bandpass=apply_bandpass_checkbox.value,
                           lowcut=lowcut_text.value if lowcut_text.value != '' else None,
                           highcut=highcut_text.value if highcut_text.value != '' else None)
        plt.draw()

def on_arrow_key(event):
    global scale_amplitude
    if event.key == 'up':
        scale_amplitude *= 1.5
    elif event.key == 'down':
        scale_amplitude /= 1.5
    visualize_eeg_data(window_duration_text.value,
                       hidden_channels_text.value,
                       apply_hamming=apply_hamming_checkbox.value,
                       apply_double_hamming=apply_double_hamming_checkbox.value,
                       apply_bandpass=apply_bandpass_checkbox.value,
                       lowcut=lowcut_text.value,
                       highcut=highcut_text.value,
                       scale_amplitude=scale_amplitude)
    plt.draw()

def reset_noisy_intervals(change):
    global noisy_intervals, current_interval
    noisy_intervals = []
    current_interval = []
    update_noisy_intervals_label()
    print("Noisy intervals have been reset.")

def on_click(event):
    global current_interval, is_noisy_mode, start_time
    if is_noisy_mode:
        if len(current_interval) == 0:
            # Mark the start of the noisy interval (absolute time)
            current_interval.append(event.xdata + start_time)
            print(f"Noisy interval start marked at: {event.xdata + start_time:.2f} seconds")
        elif len(current_interval) == 1:
            # Mark the end of the noisy interval (absolute time)
            current_interval.append(event.xdata + start_time)
            print(f"Noisy interval end marked at: {event.xdata + start_time:.2f} seconds")
            finalize_noisy_interval()

# Attach the click event to the matplotlib figure
plt.gcf().canvas.mpl_connect('button_press_event', on_click)

reset_button = widgets.Button(description='Reset Noisy Intervals', button_style='danger')
reset_button.on_click(reset_noisy_intervals)

def reset_scale_amplitude():
    global scale_amplitude
    scale_amplitude = calculate_starting_amplitude(resampled_data, initial_scale_multiplier)

def new_window(_):
    global resampled_data, fs_resampled, num_channels, start_time, scale_amplitude

    window_duration_text_new = widgets.FloatText(value=window_duration_text.value, description='Window Duration')
    hidden_channels_text_new = widgets.Text(value=hidden_channels_text.value, description='Hidden Channels')

    apply_hamming_checkbox_new = widgets.Checkbox(value=apply_hamming_checkbox.value, description='Apply Hamming')
    apply_double_hamming_checkbox_new = widgets.Checkbox(value=apply_double_hamming_checkbox.value, description='Apply Double Hamming')
    apply_bandpass_checkbox_new = widgets.Checkbox(value=apply_bandpass_checkbox.value, description='Apply Bandpass')
    lowcut_text_new = widgets.FloatText(value=lowcut_text.value, description='Lowcut')
    highcut_text_new = widgets.FloatText(value=highcut_text.value, description='Highcut')

    plt.figure()

    plt.gcf().canvas.mpl_connect('key_press_event', on_key)
    plt.gcf().canvas.mpl_connect('key_press_event', on_arrow_key)

    initial_scale_amplitude = calculate_starting_amplitude(resampled_data, initial_scale_multiplier)

    visualize_eeg_data(window_duration_text_new.value,
                       hidden_channels_text_new.value,
                       apply_hamming=apply_hamming_checkbox_new.value,
                       apply_double_hamming=apply_double_hamming_checkbox_new.value,
                       apply_bandpass=apply_bandpass_checkbox_new.value,
                       lowcut=lowcut_text_new.value,
                       highcut=highcut_text_new.value)

    snapshot_button_new = widgets.Button(description='Take a Snapshot', button_style='info')
    snapshot_button_new.on_click(save_snapshot)
    display(snapshot_button_new)
    display(window_duration_text_new)
    display(hidden_channels_text_new)

    display(apply_hamming_checkbox_new)
    display(apply_double_hamming_checkbox_new)
    display(apply_bandpass_checkbox_new)
    display(lowcut_text_new)
    display(highcut_text_new)

window_duration_text = widgets.FloatText(value='10', description='Window Duration')
hidden_channels_text = widgets.Text(value='', description='Hidden Channels')
apply_hamming_checkbox = widgets.Checkbox(value=False, description='Apply Hamming')
apply_double_hamming_checkbox = widgets.Checkbox(value=False, description='Apply Double Hamming')
apply_bandpass_checkbox = widgets.Checkbox(value=False, description='Apply Bandpass')
lowcut_text = widgets.Text(value='', placeholder='Lowcut (Hz)', description='Lowcut:')
highcut_text = widgets.Text(value='', placeholder='Highcut (Hz)', description='Highcut:')
scroll_fast_checkbox = widgets.Checkbox(value=False, description='Fast Scroll')
csd_apply_checkbox = widgets.Checkbox(value=False, description='Apply CSD (for traces)')
csd_heatmap_checkbox = widgets.Checkbox(value=False, description='Show CSD Heatmap')
csd_is_gradient_checkbox = widgets.Checkbox(value=True, description='Input is 1st-derivative')
channel_spacing_text = widgets.FloatText(value=1.0, description='Channel Spacing')

# Adjust the function to handle key release event to allow empty input
def on_lowcut_change(change):
    if change['new'] == '':
        lowcut_text.value = ''  # Allow the input to be empty

def on_highcut_change(change):
    if change['new'] == '':
        highcut_text.value = ''  # Allow the input to be empty

# Attach the change events
lowcut_text.observe(on_lowcut_change, names='value')
highcut_text.observe(on_highcut_change, names='value')

plt.gcf().canvas.mpl_connect('key_press_event', on_key)
plt.gcf().canvas.mpl_connect('key_press_event', on_arrow_key)

interact(visualize_eeg_data,
         window_duration_str=window_duration_text,
         hidden_channels_str=hidden_channels_text,
         apply_hamming=apply_hamming_checkbox,
         apply_double_hamming=apply_double_hamming_checkbox,
         apply_bandpass=apply_bandpass_checkbox,
         lowcut=lowcut_text,
         highcut=highcut_text)

snapshot_button = widgets.Button(description='Take a Snapshot', button_style='info')
snapshot_button.on_click(save_snapshot)

save_filter_settings_button = widgets.Button(description='Save Filter Settings', button_style='success')
save_filter_settings_button.on_click(save_filter_settings)

new_window_button = widgets.Button(description='New Window', button_style='info')
new_window_button.on_click(new_window)
monochrome_button = widgets.Button(description='Polychrome', button_style='warning')
monochrome_button.on_click(toggle_monochrome)

# Display the button
display(scroll_fast_checkbox)
display(monochrome_button)
display(snapshot_button)
display(save_filter_settings_button)
display(new_window_button)
display(noisy_mode_checkbox)

display(csd_apply_checkbox)
display(csd_heatmap_checkbox)
display(csd_is_gradient_checkbox)
display(channel_spacing_text)
display(noisy_intervals_label)
display(reset_button)

interactive(children=(FloatText(value=10.0, description='Window Duration'), Text(value='', description='Hidden…

Checkbox(value=False, description='Fast Scroll')



Button(button_style='info', description='Take a Snapshot', style=ButtonStyle())

Button(button_style='success', description='Save Filter Settings', style=ButtonStyle())

Button(button_style='info', description='New Window', style=ButtonStyle())

Checkbox(value=False, description='Noisy Data Mode')

Checkbox(value=False, description='Apply CSD (for traces)')

Checkbox(value=False, description='Show CSD Heatmap')

Checkbox(value=True, description='Input is 1st-derivative')

FloatText(value=1.0, description='Channel Spacing')

Label(value='Noisy Intervals: []')

Button(button_style='danger', description='Reset Noisy Intervals', style=ButtonStyle())

2025-11-24 19:19:43.637 python3.13[84592:26391550] +[IMKClient subclass]: chose IMKClient_Modern


In [9]:
start_time = 0
window_duration = 10  # Default window duration is 10 seconds
channel_thresholds = {}

# Get channels to plot and target channel from user input
select_chan = input("Enter the channels to plot (comma-separated or interval): ")
target_chan = input("Select the target channel: ")

def initialize_global_variables():
    global channel_suprathreshold_events, real_peaks, channel_thresholds, channel_y_coords, y_ticks, all_absolute_peaks
    
    # Reset all global variables
    channel_suprathreshold_events = {}
    channel_thresholds = {}
    channel_y_coords = {}
    y_ticks = []
    real_peaks = []
    all_absolute_peaks = {}
    window_duration = 10

# Call this right after defining it
initialize_global_variables()

def reset_detection_variables():
    global channel_suprathreshold_events, real_peaks
    channel_suprathreshold_events.clear()
    real_peaks.clear()

def exclude_noisy_intervals(peaks, fs_resampled):
    excluded_peaks = []
    for peak in peaks:
        peak_time = peak / fs_resampled
        if not any(start <= peak_time <= end for start, end in noisy_intervals):
            excluded_peaks.append(peak)
    return excluded_peaks

def save_ev2_files(_):
    global channel_suprathreshold_events
    
    for channel, events in channel_suprathreshold_events.items():
        spa_info = []
        
        unique_events = sorted(set(events))
        
        # Convert event times to sample indices
        event_samples = [int(event * fs_resampled * 10) for event in unique_events]
        
        # Exclude events in noisy intervals
        event_samples = exclude_noisy_intervals(event_samples, fs_resampled)
        
        for sample_time in event_samples:
            spa_info.append([channel, 0, 0, 0, 0, sample_time])
        
        spa_info.sort(key=lambda x: x[5])
        
        output_file = os.path.join(cleaned_path[:-4] + f'_channel_{channel}_spa_peaks.ev2')
        
        with open(output_file, 'w') as f:
            f.write("Channel Zeros1 Zeros2 Zeros3 Zeros4 SampleTime\n")
            for info in spa_info:
                f.write(f"{info[0]} {info[1]} {info[2]} {info[3]} {info[4]} {info[5]}\n")
        
        print(f"Saved {len(spa_info)} peaks for channel {channel} to {output_file}")

def on_export_button_click(b):
    export_to_csv("suprathreshold_events.csv")

def visualize_eeg_data(channels_to_plot_str, target_channel, window_duration_str, apply_hamming=False, apply_double_hamming=False, apply_bandpass=False, lowcut=1, highcut=30, scale_amplitude=1.0):
    global channel_thresholds, real_peaks, all_absolute_peaks, channel_suprathreshold_events, window_duration

    if not hasattr(visualize_eeg_data, 'has_run'):
        reset_detection_variables()
        visualize_eeg_data.has_run = True

    filtered_target_peaks = []

    plt.clf()
    
    # Update window duration based on textbox input
    window_duration = float(window_duration_str) if window_duration_str.strip() else 10
    window_start = int(start_time * fs_resampled)
    window_end = window_start + int(window_duration * fs_resampled)

    # Ensure the window end does not exceed the length of the data
    data_length = resampled_data.shape[1]
    if window_end > data_length:
        window_end = data_length

    channels_to_plot = []
    if channels_to_plot_str.strip():
        for part in channels_to_plot_str.split(","):
            part = part.strip()
            if "-" in part:
                start, end = map(int, part.split("-"))
                channels_to_plot.extend(range(start, end + 1))
            else:
                channels_to_plot.append(int(part))

    num_visible_channels = len(channels_to_plot)

    if num_visible_channels > 0:
        try:
            lowcut = float(lowcut) if lowcut not in [None, ''] else None
            highcut = float(highcut) if highcut not in [None, ''] else None
        except ValueError:
            print("Invalid input for lowcut or highcut. Please enter valid numerical values.")
            return

        nyquist_freq = fs_resampled / 2
        if highcut is not None and highcut > nyquist_freq:
            print(f"Highcut frequency {highcut} exceeds Nyquist frequency {nyquist_freq}.")
            data_to_visualize = resampled_data_high[np.array(channels_to_plot) - 1, window_start:window_end] * scale_amplitude
        else:
            data_to_visualize = resampled_data[np.array(channels_to_plot) - 1, window_start:window_end] * scale_amplitude

        if apply_hamming or apply_double_hamming:
            data_to_visualize = apply_hamming_filter(data_to_visualize, fs_resampled, double=apply_double_hamming)

        if apply_bandpass:
            try:
                if lowcut is not None or highcut is not None:
                    data_to_visualize = butter_bandpass_filter(data_to_visualize, lowcut, highcut, fs_resampled)
            except ValueError:
                print("Switching from 2 kHz downsampled data to 5 kHz.")

        all_peaks = {}
        window_peaks = []
        span = 5000

        for i, channel_data in enumerate(data_to_visualize):
            time = np.linspace(window_start / fs_resampled, window_end / fs_resampled, len(channel_data))
            original_channel_number = channels_to_plot[i]

            color = 'orange' if original_channel_number == target_channel else '#0f4c5c'
            plt.plot(
                time - start_time,
                channel_data + (num_visible_channels - 1 - i) * span,
                color=color,
                alpha=0.9,
                linewidth=0.5
            )

            threshold = channel_thresholds.get(original_channel_number, None)
            if threshold is not None:
                threshold_array = [threshold] * len(channel_data)
                peaks = detect_single_event_peaks(channel_data, threshold_array)
                averaged_peaks = average_peaks(peaks)
                
                all_peaks[original_channel_number] = [peak + window_start for peak in averaged_peaks]
                
                if original_channel_number not in channel_suprathreshold_events:
                    channel_suprathreshold_events[original_channel_number] = []

                absolute_peaks = [peak + window_start for peak in averaged_peaks]
                filtered_peaks = exclude_noisy_intervals(absolute_peaks, fs_resampled)
                new_events = [peak / fs_resampled for peak in filtered_peaks]
                channel_suprathreshold_events[original_channel_number] = list(set(
                    channel_suprathreshold_events[original_channel_number] + new_events
                ))

                absolute_peaks = [peak + window_start for peak in averaged_peaks]
                if original_channel_number == int(target_channel):
                    # Filter close peaks by amplitude
                    filtered_peaks = filter_peaks_by_amplitude(filtered_peaks, resampled_data[original_channel_number-1])
                    for peak in filtered_peaks:
                        peak_color = 'purple' if original_channel_number == target_channel else 'red'
                        plt.scatter(
                            (peak) / fs_resampled - start_time,
                            channel_data[peak - window_start] + (num_visible_channels - 1 - i) * span,
                            color=peak_color,
                            alpha=0.5,
                            s=50
                        )
                        for peak in filtered_peaks:
                            if not is_peak_already_annotated(peak, real_peaks, fs_resampled):
                                real_peaks.append(peak)

        plt.ylim(-span, num_visible_channels * span)
        plt.yticks(np.arange(0, num_visible_channels) * span, reversed(channels_to_plot))
    else:
        plt.ylim(0, 1)

    real_peaks = sorted(set(real_peaks))
    

    plt.xticks(np.linspace(0, window_duration, 11), np.round(np.linspace(start_time, start_time + window_duration, 11), 2))
    plt.xlabel('Time (s)')
    plt.ylabel('Channel')
    plt.grid(True, axis='x', linestyle='--', color='gray', alpha=0.5)
    plt.tight_layout()

    yticks_positions, _ = plt.yticks()
    y_ticks.clear()
    y_ticks.extend(yticks_positions)
    plt.gca().yaxis.set_ticks_position('both')
    plt.tick_params(axis='y', labelleft=True, labelright=True)
    plt.subplots_adjust(left=0.05, right=0.95)

    redraw_threshold_lines()
    
    plt.show()

def is_peak_already_annotated(new_peak, real_peaks, fs_resampled, min_distance_ms=100):
    min_distance_samples = int(min_distance_ms * fs_resampled / 1000)
    for peak in real_peaks:
        if abs(peak - new_peak) <= min_distance_samples:
            return True
    return False
    
def detect_suprathreshold_peaks_for_target_channel():
    global resampled_data, fs_resampled, target_chan, real_peaks, channel_thresholds, channel_suprathreshold_events
    
    target_channel_data = resampled_data[int(target_chan) - 1]
    threshold = channel_thresholds.get(int(target_chan), None)
    
    if threshold is not None:
        threshold_array = [threshold] * len(target_channel_data)
        peaks = detect_single_event_peaks(target_channel_data, threshold_array)
        averaged_peaks = average_peaks(peaks)
        filtered_peaks = exclude_noisy_intervals(averaged_peaks, fs_resampled)
        
        channel_suprathreshold_events[int(target_chan)] = [(peak / fs_resampled) for peak in filtered_peaks]

def update_thresholds():
    global channel_thresholds, real_peaks, channel_suprathreshold_events

    # Update peaks for entire recording
    define_consensus_target_peaks()

    # Update suprathreshold events
    channel_suprathreshold_events.clear()
    for channel, threshold in channel_thresholds.items():
        channel_data = resampled_data[channel - 1]
        threshold_array = [threshold] * len(channel_data)
        peaks = detect_single_event_peaks(channel_data, threshold_array)
        averaged_peaks = average_peaks(peaks)
        filtered_peaks = exclude_noisy_intervals(averaged_peaks, fs_resampled)
        # Convert filtered peaks to time points
        channel_suprathreshold_events[channel] = [(peak / fs_resampled) for peak in filtered_peaks]

def average_peaks(peaks):
    averaged_peaks = []
    current_peak_group = []

    for peak in peaks:
        if current_peak_group and (peak - current_peak_group[-1] < int(0.1 * fs_resampled)):
            current_peak_group.append(peak)
        else:
            if current_peak_group:
                averaged_peak = int(np.mean(current_peak_group))
                averaged_peaks.append(averaged_peak)
            current_peak_group = [peak]

    if current_peak_group:
        averaged_peak = int(np.mean(current_peak_group))
        averaged_peaks.append(averaged_peak)

    return averaged_peaks

def compute_consensus_peaks(all_peaks):
    peak_events = []
    for channel, peaks in all_peaks.items():
        for peak in peaks:
            peak_events.append((peak, channel))
    peak_events.sort()

    consensus_groups = []
    current_group = []

    for i, (peak, channel) in enumerate(peak_events):
        if not current_group:
            current_group = [(peak, channel)]
        else:
            if abs(peak - current_group[0][0]) <= 0.05 * fs_resampled:
                current_group.append((peak, channel))
            else:
                if len(set(ch for _, ch in current_group)) >= 2:
                    consensus_groups.append(current_group)
                current_group = [(peak, channel)]
    if len(set(ch for _, ch in current_group)) >= 2:
        consensus_groups.append(current_group)
    
    consensus_peaks = []
    for group in consensus_groups:
        non_target_channels = set(ch for _, ch in group if ch != int(target_chan))
        if len(non_target_channels) >= 2:
            avg_timestamp = np.mean([peak for peak, _ in group])
            consensus_peaks.append(avg_timestamp)
    
    return consensus_peaks
    
def get_all_absolute_peaks():
    return all_absolute_peaks

# Define a function to redraw the threshold lines every time the plot is updated
def redraw_threshold_lines():
    global channel_thresholds
    
    for channel, threshold in channel_y_coords.items():
        plt.axhline(y=threshold, linestyle='--', color='green', alpha=0.5)

def detect_single_event_peaks(channel_data, threshold, min_distance_ms=100):
    """
    Detect peaks in the channel data based on threshold and exclude noisy intervals.
    """
    peaks = []
    in_event = False
    max_peak_index = None
    max_peak_value = -np.inf
    min_distance_samples = int(min_distance_ms * fs_resampled / 1000)  # Convert ms to samples

    for i, value in enumerate(channel_data):
        if (threshold[i] > 0 and value > threshold[i]) or (threshold[i] < 0 and value < threshold[i]):
            if not in_event:
                in_event = True
                max_peak_index = i
                max_peak_value = value
            else:
                if abs(value) > abs(max_peak_value):
                    max_peak_value = value
                    max_peak_index = i
        else:
            if in_event:
                if max_peak_index is not None:
                    peaks.append(max_peak_index)
                in_event = False

    if peaks:
        filtered_peaks = [peaks[0]]
        for i in range(1, len(peaks)):
            if peaks[i] - peaks[i-1] > min_distance_samples:
                filtered_peaks.append(peaks[i])
            elif abs(channel_data[peaks[i]]) > abs(channel_data[filtered_peaks[-1]]):
                filtered_peaks[-1] = peaks[i]
        
        # Exclude peaks within noisy intervals
        return exclude_noisy_intervals(filtered_peaks, fs_resampled)
    else:
        return []

all_peaks = []
timeframe = []

def filter_peaks_by_amplitude(peaks, target_data, window_ms=100):
    if not peaks:
        return []
    
    sorted_peaks = sorted(peaks)
    filtered_peaks = []
    current_group = [sorted_peaks[0]]
    window_samples = int(window_ms * fs_resampled / 1000)  # Convert ms to samples
    
    for peak in sorted_peaks[1:]:
        if peak - current_group[0] <= window_samples:
            current_group.append(peak)
        else:
            best_peak = max(current_group, key=lambda x: abs(target_data[x]))
            filtered_peaks.append(best_peak)
            current_group = [peak]
    
    if current_group:
        best_peak = max(current_group, key=lambda x: abs(target_data[x]))
        filtered_peaks.append(best_peak)
    
    return filtered_peaks

def update_interactive_components():
    interact(visualize_eeg_data,
             channels_to_plot_str=channels_to_plot_text,
             target_channel=target_channel_text,
             window_duration_str=window_duration_text,
             threshold_str=threshold_text,
             apply_hamming=apply_hamming_checkbox,
             apply_double_hamming=apply_double_hamming_checkbox,
             apply_bandpass=apply_bandpass_checkbox,
             lowcut=lowcut_text,
             highcut=highcut_text,
             scale_amplitude=fixed(scale_amplitude),
             peak_mode=peak_mode_dropdown)

scale_amplitude = 1.0

def on_key(event):
    global start_time, scale_amplitude
    if event.key == 'left':
        start_time -= window_duration if scroll_fast_checkbox.value else round(window_duration * 0.2, 2)
        start_time = max(0, start_time)
    elif event.key == 'right':
        start_time += window_duration if scroll_fast_checkbox.value else round(window_duration * 0.2, 2)
        start_time = min(t - window_duration, start_time)
    elif event.key == 'up':
        scale_amplitude *= 1.5
    elif event.key == 'down':
        scale_amplitude /= 1.5

    # Ensure window_duration_text.value is passed as a string
    window_duration_str = window_duration_text.value if isinstance(window_duration_text.value, str) else str(window_duration_text.value)

    visualize_eeg_data(
        channels_to_plot_text.value,  # Pass channels_to_plot_str
        int(target_channel_text.value),  # Pass target_channel
        window_duration_str,  # Pass window_duration_str
        apply_hamming=apply_hamming_checkbox.value,
        apply_double_hamming=apply_double_hamming_checkbox.value,
        apply_bandpass=apply_bandpass_checkbox.value,
        lowcut=lowcut_text.value or None,
        highcut=highcut_text.value or None,
        scale_amplitude=scale_amplitude
    )
    
    redraw_threshold_lines()
    plt.draw()

channel_y_coords = {}

def onclick(event):
    global channel_y_coords, channel_thresholds

    if event.inaxes is not None:
        y_coord = event.ydata
        channels_to_plot_str = select_chan.strip()
        
        channels_to_plot = []
        if channels_to_plot_str:
            for part in channels_to_plot_str.split(","):
                part = part.strip()
                if "-" in part:
                    start, end = map(int, part.split("-"))
                    channels_to_plot.extend(range(start, end + 1))
                else:
                    channels_to_plot.append(int(part))

        reversed_y_ticks = y_ticks[::-1]

        distances = {channel: abs(reversed_y_ticks[channel_index] - y_coord) 
                    for channel_index, channel in enumerate(channels_to_plot)}
        closest_channels = sorted(distances, key=distances.get)[:2]

        peak_mode = peak_mode_dropdown.value
        if peak_mode == 'Positive':
            # Select the channel with the higher index
            closest_channel = max(closest_channels)
        elif peak_mode == 'Negative':
            # Select the channel with the lower index
            closest_channel = min(closest_channels)

        closest_channel_index = channels_to_plot.index(closest_channel)

        if event.button == 1:  # Left click to add/update threshold
            channel_y_coords[closest_channel] = y_coord
            channel_thresholds[closest_channel] = y_coord - reversed_y_ticks[closest_channel_index]
            print("New threshold value for Channel", closest_channel, ":", channel_thresholds[closest_channel])
        elif event.button == 3:  # Right click to delete threshold
            if closest_channel in channel_y_coords:
                del channel_y_coords[closest_channel]
            if closest_channel in channel_thresholds:
                del channel_thresholds[closest_channel]
            print("Threshold value for Channel", closest_channel, "deleted")

        detect_suprathreshold_peaks_for_target_channel()

        visualize_eeg_data(channels_to_plot_text.value,
                          int(target_channel_text.value),
                          window_duration_text.value,
                          apply_hamming=apply_hamming_checkbox.value,
                          apply_double_hamming=apply_double_hamming_checkbox.value,
                          apply_bandpass=apply_bandpass_checkbox.value,
                          lowcut=lowcut_text.value or None,
                          highcut=highcut_text.value or None,
                          scale_amplitude=scale_amplitude)
        
        redraw_threshold_lines()
        
        if closest_channel == int(target_channel_text.value):
            plt.axhline(y=y_coord, linestyle='--', color='green')
        
        plt.draw()

export_ev2_button = widgets.Button(description="Export to EV2")
export_ev2_button.on_click(save_ev2_files)
display(export_ev2_button)

# Define interactive components
window_duration_text = widgets.Text(value='10', description='Window Duration')
channels_to_plot_text = widgets.Text(value=select_chan, description='Channels to Plot')
target_channel_text = widgets.Text(value=target_chan, description='Target Channel')
apply_hamming_checkbox = widgets.Checkbox(value=True, description='Apply Hamming')
apply_double_hamming_checkbox = widgets.Checkbox(value=False, description='Apply Double Hamming')
apply_bandpass_checkbox = widgets.Checkbox(value=True, description='Apply Bandpass')
lowcut_text = widgets.Text(value='1', placeholder='Lowcut (Hz)', description='Lowcut:')
highcut_text = widgets.Text(value='30', placeholder='Highcut (Hz)', description='Highcut:')

# Global variable to store thresholds for each channel
thresholds = {}

interact(visualize_eeg_data,
         channels_to_plot_str=channels_to_plot_text,
         target_channel=target_channel_text,
         window_duration_str=window_duration_text,
         apply_hamming=apply_hamming_checkbox,
         apply_double_hamming=apply_double_hamming_checkbox,
         apply_bandpass=apply_bandpass_checkbox,
         lowcut=lowcut_text or None,
         highcut=highcut_text or None,
         scale_amplitude=fixed(scale_amplitude),
         **thresholds)

scroll_fast_checkbox = widgets.Checkbox(
    value=False,  # Unchecked by default
    description='Fast Scroll',
    disabled=False,
    indent=False
)

peak_mode_dropdown = widgets.Dropdown(
    options=['Positive', 'Negative'],
    value='Positive',
    description='Peak Mode'
)

plt.gcf().canvas.mpl_connect('key_press_event', on_key)
plt.gcf().canvas.mpl_connect('button_press_event', onclick)  # Connect onclick event
scale_amplitude = 1.0
lowcut_text.observe(on_lowcut_change, names='value')
highcut_text.observe(on_highcut_change, names='value')

snapshot_button = widgets.Button(description='Take a Snapshot', button_style='info')
snapshot_button.on_click(save_snapshot)
display(peak_mode_dropdown)
display(snapshot_button)
display(scroll_fast_checkbox)

reset_button = widgets.Button(description="Reset Detection")
reset_button.on_click(lambda b: reset_detection_variables())
display(reset_button) 

Enter the channels to plot (comma-separated or interval):  4-10
Select the target channel:  6


Button(description='Export to EV2', style=ButtonStyle())

interactive(children=(Text(value='4-10', description='Channels to Plot'), Text(value='6', description='Target …

Dropdown(description='Peak Mode', options=('Positive', 'Negative'), value='Positive')

Button(button_style='info', description='Take a Snapshot', style=ButtonStyle())

Checkbox(value=False, description='Fast Scroll', indent=False)

Button(description='Reset Detection', style=ButtonStyle())

In [10]:
try:
    real_peaks
except NameError:
    real_peaks = []

# Initialize global variables
start_time = 0
scale_amplitude = 10.0

def read_ev2_file(file_path):
    events = []
    with open(file_path, 'r') as file:
        for line in file.readlines()[1:]:  # Skip the header line
            parts = line.split()
            if len(parts) == 6:
                event_time = int(parts[5]) / 10  # Divide by 10
                events.append(event_time)
    print(f"Read {len(events)} events from file.")
    return events

def import_ev2_file(_):
    global real_peaks

    file_path = file_path_text.value.strip(' "\'')
    file_path = os.path.normpath(file_path)  # Normalize the file path
    
    if os.path.exists(file_path):
        events = read_ev2_file(file_path)
        new_peaks = [int(event) for event in events]  # Convert to sample indices
        real_peaks = list(set(new_peaks))  # Remove duplicates and sort
        real_peaks.sort()

        print(f"Imported {len(events)} events from {file_path}")
        print(f"Updated real_peaks: {real_peaks}")

        # Re-visualize the EEG data with the new events
        visualize_eeg_data(
            window_duration_text.value,
            hidden_channels_text.value,
            apply_hamming_checkbox.value,
            apply_double_hamming_checkbox.value,
            apply_bandpass_checkbox.value,
            lowcut_text.value,
            highcut_text.value,
            scale_amplitude
        )
    else:
        print("File not found. Please check the path and try again.")

def visualize_eeg_data(window_duration_str, hidden_channels_str, apply_hamming=False, apply_double_hamming=False, apply_bandpass=False, lowcut=1, highcut=30, scale_amplitude=10.0):
    plt.clf()  

    global resamp_data, fs_resampled, num_channels, window_duration, start_time, real_peaks

    window_duration = float(window_duration_str)
    window_start = int(start_time * fs_resampled)
    window_end = window_start + int(window_duration * fs_resampled)

    data_length = resamp_data.shape[1]
    if window_end > data_length:
        window_end = data_length

    hidden_channels = []
    if hidden_channels_str.strip() and hidden_channels_str != '':
        hidden_channels = [int(ch) for ch in hidden_channels_str.split(",")]

    selected_channels = np.array([channel for channel in range(1, num_channels + 1) if channel not in hidden_channels])

    num_visible_channels = len(selected_channels)
    span = 5000

    if num_visible_channels > 0:
        try:
            lowcut = float(lowcut) if lowcut not in [None, ''] else None
            highcut = float(highcut) if highcut not in [None, ''] else None
        except ValueError:
            print("Invalid input for lowcut or highcut. Please enter valid numerical values.")
            return

        nyquist_freq = fs_resampled / 2
        if highcut is not None and highcut > nyquist_freq:
            print(f"Highcut frequency {highcut} exceeds Nyquist frequency {nyquist_freq}.")
            data_to_visualize = resamp_data_high[selected_channels - 1, window_start:window_end] * scale_amplitude
        else:
            data_to_visualize = resamp_data[selected_channels - 1, window_start:window_end] * scale_amplitude

        if apply_hamming or apply_double_hamming:
            data_to_visualize = apply_hamming_filter(data_to_visualize, fs_resampled, double=apply_double_hamming)

        if apply_bandpass:
            try:
                if lowcut is not None or highcut is not None:
                    data_to_visualize = butter_bandpass_filter(data_to_visualize, lowcut, highcut, fs_resampled)
            except ValueError:
                print("Switching from 2 kHz downsampled data to 5 kHz.")

        for i, channel_data in enumerate(data_to_visualize):
            time = np.linspace(window_start / fs_resampled, window_end / fs_resampled, len(channel_data))
            plt.plot(time - start_time, channel_data + (num_visible_channels - 1 - i) * span, color='#335c67', alpha=0.7, linewidth=0.5)  
            
            for peak_index in list(real_peaks):
                peak_time = peak_index / fs_resampled
                if peak_time >= start_time and peak_time < start_time + window_duration:
                    peak_start = max(peak_time - 0.150, start_time)
                    peak_end = min(peak_time + 0.150, start_time + window_duration)
                    plt.axvspan(peak_start - start_time, peak_end - start_time, color='#540b0e', alpha=0.01)
                    plt.axvline(peak_time - start_time, color='k', linestyle='--', linewidth=0.5, alpha=0.3)

        plt.ylim(-span, num_visible_channels * span)  
        plt.yticks(np.arange(0, num_visible_channels) * span, reversed(selected_channels))  
    else:
        plt.ylim(0, 1)  

    plt.xticks(np.linspace(0, window_duration, 11), np.round(np.linspace(start_time, start_time + window_duration, 11), 2))  
    plt.xlabel('Time (s)')
    plt.ylabel('Channel')
    plt.grid(True, axis='x', linestyle='--', color='#e09f3e', alpha=0.5)
    plt.tight_layout()
    plt.gca().yaxis.set_ticks_position('both')
    plt.tick_params(axis='y', labelleft=True, labelright=True)
    plt.subplots_adjust(left=0.05, right=0.95)
    plt.show()

snapshot_counter = 1
    
def save_snapshot(_):
    global snapshot_counter
    snap_path = os.path.join(cleaned_path[:-4] + f'_overall_window{snapshot_counter}.svg')
    snapshot_counter += 1
    plt.savefig(snap_path, bbox_inches='tight')

def on_key(event):
    global start_time
    if event.key == 'left':
        start_time -= window_duration if scroll_fast_checkbox.value else round(window_duration * 0.2, 2)
        start_time = max(0, start_time)
    elif event.key == 'right':
        start_time += window_duration if scroll_fast_checkbox.value else round(window_duration * 0.2, 2)
        start_time = min(t - window_duration, start_time)

    # Ensure window_duration_text.value is passed as a string
    window_duration_str = window_duration_text.value if isinstance(window_duration_text.value, str) else str(window_duration_text.value)

    visualize_eeg_data(channels_to_plot_text.value,
                       int(target_channel_text.value),
                       window_duration_str,
                       apply_hamming=apply_hamming_checkbox.value, 
                       apply_double_hamming=apply_double_hamming_checkbox.value,
                       apply_bandpass=apply_bandpass_checkbox.value,
                       lowcut=lowcut_text.value or None,
                       highcut=highcut_text.value or None,
                       scale_amplitude=scale_amplitude)

    plt.draw()

window_duration_text = widgets.Text(value='10', description='Window Duration')
hidden_channels_text = widgets.Text(value='', description='Hidden Channels')
apply_hamming_checkbox = widgets.Checkbox(value=True, description='Apply Hamming')
apply_double_hamming_checkbox = widgets.Checkbox(value=False, description='Apply Double Hamming')
apply_bandpass_checkbox = widgets.Checkbox(value=True, description='Apply Bandpass')
lowcut_text = widgets.Text(value='1', placeholder='Lowcut (Hz)', description='Lowcut:')
highcut_text = widgets.Text(value='30', placeholder='Highcut (Hz)', description='Highcut:')

plt.gcf().canvas.mpl_connect('key_press_event', on_key)

scale_amplitude = 10.0

def on_arrow_key(event):
    global scale_amplitude
    if event.key == 'up':
        scale_amplitude *= 1.5
    elif event.key == 'down':
        scale_amplitude /= 1.5
    visualize_eeg_data(
        window_duration_text.value,
        hidden_channels_text.value,
        apply_hamming_checkbox.value,
        apply_double_hamming_checkbox.value,
        apply_bandpass_checkbox.value,
        lowcut_text.value,
        highcut_text.value,
        scale_amplitude
    )
    plt.draw()

plt.gcf().canvas.mpl_connect('key_press_event', on_arrow_key)

def remove_duplicates():
    global real_peaks
    real_peaks = list(set(real_peaks))

def save_ev2_file(_):
    global real_peaks, spike_item_number
    spa_info = []
    item_number = 1
    recalc_peaks = [peak * 10 for peak in real_peaks]
    for spa_time in recalc_peaks:
        spa_info.append([item_number, 0, 0, 0, 0, spa_time])
        item_number += 1

    output_file = os.path.join(cleaned_path[:-4] + '_spa_peaks.ev2')

    with open(output_file, 'w') as f:
        f.write("SPA Type Zeros1 Zeros2 Zeros3 SampleTime\n")
        for info in spa_info:
            f.write(f"{info[0]} {info[1]} {info[2]} {info[3]} {info[4]} {info[5]}\n")

def on_click(event):
    global start_time, real_peaks, resamp_data, fs_resampled, target_chan_text

    if event.button == 3:  # Right-click to remove a peak
        click_time = event.xdata + start_time
        for i, peak_index in enumerate(real_peaks):
            peak_time = peak_index / fs_resampled
            if abs(click_time - peak_time) < 0.150:
                del real_peaks[i]
                break
        visualize_eeg_data(
            window_duration_text.value,
            hidden_channels_text.value,
            apply_hamming_checkbox.value,
            apply_double_hamming_checkbox.value,
            apply_bandpass_checkbox.value,
            lowcut_text.value,
            highcut_text.value,
            scale_amplitude
        )
        plt.draw()
    elif event.button == 1:  # Left-click to add a peak
        click_time = event.xdata + start_time
        click_sample = int(click_time * fs_resampled)
    
        try:
            target_channel = int(target_chan_text.value)
        except ValueError:
            return
    
        window_start = max(click_sample - int(0.150 * fs_resampled), 0)
        window_end = min(click_sample + int(0.150 * fs_resampled), resamp_data.shape[1] - 1)
        window_data = resamp_data[target_channel - 1, window_start:window_end]
    
        if peak_mode_dropdown.value == 'Positive':
            peak_index_within_window = np.argmax(window_data)
        elif peak_mode_dropdown.value == 'Negative':
            peak_index_within_window = np.argmin(window_data)
        else:
            peak_index_within_window = np.argmax(np.abs(window_data))
    
        absolute_peak_index_within_window = window_start + peak_index_within_window
        real_peaks.append(absolute_peak_index_within_window)
        real_peaks = list(set(real_peaks))
        real_peaks.sort()

        visualize_eeg_data(
            window_duration_text.value,
            hidden_channels_text.value,
            apply_hamming_checkbox.value,
            apply_double_hamming_checkbox.value,
            apply_bandpass_checkbox.value,
            lowcut_text.value,
            highcut_text.value,
            scale_amplitude
        )
        plt.draw()

remove_duplicates_button = widgets.Button(description='Remove Duplicates')
remove_duplicates_button.on_click(remove_duplicates)

# Widgets
target_chan_text = widgets.Text(value=target_chan if 'target_chan' in globals() else '', placeholder='Target Channel', description='Target Channel:')
display(target_chan_text)

save_ev2_button = widgets.Button(description='Save .ev2 file', button_style='info')
save_ev2_button.on_click(save_ev2_file)

file_path_text = widgets.Text(value='', description='File Path:')
import_ev2_button = widgets.Button(description='Import .ev2 file', button_style='success')
import_ev2_button.on_click(import_ev2_file)

display(file_path_text)
display(import_ev2_button)
display(remove_duplicates_button)
display(save_ev2_button)
lowcut_text.observe(on_lowcut_change, names='value')
highcut_text.observe(on_highcut_change, names='value')
plt.gcf().canvas.mpl_connect('button_press_event', on_click)
plt.gcf().canvas.mpl_connect('key_press_event', on_key)

interact(
    visualize_eeg_data,
    window_duration_str=window_duration_text,
    hidden_channels_str=hidden_channels_text,
    apply_hamming=apply_hamming_checkbox,
    apply_double_hamming=apply_double_hamming_checkbox,
    apply_bandpass=apply_bandpass_checkbox,
    lowcut=lowcut_text,
    highcut=highcut_text,
    scale_amplitude=widgets.FloatSlider(value=scale_amplitude, min=1, max=100, step=1, description='Scale Amplitude')
)
scroll_fast_checkbox = widgets.Checkbox(
    value=False,  # Unchecked by default
    description='Scroll Fast',
    disabled=False,
    indent=False
)
peak_mode_dropdown = widgets.Dropdown(
    options=['Positive', 'Negative'],
    value='Positive',
    description='Peak Mode'
)
display(peak_mode_dropdown)

snapshot_button = widgets.Button(description='Take a Snapshot', button_style='info')
snapshot_button.on_click(save_snapshot)
display(snapshot_button)
display(scroll_fast_checkbox)

Text(value='6', description='Target Channel:', placeholder='Target Channel')

Text(value='', description='File Path:')

Button(button_style='success', description='Import .ev2 file', style=ButtonStyle())

Button(description='Remove Duplicates', style=ButtonStyle())

Button(button_style='info', description='Save .ev2 file', style=ButtonStyle())

interactive(children=(Text(value='10', description='Window Duration'), Text(value='', description='Hidden Chan…

Dropdown(description='Peak Mode', options=('Positive', 'Negative'), value='Positive')

Button(button_style='info', description='Take a Snapshot', style=ButtonStyle())

Checkbox(value=False, description='Scroll Fast', indent=False)

In [11]:
def save_figure(fig, filename):
    fig.savefig(filename, format='svg')

def plot_heatmap_interactive(corrected_snippets, fs_resampled, save_button):
    # Compute the mean of corrected snippets along axis 0 and convert to microvolts
    heatmap_data = np.mean(corrected_snippets, axis=0) * INT16_TO_UV
    fig, ax = plt.subplots(figsize=(10, 6))
    ax.set_ylim(heatmap_data.min(), heatmap_data.max())

    # Create sliders with 1-based index for channels
    channel_slider = widgets.IntRangeSlider(
        value=[1, heatmap_data.shape[0]],  # start from 1
        min=1,  # starting from 1
        max=heatmap_data.shape[0],
        step=1,
        description="Channels",
        continuous_update=False,
    )

    # Set amplitude slider range symmetric around zero
    max_abs_val = max(abs(heatmap_data.min()), abs(heatmap_data.max()))
    amplitude_slider = widgets.FloatRangeSlider(
        value=[-max_abs_val, max_abs_val],
        min=-max_abs_val,
        max=max_abs_val,
        step=(2 * max_abs_val) / 100,
        description="Amplitude (μV)",
        continuous_update=False,
    )
    print("Heatmap min/max:", heatmap_data.min(), heatmap_data.max())

    def update_plot(channel_range, amplitude_range):
        selected_channels = heatmap_data[channel_range[0]-1:channel_range[1], :]
        vmin, vmax = amplitude_range

        # Refresh the plot - selected channels fixalas!!
        ax.clear()
        im = ax.imshow(
            selected_channels,
            aspect="auto",
            cmap="jet_r",  # Use a diverging colormap
            extent=[-300, 300, channel_range[1], channel_range[0]],  # Adjust to 1-based indexing
            vmin=vmin,
            vmax=vmax,
        )
        ax.set_xlabel("Time (ms)")
        ax.set_ylabel("Channel")
        ax.set_title("LFPg Heatmap")

        # Add colorbar only once
        if not hasattr(update_plot, "colorbar") or update_plot.colorbar is None:
            update_plot.colorbar = fig.colorbar(im, ax=ax, label="Amplitude (μV)")
        else:
            # Update the color limits for the existing colorbar
            im.set_clim(vmin, vmax)
            update_plot.colorbar.set_label("Amplitude (μV)")

        plt.draw()

    # Link sliders to the plotting function
    widgets.interactive(
        update_plot,
        channel_range=channel_slider,
        amplitude_range=amplitude_slider,
    )

    # Display sliders and initial plot
    display(channel_slider, amplitude_slider)
    update_plot(channel_slider.value, amplitude_slider.value)

    # Define save button action
    def on_save_button_clicked(b):
        save_path = os.path.join(cleaned_path[:-4] + "_lfpg_heatmap.svg")
        save_figure(fig, save_path)
        print(f"Heatmap saved to {save_path}")

    save_button.on_click(on_save_button_clicked)

# Function to plot averaged waveforms
def plot_averaged_waveforms(averaged_waveforms, channels_to_plot, fs_resampled, save_button):
    time = np.linspace(-0.3, 0.3, averaged_waveforms.shape[1])
    fig, ax = plt.subplots(figsize=(10, 6))
    # Convert waveform to microvolts for plotting
    for i, waveform in enumerate(averaged_waveforms):
        ax.plot(time, waveform, label=f'Channel {channels_to_plot[i]+1}')  # Correct 1-based index
    ax.set_xlabel('Time (s)')
    ax.set_ylabel('Amplitude (μV)')
    ax.set_title('Averaged LFP Sweep Waveforms')
    ax.legend()
    plt.show()
    print("Waveforms min/max:", averaged_waveforms.min(), averaged_waveforms.max())

    # Define save button action
    def on_save_button_clicked(b):
        save_path = os.path.join(cleaned_path[:-4] + "_lfpg_averaged_waveforms.svg")
        save_figure(fig, save_path)
        print(f"Averaged waveforms saved to {save_path}")

    save_button.on_click(on_save_button_clicked)

# Modify process_lfp_data for 1-based indexing and baseline correction
def process_lfp_data(resampled_data, fs_resampled, real_peaks, select_chan):
    # User input for selected channels (1-based index)
    channels_to_plot_str = select_chan.strip()
    channels_to_plot = []
    if channels_to_plot_str:
        for part in channels_to_plot_str.split(","):
            part = part.strip()
            if "-" in part:
                start, end = map(int, part.split("-"))
                channels_to_plot.extend(range(start-1, end))  # Convert to 0-based index
            else:
                channels_to_plot.append(int(part) - 1)  # Convert to 0-based index

    # Step A: Filter the data
    filtered_data = butter_bandpass_filter(resampled_data, 1, 30, fs_resampled)
    hamming_filtered_data = apply_hamming_filter(filtered_data, fs_resampled)

    # Step B: Extract snippets
    snippet_range = int(0.3 * fs_resampled)
    snippets = []
    for peak in real_peaks:
        if peak - snippet_range >= 0 and peak + snippet_range < hamming_filtered_data.shape[1]:
            snippets.append(hamming_filtered_data[:, peak - snippet_range:peak + snippet_range])

    snippets = np.array(snippets)  # Shape: (num_events, num_channels, snippet_length)

    # Baseline correction
    baseline_period = int(0.2 * fs_resampled)  # First 200 ms
    baselines = snippets[:, :, :baseline_period].mean(axis=2, keepdims=True)
    corrected_snippets = snippets - baselines

    # Step D: Recurrence frequency
    total_noisy_duration = sum(end - start for start, end in noisy_intervals)

    # Compute the effective duration of the file after subtracting noisy intervals
    total_file_duration = resampled_data.shape[1] / fs_resampled  # Duration of the file in seconds
    effective_duration = total_file_duration - total_noisy_duration

    # Compute recurrence frequency using the effective duration
    recurrence_frequency = len(real_peaks) / effective_duration if effective_duration > 0 else 0

    # Step E: LFPg peak analysis (for selected channels)
    # Multiply by conversion for microvolt units
    max_peaks = np.max(corrected_snippets[:, channels_to_plot, :] * INT16_TO_UV, axis=2)
    min_peaks = np.min(corrected_snippets[:, channels_to_plot, :] * INT16_TO_UV, axis=2)

    isi_list = []
    real_peaks = np.array(real_peaks)
    for i in range(1, len(real_peaks)):
        prev_peak_time = real_peaks[i-1] / fs_resampled
        curr_peak_time = real_peaks[i] / fs_resampled
        # Check if any noisy interval overlaps with this ISI
        isi_noisy = False
        for start, end in noisy_intervals:
            # If the noisy interval is strictly between the two peaks (not touching peaks)
            if (prev_peak_time < end and curr_peak_time > start):
                isi_noisy = True
                break
        if isi_noisy:
            isi_list.append(np.nan)
        else:
            isi_list.append(curr_peak_time - prev_peak_time)
    isi_list = np.array(isi_list)
    isi_mean = np.nanmean(isi_list) if np.any(~np.isnan(isi_list)) else np.nan

    # Compute averaged waveforms for selected channels (convert to microvolts)
    averaged_waveforms = np.mean(corrected_snippets[:, channels_to_plot, :], axis=0) * INT16_TO_UV
    # Average amplitudes of max and min peaks for selected channels
    avg_max_peaks = np.mean(max_peaks, axis=0)
    avg_min_peaks = np.mean(min_peaks, axis=0)

    averaged_waveform_peaks = np.max(np.abs(averaged_waveforms), axis=1)
    max_channel_idx = channels_to_plot[np.argmax(averaged_waveform_peaks)] + 1
    
    # Create save button
    save_button = widgets.Button(description="Save Plots", button_style='success')
    display(save_button)

    # Plot averaged waveforms
    plot_averaged_waveforms(averaged_waveforms, channels_to_plot, fs_resampled, save_button)

    # Step G: Interactive Heatmap
    plot_heatmap_interactive(corrected_snippets, fs_resampled, save_button)
    
    # Step H: Export max and min peaks for selected channels
    max_min_peaks_data = {
        f"Channel {channels_to_plot[i]+1} Max Peaks (μV)": max_peaks[:, i]
        for i in range(len(channels_to_plot))
    }
    max_min_peaks_data.update({
        f"Channel {channels_to_plot[i]+1} Min Peaks (μV)": min_peaks[:, i]
        for i in range(len(channels_to_plot))
    })
    max_min_peaks_df = pd.DataFrame(max_min_peaks_data)

    # Add ISI list as a column to the max_min_peaks DataFrame (still in seconds)
    isi_df = pd.DataFrame({"ISI List (s)": isi_list})
    max_min_peaks_df = pd.concat([max_min_peaks_df, isi_df], axis=1)

    max_min_peaks_df.to_csv(os.path.join(cleaned_path[:-4] + "_lfp_max_min_peaks.csv"), index=False)

    # Save recurrence frequency, average peak values for each channel, and max channel in a separate CSV
    results = {
        "Recurrence Frequency (Hz)": [recurrence_frequency],
        "ISI Mean (s)": [isi_mean],
    }

    # Add separate columns for each channel's average max and min peaks
    for i, chan in enumerate(channels_to_plot):
        results[f"Avg Max Peak (Channel {chan+1}) (μV)"] = [avg_max_peaks[i]]  # Convert to 1-based index
        results[f"Avg Min Peak (Channel {chan+1}) (μV)"] = [avg_min_peaks[i]]  # Convert to 1-based index

    # Add the channel with the highest peak
    results["Channel with Highest Average Peak"] = [max_channel_idx]  # Keep 1-based index

    results_df = pd.DataFrame.from_dict(results, orient="index").T

    results_df.to_csv(os.path.join(cleaned_path[:-4] + "_lfp_analysis.csv"), index=False)

    print("Results saved to 'lfp_analysis.csv' and 'lfp_max_min_peaks.csv'.")
    print("abs_norm_snippets shape:", corrected_snippets.shape)
    print("max_peaks shape:", max_peaks.shape)
    print("min_peaks shape:", min_peaks.shape)

# Prompt the user for selected channels only if select_chan is not defined
try:
    select_chan
except NameError:
    select_chan = input("Please enter the channels for which you need the waveforms (e.g., '1,2,3' or '1-3'): ")

# Assuming butter_bandpass_filter and apply_hamming_filter are already defined
process_lfp_data(resampled_data, fs_resampled, real_peaks, select_chan)

Button(button_style='success', description='Save Plots', style=ButtonStyle())

Waveforms min/max: -0.2875447633280092 0.06181717994215024
Heatmap min/max: -0.2875447633280092 0.2570548262997873


IntRangeSlider(value=(1, 23), continuous_update=False, description='Channels', max=23, min=1)

FloatRangeSlider(value=(-0.2875447633280092, 0.2875447633280092), continuous_update=False, description='Amplit…

Results saved to 'lfp_analysis.csv' and 'lfp_max_min_peaks.csv'.
abs_norm_snippets shape: (423, 23, 1200)
max_peaks shape: (423, 7)
min_peaks shape: (423, 7)


2025-11-24 20:35:57.494 python3.13[84592:26391550] The class 'NSSavePanel' overrides the method identifier.  This method is implemented by class 'NSWindow'


In [12]:
def save_figure(fig, filename):
    fig.savefig(filename, format='svg')

def apply_laplacian_filter(data):
    filtered_data = np.zeros_like(data)
    for chan_idx in range(1, data.shape[0] - 1):  # Skip boundary channels
        filtered_data[chan_idx, :] = data[chan_idx + 1, :] - 2 * data[chan_idx, :] + data[chan_idx - 1, :]
    filtered_data[0, :] = data[1, :] - data[0, :]  # Forward difference
    filtered_data[-1, :] = data[-1, :] - data[-2, :]  # Backward difference
    return filtered_data

def normalize_heatmap_data(snippets, fs_resampled):
    baseline_period = int(0.2 * fs_resampled)  # First 200ms
    baseline = np.mean(snippets[:, :, :baseline_period], axis=2, keepdims=True)  # Calculate baseline from the first 200ms
    normalized_data = snippets - baseline  # Subtract baseline to normalize
    return normalized_data

def plot_heatmap_interactive(corrected_snippets, fs_resampled, save_button, heatmap_endpoint):
    # Normalize the data by subtracting the first 200ms as baseline
    normalized_data = normalize_heatmap_data(corrected_snippets, fs_resampled)
    # Convert to microvolts
    heatmap_data = np.mean(normalized_data, axis=0) * INT16_TO_UV
    print("Heatmap min/max:", heatmap_data.min(), heatmap_data.max())
    fig, ax = plt.subplots(figsize=(10, 6))

    channel_slider = widgets.IntRangeSlider(
        value=[1, heatmap_data.shape[0]],  # start from 1
        min=1,  # starting from 1
        max=heatmap_data.shape[0],
        step=1,
        description="Channels",
        continuous_update=False,
    )

    max_abs_val = max(abs(heatmap_data.min()), abs(heatmap_data.max()))
    amplitude_range = [-max_abs_val, max_abs_val]
    amplitude_slider = widgets.FloatRangeSlider(
        value=amplitude_range,
        min=-max_abs_val,
        max=max_abs_val,
        step=(2 * max_abs_val) / 100,
        description="Amplitude (μV)",
        continuous_update=False,
    )

    def update_plot(channel_range, amplitude_range):
        selected_channels = heatmap_data[channel_range[0]-1:channel_range[1], :]
        vmin, vmax = amplitude_range

        ax.clear()
        im = ax.imshow(
            selected_channels,
            aspect="auto",
            cmap="jet",
            extent=[-300, 300, channel_range[1], channel_range[0]],
            vmin=vmin,
            vmax=vmax,
        )
        ax.set_xlabel("Time (ms)")
        ax.set_ylabel("Channel")
        ax.set_title("CSD Heatmap")

        if not hasattr(update_plot, "colorbar") or update_plot.colorbar is None:
            update_plot.colorbar = fig.colorbar(im, ax=ax, label="Amplitude (μV)")
        else:
            im.set_clim(vmin, vmax)
            update_plot.colorbar.set_label("Amplitude (μV)")
        plt.draw()

    widgets.interactive(
        update_plot,
        channel_range=channel_slider,
        amplitude_range=amplitude_slider,
    )

    display(channel_slider, amplitude_slider)
    update_plot(channel_slider.value, amplitude_slider.value)

    def on_save_button_clicked(b):
        save_path = os.path.join(cleaned_path[:-4] + "_csd_heatmap.svg")
        save_figure(fig, save_path)
        print(f"Heatmap saved to {save_path}")

    save_button.on_click(on_save_button_clicked)

def plot_averaged_waveforms(averaged_waveforms, channels_to_plot, fs_resampled, save_button):
    time = np.linspace(-0.3, 0.3, averaged_waveforms.shape[1])
    fig, ax = plt.subplots(figsize=(10, 6))
    # Convert to microvolts for plotting
    for i, waveform in enumerate(averaged_waveforms):
        ax.plot(time, waveform, label=f'Channel {channels_to_plot[i]+1}')
    ax.set_xlabel('Time (s)')
    ax.set_ylabel('Amplitude (μV)')
    ax.set_title('Averaged CSD Sweep Waveforms')
    ax.legend()
    plt.show()
    print("Waveforms min/max:", averaged_waveforms.min(), averaged_waveforms.max())

    def on_save_button_clicked(b):
        save_path = os.path.join(cleaned_path[:-4] + "_csd_averaged_waveforms.svg")
        save_figure(fig, save_path)
        print(f"Averaged waveforms saved to {save_path}")

    save_button.on_click(on_save_button_clicked)

def process_csd_data(resampled_data, fs_resampled, real_peaks, select_chan):
    channels_to_plot_str = select_chan.strip()
    channels_to_plot = []
    if channels_to_plot_str:
        for part in channels_to_plot_str.split(","):
            part = part.strip()
            if "-" in part:
                start, end = map(int, part.split("-"))
                channels_to_plot.extend(range(start-1, end))  # Convert to 0-based index
            else:
                channels_to_plot.append(int(part) - 1)  # Convert to 0-based index

    bandpass_data = butter_bandpass_filter(resampled_data, 1, 30, fs_resampled)
    hamming_filtered_data = apply_hamming_filter(bandpass_data, fs_resampled)
    filtered_data = hamming_filtered_data

    csd_data = apply_laplacian_filter(filtered_data)
    snippet_range = int(0.3 * fs_resampled)
    snippets = []
    for peak in real_peaks:
        if peak - snippet_range >= 0 and peak + snippet_range < csd_data.shape[1]:
            snippets.append(csd_data[:, peak - snippet_range:peak + snippet_range])
    snippets = np.array(snippets)

    baseline_period = int(0.2 * fs_resampled)
    baselines = snippets[:, :, :baseline_period].mean(axis=2, keepdims=True)
    corrected_snippets = snippets - baselines

    # Convert corrected_snippets to μV for all subsequent analyses/plots/exports
    corrected_snippets_uv = corrected_snippets * INT16_TO_UV

    averaged_waveforms = np.mean(corrected_snippets_uv[:, channels_to_plot, :], axis=0)

    save_button = widgets.Button(description="Save Plots", button_style='success')
    display(save_button)

    plot_averaged_waveforms(averaged_waveforms, channels_to_plot, fs_resampled, save_button)

    total_noisy_duration = sum(end - start for start, end in noisy_intervals)
    total_file_duration = resampled_data.shape[1] / fs_resampled
    effective_duration = total_file_duration - total_noisy_duration
    recurrence_frequency = len(real_peaks) / effective_duration if effective_duration > 0 else 0

    max_peaks = np.max(corrected_snippets_uv[:, channels_to_plot, :], axis=2)
    min_peaks = np.min(corrected_snippets_uv[:, channels_to_plot, :], axis=2)

    heatmap_endpoint = max(np.abs(max_peaks).max(), np.abs(min_peaks).max())

    avg_max_peaks = np.mean(max_peaks, axis=0)
    avg_min_peaks = np.mean(min_peaks, axis=0)
    averaged_waveform_peaks = np.max(np.abs(averaged_waveforms), axis=1)
    max_channel_idx = channels_to_plot[np.argmax(averaged_waveform_peaks)] + 1

    isi_list = []
    real_peaks = np.array(real_peaks)
    for i in range(1, len(real_peaks)):
        prev_peak_time = real_peaks[i-1] / fs_resampled
        curr_peak_time = real_peaks[i] / fs_resampled
        isi_noisy = False
        for start, end in noisy_intervals:
            if (prev_peak_time < end and curr_peak_time > start):
                isi_noisy = True
                break
        if isi_noisy:
            isi_list.append(np.nan)
        else:
            isi_list.append(curr_peak_time - prev_peak_time)
    isi_list = np.array(isi_list)
    isi_mean = np.nanmean(isi_list) if np.any(~np.isnan(isi_list)) else np.nan

    plot_heatmap_interactive(corrected_snippets, fs_resampled, save_button, heatmap_endpoint)

    max_min_peaks_data = {
        f"Channel {channels_to_plot[i] + 1} Max Peaks (μV)": max_peaks[:, i]
        for i in range(len(channels_to_plot))
    }
    max_min_peaks_data.update({
        f"Channel {channels_to_plot[i] + 1} Min Peaks (μV)": min_peaks[:, i]
        for i in range(len(channels_to_plot))
    })
    max_min_peaks_df = pd.DataFrame(max_min_peaks_data)

    isi_df = pd.DataFrame({"ISI List (s)": isi_list})
    max_min_peaks_df = pd.concat([max_min_peaks_df, isi_df], axis=1)
    max_min_peaks_df.to_csv(os.path.join(cleaned_path[:-4]+"_csd_max_min.csv"), index=False)

    results = {
        "Recurrence Frequency (Hz)": [recurrence_frequency],
        "ISI Mean (s)": [isi_mean],
    }
    for i, chan in enumerate(channels_to_plot):
        results[f"Avg Max Peak (Channel {chan + 1}) (μV)"] = [avg_max_peaks[i]]
        results[f"Avg Min Peak (Channel {chan + 1}) (μV)"] = [avg_min_peaks[i]]
    results["Channel with Highest Average Peak"] = [max_channel_idx]
    results_df = pd.DataFrame.from_dict(results, orient="index").T
    results_df.to_csv(os.path.join(cleaned_path[:-4]+"_csd_analysis.csv"), index=False)
    print("Results saved to 'csd_analysis.csv' and 'csd_max_min.csv'.")

# Assuming butter_bandpass_filter and apply_hamming_filter are already defined
process_csd_data(resampled_data, fs_resampled, real_peaks, select_chan)

Button(button_style='success', description='Save Plots', style=ButtonStyle())

Waveforms min/max: -0.06174572981701603 0.06025774374594331
Heatmap min/max: -0.4981603936533857 0.5139616195482103


IntRangeSlider(value=(1, 23), continuous_update=False, description='Channels', max=23, min=1)

FloatRangeSlider(value=(-0.5139616195482103, 0.5139616195482103), continuous_update=False, description='Amplit…

Results saved to 'csd_analysis.csv' and 'csd_max_min.csv'.


In [13]:
def save_figure(fig, filename):
    fig.savefig(filename, format='svg')

defacto_peaks = [peak * 10 for peak in real_peaks]

def butter_highpass_filter(data, cutoff, fs, order=4):
    nyquist = 0.5 * fs
    normal_cutoff = cutoff / nyquist
    b, a = butter(order, normal_cutoff, btype='high', analog=False)
    filtered_data = filtfilt(b, a, data, axis=1)
    return filtered_data

def normalize_heatmap_data(snippets, fs):
    baseline_period = int((snippets.shape[-1]) / 3)
    baseline = np.mean(snippets[:, :, :baseline_period], axis=2, keepdims=True)
    normalized_data = snippets - baseline
    return normalized_data

def smooth_data_with_kernel(data, kernel_size, stride):
    kernel = np.ones(kernel_size) / kernel_size
    smoothed_data = np.apply_along_axis(lambda m: np.convolve(m, kernel, mode='valid')[::stride], axis=-1, arr=data)
    return smoothed_data

def plot_heatmap_interactive(corrected_snippets, fs, save_button, heatmap_endpoint):
    normalized_data = normalize_heatmap_data(corrected_snippets, fs)
    # Convert to microvolt
    heatmap_data = np.mean(normalized_data, axis=0) * INT16_TO_UV
    fig, ax = plt.subplots(figsize=(10, 6))

    channel_slider = widgets.IntRangeSlider(
        value=[1, heatmap_data.shape[0]],
        min=1,
        max=heatmap_data.shape[0],
        step=1,
        description="Channels",
        continuous_update=False,
    )

    max_abs_val = max(abs(heatmap_data.min()), abs(heatmap_data.max()))
    amplitude_range = [-max_abs_val, max_abs_val]
    amplitude_slider = widgets.FloatRangeSlider(
        value=amplitude_range,
        min=-max_abs_val,
        max=max_abs_val,
        step=(2 * max_abs_val) / 100,
        description="Amplitude (μV)",
        continuous_update=False,
    )

    def update_plot(channel_range, amplitude_range):
        selected_channels = heatmap_data[channel_range[0]-1:channel_range[1], :]
        vmin, vmax = amplitude_range

        ax.clear()
        im = ax.imshow(
            selected_channels,
            aspect="auto",
            cmap="jet_r",
            extent=[-300, 300, channel_range[1], channel_range[0]],
            vmin=-vmax,
            vmax=vmax,
        )
        ax.set_xlabel("Time (ms)")
        ax.set_ylabel("Channel")
        ax.set_title("MUA Heatmap")

        if not hasattr(update_plot, "colorbar") or update_plot.colorbar is None:
            update_plot.colorbar = fig.colorbar(im, ax=ax, label="Amplitude (μV)")
        else:
            im.set_clim(-vmax, vmax)
            update_plot.colorbar.set_ticks([-vmax, 0, vmax])
            update_plot.colorbar.set_label("Amplitude (μV)")

        plt.draw()

    widgets.interactive(
        update_plot,
        channel_range=channel_slider,
        amplitude_range=amplitude_slider,
    )

    display(channel_slider, amplitude_slider)
    update_plot(channel_slider.value, amplitude_slider.value)

    def on_save_button_clicked(b):
        save_path = os.path.join(cleaned_path[:-4] + "_mua_heatmap.svg")
        save_figure(fig, save_path)
        print(f"Heatmap saved to {save_path}")

    save_button.on_click(on_save_button_clicked)

def plot_averaged_waveforms(averaged_waveforms, channels_to_plot, fs, save_button):
    time = np.linspace(-0.3, 0.3, averaged_waveforms.shape[1])
    fig, ax = plt.subplots(figsize=(10, 6))
    # Convert to microvolt for plotting
    for i, waveform in enumerate(averaged_waveforms):
        ax.plot(time, waveform, label=f'Channel {channels_to_plot[i]+1}')
    ax.set_xlabel('Time (s)')
    ax.set_ylabel('Amplitude (μV)')
    ax.set_title('Averaged MUA Sweep Waveforms')
    ax.legend()
    plt.show()

    def on_save_button_clicked(b):
        save_path = os.path.join(cleaned_path[:-4] + "_mua_averaged_waveforms.svg")
        save_figure(fig, save_path)
        print(f"Averaged waveforms saved to {save_path}")

    save_button.on_click(on_save_button_clicked)

def plot_highest_event_all_channels(snippets, fs, channels_to_plot, target_chan_idx, save_button):
    # Find event with highest peak on target channel
    snippet_peaks = np.max(np.abs(snippets[:, target_chan_idx, :]), axis=1)
    max_idx = np.argmax(snippet_peaks)
    plot_snippet = snippets[max_idx, channels_to_plot, :] * INT16_TO_UV

    time = np.linspace(-0.3, 0.3, plot_snippet.shape[1])
    fig, ax = plt.subplots(figsize=(10, 6))
    for i, chan_idx in enumerate(channels_to_plot):
        ax.plot(time, plot_snippet[i], label=f'Channel {chan_idx+1}')
    ax.set_xlabel('Time (s)')
    ax.set_ylabel('Amplitude (μV)')
    ax.set_title(f'Waveforms of Highest Event on Target Channel {target_chan_idx+1} (Event {max_idx+1})')
    ax.legend()
    plt.show()

    def on_save_button_clicked(b):
        save_path = os.path.join(cleaned_path[:-4] + f"_mua_highest_event_all_channels.svg")
        save_figure(fig, save_path)
        print(f"Highest event waveforms saved to {save_path}")

    save_button.on_click(on_save_button_clicked)

def process_mua_data(resampled_data, fs, real_peaks, select_chan):
    channels_to_plot_str = select_chan.strip()
    channels_to_plot = []
    if channels_to_plot_str:
        for part in channels_to_plot_str.split(","):
            part = part.strip()
            if "-" in part:
                start, end = map(int, part.split("-"))
                channels_to_plot.extend(range(start-1, end))
            else:
                channels_to_plot.append(int(part) - 1)

    filtered_data = butter_highpass_filter(resampled_data, 500, fs)

    snippet_range = int(0.3 * fs)
    snippets = []
    for peak in defacto_peaks:
        if peak - snippet_range >= 0 and peak + snippet_range < filtered_data.shape[1]:
            snippets.append(filtered_data[:, peak - snippet_range:peak + snippet_range])
    snippets = np.array(snippets)

    normalized_snippets = normalize_heatmap_data(snippets, fs)
    smoothed_snippets = smooth_data_with_kernel(normalized_snippets, kernel_size=21, stride=7)
    abs_snippets = np.abs(smoothed_snippets)
    abs_norm_snippets = normalize_heatmap_data(abs_snippets, fs)

    # Convert to microvolt for all further analysis and plotting
    abs_norm_snippets_uv = abs_norm_snippets * INT16_TO_UV

    averaged_waveforms = np.mean(abs_norm_snippets_uv[:, channels_to_plot, :], axis=0)

    save_button = widgets.Button(description="Save Plots", button_style='success')
    display(save_button)

    target_chan_idx = int(target_chan_text.value.strip()) - 1
    plot_averaged_waveforms(averaged_waveforms, channels_to_plot, fs, save_button)
    plot_highest_event_all_channels(abs_norm_snippets, fs, channels_to_plot, target_chan_idx, save_button)

    total_noisy_duration = sum(end - start for start, end in noisy_intervals)
    total_file_duration = loaded_data.shape[1] / fs
    effective_duration = total_file_duration - total_noisy_duration
    recurrence_frequency = len(defacto_peaks) / effective_duration if effective_duration > 0 else 0

    max_peaks = np.max(abs_norm_snippets[:, channels_to_plot, :], axis=2)
    min_peaks = np.min(abs_norm_snippets[:, channels_to_plot, :], axis=2)

    heatmap_endpoint = max(np.abs(max_peaks).max(), np.abs(min_peaks).max())
    avg_max_peaks = np.mean(max_peaks, axis=0)
    avg_min_peaks = np.mean(min_peaks, axis=0)
    averaged_waveform_peaks = np.max(np.abs(averaged_waveforms), axis=1)
    max_channel_idx = channels_to_plot[np.argmax(averaged_waveform_peaks)] + 1

    isi_list = []
    real_peaks = np.array(real_peaks)
    for i in range(1, len(real_peaks)):
        prev_peak_time = real_peaks[i-1] / fs
        curr_peak_time = real_peaks[i] / fs
        isi_noisy = False
        for start, end in noisy_intervals:
            if (prev_peak_time < end and curr_peak_time > start):
                isi_noisy = True
                break
        if isi_noisy:
            isi_list.append(np.nan)
        else:
            isi_list.append(curr_peak_time - prev_peak_time)
    isi_list = np.array(isi_list)
    isi_mean = np.nanmean(isi_list) if np.any(~np.isnan(isi_list)) else np.nan

    # For the heatmap, pass the unconverted abs_snippets (since the conversion is inside the function)
    plot_heatmap_interactive(abs_snippets, fs, save_button, heatmap_endpoint)

    max_min_peaks_data = {
        f"Channel {channels_to_plot[i] + 1} Max Peaks (μV)": max_peaks[:, i]
        for i in range(len(channels_to_plot))
    }
    max_min_peaks_data.update({
        f"Channel {channels_to_plot[i] + 1} Min Peaks (μV)": min_peaks[:, i]
        for i in range(len(channels_to_plot))
    })
    max_min_peaks_df = pd.DataFrame(max_min_peaks_data)

    isi_df = pd.DataFrame({"ISI List (s)": isi_list})
    max_min_peaks_df = pd.concat([max_min_peaks_df, isi_df], axis=1)

    max_min_peaks_df.to_csv(os.path.join(cleaned_path[:-4] + "_mua_max_min_peaks.csv"), index=False)

    results = {
        "Recurrence Frequency (Hz)": [recurrence_frequency],
        "ISI Mean (s)": [isi_mean],
    }
    for i, chan in enumerate(channels_to_plot):
        results[f"Avg Max Peak (Channel {chan + 1}) (μV)"] = [avg_max_peaks[i]]
        results[f"Avg Min Peak (Channel {chan + 1}) (μV)"] = [avg_min_peaks[i]]
    results["Channel with Highest Average Peak"] = [max_channel_idx]
    results_df = pd.DataFrame.from_dict(results, orient="index").T
    results_df.to_csv(os.path.join(cleaned_path[:-4] + "_mua_analysis.csv"), index=False)

    print("Results saved to 'mua_analysis.csv' and 'mua_max_min_peaks.csv'.")

# Assuming butter_bandpass_filter and apply_hamming_filter are already defined
process_mua_data(loaded_data, fs, defacto_peaks, select_chan)

Button(button_style='success', description='Save Plots', style=ButtonStyle())

IntRangeSlider(value=(1, 23), continuous_update=False, description='Channels', max=23, min=1)

FloatRangeSlider(value=(-0.0004299720597861914, 0.0004299720597861914), continuous_update=False, description='…

Results saved to 'mua_analysis.csv' and 'mua_max_min_peaks.csv'.
