Test for stroke count and stroke rate

1. Load data

In [None]:
import sqlite3
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from scipy.signal import find_peaks, butter, filtfilt
import ipywidgets as widgets
from IPython.display import display, clear_output
from pathlib import Path

if 'df' not in globals():
    df = pd.DataFrame()
if 'csv_df' not in globals():
    csv_df = pd.DataFrame()

db_file_input = widgets.Text(
    value="",
    placeholder="C:/path/to/your/database.db",
    description='DB file path:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='80%')
)
load_db_button = widgets.Button(description="Load DB", button_style='primary')
db_load_output = widgets.Output()

csv_file_input = widgets.Text(
    value="",
    placeholder="C:/path/to/your/stroke_data.csv",
    description='CSV file path:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='80%')
)
load_csv_button = widgets.Button(description="Load CSV", button_style='info')
csv_load_output = widgets.Output()


def load_db(_):
    with db_load_output:
        clear_output()
        db_path = Path(db_file_input.value)
        if not db_path.exists():
            print(f"Path not found: {db_path}")
            return
        try:
            conn = sqlite3.connect(str(db_path))
            print("Available tables:")
            tables = pd.read_sql("SELECT name FROM sqlite_master WHERE type='table';", conn)
            display(tables)

            table_name = 'sensor_data'
            if table_name in tables['name'].values:
                global df
                df = pd.read_sql(f"SELECT * FROM {table_name};", conn)

                # UNIX to ISO
                if "unix_ts" in df.columns:
                    df['datetime'] = (
                        pd.to_datetime(df['unix_ts'], unit='ms', utc=True)
                          .dt.tz_convert('Asia/Manila')
                    )
                elif "timestamp" in df.columns:
                    df['datetime'] = (
                        pd.to_datetime(df['timestamp'], unit='ms', utc=True)
                          .dt.tz_convert('Asia/Manila')
                    )
                else:
                    print(f"Neither 'unix_ts' nor 'timestamp' column found in '{table_name}'. Cannot create 'datetime' column.")
                    return

                print(f"\nPreview of '{table_name}':")
                display(df.head(5))

                print("Columns:")
                print(df.columns.tolist())

                if 'datetime' in df.columns:
                    print("\nTime range in dataset (Asia/Manila):")
                    print(df['datetime'].min(), "to", df['datetime'].max())
            else:
                print(f"Table '{table_name}' not found in database.")
        except Exception as e:
            print(f"Failed to open DB: {e}")


load_db_button.on_click(load_db)


def load_csv(_):
    with csv_load_output:
        clear_output()
        csv_path = Path(csv_file_input.value)
        if not csv_path.exists():
            print(f"Path not found: {csv_path}")
            return
        try:
            global csv_df
            csv_df = pd.read_csv(csv_path)

            # Create datetime from available timestamp column
            if "unix_ts" in csv_df.columns:
                csv_df['datetime'] = (
                    pd.to_datetime(csv_df['unix_ts'], unit='ms', utc=True)
                      .dt.tz_convert('Asia/Manila')
                )
            elif "timestamp" in csv_df.columns:
                csv_df['datetime'] = (
                    pd.to_datetime(csv_df['timestamp'], unit='ms', utc=True)
                      .dt.tz_convert('Asia/Manila')
                )

            print(f"\nPreview of '{csv_path.name}':")
            display(csv_df.head(30))

            print("Columns:")
            print(csv_df.columns.tolist())

            if 'datetime' in csv_df.columns:
                print("\nTime range in CSV dataset (Asia/Manila):")
                print(csv_df['datetime'].min(), "to", csv_df['datetime'].max())
            else:
                print("Warning: CSV has no 'unix_ts' or 'timestamp' column, so 'datetime' could not be created.")

        except Exception as e:
            print(f"Failed to load CSV: {e}")


load_csv_button.on_click(load_csv)


display(db_file_input, load_db_button, db_load_output,
        csv_file_input, load_csv_button, csv_load_output)

