# nVital Analysis Notebook

# Full Pipeline

## Setup

### Environment Init

In [None]:
# Set up environment
%matplotlib
%matplotlib

# Imports
import platform
import matplotlib#must be before all other matplotlib imports
#matplotlib.use('Qt5Agg')
import numpy as np
import numpy.matlib
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import matplotlib.dates as md
from matplotlib.ticker import FuncFormatter
from typing import NamedTuple
from scipy import signal
from datetime import datetime
from datetime import timezone
import sys
import math
from pathlib import Path
import os
from glob import glob
from scipy.signal import savgol_filter, resample
import csv
from datetime import datetime, timedelta
from tkinter import filedialog, Tk
from matplotlib.widgets import Slider
from scipy.interpolate import CubicSpline, make_interp_spline
import re
import psutil


### Settings 

In [None]:

#Master Toggles
export_to_csv = True                 #toggle off all export data to csvs
plot_output = True                   #toggle off all plotting

# ANSI escape code for text styling
BOLD = "\033[1m"
GREEN = "\033[32m"  # Green text
RED = "\033[31m"    # Red text
DEFAULT = "\033[0m"

sampling_rate = 800
window = 1 #For activity calculations

#Temperature plot
plot_temperature = True            #plot temperature

#Acceleration plots
plot_accel_x = False                #plot x acceleration
plot_accel_y = False                #plot y acceleration
plot_accel_z = False                #plot z acceleration

#Activity plot
plot_activity = True                #plot activity

#Cardiac plots
plot_accel_z_clean = False          #plot z acceleration with motion artifacts removed
plot_hs = False                     #plot heart sound 
plot_hs_sh_raw = False              #plot shannon energy envelope of heart sound
plot_hs_sh_filtered = False         #plot shannon energy envelope of heart sound with LPF applied
plot_hs_lt = False                  #plot length transform of heart sound
plot_HR_peaks = False               #plot HR peaks
plot_HR = True                      #plot heart rate
plot_SQI = True                     #plot signal quality index

#Respiratory plots
use_hs_to_calculate_RR = False         #use heart sound to calculate respiratory rate
use_shannon_peaks_only_to_calculate_resp_env = True #If false, use both shannon peaks and length transform peaks to calculate respiratory envelope
plot_resp_sig_prefilt = False          #plot the respiratory signal before bpf applied
overlay_resp_sig_prefilt_w_hs = False  #Set to true to plot respiratory signal before bpf applied on top of hs_sh_filtered
plot_resp_sig = False                  #plot the respiratory signal after bpf applied
plot_RR_peaks = False                  #plot peaks of the respiratory signal
plot_RR = True                        #plot respiratory rate

#Debugging plots:
plot_SR = False                       #Sampling Rate of incoming data #Does not work yet
plot_std_xx = False                   #Standard Deviation of acceleration #Does not work yet

#Plot settings:
scrolling = False                     #produce additional interactive figure with scrolling 
scroll_window = 2                     #scrolling window size (seconds)

####################################################################################################################

class paramstruct(NamedTuple):
    fs: int
    total_time: float
    # HR Parameters
    vital_w: int
    hs: tuple
    hs_len: float
    hr: tuple
    hs_thresh: tuple
    activity_thresh: tuple
    min_scale: int
    max_scale: int
    min_scale_resp: int
    max_scale_resp: int
    fs_ds: int
    resp: tuple 
    resp_axis: str
    ds_resp: int
    fs_resp: int

params = paramstruct
params.fs = sampling_rate        
        
#Set up HR and RR Calculation Parameters
params.hs = (25, 390) #Original
params.vital_w = 4 #8 #60 # vital sign calculation window in seconds
params.vital_ovlp = 0.25 # overlap of windows #1 is no overlap, 0 is complete overlap
params.hs_len = .165 # max length (sec) of heart sound (Luisada, Mendoza, Alimurung (1948))
params.hs_thresh = (.00001,.3) 
params.hr = (400, 900) # min and max heart rate expected
params.activity_thresh = (.00001,.5)
params.downsample = 400 # frequency to downsample sensor data for vitals analysis
params.resp = (1.25, 3.0) # respiratory frequency #original
params.resp_axis = 'x' # axis to use to analyze respiration data (recommend x)
params.ds_resp = 50. # frequency to downsample sensor data for respiration analysis

params.resp_peak_height = None
params.resp_peak_prominence = 0.00002
params.resp_peak_width = None
params.resp_peak_wlen = None

disconnection_length = 1000 # create a new chunk if the device disconnects for more than this length (ms)
min_chunk_length = 60 # discard chunks shorter than this length (s)

## Select and Preprocess Files

### Select and Load Files

In [None]:
# Import data (Select the folder containing the Accel/Temp/Event CSVs)

def select_folder_and_files():
    root = Tk()
    root.withdraw()
    root.attributes("-topmost", True)
    root.update()

    folder_path = filedialog.askdirectory(title="Select a Folder")

    if not folder_path:
        return []

    files_in_folder = os.listdir(folder_path)

    csv_files = [os.path.join(folder_path, f).replace("\\", "/") for f in files_in_folder if f.endswith('.csv')]
    pq_files = [os.path.join(folder_path, f).replace("\\", "/") for f in files_in_folder if f.endswith('.pq')]

    if csv_files and pq_files:
        print("The folder contains a mix of CSV and PQ files. Please select a valid folder.")
        return []

    return folder_path, csv_files if csv_files else pq_files

def process_csv_files(file_paths):
    print(f"before process_csv_files RAM used: {psutil.Process().memory_info().rss / 1e9:.2f} GB")
    accel_files = [f for f in file_paths if '_Accel_' in f and f.endswith('.csv')]
    temp_files = [f for f in file_paths if '_Temp_' in f and f.endswith('.csv')]
    event_files = [f for f in file_paths if '_Event_' in f and f.endswith('.csv')]

    def get_file_number(filename):
        return int(filename.split('_')[-1].replace('.csv', ''))

    accel_files.sort(key=get_file_number)
    temp_files.sort(key=get_file_number)
    event_files.sort(key=get_file_number)


    accel_data_list = []
    temp_data_list = []
    event_data_list = []

    for file in accel_files:
        df = pd.read_csv(file)
        
        df_rows = df.shape[0]
        df_values = []
        for i, row in enumerate(df.values):
            df_values.append(row)
            if i % 1000 == 0:
                print(f"Extracted {i + 1} out of {df_rows} rows from accel CSV")
        
        accel_data_list.append(df_values)
    
    for file in temp_files:
        df = pd.read_csv(file)
        
        df_rows = df.shape[0]
        df_values = []
        for i, row in enumerate(df.values):
            df_values.append(row)
            if i % 1000 == 0:
                print(f"Extracted {i + 1} out of {df_rows} rows from temp CSV")
        
        temp_data_list.append(df_values)

    for file in event_files:
        df = pd.read_csv(file)
        
        df_rows = df.shape[0]
        df_values = []
        for i, row in enumerate(df.values):
            df_values.append(row)
            if i % 10 == 0:
                print(f"Extracted {i + 1} out of {df_rows} rows from event CSV")
        
        event_data_list.append(df_values)

    accel_data = np.concatenate(accel_data_list) if accel_data_list else np.array([])
    temp_data = np.concatenate(temp_data_list) if temp_data_list else np.array([])
    event_data = np.concatenate(event_data_list) if event_data_list else np.array([])

    datetime_format = "%Y-%m-%d %H:%M:%S.%f %z"

    if len(accel_data) > 0:
        accel_data_rows = accel_data.shape[0]
        accel_data_time = []
        
        for i, t in enumerate(accel_data[:, 0]):
            accel_data_time.append(datetime.strptime(t, datetime_format))
            if i % 1000 == 0:
                print(f"Finished converting {i + 1} out of {accel_data_rows} accel_data rows to datetimes")
        
        accel_data = np.column_stack((
            accel_data_time,
            accel_data[:, 1:]
        ))

    print("Processed accelerometer files")

    if len(temp_data) > 0:
        temp_data_rows = temp_data.shape[0]
        temp_data_time = []
        
        for i, t in enumerate(temp_data[:, 0]):
            temp_data_time.append(datetime.strptime(t, datetime_format))
            if i % 1000 == 0:
                print(f"Finished converting {i + 1} out of {temp_data_rows} temp_data rows to datetimes")
        
        temp_data = np.column_stack((
            temp_data_time,
            temp_data[:, 1:]
        ))
        
    print("Processed temperature files")
        
    if len(event_data) > 0:
        event_data_rows = event_data.shape[0]
        event_data_time = []
        
        for i, t in enumerate(event_data[:, 0]):
            event_data_time.append(datetime.strptime(t, datetime_format))
            if i % 10 == 0:
                print(f"Finished converting {i + 1} out of {event_data_rows} event_data rows to datetimes")
        
        event_data = np.column_stack((
            event_data_time,
            event_data[:, 1:]
        ))    
        
    print("Processed event files")    
        
    print(f"process_csv_files RAM used: {psutil.Process().memory_info().rss / 1e9:.2f} GB")
    
    return accel_data, temp_data, event_data

def process_pq_files(file_paths):
    accel_file = next((f for f in file_paths if 'stream-0x30.pq' in f), None)
    temp_file = next((f for f in file_paths if 'stream-0x13.pq' in f), None)

    accel_data = pd.read_parquet(accel_file).values if accel_file else np.array([])
    temp_data = pd.read_parquet(temp_file).values if temp_file else np.array([])
    event_data = np.array([])
    
    print("Processed accelerometer files")
    print("Processed temperature files")
    
    return accel_data, temp_data, event_data

print(f"RAM used: {psutil.Process().memory_info().rss / 1e9:.2f} GB")

os_name = platform.system()
if os_name == "Windows":
    folder_path, file_paths = select_folder_and_files()
else: #Mac
    folder_path = 'your/folder/path/here' 
    
    files_in_folder = os.listdir(folder_path)
    file_paths = [os.path.join(folder_path, f).replace("\\", "/") for f in files_in_folder if f.endswith('.csv')]

