In [None]:
import sys
import os

sys.path.append("../../")

import zarr
import numpy as np

from plotly.offline import init_notebook_mode, iplot
from plotly.subplots import make_subplots
import plotly.graph_objs as go
import plotly.io as pio
import plotly.express as px

from PyriteUtility.computer_vision.imagecodecs_numcodecs import register_codecs
register_codecs()

pio.templates.default = "plotly_dark"
pio.renderers.default = "vscode"
init_notebook_mode(connected=True)

# import matplotlib.pyplot as plt
dataset_path = "/path/to/dataset/folder"
store = zarr.DirectoryStore(path=dataset_path)
buffer = zarr.open(store=store, mode="a")

# buffer.tree()


In [None]:
from spatialmath import SE3
from spatialmath.base import q2r, r2q

for ep, ep_data in buffer['data'].items():
    print(ep)
    
        # episode_1722317514
        # camera0_rgb
        # low_dim_time_stamps
        # stiffness
        # ts_pose_command
        # ts_pose_fb
        # ts_pose_virtual_target
        # visual_time_stamps
        # wrench
        # wrench_filtered
    
    ts_pose_fb = ep_data["ts_pose_fb"][:]
    ts_pose_command = ep_data["ts_pose_command"][:]
    ts_pose_virtual_target = ep_data["ts_pose_virtual_target"][:]
    wrench = ep_data["wrench"][:]
    wrench_filtered = ep_data["wrench_filtered"][:]
    times = ep_data["low_dim_time_stamps"][:]

    # compute wrench in the world frame
    N = ts_pose_fb.shape[0]
    wrench_W = np.zeros((N, 6))
    wrench_filtered_W = np.zeros((N, 6))
    for i in range(N):
        pose7_WT = ts_pose_fb[i]
        wrenchi_T = wrench[i]
        wrenchi_filtered_T = wrench_filtered[i]
        SE3_WT = SE3.Rt(q2r(pose7_WT[3:7]), pose7_WT[0:3], check=False)
        SE3_TW = SE3_WT.inv()
        wrench_W[i, :] = SE3_TW.Ad().T @ wrenchi_T
        wrench_filtered_W[i, :] = SE3_TW.Ad().T @ wrenchi_filtered_T
    
    
    delta_pose = ts_pose_virtual_target - ts_pose_fb
    

    fig = make_subplots(
        rows=4, cols=3,
        shared_xaxes=True, subplot_titles=('X', 'Y', 'Z',
                                           'dX', 'dY', 'dZ',
                                           'fx', 'fy', 'fz',
                                           'fx_filtered', 'fy_filtered', 'fz_filtered'),)

    fig.add_trace(go.Scatter(x=times, y=ts_pose_fb[:,0], name='ts_pose_fb0'),row=1, col=1)
    fig.add_trace(go.Scatter(x=times, y=ts_pose_fb[:,1], name='ts_pose_fb1'),row=1, col=2)
    fig.add_trace(go.Scatter(x=times, y=ts_pose_fb[:,2], name='ts_pose_fb2'),row=1, col=3)

    fig.add_trace(go.Scatter(x=times, y=ts_pose_virtual_target[:,0], name='ts_pose_fvirtual_target'),row=1, col=1)
    fig.add_trace(go.Scatter(x=times, y=ts_pose_virtual_target[:,1], name='ts_pose_fvirtual_target'),row=1, col=2)
    fig.add_trace(go.Scatter(x=times, y=ts_pose_virtual_target[:,2], name='ts_pose_fvirtual_target'),row=1, col=3)

    fig.add_trace(go.Scatter(x=times, y=delta_pose[:,0], name='delta_pose0'),row=2, col=1)
    fig.add_trace(go.Scatter(x=times, y=delta_pose[:,1], name='delta_pose1'),row=2, col=2)
    fig.add_trace(go.Scatter(x=times, y=delta_pose[:,2], name='delta_pose2'),row=2, col=3)

    fig.add_trace(go.Scatter(x=times, y=wrench_W[:,0], name='wrench_W0'),row=3, col=1)
    fig.add_trace(go.Scatter(x=times, y=wrench_W[:,1], name='wrench_W1'),row=3, col=2)
    fig.add_trace(go.Scatter(x=times, y=wrench_W[:,2], name='wrench_W2'),row=3, col=3)

    fig.add_trace(go.Scatter(x=times, y=wrench_filtered_W[:,0], name='wrench_filtered_W0'),row=4, col=1)
    fig.add_trace(go.Scatter(x=times, y=wrench_filtered_W[:,1], name='wrench_filtered_W1'),row=4, col=2)
    fig.add_trace(go.Scatter(x=times, y=wrench_filtered_W[:,2], name='wrench_filtered_W2'),row=4, col=3)
    
    fig.update_layout(height=600, width=1200, title_text=ep)
    fig.show()
    # fig.write_html('output.html')
    
    input("Press Enter to continue...")