Text(value='', description='DB file path:', layout=Layout(width='80%'), placeholder='C:/path/to/your/database.…

Button(button_style='primary', description='Load DB', style=ButtonStyle())

Output()

Text(value='', description='CSV file path:', layout=Layout(width='80%'), placeholder='C:/path/to/your/stroke_d…

Button(button_style='info', description='Load CSV', style=ButtonStyle())

Output()

2. Butterworth filter

References:
https://doi.org/10.1016/j.proeng.2010.04.055

http://dx.doi.org/10.4236/jsip.2012.34062

In [10]:
def butter_bandpass_filter(data, lowcut=0.25, highcut=0.5, fs=50, order=2):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return filtfilt(b, a, data)

3. Sampling rate estimator

In [11]:
def estimate_sampling_rate(df):
    diffs = df['datetime'].diff().dt.total_seconds().dropna()
    return 1.0 / diffs.median()

4. Stroke cycle identification (Peak detection)

In [12]:
def identify_stroke_cycles(segment):
    if not all(col in segment.columns for col in ['accel_y', 'accel_z']):
        print("accel_y and accel_z not found.")
        return 0, [], []

    fs = estimate_sampling_rate(segment)
    ay_f = butter_bandpass_filter(segment['accel_y'].values, fs=fs)
    az_f = butter_bandpass_filter(segment['accel_z'].values, fs=fs)
    signal = ay_f + az_f

    peaks, _ = find_peaks(signal)
    valid_peaks = list(peaks)

    return len(valid_peaks), valid_peaks, signal


In [13]:
def detect_strokes_butterfly(segment):
    """Butterfly-specific stroke detection using gyro magnitude.

    Separate from the default accel-based pipeline and only used
    when explicitly selected in the evaluation widgets.
    """
    required_cols = ['gyro_x', 'gyro_y', 'gyro_z']
    if not all(col in segment.columns for col in required_cols):
        print("Butterfly detection: gyro_x/gyro_y/gyro_z not found.")
        return 0, [], []

    fs = estimate_sampling_rate(segment)

    gx = segment['gyro_x'].values
    gy = segment['gyro_y'].values
    gz = segment['gyro_z'].values
    gyro_mag = np.sqrt(gx**2 + gy**2 + gz**2)

    # Smooth gyro magnitude (~0.3 s moving average)
    window = max(1, int(fs * 0.3))
    gyro_series = pd.Series(gyro_mag)
    gyro_smooth = gyro_series.rolling(window, center=True, min_periods=1).mean().values

    # Peak detection tuned for ~22–30 spm butterfly
    # Allow a bit shorter minimum interval and lower prominence than before
    min_interval_s = 1.2  # ~1.2 s between peaks; prevents double-counting but less aggressive
    distance = max(1, int(fs * min_interval_s))

    scale = np.std(gyro_smooth)
    prominence = max(5.0, 1.2 * scale)  # reduced so we don't miss real strokes

    peaks, _ = find_peaks(
        gyro_smooth,
        distance=distance,
        prominence=prominence
    )

    stroke_count = len(peaks)
    return stroke_count, list(peaks), gyro_smooth

5. Output

In [14]:
# inputs
data_source_select = widgets.Dropdown(
    options=[('Database (Accelerometer)', 'db'), ('CSV (Annotated Strokes)', 'csv')],
    value='db',
    description='Data Source:',
    style={'description_width': 'initial'}
)

start_time_input = widgets.Text(
    value="07:00:00",
    description='Start Time (HH:MM:SS GMT+8):',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='90%')
)

end_time_input = widgets.Text(
    value="07:05:00",
    description='End Time (HH:MM:SS GMT+8):',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='90%')
)

output = widgets.Output()

