# Stroke detection

## Load and inspect data
Load pickle file and inspect contents

In [None]:
import os
import pickle

# Import necessary pyologger utilities
from pyologger.load_data.datareader import DataReader
from pyologger.load_data.metadata import Metadata
from pyologger.plot_data.plotter import *
from pyologger.process_data.sampling import *
from pyologger.process_data.cropping import *
from pyologger.calibrate_data.zoc import *
from pyologger.calibrate_data.calibrate_acc_mag import *

# Change the current working directory to the root directory
# os.chdir("/Users/fbar/Documents/GitHub/pyologger")
os.chdir("/Users/jessiekb/Documents/GitHub/pyologger")

root_dir = os.getcwd()
data_dir = os.path.join(root_dir, "data")
color_mapping_path = os.path.join(root_dir, "color_mappings.json")

# Verify the current working directory
print(f"Current working directory: {root_dir}")

In [None]:
# Initialize the info class
metadata = Metadata()
metadata.fetch_databases(verbose=False)

# Save databases
dep_db = metadata.get_metadata("dep_DB")
logger_db = metadata.get_metadata("logger_DB")
rec_db = metadata.get_metadata("rec_DB")
animal_db = metadata.get_metadata("animal_DB")

# Assuming you have the metadata and dep_db loaded:
datareader = DataReader()
deployment_folder = datareader.check_deployment_folder(dep_db, data_dir)

if deployment_folder:
    datareader.read_files(metadata, save_csv=True, save_parq=True)

In [None]:
# Load the data_reader object from the pickle file
pkl_path = os.path.join(deployment_folder, 'outputs', 'data.pkl')

with open(pkl_path, 'rb') as file:
    data_pkl = pickle.load(file)

for logger_id, info in data_pkl.info.items():
    sampling_frequency = info.get('datetime_metadata', {}).get('fs', None)
    if sampling_frequency is not None:
        # Format the sampling frequency to 5 significant digits
        print(f"Sampling frequency for {logger_id}: {sampling_frequency} Hz")
    else:
        print(f"No sampling frequency available for {logger_id}")

In [None]:
# Change out preferred source of IMU or ephys data depending on your deployment
imu_logger = 'CC-96'
ephys_logger = 'UF-01'

if imu_logger != None:
    imu_fs = int(data_pkl.info[imu_logger]['datetime_metadata']['fs'])
    print(f"IMU Logger {imu_logger} sampled at: {imu_fs} Hz")
if ephys_logger != None: 
    ephys_fs = int(data_pkl.info[ephys_logger]['datetime_metadata']['fs'])
    print(f"ePhys Logger {ephys_logger} sampled at: {ephys_fs} Hz")



## Find time chunk when stroking is dominant activity
Use interactive plot to locate a start time and end time when stroking is the dominant activity.

In [None]:
imu_channels_to_plot = ['depth', 'corr_accX', 'corr_accY', 'corr_accZ', 'pitch', 'roll', 'heading']
ephys_channels_to_plot = []

imu_df = data_pkl.data[imu_logger]
ephys_df = data_pkl.data[ephys_logger]
start_time = max(imu_df['datetime'].min(), ephys_df['datetime'].min()).to_pydatetime()
end_time = min(imu_df['datetime'].max(), ephys_df['datetime'].max()).to_pydatetime()

# Define notes to plot
notes_to_plot = {
    'exhalation_breath': 'depth'
}

plot_tag_data_interactive(data_pkl, imu_channels_to_plot, imu_sampling_rate=10, ephys_channels=ephys_channels_to_plot, 
                          imu_logger=imu_logger, ephys_logger=ephys_logger, note_annotations= notes_to_plot,
                          time_range=(start_time, end_time), color_mapping_path=color_mapping_path)

In [None]:
# Example usage
new_start_time = pd.to_datetime('2024-01-16 10:06:30')
new_end_time = pd.to_datetime('2024-01-16 10:07:00')

# Get the time zone from the selected deployment
time_zone_str = data_pkl.selected_deployment['Time Zone']
time_zone = pytz.timezone(time_zone_str)

# Localize start_time and end_time to the specified time zone
new_start_time = time_zone.localize(new_start_time)
new_end_time = time_zone.localize(new_end_time)

data_crop = crop_data(data_pkl, imu_logger=imu_logger, ephys_logger=ephys_logger, start_time=new_start_time, end_time=new_end_time)

