In [None]:
import pandas as pd
import os
import numpy as np
import pytz
import plotly.graph_objects as go
import plotly.express as px
import plotly.subplots as sp
import pytz as tz
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime, timedelta
from scipy.signal import butter, filtfilt
from sklearn.discriminant_analysis import StandardScaler
from sklearn.preprocessing import MinMaxScaler
from sklearn.decomposition import PCA

In [None]:
RESPECK_FILE = '../data/bishkek_csr/03_train_ready/respeck/05-04-2025_respeck.csv'
PSG_FILE = '../data/bishkek_csr/03_train_ready/nasal_files/05-04-2025_nasal.csv'
LABELS_FILE = '../data/bishkek_csr/03_train_ready/event_exports/05-04-2025_event_export.csv'
OUTPUT_FILE = './08-05-2025_respeck_features.csv'

# --- Load Data ---
print("Loading data...")

respeck_df = pd.read_csv(RESPECK_FILE)
respeck_df['timestamp'] = pd.to_datetime(respeck_df['alignedTimestamp'], unit='ms')
tz = pytz.timezone('Asia/Bishkek')
respeck_df['timestamp'] = respeck_df['timestamp'].dt.tz_localize('UTC').dt.tz_convert(tz)

psg_df = pd.read_csv(PSG_FILE)
psg_df['timestamp'] = pd.to_datetime(psg_df['UnixTimestamp'], unit='ms')
tz = pytz.timezone('Asia/Bishkek')
psg_df['timestamp'] = psg_df['timestamp'].dt.tz_localize('UTC').dt.tz_convert(tz)

labels_df = pd.read_csv(LABELS_FILE)
labels_df['timestamp'] = pd.to_datetime(labels_df['UnixTimestamp'], unit='ms')
tz = pytz.timezone('Asia/Bishkek')
labels_df['timestamp'] = labels_df['timestamp'].dt.tz_localize('UTC').dt.tz_convert(tz)

# forward and back fill respeck data before extraction

start_time_respeck = respeck_df['timestamp'].min()
end_time_respeck = respeck_df['timestamp'].max()

start_time_psg = psg_df['timestamp'].min()
end_time_psg = psg_df['timestamp'].max()

overlap_start = max(start_time_respeck, start_time_psg)
overlap_end = min(end_time_respeck, end_time_psg)


print(overlap_start)
print(overlap_end)

respeck_df = respeck_df[(respeck_df['timestamp'] >= overlap_start) & (respeck_df['timestamp'] <= overlap_end)]
psg_df = psg_df[(psg_df['timestamp'] >= overlap_start) & (psg_df['timestamp'] <= overlap_end)]

Dynamically calculate the sampling rate from the timestamps
time_diffs_ms = respeck_df['alignedTimestamp'].diff().median()
if pd.isna(time_diffs_ms) or time_diffs_ms == 0:

    fs = 1000.0 / time_diffs_ms  # Sampling frequency in Hz
    print(f"    - Calculated sampling rate: {fs:.2f} Hz")

    # Define filter parameters
    lowcut = 0.1   # Lower cutoff frequency in Hz
    highcut = 1.5  # Upper cutoff frequency in Hz
    order = 2      # Filter order (2 is a good choice to avoid distortion)

    try:
        # Design the Butterworth bandpass filter
        nyquist = 0.5 * fs
        low = lowcut / nyquist
        high = highcut / nyquist
        b, a = butter(order, [low, high], btype='band')
        
        respeck_df['original_breathingSignal'] = respeck_df['breathingSignal']

    # 2. Apply the filter and OVERWRITE the 'breathingSignal' column with the clean data
        respeck_df['breathingSignal'] = filtfilt(b, a, respeck_df['breathingSignal'])

        # # Apply the filter and store it in a NEW column
        # # We keep the original 'breathingSignal' for reference
        # respeck_df['filteredBreathingSignal'] = filtfilt(b, a, respeck_df['breathingSignal'])
    except ValueError as e:
        print(f"  - WARNING: Skipping session. Filter could not be applied. Error: {e}")



In [None]:
# Jack's Util file

def nans(dims):
    a = np.empty(dims)
    a[:] = np.nan
    return a

''' Find the RMS value of an input signal in array form. '''
def rms(signal):
    return np.sqrt(np.mean(signal**2))

def rmsHamming(signal):
    squares = signal**2
    weights = np.hamming(len(signal))
    weightedSum = 0.0
    weightsSum = 0.0

    for i in range(len(signal)):
        weightedSum += squares[i] * weights[i]
        weightsSum += weights[i]

    return np.sqrt(weightedSum / weightsSum)

''' Find islands of defined values in a signal that may contain NaNs. '''
def findIslandLimits(signal, minIslandLength=0, minIslandGap=0):

    islands = []

    start = None
    end = None
    foundIsland = False

    for i in range(len(signal)):
        if not signal[i]:
            if start == None:
                start = i
            else:
                end = i + 1
                if i == len(signal) - 1:
                    foundIsland = True
        else:
            if start != None:
                if end != None:
                    foundIsland = True
                else:
                    start = None

        if foundIsland:
            if (minIslandGap > 0) and (len(islands) > 0):
                prevIslandStart = islands[-1][0]
                prevIslandEnd = islands[-1][1]
                islandGap = start - prevIslandEnd - 1
                if islandGap < minIslandGap:
                    # merge the new island with the previous one
                    islands[-1] = ((prevIslandStart, end))
                else:
                    islands.append((start, end))
            else:    
                islands.append((start, end))

            start = None
            end = None
            foundIsland = False
            
    # now return only the islands that are long enough
    longIslands = []
    for island in islands:
        if (island[1] - island[0]) >= minIslandLength:
            longIslands.append(island)

    return longIslands

def calculateThresholdLevels(signal, rmsBackwardLength, rmsForwardLength, rmsMultiplier, symmetrical):
    result = nans((len(signal), 2))
    
    if not symmetrical:
        
        #fill sum of squares buffers
        posValues = []
        negValues = []
        windowLength = rmsBackwardLength + rmsForwardLength
        if len(signal) < windowLength:
            return result
        
        lastBananaIndex = np.nan
            
        for i in range(windowLength - 1):
            if signal[i] >= 0:
                posValues.append(signal[i])
            elif signal[i] < 0:
                negValues.append(signal[i])
            else: # if nan
                lastBananaIndex = i
                
        posArray = np.array(posValues)
        negArray = np.array(negValues)
        
        sumOfSquaresPos = np.sum(posArray**2)
        posCount = len(posArray)
        sumOfSquaresNeg = np.sum(negArray**2)
        negCount = len(negArray)
        
        for i in range(0, len(signal)):
            if i < rmsBackwardLength or i >= len(signal) - rmsForwardLength:
                posResult = np.nan
                negResult = np.nan
            else:
                newValue = signal[i+rmsForwardLength-1]
                if np.isnan(newValue):
                    lastBananaIndex = i+rmsForwardLength-1
                else:
                    if newValue >= 0:
                        sumOfSquaresPos += newValue**2
                        posCount += 1
                    elif newValue < 0:
                        sumOfSquaresNeg += newValue**2
                        negCount += 1
                
                if not np.isnan(lastBananaIndex) and i - lastBananaIndex <= rmsBackwardLength:
                    posResult = np.nan
                    negResult = np.nan
                else:
                    posResult = np.sqrt(sumOfSquaresPos / posCount) * rmsMultiplier
                    negResult = -np.sqrt(sumOfSquaresNeg / negCount) * rmsMultiplier
                
                oldValue = signal[i-rmsBackwardLength]
                
                if oldValue >= 0:
                    sumOfSquaresPos -= oldValue**2
                    posCount -= 1
                elif oldValue < 0:
                    sumOfSquaresNeg -= oldValue**2
                    negCount -=1
            result[i,0] = posResult
            result[i,1] = negResult
            
        return result
    
    else:
        #fill sum of squares buffers
        allValues = []
        windowLength = rmsBackwardLength + rmsForwardLength
        if len(signal) < windowLength:
            return result
        
        lastBananaIndex = np.nan
        
        for i in range(windowLength - 1):
            if not np.isnan(signal[i]):
                allValues.append(signal[i])
            else:
                lastBananaIndex = i
        allArray = np.array(allValues)
        
        sumOfSquaresAll = np.sum(allArray**2)
        allCount = len(allArray)
        
        for i in range(0, len(signal)):
            if i < rmsBackwardLength or i >= len(signal) - rmsForwardLength:
                allResult = np.nan
            else:
                newValue = signal[i+rmsForwardLength-1]
                if np.isnan(newValue):
                    lastBananaIndex = i+rmsForwardLength-1
                else:
                    sumOfSquaresAll += newValue**2
                    allCount += 1
                
                if not np.isnan(lastBananaIndex) and i - lastBananaIndex <= rmsBackwardLength:
                    allResult = np.nan
                else:
                    allResult = np.sqrt(sumOfSquaresAll / allCount) * rmsMultiplier
                
                oldValue = signal[i-rmsBackwardLength]
                if not np.isnan(oldValue):
                    sumOfSquaresAll -= oldValue**2
                    allCount -= 1
                    
            result[i,0] = allResult
            result[i,1] = -allResult
        #figure()
        #plot(signal)
        #plot(result)
        #show()
        return result

def calculateBreathTimes(signal, posThresholds, negThresholds, minThreshold, zeroCrossingBreathStart):
    
    def breathTimes(startIndex, endIndex):

        def setInitialState(startValue, posThreshold, negThreshold):
            if startValue < negThreshold:
                state = LOW
            elif startValue > posThreshold:
                state = HIGH
            else:
                state = MID_UNKNOWN
            return state
    
        state = setInitialState(signal[startIndex], posThresholds[startIndex], negThresholds[startIndex])
        times = []
    
        for i in range(startIndex + 1, endIndex + 1):
            posThreshold = posThresholds[i]
            negThreshold = negThresholds[i]
            if state == LOW and signal[i] > negThreshold:
                state = MID_RISING
            elif state == HIGH and signal[i] < posThreshold:
                state = MID_FALLING
            elif (state == MID_RISING or state == MID_UNKNOWN) and signal[i] > posThreshold:
                state = HIGH
            elif (state == MID_FALLING or state == MID_UNKNOWN) and signal[i] < negThreshold:
                state = LOW
                times.append(i)

        if zeroCrossingBreathStart:
            zeroCrossingBreathTimes = []
            for t in times:
                for i in range(t,-1,-1):
                    if signal[i] >= 0:
                        zeroCrossingBreathTimes.append(i)
                        break
            return zeroCrossingBreathTimes
        else:
            return times

    LOW, MID_FALLING, MID_UNKNOWN, MID_RISING, HIGH = range(5)

    
    invalidated = np.ones(np.shape(signal), dtype=bool)
    for i in range(len(invalidated)):
        if posThresholds[i] > minThreshold or negThresholds[i] < -minThreshold:
            invalidated[i] = False
    

    minIslandLength = 0
    islandLimits = findIslandLimits(invalidated, minIslandLength)
    
    times = []
    for (start, end) in islandLimits:
        bt = breathTimes(start, end)
        if len(bt) > 0:
            times.append(bt)

    return times


# Code from Jack Taylor

def countLocalMaximas(values):
    count = 0
    if len(values) < 3:
        return 1
    if len(values) > 1 and values[0] > values[1]:
        count += 1
    if len(values) > 1 and values[-1] > values[-2]:
        count += 1
    for i in range(1, len(values) - 1):
        if values[i] > values[i - 1] and values[i] > values[i + 1]:
            count += 1
    return count

def countLocalMinimas(values):
    count = 0
    if len(values) < 3:
        return 1
    if len(values) > 1 and values[0] < values[1]:
        count += 1
    if len(values) > 1 and values[-1] < values[-2]:
        count += 1
    for i in range(1, len(values) - 1):
        if values[i] < values[i - 1] and values[i] < values[i + 1]:
            count += 1
    return count

