Test for stroke count and stroke rate

1. Load data

In [None]:
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
from scipy.signal import find_peaks, butter, filtfilt
import ipywidgets as widgets
from IPython.display import display, clear_output
from pathlib import Path

if 'df' not in globals():
    df = pd.DataFrame()
if 'csv_df' not in globals():
    csv_df = pd.DataFrame()

db_file_input = widgets.Text(
    value="",
    placeholder="C:/path/to/your/database.db",
    description='DB file path:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='80%')
)
load_db_button = widgets.Button(description="Load DB", button_style='primary')
db_load_output = widgets.Output()

csv_file_input = widgets.Text(
    value="",
    placeholder="C:/path/to/your/stroke_data.csv",
    description='CSV file path:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='80%')
)
load_csv_button = widgets.Button(description="Load CSV", button_style='info')
csv_load_output = widgets.Output()

def load_db(_):
    with db_load_output:
        clear_output()
        db_path = Path(db_file_input.value)
        if not db_path.exists():
            print(f"Path not found: {db_path}")
            return
        try:
            conn = sqlite3.connect(str(db_path))
            print("Available tables:")
            tables = pd.read_sql("SELECT name FROM sqlite_master WHERE type='table';", conn)
            display(tables)

            table_name = 'sensor_data'
            if table_name in tables['name'].values:
                global df
                df = pd.read_sql(f"SELECT * FROM {table_name};", conn)

                # UNIX to ISO
                if "unix_ts" in df.columns:
                    df['datetime'] = (
                        pd.to_datetime(df['unix_ts'], unit='ms', utc=True)
                          .dt.tz_convert('Asia/Manila')
                    )
                elif "timestamp" in df.columns:
                    df['datetime'] = (
                        pd.to_datetime(df['timestamp'], unit='ms', utc=True)
                          .dt.tz_convert('Asia/Manila')
                    )
                else:
                    print(f"Neither 'unix_ts' nor 'timestamp' column found in '{table_name}'. Cannot create 'datetime' column.")
                    return

                print(f"\nPreview of '{table_name}':")
                display(df.head(5))

                print("Columns:")
                print(df.columns.tolist())

                if 'datetime' in df.columns:
                    print("\nTime range in dataset (Asia/Manila):")
                    print(df['datetime'].min(), "to", df['datetime'].max())
            else:
                print(f"Table '{table_name}' not found in database.")
        except Exception as e:
            print(f"Failed to open DB: {e}")

load_db_button.on_click(load_db)

def load_csv(_):
    with csv_load_output:
        clear_output()
        csv_path = Path(csv_file_input.value)
        if not csv_path.exists():
            print(f"Path not found: {csv_path}")
            return
        try:
            global csv_df
            csv_df = pd.read_csv(csv_path)

            if "unix_ts" in csv_df.columns:
                csv_df['datetime'] = (
                    pd.to_datetime(csv_df['unix_ts'], unit='ms', utc=True)
                      .dt.tz_convert('Asia/Manila')
                )

            print(f"\nPreview of '{csv_path.name}':")
            display(csv_df.head(5))

            print("Columns:")
            print(csv_df.columns.tolist())

            if 'datetime' in csv_df.columns:
                print("\nTime range in CSV dataset (Asia/Manila):")
                print(csv_df['datetime'].min(), "to", csv_df['datetime'].max())

        except Exception as e:
            print(f"Failed to load CSV: {e}")

load_csv_button.on_click(load_csv)

display(db_file_input, load_db_button, db_load_output,
        csv_file_input, load_csv_button, csv_load_output)