# peek at cropped data
plot_tag_data_interactive(data_pkl, imu_channels_to_plot, imu_sampling_rate=10, ephys_channels=ephys_channels_to_plot, 
                          imu_logger=imu_logger, ephys_logger=ephys_logger, note_annotations= notes_to_plot,
                          color_mapping_path=color_mapping_path) #time_range=(new_start_time, new_end_time), 

In [None]:
import numpy as np
import plotly.graph_objs as go
from scipy.signal import welch
import numpy.polynomial.polynomial as poly

def dsf(A, sampling_rate=None, fc=2.5, Nfft=None, channels=None):
    """
    Estimate the dominant stroke frequency from triaxial sensor data (accelerometer, gyroscope, magnetometer).

    Parameters
    ----------
    A : dict
        A dictionary where each key is the name of the sensor (e.g., 'acc', 'gyro', 'mag'), 
        and each value is an nx3 matrix with columns [x, y, z].
    sampling_rate : float, optional
        The sampling rate of the sensor data in Hz (samples per second). Required if A is not a dictionary.
    fc : float, optional
        The cut-off frequency in Hz for a low-pass filter to apply to A before computing the spectra. Default is 2.5 Hz.
    Nfft : int, optional
        The FFT length and therefore the frequency resolution. Default is the power of two closest to 20 * sampling_rate.
    channels : list, optional
        List of channel names (e.g., ['acc', 'gyro', 'mag']) to include in the analysis.

    Returns
    -------
    dict
        A dictionary with:
        - 'fpk': The dominant stroke frequency in Hz.
        - 'q': The quality of the peak measured by the peak power divided by the mean power of the spectra.
    """
    
    if channels is None:
        channels = ['acc']  # Default to accelerometer only if no channels are specified
    
    # Handle the input data structure
    if sampling_rate is None:
        raise ValueError("sampling_rate is a required input.")

    # Default FFT length if not provided
    if Nfft is None:
        Nfft = int(20 * sampling_rate)
        
    # Force Nfft to the nearest power of 2
    Nfft = 2 ** int(np.round(np.log2(Nfft)))

    results = {}
    fig = go.Figure()

    # Iterate over each channel (sensor type)
    for channel in channels:
        data = A[channel]

        # Apply low-pass filter if cutoff frequency is valid
        if fc is not None and fc < (sampling_rate / 2):
            data = low_pass_filter(data, cutoff=fc, fs=sampling_rate)
        
        # Compute the differential (difference between consecutive samples)
        data_diff = np.diff(data, axis=0)
        
        # Calculate power spectral density using Welch's method
        f, Pxx = welch(data_diff, fs=sampling_rate, nperseg=Nfft, axis=0)
        
        # Sum spectral power in the three axes
        v = np.sum(Pxx, axis=1)
        
        # Identify the frequency with maximum power
        m = np.max(v)
        n = np.argmax(v)
        
        # Use the frequency corresponding to the maximum power directly
        fpk = f[n]

        # # Quadratic interpolation to refine peak frequency estimate - currently giving less reliable results, bring back if necessary
        # if 1 < n < len(f) - 1:
        #     p = poly.polyfit(f[n-1:n+2], v[n-1:n+2], 2)
        #     fpk = -p[1] / (2 * p[0])
        # else:
        #     fpk = f[n]
        
        # Calculate the quality of the peak
        q = m / np.mean(v)
        
        # Store results
        results[channel] = {'fpk': fpk, 'q': q}

        # Plot the power density spectrum on a logarithmic scale using Plotly
        fig.add_trace(go.Scatter(x=f, y=10 * np.log10(Pxx[:, 0]), mode='lines', name=f'{channel.upper()} X', line=dict(color='red')))
        fig.add_trace(go.Scatter(x=f, y=10 * np.log10(Pxx[:, 1]), mode='lines', name=f'{channel.upper()} Y', line=dict(color='green')))
        fig.add_trace(go.Scatter(x=f, y=10 * np.log10(Pxx[:, 2]), mode='lines', name=f'{channel.upper()} Z', line=dict(color='blue')))
        fig.add_vline(x=fpk, line_dash="dash", line_color="black", annotation_text=f'DSF: {fpk:.2f} Hz', annotation_position="top right")

    # Update layout for logarithmic x-axis
    fig.update_layout(
        xaxis_type="log",
        title="Power Density Spectrum",
        xaxis_title="Frequency (Hz)",
        yaxis_title="Power/Frequency (dB/Hz)",
        showlegend=True
    )

    fig.show()

    return results