folder_name = os.path.basename(folder_path)
csv_not_pq = True

print("loading files")

if file_paths:
    if all(f.endswith('.csv') for f in file_paths):
        accel_data, temp_data, event_data = process_csv_files(file_paths)
    elif all(f.endswith('.pq') for f in file_paths):
        csv_not_pq = False
        accel_data, temp_data, event_data = process_pq_files(file_paths)
    else:
        print("No valid files were selected.")
else:
    print("No valid files were selected.")

### Chunk Data

In [None]:
# Seperate data into chunks

def remove_duplicates_from_data(accel_data):
    mask = np.r_[True, accel_data[1:,0] > accel_data[:-1,0]]
    accel_data_clean = accel_data[mask]
    num_removed = len(accel_data) - len(accel_data_clean)
    print(f"Removed {num_removed} duplicate entries")
    
    return accel_data_clean

accel_data = remove_duplicates_from_data(accel_data)

accel_first_dt = accel_data[0, 0]
if csv_not_pq:
    accel_first_ms = accel_first_dt.timestamp() * 1000
else:
    accel_first_ms = accel_first_dt * 1000
accel_ms = np.round([(dt.timestamp() * 1000 - accel_first_ms) for dt in accel_data[:, 0]], decimals=3)
accel_ms_diff = np.diff(accel_ms)
chunk_end = np.argwhere(accel_ms_diff > disconnection_length)
chunk_start = [0]
chunk_start = np.append(chunk_start, chunk_end + 1)
chunk_end = np.append(chunk_end, len(accel_ms_diff))
chunk_mask = chunk_end - chunk_start > params.fs * min_chunk_length
chunk_start = chunk_start[chunk_mask]
chunk_end = chunk_end[chunk_mask]
num_chunks = len(chunk_start)

accel_x = accel_data[:,1]
accel_y = accel_data[:,2]
accel_z = accel_data[:,3]

temp_s = np.round([(dt.timestamp() - accel_first_ms/1000) for dt in temp_data[:, 0]], decimals=3)
temp = temp_data[:, 1] if csv_not_pq else temp_data[:, 1] / 1000

if event_data.size == 0:
    event_ms = []
    event = []
else:
    event_ms = np.round([(dt.timestamp() * 1000 - accel_first_ms) for dt in event_data[:, 0]], decimals=3)
    event = event_data[:, 1]

### Resample Data

In [None]:
# Resample Acceleration Data to 800 Hz

accel_ms_old = accel_ms
new_chunk_start = [0]
new_chunk_end = []
new_accel_ms = []
new_accel_x = []
new_accel_y = []
new_accel_z = []

for chunk in range(num_chunks):
    print(f"Resampling chunk {chunk+1} out of {num_chunks}")
    new_start_time = accel_ms[chunk_start[chunk]]
    new_num_samples = int(np.floor(params.fs*(accel_ms[chunk_end[chunk]] - new_start_time)/1000) + 1)
    new_end_time = accel_ms[chunk_start[chunk]] + round((1000*(new_num_samples - 1)/params.fs), 3)

    new_accel_ms_chunk = np.linspace(new_start_time, new_end_time, new_num_samples)
    new_accel_ms = np.append(new_accel_ms, new_accel_ms_chunk)

    if chunk > 0:
        new_chunk_start = np.append(new_chunk_start, (new_chunk_end[chunk-1] + 1))
        new_chunk_end = np.append(new_chunk_end, (new_chunk_end[chunk - 1] + (new_num_samples))).astype(int)
    else:
        new_chunk_end = np.append(new_chunk_end, (new_chunk_start[chunk] + (new_num_samples - 1))).astype(int)

    cs_x = CubicSpline(accel_ms[chunk_start[chunk]:chunk_end[chunk]], accel_x[chunk_start[chunk]:chunk_end[chunk]])
    cs_y = CubicSpline(accel_ms[chunk_start[chunk]:chunk_end[chunk]], accel_y[chunk_start[chunk]:chunk_end[chunk]])
    cs_z = CubicSpline(accel_ms[chunk_start[chunk]:chunk_end[chunk]], accel_z[chunk_start[chunk]:chunk_end[chunk]])

    new_accel_x_chunk = cs_x(new_accel_ms_chunk)
    new_accel_y_chunk = cs_y(new_accel_ms_chunk)
    new_accel_z_chunk = cs_z(new_accel_ms_chunk)

    new_accel_x = np.append(new_accel_x, new_accel_x_chunk)
    new_accel_y = np.append(new_accel_y, new_accel_y_chunk)
    new_accel_z = np.append(new_accel_z, new_accel_z_chunk)

accel_dt = [accel_first_dt + pd.Timedelta(milliseconds=ms) for ms in new_accel_ms]
    
chunk_start = new_chunk_start  
chunk_end = new_chunk_end
accel_ms = new_accel_ms
accel_x = new_accel_x
accel_y = new_accel_y
accel_z = new_accel_z

## Bluetooth Conctivity Metrics 

In [None]:
#Export BLE connectivity metrics

def create_connection_status_csv(accel_dt, chunk_start, chunk_end, output_path):
    ble_status = []
    start_times = []
    stop_times = []
    total_times = []
    total_times_s = []

    overall_start_time = accel_dt[chunk_start[0]]
    overall_end_time = accel_dt[chunk_end[-1]]
    hours_rounded = int(((overall_end_time - overall_start_time).total_seconds()) / 3600)

    hourly_conn_rates = []

    for hour in range(hours_rounded):
        hour_start_time = overall_start_time + timedelta(hours=hour)
        hour_stop_time = hour_start_time + timedelta(hours=1)

        connection_time = timedelta(seconds=0)
        total_time = timedelta(seconds=0)

        for i in range(len(chunk_start)):
            chunk_start_time = accel_dt[chunk_start[i]]
            chunk_stop_time = accel_dt[chunk_end[i]]
            
            if chunk_start_time < hour_stop_time:
                if chunk_stop_time > hour_start_time:
                    chunk_in_hour_start = max(chunk_start_time, hour_start_time)
                    chunk_in_hour_end = min(chunk_stop_time, hour_stop_time)
                    connection_time += chunk_in_hour_end - chunk_in_hour_start
                    #total_time += timedelta(hours=1)
        
        connection_rate = 100 * connection_time.total_seconds() / 3600
        
        hourly_conn_rates.append({
            'hour': hour + 1,
            'start_time': hour_start_time,
            'end_time': hour_stop_time,
            'connection_rate': round(connection_rate, 3)
        })
    
    for i in range(num_chunks):
        start_time = accel_dt[chunk_start[i]]
        stop_time = accel_dt[chunk_end[i]]
        
        ble_status.append("Connected")
        start_times.append(start_time)
        stop_times.append(stop_time)
        total_times.append(stop_time - start_time)
        total_times_s.append((stop_time - start_time).total_seconds())
    
    for i in range(num_chunks - 1):
        start_time = accel_dt[chunk_end[i]]
        stop_time = accel_dt[chunk_start[i + 1]]
        
        ble_status.append("Disconnected")
        start_times.append(start_time)
        stop_times.append(stop_time)
        total_times.append(stop_time - start_time)
        total_times_s.append((stop_time - start_time).total_seconds())
    
    total_duration = sum(total_times_s)
    percent_times = [round((t / total_duration) * 100, 3) for t in total_times_s]

    hourly_df = pd.DataFrame(hourly_conn_rates)

    chunk_df = pd.DataFrame({
        'ble_status': ble_status,
        'start_time': start_times,
        'stop_time': stop_times,
        'total_time': total_times,
        'percent_time': percent_times
    })
    
    chunk_df = chunk_df.sort_values('start_time')

    summary_df = chunk_df.groupby('ble_status').agg({
        'total_time': 'sum',
        'percent_time': 'sum'
    }).reset_index()
    
    with open(output_path, 'w', newline='') as f:
        f.write("Connection Summary\n")
        summary_df.to_csv(f, index=False)
        f.write("\nHourly Connection Rates\n")
        hourly_df.to_csv(f, index=False)
        f.write("\nConnection Log\n")
        chunk_df.to_csv(f, index=False)
    
output_file = os.path.join(folder_path, f"{folder_name}_dropouts.csv")
create_connection_status_csv(accel_dt, chunk_start, chunk_end, output_file)

## Class and Function Definitions 

In [None]:
# Functions for Calculating Multiple Metrics

def downsample_stages(x, original, target):
    next_downsample = 1
    num_down = 0
    current = original
    # While we want to downsample
    while current / (next_downsample * 2) > target:
       # Increase the downsample factor
        next_downsample *= 2
        num_down += 1
        # If we don't want to use a greater downsample factor,
        # Do the downsample
        if next_downsample * 2 > 8:
            x = signal.decimate(x, next_downsample)
            current = current / next_downsample
            next_downsample = 1

    if (current/(next_downsample)-target) > (target-current/(next_downsample*2)):
        x = signal.decimate(x, next_downsample*2)  # Final downsample
        num_down += 1
    else:
        x = signal.decimate(x, next_downsample)  # Final downsample
    return(x, 2**num_down)

def accel_detrend(accel):
    return accel - savgol_filter(accel, 21, 8)

In [None]:
#Calculate sampling rate for debugging #Do not use yet

class SampleRate:
    def __init__(self, chunk, time_SR, SR):
        self.chunk = chunk                                # index of data range (used to index the all of the below arrays)
        self.time_SR = time_SR                            # array of time value arrays corresponding to sample rate data
        self.SR = SR                                      # array of sample rate data arrays

def calculate_sampling_rate(chunk, accel_ms, accel_ms_old, accel_ms_diff):
    window_n = params.vital_w*params.fs
    time_seg_start_idx = np.arange(0,int(len(accel_ms)-window_n),int(np.floor(window_n*params.vital_ovlp)))

    time = (time_seg_start_idx + window_n/2)+params.start_time

    data = np.zeros(len(time_seg_start_idx))
    for i in range(len(time_seg_start_idx)):
        time_seg = accel_ms[time_seg_start_idx[i]:int(time_seg_start_idx[i]+window_n)]
        idx = np.argwhere((accel_ms_old >= time_seg[0]) & (accel_ms_old < time_seg[-1])).flatten()
        if idx.size > 0:
            data = 1000/(np.mean(accel_ms_diff[idx]))
        else:
            data = 0 

    return SampleRate(chunk, time, data)

