# Exercise: IMU-Based Stride Length Estimation

Can we estimate the stride length of a person using data from an inertial measurement unit (IMU) attached to their foot? In this exercise, we will use data from an IMU attached to the foot of a person walking a 5 meter distance on level ground to estimate their stride length.

In [2]:
# Import the usual libraries
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from kielmat.datasets import keepcontrol

## Load the data

First, we load the data from the left foot-worn IMU for subject "pp002" of the trial "walking at half your normal speed" from the Keep Control dataset.

In [3]:
DATASET_PATH = Path("../datasets/keepcontrol")
sub_id = "pp002"
task_name = "walkSlow"
track_sys = "imu"
tracked_points = "left_foot"

In [4]:
recording = keepcontrol.load_recording(
    dataset_path=DATASET_PATH, id=sub_id, task=task_name, tracking_systems=[track_sys], tracked_points=tracked_points
)



## Extract relevant metadata

When working with IMU data, it is important to know the sampling rate of the sensor. This information is usually available in the metadata of the dataset.
Furthermore, we need to know in which units the acceleration and angular velocity data are provided.
This metadata is provided in the `KielMATRecording` object in the `channels` attribute.

Explore the `channels` attribute, and 
- extract the sampling rate, and
- units of the acceleration and angular velocity data

Put the units of recording in a dictionary that maps the channel type (i.e., `ACCEL` and `GYRO`) to the corresponding units.

In [6]:
# Replace the Ellipsis with the correct code
sampling_freq_Hz = ... # YOUR CODE HERE
mapping_units = {
    ...: recording.channels[track_sys][recording.channels[track_sys]["type"] == ...]["units"].iloc[0]
    for ... in recording.channels[track_sys]["type"].unique()
}
print(f"Sampling frequency: {sampling_freq_Hz} Hz")
print(f"Sensor units: {mapping_units}")

Sampling frequency: 200.0 Hz
Sensor units: {'ACCEL': 'g', 'GYRO': 'deg/s', 'MAGN': 'Gauss'}


## Visualize the data

Now, to get a first impression of what the data looks like, we will plot the acceleration and angular velocity data over time.
- Plot the acceleration data in the three axes (x, y, z) in a single plot.
- Plot the angular velocity data in the three axes (x, y, z) in a single plot.

In [9]:
fig = make_subplots(rows=2, cols=1, shared_xaxes=True, shared_yaxes=True, vertical_spacing=0.1, horizontal_spacing=0.01)
for col_idx, tracked_point in enumerate(tracked_points):
    for row_idx, channel_type in enumerate(["ACCEL", "GYRO"]):
        for _, axis in enumerate(["x", "y", "z"]):
            fig.add_trace(
                go.Scatter(
                    x=np.arange(len(recording.data[track_sys])) / sampling_freq_Hz,
                    y=recording.data[track_sys][f"{tracked_point}_{channel_type}_{axis}"],
                    mode="lines",
                    name=f"{tracked_point}_{channel_type}_{axis}",
                ),
                row=row_idx + 1,
                col=col_idx + 1,
            )
        fig.update_yaxes(title_text=f"{mapping_units[channel_type]}", row=row_idx + 1, col=1)
fig.update_xaxes(title_text="time (s)", row=2, col=1)
fig.update_xaxes(title_text="time (s)", row=2, col=2)
fig.update_layout(margin=dict(l=30, r=30, t=20, b=20))
fig

Can you guess the sensor orientation based on these signals?
- What sensor axis corresponds to the vertical axis?
- What sensor axis corresponds to the mediolateral axis?

## Put the data in the expected format

To estimate the stride length, we need to calculate the displacement of the foot during the gait cycle. This can be done by integrating the vertical acceleration data twice. However, before we can do that, we need to make sure we integrate the acceleration data in the axes, namely aliging the sensor axes with the global axes.

Luckily, several research groups have faced the same issue, and have prepared some pretty useful tools. Here, we will explore the use of `gaitmap` to estimate the stride length.