def low_pass_filter(data, cutoff, fs, order=4):
    """
    Apply a low-pass filter to the data to extract the static component.

    Parameters
    ----------
    data : numpy.ndarray
        Input data to filter.
    cutoff : float
        Cutoff frequency for the low-pass filter.
    fs : float
        Sampling rate in Hz.
    order : int, optional
        The order of the filter. Default is 4.

    Returns
    -------
    numpy.ndarray
        Filtered data.
    """
    nyquist = 0.5 * fs
    normal_cutoff = cutoff / nyquist
    b, a = butter(order, normal_cutoff, btype='low', analog=False)
    filtered_data = filtfilt(b, a, data, axis=0)
    return filtered_data

# Function to calculate ODBA
def calculate_odba(accX, accY, accZ, cutoff=0.1, fs=10):
    # Apply low-pass filter to get the static acceleration
    accX_static = low_pass_filter(accX, cutoff, fs)
    accY_static = low_pass_filter(accY, cutoff, fs)
    accZ_static = low_pass_filter(accZ, cutoff, fs)

    # Subtract the static component to get the dynamic acceleration
    accX_dynamic = accX - accX_static
    accY_dynamic = accY - accY_static
    accZ_dynamic = accZ - accZ_static

    # Calculate ODBA
    odba = np.abs(accX_dynamic) + np.abs(accY_dynamic) + np.abs(accZ_dynamic)
    
    return odba

import scipy.signal as sp_sig

def bandpass_filter(signal, lowcut=1, highcut=20, fs=500, order=5):
    """
    Apply a bandpass filter to the signal.

    Parameters
    ----------
    signal : array-like
        The input signal to filter.
    lowcut : float, optional
        The low cut-off frequency of the bandpass filter (default is 1 Hz).
    highcut : float, optional
        The high cut-off frequency of the bandpass filter (default is 20 Hz).
    fs : float, optional
        The sampling rate of the signal (default is 500 Hz).
    order : int, optional
        The order of the Butterworth filter (default is 5).

    Returns
    -------
    y : numpy.ndarray
        The filtered signal.
    """
    nyquist = 0.5 * fs  # Nyquist frequency
    low = lowcut / nyquist  # Normalize lowcut frequency
    high = highcut / nyquist  # Normalize highcut frequency
    b, a = sp_sig.butter(order, [low, high], btype='band')  # Design bandpass filter
    y = sp_sig.filtfilt(b, a, signal)  # Apply filter forward and backward
    return y



In [None]:
# Extract accelerometer, gyroscope, and magnetometer data from the cropped dataset
accX = data_crop.data[imu_logger]['corr_accX'].values
accY = data_crop.data[imu_logger]['corr_accY'].values
accZ = data_crop.data[imu_logger]['corr_accZ'].values

gyrX = data_crop.data[imu_logger]['corr_gyrX'].values
gyrY = data_crop.data[imu_logger]['corr_gyrY'].values
gyrZ = data_crop.data[imu_logger]['corr_gyrZ'].values

magX = data_crop.data[imu_logger]['corr_magX'].values
magY = data_crop.data[imu_logger]['corr_magY'].values
magZ = data_crop.data[imu_logger]['corr_magZ'].values

# Stack data into nx3 matrices for accelerometer, gyroscope, and magnetometer
acc_data = np.vstack((accX, accY, accZ)).T
gyr_data = np.vstack((gyrX, gyrY, gyrZ)).T
mag_data = np.vstack((magX, magY, magZ)).T

# Pack the data into a dictionary as required by the dsf function
sensor_data = {
    'acc': acc_data,
    'gyr': gyr_data,
    'mag': mag_data
}

# Call the dsf function, specifying the sensor data and sampling rate
result = dsf(sensor_data, sampling_rate=imu_fs, channels=['acc', 'gyr', 'mag'])

# Print the dominant stroke frequency and quality for each sensor type
print(f"Accelerometer - Dominant Stroke Frequency: {result['acc']['fpk']} Hz")
print(f"Accelerometer - Quality: {result['acc']['q']}")

print(f"Gyroscope - Dominant Stroke Frequency: {result['gyr']['fpk']} Hz")
print(f"Gyroscope - Quality: {result['gyr']['q']}")