Text(value='', description='DB file path:', layout=Layout(width='80%'), placeholder='C:/path/to/your/database.…

Button(button_style='primary', description='Load DB', style=ButtonStyle())

Output()

Text(value='', description='CSV file path:', layout=Layout(width='80%'), placeholder='C:/path/to/your/stroke_d…

Button(button_style='info', description='Load CSV', style=ButtonStyle())

Output()

2. Butterworth filter

References:
https://doi.org/10.1016/j.proeng.2010.04.055

http://dx.doi.org/10.4236/jsip.2012.34062

In [2]:
def butter_bandpass_filter(data, lowcut=0.25, highcut=0.5, fs=50, order=2):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return filtfilt(b, a, data)

3. Sampling rate estimator

In [3]:
def estimate_sampling_rate(df):
    diffs = df['datetime'].diff().dt.total_seconds().dropna()
    return 1.0 / diffs.median()

4. Stroke cycle identification (Peak detection)

In [4]:
def identify_stroke_cycles(segment):
    if not all(col in segment.columns for col in ['accel_y', 'accel_z']):
        print("⚠️ accel_y and accel_z not found.")
        return 0, [], []

    fs = estimate_sampling_rate(segment)
    ay_f = butter_bandpass_filter(segment['accel_y'].values, fs=fs)
    az_f = butter_bandpass_filter(segment['accel_z'].values, fs=fs)
    signal = ay_f + az_f

    peaks, _ = find_peaks(signal)
    valid_peaks = list(peaks)

    return len(valid_peaks), valid_peaks, signal


5. Output

In [None]:
# inputs
data_source_select = widgets.Dropdown(
    options=[('Database (Accelerometer)', 'db'), ('CSV (Annotated Strokes)', 'csv')],
    value='db',
    description='Data Source:',
    style={'description_width': 'initial'}
)

start_time_input = widgets.Text(
    value="07:00:00",
    description='Start Time (HH:MM:SS GMT+8):',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='90%')
)

end_time_input = widgets.Text(
    value="07:05:00",
    description='End Time (HH:MM:SS GMT+8):',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='90%')
)

output = widgets.Output()

def calculate_stroke_metrics(change):
    with output:
        clear_output()

        source = data_source_select.value
        current_df = None
        if source == 'db':
            if 'df' not in globals() or df.empty:
                print("DB data not loaded. Please load the .db file first.")
                return
            current_df = df
        elif source == 'csv':
            if 'csv_df' not in globals() or csv_df.empty:
                print("CSV data not loaded. Please load the .csv file first.")
                return
            current_df = csv_df
        
        if 'datetime' not in current_df.columns:
            print("Error: 'datetime' column not found in the selected data. Please ensure the data is loaded correctly with a 'timestamp' or 'unix_ts' column.")
            return
        
        if current_df['datetime'].empty:
            print("Error: 'datetime' column is empty. No time data available for analysis.")
            return

        print(f"\nFull data time range (GMT+8): {current_df['datetime'].min().strftime('%Y-%m-%d %H:%M:%S')} to {current_df['datetime'].max().strftime('%Y-%m-%d %H:%M:%S')}")

        try:
            start_time_str = start_time_input.value
            end_time_str = end_time_input.value

            base_date = current_df['datetime'].iloc[0].date()
            
            start_time = pd.to_datetime(f"{base_date} {start_time_str}").tz_localize('Asia/Manila')
            end_time = pd.to_datetime(f"{base_date} {end_time_str}").tz_localize('Asia/Manila')

            print(f"Filtering for: {start_time.strftime('%Y-%m-%d %H:%M:%S')} GMT+8 to {end_time.strftime('%Y-%m-%d %H:%M:%S')} GMT+8")

        except ValueError:
            print("Invalid time format. Please use HH:MM:SS.")
            return
        except Exception as e:
            print(f"Error parsing times: {e}")
            return

        mask = (current_df['datetime'] >= start_time) & (current_df['datetime'] <= end_time)
        segment = current_df[mask]

        if segment.empty:
            print("No data in selected time range.")
            return

        if source == 'db':
            # stroke cycle identification for DB data
            stroke_count, peaks, signal = identify_stroke_cycles(segment)
            print(f"Detected Stroke Count (DB): {stroke_count}")

            # calculate stroke rate
            duration_seconds = (end_time - start_time).total_seconds()
            stroke_rate = (stroke_count / duration_seconds) * 60 if duration_seconds > 0 else 0
            print(f"Stroke Rate (DB): {stroke_rate:.2f} strokes/min")

            # results plot
            plt.figure(figsize=(12,4))
            plt.plot(segment['datetime'], signal, label="Filtered ay+az")
            if len(peaks) > 0:
                plt.plot(segment['datetime'].iloc[peaks], signal[peaks], "rx", label="Strokes")
            plt.xlabel("Time")
            plt.ylabel("Filtered Acceleration (m/s^2)")
            plt.legend()
            plt.title(f"Stroke Analysis (DB Data) from {start_time.strftime('%H:%M:%S')} to {end_time.strftime('%H:%M:%S')}")
            plt.show()
        elif source == 'csv':
            # stroke count
            stroke_count = segment.shape[0]
            print(f"Annotated Stroke Count (CSV): {stroke_count}")

            # calculate stroke rate
            duration_seconds = (end_time - start_time).total_seconds()
            stroke_rate = (stroke_count / duration_seconds) * 60 if duration_seconds > 0 else 0
            print(f"Stroke Rate (CSV): {stroke_rate:.2f} strokes/min")