In [None]:
# Classes and Functions for Calculating Activity

class Activity:
    def __init__(self, chunk, time, data):
        self.chunk = chunk                          # index of data range (used to index the all of the below arrays)
        self.time = time                            # array of time value arrays corresponding to activity data
        self.data = data                            # array of activity data arrays
        
#def calculate_activity(accel_x, accel_y, accel_z_clean):
def calculate_activity(chunk, accel_x_chunk, accel_y_chunk, accel_z_clean):
    act_window_N = window*params.fs_ds
    act_overlap = 0.25
    
    down_x,_ = downsample_stages(accel_x_chunk[chunk],params.fs,params.downsample)
    down_y,_ = downsample_stages(accel_y_chunk[chunk],params.fs,params.downsample)
    down_z,_ = downsample_stages(accel_z_clean[chunk],params.fs,params.downsample)

    act_index = np.arange(0,int(len(down_x)-act_window_N),int(np.floor(act_window_N*act_overlap)))

    act_data = np.zeros((len(act_index)))

    # act_time = (act_index + act_window_N/2)/params.fs_ds+params.start_time
    
    act_time = []
    for i, idx in enumerate(act_index):
        t = (idx + act_window_N / 2) / params.fs_ds + params.start_time
        if np.isnan(t):
            print(f"NaN at index {i} (idx={idx})")
        act_time.append(t)

    for i in range(len(act_index)):
        x_window = down_x[act_index[i]:int(act_index[i]+act_window_N)]
        y_window = down_y[act_index[i]:int(act_index[i]+act_window_N)]
        z_window = down_z[act_index[i]:int(act_index[i]+act_window_N)]
        act_data[i] = calculate_variance(x_window, y_window, z_window)

    #return (act_time, activity)
    return Activity(chunk, act_time, act_data)

def calculate_variance(x, y, z):
    sysVar = 4.38840806 * 10**-5
    AI_prep = 0
    AI_prep = AI_prep + ((np.std(x))**2 - sysVar)/sysVar
    AI_prep = AI_prep + ((np.std(y))**2 - sysVar)/sysVar
    AI_prep = AI_prep + ((np.std(z))**2 - sysVar)/sysVar
    AI_prep = AI_prep * (1/3)
    AI_prep = (max(AI_prep, 0))**.5
    return(AI_prep)

In [None]:
# Classes and Functions for Calculating Heart Rate

class Cardiac:
    def __init__(self, chunk, time_sig, sig, sig_sh, sig_sh_env, time_sig_ds, sig_sh_env_ds, sig_lt_ds, all_pks_sh, all_pks_lt, all_pks_matched, time_HR, HR, SQI):
        self.chunk = chunk                          # index of data range (used to index the all of the below arrays)
        self.time_sig = time_sig                    # array of time value arrays corresponding to attributes with "sig"
        self.sig = sig                              # array of heart sound arrays
        self.sig_sh = sig_sh                        # array of shannon envelope arrays (before filtering)
        self.sig_sh_env = sig_sh_env                # array of shannon envelope arrays (after filtering)
        self.time_sig_ds = time_sig_ds              # array of time value arrays corresponding to attributes with "sig" and "ds"
        self.sig_sh_env_ds = sig_sh_env_ds          # array of shannon envelope arrays (after filtering) downsampled to 400 Hz
        self.sig_lt_ds = sig_lt_ds                  # array of length transform arrays downsampled to 400 Hz
        self.all_pks_sh = all_pks_sh                # array of arrays containing peak indices for sig_sh_env_ds
        self.all_pks_lt = all_pks_lt                # array of arrays containing peak indices for sig_lt_ds
        self.all_pks_matched = all_pks_matched      # array of arrays containing matched indices for all_pks_sh and all_pks_lt
        self.time_HR = time_HR                      # array of time value arrays corresponding to HR and SQI
        self.HR = HR                                # array of heart rate arrays
        self.SQI = SQI                              # array of signal quality index arrays

def autocorr(x):
    result = signal.correlate(x,x)[len(x)-1:]
    return result/result[0]    

