In [124]:
import os
import math
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.express as px
import ipywidgets as widgets

from IPython.display import display, Markdown, clear_output


def numpy_to_latex_bmatrix(array):
    if len(array.shape) != 2:
        raise ValueError("Input must be a 2D numpy array")
    
    latex_str = "\\begin{bmatrix}"
    for row in array:
        latex_str += " & ".join(map(str, row)) + r" \\"
    latex_str += r"\end{bmatrix}"
    
    return latex_str


def plot_path(df):
    """
    Plots the 3D path of a point over time using Plotly.
    
    Parameters:
    df (pd.DataFrame): DataFrame containing
    x, y, and z positions in the first three columns.
    """
    fig = go.Figure()
    fig.data = []

    # Add trace for the path
    fig.add_trace(go.Scatter3d(
        x=df.iloc[:, 0],  # x positions
        y=df.iloc[:, 1],  # y positions
        z=df.iloc[:, 2],  # z positions
        mode='lines',
        name='Path'
    ))

    # Add markers for start and end points
    fig.add_trace(go.Scatter3d(
        x=[df.iloc[0, 0]],  # start x position
        y=[df.iloc[0, 1]],  # start y position
        z=[df.iloc[0, 2]],  # start z position
        mode='markers',
        marker=dict(size=10, color='green'),
        name='Start'
    ))

    fig.add_trace(go.Scatter3d(
        x=[df.iloc[-1, 0]],  # end x position
        y=[df.iloc[-1, 1]],  # end y position
        z=[df.iloc[-1, 2]],  # end z position
        mode='markers',
        marker=dict(size=10, color='red'),
        name='End'
    ))

    # Update layout
    fig.update_layout(
        title='3D Path Over Time',
        scene=dict(
            xaxis_title='X Position',
            yaxis_title='Y Position',
            zaxis_title='Z Position'
        ),
        showlegend=True
    )

    fig.show()


def plot_trc(fp, output):
    testtrc = pd.read_csv(fp, sep='\t', skiprows = 3)
    newtrccols = [f"{i.lstrip('Mobility_markerset:')}_{axis}" for i in testtrc.columns[2::3] for axis in ['x', 'y', 'z']]
    testtrc.columns = list(testtrc.columns[:2])+list(newtrccols)
    testtrc.iloc[1:, 2:] = testtrc.iloc[1:, 2:].astype(float) / 10 # Convert mm to cm
    t = testtrc['Time'][1:].to_numpy()

    # Create the figure
    fig = go.Figure()

    # Initialize frames list
    frames = []

    # Loop to add multiple traces
    for i in range(36):
        x = testtrc.iloc[1:, 2 + 3 * i].to_numpy(dtype=float)
        z = testtrc.iloc[1:, 3 + 3 * i].to_numpy(dtype=float)
        y = testtrc.iloc[1:, 4 + 3 * i].to_numpy(dtype=float)
        
        # Add the trace for the path of each set of data
        fig.add_trace(go.Scatter3d(x=x, y=y, z=z, mode='markers', marker=dict(size=1)))

    # Create frames for all traces
    frames = [
        go.Frame(
            data=[
                go.Scatter3d(x=[testtrc.iloc[k, 2 + 3 * i]], y=[testtrc.iloc[k, 4 + 3 * i]], z=[testtrc.iloc[k, 3 + 3 * i]], mode='markers', marker=dict(size=2))
                for i in range(36)
            ],
            name=str(k)
        )
        for k in range(1, len(t), 10)
    ]

    # Determine the axis ranges
    x_min, x_max = -500, 120
    y_min, y_max = -50, 200
    z_min, z_max = 0, 200

    # Calculate aspect ratios
    x_range = x_max - x_min
    y_range = y_max - y_min
    z_range = z_max - z_min
    # Calculate the aspect ratio
    max_range = max(x_range, y_range, z_range)
    aspect_ratio = dict(x=x_range / max_range, y=y_range / max_range, z=z_range / max_range)

    camera = dict(
        eye=dict(x=.1, y=0.5, z=.1),  # Adjust these values to set the initial zoom level
        up=dict(x=0, y=0, z=1),  # Adjust these values to set the up direction
        center=dict(x=0, y=0, z=0)  # Adjust these values to set the center of the view
    )

    # Update layout with animation settings and fixed axis ranges
    fig.update_layout(
        scene=dict(
            xaxis=dict(title='X Axis', range=[x_min, x_max], autorange=False),
            yaxis=dict(title='Y Axis', range=[y_min, y_max], autorange=False),
            zaxis=dict(title='Z Axis', range=[z_min, z_max], autorange=False),
            aspectratio=aspect_ratio,
            camera=camera
        ),
        updatemenus=[dict(
            type='buttons',
            showactive=False,
            buttons=[
                dict(label='Play', method='animate', args=[None, dict(frame=dict(duration=80, redraw=True), fromcurrent=True)]),
                dict(label='Pause', method='animate', args=[[None], dict(frame=dict(duration=0, redraw=False), mode='immediate')])
            ]
        )],
        sliders=[{
            'steps': [{'args': [[f.name], {'frame': {'duration': 0, 'redraw': True}, 'mode': 'immediate'}],
                    'label': str(i),
                    'method': 'animate'} for i, f in enumerate(frames)],
            'transition': {'duration': 0},
            'x': 0.1,
            'len': 0.9
        }],
        title=f'3D Animation of {fp[13:-4]}'
    )

    # Add frames to the figure
    fig.frames = frames

    # Show the plot
    with output:
        output.clear_output()
        fig.show()