Like any other tool, `gaitmap` expects the data to be in a specific format.
See: https://gaitmap.readthedocs.io/en/latest/source/user_guide/prepare_data.html

Prepare the data in the expected format:
- Put the acceleration and angular velocity data in a single DataFrame, and put it in a `dict` with the keys corresponding to the tracked point,
- if necessary, convert the data to the expected units.

In [17]:
fsf_dataset = dict()
for tracked_point in tracked_points:
    # Extract the data for the tracked point
    fsf_dataset[tracked_point] = ... # YOUR CODE HERE

    # Update the column names
    fsf_dataset[tracked_point].columns = ... # YOUR CODE HERE

    # Convert the units
    if mapping_units["ACCEL"] == "g":
        ... # YOUR CODE HERE
    if mapping_units["GYRO"] == "rad/s":
        ... # YOUR CODE HERE
fsf_dataset



A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



{'left_foot':          acc_x     acc_y     acc_z      gyr_x      gyr_y      gyr_z
 0    -4.914837 -1.389314  8.406459  -0.088907   0.082052   0.259661
 1    -4.943567 -1.336120  8.392271  -0.702367   0.175827   0.086554
 2    -4.923753 -1.355520  8.334337  -0.346738   0.082052   0.173108
 3    -4.967344 -1.384308  8.415918  -0.266722   0.527480  -0.086554
 4    -4.862330 -1.331739  8.387542   0.266722  -0.257879  -0.259661
 ...        ...       ...       ...        ...        ...        ...
 2579 -9.465098 -6.179320  6.150549 -49.139042 -42.573490  97.484344
 2580 -8.708207 -3.247366  7.731342 -22.466864 -42.667263  68.278603
 2581 -7.194426 -0.330432  9.120595 -10.926702 -39.080402  45.465488
 2582 -6.518771  1.896852  9.277846 -13.638373 -30.863436  31.913633
 2583 -5.124864  3.171016  9.321593 -16.172230 -26.584988  27.895063
 
 [2584 rows x 6 columns]}

## Put the data in the Foot Body Frame