def AMPD_pks(x, min_scale=None, max_scale=None):
    """Find peaks in quasi-periodic noisy signals using ASS-AMPD algorithm.
    AMPD_PKS Calculates the peaks of a periodic/quasi-periodic signal
    Method adapted from Scholkmann et.al. (2012)
    An Efficient Algorithm for Automatic Peak Detection in Noisy Periodic and
    Quasi-Periodic Signals"""

    x = signal.detrend(x)
    N = len(x)

    L = max_scale // 2
    cut = min_scale // 2

    # create LSM matix
    LSM = np.ones((L, N), dtype=bool)
    for k in np.arange(1, L + 1):
        # compare to right neighbours
        LSM[k - 1, 0:N - k] &= (x[0:N - k] > x[k:N])
        LSM[k - 1, k:N] &= (x[k:N] > x[0:N - k])  # compare to left neighbours

    G = LSM.sum(axis=1)
    # normalize to adjust for new edge regions
    G = G * np.arange(N // 2, N // 2 - L, -1)
    l_scale = cut+np.argmax(G[cut:])

    # find peaks that persist on all scales up to l
    pks_logical = np.min(LSM[0:l_scale, :], axis=0)
    pks = np.flatnonzero(pks_logical)

    return pks

def lengthtransform(x, w, fs):
    # LENGTHTRANSFORM Computes the length transform of signal <sig>
    # Length transform as described in Zong, Moody, Jiang (2003) A Robust
    # Open-source Algorithm to Detect Onset and Duration of QRS Complexes.
    # Length transform is simply the curve length with different windows <w>
    # resulting in output LT as a function of window length and sample
    C = 1/(fs**2)
    w_N = int(np.ceil(w*fs))
    normfactor = w_N/fs
    dy_k = np.array(np.diff(x, prepend=0)).astype(float)
    dL = np.sqrt(C + dy_k**2)

    LT = dL.cumsum()
    LT[w_N:] = LT[w_N:] - LT[:-w_N]
    return(LT-normfactor)


def shannon_energy_env(x):
    x_env = -x**2 * np.log(x.astype(float)**2) #Original
    #x_env = x #Just for testing, not necessary
    return(x_env)

# Heart Rate Calculations
def calculate_b2b(envelope, length_transform, params):
    pks_sh = AMPD_pks(envelope, min_scale=params.min_scale, max_scale=params.max_scale)
    pks_lt = AMPD_pks(length_transform, min_scale=params.min_scale, max_scale=params.max_scale)
    N_detected = min(len(pks_sh), len(pks_lt))
    pks_sh = pks_sh[(envelope[pks_sh] > params.hs_thresh[0]) & (envelope[pks_sh] < params.hs_thresh[1])]

    if N_detected:
        residual = np.min(pks_sh[np.newaxis, :] - pks_lt[:, np.newaxis], axis=0)
        matched = pks_sh[residual < .150*params.fs_ds]

        # Find Beat to Beat Intervals
        b2b = np.diff(matched)/params.fs_ds
        # Check for accidental peaks in between real beats
        for k in range(len(b2b)-1):
            if b2b[k] < 60/params.hr[1] and b2b[k+1] < 60/params.hr[1]:
                b2b[k] = b2b[k] + b2b[k+1]
                b2b[k+1] = 0
                matched[k] = 0
        # Remove intervals outside expected HR range
        b2b = b2b[(b2b < 60/params.hr[0]) & (b2b > 60/params.hr[1])]
        matched = matched[matched > 0]
        N_cleaned = len(b2b+1)

        SQI = N_cleaned/N_detected
    else:
        matched, b2b, SQI = np.array([]), np.array([]), 0

    return(pks_sh, pks_lt, matched, b2b, SQI)
        
def extract_cardiac(chunk, params, time, accel_z_clean):
    time_sig = time
    sos_hs = signal.butter(8,np.array(params.hs)/(params.fs/2),btype='bandpass',output='sos')
    sos_shan = signal.butter(8,14/(params.fs/2),btype='low',output='sos')
    sig = signal.sosfiltfilt(sos_hs,accel_z_clean[chunk]) # heart sound
    
    sig_lt = lengthtransform(sig,params.hs_len,params.fs) # heart sound length transform
    sig_sh = shannon_energy_env(sig) # heart sound shannon envelope without filter
    sig_sh_env = signal.sosfiltfilt(sos_shan,sig_sh) # heart sound shannon envelope with filter
        
    sig_lt_ds,ds_hs_lt = downsample_stages(sig_lt,params.fs,params.downsample)
    sig_sh_env_ds,ds_hs_sh = downsample_stages(sig_sh_env,params.fs,params.downsample)
    time_sig_ds = time[::2]

    params.fs_ds =  params.fs/ds_hs_sh
    
    win_N = params.vital_w*params.fs_ds #number of samples in a single window
    idx_v = np.arange(0,int(len(sig_sh_env_ds)-win_N),int(np.floor(win_N*params.vital_ovlp))) #index of hs_sh in which to start a new calculation
    time_HR = (idx_v + win_N/2)/params.fs_ds+params.start_time #time in seconds of the centers of each window of hs_sh
    
    all_pks_matched = []
    all_pks_sh = []
    all_pks_lt = []

    # Initialize starting search variables to capture correct signal features
    params.min_scale = int(np.floor(60/params.hr[1]*params.fs_ds)) # Assuming initial heart rate is non-tachycardia
    params.max_scale = int(np.ceil(60/params.hr[0]*params.fs_ds))
    
    HR = np.zeros(len(idx_v))
    SQI = np.zeros(len(idx_v))
    
    # Loop through analysis windows
    for i in range(len(idx_v)):
        x_lt = sig_lt_ds[idx_v[i]:int(idx_v[i]+win_N)]
        x_sh = sig_sh_env_ds[idx_v[i]:int(idx_v[i]+win_N)]
        
        pks_sh, pks_lt, pks_matched, b2b, SQI_temp = calculate_b2b(x_sh,x_lt,params)

        # Saves detected heartbeats to global lists (debugging purpose)
        if(len(pks_sh) > 0):
            all_pks_sh = all_pks_sh + list((pks_sh + idx_v[i]))
        if(len(pks_lt) > 0):
            all_pks_lt = all_pks_lt + list((pks_lt + idx_v[i]))
        if(len(pks_matched) > 0):
            #all_matched = all_matched + list((matched + idx_v[i]))
            all_pks_matched = all_pks_matched + list((pks_matched + idx_v[i]))

        # This part uses autocorrelation to calculate the signal quality:
        # If the data shows regular/periodical heartbeats, then it is high quality; if not, then low quality (motion artifacts or weak heartbeats from the sensor)
        # The autocorrelation quantifies periodicity of signal.
        ac = autocorr(x_sh)
        win_hr = 0
        SQI_temp = 0
        ac_peaks, _ = signal.find_peaks(ac, distance = 60/params.hr[1]*params.fs_ds, height = 0.2, prominence = 0.1)
        if ac_peaks.size > 0:
            win_hr = 60/(ac_peaks[0]/params.fs_ds)
            SQI_temp = ac[ac_peaks[0]]
        else:
            SQI_temp = 0

        if (len(b2b)>1):
            HR[i] = (60/np.mean(b2b))
            SQI[i] = (SQI_temp)
         
    return Cardiac(chunk, time_sig, sig, sig_sh, sig_sh_env, time_sig_ds, sig_sh_env_ds, sig_lt_ds, all_pks_sh, all_pks_lt, all_pks_matched, time_HR, HR, SQI)

In [None]:
# Classes and Functions for Calculating Respiratory Rate
        
class Respiratory:
    def __init__(self, chunk, time_sig, sig_prefilt, sig, pks, time_RR, RR):
        self.chunk = chunk                          # index of data range (used to index the all of the below arrays)
        self.time_sig = time_sig                    # array of time value arrays corresponding to attributes with "sig"
        self.sig_prefilt = sig_prefilt              # array of respiratory signal arrays computed using envelope
        self.sig = sig                              # array of respiratory signal arrays computed using bpf or envelope and bpf
        self.pks = pks                              # array of arrays containing peak indices for sig
        self.time_RR = time_RR                      # array of time value arrays corresponding to RR
        self.RR = RR                                # array of respiratory rate arrays computed from pks

def extract_respiratory(chunk, params, time, accel_x_chunk, accel_y_chunk, accel_z_chunk, all_pks_sh, all_pks_matched, sig_sh_env, accel_z_clean):

    if not use_hs_to_calculate_RR:
        time_sig = time[::16]
        if params.resp_axis == 'x':
            sig_prefilt, ds_f_resp = downsample_stages(accel_x_chunk,params.fs,params.ds_resp)
        elif params.resp_axis == 'y':
            sig_prefilt, ds_f_resp = downsample_stages(accel_y_chunk,params.fs,params.ds_resp)
        elif params.resp_axis == 'z':
            sig_prefilt, ds_f_resp = downsample_stages(accel_z_chunk,params.fs,params.ds_resp)
        params.fs_resp =  params.fs/ds_f_resp
        filt = signal.butter(8,[x/(params.fs_resp/2) for x in params.resp],btype='bandpass',output='sos')
        sig = signal.sosfiltfilt(filt, sig_prefilt)

        # Setup Output DataFrame
        window_n = params.vital_w*params.fs_resp # 8s *50 Hz = 400 samples
        sig_seg_start_idx = np.arange(0,int(len(sig_prefilt)-window_n),int(np.floor(window_n*params.vital_ovlp)))

        time_RR = (sig_seg_start_idx + window_n/2)/params.fs_resp+params.start_time

        pks = []

        RR = np.zeros(len(sig_seg_start_idx))
        
        resp_peak_height = getattr(params, "resp_peak_height", None)
        resp_peak_prominence = getattr(params, "resp_peak_prominence", None)
        resp_peak_width = getattr(params, "resp_peak_width", None)
        resp_peak_wlen = getattr(params, "resp_peak_wlen", None)

        for i in range(len(sig_seg_start_idx)):
            sig_seg = sig[sig_seg_start_idx[i]:int(sig_seg_start_idx[i] + window_n)]
            pks_seg,_ = signal.find_peaks(sig_seg, 
                                          height=resp_peak_height,
                                          distance=params.fs_resp/params.resp[1],
                                          prominence=resp_peak_prominence, 
                                          width=resp_peak_width, 
                                          wlen=resp_peak_wlen)
            pks = pks + list((pks_seg + sig_seg_start_idx[i]))
            RR[i] = 60*params.fs_resp/np.mean(np.diff(pks_seg))#60/np.mean(resp) # Respiration Rate

        return Respiratory(chunk, time_sig, sig_prefilt, sig, pks, time_RR, RR)
                
    elif use_hs_to_calculate_RR:
        hs_pks_idx = []
        if use_shannon_peaks_only_to_calculate_resp_env:
            hs_pks_idx = np.array(all_pks_sh) #if only using hs_sh
        else:
            hs_pks_idx = np.array(all_pks_matched) #if using both sig_sh_env and sig_lt
        hs_pks_idx, _ = np.unique(hs_pks_idx, return_index=True)
        
        hs_pks_idx = hs_pks_idx * params.fs / params.fs_ds
        hs_pks_idx = hs_pks_idx.astype(int)

        for i in range(len(hs_pks_idx)):
            if(hs_pks_idx[i] > 7):
                if use_shannon_peaks_only_to_calculate_resp_env:
                    hs_pks_idx[i] = np.argmax (sig_sh_env[hs_pks_idx[i] - 8 : hs_pks_idx[i] + 8]) + (hs_pks_idx[i] - 8 ) #if only using sig_sh_env
                else:
                    hs_pks_idx[i] = np.argmax(accel_z_clean[hs_pks_idx[i] - 8 : hs_pks_idx[i] + 8]) + (hs_pks_idx[i] - 8 ) #if using both sig_sh_env and sig_lt
        hs_pks_idx, unique_index = np.unique(hs_pks_idx, return_index=True)
        
        time_sig = time[::16]
        idx_sig = time_sig * params.fs #50 Hz
        
        cs = None
        if use_shannon_peaks_only_to_calculate_resp_env:
            #cs = CubicSpline(hs_pks_idx, sig_sh_env[hs_pks_idx]) #if only using sig_sh_env
            cs = CubicSpline(idx_sig[0] + hs_pks_idx, sig_sh_env[hs_pks_idx]) #if only using sig_sh_env
        else:
            #cs = CubicSpline(hs_pks_idx, accel_z_clean[hs_pks_idx]) #if using both sig_sh_env and sig_lt
            cs = CubicSpline(idx_sig[0] + hs_pks_idx, accel_z_clean[hs_pks_idx]) #if using both sig_sh_env and sig_lt
        
        valid_mask = ((idx_sig[0] + hs_pks_idx[0]) <= idx_sig) & (idx_sig <= (idx_sig[0] + hs_pks_idx[-1]))  # Ensure idx_sig stays within bounds
        time_sig = time_sig[valid_mask]  # Trim time_sig accordingly
        idx_sig = idx_sig[valid_mask]    # Trim idx_sig accordingly
        sig_prefilt = cs(idx_sig)
        
        filt = signal.butter(4,np.array(params.resp)/(params.ds_resp/2),btype='bandpass',output='sos')
        sig = signal.sosfiltfilt(filt,sig_prefilt)

        window_n = params.vital_w*50
        sig_seg_start_idx = np.arange(0,int(len(sig_prefilt)-window_n),int(np.floor(window_n*params.vital_ovlp)))
        
        time_RR = (sig_seg_start_idx + window_n/2)/50+params.start_time

        RR = np.zeros(len(sig_seg_start_idx))
        pks = []
        
        resp_peak_height = getattr(params, "resp_peak_height", None)
        resp_peak_prominence = getattr(params, "resp_peak_prominence", None)
        resp_peak_width = getattr(params, "resp_peak_width", None)
        resp_peak_wlen = getattr(params, "resp_peak_wlen", None)
        
        for i in range(len(sig_seg_start_idx)):
            sig_seg = sig[sig_seg_start_idx[i]:int(sig_seg_start_idx[i]+window_n)]
            pks_seg = []
#             if(np.std(sig_seg) < 0.00001):
#                 pks_seg,_ = signal.find_peaks(sig_seg,distance=params.fs/params.resp[1],prominence=.0000005)
#             else:
#                 pks_seg,_ = signal.find_peaks(sig_seg,distance=params.fs/params.resp[1],prominence=.00037, width=[0,250], wlen=300) 
            pks_seg, _ = signal.find_peaks(sig_seg, 
                                           height = resp_peak_height,
                                           distance=50/params.resp[1],
                                           prominence=resp_peak_prominence, 
                                           width=resp_peak_width, 
                                           wlen=resp_peak_wlen)
        
            pks = pks + list((pks_seg + sig_seg_start_idx[i]))
            if len(pks_seg) > 1:
                RR[i] = 60*50/np.mean(np.diff(pks_seg))
    
        return Respiratory(chunk, time_sig, sig_prefilt, sig, pks, time_RR, RR)

## Analyze Data

In [None]:
# Analyze Data

#scale = 4096
scale = 1
    
chunk_time = []
accel_x_chunk = []
accel_y_chunk = []
accel_z_chunk = []
chunk_temp_s = []
chunk_temp = []

time_activity = []
activity = []

x_resp = []
all_pks_resp = []
all_matched = []
all_pks_sh = []
all_pks_lt = []
accel_z_clean = []
accel_x_clean = []
accel_y_clean = []
hs = []
hs_lt = []
hs_sh_prefilt = []
hs_sh = []
chunk_time_ds = []
chunk_time_rrds = []
all_matched_idx_interp = []
contour_interp = []
contour_interp_filt = []
all_pks_resp_heart = []
all_std_xx = []

time = []
HR = []
SQI = []
RR = []
RR_y = []
RR_z = []
RR_heart = []

sampleRateX = {}
activityX = {}
cardiacX = {}
respiratoryX = {}

event_s_chunk = []
event_chunk = []

for chunk in range(num_chunks):
    print(f"Analyzing chunk {chunk+1} out of {num_chunks}")
    chunk_time.append(accel_ms[chunk_start[chunk]:chunk_end[chunk]] / 1000)

    start_time = chunk_time[chunk][0]

    accel_x_chunk.append(accel_x[chunk_start[chunk]:chunk_end[chunk]]/scale)
    accel_y_chunk.append(accel_y[chunk_start[chunk]:chunk_end[chunk]]/scale)
    accel_z_chunk.append(accel_z[chunk_start[chunk]:chunk_end[chunk]]/scale)

    chunk_temp_s.append(temp_s[(temp_s >= chunk_time[chunk][0]) & (temp_s <= chunk_time[chunk][-1])])
    chunk_temp.append(temp[(temp_s >= chunk_time[chunk][0]) & (temp_s <= chunk_time[chunk][-1])])

    event_indices_chunk = np.where((event_ms >= accel_ms[chunk_start[chunk]]) & (event_ms < accel_ms[chunk_end[chunk]]))[0]

    if event_indices_chunk.size > 0:
        event_s_chunk.append(event_ms[event_indices_chunk] / 1000)
        event_chunk.append(event[event_indices_chunk])
    
    params.start_time = start_time 

    accel_z_clean.append(accel_detrend(accel_z_chunk[chunk]))

    #sampleRateX[chunk] = calculate_sampling_rate(chunk, chunk_time[chunk], accel_ms_old, accel_ms_diff)
    cardiacX[chunk] = extract_cardiac(chunk, params, chunk_time[chunk], accel_z_clean)
    respiratoryX[chunk] = extract_respiratory(chunk, params, chunk_time[chunk], accel_x_chunk[chunk], accel_y_chunk[chunk], accel_z_chunk[chunk], cardiacX[chunk].all_pks_sh, cardiacX[chunk].all_pks_matched, cardiacX[chunk].sig_sh_env, accel_z_clean[chunk])
    activityX[chunk] = calculate_activity(chunk, accel_x_chunk, accel_y_chunk, accel_z_clean)
    
    #if use_hs_to_calculate_RR:
        #To-Do: rename "fs_ds" to reflect that it is meant for heart rate, not respiratory rate. 
        #respiratoryX[chunk].time_sig = (respiratoryX[chunk].time_sig / params.fs_ds) + chunk_time[chunk][0];

## Plot and Export Data

### Individual Chunk Time Plots/CSVs

In [None]:
#Export data and plots for INDIVIDUAL chunks
if plot_output:
    for chunk in range(num_chunks):
        print(f"Exporting data and plots for chunk {chunk+1} out of {num_chunks}")
        idx_good = (cardiacX[chunk].SQI >= .0)
        
        plot_list = [plot_temperature, plot_accel_x, plot_accel_y, plot_accel_z, plot_activity, plot_accel_z_clean, plot_hs, plot_hs_sh_raw, 
                plot_hs_sh_filtered, plot_hs_lt, plot_HR, plot_SQI, plot_resp_sig_prefilt, plot_resp_sig, plot_RR]
        try:
            last_plot_index = max(index for index, value in enumerate(plot_list) if value)
        except:
            last_plot_index=0

        if (plot_hs_sh_filtered & plot_resp_sig_prefilt & overlay_resp_sig_prefilt_w_hs & use_hs_to_calculate_RR):
            num_plots = sum(plot_list) - 1
        else:
            num_plots = sum(plot_list)

        fig, axes = plt.subplots(num_plots, 1, sharex=True)

        if num_plots == 1:
            axes = [axes]

        Q1 = np.percentile(accel_z_clean[chunk], 5)
        Q3 = np.percentile(accel_z_clean[chunk], 95)
        IQR = Q3 - Q1
        T1 = (accel_z_clean[chunk] >= (Q1 - 1.5 * (IQR - Q1)))
        T2 = (accel_z_clean[chunk] <= Q3 + 1.5 * IQR)
        accel_z_clean_no_outliers = accel_z_clean[chunk][T1 & T2]
        accel_z_clean_min = min(accel_z_clean_no_outliers)
        accel_z_clean_max = max(accel_z_clean_no_outliers)

        hs_sh_no_outliers = cardiacX[chunk].sig_sh_env_ds[cardiacX[chunk].sig_sh_env_ds <= (Q3:= np.percentile(cardiacX[chunk].sig_sh_env_ds, 95)) + 1.5 * (IQR:= Q3 - (Q1 := np.percentile(cardiacX[chunk].sig_sh_env_ds, 5)))]
        hs_sh_no_outliers_max = max(hs_sh_no_outliers)

        hs_lt_no_outliers = cardiacX[chunk].sig_lt_ds[cardiacX[chunk].sig_lt_ds <= (Q3:= np.percentile(cardiacX[chunk].sig_lt_ds, 95)) + 1.5 * (IQR:= Q3 - (Q1 := np.percentile(cardiacX[chunk].sig_lt_ds, 5)))]
        hs_lt_no_outliers_max = max(hs_lt_no_outliers)

        plot_data = [
            (chunk_temp_s[chunk], chunk_temp[chunk], 'm', [chunk_temp[chunk].min() - 1, chunk_temp[chunk].max() + 1], 'Temp\n(C)', 'line'),
            (chunk_time[chunk], accel_x_chunk[chunk], 'g', [-2, 2], 'Accel\nX (g)', 'line'),
            (chunk_time[chunk], accel_y_chunk[chunk], 'r', [-2, 2], 'Accel\nY (g)', 'line'),
            (chunk_time[chunk], accel_z_chunk[chunk], 'b', [-2, 2], 'Accel\nZ (g)', 'line'),
            (activityX[chunk].time, activityX[chunk].data, 'k', [0.0, 1.1*activityX[chunk].data.max()], 'Phys.\nAct.\n(a.u.)', 'line'),
            (chunk_time[chunk], accel_z_clean[chunk], 'b', [accel_z_clean_min, accel_z_clean_max], 'Accel\nZ\nCln.', 'line'),
            (cardiacX[chunk].time_sig, cardiacX[chunk].sig, 'b', [cardiacX[chunk].sig.min(), cardiacX[chunk].sig.max()], 'H.S.', 'line'),
            (cardiacX[chunk].time_sig, cardiacX[chunk].sig_sh, 'b', [0, cardiacX[chunk].sig_sh.max()], 'H.S.\nFilt.', 'line'),
            (cardiacX[chunk].time_sig_ds, cardiacX[chunk].sig_sh_env_ds, 'b', [0, hs_sh_no_outliers_max], 'H.S.\nShan.\nEnv.', 'line'),
            (cardiacX[chunk].time_sig_ds, cardiacX[chunk].sig_lt_ds, 'b', [0, hs_lt_no_outliers_max], 'H.S.\nLen.\nTrans.', 'line'),
            (cardiacX[chunk].time_HR, cardiacX[chunk].HR, 'r',  [100, 800], 'HR\n(bpm)', 'scatter'),
            (cardiacX[chunk].time_HR, cardiacX[chunk].SQI, 'k', [0, 1], 'HR\nSQI', 'scatter'),
            (respiratoryX[chunk].time_sig, respiratoryX[chunk].sig_prefilt, 'k', [respiratoryX[chunk].sig_prefilt.min(),respiratoryX[chunk].sig_prefilt.max()], 'Resp.\nSig.', 'line'),
            (respiratoryX[chunk].time_sig, respiratoryX[chunk].sig, 'b', [respiratoryX[chunk].sig.min(), respiratoryX[chunk].sig.max()], 'Resp.\nSig.\nFilt.', 'line'),
            (respiratoryX[chunk].time_RR, respiratoryX[chunk].RR, 'g', [0, respiratoryX[chunk].RR.max()], 'RR\n(brpm)', 'scatter'),
            #(time[chunk], all_std_xx[chunk], 'k', [0, 0.00003], 'std', 'scatter')
            #(sampleRateX[chunk].time_SR, sampleRateX[chunk].SR, 'k', sampleRateX[chunk].SR.min(), sampleRateX[chunk].SR.max(), 'sampling rate', 'line')
        ]
        
        
        plot_idx = 0
        hs_sh_plot_idx = 0

        legend_elements = []

        for i, show_plot in enumerate(plot_list):
            if show_plot:
                plot_info = plot_data[i]
                x_data, y_data, color, y_lim, y_label, plot_type = plot_info

                if plot_type == 'line':
                    if (y_label == 'H.S.\nShan.\nEnv.') & overlay_resp_sig_prefilt_w_hs & use_hs_to_calculate_RR:
                        hs_sh_plot_idx = plot_idx
                    if (y_label == 'Resp.\nSig.') & plot_hs_sh_filtered & overlay_resp_sig_prefilt_w_hs & use_hs_to_calculate_RR:
                        axes[hs_sh_plot_idx].plot(x_data, y_data, c=color, linewidth=0.5)
                        continue
                    else:
                        axes[plot_idx].plot(x_data, y_data, c=color, linewidth=0.5)
                    if (y_label == 'H.S.\nShan.\nEnv.') & plot_HR_peaks:
                        for point in cardiacX[chunk].all_pks_sh:
                            axes[plot_idx].scatter(cardiacX[chunk].time_sig_ds[point], cardiacX[chunk].sig_sh_env_ds[point], c='r', marker='o', s=3)
                    if (y_label == 'Resp.\nSig.\nFilt.') & plot_RR_peaks:
                        for point in respiratoryX[chunk].pks:
                            axes[plot_idx].scatter(respiratoryX[chunk].time_sig[point], respiratoryX[chunk].sig[point], c='m', marker='o', s=3)
                elif plot_type == 'scatter':
                    axes[plot_idx].scatter(x_data[idx_good], y_data[idx_good], c=color, s=1, marker='o')

                axes[plot_idx].set_ylim(y_lim)
                axes[plot_idx].set_ylabel(y_label, fontsize=8)
                axes[plot_idx].tick_params(axis='x', labelsize=8)
                axes[plot_idx].tick_params(axis='y', labelsize=8)

                if False:
                    if chunk in local_chunk_event_times and local_chunk_event_times[chunk]:
                        for event_time in local_chunk_event_times[chunk]:
                            axes[plot_idx].plot([event_time, event_time], [y_lim[0], y_lim[1]], color='r', linewidth=0.5)

                    if chunk in global_chunk_event_times and global_chunk_event_times[chunk]:
                        for event_time in global_chunk_event_times[chunk]:
                            axes[plot_idx].plot([event_time, event_time], [y_lim[0], y_lim[1]], color='k', linewidth=0.5)
                if len(event_chunk) > 0:
                    unique_events = np.unique(event_chunk[chunk])
                    colors = plt.cm.tab10(np.linspace(0, 1, len(unique_events)))
                    event_color_dict = dict(zip(unique_events, colors))


                    for event_name in unique_events:
                        event_indices = event_chunk[chunk] == event_name
                        event_times = event_s_chunk[chunk][event_indices]

                        event_color = event_color_dict[event_name]
                        for event_time in event_times:
                            axes[plot_idx].plot([event_time, event_time], [y_lim[0], y_lim[1]], 
                                        color=event_color, linewidth=1, alpha=0.7)

                        legend_elements.append(Line2D([0], [0], color=event_color, label=event_name))

                if i == last_plot_index:
                    ticks = np.linspace(chunk_time[chunk][0], chunk_time[chunk][-1], 7)

                    #FOR DEBUG
                    #axes[plot_idx].set_xlim([chunk_time[chunk][0], chunk_time[chunk][8000]])
                    #ticks = np.linspace(chunk_time[chunk][0], chunk_time[chunk][8000], 7)

                    tick_labels = [(accel_first_dt + timedelta(seconds=tick)).time().strftime("%H:%M:%S") for tick in ticks]
                    axes[plot_idx].set_xticks(ticks)
                    axes[plot_idx].set_xticklabels(tick_labels)

                #FOR DEBUG    
                #axes[plot_idx].set_xlim([chunk_time[chunk][0], chunk_time[chunk][8000]])
                #ticks = np.linspace(chunk_time[chunk][0], chunk_time[chunk][8000], 7)

                plot_idx += 1

        if legend_elements:        
            axes[0].legend(handles=legend_elements, bbox_to_anchor=(0., 1.02, 1., .102), loc='lower right', ncols=2, borderaxespad=0.)     
        fig.align_ylabels()
        fig.subplots_adjust(left=0.15, right=0.95, top=0.95, bottom=0.08, hspace=0.4)

        if scrolling:
            ax_slider = plt.axes([0.12, 0.01, 0.76, 0.03])  # [left, bottom, width, height]
            slider = Slider(ax_slider, 'Time (s)', chunk_time[chunk][0], chunk_time[chunk][-1] - scroll_window, valinit=scroll_window)

            plot_idx = 0
            for i, show_plot in enumerate(plot_list):
                if show_plot:

                    if i == last_plot_index:
                        ticks = np.arange(np.floor(min(x_data)), np.ceil(max(x_data)) + 1, 1)
                        tick_labels = [(accel_first_dt + timedelta(seconds=tick)).time().strftime("%H:%M:%S") for tick in ticks]
                        axes[plot_idx].set_xticks(ticks)
                        axes[plot_idx].set_xticklabels(tick_labels)

            # Function to update xlim based on slider value
            def update(val):
                new_xlim = slider.val, slider.val + scroll_window

                plot_idx = 0
                for i, show_plot in enumerate(plot_list):
                    if show_plot and not(i == 12 and plot_hs_sh_filtered & plot_resp_sig_prefilt & overlay_resp_sig_prefilt_w_hs & use_hs_to_calculate_RR): #Do not increment subplot if resp_env_raw and hs_sh are overlayed
                        axes[plot_idx].set_xlim(new_xlim)
                        plot_idx += 1

                fig.canvas.draw_idle()

            update(0)

            slider.on_changed(update)

            plt.show()

        # png_output_file = os.path.join(folder_path, f"{folder_name}_chunk_{chunk}_plot.png")
        # fig.savefig(png_output_file, dpi=300)
        

if export_to_csv:
    for chunk in range(num_chunks):
        chunk_dir = os.path.join(folder_path, f'chunk_{chunk}')
        os.makedirs(chunk_dir, exist_ok=True)
        
        png_output_file = os.path.join(chunk_dir, f"{folder_name}_chunk_{chunk}_time_plot.png")
        if plot_output:
            fig.savefig(png_output_file, dpi=300)

        accel_export = np.zeros((len(chunk_time[chunk]), 4))
        accel_export[:,0] = chunk_time[chunk]
        accel_export[:,1] = accel_x_chunk[chunk]
        accel_export[:,2] = accel_y_chunk[chunk]
        accel_export[:,3] = accel_z_chunk[chunk]
        accel_export = accel_export[::100,:]
        accel_export_headers = "time (s),accel X (g),accel Y (g),accel Z (g)"
        accel_export_with_headers = np.vstack([accel_export_headers.split(','), accel_export])
        accel_output_file = os.path.join(chunk_dir, f"{folder_name}_chunk_{chunk}_accel.csv")
        np.savetxt(accel_output_file, accel_export_with_headers, delimiter=',', fmt='%s')

        PA_export = np.zeros((len(activityX[chunk].time), 2))
        PA_export[:,0] = activityX[chunk].time
        PA_export[:,1] = activityX[chunk].data
        PA_export_headers = "time (s),PA (a.u.)"
        PA_export_with_headers = np.vstack([PA_export_headers.split(','), PA_export])
        PA_output_file = os.path.join(chunk_dir, f"{folder_name}_chunk_{chunk}_PA.csv")
        np.savetxt(PA_output_file, PA_export_with_headers, delimiter=',', fmt='%s')
        
        cardiac_df = pd.DataFrame({
            "Time (ms)": np.round(cardiacX[chunk].time_HR, 3),
            "Heart Rate (bpm)": np.round(cardiacX[chunk].HR, 3),
            "SQI (a.u.)": np.round(cardiacX[chunk].SQI, 3),
        })

        # Save to CSV
        cardiac_output_file = os.path.join(chunk_dir, f"{folder_name}_chunk_{chunk}_cardiac.csv")
        cardiac_df.to_csv(cardiac_output_file, index=False)
        
        respiratory_df = pd.DataFrame({
            "Time (ms)": np.round(respiratoryX[chunk].time_RR, 3),
            "Respiratory Rate (brpm)": np.round(respiratoryX[chunk].RR, 3),
        })

        # Save to CSV
        respiratory_output_file = os.path.join(chunk_dir, f"{folder_name}_chunk_{chunk}_respiratory.csv")
        respiratory_df.to_csv(respiratory_output_file, index=False)

### Combined Chunk Time Plot/CSVs

In [None]:
#Export data and plots for MERGED chunks

def remove_nans(data):
    return np.array([x for x in data if x is not None and not math.isnan(x)])

def merge_line(data):
    return np.array([item for sublist in [item for sublist in data for item in (sublist, None)][:-1] for item in (sublist if sublist is not None else [None])])

def merge_scatter(data):
    return np.concatenate(data)

def merge_object(data):
    return np.array([item for sublist in [item for sublist in data for item in (sublist, None)][:-1] for item in (sublist if sublist is not None else [None])])

def merge_peak(data, some_time):
    combined_peaks = []
    for offset, sublist in enumerate(data):
        chunk_increment = sum(len(indices) for indices in some_time[:offset])
        #need to shift peaks forward to accomodate for chunks (chunk_increment) and NaNs (offset)
        adjusted_sublist = [index + offset + chunk_increment for index in sublist]
        combined_peaks.extend(adjusted_sublist)
    return np.array(combined_peaks)

merged_chunk_time = merge_line(chunk_time)
merged_chunk_time_ds = merge_line([item.time_sig_ds for item in cardiacX.values()])
merged_chunk_time_rrds = merge_line([item.time_sig for item in respiratoryX.values()])

merged_chunk_temp_s = merge_line(chunk_temp_s)

merged_accel_x_chunk = merge_line(accel_x_chunk)
merged_accel_y_chunk = merge_line(accel_y_chunk)
merged_accel_z_chunk = merge_line(accel_z_chunk)
merged_chunk_temp = merge_line(chunk_temp)
merged_accel_z_clean = merge_line(accel_z_clean)

merged_hs = merge_line([item.sig for item in cardiacX.values()])

merged_hs_sh_prefilt = merge_line([item.sig_sh for item in cardiacX.values()]) 
merged_hs_sh = merge_line([item.sig_sh_env_ds for item in cardiacX.values()]) 
merged_hs_lt = merge_line([item.sig_lt_ds for item in cardiacX.values()]) 
    
merged_time_activity = merge_line([item.time for item in activityX.values()])
merged_activity = merge_line([item.data for item in activityX.values()])

merged_time_HR = merge_scatter([item.time_HR for item in cardiacX.values()])
merged_HR = merge_scatter([item.HR for item in cardiacX.values()])
merged_SQI = merge_scatter([item.SQI for item in cardiacX.values()])

merged_time_resp = merge_line([item.time_sig for item in respiratoryX.values()])
merged_resp_prefilt = merge_line([item.sig_prefilt for item in respiratoryX.values()])
merged_resp = merge_line([item.sig for item in respiratoryX.values()])
merged_time_RR = merge_scatter([item.time_RR for item in respiratoryX.values()])
merged_RR = merge_scatter([item.RR for item in respiratoryX.values()])

merged_all_pks_resp = merge_peak([item.pks for item in respiratoryX.values()], [item.time_sig for item in respiratoryX.values()])
merged_all_matched = merge_peak([item.all_pks_matched for item in cardiacX.values()], [item.time_sig_ds for item in cardiacX.values()])

merged_all_pks_sh = merge_peak([item.all_pks_sh for item in cardiacX.values()], [item.time_sig_ds for item in cardiacX.values()])
merged_all_pks_lt = merge_peak([item.all_pks_lt for item in cardiacX.values()], [item.time_sig_ds for item in cardiacX.values()])

if len(event_chunk) > 0:
    merged_event_chunk = merge_scatter(event_chunk)
    merged_event_s_chunk = merge_scatter(event_s_chunk)
else:
    merged_event_chunk = []
    merged_event_s_chunk = []
    
idx_good = (merged_SQI >= .0)

plot_list = [plot_temperature, plot_accel_x, plot_accel_y, plot_accel_z, plot_activity, plot_accel_z_clean, plot_hs, plot_hs_sh_raw, 
            plot_hs_sh_filtered, plot_hs_lt, plot_HR, plot_SQI, plot_resp_sig_prefilt, plot_resp_sig, plot_RR]

last_plot_index = max(index for index, value in enumerate(plot_list) if value) #Returns the highest index of the displayed plots

if (plot_hs_sh_filtered & plot_resp_sig_prefilt & overlay_resp_sig_prefilt_w_hs):
    num_plots = sum(plot_list) - 1
else:
    num_plots = sum(plot_list)
    
fig, axes = plt.subplots(num_plots, 1, sharex=True)

if num_plots == 1:
    axes = [axes]

accel_z_clean_no_outliers = remove_nans(merged_accel_z_clean)[(remove_nans(merged_accel_z_clean) >= (Q1 := np.percentile(remove_nans(merged_accel_z_clean), 5)) - 1.5 * (IQR := (Q3 := np.percentile(remove_nans(merged_accel_z_clean), 95)) - Q1)) & (remove_nans(merged_accel_z_clean) <= Q3 + 1.5 * IQR)]
accel_z_clean_min = min(accel_z_clean_no_outliers)
accel_z_clean_max = max(accel_z_clean_no_outliers)

hs_sh_prefilt_no_outliers = remove_nans(merged_hs_sh_prefilt)[remove_nans(merged_hs_sh_prefilt) <= (Q3:= np.percentile(remove_nans(merged_hs_sh_prefilt), 99)) + 1.5 * (IQR:= Q3 - (Q1 := np.percentile(remove_nans(merged_hs_sh_prefilt), 1)))]
hs_sh_prefilt_no_outliers_max = max(hs_sh_prefilt_no_outliers)

hs_sh_no_outliers = remove_nans(merged_hs_sh)[remove_nans(merged_hs_sh) <= (Q3:= np.percentile(remove_nans(merged_hs_sh), 95)) + 1.5 * (IQR:= Q3 - (Q1 := np.percentile(remove_nans(merged_hs_sh), 5)))]
hs_sh_no_outliers_max = max(hs_sh_no_outliers)

hs_lt_no_outliers = remove_nans(merged_hs_lt)[remove_nans(merged_hs_lt) <= (Q3:= np.percentile(remove_nans(merged_hs_lt), 95)) + 1.5 * (IQR:= Q3 - (Q1 := np.percentile(remove_nans(merged_hs_lt), 5)))]
hs_lt_no_outliers_max = max(hs_lt_no_outliers)

if plot_output:
    plot_data = [
        (merged_chunk_temp_s, merged_chunk_temp, 'm', [min(remove_nans(merged_chunk_temp)) - 1, max(remove_nans(merged_chunk_temp)) + 1], 'Temp\n(C)', 'line'),
        (merged_chunk_time, merged_accel_x_chunk, 'g', [-2, 2], 'Accel\nX (g)', 'line'),
        (merged_chunk_time, merged_accel_y_chunk, 'r', [-2, 2], 'Accel\nY (g)', 'line'),
        (merged_chunk_time, merged_accel_z_chunk, 'b', [-2, 2], 'Accel\nZ (g)', 'line'),
        (merged_time_activity, merged_activity, 'k', [0.0, 1.1*remove_nans(merged_activity).max()], 'Phys.\nAct.\n(a.u.)', 'line'),
        (merged_chunk_time, merged_accel_z_clean, 'b', [accel_z_clean_min, accel_z_clean_max], 'Accel\nZ\nCln.(g)', 'line'),
        
        (merged_chunk_time, merged_hs, 'b', [min(remove_nans(merged_hs)), max(remove_nans(merged_hs))], 'H.S.', 'line'), ###
        (merged_chunk_time, merged_hs_sh_prefilt, 'b', [0, hs_sh_prefilt_no_outliers_max], 'H.S.\nFilt.', 'line'),
        (merged_chunk_time_ds, merged_hs_sh, 'b', [0, hs_sh_no_outliers_max], 'H.S.\nShan.\nEnv.', 'line'),
        (merged_chunk_time_ds, merged_hs_lt, 'b', [0, hs_lt_no_outliers_max], 'H.S.\nLen.\nTrans.', 'line'),
        (merged_time_HR, merged_HR, 'r',  [100, 800], 'HR\n(bpm)', 'scatter'),
        (merged_time_HR, merged_SQI, 'k', [0, 1], 'HR\nSQI', 'scatter'),
        
        (merged_time_resp, merged_resp_prefilt, 'g', [min(remove_nans(merged_resp_prefilt)), max(remove_nans(merged_resp_prefilt))], 'Resp.\nSig.', 'line'),
        (merged_time_resp, merged_resp, 'g', [min(remove_nans(merged_resp)), max(remove_nans(merged_resp))], 'Resp.\nSig.\nFilt.', 'line'),
        (merged_time_RR, merged_RR, 'g', [0, merged_RR.max()], 'RR\n(brpm)', 'scatter'),
    ]
    
    plot_idx = 0
    hs_sh_plot_idx = 0

    for i, show_plot in enumerate(plot_list):
        if show_plot:
            plot_info = plot_data[i]
            x_data, y_data, color, y_lim, y_label, plot_type = plot_info
            
            if plot_type == 'line':
                if (y_label == 'hs_sh') & overlay_resp_sig_prefilt_w_hs:
                    hs_sh_plot_idx = plot_idx
                if (y_label == 'x_resp_env') & plot_hs_sh_filtered & overlay_resp_sig_prefilt_w_hs:
                    axes[hs_sh_plot_idx].plot(x_data, y_data, c=color, linewidth=0.5)
                    if plot_RR_hs_env_peaks:
                        for point in merged_all_pks_resp_heart:
                            axes[hs_sh_plot_idx].scatter(merged_all_matched_idx_interp[point], merged_contour_interp_filt[point], c='k', marker='o', s=3)

                    continue

                else:
                    axes[plot_idx].plot(x_data, y_data, c=color, linewidth=0.5)
                if (y_label == 'hs_sh') & plot_HR_peaks:
                    for point in merged_all_pks_sh:
                        axes[plot_idx].scatter(merged_chunk_time_ds[point], merged_hs_sh[point], c='r', marker='o', s=3)
                if (y_label == 'sig_resp') & plot_RR_peaks:
                    for point in merged_all_pks_resp:
                        axes[plot_idx].scatter(merged_time_resp[point], merged_resp[point], c='m', marker='o', s=3)
            elif plot_type == 'scatter':
                x_temp = np.array(x_data[idx_good])
                y_temp = np.array(y_data[idx_good])
                axes[plot_idx].scatter(x_temp, y_temp, c=color, s=1, marker='o')

            axes[plot_idx].set_ylim(y_lim)
            axes[plot_idx].set_ylabel(y_label, fontsize=8)
            axes[plot_idx].tick_params(axis='x', labelsize=8)
            axes[plot_idx].tick_params(axis='y', labelsize=8)

            legend_elements = []
            
            if len(merged_event_chunk):
                unique_events = np.unique(merged_event_chunk)
                colors = plt.cm.tab10(np.linspace(0, 1, len(unique_events)))
                event_color_dict = dict(zip(unique_events, colors))

                
                for event_name in unique_events:
                    event_indices = merged_event_chunk == event_name
                    event_times = merged_event_s_chunk[event_indices]

                    event_color = event_color_dict[event_name]
                    for event_time in event_times:
                        axes[plot_idx].plot([event_time, event_time], [y_lim[0], y_lim[1]], 
                                    color=event_color, linewidth=0.5, alpha=0.7)

                    legend_elements.append(Line2D([0], [0], color=event_color, label=event_name))
                
            if i == last_plot_index:
                ticks = np.linspace(merged_chunk_time[0], merged_chunk_time[-1], 7)
                tick_labels = [(accel_first_dt + timedelta(seconds=tick)).time().strftime("%H:%M:%S") for tick in ticks]
                axes[plot_idx].set_xticks(ticks)
                axes[plot_idx].set_xticklabels(tick_labels)

            plot_idx += 1

    if legend_elements:        
        axes[0].legend(handles=legend_elements, bbox_to_anchor=(0., 1.02, 1., .102), loc='lower right', ncols=2, borderaxespad=0.)     
    
    fig.align_ylabels()
    fig.subplots_adjust(left=0.15, right=0.95, top=0.95, bottom=0.08, hspace=0.4)
    
    merged_dir = os.path.join(folder_path, 'chunk_all')
    os.makedirs(merged_dir, exist_ok=True)
    png_output_file = os.path.join(merged_dir, f"{folder_name}_chunk_all_time_plot.png")
    fig.savefig(png_output_file, dpi=300)
    
if export_to_csv: 
    accel_export = np.zeros((len(merged_chunk_time), 4))
    accel_export[:,0] = merged_chunk_time
    accel_export[:,1] = merged_accel_x_chunk
    accel_export[:,2] = merged_accel_y_chunk
    accel_export[:,3] = merged_accel_z_chunk
    accel_export = accel_export[::100,:]
    accel_export_headers = "time (s),accel X (g),accel Y (g),accel Z (g)"
    accel_export_with_headers = np.vstack([accel_export_headers.split(','), accel_export])
    accel_output_file = os.path.join(merged_dir, f"{folder_name}_chunk_all_accel.csv")
    np.savetxt(accel_output_file, accel_export_with_headers, delimiter=',', fmt='%s')

    PA_export = np.zeros((len(merged_time_activity), 2))
    PA_export[:,0] = merged_time_activity
    PA_export[:,1] = merged_activity
    PA_export_headers = "time (s),PA (a.u.)"
    PA_export_with_headers = np.vstack([PA_export_headers.split(','), PA_export])
    PA_output_file = os.path.join(merged_dir, f"{folder_name}_chunk_all_PA.csv")
    np.savetxt(PA_output_file, PA_export_with_headers, delimiter=',', fmt='%s')

    cardiac_df = pd.DataFrame({
        "Time (ms)": np.round(merged_time_HR, 3),
        "Heart Rate (bpm)": np.round(merged_HR, 3),
        "SQI (a.u.)": np.round(merged_SQI, 3),
    })

    
    # Save to CSV
    cardiac_output_file = os.path.join(merged_dir, f"{folder_name}_chunk_all_cardiac.csv")
    cardiac_df.to_csv(cardiac_output_file, index=False)
    
    respiratory_df = pd.DataFrame({
        "Time (ms)": np.round(merged_time_RR, 3),
        "Respiratory Rate (brpm)": np.round(merged_RR, 3),
    })

    # Save to CSV
    respiratory_output_file = os.path.join(merged_dir, f"{folder_name}_chunk_all_respiratory.csv")
    respiratory_df.to_csv(respiratory_output_file, index=False)

### Combined Chunk Radar Plots

In [None]:
import os
import re
import pandas as pd
import numpy as np
import math
from datetime import datetime, timedelta
import plotly.graph_objects as go
from tkinter import Tk, filedialog
import importlib.util
import subprocess
import sys
import pkg_resources

def ensure_kaleido_version(required_version="0.1.0.post1"):
    package = "kaleido"

    try:
        current_version = pkg_resources.get_distribution(package).version
        if current_version != required_version:
            print(f"Kaleido version mismatch: {current_version} installed, {required_version} required.")
            print("Reinstalling correct version...")
            subprocess.check_call([sys.executable, "-m", "pip", "install", f"{package}=={required_version}", "--quiet", "--force-reinstall"])
            print("Kaleido installed successfully. Please restart the kernel.")
    except pkg_resources.DistributionNotFound:
        print("Kaleido not installed. Installing required version...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", f"{package}=={required_version}", "--quiet"])
        print("Kaleido installed successfully. Please restart the kernel.")