class imuData:
    def __init__(self, name, df, sensor_num):
        self.name = name
        self.indices = [row for row in df.index]

        self.start_row_index = self.indices.index(
            'DelsysTrignoBase 1: Sensor '+str(sensor_num)+'IM ACC Pitch')

        self.all_data = df.iloc[
            self.start_row_index : self.start_row_index+6]

        acc_data = self.all_data.iloc[0:3]
        self.a_local = acc_data.transpose()
        self.a_local.columns.values[0:3] = ['a_pitch', 'a_roll', 'a_yaw']

        sqrt_acc = np.square(self.a_local)
        net_acc_sq = sqrt_acc.apply(np.sum, axis=1, raw=True)
        self.net_acc = np.sqrt(net_acc_sq)

        gyr_data = self.all_data.iloc[3:7]
        self.omega_local = gyr_data.transpose()
        self.omega_local.columns.values[0:3] = ['omega_pitch', 'omega_roll',
                                                'omega_yaw']

        self.measurements = self.a_local.join(self.omega_local)

        self.frames = len(acc_data.columns)

    def __str__(self):
        return f"ImuData object.\nname:  '{self.name}'\nframes: {self.frames}"


class KF6:
    def __init__(self, dt, Q, R, x_0, P_0):
        
        self.dt = float(dt)
        
        Q = np.array(Q)
        if Q.shape != (16, 16):
            raise ValueError(
                '"Q" shape error: expected (16, 16), '
                f'got {Q.shape} instead'
            )
        self.Q = Q
        
        R = np.array(R)
        if R.shape != (6, 6):
            raise ValueError(
                '"R" shape error: expected (6, 6), '
                f'got {R.shape} instead'
            )
        self.R = R
        
        x_0 = np.array(x_0)
        if x_0.ndim != 2 or x_0.shape != (16, 1):
            raise ValueError(
                '"x_0" shape error: '
                'expected 2D column vector (16, 1), '
                f'got {x_0.shape} instead'
            )
        self.x = x_0
        
        P_0 = np.array(P_0)
        if P_0.shape != (16, 16):
            raise ValueError(
                '"P_0" shape error: expected (16, 16), '
                f'got {P_0.shape} instead'
            )
        self.P = P_0

    def get_A(self, x, dt):
        wN, wE, wD = [float(x[i, 0]) for i in range(13, 16)]
        A = np.diag(16*[1])
        for i in range(6):
            A[i,i+3] = dt
        bottom = np.array(
            [[1, -dt*wN/2, -dt*wE/2, -dt*wD/2],
            [dt*wN/2, 1, dt*wD/2, -dt*wE/2],
            [dt*wE/2, -dt*wD/2, 1, dt*wN/2],
            [dt*wD/2, dt*wE/2, -dt*wN/2, 1]],
            dtype=float
            )
        A[9:13,9:13] = bottom
        return A

    def quat2matrix(self, q):
        q0, q1, q2, q3 = [float(q[i,0]) for i in range(4)]
        C = np.array(
            [[1 - 2*(q2**2 + q3**2), 2*(q1*q2 - q0*q3),
              2*(q1*q3 + q0*q2)]
            ,[2*(q1*q2 + q0*q3), 1 - 2*(q1**2 + q3**2),
              2*(q2*q3 - q0*q1)]
            ,[2*(q1*q3 - q0*q2), 2*(q2*q3 + q0*q1),
              1 - 2*(q1**2 + q2**2)]]
            )
        return C.reshape(3,3)
    
    def get_H(self, x):
        C = self.quat2matrix(x[9:13])
        H = np.zeros((6, 16))
        H[0:3, 6:9] = C.T
        H[3:6, 13:16] = C.T
        return H
    
    def quat_norm(self, q):
        norm = np.linalg.norm(q)
        if norm == 0:
            raise ValueError('Cannot normalize a zero vector')
        return q/norm

    def predict(self):
        self.A = self.get_A(self.x, self.dt)
        self.H = self.get_H(self.x)
        self.xp = self.A @ self.x
        self.xp[5] -= 9.8*self.dt
        self.xp[9:13] = self.quat_norm(self.xp[9:13])
        self.Pp = self.A @ self.P @ self.A.T + self.Q

    def update(self, z):

        z = np.array(z).reshape(-1, 1)

        self.y = z - self.H @ self.xp
        self.K = self.Pp @ self.H.T @ np.linalg.inv(
            self.H @ self.Pp @ self.H.T + self.R)
        self.x = self.xp + self.K @ (self.y)
        self.P = (np.eye(16) - self.K @ self.H) @ self.Pp
        return self.x