print(f"Magnetometer - Dominant Stroke Frequency: {result['mag']['fpk']} Hz")
print(f"Magnetometer - Quality: {result['mag']['q']}")


## Define Low-Pass and High-Pass Cut-Off Frequencies

This should be 70% of your dominant stroking frequency, but anywhere between 50% and 70% could be good options. The terms "high-pass" and "low-pass" can sometimes be confusing—a high-pass filter lets everything higher than the cut-off frequency pass, and a low-pass filter lets everything lower than the cut-off frequency pass. So, the high-pass filter is the lower cut-off frequency (below your dominant stroking frequency, `dsf`), and your low-pass filter is your higher cut-off frequency (above your `dsf`). We recommend:

$$
\text{High-Pass Filter} = \text{Low Cut-Off Frequency} = 0.70 \times \text{dsf}
$$

$$
\text{Low-Pass Filter} = \text{High Cut-Off Frequency} = 2 \times \text{dsf}
$$


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scipy.signal as sp_sig

# Bandpass filter function
def bandpass_filter(signal, lowcut=1, highcut=20, fs=500, order=5):
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = sp_sig.butter(order, [low, high], btype='band')
    y = sp_sig.filtfilt(b, a, signal)
    return y

# Example usage with your data
accX = data_pkl.data[imu_logger]['corr_accX'].values

# Check for NaNs or Infs in the input data
print(f"NaNs in signal: {np.isnan(accX).sum()}, Infs in signal: {np.isinf(accX).sum()}")

# Define cut-off frequencies and sampling rate
lowcut = 0.70 * result['acc']['fpk']
highcut = 2 * result['acc']['fpk']
imu_fs

# Validate cutoff frequencies
nyquist = 0.5 * imu_fs
print(f"Low cutoff: {lowcut}, High cutoff: {highcut}, Nyquist: {nyquist}")
if lowcut >= nyquist or highcut >= nyquist or lowcut <= 0 or highcut <= 0:
    raise ValueError("Invalid cutoff frequencies")

# Apply the bandpass filter
filtered_data = bandpass_filter(signal=accX, lowcut=lowcut, highcut=highcut, fs=imu_fs, order=4)

# Check if filtered data is all NaNs
print(f"Filtered data contains NaNs: {np.isnan(filtered_data).sum()}")

# Create time vector based on sampling rate and length of data
time = np.arange(len(accX)) / imu_fs

# Plot original and filtered data
plt.figure(figsize=(12, 6))
plt.plot(time, accX, label='Original Data', alpha=0.5)
plt.plot(time, filtered_data, label='Filtered Data', linewidth=2)
plt.title('Original and Bandpass Filtered Accelerometer Data')
plt.xlabel('Time (s)')
plt.ylabel('Acceleration (m/s^2)')
plt.legend()
plt.grid(True)
plt.show()


In [None]:
# Example usage with your data
accX = data_pkl.data[imu_logger]['corr_accX'].values
accY = data_pkl.data[imu_logger]['corr_accY'].values
accZ = data_pkl.data[imu_logger]['corr_accZ'].values

odba = calculate_odba(accX, accY, accZ)

data_pkl.data[imu_logger]['odba'] = odba

imu_channels_to_plot = ['depth', 'corr_accX', 'corr_accY', 'corr_accZ', 'odba', 'pitch', 'roll', 'heading']
ephys_channels_to_plot = []
imu_logger_to_use = imu_logger
ephys_logger_to_use = ephys_logger

# Get the overlapping time range
imu_df = data_pkl.data[imu_logger_to_use]
ephys_df = data_pkl.data[ephys_logger_to_use]
start_time = max(imu_df['datetime'].min(), ephys_df['datetime'].min()).to_pydatetime()
end_time = min(imu_df['datetime'].max(), ephys_df['datetime'].max()).to_pydatetime()

# Define notes to plot
notes_to_plot = {
    'exhalation_breath': 'depth'
}

plot_tag_data_interactive(data_pkl, imu_channels_to_plot, imu_sampling_rate=10, ephys_channels=ephys_channels_to_plot, 
                          imu_logger=imu_logger_to_use, ephys_logger=ephys_logger_to_use, note_annotations= notes_to_plot,
                          time_range=(start_time, end_time), color_mapping_path=color_mapping_path)