Test for stroke count and stroke rate

In [None]:
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
from scipy.signal import find_peaks, butter, filtfilt
import ipywidgets as widgets
from IPython.display import display, clear_output
from pathlib import Path

if 'df' not in globals():
    df = pd.DataFrame()

# Text input for .db path for now
file_input = widgets.Text(
    value="",
    placeholder="C:/path/to/your/database.db",
    description='DB file path:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='80%')
)
load_button = widgets.Button(description="Load DB", button_style='primary')
load_output = widgets.Output()

def load_db(_):
    with load_output:
        clear_output()
        db_path = Path(file_input.value)
        if not db_path.exists():
            print(f"Path not found: {db_path}")
            return
        try:
            conn = sqlite3.connect(str(db_path))
            print("Available tables:")
            tables = pd.read_sql("SELECT name FROM sqlite_master WHERE type='table';", conn)
            display(tables)

            table_name = 'sensor_data'
            if table_name in tables['name'].values:
                global df
                df = pd.read_sql(f"SELECT * FROM {table_name};", conn)

                # Convert UNIX ms timestamp to datetime
                if "timestamp" in df.columns:
                    df['datetime'] = (
                        pd.to_datetime(df['timestamp'], unit='ms', utc=True)
                          .dt.tz_convert('Asia/Manila')
                    )

                print(f"\nPreview of '{table_name}':")
                display(df.head(5))

                print("Columns:")
                print(df.columns.tolist())

                if 'datetime' in df.columns:
                    print("\nTime range in dataset (Asia/Manila):")
                    print(df['datetime'].min(), "to", df['datetime'].max())
            else:
                print(f"⚠️ Table '{table_name}' not found in database.")
        except Exception as e:
            print(f"Failed to open DB: {e}")

load_button.on_click(load_db)

display(file_input, load_button, load_output)

Text(value='', description='DB file path:', layout=Layout(width='80%'), placeholder='C:/path/to/your/database.…

Button(button_style='primary', description='Load DB', style=ButtonStyle())

Output()

In [2]:
def butter_bandpass_filter(data, lowcut=0.25, highcut=0.5, fs=50, order=2):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return filtfilt(b, a, data)

In [3]:
def estimate_sampling_rate(df):
    diffs = df['datetime'].diff().dt.total_seconds().dropna()
    return 1.0 / diffs.median()

In [4]:
def identify_stroke_cycles(segment, min_sep=0.5, max_sep=4.0, height=0.98):
    """
    Identify stroke cycles using ay + az, bandpass filtering, and constraints.
    min_sep, max_sep in seconds
    height: amplitude threshold in m/s^2
    """
    if not all(col in segment.columns for col in ['accel_y', 'accel_z']):
        print("⚠️ accel_y and accel_z not found.")
        return 0, []

    # Estimate sampling frequency
    fs = estimate_sampling_rate(segment)

    # Filter ay and az
    ay_f = butter_bandpass_filter(segment['accel_y'].values, fs=fs)
    az_f = butter_bandpass_filter(segment['accel_z'].values, fs=fs)

    # Combine signals
    signal = ay_f + az_f

    # Peak detection
    min_distance = int(fs * min_sep)
    peaks, _ = find_peaks(signal, height=height, distance=min_distance)

    # Temporal filtering
    valid_peaks = [peaks[0]] if len(peaks) > 0 else []
    for i in range(1, len(peaks)):
        dt = (peaks[i] - peaks[i-1]) / fs
        if dt <= max_sep:
            valid_peaks.append(peaks[i])

    # Discard first second after push-off
    t0 = segment['datetime'].iloc[0]
    valid_peaks = [p for p in valid_peaks
                   if (segment['datetime'].iloc[p] - t0).total_seconds() > 1]

    return len(valid_peaks), valid_peaks, signal

In [5]:
# Inputs
start_time_input = widgets.Text(
    value="07:00:00",
    description='Start Time (HH:MM:SS):',
    style={'description_width': 'initial'}
)

end_time_input = widgets.Text(
    value="07:05:00",
    description='End Time (HH:MM:SS):',
    style={'description_width': 'initial'}
)

output = widgets.Output()

def calculate_stroke_count(change):
    with output:
        clear_output()

        try:
            base_date = df['datetime'].iloc[0].date()
            start_time = pd.to_datetime(f"{base_date} {start_time_input.value}")
            end_time = pd.to_datetime(f"{base_date} {end_time_input.value}")
            if df['datetime'].dt.tz is not None:
                tz = df['datetime'].dt.tz
                start_time = start_time.tz_localize(tz)
                end_time = end_time.tz_localize(tz)
        except Exception as e:
            print(f"Invalid time format. Use HH:MM:SS. Error: {e}")
            return

        mask = (df['datetime'] >= start_time) & (df['datetime'] <= end_time)
        segment = df[mask]

        if segment.empty:
            print("No data in selected time range.")
            return

        # stroke cycle identification
        stroke_count, peaks, signal = identify_stroke_cycles(segment)

        print(f"Detected Stroke Count: {stroke_count}")

        # results
        plt.figure(figsize=(12,4))
        plt.plot(segment['datetime'], signal, label="Filtered ay+az")
        if len(peaks) > 0:
            plt.plot(segment['datetime'].iloc[peaks], signal[peaks], "rx", label="Strokes")
        plt.xlabel("Time")
        plt.ylabel("Filtered Acceleration (m/s^2)")
        plt.legend()
        plt.show()

calc_button = widgets.Button(
    description="Calculate Stroke Count",
    button_style='success'
)
calc_button.on_click(calculate_stroke_count)

display(start_time_input, end_time_input, calc_button, output)

Text(value='07:00:00', description='Start Time (HH:MM:SS):', style=TextStyle(description_width='initial'))

Text(value='07:05:00', description='End Time (HH:MM:SS):', style=TextStyle(description_width='initial'))

Button(button_style='success', description='Calculate Stroke Count', style=ButtonStyle())

Output()