def calculate_stroke_metrics(change):
    with output:
        clear_output()

        source = data_source_select.value
        current_df = None
        if source == 'db':
            if 'df' not in globals() or df.empty:
                print("DB data not loaded. Please load the .db file first.")
                return
            current_df = df
        elif source == 'csv':
            if 'csv_df' not in globals() or csv_df.empty:
                print("CSV data not loaded. Please load the .csv file first.")
                return
            current_df = csv_df
        
        if 'datetime' not in current_df.columns:
            print("Error: 'datetime' column not found in the selected data. Please ensure the data is loaded correctly with a 'timestamp' or 'unix_ts' column.")
            return
        
        if current_df['datetime'].empty:
            print("Error: 'datetime' column is empty. No time data available for analysis.")
            return

        print(f"\nFull data time range (GMT+8): {current_df['datetime'].min().strftime('%Y-%m-%d %H:%M:%S')} to {current_df['datetime'].max().strftime('%Y-%m-%d %H:%M:%S')}")

        try:
            start_time_str = start_time_input.value
            end_time_str = end_time_input.value

            base_date = current_df['datetime'].iloc[0].date()
            
            start_time = pd.to_datetime(f"{base_date} {start_time_str}").tz_localize('Asia/Manila')
            end_time = pd.to_datetime(f"{base_date} {end_time_str}").tz_localize('Asia/Manila')

            print(f"Filtering for: {start_time.strftime('%Y-%m-%d %H:%M:%S')} GMT+8 to {end_time.strftime('%Y-%m-%d %H:%M:%S')} GMT+8")

        except ValueError:
            print("Invalid time format. Please use HH:MM:SS.")
            return
        except Exception as e:
            print(f"Error parsing times: {e}")
            return

        mask = (current_df['datetime'] >= start_time) & (current_df['datetime'] <= end_time)
        segment = current_df[mask]

        if segment.empty:
            print("No data in selected time range.")
            return

        if source == 'db':
            # stroke cycle identification for DB data
            stroke_count, peaks, signal = identify_stroke_cycles(segment)
            print(f"Detected Stroke Count (DB): {stroke_count}")

            # calculate stroke rate
            duration_seconds = (end_time - start_time).total_seconds()
            stroke_rate = (stroke_count / duration_seconds) * 60 if duration_seconds > 0 else 0
            print(f"Stroke Rate (DB): {stroke_rate:.2f} strokes/min")

            # results plot
            plt.figure(figsize=(12,4))
            plt.plot(segment['datetime'], signal, label="Filtered ay+az")
            if len(peaks) > 0:
                plt.plot(segment['datetime'].iloc[peaks], signal[peaks], "rx", label="Strokes")
            plt.xlabel("Time")
            plt.ylabel("Filtered Acceleration (m/s^2)")
            plt.legend()
            plt.title(f"Stroke Analysis (DB Data) from {start_time.strftime('%H:%M:%S')} to {end_time.strftime('%H:%M:%S')}")
            plt.show()
        elif source == 'csv':
            # stroke count
            stroke_count = segment.shape[0]
            print(f"Annotated Stroke Count (CSV): {stroke_count}")

            # calculate stroke rate
            duration_seconds = (end_time - start_time).total_seconds()
            stroke_rate = (stroke_count / duration_seconds) * 60 if duration_seconds > 0 else 0
            print(f"Stroke Rate (CSV): {stroke_rate:.2f} strokes/min")

calc_button = widgets.Button(
    description="Calculate Stroke Metrics",
    button_style='success'
)

calc_button.on_click(None)
calc_button.on_click(calculate_stroke_metrics)

display(data_source_select, start_time_input, end_time_input, calc_button, output)