def generate_RRV(sliced):
    sliced = sliced.dropna()
    if sliced.size == 0:
        return np.nan
    breathingSignal = sliced.values
    N = breathingSignal.shape[-1]
    y = breathingSignal
    yf = np.fft.fft(y)
    yff = 2.0/N * np.abs(yf[:N//2])
    temp_DCnotremov = yff
    if len(temp_DCnotremov) == 0 or len(temp_DCnotremov) == 1: 
        return 0.0
    else:
        DC = np.amax(temp_DCnotremov)
        maxi = np.argmax(temp_DCnotremov)
        temp_DCremov = np.delete(temp_DCnotremov, maxi)
        H1 = np.amax(temp_DCremov)
        return 100-(H1/DC)*100

def getBreaths(df):
    minThreshold = 0.001
    mult = 0.0125
    
    signal = list(df.breathingSignal)
    
    time_diff = df['timestamp'].diff()
    time_diff.map(lambda x: x.total_seconds()).mean()
    
    window_size = int((20 / time_diff.dropna().apply(lambda x: x.total_seconds()).mean()) // 2)
    threshs = calculateThresholdLevels(list(signal), window_size, window_size, mult, False)
    posThresh = threshs[:, 0]
    negThresh = threshs[:, 1]

    times = calculateBreathTimes(list(signal), posThresh, negThresh, minThreshold, False)

    total = set()
    minBreathLength = float("inf")
    maxBreathLength = float("-inf")
    for i in range(0, len(times)):
        vals = times[i]
        for j in range(0, len(vals)-1):
            start, end = vals[j], vals[j+1]
            minBreathLength = min(minBreathLength, end-start+1)
            maxBreathLength = max(maxBreathLength, end-start+1)
            for k in range(start, end+1):
                total.add(k)

    f = list(df.breathingSignal.dropna())
    a = f"Uses Breath From {len(total)}/{len(f)} = {round((len(total)/len(f)) * 100, 2)}% Signal"
    b = f"Max Breath Length: {maxBreathLength} points. Min Breath Length: {minBreathLength} points"
    print(a)
    print(b)
        
    return times


def mode(l):
    if len(l) == 0:
        return np.NaN, {}, []
    
    sortedRoundedArray = np.sort(np.around(l))
    dict = {}
    dist = np.zeros(sortedRoundedArray[-1] + 1)
    maxCount = 0
    for e in sortedRoundedArray:
        dist[e] += 1
        if e in dict:
            newCount = dict[e] + 1
            dict[e] = newCount
        else:
            newCount = 1
            dict[e] = newCount
            
        if newCount > maxCount:
                maxCount = newCount
    
    if maxCount > 0:
        l = []
        for e in dict:
            if dict[e] == maxCount:
                l.append(e)
        sorted = np.sort(l)
        return sorted[len(sorted) // 2], dict, dist
                
    else:
        return np.NaN, dict, dist
    

def extractFeatures(df):
    times = getBreaths(df)

    areas = []
    extremas = []
    peakRespiratoryFlows = []
    types = []
    durations = []
    activityLevels = []
    activityTypes = []
    starts = []
    ends = []
    
    activityLevel = np.array(df.activityLevel)
    activityType = np.array(df.activityType)
    signal = np.array(df.breathingSignal)
    timestamps = list(df.timestamp)

    for i in range(0, len(times)):
        if i % 25 == 0:
            print(f"{i}/{len(times)}... ", end=" ")
        vals = times[i]
        
        for j in range(0, len(vals)-1):
            start, end = vals[j], vals[j+1]
            flag = False
            breath = signal[start:end+1]
            breakPoint = start
            for k, val in enumerate(breath):
                if val >= 0.005: # arbitrary but to remove noise...
                    breakPoint = start + k
                    break

            # compute inhalation
            inhalation, inhalation_times = signal[start:breakPoint], timestamps[start:breakPoint]
            exhalation, exhalation_times = signal[breakPoint:end+1], timestamps[breakPoint:end+1]
                    
            level = activityLevel[start:end+1].mean()
            modeType = mode(activityType[start:end+1])[0]
            
            # compute inhalation
            if len(inhalation) > 1:
                peak = max(abs(np.array(inhalation)))
                extrema = countLocalMaximas(inhalation)
                dx = (inhalation_times[-1]-inhalation_times[0]).total_seconds() / len(inhalation)
                area = abs(np.trapezoid(y=inhalation,dx=dx))
                duration = (inhalation_times[-1]-inhalation_times[0]).total_seconds()
                
                areas.append(area)
                extremas.append(extrema)
                peakRespiratoryFlows.append(peak)
                types.append("Inhalation")
                durations.append(duration)
                activityLevels.append(level)
                activityTypes.append(modeType)
                starts.append(inhalation_times[0])
                ends.append(inhalation_times[-1])

            if len(exhalation) > 1:
                peak = max(abs(np.array(exhalation)))
                extrema = countLocalMinimas(exhalation)    
                dx = (exhalation_times[-1]-exhalation_times[0]).total_seconds() / len(exhalation)
                area = abs(np.trapezoid(y=exhalation,dx=dx))  
                duration = (exhalation_times[-1]-exhalation_times[0]).total_seconds()
                
                areas.append(area)
                extremas.append(extrema)
                peakRespiratoryFlows.append(peak)
                types.append("Exhalation")
                durations.append(duration)
                activityLevels.append(level)
                activityTypes.append(modeType)
                starts.append(exhalation_times[0])
                ends.append(exhalation_times[-1])

    return pd.DataFrame(data={"type": types, "area": areas, "peakRespiratoryFlow": peakRespiratoryFlows, "extremas": extremas, "duration": durations, "meanActivityLevel": activityLevels, "modeActivityType": activityTypes, "startTimestamp": starts, "endTimestamp": ends})


def getRegularity(df):
    # get distance to 1st PC for area, PRF only --> makes rapid shallow in feature level
    # make it temporal by adding distance to PC from all 3 as a seperaate PCA raansform
    scaler = MinMaxScaler()
    columns = ['area', 'peakRespiratoryFlow']
    df_normalized = scaler.fit_transform(df[columns])
    pca = PCA(n_components=1)  
    pca.fit(df_normalized)
    df_pca = pca.transform(df_normalized)
    first_principal_component = pca.components_[0]
    te = np.linalg.norm(df_normalized - np.outer(df_normalized.dot(first_principal_component), first_principal_component), axis=1)

    scaler = MinMaxScaler()
    
    # Okay so we are looking at the resapmled breathing rate
    columns = ['area', 'peakRespiratoryFlow', 'BR_mean']
    df_normalized = scaler.fit_transform(df[columns])
    pca = PCA(n_components=3)  
    pca.fit(df_normalized)
    df_pca = pca.transform(df_normalized)
    
    first_principal_component = pca.components_[0]
    second_principal_component = pca.components_[1]
    third_principal_component = pca.components_[2]
    
    distances_to_first_component = np.linalg.norm(df_normalized - np.outer(df_normalized.dot(first_principal_component), first_principal_component), axis=1)
    distances_to_second_component = np.linalg.norm(df_normalized - np.outer(df_normalized.dot(second_principal_component), second_principal_component), axis=1)
    distances_to_third_component = np.linalg.norm(df_normalized - np.outer(df_normalized.dot(third_principal_component), third_principal_component), axis=1)

    # Linear combination of these distances
    distances_difference = te + (distances_to_first_component - distances_to_second_component + distances_to_third_component)

    distances_difference = (distances_difference - distances_difference.min()) / (distances_difference.max() - distances_difference.min())
    
    return 1 - distances_difference

def combineDfs(respeck_df, original_respeck_df):
    breath_averages = []
    
    original_respeck_df.set_index('timestamp', inplace=True)
    original_respeck_df['BR_md'] = original_respeck_df[['breathingRate']].resample('30s').median().reindex(original_respeck_df.index, method='nearest')
    original_respeck_df['BR_mean'] = original_respeck_df[['breathingRate']].resample('30s').mean().reindex(original_respeck_df.index, method='nearest')
    original_respeck_df['BR_std'] = original_respeck_df[['breathingRate']].resample('30s').std().reindex(original_respeck_df.index, method='nearest')

    original_respeck_df['AL_md'] = original_respeck_df[['activityLevel']].resample('30s').median().reindex(original_respeck_df.index, method='nearest')
    original_respeck_df['AL_mean'] = original_respeck_df[['activityLevel']].resample('30s').mean().reindex(original_respeck_df.index, method='nearest')
    original_respeck_df['AL_std'] = original_respeck_df[['activityLevel']].resample('30s').std().reindex(original_respeck_df.index, method='nearest')


    RRV = original_respeck_df[["breathingSignal"]].resample('30s').apply(generate_RRV)
    RRV = RRV.replace(0, np.nan).ffill().bfill()
    original_respeck_df['RRV'] = RRV.reindex(original_respeck_df.index, method='nearest')

    # average of 3 Neighbours
    RRV3MA = RRV.rolling(window=3, center = True).mean() * 0.65
    original_respeck_df['RRV3MA'] = RRV3MA.reindex(original_respeck_df.index, method='nearest')
    
    original_respeck_df = original_respeck_df.reset_index()
    
    for index, row in respeck_df.iterrows():
        start_timestamp_str = row['startTimestamp']
        end_timestamp_str = row['endTimestamp']

        start_timestamp = pd.to_datetime(start_timestamp_str)
        end_timestamp = pd.to_datetime(end_timestamp_str)

        
        filtered_df = original_respeck_df[
            (original_respeck_df['timestamp'] >= start_timestamp) &
            (original_respeck_df['timestamp'] <= end_timestamp)
        ]
        """
        get sleeping features
        """
        breath_averages.append({
            'type': row['type'],
            'startTimestamp': start_timestamp,
            'endTimestamp': end_timestamp,
            'area': row['area'],
            'extremas': row['extremas'],
            'meanActivityLevel': row['meanActivityLevel'],
            'modeActivityType': row['modeActivityType'],
            'peakRespiratoryFlow': row['peakRespiratoryFlow'],
            'duration': row['duration'],
            'BR_md': filtered_df.BR_md.mean(),
            'BR_mean': filtered_df.BR_mean.mean(),
            'BR_std': filtered_df.BR_std.mean(),
            'AL_md': filtered_df.AL_md.mean(),
            'AL_mean': filtered_df.AL_mean.mean(),
            'AL_std': filtered_df.AL_std.mean(),
            'RRV': filtered_df.RRV.mean(),
            'RRV3MA': filtered_df.RRV3MA.mean(),
        })
    breath_averages_df = pd.DataFrame(breath_averages)
    return breath_averages_df


def calculate_breathing_rate_from_breaths(df, breath_times, window_minutes=1):
    """
    Calculate breathing rate from detected breath times.
    
    Parameters:
    - df: DataFrame with timestamp and breathingSignal columns
    - breath_times: Output from getBreaths function
    - window_minutes: Time window for rate calculation in minutes
    
    Returns:
    - DataFrame with timestamp and calculated breathing rate
    """
    # Convert timestamps to datetime

    df['datetime'] = df['timestamp']
    
    # Flatten all breath indices
    all_breath_indices = []
    for breath_group in breath_times:
        all_breath_indices.extend(breath_group)
    
    # Sort breath indices
    all_breath_indices.sort()
    
    # Create breathing rate time series
    breathing_rates = []
    timestamps = []
    
    # Calculate rate using sliding window
    window_seconds = window_minutes * 60
    
    for i, breath_idx in enumerate(all_breath_indices):
        if breath_idx >= len(df):
            continue
            
        current_time = df.iloc[breath_idx]['datetime']
        timestamps.append(df.iloc[breath_idx]['timestamp'])
        
        # Count breaths in the past window
        breath_count = 0
        for j in range(i, -1, -1):  # Look backwards
            if all_breath_indices[j] >= len(df):
                continue
            breath_time = df.iloc[all_breath_indices[j]]['datetime']
            time_diff = (current_time - breath_time).total_seconds()
            
            if time_diff <= window_seconds:
                breath_count += 1
            else:
                break
        
        # Convert to breaths per minute
        rate = (breath_count / window_seconds) * 60
        breathing_rates.append(rate)
    
    # Create result DataFrame
    result_df = pd.DataFrame({
        'timestamp': timestamps,
        'calculated_breathing_rate': breathing_rates
    })
    
    return result_df

old_breaths = getBreaths(respeck_df)
# Convert 'startTimestamp' to datetime
# breath_features['startTimestamp'] = pd.to_datetime(breath_features['startTimestamp'])

# # Count the number of breaths (inhalations + exhalations)
# breath_features['breath_count'] = 1  # Each row corresponds to a breath

# # Total number of breaths
# total_breaths = breath_features['breath_count'].sum()

# # Get the total duration of the DataFrame in minutes
# start_time = breath_features['startTimestamp'].min()
# end_time = breath_features['startTimestamp'].max()
# total_duration_minutes = (end_time - start_time).total_seconds() / 60  # Convert to minutes

# # Calculate average breaths per minute
# if total_duration_minutes > 0:
#     avg_breaths_per_minute = total_breaths / total_duration_minutes
# else:
#     avg_breaths_per_minute = 0

# print(f'Total Breaths: {total_breaths}')
# print(f'Total Duration (minutes): {total_duration_minutes:.2f}')
# print(f'Average Breaths per Minute: {avg_breaths_per_minute:.2f}')

## Accelerometer

In [None]:
# --- START OF FILE jack-breaths.py ---

import numpy as np
import pandas as pd
from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler

# Jack's Util file

def nans(dims):
    a = np.empty(dims)
    a[:] = np.nan
    return a

''' Find the RMS value of an input signal in array form. '''
def rms(signal):
    return np.sqrt(np.mean(signal**2))

def rmsHamming(signal):
    squares = signal**2
    weights = np.hamming(len(signal))
    weightedSum = 0.0
    weightsSum = 0.0

    for i in range(len(signal)):
        weightedSum += squares[i] * weights[i]
        weightsSum += weights[i]

    return np.sqrt(weightedSum / weightsSum)

''' Find islands of defined values in a signal that may contain NaNs. '''
def findIslandLimits(signal, minIslandLength=0, minIslandGap=0):

    islands = []

    start = None
    end = None
    foundIsland = False

    for i in range(len(signal)):
        if not signal[i]:
            if start == None:
                start = i
            else:
                end = i + 1
                if i == len(signal) - 1:
                    foundIsland = True
        else:
            if start != None:
                if end != None:
                    foundIsland = True
                else:
                    start = None

        if foundIsland:
            if (minIslandGap > 0) and (len(islands) > 0):
                prevIslandStart = islands[-1][0]
                prevIslandEnd = islands[-1][1]
                islandGap = start - prevIslandEnd - 1
                if islandGap < minIslandGap:
                    # merge the new island with the previous one
                    islands[-1] = ((prevIslandStart, end))
                else:
                    islands.append((start, end))
            else:
                islands.append((start, end))

            start = None
            end = None
            foundIsland = False
            
    # now return only the islands that are long enough
    longIslands = []
    for island in islands:
        if (island[1] - island[0]) >= minIslandLength:
            longIslands.append(island)

    return longIslands

def calculateThresholdLevels(signal, rmsBackwardLength, rmsForwardLength, rmsMultiplier, symmetrical):
    result = nans((len(signal), 2))
    
    if not symmetrical:
        
        #fill sum of squares buffers
        posValues = []
        negValues = []
        windowLength = rmsBackwardLength + rmsForwardLength
        if len(signal) < windowLength:
            return result
        
        lastBananaIndex = np.nan
            
        for i in range(windowLength - 1):
            if signal[i] >= 0:
                posValues.append(signal[i])
            elif signal[i] < 0:
                negValues.append(signal[i])
            else: # if nan
                lastBananaIndex = i
                
        posArray = np.array(posValues)
        negArray = np.array(negValues)
        
        sumOfSquaresPos = np.sum(posArray**2)
        posCount = len(posArray)
        sumOfSquaresNeg = np.sum(negArray**2)
        negCount = len(negArray)
        
        for i in range(0, len(signal)):
            if i < rmsBackwardLength or i >= len(signal) - rmsForwardLength:
                posResult = np.nan
                negResult = np.nan
            else:
                newValue = signal[i+rmsForwardLength-1]
                if np.isnan(newValue):
                    lastBananaIndex = i+rmsForwardLength-1
                else:
                    if newValue >= 0:
                        sumOfSquaresPos += newValue**2
                        posCount += 1
                    elif newValue < 0:
                        sumOfSquaresNeg += newValue**2
                        negCount += 1
                
                if not np.isnan(lastBananaIndex) and i - lastBananaIndex <= rmsBackwardLength:
                    posResult = np.nan
                    negResult = np.nan
                else:
                    posResult = np.sqrt(sumOfSquaresPos / posCount) * rmsMultiplier if posCount > 0 else np.nan
                    negResult = -np.sqrt(sumOfSquaresNeg / negCount) * rmsMultiplier if negCount > 0 else np.nan

                oldValue = signal[i-rmsBackwardLength]
                
                if oldValue >= 0:
                    sumOfSquaresPos -= oldValue**2
                    posCount -= 1
                elif oldValue < 0:
                    sumOfSquaresNeg -= oldValue**2
                    negCount -=1
            result[i,0] = posResult
            result[i,1] = negResult
            
        return result
    
    else:
        #fill sum of squares buffers
        allValues = []
        windowLength = rmsBackwardLength + rmsForwardLength
        if len(signal) < windowLength:
            return result
        
        lastBananaIndex = np.nan
        
        for i in range(windowLength - 1):
            if not np.isnan(signal[i]):
                allValues.append(signal[i])
            else:
                lastBananaIndex = i
        allArray = np.array(allValues)
        
        sumOfSquaresAll = np.sum(allArray**2)
        allCount = len(allArray)
        
        for i in range(0, len(signal)):
            if i < rmsBackwardLength or i >= len(signal) - rmsForwardLength:
                allResult = np.nan
            else:
                newValue = signal[i+rmsForwardLength-1]
                if np.isnan(newValue):
                    lastBananaIndex = i+rmsForwardLength-1
                else:
                    sumOfSquaresAll += newValue**2
                    allCount += 1
                
                if not np.isnan(lastBananaIndex) and i - lastBananaIndex <= rmsBackwardLength:
                    allResult = np.nan
                else:
                    allResult = np.sqrt(sumOfSquaresAll / allCount) * rmsMultiplier if allCount > 0 else np.nan

                oldValue = signal[i-rmsBackwardLength]
                if not np.isnan(oldValue):
                    sumOfSquaresAll -= oldValue**2
                    allCount -= 1
                    
            result[i,0] = allResult
            result[i,1] = -allResult
        return result

def calculateBreathTimes(signal, posThresholds, negThresholds, minThreshold, zeroCrossingBreathStart):
    
    def breathTimes(startIndex, endIndex):

        def setInitialState(startValue, posThreshold, negThreshold):
            if startValue < negThreshold:
                state = LOW
            elif startValue > posThreshold:
                state = HIGH
            else:
                state = MID_UNKNOWN
            return state
    
        state = setInitialState(signal[startIndex], posThresholds[startIndex], negThresholds[startIndex])
        times = []
    
        for i in range(startIndex + 1, endIndex + 1):
            posThreshold = posThresholds[i]
            negThreshold = negThresholds[i]
            if state == LOW and signal[i] > negThreshold:
                state = MID_RISING
            elif state == HIGH and signal[i] < posThreshold:
                state = MID_FALLING
            elif (state == MID_RISING or state == MID_UNKNOWN) and signal[i] > posThreshold:
                state = HIGH
            elif (state == MID_FALLING or state == MID_UNKNOWN) and signal[i] < negThreshold:
                state = LOW
                times.append(i)

        if zeroCrossingBreathStart:
            zeroCrossingBreathTimes = []
            for t in times:
                for i in range(t,-1,-1):
                    if signal[i] >= 0:
                        zeroCrossingBreathTimes.append(i)
                        break
            return zeroCrossingBreathTimes
        else:
            return times

    LOW, MID_FALLING, MID_UNKNOWN, MID_RISING, HIGH = range(5)

    invalidated = np.ones(np.shape(signal), dtype=bool)
    for i in range(len(invalidated)):
        if posThresholds[i] > minThreshold or negThresholds[i] < -minThreshold:
            invalidated[i] = False
    
    minIslandLength = 0
    islandLimits = findIslandLimits(invalidated, minIslandLength)
    
    times = []
    for (start, end) in islandLimits:
        bt = breathTimes(start, end - 1) # Corrected end index
        if len(bt) > 0:
            times.append(bt)

    return times


# Code from Jack Taylor

def countLocalMaximas(values):
    count = 0
    if len(values) < 3:
        return 1
    if len(values) > 1 and values[0] > values[1]:
        count += 1
    if len(values) > 1 and values[-1] > values[-2]:
        count += 1
    for i in range(1, len(values) - 1):
        if values[i] > values[i - 1] and values[i] > values[i + 1]:
            count += 1
    return count

def countLocalMinimas(values):
    count = 0
    if len(values) < 3:
        return 1
    if len(values) > 1 and values[0] < values[1]:
        count += 1
    if len(values) > 1 and values[-1] < values[-2]:
        count += 1
    for i in range(1, len(values) - 1):
        if values[i] < values[i - 1] and values[i] < values[i + 1]:
            count += 1
    return count

def generate_RRV(sliced):
    sliced = sliced.dropna()
    if sliced.size == 0:
        return np.nan
    breathingSignal = sliced.values
    N = breathingSignal.shape[-1]
    y = breathingSignal
    yf = np.fft.fft(y)
    yff = 2.0/N * np.abs(yf[:N//2])
    temp_DCnotremov = yff
    if len(temp_DCnotremov) == 0 or len(temp_DCnotremov) == 1:
        return 0.0
    else:
        DC = np.amax(temp_DCnotremov)
        maxi = np.argmax(temp_DCnotremov)
        temp_DCremov = np.delete(temp_DCnotremov, maxi)
        H1 = np.amax(temp_DCremov)
        return 100-(H1/DC)*100

def getBreathsConservative(df, return_dataframe=True):
    """
    This function wraps the original breath detection logic and formats the output
    to be compatible with the `compare_breathing_rates_over_time_corrected` testing script.

    Args:
        df (pd.DataFrame): Input dataframe with 'timestamp' and 'breathingSignal'.
        return_dataframe (bool): If True, returns (DataFrame, stats). Otherwise,
                                 returns the raw 'times' list and stats.

    Returns:
        tuple: A tuple containing:
            - breath_df (pd.DataFrame): A DataFrame with 'timestamp' and 'type' for each detected breath event.
            - stats (dict): A dictionary with statistics ('inhalations', 'exhalations', 'breaths_per_minute').
    """
    # Use the core `getBreaths` logic but without the verbose printing
    times = getBreaths(df)

    signal_array = np.array(df.breathingSignal)
    timestamps = list(df.timestamp)
    breath_events = []
    inhalation_count = 0
    exhalation_count = 0

    for island in times:
        for j in range(len(island) - 1):
            start_idx = island[j]
            end_idx = island[j+1]

            if start_idx >= end_idx or end_idx >= len(signal_array):
                continue

            breath_segment = signal_array[start_idx:end_idx+1]
            peak_idx_relative = np.argmax(breath_segment)
            peak_idx = start_idx + peak_idx_relative

            if peak_idx > start_idx:
                inhalation_start_time = timestamps[start_idx]
                breath_events.append({'timestamp': inhalation_start_time, 'type': 'Inhalation'})
                inhalation_count += 1

            if end_idx > peak_idx:
                exhalation_start_time = timestamps[peak_idx]
                breath_events.append({'timestamp': exhalation_start_time, 'type': 'Exhalation'})
                exhalation_count += 1

    if not breath_events:
        breath_df = pd.DataFrame(columns=['timestamp', 'type'])
    else:
        breath_df = pd.DataFrame(breath_events)
        breath_df['timestamp'] = pd.to_datetime(breath_df['timestamp'])

        # ==================== FIX IS HERE ====================
        # The test script expects to localize from UTC. To ensure this works,
        # we strip any existing timezone info, returning a "naive" datetime.
        # The test script will then correctly localize this naive time to UTC.
        if breath_df['timestamp'].dt.tz is not None:
            breath_df['timestamp'] = breath_df['timestamp'].dt.tz_convert('UTC').dt.tz_localize(None)
        # =====================================================

    total_breaths = min(inhalation_count, exhalation_count)
    
    if not df.empty and not df['timestamp'].empty:
        df['timestamp'] = pd.to_datetime(df['timestamp'])
        duration_seconds = (df['timestamp'].max() - df['timestamp'].min()).total_seconds()
        duration_minutes = duration_seconds / 60 if duration_seconds > 0 else 0
    else:
        duration_minutes = 0

    breaths_per_minute = total_breaths / duration_minutes if duration_minutes > 0 else 0

    stats = {
        'inhalations': inhalation_count,
        'exhalations': exhalation_count,
        'breaths_per_minute': breaths_per_minute
    }

    if return_dataframe:
        return breath_df, stats
    else:
        return times, stats
# ==============================================================================
# END OF NEW FUNCTION
# ==============================================================================

def mode(l):
    if len(l) == 0:
        return np.NaN, {}, []
    
    # This function expects integer inputs, handle potential floats
    l = [int(x) for x in np.nan_to_num(l)]
    if not l: return np.NaN, {}, []

    sortedRoundedArray = np.sort(l)
    dict = {}
    
    # Handle potentially large integer values gracefully
    dist = {} # Use dict instead of pre-allocating large array
    maxCount = 0
    for e in sortedRoundedArray:
        dist[e] = dist.get(e, 0) + 1
        dict[e] = dict.get(e, 0) + 1
        newCount = dict[e]

        if newCount > maxCount:
                maxCount = newCount
    
    if maxCount > 0:
        l_modes = []
        for e in dict:
            if dict[e] == maxCount:
                l_modes.append(e)
        sorted_modes = np.sort(l_modes)
        # Return the median of the modes
        return sorted_modes[len(sorted_modes) // 2], dict, dist
                
    else:
        return np.NaN, dict, dist

def extractFeatures(df):
    times = getBreaths(df)

    areas = []
    extremas = []
    peakRespiratoryFlows = []
    types = []
    durations = []
    activityLevels = []
    activityTypes = []
    starts = []
    ends = []
    
    activityLevel = np.array(df.activityLevel)
    activityType = np.array(df.activityType)
    signal = np.array(df.breathingSignal)
    timestamps = list(df.timestamp)

    for i in range(0, len(times)):
        if i % 25 == 0 and len(times) > 0:
            print(f"Processing island {i}/{len(times)}... ", end=" ")
        vals = times[i]
        
        for j in range(0, len(vals)-1):
            start, end = vals[j], vals[j+1]
            flag = False
            breath = signal[start:end+1]
            breakPoint = start
            for k, val in enumerate(breath):
                if val >= 0.005: # arbitrary but to remove noise...
                    breakPoint = start + k
                    break

            # compute inhalation
            inhalation, inhalation_times = signal[start:breakPoint], timestamps[start:breakPoint]
            exhalation, exhalation_times = signal[breakPoint:end+1], timestamps[breakPoint:end+1]
                    
            level = activityLevel[start:end+1].mean()
            modeType = mode(activityType[start:end+1])[0]
            
            # compute inhalation
            if len(inhalation) > 1:
                peak = max(abs(np.array(inhalation)))
                extrema = countLocalMaximas(inhalation)
                dx = (inhalation_times[-1]-inhalation_times[0]).total_seconds() / len(inhalation)
                area = abs(np.trapezoid(y=inhalation,dx=dx))
                duration = (inhalation_times[-1]-inhalation_times[0]).total_seconds()
                
                areas.append(area)
                extremas.append(extrema)
                peakRespiratoryFlows.append(peak)
                types.append("Inhalation")
                durations.append(duration)
                activityLevels.append(level)
                activityTypes.append(modeType)
                starts.append(inhalation_times[0])
                ends.append(inhalation_times[-1])

            if len(exhalation) > 1:
                peak = max(abs(np.array(exhalation)))
                extrema = countLocalMinimas(exhalation)    
                dx = (exhalation_times[-1]-exhalation_times[0]).total_seconds() / len(exhalation)
                area = abs(np.trapezoid(y=exhalation,dx=dx))  
                duration = (exhalation_times[-1]-exhalation_times[0]).total_seconds()
                
                areas.append(area)
                extremas.append(extrema)
                peakRespiratoryFlows.append(peak)
                types.append("Exhalation")
                durations.append(duration)
                activityLevels.append(level)
                activityTypes.append(modeType)
                starts.append(exhalation_times[0])
                ends.append(exhalation_times[-1])

    return pd.DataFrame(data={"type": types, "area": areas, "peakRespiratoryFlow": peakRespiratoryFlows, "extremas": extremas, "duration": durations, "meanActivityLevel": activityLevels, "modeActivityType": activityTypes, "startTimestamp": starts, "endTimestamp": ends})


## OLD vs NEW

## PSG breaths

In [None]:
RESPECK_FILE = '../data/bishkek_csr/03_train_ready/respeck/11-05-2025_respeck.csv'
PSG_FILE = '../data/bishkek_csr/03_train_ready/nasal_files/11-05-2025_nasal.csv'
LABELS_FILE = '../data/bishkek_csr/03_train_ready/event_exports/11-05-2025_event_export.csv'
OUTPUT_FILE = './08-05-2025_respeck_features.csv'

# --- Load Data ---
print("Loading data...")

respeck_df = pd.read_csv(RESPECK_FILE)
respeck_df['timestamp'] = pd.to_datetime(respeck_df['alignedTimestamp'], unit='ms')
tz = pytz.timezone('Asia/Bishkek')
respeck_df['timestamp'] = respeck_df['timestamp'].dt.tz_localize('UTC').dt.tz_convert(tz)

psg_df = pd.read_csv(PSG_FILE)
psg_df['timestamp'] = pd.to_datetime(psg_df['UnixTimestamp'], unit='ms')
tz = pytz.timezone('Asia/Bishkek')
psg_df['timestamp'] = psg_df['timestamp'].dt.tz_localize('UTC').dt.tz_convert(tz)

labels_df = pd.read_csv(LABELS_FILE)
labels_df['timestamp'] = pd.to_datetime(labels_df['UnixTimestamp'], unit='ms')
tz = pytz.timezone('Asia/Bishkek')
labels_df['timestamp'] = labels_df['timestamp'].dt.tz_localize('UTC').dt.tz_convert(tz)

# forward and back fill respeck data before extraction

start_time_respeck = respeck_df['timestamp'].min()
end_time_respeck = respeck_df['timestamp'].max()

start_time_psg = psg_df['timestamp'].min()
end_time_psg = psg_df['timestamp'].max()

overlap_start = max(start_time_respeck, start_time_psg)
overlap_end = min(end_time_respeck, end_time_psg)


print(overlap_start)
print(overlap_end)

respeck_df = respeck_df[(respeck_df['timestamp'] >= overlap_start) & (respeck_df['timestamp'] <= overlap_end)]
psg_df = psg_df[(psg_df['timestamp'] >= overlap_start) & (psg_df['timestamp'] <= overlap_end)]

# Dynamically calculate the sampling rate from the timestamps
time_diffs_ms = respeck_df['alignedTimestamp'].diff().median()
if pd.isna(time_diffs_ms) or time_diffs_ms == 0:

    fs = 1000.0 / time_diffs_ms  # Sampling frequency in Hz
    print(f"    - Calculated sampling rate: {fs:.2f} Hz")

    # Define filter parameters
    lowcut = 0.1   # Lower cutoff frequency in Hz
    highcut = 1.5  # Upper cutoff frequency in Hz
    order = 2      # Filter order (2 is a good choice to avoid distortion)

    try:
        # Design the Butterworth bandpass filter
        nyquist = 0.5 * fs
        low = lowcut / nyquist
        high = highcut / nyquist
        b, a = butter(order, [low, high], btype='band')
        
        respeck_df['original_breathingSignal'] = respeck_df['breathingSignal']

    # 2. Apply the filter and OVERWRITE the 'breathingSignal' column with the clean data
        respeck_df['breathingSignal'] = filtfilt(b, a, respeck_df['breathingSignal'])

        # # Apply the filter and store it in a NEW column
        # # We keep the original 'breathingSignal' for reference
        # respeck_df['filteredBreathingSignal'] = filtfilt(b, a, respeck_df['breathingSignal'])
    except ValueError as e:
        print(f"  - WARNING: Skipping session. Filter could not be applied. Error: {e}")



In [None]:
import numpy as np
import pandas as pd
from scipy.signal import find_peaks, detrend, butter, filtfilt
from skimage.filters import threshold_otsu


def calculate_robust_min_distance(detrended_signal, sampling_rate, height_threshold):
    """
    More robust dynamic distance calculation that prevents spikes
    """

    rough_height = height_threshold * 0.8  # More restrictive than 0.5
    rough_distance = int(0.8 * sampling_rate)  # More restrictive than 0.4
    
    rough_peaks, _ = find_peaks(np.abs(detrended_signal), 
                               height=rough_height, 
                               distance=rough_distance)
    
    if len(rough_peaks) >= 3:  # Need at least 3 peaks for reliable intervals
        intervals = np.diff(rough_peaks)
        
        # 2. Use more robust statistics - filter outliers first
        q75, q25 = np.percentile(intervals, [75, 25])
        iqr = q75 - q25
        
        # Remove outliers (intervals too short or too long)
        valid_intervals = intervals[
            (intervals >= q25 - 1.5 * iqr) & 
            (intervals <= q75 + 1.5 * iqr)
        ]
        
        if len(valid_intervals) >= 2:
            median_interval = np.median(valid_intervals)
            
            # 3. More conservative multiplier and stricter bounds
            proposed_min_distance = median_interval * 0.8  # More conservative than 0.7
            
            # 4. Stricter bounds - never allow very short distances
            min_distance = int(np.clip(proposed_min_distance,
                                     a_min=int(sampling_rate * 0.8),    # 0.8s instead of 0.5s
                                     a_max=int(sampling_rate * 4.0)))
            
            # 5. Sanity check - if calculated distance suggests >40 BPM, cap it
            max_reasonable_rate = 40  # breaths per minute
            min_reasonable_distance = int(60 * sampling_rate / max_reasonable_rate)
            min_distance = max(min_distance, min_reasonable_distance)
            
            return min_distance
    
    # Fallback for unclear signals
    return int(sampling_rate * 1.2)  # Conservative 1.2s instead of 1.5s

def _consolidate_event_group(event_group):
    """
    Helper function to consolidate a group of consecutive events of the same type.
    """
    if len(event_group) == 1:
        return event_group[0]
    
    event_type = event_group[0]['type']
    start_timestamp = event_group[0]['timestamp']
    end_timestamp = event_group[-1]['timestamp']
    
    duration_delta = end_timestamp - start_timestamp
    if hasattr(duration_delta, 'total_seconds'):
        duration = duration_delta.total_seconds()
    else:
        duration = duration_delta / np.timedelta64(1, 's')
    
    amplitudes = [event['amplitude'] for event in event_group]
    raw_amplitudes = [event['raw_amplitude'] for event in event_group]
    
    max_amplitude_idx = np.argmax([abs(amp) for amp in amplitudes])
    consolidated_amplitude = amplitudes[max_amplitude_idx]
    consolidated_raw_amplitude = raw_amplitudes[max_amplitude_idx]
    consolidated_index = event_group[max_amplitude_idx]['index']
    
    consolidated_event = {
        'type': event_type, 'index': consolidated_index, 'timestamp': start_timestamp,
        'end_timestamp': end_timestamp, 'duration_seconds': duration, 'amplitude': consolidated_amplitude,
        'raw_amplitude': consolidated_raw_amplitude, 'event_type': event_group[0]['event_type'],
        'orientation_type': event_group[0]['orientation_type'], 'gravity_influence': event_group[0]['gravity_influence'],
        'events_merged': len(event_group), 'is_consolidated': True
    }
    return consolidated_event


def _calibrate_orientation_thresholds_with_accelerometer(signal, sampling_rate, accel_x, accel_y, accel_z, fallback_low=0.15, fallback_high=0.5):
    try:
        if len(accel_x) > 0:
            avg_z = np.mean(accel_z)
            if abs(avg_z) > 0.7:
                gravity_influence = "high"
                preprocessing_needed = "high_pass_filter"
            else:
                gravity_influence = "medium"
                preprocessing_needed = "detrend_only"
        else:
            gravity_influence = "medium"
            preprocessing_needed = "detrend_only"
        
        return (0.1, 0.5), {'orientation_type': 'unknown', 'gravity_influence': gravity_influence, 'preprocessing_needed': preprocessing_needed}
    except:
         return (0.1, 0.5), {'orientation_type': 'unknown', 'gravity_influence': 'medium', 'preprocessing_needed': 'detrend_only'}

def _detect_high_breathing_rate_periods(signal, sampling_rate, window_minutes=2):
    window_samples = int(window_minutes * 60 * sampling_rate)
    high_rate_mask = np.zeros(len(signal), dtype=bool)
    step_size = window_samples // 4
    for i in range(0, len(signal) - window_samples, step_size):
        window_signal = signal[i:i + window_samples]
        detrended = detrend(window_signal, type='constant')
        rough_peaks, _ = find_peaks(np.abs(detrended), distance=int(0.4 * sampling_rate))
        estimated_rate = len(rough_peaks) * (60 / window_minutes)
        if estimated_rate > 20:
            high_rate_mask[i:i + window_samples] = True
    return high_rate_mask


# =============================================================================
# THE MAIN FUNCTION: ORIGINAL LOGIC + DYNAMIC DISTANCE FIX
# =============================================================================
def adaptive_breath_detection_original_fixed(df, adaptation_window_minutes=10, 
                                           sensitivity='medium', method='peaks',
                                           pad_duration_minutes=20):
    """
    Your original, successful function with a single, targeted fix
    to make the peak detection distance dynamically adaptive.
    """
    
    print("🚀 ORIGINAL ALGORITHM - WITH DYNAMIC DISTANCE FIX")
    print("=" * 75)
    
    # --- 1. Input Validation and Data Preparation (Your Original Code) ---
    required_columns = ['breathingSignal', 'timestamp']
    if not all(col in df.columns for col in required_columns):
        raise ValueError(f"DataFrame must contain columns: {required_columns}")
    
    signal_series = df['breathingSignal'].copy().replace([np.inf, -np.inf], np.nan)
    valid_signal = signal_series.dropna()
    if len(valid_signal) < 200:
        raise ValueError(f"Insufficient valid samples: {len(valid_signal)}")
    
    original_signal = valid_signal.values
    valid_indices = valid_signal.index
    original_timestamps = df.loc[valid_indices, 'timestamp'].values
    original_accel_x = df.loc[valid_indices, 'x'].values
    original_accel_y = df.loc[valid_indices, 'y'].values
    original_accel_z = df.loc[valid_indices, 'z'].values
    activity_level = df.loc[valid_indices, 'activityLevel'].values if 'activityLevel' in df.columns else None
    
    time_diffs = pd.Series(pd.to_datetime(original_timestamps)).diff().dt.total_seconds().dropna()
    avg_sample_period = time_diffs.median()
    if pd.isna(avg_sample_period) or avg_sample_period <= 0:
        avg_sample_period = 0.02
    sampling_rate = 1 / avg_sample_period
    
    pad_samples = int(pad_duration_minutes * 60 * sampling_rate)
    signal_padded = np.pad(original_signal, pad_samples, mode='reflect')
    accel_x_padded = np.pad(original_accel_x, pad_samples, mode='reflect')
    accel_y_padded = np.pad(original_accel_y, pad_samples, mode='reflect')
    accel_z_padded = np.pad(original_accel_z, pad_samples, mode='reflect')
    activity_padded = np.pad(activity_level, pad_samples, mode='edge') if activity_level is not None else None

    high_rate_mask = _detect_high_breathing_rate_periods(signal_padded, sampling_rate)
    
    # --- 2. Your Original Processing Loop ---
    adaptation_window_samples = int(adaptation_window_minutes * 60 * sampling_rate)
    
    sensitivity_params = {
        'low': {'base_height': 0.6, 'base_prominence': 0.5},
        'medium': {'base_height': 0.5, 'base_prominence': 0.4},
        'high': {'base_height': 0.4, 'base_prominence': 0.3}
    }
    params = sensitivity_params.get(sensitivity, sensitivity_params['medium'])
    all_breath_events = []
    step_size = adaptation_window_samples // 4
    window_start = 0
    
    while window_start + adaptation_window_samples <= len(signal_padded):
        window_end = window_start + adaptation_window_samples
        
        window_signal = signal_padded[window_start:window_end]
        window_accel_x = accel_x_padded[window_start:window_end]
        window_accel_y = accel_y_padded[window_start:window_end]
        window_accel_z = accel_z_padded[window_start:window_end]
        
        is_high_rate_window = np.mean(high_rate_mask[window_start:window_end]) > 0.15        
        try:
            (_, _), window_orientation_info = _calibrate_orientation_thresholds_with_accelerometer(
                window_signal, sampling_rate, window_accel_x, window_accel_y, window_accel_z
            )
        except:
            window_orientation_info = {'orientation_type': 'unknown', 'gravity_influence': 'medium', 'preprocessing_needed': 'detrend_only'}
        
        # --- Your original sophisticated parameter calculation ---
        detrended_signal = detrend(window_signal, type='constant')
        signal_std = np.std(detrended_signal)
        # (All your logic for MAD, gravity, activity, quality, etc. is preserved here)
        gravity_influence = window_orientation_info['gravity_influence']
        if gravity_influence == 'high': min_amplitude, base_height_factor = 0.015, 0.25
        elif gravity_influence == 'medium': min_amplitude, base_height_factor = 0.008, 0.15
        else: min_amplitude, base_height_factor = 0.005, 0.1
            
        height_threshold = max(signal_std * params['base_height'] * base_height_factor, min_amplitude * 0.3)
        prominence_threshold = height_threshold * 0.8
        
        # =========================================================================
        # THE ONLY CHANGE: DYNAMIC DISTANCE FIX
        # This replaces your old rigid `if is_high_rate_window:` block for distance.
        # =========================================================================
        # 1. Perform a lenient first pass to estimate the local rhythm of the current window.
        rough_peaks, _ = find_peaks(np.abs(detrended_signal),distance=int(0.4 * sampling_rate))
        
        # 2. If enough peaks were found, calculate the median interval between them.
        if len(rough_peaks) > 2:
            median_interval_samples = np.median(np.diff(rough_peaks))
            # 3. Set the definitive min_distance to a fraction of that median interval.
            # We clip it to prevent it from being too short (noise) or too long (missed breaths).
            min_distance = int(np.clip(median_interval_samples * 0.7,      # 70% of median interval
                                   a_min=int(sampling_rate * 0.8),     # Never shorter than 0.5s (120 BPM)
                                   a_max=int(sampling_rate * 4.0)))    # Never longer than 4s (15 BPM)
        else:
            # 4. If not enough peaks were found, fall back to a safe, normal-rate default.
            min_distance = int(sampling_rate * 1.5) 
        # =========================================================================

        # --- Your original peak detection and event creation logic ---
        processed_signal = detrended_signal # Use your filtered signal if applicable
        if window_orientation_info.get('preprocessing_needed') == 'high_pass_filter':
            try:
                b, a = butter(2, 0.1 / (sampling_rate / 2), btype='high')
                processed_signal = filtfilt(b, a, processed_signal)
            except Exception: pass
        
        try:
            peaks, _ = find_peaks(processed_signal, height=height_threshold, distance=min_distance, prominence=prominence_threshold)
            troughs, _ = find_peaks(-processed_signal, height=height_threshold, distance=min_distance, prominence=prominence_threshold)
            
            # --- Your original event creation logic ---
            for peak_idx in peaks:
                global_padded_idx = window_start + peak_idx
                original_signal_idx = global_padded_idx - pad_samples
                if 0 <= original_signal_idx < len(original_timestamps):
                    all_breath_events.append({
                        'type': 'Inhalation', 'index': valid_indices[original_signal_idx],
                        'timestamp': original_timestamps[original_signal_idx], 'amplitude': processed_signal[peak_idx], 
                        'raw_amplitude': original_signal[original_signal_idx], 'event_type': 'peak', 
                        'orientation_type': window_orientation_info['orientation_type'],
                        'gravity_influence': gravity_influence, 'high_rate_period': is_high_rate_window
                    })
            
            for trough_idx in troughs:
                global_padded_idx = window_start + trough_idx
                original_signal_idx = global_padded_idx - pad_samples
                if 0 <= original_signal_idx < len(original_timestamps):
                    all_breath_events.append({
                        'type': 'Exhalation', 'index': valid_indices[original_signal_idx],
                        'timestamp': original_timestamps[original_signal_idx], 'amplitude': processed_signal[trough_idx], 
                        'raw_amplitude': original_signal[original_signal_idx], 'event_type': 'trough', 
                        'orientation_type': window_orientation_info['orientation_type'],
                        'gravity_influence': gravity_influence, 'high_rate_period': is_high_rate_window
                    })
                    
        except Exception:
            pass
        
        window_start += step_size
    
    # --- 3. Finalization and Stats (Your Original Code) ---
    if not all_breath_events:
        return pd.DataFrame(), {'error': 'No events detected'}
    
    all_breath_events.sort(key=lambda x: x['timestamp'])
    
    filtered_events = []
    last_timestamp = None
    min_event_spacing = pd.Timedelta(seconds=0.1) # Slightly shorter to allow faster rates
    for event in all_breath_events:
        if last_timestamp is None or (pd.Timestamp(event['timestamp']) - last_timestamp) > min_event_spacing:
            filtered_events.append(event)
            last_timestamp = pd.Timestamp(event['timestamp'])
    
    consolidated_events = []
    current_group = []
    for event in filtered_events:
        if not current_group or current_group[-1]['type'] == event['type']:
            current_group.append(event)
        else:
            consolidated_events.append(_consolidate_event_group(current_group))
            current_group = [event]
    if current_group:
        consolidated_events.append(_consolidate_event_group(current_group))
    
    breath_df = pd.DataFrame(consolidated_events).sort_values('timestamp').reset_index(drop=True)
    
    inhalations = len(breath_df[breath_df['type'] == 'Inhalation'])
    exhalations = len(breath_df[breath_df['type'] == 'Exhalation'])
    breathing_cycles = min(inhalations, exhalations)
    duration_minutes = (pd.to_datetime(original_timestamps[-1]) - pd.to_datetime(original_timestamps[0])).total_seconds() / 60
    breaths_per_minute = breathing_cycles / duration_minutes if duration_minutes > 0 else 0
    
    stats = {
        'breaths_per_minute': breaths_per_minute, 'breathing_cycles': breathing_cycles,
        'inhalations': inhalations, 'exhalations': exhalations, 'duration_minutes': duration_minutes
    }
    
    return breath_df, stats

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from scipy.stats import pearsonr
import seaborn as sns

# Import NeuroKit - make sure it's installed: pip install neurokit2
try:
    import neurokit2 as nk
    NEUROKIT_AVAILABLE = True
    print("✅ NeuroKit2 successfully imported")
except ImportError:
    NEUROKIT_AVAILABLE = False
    print("❌ NeuroKit2 not available. Install with: pip install neurokit2")

def cal_timeseries_instantaneous_rr(signal, sampling_rate=12, window=10):
    """
    Calculate the instantaneous respiratory rate (breaths per minute) from a given respiratory signal.
    Modified to accept dynamic sampling rate.

    Parameters:
    - signal (array-like): The respiratory signal data.
    - sampling_rate (float): Sampling rate of the signal in Hz
    - window (int): Window size for rate calculation

    Returns:
    - rsp_rate (array-like): The computed respiratory rate over time.
    """
    if not NEUROKIT_AVAILABLE:
        raise ImportError("NeuroKit2 is required but not installed")
    
    try:
        rsp_rate = nk.rsp_rate(signal, troughs=None, sampling_rate=sampling_rate, window=window,
                               hop_size=1, method='trough', peak_method='khodadad2018',
                               interpolation_method='monotone_cubic')
        return rsp_rate
    except Exception as e:
        print(f"NeuroKit rsp_rate failed: {e}")
        return None

def compare_algorithms_vs_respeck_builtin_with_neurokit(respeck_df, window_minutes=5):
    """
    Compare three methods against RESpeck's built-in breathing rate measurements:
    1. New adaptive algorithm
    2. Old algorithm  
    3. NeuroKit2 respiratory rate
    """
    
    print("🔍 ALGORITHMS + NEUROKIT vs RESpeck BUILT-IN COMPARISON")
    print("=" * 80)
    print(f"📅 Using {window_minutes}-minute non-overlapping windows")
    
    # --- 1. Check RESpeck built-in breathing rate data ---
    if 'breathingRate' not in respeck_df.columns:
        print("❌ No 'breathingRate' column found in RESpeck data")
        return None
    
    respeck_sensor_data = respeck_df[['timestamp', 'breathingRate']].copy()
    respeck_sensor_data['breathingRate'] = pd.to_numeric(respeck_sensor_data['breathingRate'], errors='coerce')
    
    # Only keep valid breathing rate measurements
    valid_respeck = respeck_sensor_data.dropna(subset=['breathingRate'])
    
    if valid_respeck.empty:
        print("❌ No valid RESpeck breathing rate data available")
        return None
    
    print(f"✅ Found {len(valid_respeck):,} valid RESpeck measurements")
    print(f"   Time range: {valid_respeck['timestamp'].min()} to {valid_respeck['timestamp'].max()}")
    print(f"   Rate range: {valid_respeck['breathingRate'].min():.1f} - {valid_respeck['breathingRate'].max():.1f} breaths/min")
    
    # --- 2. Calculate sampling rate ---
    time_diffs = respeck_df['timestamp'].diff().dropna()
    avg_sample_period = time_diffs.apply(lambda x: x.total_seconds()).median()
    if pd.isna(avg_sample_period) or avg_sample_period <= 0:
        avg_sample_period = 0.02
    sampling_rate = 1 / avg_sample_period
    print(f"📊 Detected sampling rate: {sampling_rate:.1f} Hz")
    
    # --- 3. Determine time windows ---
    data_start = respeck_df['timestamp'].min()
    data_end = respeck_df['timestamp'].max()
    
    total_duration = (data_end - data_start).total_seconds() / 60  # minutes
    num_windows = int(total_duration // window_minutes)
    
    print(f"\n📊 Dataset Overview:")
    print(f"   Data range: {data_start} to {data_end}")
    print(f"   Total duration: {total_duration:.1f} minutes")
    print(f"   Number of {window_minutes}-min windows: {num_windows}")
    
    if num_windows < 1:
        print(f"❌ Insufficient data for {window_minutes}-minute windows")
        return None
    
    # --- 4. Process each window ---
    results = []
    
    for i in range(num_windows):
        window_start = data_start + pd.Timedelta(minutes=i * window_minutes)
        window_end = window_start + pd.Timedelta(minutes=window_minutes)
        
        print(f"\n--- Window {i+1}/{num_windows}: {window_start.strftime('%H:%M')} to {window_end.strftime('%H:%M')} ---")
        
        # Extract data for this window
        respeck_window = respeck_df[(respeck_df['timestamp'] >= window_start) & 
                                   (respeck_df['timestamp'] < window_end)].copy()
        
        if len(respeck_window) < 50:
            print(f"   ⚠️  Insufficient data in window {i+1}")
            continue
        
        window_result = {
            'window_id': i + 1,
            'start_time': window_start,
            'end_time': window_end,
            'respeck_samples': len(respeck_window)
        }
        
        # Get RESpeck built-in breathing rate for this window
        respeck_builtin_rates = respeck_window['breathingRate'].dropna()
        if len(respeck_builtin_rates) > 0:
            window_result['respeck_builtin_bpm'] = respeck_builtin_rates.mean()
            window_result['respeck_builtin_std'] = respeck_builtin_rates.std()
            window_result['respeck_builtin_count'] = len(respeck_builtin_rates)
        else:
            print(f"   ⚠️  No valid RESpeck breathing rates in window {i+1}")
            continue
        
        # --- Method 1: New Algorithm ---
        try:
            print("   🚀 Running new algorithm...")
            breath_df_new, stats_new = adaptive_breath_detection_original_fixed(
                respeck_window, 
                adaptation_window_minutes=0.5,        
                pad_duration_minutes=1,             
                # sensitivity='medium'                
            )
            
            window_result['new_algo_events'] = len(breath_df_new)
            window_result['new_algo_cycles'] = stats_new.get('breathing_cycles', 0)
            window_result['new_algo_bpm'] = stats_new.get('breaths_per_minute', 0)
            window_result['new_algo_success'] = True
            
        except Exception as e:
            print(f"   ❌ New algorithm failed: {e}")
            window_result.update({
                'new_algo_events': 0, 'new_algo_cycles': 0, 'new_algo_bpm': 0, 'new_algo_success': False
            })
        
        # --- Method 2: Old Algorithm ---
        try:
            print("   📜 Running old algorithm...")
            breath_df_old, stats_old = getBreathsConservative(respeck_window)
            
            window_result['old_algo_events'] = len(breath_df_old)
            window_result['old_algo_cycles'] = min(stats_old.get('inhalations', 0), stats_old.get('exhalations', 0))
            window_result['old_algo_bpm'] = stats_old.get('breaths_per_minute', 0)
            window_result['old_algo_success'] = True
            
        except Exception as e:
            print(f"   ❌ Old algorithm failed: {e}")
            window_result.update({
                'old_algo_events': 0, 'old_algo_cycles': 0, 'old_algo_bpm': 0, 'old_algo_success': False
            })
        
        # --- Method 3: NeuroKit Algorithm ---
        try:
            print("   🧠 Running NeuroKit algorithm...")
            
            if not NEUROKIT_AVAILABLE:
                raise ImportError("NeuroKit2 not available")
            
            # Extract breathing signal and clean it
            breathing_signal = respeck_window['breathingSignal'].dropna().values
            
            if len(breathing_signal) < 30:  # Need minimum samples
                raise ValueError("Insufficient signal length for NeuroKit")
            
            # Calculate instantaneous respiratory rate
            window_size_seconds = min(20, len(breathing_signal) / sampling_rate / 2)  # Adaptive window
            instantaneous_rr = cal_timeseries_instantaneous_rr(
                breathing_signal, 
                sampling_rate=sampling_rate, 
                window=int(window_size_seconds)
            )
            
            if instantaneous_rr is not None and len(instantaneous_rr) > 0:
                # Remove outliers and calculate average
                valid_rates = instantaneous_rr[~np.isnan(instantaneous_rr)]
                
                if len(valid_rates) > 0:
                    # Remove extreme outliers (outside 5-50 bpm range)
                    valid_rates = valid_rates[(valid_rates >= 5) & (valid_rates <= 50)]
                    
                    if len(valid_rates) > 0:
                        # Use median for robustness
                        avg_neurokit_bpm = np.median(valid_rates)
                        std_neurokit_bpm = np.std(valid_rates)
                        
                        window_result['neurokit_bpm'] = avg_neurokit_bpm
                        window_result['neurokit_std'] = std_neurokit_bpm
                        window_result['neurokit_valid_samples'] = len(valid_rates)
                        window_result['neurokit_success'] = True
                        
                        print(f"   🧠 NeuroKit: {avg_neurokit_bpm:.1f} ± {std_neurokit_bpm:.1f} bpm ({len(valid_rates)} valid samples)")
                    else:
                        raise ValueError("No valid rates after outlier removal")
                else:
                    raise ValueError("No valid instantaneous rates calculated")
            else:
                raise ValueError("NeuroKit returned no valid data")
                
        except Exception as e:
            print(f"   ❌ NeuroKit algorithm failed: {e}")
            window_result.update({
                'neurokit_bpm': 0, 'neurokit_std': 0, 'neurokit_valid_samples': 0, 'neurokit_success': False
            })
        
        results.append(window_result)
        
        # Print window summary
        print(f"   📈 Window {i+1} Results:")
        print(f"      RESpeck Built-in: {window_result.get('respeck_builtin_bpm', 0):.1f} bpm")
        print(f"      New Algorithm: {window_result.get('new_algo_cycles', 0)} cycles ({window_result.get('new_algo_bpm', 0):.1f} bpm)")
        print(f"      Old Algorithm: {window_result.get('old_algo_cycles', 0)} cycles ({window_result.get('old_algo_bpm', 0):.1f} bpm)")
        print(f"      NeuroKit: {window_result.get('neurokit_bpm', 0):.1f} bpm")
    
    # --- 5. Analyze Results ---
    if not results:
        print("❌ No valid windows processed")
        return None
    
    results_df = pd.DataFrame(results)
    
    print(f"\n📊 OVERALL ANALYSIS ({len(results_df)} windows)")
    print("=" * 80)
    
    # Filter successful detections
    valid_results = results_df[
        (results_df['new_algo_success'] == True) & 
        (results_df['old_algo_success'] == True) &
        (results_df['neurokit_success'] == True)
    ].copy()
    
    print(f"✅ Valid windows (all algorithms succeeded): {len(valid_results)}/{len(results_df)}")
    
    if len(valid_results) == 0:
        print("❌ No windows where all algorithms succeeded")
        # Try with just successful RESpeck + NeuroKit
        valid_results = results_df[
            (results_df['neurokit_success'] == True)
        ].copy()
        print(f"🔄 Fallback: Windows with NeuroKit success: {len(valid_results)}")
        
        if len(valid_results) == 0:
            return results_df
    
    # --- 6. Statistical Comparisons ---
    print(f"\n📈 STATISTICAL SUMMARY:")
    
    methods = ['respeck_builtin', 'new_algo', 'old_algo', 'neurokit']
    method_names = ['RESpeck Built-in', 'New Algorithm', 'Old Algorithm', 'NeuroKit']
    
    summary_stats = {}
    
    for method, name in zip(methods, method_names):
        bpm_col = f'{method}_bpm'
        
        if bpm_col in valid_results.columns and valid_results[bpm_col].notna().any():
            bpm_values = valid_results[bpm_col].dropna().values
            
            if len(bpm_values) > 0:
                mean_bpm = np.mean(bpm_values)
                std_bpm = np.std(bpm_values)
                median_bpm = np.median(bpm_values)
                
                summary_stats[method] = {
                    'name': name,
                    'mean_bpm': mean_bpm,
                    'std_bpm': std_bpm,
                    'median_bpm': median_bpm,
                    'values': bpm_values
                }
                
                print(f"{name}:")
                print(f"   Mean: {mean_bpm:.1f} ± {std_bpm:.1f} bpm")
                print(f"   Median: {median_bpm:.1f} bpm")
                print(f"   Range: {np.min(bpm_values):.1f} - {np.max(bpm_values):.1f} bpm")
                print()
    
    # --- 7. Correlation Analysis ---
    print(f"📈 CORRELATION ANALYSIS (vs RESpeck Built-in):")
    
    correlations = {}
    
    if len(valid_results) > 2:
        try:
            # Calculate correlations for all methods that have data
            if 'new_algo_bpm' in valid_results.columns:
                corr_new_respeck, p_new_respeck = pearsonr(valid_results['new_algo_bpm'], valid_results['respeck_builtin_bpm'])
                correlations['new_vs_respeck'] = (corr_new_respeck, p_new_respeck)
                print(f"New Algorithm vs RESpeck Built-in: r = {corr_new_respeck:.3f} (p = {p_new_respeck:.3f})")
            
            if 'old_algo_bpm' in valid_results.columns:
                corr_old_respeck, p_old_respeck = pearsonr(valid_results['old_algo_bpm'], valid_results['respeck_builtin_bpm'])
                correlations['old_vs_respeck'] = (corr_old_respeck, p_old_respeck)
                print(f"Old Algorithm vs RESpeck Built-in: r = {corr_old_respeck:.3f} (p = {p_old_respeck:.3f})")
            
            if 'neurokit_bpm' in valid_results.columns:
                corr_neurokit_respeck, p_neurokit_respeck = pearsonr(valid_results['neurokit_bpm'], valid_results['respeck_builtin_bpm'])
                correlations['neurokit_vs_respeck'] = (corr_neurokit_respeck, p_neurokit_respeck)
                print(f"NeuroKit vs RESpeck Built-in: r = {corr_neurokit_respeck:.3f} (p = {p_neurokit_respeck:.3f})")
            
            # Cross-comparisons
            if 'new_algo_bpm' in valid_results.columns and 'old_algo_bpm' in valid_results.columns:
                corr_new_old, p_new_old = pearsonr(valid_results['new_algo_bpm'], valid_results['old_algo_bpm'])
                correlations['new_vs_old'] = (corr_new_old, p_new_old)
                print(f"New vs Old Algorithm: r = {corr_new_old:.3f} (p = {p_new_old:.3f})")
            
        except Exception as e:
            print(f"❌ Correlation analysis failed: {e}")
    
    # --- 8. Agreement Analysis ---
    print(f"\n🎯 AGREEMENT ANALYSIS (vs RESpeck Built-in as Reference):")
    
    algorithm_methods = [('new_algo', 'New Algorithm'), ('old_algo', 'Old Algorithm'), ('neurokit', 'NeuroKit')]
    
    for method, name in algorithm_methods:
        bpm_col = f'{method}_bpm'
        
        if bpm_col in valid_results.columns and valid_results[bpm_col].notna().any():
            valid_comparison = valid_results.dropna(subset=[bpm_col, 'respeck_builtin_bpm'])
            
            if len(valid_comparison) > 0:
                differences = valid_comparison[bpm_col] - valid_comparison['respeck_builtin_bpm']
                
                mean_diff = np.mean(differences)
                std_diff = np.std(differences)
                mae = np.mean(np.abs(differences))
                
                within_2 = np.sum(np.abs(differences) <= 2) / len(differences) * 100
                within_3 = np.sum(np.abs(differences) <= 3) / len(differences) * 100
                
                print(f"{name} (n={len(valid_comparison)}):")
                print(f"   Mean difference: {mean_diff:+.1f} ± {std_diff:.1f} bpm")
                print(f"   Mean Absolute Error: {mae:.1f} bpm")
                print(f"   Within ±2 bpm: {within_2:.1f}%")
                print(f"   Within ±3 bpm: {within_3:.1f}%")
                print()
    
    # --- 9. Create Visualizations ---
    create_enhanced_comparison_plots(valid_results, summary_stats, correlations, window_minutes)
    
    # --- 10. Return Results ---
    final_results = {
        'all_windows': results_df,
        'valid_windows': valid_results,
        'summary_stats': summary_stats,
        'correlations': correlations,
        'num_valid_windows': len(valid_results),
        'total_windows': len(results_df),
        'window_duration_minutes': window_minutes,
        'neurokit_available': NEUROKIT_AVAILABLE
    }
    
    return final_results

def create_enhanced_comparison_plots(valid_results, summary_stats, correlations, window_minutes):
    """
    Create comprehensive visualization plots for the enhanced comparison including NeuroKit
    """
    
    # Determine which methods have data
    available_methods = []
    method_colors = []
    method_markers = []
    
    base_methods = [
        ('respeck_builtin_bpm', 'RESpeck Built-in', 'black', 'o'),
        ('new_algo_bpm', 'New Algorithm', 'red', 'o'),
        ('old_algo_bpm', 'Old Algorithm', 'blue', 's'),
        ('neurokit_bpm', 'NeuroKit', 'green', '^')
    ]
    
    for col, name, color, marker in base_methods:
        if col in valid_results.columns and valid_results[col].notna().any():
            available_methods.append((col, name, color, marker))
            method_colors.append(color)
            method_markers.append(marker)
    
    if len(available_methods) < 2:
        print("❌ Insufficient methods with data for plotting")
        return
    
    # Create dynamic subplot layout
    n_methods = len(available_methods)
    if n_methods == 4:
        fig = plt.figure(figsize=(20, 15))
        gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)
    else:
        fig = plt.figure(figsize=(18, 12))
        gs = fig.add_gridspec(3, 2, hspace=0.3, wspace=0.3)
    
    # 1. Time series comparison (spans top row)
    ax1 = fig.add_subplot(gs[0, :])
    
    for col, name, color, marker in available_methods:
        valid_data = valid_results[col].dropna()
        indices = valid_data.index
        ax1.plot(indices, valid_data.values, 
                marker=marker, linestyle='-', label=name, color=color, 
                alpha=0.8, markersize=6, linewidth=2 if name == 'RESpeck Built-in' else 1.5)
    
    ax1.set_title(f'Breathing Rate Comparison Across {window_minutes}-Minute Windows')
    ax1.set_xlabel('Window Number')
    ax1.set_ylabel('Breathing Rate (breaths/min)')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # 2. Box plot comparison
    ax2 = fig.add_subplot(gs[1, 0])
    
    box_data = []
    box_labels = []
    box_colors = []
    
    for col, name, color, marker in available_methods:
        valid_data = valid_results[col].dropna()
        if len(valid_data) > 0:
            box_data.append(valid_data.values)
            box_labels.append(name.replace(' ', '\n'))
            box_colors.append(color)
    
    if box_data:
        bp = ax2.boxplot(box_data, labels=box_labels, patch_artist=True)
        for patch, color in zip(bp['boxes'], box_colors):
            patch.set_facecolor(color)
            patch.set_alpha(0.6)
    
    ax2.set_title('Distribution Comparison')
    ax2.set_ylabel('Breathing Rate (breaths/min)')
    ax2.grid(True, alpha=0.3)
    
    # 3. Correlation plots - create subplots for each algorithm vs RESpeck
    correlation_plots = []
    algorithm_methods = [(col, name, color, marker) for col, name, color, marker in available_methods 
                        if name != 'RESpeck Built-in']
    
    # Plot correlations
    plot_idx = 0
    for col, name, color, marker in algorithm_methods:
        if plot_idx < 2:  # Limit to available subplot positions
            if n_methods == 4:
                ax = fig.add_subplot(gs[1, plot_idx + 1])
            else:
                ax = fig.add_subplot(gs[1, 1] if plot_idx == 0 else gs[2, 0])
            
            # Create correlation plot
            respeck_data = valid_results['respeck_builtin_bpm'].dropna()
            method_data = valid_results[col].dropna()
            
            # Find common indices
            common_idx = respeck_data.index.intersection(method_data.index)
            
            if len(common_idx) > 1:
                x_vals = respeck_data.loc[common_idx].values
                y_vals = method_data.loc[common_idx].values
                
                ax.scatter(x_vals, y_vals, alpha=0.7, color=color, s=50)
                
                # Perfect agreement line
                min_val = min(np.min(x_vals), np.min(y_vals))
                max_val = max(np.max(x_vals), np.max(y_vals))
                ax.plot([min_val, max_val], [min_val, max_val], 'k--', alpha=0.5, label='Perfect Agreement')
                
                # Regression line
                corr_key = f"{col.replace('_bpm', '')}_vs_respeck"
                if corr_key in correlations:
                    z = np.polyfit(x_vals, y_vals, 1)
                    p = np.poly1d(z)
                    ax.plot(x_vals, p(x_vals), color=color, alpha=0.8, 
                           label=f'r = {correlations[corr_key][0]:.3f}')
                
                ax.set_xlabel('RESpeck Built-in (breaths/min)')
                ax.set_ylabel(f'{name} (breaths/min)')
                ax.set_title(f'{name} vs RESpeck Built-in')
                ax.legend()
                ax.grid(True, alpha=0.3)
        
        plot_idx += 1
    
    # 4. Bland-Altman plot comparing all algorithms vs RESpeck Built-in
    if n_methods == 4:
        ax_ba = fig.add_subplot(gs[2, :])
    else:
        ax_ba = fig.add_subplot(gs[2, 1])
    
    # Perfect agreement line
    ax_ba.axhline(y=0, color='black', linestyle='-', alpha=0.5, label='Perfect Agreement')
    
    for col, name, color, marker in algorithm_methods:
        respeck_data = valid_results['respeck_builtin_bpm'].dropna()
        method_data = valid_results[col].dropna()
        common_idx = respeck_data.index.intersection(method_data.index)
        
        if len(common_idx) > 0:
            differences = method_data.loc[common_idx] - respeck_data.loc[common_idx]
            mean_diff = differences.mean()
            
            # Mean line for this method
            ax_ba.axhline(y=mean_diff, color=color, linestyle='-', linewidth=2,
                         label=f'{name} (Mean: {mean_diff:.1f})')
    
    ax_ba.set_xlabel('Breathing Rate Range (breaths/min)')
    ax_ba.set_ylabel('Algorithm - RESpeck Built-in (breaths/min)')
    ax_ba.set_title('Bland-Altman: All Algorithms vs RESpeck Built-in')
    ax_ba.legend()
    ax_ba.grid(True, alpha=0.3)
    
    plt.suptitle(f'Enhanced Algorithms vs RESpeck Built-in Comparison\n({len(valid_results)} valid {window_minutes}-minute windows)', 
                 fontsize=16, fontweight='bold')
    
    plt.tight_layout()
    plt.show()

def run_enhanced_respeck_comparison(respeck_df, window_minutes=5):
    """
    Run the enhanced comparison including NeuroKit and provide a comprehensive summary.
    """
    print("🚀 STARTING ENHANCED ALGORITHMS vs RESpeck COMPARISON")
    print("=" * 80)
    
    # Check NeuroKit availability
    if not NEUROKIT_AVAILABLE:
        print("⚠️  NeuroKit2 not available - install with: pip install neurokit2")
        print("   Continuing with available algorithms only...")
    
    # Run the comparison
    results = compare_algorithms_vs_respeck_builtin_with_neurokit(respeck_df, window_minutes=window_minutes)
    
    if results is None:
        print("❌ Comparison failed - check your data")
        return None
    
    # Print final summary
    print("\n" + "="*80)
    print("🎯 ENHANCED FINAL SUMMARY")
    print("="*80)
    
    valid_windows = results['num_valid_windows']
    total_windows = results['total_windows']
    
    print(f"📊 Windows analyzed: {valid_windows}/{total_windows} successful")
    print(f"⏱️  Window duration: {window_minutes} minutes each")
    print(f"🧠 NeuroKit available: {results['neurokit_available']}")
    
    if valid_windows > 0:
        correlations = results['correlations']
        
        print(f"\n📈 PERFORMANCE RANKINGS (vs RESpeck Built-in Reference):")
        
        # Rank methods by correlation with RESpeck built-in
        if correlations:
            method_correlations = []
            
            for key, (corr, p_val) in correlations.items():
                if '_vs_respeck' in key and key != 'new_vs_old':
                    method_name = key.replace('_vs_respeck', '').replace('_', ' ').title()
                    if 'neurokit' in key.lower():
                        method_name = 'NeuroKit'
                    elif 'new' in key.lower():
                        method_name = 'New Algorithm'
                    elif 'old' in key.lower():
                        method_name = 'Old Algorithm'
                    
                    method_correlations.append((method_name, corr, p_val))
            
            # Sort by correlation strength
            method_correlations.sort(key=lambda x: abs(x[1]), reverse=True)
            
            for i, (method, corr, p_val) in enumerate(method_correlations, 1):
                significance = "***" if p_val < 0.001 else "**" if p_val < 0.01 else "*" if p_val < 0.05 else ""
                print(f"   {i}. {method}: r = {corr:.3f}{significance}")
        
        # Calculate mean absolute errors for recommendation
        valid_data = results['valid_windows']
        if len(valid_data) > 0:
            print(f"\n📏 ACCURACY COMPARISON (vs RESpeck Built-in):")
            
            algorithm_methods = [
                ('new_algo_bpm', 'New Algorithm'),
                ('old_algo_bpm', 'Old Algorithm'),
                ('neurokit_bpm', 'NeuroKit')
            ]
            
            best_method = None
            best_mae = float('inf')
            
            for col, name in algorithm_methods:
                if col in valid_data.columns and valid_data[col].notna().any():
                    valid_comparison = valid_data.dropna(subset=[col, 'respeck_builtin_bpm'])
                    
                    if len(valid_comparison) > 0:
                        mae = np.mean(np.abs(valid_comparison[col] - valid_comparison['respeck_builtin_bpm']))
                        print(f"   {name} MAE: {mae:.1f} breaths/min (n={len(valid_comparison)})")
                        
                        if mae < best_mae:
                            best_mae = mae
                            best_method = name
            
            print(f"\n🏆 RECOMMENDATION:")
            if best_method:
                print(f"✅ Best performing method: {best_method} (MAE: {best_mae:.1f} bpm)")
            else:
                print(f"🔄 Unable to determine best method - insufficient data")
    
    return results

# Usage example:
results = run_enhanced_respeck_comparison(respeck_df, window_minutes=5)

## PSG ANALYSIS

In [None]:
import numpy as np
import pandas as pd
from scipy.signal import find_peaks, detrend

def simple_5min_breath_comparison_nasal(respeck_df, nasal_df, respeck_algorithm_func, 
                                        nasal_flow_column='flow_rate'):
    """
    Simple 5-minute window comparison: count breaths detected by respeck vs nasal cannula
    """
    
    print("📊 SIMPLE 5-MINUTE WINDOW COMPARISON")
    print("=" * 50)
    
    # 1. Get Nasal Cannula breaths
    print("Processing Nasal Cannula...")
    nasal_signal, nasal_sampling_rate = process_nasal_signal(nasal_df, nasal_flow_column)
    nasal_breath_indices, _ = detect_nasal_breaths(nasal_signal, nasal_sampling_rate)
    nasal_breath_times = pd.to_datetime(nasal_df.loc[nasal_breath_indices, 'timestamp'])
    if nasal_breath_times.dt.tz is not None:
        nasal_breath_times = nasal_breath_times.dt.tz_convert('UTC').dt.tz_localize(None)
    
    # 2. Get RESPeck breath cycles
    print("Processing RESPeck...")
    respeck_results, _ = respeck_algorithm_func(respeck_df, adaptation_window_minutes=10, pad_duration_minutes=10)
    respeck_cycles_df, _ = convert_respeck_events_to_cycles(respeck_results)
    respeck_cycle_times = pd.to_datetime(respeck_cycles_df['cycle_time'])
    if respeck_cycle_times.dt.tz is not None:
        respeck_cycle_times = respeck_cycle_times.dt.tz_convert('UTC').dt.tz_localize(None)
    
    # 3. Find overlap period
    overlap_start = max(nasal_breath_times.iloc[0], respeck_cycle_times.iloc[0])
    overlap_end = min(nasal_breath_times.iloc[-1], respeck_cycle_times.iloc[-1])
    
    print(f"Overlap: {overlap_start.strftime('%H:%M:%S')} to {overlap_end.strftime('%H:%M:%S')}")
    
    # 4. Create 5-minute windows
    window_starts = pd.date_range(start=overlap_start, end=overlap_end-pd.Timedelta(minutes=5), freq='5min')
    
    results = []
    
    for i, window_start in enumerate(window_starts):
        window_end = window_start + pd.Timedelta(minutes=5)
        
        # Count breaths in this 5-minute window
        nasal_count = len(nasal_breath_times[(nasal_breath_times >= window_start) & (nasal_breath_times < window_end)])
        respeck_count = len(respeck_cycle_times[(respeck_cycle_times >= window_start) & (respeck_cycle_times < window_end)])
        
        # ADDED: Filter out windows outside 58-110 breath range
        # if 58 <= nasal_count <= 110 and 58 <= respeck_count <= 110:
        results.append({
            'Window': i+1,
            'Start_Time': window_start.strftime('%H:%M:%S'),
            'Nasal_Breaths': nasal_count,
            'RESPeck_Breaths': respeck_count,
            'Difference': respeck_count - nasal_count
        })
    
    # 5. Create results DataFrame
    df_results = pd.DataFrame(results)
    
    # 6. Display results
    print(f"\n📋 BREATH COUNTS PER 5-MINUTE WINDOW (58-110 range only):")
    print(df_results.to_string(index=False))
    
    # 7. Summary stats
    print(f"\n📊 SUMMARY:")
    print(f"Valid windows: {len(df_results)}")
    if not df_results.empty:
        print(f"Nasal average: {df_results['Nasal_Breaths'].mean():.1f} breaths per 5 minutes")
        print(f"RESPeck average: {df_results['RESPeck_Breaths'].mean():.1f} breaths per 5 minutes")
        print(f"Average difference: {df_results['Difference'].mean():.1f} breaths per 5 minutes")
    
    return df_results

def plot_breath_comparison(df_results):
    """
    Plot the valid windows comparing respeck and nasal breaths
    """
    import matplotlib.pyplot as plt
    
    if df_results.empty:
        print("No valid windows to plot")
        return
    
    plt.figure(figsize=(12, 8))
    
    # Create x-axis values (window numbers)
    windows = df_results['Window'].values
    nasal_breaths = df_results['Nasal_Breaths'].values
    respeck_breaths = df_results['RESPeck_Breaths'].values
    
    # Plot both series
    plt.plot(windows, nasal_breaths, 'o-', label='Nasal Cannula', color='blue', linewidth=2, markersize=6)
    plt.plot(windows, respeck_breaths, 's-', label='RESPeck', color='red', linewidth=2, markersize=6)
    
    # Add reference lines for the valid range
    plt.axhline(y=58, color='gray', linestyle='--', alpha=0.5, label='Valid range (58-110)')
    plt.axhline(y=110, color='gray', linestyle='--', alpha=0.5)
    
    # Formatting
    plt.xlabel('Window Number', fontsize=12)
    plt.ylabel('Breaths per 5 minutes', fontsize=12)
    plt.title('Breath Count Comparison: RESPeck vs Nasal Cannula\n(Valid Windows Only: 58-110 breaths)', fontsize=14)
    plt.legend(fontsize=11)
    plt.grid(True, alpha=0.3)
    
    # Set y-axis limits with some padding
    y_min = min(min(nasal_breaths), min(respeck_breaths)) - 5
    y_max = max(max(nasal_breaths), max(respeck_breaths)) + 5
    plt.ylim(y_min, y_max)
    
    # Add window start times as x-axis labels if not too many windows
    if len(df_results) <= 20:
        plt.xticks(windows, df_results['Start_Time'].values, rotation=45)
    else:
        plt.xticks(windows[::max(1, len(windows)//10)])  # Show every 10th window if too many
    
    plt.tight_layout()
    plt.show()
    
    # Second plot: Difference plot
    plt.figure(figsize=(12, 6))
    differences = df_results['Difference'].values
    
    colors = ['green' if d >= 0 else 'orange' for d in differences]
    plt.bar(windows, differences, color=colors, alpha=0.7, edgecolor='black', linewidth=0.5)
    
    plt.axhline(y=0, color='black', linestyle='-', linewidth=1)
    plt.xlabel('Window Number', fontsize=12)
    plt.ylabel('Difference (RESPeck - Nasal)', fontsize=12)
    plt.title('Breath Count Differences by Window\n(Positive = RESPeck higher, Negative = Nasal higher)', fontsize=14)
    plt.grid(True, alpha=0.3, axis='y')
    
    # Add window start times as x-axis labels if not too many windows
    if len(df_results) <= 20:
        plt.xticks(windows, df_results['Start_Time'].values, rotation=45)
    else:
        plt.xticks(windows[::max(1, len(windows)//10)])
    
    plt.tight_layout()
    plt.show()
    
    # Print some stats
    print(f"\n📈 PLOT STATISTICS:")
    print(f"Mean absolute difference: {abs(differences).mean():.1f} breaths")
    print(f"RESPeck higher in {sum(d > 0 for d in differences)} windows")
    print(f"Nasal higher in {sum(d < 0 for d in differences)} windows")
    print(f"Exact match in {sum(d == 0 for d in differences)} windows")

def process_nasal_signal(nasal_df, flow_column='flow_rate'):
    """
    Process nasal cannula signal - just center it around zero
    """
    nasal_signal = nasal_df[flow_column].copy()
    timestamps = pd.to_datetime(nasal_df['timestamp'])
    time_diffs = timestamps.diff().dt.total_seconds().dropna()
    sampling_rate = 1 / time_diffs.median()
    
    # Simple centering - subtract rolling mean
    window_samples = int(2 * 60 * sampling_rate)  # 2 minutes
    rolling_mean = nasal_signal.rolling(window=window_samples, center=True).mean()
    centered_signal = nasal_signal - rolling_mean
    
    return centered_signal, sampling_rate

def detect_nasal_breaths(centered_signal, sampling_rate):
    """
    Detect breaths in nasal cannula data
    """
    valid_signal = centered_signal.dropna()
    
    # Test both polarities to find which direction represents inhalation
    pos_peaks, _ = find_peaks(valid_signal, distance=int(sampling_rate * 1.0))
    neg_peaks, _ = find_peaks(-valid_signal, distance=int(sampling_rate * 1.0))
    
    pos_strength = np.mean(valid_signal.iloc[pos_peaks]) if len(pos_peaks) > 0 else 0
    neg_strength = np.mean(-valid_signal.iloc[neg_peaks]) if len(neg_peaks) > 0 else 0
    
    # Choose stronger polarity
    if pos_strength > neg_strength:
        signal_for_detection = valid_signal
    else:
        signal_for_detection = -valid_signal
    
    # Detect breaths using MAD-based thresholds
    signal_mad = np.median(np.abs(signal_for_detection - np.median(signal_for_detection)))
    height_threshold = 1.5 * signal_mad
    prominence_threshold = 1.0 * signal_mad
    
    breath_peaks, _ = find_peaks(
        signal_for_detection,
        height=height_threshold,
        prominence=prominence_threshold,
        distance=int(sampling_rate * 1.0)  # Minimum 1 second between breaths
    )
    
    return valid_signal.index[breath_peaks], 'breath'

def convert_respeck_events_to_cycles(respeck_df):
    """
    Convert respeck inhalation/exhalation events to breath cycles
    """
    if respeck_df.empty:
        return pd.DataFrame(), 0
    
    inhalations = respeck_df[respeck_df['type'] == 'Inhalation'].copy()
    exhalations = respeck_df[respeck_df['type'] == 'Exhalation'].copy()
    
    inhalations['timestamp'] = pd.to_datetime(inhalations['timestamp'])
    exhalations['timestamp'] = pd.to_datetime(exhalations['timestamp'])
    
    breath_cycles = []
    used_exhalations = set()
    
    for _, inhalation in inhalations.iterrows():
        time_diffs = np.abs((exhalations['timestamp'] - inhalation['timestamp']).dt.total_seconds())
        
        for idx in time_diffs.argsort():
            if idx not in used_exhalations and time_diffs.iloc[idx] <= 5:
                exhalation = exhalations.iloc[idx]
                cycle_time = min(inhalation['timestamp'], exhalation['timestamp'])
                breath_cycles.append({'cycle_time': cycle_time})
                used_exhalations.add(idx)
                break
    
    return pd.DataFrame(breath_cycles), len(breath_cycles)

# USAGE:
def run_comparison(respeck_df, nasal_df, nasal_flow_column='flow_rate'):
    """
    Run the comparison between respeck and nasal cannula
    """
    results = simple_5min_breath_comparison_nasal(
        respeck_df=respeck_df,
        nasal_df=nasal_df,
        respeck_algorithm_func=adaptive_breath_detection_original_fixed,
        nasal_flow_column=nasal_flow_column
    )
    
    # Plot the results
    plot_breath_comparison(results)
    
    return results

# Example usage:
results = run_comparison(respeck_df, psg_df, 'Resp nasal')

## Feature Analysis

In [None]:
import numpy as np
import torch
import torch.nn as nn
import pandas as pd
import pickle
from scipy.stats import spearmanr, skew, kurtosis
from scipy import signal as scipy_signal
import matplotlib.pyplot as plt

def load_pytorch_model(model_path):
    """
    Load PyTorch model - handles both full model and state_dict cases
    """
    try:
        # Try loading as full model first
        checkpoint = torch.load(model_path, map_location='cpu')
        
        if hasattr(checkpoint, 'eval'):
            print("Loaded full model")
            return checkpoint, 'full_model'
        elif isinstance(checkpoint, dict):
            if 'model_state_dict' in checkpoint:
                print("" \
                " checkpoint with state_dict")
                return checkpoint['model_state_dict'], 'state_dict'
            else:
                print("Loaded raw state_dict")
                return checkpoint, 'state_dict'
        else:
            print("Unknown model format")
            return checkpoint, 'unknown'
            
    except Exception as e:
        print(f"Error loading model: {e}")
        return None, None


def robust_breath_features(signal_segment, timestamps_segment):
    """
    Robust feature extraction that handles edge cases
    """
    features = {}
    
    try:
        # Basic statistical features (always computable)
        features['signal_mean'] = np.mean(signal_segment)
        features['signal_std'] = np.std(signal_segment)
        features['signal_var'] = np.var(signal_segment)
        features['signal_min'] = np.min(signal_segment)
        features['signal_max'] = np.max(signal_segment)
        features['signal_range'] = features['signal_max'] - features['signal_min']
        features['signal_skewness'] = skew(signal_segment) if len(signal_segment) > 2 else 0
        features['signal_kurtosis'] = kurtosis(signal_segment) if len(signal_segment) > 3 else 0
        
        # Zero crossing rate
        zero_crossings = np.sum(np.diff(np.sign(signal_segment - np.mean(signal_segment))) != 0)
        features['zero_crossing_rate'] = zero_crossings / len(signal_segment)
        
        # RMS (Root Mean Square) - energy measure
        features['rms'] = np.sqrt(np.mean(signal_segment**2))
        
        # Frequency domain features
        try:
            # Power spectral density
            freqs, psd = scipy_signal.welch(signal_segment, fs=12.5, nperseg=min(64, len(signal_segment)//4))
            
            # Define frequency bands for breathing analysis
            very_low_freq = (freqs >= 0.008) & (freqs < 0.04)   # Apnea cycling
            low_freq = (freqs >= 0.04) & (freqs < 0.15)         # Abnormal patterns  
            normal_breathing = (freqs >= 0.15) & (freqs < 0.5)   # Normal breathing
            high_freq = (freqs >= 0.5) & (freqs < 2.0)          # Effort/artifacts
            
            # Calculate power in each band
            total_power = np.sum(psd) + 1e-10
            features['vlf_power_ratio'] = np.sum(psd[very_low_freq]) / total_power if np.any(very_low_freq) else 0
            features['lf_power_ratio'] = np.sum(psd[low_freq]) / total_power if np.any(low_freq) else 0
            features['normal_power_ratio'] = np.sum(psd[normal_breathing]) / total_power if np.any(normal_breathing) else 0
            features['hf_power_ratio'] = np.sum(psd[high_freq]) / total_power if np.any(high_freq) else 0
            
            # Dominant frequency
            if len(psd) > 0:
                dominant_freq_idx = np.argmax(psd)
                features['dominant_frequency'] = freqs[dominant_freq_idx]
                features['dominant_power'] = psd[dominant_freq_idx]
            else:
                features['dominant_frequency'] = 0
                features['dominant_power'] = 0
                
        except Exception as e:
            print(f"Frequency analysis failed: {e}")
            features.update({
                'vlf_power_ratio': 0, 'lf_power_ratio': 0, 
                'normal_power_ratio': 0, 'hf_power_ratio': 0,
                'dominant_frequency': 0, 'dominant_power': 0
            })
        
        # Try to get breathing-specific features from your original function
        try:
            from calculateContinuousBreathFeatures import calculate_TS_breathFeatures
            breath_features = calculate_TS_breathFeatures(timestamps_segment, signal_segment)
            
            if breath_features:
                # Extract key features if they exist and are not empty
                if 'amplitude' in breath_features and len(breath_features['amplitude']) > 0:
                    amplitudes = np.array(breath_features['amplitude'])
                    features['amplitude_mean'] = np.mean(amplitudes)
                    features['amplitude_std'] = np.std(amplitudes)
                    features['amplitude_cv'] = features['amplitude_std'] / (features['amplitude_mean'] + 1e-10)
                    
                    # Amplitude reduction analysis
                    features['amplitude_p10'] = np.percentile(amplitudes, 10)
                    features['amplitude_p50'] = np.percentile(amplitudes, 50)  
                    features['amplitude_p90'] = np.percentile(amplitudes, 90)
                    features['amplitude_reduction_ratio'] = 1 - (features['amplitude_p10'] / (features['amplitude_p90'] + 1e-10))
                
                if 'breath_durations' in breath_features and len(breath_features['breath_durations']) > 0:
                    durations = np.array(breath_features['breath_durations'])
                    features['breath_duration_mean'] = np.mean(durations)
                    features['breath_duration_std'] = np.std(durations)
                    features['breath_duration_cv'] = features['breath_duration_std'] / (features['breath_duration_mean'] + 1e-10)
                    features['long_breath_ratio'] = np.sum(durations > 20) / len(durations)  # >20 sec breaths
                
                if 'rr' in breath_features and len(breath_features['rr']) > 0:
                    rr = np.array(breath_features['rr'])
                    rr = rr[~np.isnan(rr)]  # Remove NaN values
                    if len(rr) > 0:
                        features['respiratory_rate_mean'] = np.mean(rr)
                        features['respiratory_rate_std'] = np.std(rr)
                        features['respiratory_rate_cv'] = features['respiratory_rate_std'] / (features['respiratory_rate_mean'] + 1e-10)
                
        except Exception as e:
            print(f"Breath feature extraction failed: {e}")
            # Set default values for breath features
            breath_feature_defaults = {
                'amplitude_mean': 0, 'amplitude_std': 0, 'amplitude_cv': 0,
                'amplitude_p10': 0, 'amplitude_p50': 0, 'amplitude_p90': 0,
                'amplitude_reduction_ratio': 0, 'breath_duration_mean': 0,
                'breath_duration_std': 0, 'breath_duration_cv': 0,
                'long_breath_ratio': 0, 'respiratory_rate_mean': 0,
                'respiratory_rate_std': 0, 'respiratory_rate_cv': 0
            }
            features.update(breath_feature_defaults)
        
        # Additional interpretable features for OSA
        
        # Activity level (how much movement/breathing effort)
        features['activity_level'] = np.mean(np.abs(np.diff(signal_segment)))
        
        # Signal variability in different time scales
        if len(signal_segment) >= 20:
            # Short-term variability (every 5 samples ≈ 0.4 seconds)
            short_segments = signal_segment[::5]
            features['short_term_variability'] = np.std(short_segments)
            
            # Long-term trend
            if len(signal_segment) >= 60:
                long_segments = signal_segment[::12]  # Every ~1 second
                features['long_term_trend'] = np.abs(np.polyfit(range(len(long_segments)), long_segments, 1)[0])
            else:
                features['long_term_trend'] = 0
        else:
            features['short_term_variability'] = 0
            features['long_term_trend'] = 0
            
        # Breathing irregularity proxy
        signal_envelope = np.abs(scipy_signal.hilbert(signal_segment - np.mean(signal_segment)))
        features['envelope_std'] = np.std(signal_envelope)
        features['envelope_cv'] = features['envelope_std'] / (np.mean(signal_envelope) + 1e-10)
        
    except Exception as e:
        print(f"Feature extraction failed completely: {e}")
        # Return basic features at minimum
        features = {
            'signal_mean': np.mean(signal_segment),
            'signal_std': np.std(signal_segment),
            'signal_range': np.max(signal_segment) - np.min(signal_segment)
        }
    
    return features

def analyze_with_model(model_path, respiratory_signal, timestamps, is_pickle=False, window_size=30):
    """
    Analyze respiratory signal with either PyTorch or Pickle model
    """
    print("Starting robust model interpretation...")
    

    model, model_type = load_pytorch_model(model_path)
    
    if model is None:
        print("Model loading failed. Analyzing features only...")
        return analyze_features_only(respiratory_signal, timestamps, window_size)
    
    sampling_rate = 12.5
    window_samples = int(window_size * sampling_rate)
    
    all_features = []
    all_predictions = []
    segment_info = []
    
    successful_segments = 0
    
    print(f"Processing {len(respiratory_signal)} samples into {window_size}s windows...")
    
    # For PyTorch state_dict, try to load architecture
    pytorch_model = None
    if model_type == 'state_dict':
        print("Analyzing features only...")
        return analyze_features_only(respiratory_signal, timestamps, window_size)
    elif model_type == 'full_model':
        pytorch_model = model
    
    # Process in windows
    for i in range(0, len(respiratory_signal) - window_samples, window_samples // 2):
        signal_segment = respiratory_signal[i:i + window_samples]
        time_segment = timestamps[i:i + window_samples]
        
        try:
            # Extract robust features
            features = robust_breath_features(signal_segment, time_segment)
            all_features.append(features)
            
            # Try to get model prediction
            prediction = None
            if is_pickle:
                # For pickle models (sklearn, etc.)
                try:
                    feature_array = np.array([list(features.values())]).reshape(1, -1)
                    
                    if hasattr(model, 'predict_proba'):
                        pred_proba = model.predict_proba(feature_array)[0]
                        prediction = {
                            'class': np.argmax(pred_proba),
                            'probabilities': pred_proba,
                            'confidence': np.max(pred_proba)
                        }
                    elif hasattr(model, 'predict'):
                        pred_class = model.predict(feature_array)[0]
                        prediction = {
                            'class': pred_class,
                            'confidence': 0.5  # Unknown confidence for simple predict
                        }
                except Exception as e:
                    print(f"Pickle model prediction failed for segment {successful_segments}: {e}")
                    prediction = {'class': 0, 'confidence': 0.0}
            else:
                # For PyTorch models
                try:
                    if pytorch_model is not None:
                        pytorch_model.eval()
                        
                        # Prepare input based on architecture
                        model_input = torch.FloatTensor(signal_segment)
                        
                        # Add appropriate dimensions
                        if 'cnn' in str(type(pytorch_model)).lower():
                            model_input = model_input.unsqueeze(0).unsqueeze(0)  # [batch, channel, length]
                        elif 'lstm' in str(type(pytorch_model)).lower():
                            model_input = model_input.unsqueeze(0).unsqueeze(-1)  # [batch, length, features]
                        else:
                            model_input = model_input.unsqueeze(0)  # [batch, length]
                        
                        with torch.no_grad():
                            output = pytorch_model(model_input)
                            if output.shape[-1] > 1:
                                probs = torch.softmax(output, dim=-1).numpy()[0]
                                prediction = {
                                    'class': np.argmax(probs),
                                    'probabilities': probs,
                                    'confidence': np.max(probs)
                                }
                            else:
                                prob = torch.sigmoid(output).numpy()[0][0]
                                prediction = {
                                    'class': int(prob > 0.5),
                                    'probability': prob,
                                    'confidence': abs(prob - 0.5) * 2
                                }
                    else:
                        prediction = {'class': 0, 'confidence': 0.0}
                except Exception as e:
                    print(f"PyTorch model prediction failed for segment {successful_segments}: {e}")
                    prediction = {'class': 0, 'confidence': 0.0}
            
            all_predictions.append(prediction)
            
            segment_info.append({
                'start_time': time_segment[0],
                'end_time': time_segment[-1],
                'segment_index': successful_segments
            })
            
            successful_segments += 1
            
            if successful_segments % 50 == 0:
                print(f"Successfully processed {successful_segments} segments...")
                
        except Exception as e:
            print(f"Error processing segment {i}: {e}")
            continue
    
    print(f"Successfully processed {successful_segments} segments")
    
    # Analyze results
    if all_predictions and any(pred.get('confidence', 0) > 0 for pred in all_predictions):
        correlation_analysis = analyze_feature_correlations(all_predictions, all_features)
        
        # Print prediction distribution
        classes = [pred['class'] for pred in all_predictions]
        unique, counts = np.unique(classes, return_counts=True)
        print(f"\nModel predictions distribution:")
        class_names = {0: 'Normal', 1: 'Hypopnea', 2: 'Apnea'}
        for cls, count in zip(unique, counts):
            print(f"  {class_names.get(cls, f'Class {cls}')}: {count} segments ({count/len(classes)*100:.1f}%)")
    else:
        correlation_analysis = None
        print("No valid predictions obtained - analyzing features only")
    
    # Feature distribution analysis
    analyze_feature_distributions(all_features)
    
    return {
        'predictions': all_predictions,
        'features': all_features,
        'segment_info': segment_info,
        'correlation_analysis': correlation_analysis
    }

def analyze_features_only(respiratory_signal, timestamps, window_size=30):
    """
    Analyze features without model predictions
    """
    print("Analyzing features only...")
    
    sampling_rate = 12.5
    window_samples = int(window_size * sampling_rate)
    
    all_features = []
    segment_info = []
    successful_segments = 0
    
    for i in range(0, len(respiratory_signal) - window_samples, window_samples // 2):
        signal_segment = respiratory_signal[i:i + window_samples]
        time_segment = timestamps[i:i + window_samples]
        
        try:
            features = robust_breath_features(signal_segment, time_segment)
            all_features.append(features)
            
            segment_info.append({
                'start_time': time_segment[0],
                'end_time': time_segment[-1],
                'segment_index': successful_segments
            })
            
            successful_segments += 1
            
        except Exception as e:
            print(f"Error processing segment {i}: {e}")
            continue
    
    print(f"Successfully processed {successful_segments} segments")
    analyze_feature_distributions(all_features)
    
    return {
        'features': all_features,
        'segment_info': segment_info
    }

def analyze_feature_correlations(predictions, features):
    """
    Analyze correlations between features and model predictions
    """
    if not features or not predictions:
        return None
    
    # Convert to arrays
    classes = [pred.get('class', 0) for pred in predictions]
    confidences = [pred.get('confidence', 0) for pred in predictions]
    
    # Create feature matrix
    feature_names = list(features[0].keys())
    feature_matrix = np.array([[f.get(name, 0) for name in feature_names] for f in features])
    
    correlations = {}
    
    for i, feature_name in enumerate(feature_names):
        feature_values = feature_matrix[:, i]
        
        # Skip if no variation
        if np.std(feature_values) < 1e-10:
            continue
            
        try:
            # Correlation with class
            corr_class, p_class = spearmanr(feature_values, classes)
            # Correlation with confidence  
            corr_conf, p_conf = spearmanr(feature_values, confidences)
            
            correlations[feature_name] = {
                'class_correlation': corr_class if not np.isnan(corr_class) else 0,
                'class_p_value': p_class if not np.isnan(p_class) else 1,
                'confidence_correlation': corr_conf if not np.isnan(corr_conf) else 0,
                'confidence_p_value': p_conf if not np.isnan(p_conf) else 1
            }
        except:
            continue
    
    # Sort by absolute correlation with class
    sorted_correlations = sorted(correlations.items(),
                                key=lambda x: abs(x[1]['class_correlation']),
                                reverse=True)
    
    print("\n" + "="*80)
    print("MODEL INTERPRETATION RESULTS")
    print("="*80)
    print("Top features correlated with model predictions:")
    print("-" * 80)
    
    for i, (feature, corr_data) in enumerate(sorted_correlations[:15]):
        print(f"{i+1:2d}. {feature:<30} | Class corr: {corr_data['class_correlation']:6.3f} "
              f"(p={corr_data['class_p_value']:.4f}) | Conf corr: {corr_data['confidence_correlation']:6.3f}")
    
    return sorted_correlations

def analyze_feature_distributions(features):
    """
    Analyze feature distributions
    """
    if not features:
        return
        
    print("\n" + "="*60)
    print("FEATURE DISTRIBUTION ANALYSIS")
    print("="*60)
    
    # Get all feature names
    all_feature_names = set()
    for f in features:
        all_feature_names.update(f.keys())
    
    # Calculate statistics
    for feature_name in sorted(all_feature_names):
        values = [f.get(feature_name, 0) for f in features]
        values = np.array(values)
        values = values[~np.isnan(values)]  # Remove NaN values
        
        if len(values) > 0:
            print(f"{feature_name:<25}: Mean={np.mean(values):.3f}, Std={np.std(values):.3f}, "
                  f"Min={np.min(values):.3f}, Max={np.max(values):.3f}")

# Usage functions
def analyze_with_pytorch_model(pt_path, respiratory_signal, timestamps):
    """Wrapper for PyTorch model analysis"""
    return analyze_with_model(pt_path, respiratory_signal, timestamps, is_pickle=False)



# Example usage:
if __name__ == "__main__":
    # For your case:
    respiratory_signal = respeck_df['breathingSignal'].values
    timestamps = respeck_df['timestamp'].values
    
    # With PyTorch model:
    results = analyze_with_pytorch_model('/Users/hkhes/Developer/msc/dissertation/DysfunctionalBreathingCharacterisation/results/cd2046784/final_sleep_apnea_model_20250718_200301.pt', respiratory_signal, timestamps)
    
    # Features only:
    results = analyze_features_only(respiratory_signal, timestamps)
    
    pass

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import signal as scipy_signal
from scipy.stats import spearmanr
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler

def define_clinical_breath_signatures(features_data, model_predictions=None):
    """
    Define sleep apnea events in terms of clinical breath features
    Maps complex model patterns to interpretable clinical signatures
    
    Args:
        features_data: List of feature dictionaries from your analysis
        model_predictions: Optional model predictions for correlation analysis
    """
    
    print("🏥 CLINICAL BREATH FEATURE SIGNATURES FOR SLEEP APNEA")
    print("="*80)
    print("Defining OSA events in terms clinicians understand")
    print("="*80)
    
    # Convert to DataFrame
    df = pd.DataFrame(features_data)
    
    # Define clinical breath feature signatures
    clinical_signatures = define_osa_breath_signatures()
    
    # Classify each segment using clinical criteria
    segment_classifications = []
    clinical_scores = []
    
    for idx, row in df.iterrows():
        classification, scores = classify_segment_clinically(row, clinical_signatures)
        segment_classifications.append(classification)
        clinical_scores.append(scores)
    
    # Add classifications to dataframe
    df['clinical_classification'] = segment_classifications
    
    # Analyze signature patterns
    signature_analysis = analyze_signature_patterns(df, clinical_scores, clinical_signatures)
    
    # Create clinical decision tree
    decision_rules = create_clinical_decision_tree(df, clinical_signatures)
    
    # Visualize clinical signatures
    create_clinical_signature_visualizations(df, clinical_signatures)
    
    return {
        'clinical_signatures': clinical_signatures,
        'classifications': segment_classifications,
        'signature_analysis': signature_analysis,
        'decision_rules': decision_rules,
        'clinical_df': df
    }

def define_osa_breath_signatures():
    """
    Define clinical breath feature signatures for each OSA event type
    Based on clinical sleep medicine literature
    """
    
    signatures = {
        'obstructive_apnea': {
            'name': 'Obstructive Apnea',
            'clinical_definition': 'Complete cessation of airflow with continued respiratory effort',
            'primary_features': {
                'amplitude_reduction_ratio': {'min': 0.90, 'max': 1.0},  # >90% amplitude reduction
                'effort_flow_mismatch': {'min': 2.0, 'max': np.inf},      # High effort, low flow
                'respiratory_rate_cv': {'min': 0.1, 'max': np.inf},       # Irregular due to obstruction
                'envelope_cv': {'min': 0.5, 'max': np.inf}                # Variable effort
            },
            'secondary_features': {
                'normal_power_ratio': {'min': 0.0, 'max': 0.6},           # Reduced normal breathing
                'hf_power_ratio': {'min': 0.15, 'max': np.inf},           # Increased effort artifacts
                'breath_duration_mean': {'min': 15.0, 'max': np.inf},     # Long event duration
                'activity_level': {'min': 0.01, 'max': np.inf}            # Increased activity/struggle
            },
            'weight': 4,  # Highest severity
            'duration_threshold': 10.0  # seconds
        },
        
        'hypopnea': {
            'name': 'Hypopnea',
            'clinical_definition': 'Partial reduction in airflow (30-90%) with arousal or desaturation',
            'primary_features': {
                'amplitude_reduction_ratio': {'min': 0.30, 'max': 0.90},  # 30-90% reduction
                'respiratory_rate_cv': {'min': 0.12, 'max': np.inf},       # Moderate irregularity
                'envelope_cv': {'min': 0.35, 'max': 0.7},                  # Some effort variability
                'normal_power_ratio': {'min': 0.4, 'max': 0.8}             # Reduced but present
            },
            'secondary_features': {
                'breath_duration_cv': {'min': 0.15, 'max': np.inf},        # Variable durations
                'hf_power_ratio': {'min': 0.10, 'max': 0.25},              # Some effort increase
                'dominant_frequency': {'min': 0.15, 'max': 0.4},           # Normal-ish frequency
                'activity_level': {'min': 0.005, 'max': 0.02}              # Mild activity increase
            },
            'weight': 2,  # Moderate severity
            'duration_threshold': 10.0
        },
        
        'central_apnea': {
            'name': 'Central Apnea',
            'clinical_definition': 'Cessation of both airflow and respiratory effort',
            'primary_features': {
                'amplitude_reduction_ratio': {'min': 0.85, 'max': 1.0},    # High amplitude reduction
                'activity_level': {'min': 0.0, 'max': 0.008},              # Very low activity
                'envelope_cv': {'min': 0.0, 'max': 0.3},                   # Low effort variability
                'respiratory_rate_cv': {'min': 0.0, 'max': 0.15}           # Can be regular
            },
            'secondary_features': {
                'normal_power_ratio': {'min': 0.0, 'max': 0.5},            # Very low normal breathing
                'hf_power_ratio': {'min': 0.0, 'max': 0.12},               # Low effort artifacts
                'rms': {'min': 0.0, 'max': 0.05},                          # Low signal energy
                'signal_std': {'min': 0.0, 'max': 0.05}                    # Low variability
            },
            'weight': 3,  # High severity
            'duration_threshold': 10.0
        },
        
        'flow_limitation': {
            'name': 'Flow Limitation',
            'clinical_definition': 'Flattened inspiratory flow with increased effort',
            'primary_features': {
                'amplitude_reduction_ratio': {'min': 0.15, 'max': 0.45},   # Mild-moderate reduction
                'effort_flow_mismatch': {'min': 1.2, 'max': 3.0},          # Moderate mismatch
                'envelope_cv': {'min': 0.4, 'max': 0.8},                   # Effort variability
                'respiratory_rate_cv': {'min': 0.08, 'max': 0.25}          # Some irregularity
            },
            'secondary_features': {
                'normal_power_ratio': {'min': 0.6, 'max': 0.85},           # Mostly preserved
                'hf_power_ratio': {'min': 0.08, 'max': 0.20},              # Some effort increase
                'dominant_frequency': {'min': 0.12, 'max': 0.35},          # Slightly altered
                'breath_duration_cv': {'min': 0.10, 'max': 0.30}           # Mild variability
            },
            'weight': 1,  # Mild severity
            'duration_threshold': 5.0
        },
        
        'normal_breathing': {
            'name': 'Normal Breathing',
            'clinical_definition': 'Regular, unobstructed breathing pattern',
            'primary_features': {
                'amplitude_reduction_ratio': {'min': 0.0, 'max': 0.25},    # Minimal reduction
                'respiratory_rate_cv': {'min': 0.0, 'max': 0.12},          # Regular pattern
                'envelope_cv': {'min': 0.1, 'max': 0.4},                   # Consistent effort
                'normal_power_ratio': {'min': 0.7, 'max': 1.0}             # High normal breathing
            },
            'secondary_features': {
                'breath_duration_cv': {'min': 0.0, 'max': 0.15},           # Regular durations
                'hf_power_ratio': {'min': 0.0, 'max': 0.15},               # Low effort artifacts
                'activity_level': {'min': 0.002, 'max': 0.012},            # Normal activity
                'dominant_frequency': {'min': 0.15, 'max': 0.35}           # Normal frequency
            },
            'weight': 0,  # No pathology
            'duration_threshold': 0.0
        }
    }
    
    return signatures

def classify_segment_clinically(segment_data, clinical_signatures):
    """
    Classify a single segment using clinical breath feature criteria
    """
    
    scores = {}
    
    for event_type, signature in clinical_signatures.items():
        primary_score = 0
        secondary_score = 0
        primary_total = len(signature['primary_features'])
        secondary_total = len(signature['secondary_features'])
        
        # Check primary features
        for feature, criteria in signature['primary_features'].items():
            if feature in segment_data:
                value = segment_data[feature]
                if pd.notna(value) and criteria['min'] <= value <= criteria['max']:
                    primary_score += 1
        
        # Check secondary features
        for feature, criteria in signature['secondary_features'].items():
            if feature in segment_data:
                value = segment_data[feature]
                if pd.notna(value) and criteria['min'] <= value <= criteria['max']:
                    secondary_score += 1
        
        # Calculate composite score
        primary_ratio = primary_score / primary_total if primary_total > 0 else 0
        secondary_ratio = secondary_score / secondary_total if secondary_total > 0 else 0
        
        # Weight primary features more heavily
        composite_score = (0.7 * primary_ratio) + (0.3 * secondary_ratio)
        
        scores[event_type] = {
            'primary_score': primary_score,
            'secondary_score': secondary_score,
            'primary_ratio': primary_ratio,
            'secondary_ratio': secondary_ratio,
            'composite_score': composite_score,
            'weight': signature['weight']
        }
    
    # Find best matching signature
    best_match = max(scores.items(), key=lambda x: x[1]['composite_score'])
    classification = best_match[0]
    
    # Require minimum threshold for pathological classifications
    min_threshold = 0.4  # At least 40% of criteria must be met
    if best_match[1]['composite_score'] < min_threshold and classification != 'normal_breathing':
        classification = 'indeterminate'
    
    return classification, scores

def analyze_signature_patterns(df, clinical_scores, clinical_signatures):
    """
    Analyze patterns in clinical signature matching
    """
    
    print(f"\n📊 CLINICAL SIGNATURE ANALYSIS:")
    print("-"*60)
    
    # Classification distribution
    class_counts = df['clinical_classification'].value_counts()
    total_segments = len(df)
    
    print(f"Clinical Classification Distribution:")
    for classification, count in class_counts.items():
        percentage = (count / total_segments) * 100
        signature_name = clinical_signatures.get(classification, {}).get('name', classification.title())
        print(f"  {signature_name:<20}: {count:4d} segments ({percentage:5.1f}%)")
    
    # Feature importance for each classification
    print(f"\n🎯 KEY DISCRIMINATING FEATURES BY EVENT TYPE:")
    print("-"*60)
    
    for event_type, signature in clinical_signatures.items():
        if event_type in class_counts and class_counts[event_type] > 5:  # Only analyze if enough samples
            event_segments = df[df['clinical_classification'] == event_type]
            
            print(f"\n{signature['name']} (n={len(event_segments)}):")
            print(f"  Clinical Definition: {signature['clinical_definition']}")
            
            # Show primary feature statistics
            print(f"  Primary Features:")
            for feature, criteria in signature['primary_features'].items():
                if feature in event_segments.columns:
                    values = event_segments[feature].dropna()
                    if len(values) > 0:
                        mean_val = values.mean()
                        std_val = values.std()
                        print(f"    {feature:<25}: μ={mean_val:.3f}±{std_val:.3f} (criteria: {criteria['min']:.2f}-{criteria['max']:.2f})")
    
    # Calculate signature consistency
    signature_consistency = {}
    for idx, scores in enumerate(clinical_scores):
        classification = df.iloc[idx]['clinical_classification']
        if classification in scores:
            consistency = scores[classification]['composite_score']
            if classification not in signature_consistency:
                signature_consistency[classification] = []
            signature_consistency[classification].append(consistency)
    
    print(f"\n📈 SIGNATURE CONSISTENCY (how well segments match their assigned signature):")
    print("-"*60)
    for event_type, consistencies in signature_consistency.items():
        if event_type in clinical_signatures:
            mean_consistency = np.mean(consistencies)
            signature_name = clinical_signatures[event_type]['name']
            print(f"  {signature_name:<20}: {mean_consistency:.3f} (higher = better match)")
    
    return {
        'class_counts': class_counts,
        'signature_consistency': signature_consistency
    }

def create_clinical_decision_tree(df, clinical_signatures):
    """
    Create interpretable decision rules for clinical classification
    """
    
    print(f"\n🌳 CLINICAL DECISION TREE RULES:")
    print("-"*60)
    print("If-then rules for clinicians to identify OSA events:")
    print("-"*60)
    
    rules = []
    
    # Rule 1: Obstructive Apnea
    rules.append({
        'condition': "Obstructive Apnea",
        'rule': "IF amplitude_reduction_ratio > 0.90 AND effort_flow_mismatch > 2.0 AND respiratory_rate_cv > 0.1",
        'clinical_meaning': "Complete airflow cessation with continued respiratory effort"
    })
    
    # Rule 2: Central Apnea  
    rules.append({
        'condition': "Central Apnea",
        'rule': "IF amplitude_reduction_ratio > 0.85 AND activity_level < 0.008 AND envelope_cv < 0.3",
        'clinical_meaning': "Both airflow and effort cessation"
    })
    
    # Rule 3: Hypopnea
    rules.append({
        'condition': "Hypopnea", 
        'rule': "IF 0.30 < amplitude_reduction_ratio < 0.90 AND respiratory_rate_cv > 0.12",
        'clinical_meaning': "Partial airflow reduction with pattern disruption"
    })
    
    # Rule 4: Flow Limitation
    rules.append({
        'condition': "Flow Limitation",
        'rule': "IF 0.15 < amplitude_reduction_ratio < 0.45 AND effort_flow_mismatch > 1.2",
        'clinical_meaning': "Increased effort with mild flow reduction"
    })
    
    # Rule 5: Normal
    rules.append({
        'condition': "Normal Breathing",
        'rule': "IF amplitude_reduction_ratio < 0.25 AND respiratory_rate_cv < 0.12 AND normal_power_ratio > 0.7",
        'clinical_meaning': "Regular, unobstructed breathing"
    })
    
    for rule in rules:
        print(f"\n{rule['condition']}:")
        print(f"  Rule: {rule['rule']}")
        print(f"  Meaning: {rule['clinical_meaning']}")
    
    # Validate rules against data
    print(f"\n✅ RULE VALIDATION:")
    print("-"*40)
    
    for rule in rules:
        condition = rule['condition'].lower().replace(' ', '_')
        if condition in df['clinical_classification'].values:
            matching_segments = len(df[df['clinical_classification'] == condition])
            print(f"{rule['condition']:<15}: {matching_segments:4d} segments match this rule")
    
    return rules

def create_clinical_signature_visualizations(df, clinical_signatures):
    """
    Create visualizations showing clinical breath feature signatures
    """
    
    # Key discriminating features
    key_features = [
        'amplitude_reduction_ratio', 'effort_flow_mismatch', 'respiratory_rate_cv',
        'envelope_cv', 'normal_power_ratio', 'activity_level'
    ]
    
    available_features = [f for f in key_features if f in df.columns]
    
    if len(available_features) < 3:
        print("Not enough features available for visualization")
        return None
    
    # Create signature comparison plot
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    fig.suptitle('Clinical Breath Feature Signatures for OSA Events', fontsize=16)
    
    axes = axes.flatten()
    
    for i, feature in enumerate(available_features[:6]):
        ax = axes[i]
        
        # Plot distributions for each clinical classification
        classifications = df['clinical_classification'].unique()
        colors = ['blue', 'red', 'green', 'orange', 'purple', 'brown']
        
        for j, classification in enumerate(classifications):
            if classification in clinical_signatures:
                class_data = df[df['clinical_classification'] == classification][feature].dropna()
                if len(class_data) > 5:  # Only plot if enough data
                    signature_name = clinical_signatures[classification]['name']
                    ax.hist(class_data, bins=20, alpha=0.6, 
                           label=f'{signature_name} (n={len(class_data)})', 
                           color=colors[j % len(colors)], density=True)
        
        ax.set_xlabel(feature.replace('_', ' ').title())
        ax.set_ylabel('Density')
        ax.set_title(f'{feature.replace("_", " ").title()} by Event Type')
        ax.legend(fontsize=8)
        ax.grid(True, alpha=0.3)
    
    # Hide unused subplots
    for i in range(len(available_features), len(axes)):
        axes[i].set_visible(False)
    
    plt.tight_layout()
    plt.show()
    
    return fig

def calculate_clinical_ahi(classifications, window_duration_sec=30):
    """
    Calculate clinical AHI (Apnea-Hypopnea Index) from classifications
    """
    
    # Count apnea and hypopnea events
    apnea_events = sum(1 for c in classifications if 'apnea' in c.lower())
    hypopnea_events = sum(1 for c in classifications if 'hypopnea' in c.lower())
    
    total_events = apnea_events + hypopnea_events
    total_time_hours = (len(classifications) * window_duration_sec) / 3600
    
    ahi = total_events / total_time_hours if total_time_hours > 0 else 0
    
    print(f"\n🏥 CLINICAL AHI CALCULATION:")
    print("-"*40)
    print(f"Apnea Events: {apnea_events}")
    print(f"Hypopnea Events: {hypopnea_events}")
    print(f"Total Events: {total_events}")
    print(f"Analysis Duration: {total_time_hours:.2f} hours")
    print(f"Calculated AHI: {ahi:.1f} events/hour")
    
    # Clinical severity classification
    if ahi < 5:
        severity = "Normal (No OSA)"
    elif ahi < 15:
        severity = "Mild OSA"
    elif ahi < 30:
        severity = "Moderate OSA"
    else:
        severity = "Severe OSA"
    
    print(f"Clinical Severity: {severity}")
    
    return {
        'ahi': ahi,
        'apnea_events': apnea_events,
        'hypopnea_events': hypopnea_events,
        'total_events': total_events,
        'analysis_hours': total_time_hours,
        'severity': severity
    }

# Main analysis function
def run_clinical_breath_signature_analysis(features_data):
    """
    Run complete clinical breath signature analysis
    """
    
    print("🚀 RUNNING CLINICAL BREATH SIGNATURE ANALYSIS")
    print("="*80)
    print("Defining OSA events in terms of clinical breath features")
    print("="*80)
    
    # Define clinical signatures and classify segments
    results = define_clinical_breath_signatures(features_data)
    
    # Calculate clinical AHI
    ahi_results = calculate_clinical_ahi(results['classifications'])
    
    # Summary recommendations
    print(f"\n💡 CLINICAL INTERPRETATION SUMMARY:")
    print("="*80)
    print("Your sophisticated attention-CNN model likely learned these patterns:")
    
    class_counts = results['signature_analysis']['class_counts']
    for classification, count in class_counts.items():
        if classification in results['clinical_signatures']:
            signature = results['clinical_signatures'][classification]
            percentage = (count / len(results['classifications'])) * 100
            print(f"\n{signature['name']} ({percentage:.1f}% of segments):")
            print(f"  Clinical Definition: {signature['clinical_definition']}")
            print(f"  Key Pattern: Multi-scale CNN likely detects this through attention on specific frequency/temporal features")
    
    results['ahi_results'] = ahi_results
    return results

# Usage:
# Define OSA events in clinical breath feature terms
clinical_results = run_clinical_breath_signature_analysis(results['features'])