ensure_kaleido_version()

RADAR_THEMES = {
    'cardiac': {
        'fill': 'rgba(0, 123, 255, 0.4)',       # Sky Blue
        'line': 'rgb(0, 123, 255)'
    },
    'sqi': {
        'fill': 'rgba(155, 89, 182, 0.4)',      # Violet Purple
        'line': 'rgb(155, 89, 182)'
    },
    'PA': {
        'fill': 'rgba(72, 201, 176, 0.4)',      # Mint Green
        'line': 'rgb(72, 201, 176)'
    },
    'respiratory': {
        'fill': 'rgba(255, 94, 77, 0.4)',       # Sunset Orange
        'line': 'rgb(255, 94, 77)'
    }
}

def polar_sector_path(start_hour, end_hour, radius=1.0, resolution=48):
    total_points = resolution
    points = []
    for i in range(total_points + 1):
        frac = i / total_points
        hour = frac * 24
        if start_hour <= hour <= end_hour:
            angle = math.radians(90 - (hour / 24) * 360)
            x = 0.5 + radius * math.cos(angle)
            y = 0.5 + radius * math.sin(angle)
            points.append((x, y))

    if not points:
        return ""

    # Build SVG path
    path = f"M {points[0][0]},{points[0][1]} "
    for x, y in points[1:]:
        path += f"L {x},{y} "
    path += "Z"
    return path