calc_button = widgets.Button(
    description="Calculate Stroke Metrics",
    button_style='success'
)

calc_button.on_click(None)
calc_button.on_click(calculate_stroke_metrics)

display(data_source_select, start_time_input, end_time_input, calc_button, output)

Dropdown(description='Data Source:', options=(('Database (Accelerometer)', 'db'), ('CSV (Annotated Strokes)', …

Text(value='07:00:00', description='Start Time (HH:MM:SS GMT+8):', layout=Layout(width='90%'), style=TextStyle…

Text(value='07:05:00', description='End Time (HH:MM:SS GMT+8):', layout=Layout(width='90%'), style=TextStyle(d…

Button(button_style='success', description='Calculate Stroke Metrics', style=ButtonStyle())

Output()

In [None]:
compare_all_button = widgets.Button(
    description="Compare Data",
    button_style='warning'
)
compare_output = widgets.Output()

def compare_all_data(_):
    with compare_output:
        clear_output()
        if 'csv_df' not in globals() or csv_df.empty:
            print("CSV data not loaded. Please load the .csv file first.")
            return
        if 'df' not in globals() or df.empty:
            print("DB data not loaded. Please load the .db file first.")
            return
        if 'datetime' not in csv_df.columns or 'datetime' not in df.columns:
            print("Error: 'datetime' column not found in one or both data sources.")
            return

        if 'unix_ts' in csv_df.columns:
            start_unix = csv_df['unix_ts'].iloc[0]
            end_unix = csv_df['unix_ts'].iloc[-1]
        elif 'timestamp' in csv_df.columns:
            start_unix = csv_df['timestamp'].iloc[0]
            end_unix = csv_df['timestamp'].iloc[-1]
        else:
            print("CSV does not have 'unix_ts' or 'timestamp' column.")
            return
        start_dt = pd.to_datetime(start_unix, unit='ms', utc=True).tz_convert('Asia/Manila')
        end_dt = pd.to_datetime(end_unix, unit='ms', utc=True).tz_convert('Asia/Manila')
        print(f"Comparing over time range: {start_dt} to {end_dt}")

        csv_stroke_count = csv_df.shape[0]
        print(f"\nAnnotated Stroke Count: {csv_stroke_count}")

        db_segment = df[(df['datetime'] >= start_dt) & (df['datetime'] <= end_dt)]
        if db_segment.empty:
            print("No DB data in this time range.")
            db_stroke_count = 0
        else:
            db_stroke_count, _, _ = identify_stroke_cycles(db_segment)
        print(f"Stroke Count (peak detection): {db_stroke_count}")

compare_all_button.on_click(compare_all_data)

display(compare_all_button, compare_output)





Output()