Dropdown(description='Data Source:', options=(('Database (Accelerometer)', 'db'), ('CSV (Annotated Strokes)', …

Text(value='07:00:00', description='Start Time (HH:MM:SS GMT+8):', layout=Layout(width='90%'), style=TextStyle…

Text(value='07:05:00', description='End Time (HH:MM:SS GMT+8):', layout=Layout(width='90%'), style=TextStyle(d…

Button(button_style='success', description='Calculate Stroke Metrics', style=ButtonStyle())

Output()

In [15]:
compare_all_button = widgets.Button(
    description="Compare Data",
    button_style='warning'
)
compare_output = widgets.Output()


def compare_all_data(_):
    with compare_output:
        clear_output()

        # Basic checks
        if 'csv_df' not in globals() or csv_df.empty:
            print("CSV data not loaded. Please load the .csv file first.")
            return
        if 'df' not in globals() or df.empty:
            print("DB data not loaded. Please load the .db file first.")
            return
        if 'datetime' not in csv_df.columns or 'datetime' not in df.columns:
            print("Error: 'datetime' column not found in one or both data sources.")
            return

        # Determine common time range based on CSV timestamps
        if 'unix_ts' in csv_df.columns:
            start_unix = csv_df['unix_ts'].iloc[0]
            end_unix = csv_df['unix_ts'].iloc[-1]
            start_dt = pd.to_datetime(start_unix, unit='ms', utc=True).tz_convert('Asia/Manila')
            end_dt = pd.to_datetime(end_unix, unit='ms', utc=True).tz_convert('Asia/Manila')
        elif 'timestamp' in csv_df.columns:
            start_unix = csv_df['timestamp'].iloc[0]
            end_unix = csv_df['timestamp'].iloc[-1]
            start_dt = pd.to_datetime(start_unix, unit='ms', utc=True).tz_convert('Asia/Manila')
            end_dt = pd.to_datetime(end_unix, unit='ms', utc=True).tz_convert('Asia/Manila')
        else:
            print("CSV does not have 'unix_ts' or 'timestamp' column.")
            return

        print(f"Comparing over time range: {start_dt} to {end_dt}")

        # Ground‑truth strokes from CSV (one row = one annotated stroke)
        csv_stroke_count = csv_df.shape[0]
        print(f"\nAnnotated Stroke Count (CSV): {csv_stroke_count}")

        # Detected strokes from DB over the same time window
        db_segment = df[(df['datetime'] >= start_dt) & (df['datetime'] <= end_dt)]
        if db_segment.empty:
            print("No DB data in this time range.")
            db_stroke_count = 0
        else:
            db_stroke_count, _, _ = identify_stroke_cycles(db_segment)
        print(f"Detected Stroke Count (peak detection, DB): {db_stroke_count}")

        # Simple accuracy metrics based on counts
        if csv_stroke_count > 0:
            count_error = db_stroke_count - csv_stroke_count
            abs_error = abs(count_error)
            percent_error = (abs_error / csv_stroke_count) * 100.0
            count_accuracy = max(0.0, 1.0 - (abs_error / csv_stroke_count)) * 100.0

            print("\nAccuracy summary (count‑based):")
            print(f"  Count error (DB - CSV): {count_error}")
            print(f"  Absolute error: {abs_error}")
            print(f"  Percent error: {percent_error:.2f}%")
            print(f"  Count accuracy: {count_accuracy:.2f}%")
        else:
            print("\nCSV annotated stroke count is zero; cannot compute percentage accuracy.")


compare_all_button.on_click(compare_all_data)

display(compare_all_button, compare_output)




Output()

In [16]:
# Stroke Count Accuracy Evaluation

accuracy_source_select = widgets.Dropdown(
    options=[('Database (Accelerometer)', 'db')],
    value='db',
    description='Detection source:',
    style={'description_width': 'initial'}
)

stroke_type_select = widgets.Dropdown(
    options=['Freestyle', 'Backstroke', 'Breaststroke', 'Butterfly'],
    value='Freestyle',
    description='Stroke type:',
    style={'description_width': 'initial'}
)

range_mode_select = widgets.Dropdown(
    options=[('Whole dataset', 'whole'), ('Time window (GMT+8)', 'window')],
    value='whole',
    description='Range:',
    style={'description_width': 'initial'}
)

acc_start_input = widgets.Text(
    value="",  # e.g. "2025-05-20 07:19:30"
    description='Start (YYYY-MM-DD HH:MM:SS):',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='95%')
)

acc_end_input = widgets.Text(
    value="",
    description='End   (YYYY-MM-DD HH:MM:SS):',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='95%')
)

actual_strokes_input = widgets.IntText(
    value=0,
    description='Actual stroke count:',
    style={'description_width': 'initial'}
)

eval_button = widgets.Button(
    description='Evaluate Accuracy',
    button_style='success'
)

eval_output = widgets.Output()


def _get_eval_segment():
    """Return (segment_df, message) for the chosen range mode from df."""
    if 'df' not in globals() or df.empty:
        return pd.DataFrame(), "DB data not loaded. Please load the .db file first."
    if 'datetime' not in df.columns:
        return pd.DataFrame(), "DB data has no 'datetime' column."

    if range_mode_select.value == 'whole':
        segment = df.copy()
        msg = f"Whole dataset | {segment['datetime'].min()} to {segment['datetime'].max()}"
        return segment, msg

    # Time-window mode
    start_raw = acc_start_input.value.strip()
    end_raw = acc_end_input.value.strip()

    if not start_raw or not end_raw:
        return pd.DataFrame(), "Please enter both start and end timestamps for the time window."

    try:
        start_dt = pd.to_datetime(start_raw)
        end_dt = pd.to_datetime(end_raw)
        # Assume user times are in Asia/Manila
        if start_dt.tzinfo is None:
            start_dt = start_dt.tz_localize('Asia/Manila')
        else:
            start_dt = start_dt.tz_convert('Asia/Manila')
        if end_dt.tzinfo is None:
            end_dt = end_dt.tz_localize('Asia/Manila')
        else:
            end_dt = end_dt.tz_convert('Asia/Manila')
    except Exception as e:
        return pd.DataFrame(), f"Failed to parse start/end timestamps: {e}"

    mask = (df['datetime'] >= start_dt) & (df['datetime'] <= end_dt)
    segment = df[mask].copy()
    if segment.empty:
        return pd.DataFrame(), "No DB data in the selected time window."

    msg = f"Time window | {start_dt} to {end_dt} | rows: {len(segment)}"
    return segment, msg


def _run_detection(segment):
    """Select detection pipeline based on stroke type.

    Freestyle/Backstroke/Breaststroke -> identify_stroke_cycles (accel-based)
    Butterfly -> detect_strokes_butterfly (gyro-based)
    """
    stroke_type = stroke_type_select.value

    if stroke_type == 'Butterfly':
        stroke_count, peaks, signal = detect_strokes_butterfly(segment)
        label = 'Smoothed gyro magnitude'
    else:
        stroke_count, peaks, signal = identify_stroke_cycles(segment)
        label = 'Filtered accel signal (ay+az)'

    return stroke_count, peaks, signal, label


def on_evaluate_accuracy(_):
    with eval_output:
        clear_output()

        if accuracy_source_select.value != 'db':
            print("Only DB-based detection is supported in this widget.")
            return

        segment, msg = _get_eval_segment()
        if segment.empty:
            print(msg)
            return

        print(f"Evaluation segment: {msg}")
        print(f"Selected stroke type: {stroke_type_select.value}")

        # Run selected detection pipeline
        stroke_count, peaks, signal, signal_label = _run_detection(segment)
        print(f"\nDetected Stroke Count (peak detection): {stroke_count}")

        # Duration and stroke rate
        duration_seconds = (segment['datetime'].max() - segment['datetime'].min()).total_seconds()
        stroke_rate = (stroke_count / duration_seconds) * 60 if duration_seconds > 0 else 0
        print(f"Estimated Stroke Rate: {stroke_rate:.2f} strokes/min")

        # Compare to user-provided actual stroke count
        actual = actual_strokes_input.value
        if actual > 0:
            count_error = stroke_count - actual
            abs_error = abs(count_error)
            percent_error = (abs_error / actual) * 100.0
            count_accuracy = max(0.0, 1.0 - (abs_error / actual)) * 100.0

            print("\nAccuracy vs. manual strokes:")
            print(f"  Actual strokes: {actual}")
            print(f"  Detected strokes: {stroke_count}")
            print(f"  Count error (detected - actual): {count_error}")
            print(f"  Absolute error: {abs_error}")
            print(f"  Percent error: {percent_error:.2f}%")
            print(f"  Count accuracy: {count_accuracy:.2f}%")
        else:
            print("\nActual stroke count not provided (or zero). Enter a positive number to see accuracy.")

        # Quick plot for tuning / visual inspection
        try:
            plt.figure(figsize=(12, 4))
            plt.plot(segment['datetime'], signal, label=signal_label)
            if len(peaks) > 0:
                plt.plot(segment['datetime'].iloc[peaks], signal[peaks], 'rx', label='Detected strokes')
            plt.xlabel('Time')
            plt.ylabel('Signal value')
            plt.title('Stroke detection segment')
            plt.legend()
            plt.grid(True, alpha=0.3)
            plt.show()
        except Exception as e:
            print(f"Plotting failed: {e}")


eval_button.on_click(on_evaluate_accuracy)

# Display UI

display(widgets.HTML(value="<h3>Stroke Count Accuracy Evaluation</h3>"))
display(accuracy_source_select)
display(stroke_type_select)
display(range_mode_select)

range_box = widgets.VBox([acc_start_input, acc_end_input])


def _on_range_mode_change(change):
    if change['new'] == 'whole':
        acc_start_input.layout.display = 'none'
        acc_end_input.layout.display = 'none'
    else:
        acc_start_input.layout.display = ''
        acc_end_input.layout.display = ''


range_mode_select.observe(_on_range_mode_change, names='value')

# Initialize visibility
acc_start_input.layout.display = 'none'
acc_end_input.layout.display = 'none'

display(range_box)

display(actual_strokes_input)
display(eval_button, eval_output)

HTML(value='<h3>Stroke Count Accuracy Evaluation</h3>')

Dropdown(description='Detection source:', options=(('Database (Accelerometer)', 'db'),), style=DescriptionStyl…

Dropdown(description='Stroke type:', options=('Freestyle', 'Backstroke', 'Breaststroke', 'Butterfly'), style=D…

Dropdown(description='Range:', options=(('Whole dataset', 'whole'), ('Time window (GMT+8)', 'window')), style=…

VBox(children=(Text(value='', description='Start (YYYY-MM-DD HH:MM:SS):', layout=Layout(display='none', width=…

IntText(value=0, description='Actual stroke count:', style=DescriptionStyle(description_width='initial'))

Button(button_style='success', description='Evaluate Accuracy', style=ButtonStyle())

Output()