def extract_start_time_from_filename(filename):
    match = re.search(r'(\d{8}_T\d{6})', filename)
    if match:
        return datetime.strptime(match.group(1), '%Y%m%d_T%H%M%S')
    return None

def load_data(folder, keyword):
    print(f"Searching folder '{os.path.basename(folder)}' for keyword '{keyword}'")
    for file in os.listdir(folder):
        if keyword in file and file.endswith('.csv'):
            print(f"Found file: {file}")
            filepath = os.path.join(folder, file).replace("\\", "/")
            
            df = pd.read_csv(filepath)

            if df.empty or df.shape[1] < 2:
                continue

            start_time = extract_start_time_from_filename(file)
            if start_time is None:
                continue

            df.columns = ['time', keyword]
            
            #PA sometimes has nans, need to investigate
            df = df[df['time'].notna()].copy()
            
            df['real_time'] = [start_time + timedelta(seconds=t) for t in df['time']]
            return df[['real_time', keyword]]

    return pd.DataFrame(columns=['real_time', keyword])

def load_cardiac_data(folder):
    print(f"Searching folder '{os.path.basename(folder)}' for keyword 'cardiac'")
    for file in os.listdir(folder):
        if 'cardiac' in file and file.endswith('.csv'):
            print(f"Found file: {file}")
            filepath = os.path.join(folder, file)

            df = pd.read_csv(filepath)

            if df.empty or df.shape[1] < 3:
                continue

            start_time = extract_start_time_from_filename(file)
            if start_time is None:
                continue

            df.columns = ['time', 'cardiac', 'sqi']
            df['real_time'] = [start_time + timedelta(seconds=t) for t in df['time']]  # ms to seconds

            return df[['real_time', 'cardiac', 'sqi']]

    return pd.DataFrame(columns=['real_time', 'cardiac', 'sqi'])