# Get the list of participants (subdirectories)
participants = [d for d in os.listdir('trialData') if os.path.isdir(os.path.join('trialData', d))]

# Create a dropdown widget for participants
participant_dropdown = widgets.Dropdown(
    options=participants,
    description='Participant:',
    disabled=False,
)

# Create a dropdown widget for files
file_dropdown = widgets.Dropdown(
    options=[],
    description='File:',
    disabled=False,
)

# Update the file dropdown based on the selected participant
def update_files(*args):
    participant = participant_dropdown.value
    files = [f.rstrip('.csv') for f in os.listdir(os.path.join('trialData', participant))]
    file_dropdown.options = files
    file_dropdown.value = files[0]

participant_dropdown.observe(update_files, 'value')

# Display the widgets
display(participant_dropdown)

# Initialize the file dropdown
update_files()

# Display the file dropdown after it has been populated
display(file_dropdown)

output = widgets.Output()
display(output)

def update_output(*args):
    csv_path = f'trialData/{participant_dropdown.value}/{file_dropdown.value}.csv'
    trc_path = f'trcFiles/{participant_dropdown.value}/{file_dropdown.value}.trc'
    testdf = pd.read_csv(csv_path, index_col=0)
    testimu = imuData('test', testdf, 4)
    testdata = testimu.measurements
    testdata.iloc[:, 0:3] = testdata.iloc[:, 0:3] * 9.8
    testdata.iloc[:, 3:6] = testdata.iloc[:, 3:6] * np.pi / 180

    testacc = testimu.a_local
    testaccnet = testimu.net_acc
    testacc['a_net'] = testaccnet
    testacc['g'], testacc['-g'] = 1, -1
    testacc *= 9.8

    fig, ax = plt.subplots(figsize=(12, 4))
    testacc.plot(ax=ax, title=f'acceleration of {file_dropdown.value} lower left shank')

    # Calculate the cumulative distribution
    x_values = np.linspace(0.0, 1.0, 1000)
    y_values = [((testacc['a_net'] >= (9.86 - x)) & (testacc['a_net'] <= (9.8 + x))).mean() * 10 for x in x_values]

    fig2, ax2 = plt.subplots(figsize=(12, 4))
    ax2.plot(x_values, y_values, label='Cumulative Distribution')
    ax2.set_xlabel('Value')
    ax2.set_ylabel('Percentage')
    ax2.set_title('Cumulative Distribution of a_net')
    # ax2.legend()

    # Create a 3D scatter plot
    proximity_to_9_8 = 0.01 /(1. + 100.*abs(testacc['a_net'][:2500] - 9.8))  # Calculate proximity to 9.8
    fig3 = px.scatter_3d(
        x=testacc.iloc[:2500, 0],  # First column for x-axis
        y=testacc.iloc[:2500, 2],  # Third column for y-axis
        z=testacc.iloc[:2500, 1],
        size=proximity_to_9_8,
        size_max=10,
        title='3D Scatter Plot of Testacc'
    )
    fig3.update_layout(
        scene=dict(
            xaxis=dict(range=[-15, 15]),  # Specify the range for the x-axis
            yaxis=dict(range=[-15, 15]),  # Specify the range for the y-axis
            zaxis=dict(range=[-15, 15]),  # Specify the range for the z-axis
            aspectmode='cube'
        )
    )
    fig3.add_trace(go.Scatter3d(
    x=[-15, 15, None, 0, 0, None, 0, 0],
    y=[0, 0, None, -15, 15, None, 0, 0],
    z=[0, 0, None, 0, 0, None, -15, 15],
    mode='lines',
    line=dict(color='black', width=2),
    showlegend=False
    ))


    with output:
        clear_output(wait=True)
        # plot_trc(trc_path, output)
        display(testdata)
        display(testacc['a_net'][:2500].min(), testacc['a_net'][:2500].max())
        display(fig)
        display(fig2)
        display(fig3)
        # plt.show()
        # display(testacc)
    plt.close(fig)
    plt.close(fig2)

file_dropdown.observe(update_output, 'value')
# update_output()

Dropdown(description='Participant:', options=('A01', 'A02', 'A03', 'A04', 'A05', 'A06', 'C03', 'C04', 'C05', '…

Dropdown(description='File:', options=('A01_Fast_01', 'A01_Fast_02', 'A01_Fast_03', 'A01_Fast_04', 'A01_Fast_0…

Output()