When using IMU data there are different coordinate frames that the data can be expressed in. Most algorithms expect the data to be in a specific coordinate frame. In the case of `gaitmap`, the data should be in the Foot Body Frame (FBF) (see: [https://gaitmap.readthedocs.io/en/latest/source/user_guide/coordinate_systems.html](https://gaitmap.readthedocs.io/en/latest/source/user_guide/coordinate_systems.html)).

- Put the dataset in the foot body frame. You can make use of the functions provided by `gaitmap` to do this.

In [18]:
from gaitmap.utils.coordinate_conversion import convert_to_fbf

fbf_dataset = ... # YOUR CODE HERE
fbf_dataset

If you have done everything correctly, you should now have the data in the expected format and coordinate frame. That means that the column headers refer to the anatomical axes, `pa`, `ml` and `si`.

In [21]:
fig = make_subplots(rows=2, cols=1, shared_xaxes=True, shared_yaxes=True, vertical_spacing=0.1, horizontal_spacing=0.01)
for col_idx, tracked_point in enumerate(tracked_points):
    for row_idx, channel_type in enumerate(["acc", "gyr"]):
        for _, axis in enumerate(["pa", "ml", "si"]):
            fig.add_trace(
                go.Scatter(
                    x=np.arange(len(fbf_dataset[tracked_point])) / sampling_freq_Hz,
                    y=fbf_dataset[tracked_point][f"{channel_type}_{axis}"],
                    mode="lines",
                    name=f"{tracked_point}_{channel_type}_{axis}",
                ),
                row=row_idx + 1,
                col=col_idx + 1,
            )
        # fig.update_yaxes(title_text=f"{mapping_units[channel_type]}", row=row_idx + 1, col=1)
fig.update_xaxes(title_text="time (s)", row=2, col=1)
fig.update_xaxes(title_text="time (s)", row=2, col=2)
fig.update_layout(margin=dict(l=30, r=30, t=20, b=20))
fig

## Segment the data into individual strides

The next step is to segment the data into consecutive strides. The strides are clearly visible from the data, and typically the mediolateral angular velocity signal is used to detect the start and end of each stride. 

Here we will use the segmentation method based on dynamic time warping (DTW) provided by `gaitmap` ([Barth et al., 2013](https://doi.org/10.1109/EMBC.2013.6611104)).

In [23]:
from gaitmap.stride_segmentation import BarthDtw, BarthOriginalTemplate
dtw = BarthDtw(template=BarthOriginalTemplate())
dtw = ... # YOUR CODE HERE
dtw

BarthDtw(conflict_resolution=True, find_matches_method='find_peaks', max_cost=4.0, max_match_length_s=3.0, max_signal_stretch_ms=None, max_template_stretch_ms=None, memory=None, min_match_length_s=0.6, resample_template=True, snap_to_min_axis='gyr_ml', snap_to_min_win_ms=300, template=BarthOriginalTemplate(scaling=FixedScaler(offset=0, scale=500.0), use_cols=None))

In [24]:
dtw.stride_list_

{'left_foot':       start   end
 s_id             
 0       803  1080
 1      1080  1384
 2      1384  1657
 3      1657  1958
 4      1958  2249
 5      2249  2370}

In [27]:
fig = make_subplots(rows=2, cols=1, shared_xaxes=True, shared_yaxes=True, vertical_spacing=0.1, horizontal_spacing=0.01)
for col_idx, tracked_point in enumerate(tracked_points):
    for row_idx, channel_type in enumerate(["acc", "gyr"]):
        for _, axis in enumerate(["pa", "ml", "si"]):
            fig.add_trace(
                go.Scatter(
                    x=np.arange(len(fbf_dataset[tracked_point])) / sampling_freq_Hz,
                    y=fbf_dataset[tracked_point][f"{channel_type}_{axis}"],
                    mode="lines",
                    name=f"{tracked_point}_{channel_type}_{axis}",
                ),
                row=row_idx + 1,
                col=col_idx + 1,
            )
        # fig.update_yaxes(title_text=f"{mapping_units[channel_type]}", row=row_idx + 1, col=1)
for _, (start_idx, end_indx) in dtw.stride_list_[tracked_point][["start", "end"]].iterrows():
    fig.add_vline(x=start_idx / sampling_freq_Hz, line_width=1, line_dash="dash", line_color="green")
fig.update_xaxes(title_text="time (s)", row=2, col=1)
fig.update_xaxes(title_text="time (s)", row=2, col=2)
fig.update_layout(margin=dict(l=30, r=30, t=20, b=20))
fig

## Event Detection

Next, now that we have the individual strides, we can detect the events of the gait cycle. The events of the gait cycle are typically the initial contact (IC; or heel strike, HS) and final contact (FC; or terminal contact (TC) / toe off (TO)) events. 

These events can be detected from the acceleration and angular velocity signals ([Rampp et al., 2015](https://doi.org/10.1109/TBME.2014.2368211)).

In [28]:
from gaitmap.event_detection import RamppEventDetection

ed = RamppEventDetection()
ed = ... # YOUR CODE HERE

In [30]:
ed.min_vel_event_list_["left_foot"]

Unnamed: 0_level_0,start,end,ic,tc,min_vel,pre_ic
s_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
0,995.0,1269.0,1196.0,1102.0,995.0,916.0
1,1269.0,1535.0,1489.0,1401.0,1269.0,1196.0
2,1535.0,1893.0,1778.0,1680.0,1535.0,1489.0
3,1893.0,2112.0,2072.0,1979.0,1893.0,1778.0


In [39]:
fig = make_subplots(rows=2, cols=1, shared_xaxes=True, shared_yaxes=True, vertical_spacing=0.1, horizontal_spacing=0.01)
for col_idx, tracked_point in enumerate(tracked_points):
    for row_idx, channel_type in enumerate(["acc", "gyr"]):
        for _, axis in enumerate(["pa", "ml", "si"]):
            fig.add_trace(
                go.Scatter(
                    x=np.arange(len(fbf_dataset[tracked_point])) / sampling_freq_Hz,
                    y=fbf_dataset[tracked_point][f"{channel_type}_{axis}"],
                    mode="lines",
                    name=f"{tracked_point}_{channel_type}_{axis}",
                ),
                row=row_idx + 1,
                col=col_idx + 1,
            )
        # fig.update_yaxes(title_text=f"{mapping_units[channel_type]}", row=row_idx + 1, col=1)
for _, (start_idx, end_indx) in ed.min_vel_event_list_[tracked_point][["start", "end"]].iterrows():
    fig.add_vline(x=start_idx / sampling_freq_Hz, line_width=1, line_dash="dash", line_color="green")
fig.add_trace(
    go.Scatter(
        x=ed.min_vel_event_list_["left_foot"]["ic"] / sampling_freq_Hz,
        y=fbf_dataset["left_foot"]["gyr_ml"][ed.min_vel_event_list_["left_foot"]["ic"].astype(int)],
        mode="markers",
        marker=dict(size=12),
        name="ic",
    ),
    row=2, col=1
)
fig.add_trace(
    go.Scatter(
        x=ed.min_vel_event_list_["left_foot"]["tc"] / sampling_freq_Hz,
        y=fbf_dataset["left_foot"]["gyr_ml"][ed.min_vel_event_list_["left_foot"]["tc"].astype(int)],
        mode="markers",
        marker=dict(size=12),
        name="tc",
    ),
    row=2, col=1
)
fig.update_xaxes(title_text="time (s)", row=2, col=1)
fig.update_xaxes(title_text="time (s)", row=2, col=2)
fig.update_layout(margin=dict(l=30, r=30, t=20, b=20))
fig

## Extract Relevant Spatio-Temporal Gait Parameters

Finally, to extract clinically relevant parameters, we can use the segmented strides and corresponding events to calculate the length on a stride-by-stride basis. For that, we need to choose a method to estimate orientation of the sensor (and thus the foot) at each time step, and then integrate the vertical acceleration data twice to get the displacement of the foot.

- Calculate the time for each gait cycle.
- Calculate the stride length for each gait cycle.

In [58]:
from gaitmap.trajectory_reconstruction import (
    ForwardBackwardIntegration,
    MadgwickAHRS,
    StrideLevelTrajectory
)
from gaitmap.parameters import TemporalParameterCalculation, SpatialParameterCalculation

In [41]:
ori_method = MadgwickAHRS()
pos_method = ForwardBackwardIntegration()
trajectory = StrideLevelTrajectory(ori_method=ori_method, pos_method=pos_method)

In [42]:
trajectory = ... # YOUR CODE HERE

In [59]:
temp_params = TemporalParameterCalculation()
temp_params = ... # YOUR CODE HERE
temp_params.parameters_[tracked_point]

Unnamed: 0_level_0,stride_time,swing_time,stance_time
s_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,1.4,0.47,0.93
1,1.465,0.44,1.025
2,1.445,0.49,0.955
3,1.47,0.465,1.005


In [60]:
spatial_params = SpatialParameterCalculation()
spatial_params = ... # YOUR CODE HERE
spatial_params.parameters_[tracked_point]

Unnamed: 0_level_0,arc_length,gait_velocity,ic_angle,max_lateral_excursion,max_sensor_lift,stride_length,tc_angle,turning_angle
s_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
0,1.117491,0.7512,-8.857714,0.032246,0.071285,1.051681,-84.363631,0.471716
1,1.224676,0.782233,-10.200285,0.057088,0.059521,1.145971,-84.125525,-2.65865
2,1.153647,0.730749,-10.635335,0.058045,0.084934,1.055932,-84.869341,-0.558656
3,1.094145,0.711655,-13.4265,0.022186,0.055925,1.046132,-80.950529,3.362747