def is_nighttime(label):
    label = label.lstrip('⸺')
    hour = int(label.split(':')[0])
    return (hour >= 0 and hour < 6) or (hour >= 18 and hour <= 23)

def make_radar_plot(df, value_col, title, save_file_name):
    if df.empty:
        print(f"No data for {title}")
        return
    
    fill_color = RADAR_THEMES.get(value_col, {}).get('fill', 'rgba(0,0,0,0.3)')
    line_color = RADAR_THEMES.get(value_col, {}).get('line', 'rgb(0,0,0)')

    start_time = df['real_time'].min()
    end_time = df['real_time'].max()
    
    total_duration = end_time - start_time
    snippet = total_duration / 48
    
    df['time_block'] = ((df['real_time'] - start_time) // snippet) * snippet + start_time
    df['label'] = df['time_block'].dt.strftime('%H:%M')

    grouped = df.groupby('label')[value_col].mean()
    
    full_range = [start_time + i * snippet for i in range(48)]
    full_labels = [t.strftime('%H:%M') for t in full_range]

    grouped = grouped.reindex(full_labels).fillna(0)
    grouped = grouped.reset_index()
    grouped.columns = ['label', value_col]

    categories_raw = grouped['label'].tolist()
    

    categories = [
        label if i % 4 == 0 else f'⸺{label}'
        for i, label in enumerate(categories_raw)
    ]
    
    values = grouped[value_col].tolist()

    categories += [categories[0]]
    values += [values[0]]

    fig = go.Figure(
        data=go.Scatterpolar(
            r=values,
            theta=categories,
            fill='toself',
            name=title ,
            fillcolor=fill_color,  
            ) 
        )
    
    fig.update_layout(
        
        polar=dict(
            radialaxis=dict(
                angle=-270,
                visible=True,
                range=[
                    min([v for v in values if v > 0]) * 0.95 if any(v > 0 for v in values) else 0,
                    max(values) * 1.05
                ]
            ),
            angularaxis=dict(
                tickmode='array',
                tickvals=[i for i, label in enumerate(categories) if not label.startswith('⸺')],
                ticktext = [
                    f"<span style='color:#8f8f8f; font-weight:bold'>{label}</span>" if not is_nighttime(label)
                    else f"<span style='color:#000000; font-weight:bold'>{label}</span>"
                    for label in categories if not label.startswith('⸺')
                ],
                rotation=90,
                direction='clockwise'
            )
        ),
        showlegend=False,
        title=title
    )
    
    fig.show()

    if chunk_all_folder and save_file_name:
        save_path=os.path.join(chunk_all_folder, f"{save_file_name}.png").replace("\\", "/")
        
        
        
        
        print(f"Saving radar chart to {save_path}")
        try:
            fig.write_image(save_path, width=800, height=600, format='png')
            print(f"Saved radar chart to {save_path}")
        except Exception as e:
            print(f"❌ Failed to save radar chart: {e}")
    
def select_folder():
    root = Tk()
    root.withdraw()
    root.attributes("-topmost", True)
    trial_folder = filedialog.askdirectory(title="Select folder containing chunk folder(s)")
    trial_folder_name = os.path.basename(trial_folder)
    chunk_all_folder = os.path.join(trial_folder, f'chunk_all').replace("\\", "/")
    return trial_folder_name, chunk_all_folder

try:
    print("trying to get path")
    folder_path
    trial_folder_name = folder_name
    chunk_all_folder = os.path.join(folder_path, "chunk_all").replace("\\", "/")
    print(f"chunk_all_folder: {chunk_all_folder}")
except NameError:
    print("could not get path, requesting manual entry")
    trial_folder_name, chunk_all_folder = select_folder()

print(f"trial folder name: {trial_folder_name}")

if chunk_all_folder:
    df_cardiac = load_cardiac_data(chunk_all_folder)
    df_pa = load_data(chunk_all_folder, 'PA')
    df_resp = load_data(chunk_all_folder, 'respiratory')

    make_radar_plot(df_cardiac[['real_time', 'cardiac']], 'cardiac', 'Average Heart Rate (BPM)', (trial_folder_name + f'_chunk_all_radar_plot_cardiac'))
    make_radar_plot(df_cardiac[['real_time', 'sqi']], 'sqi', 'Heart Rate Signal Quality Index (SQI)', (trial_folder_name + f'_chunk_all_radar_plot_SQI'))
    make_radar_plot(df_pa, 'PA', 'Average Physical Activity (a.u.)', (trial_folder_name + f'_chunk_all_radar_plot_PA'))
    make_radar_plot(df_resp, 'respiratory', 'Average Respiratory Rate (BrPM)', (trial_folder_name + f'_chunk_all_radar_plot_respiratory'))
else:
    print("No folder selected.")