# Pose Trajectory Visualization

This notebook visualizes the collected pose data from the OCD format zarr files.

In [1]:
import zarr
import numpy as np
from scipy.spatial.transform import Rotation as R
import plotly.graph_objects as go
from plotly.subplots import make_subplots

## Load Zarr Data

In [11]:
# Load the zarr data
zarr_path = '/home/maggie/research/object_centric_diffusion/data/maggie_block/episodes/episode_6/zarr'
root = zarr.open(zarr_path, 'r')

# Extract data arrays
state = np.array(root['data']['state'])  # Relative to target
state_in_world = np.array(root['data']['state_in_world'])  # World frame
goal = np.array(root['data']['goal'])  # Target object pose

print(f"State (relative) shape: {state.shape}")
print(f"State (world) shape: {state_in_world.shape}")
print(f"Goal shape: {goal.shape}")
print(f"Number of keyframes: {len(state)}")

State (relative) shape: (13, 7)
State (world) shape: (13, 7)
Goal shape: (13, 7)
Number of keyframes: 13


## Helper Functions

In [12]:
def quaternion_to_euler(quaternions):
    """
    Convert quaternions (x, y, z, w) to Euler angles (roll, pitch, yaw) in degrees.
    """
    euler_angles = []
    for quat in quaternions:
        # scipy uses (x, y, z, w) format which matches OCD format
        r = R.from_quat(quat)
        euler = r.as_euler('xyz', degrees=True)
        euler_angles.append(euler)
    return np.array(euler_angles)

def extract_positions(poses):
    """Extract x, y, z positions from 7D pose array."""
    return poses[:, :3]

def extract_quaternions(poses):
    """Extract quaternions from 7D pose array."""
    return poses[:, 3:]

## 1. World Frame Trajectory (Grasped Object in World Coordinates)

### 3D Trajectory Plot

In [15]:
# World frame 3D trajectory
pos_world = extract_positions(state_in_world)

fig = go.Figure()

# Trajectory line
fig.add_trace(go.Scatter3d(
    x=pos_world[:, 0],
    y=pos_world[:, 1],
    z=pos_world[:, 2],
    mode='lines+markers',
    marker=dict(size=4, color=np.arange(len(pos_world)), colorscale='Viridis', showscale=True, colorbar=dict(title='Time')),
    line=dict(color='blue', width=2),
    name='Trajectory'
))

# Start point
fig.add_trace(go.Scatter3d(
    x=[pos_world[0, 0]],
    y=[pos_world[0, 1]],
    z=[pos_world[0, 2]],
    mode='markers',
    marker=dict(size=10, color='green', symbol='diamond'),
    name='Start'
))

# End point
fig.add_trace(go.Scatter3d(
    x=[pos_world[-1, 0]],
    y=[pos_world[-1, 1]],
    z=[pos_world[-1, 2]],
    mode='markers',
    marker=dict(size=10, color='red', symbol='diamond'),
    name='End'
))

fig.update_layout(
    title='World Frame Trajectory (Grasped Object)',
    scene=dict(
        xaxis_title='X (m)',
        yaxis_title='Y (m)',
        zaxis_title='Z (m)',
        aspectmode='data'
    ),
    width=800,
    height=600
)

fig.show()

### Position Time Series

In [14]:
# World frame position time series
time_steps = np.arange(len(pos_world))

fig = make_subplots(rows=3, cols=1, shared_xaxes=True,
                    subplot_titles=('X Position', 'Y Position', 'Z Position'),
                    vertical_spacing=0.08)

fig.add_trace(go.Scatter(x=time_steps, y=pos_world[:, 0], mode='lines+markers', name='X', line=dict(color='red')), row=1, col=1)
fig.add_trace(go.Scatter(x=time_steps, y=pos_world[:, 1], mode='lines+markers', name='Y', line=dict(color='green')), row=2, col=1)
fig.add_trace(go.Scatter(x=time_steps, y=pos_world[:, 2], mode='lines+markers', name='Z', line=dict(color='blue')), row=3, col=1)

fig.update_xaxes(title_text='Keyframe', row=3, col=1)
fig.update_yaxes(title_text='m', row=1, col=1)
fig.update_yaxes(title_text='m', row=2, col=1)
fig.update_yaxes(title_text='m', row=3, col=1)

fig.update_layout(title='World Frame Position Over Time', height=600, showlegend=False)
fig.show()

### Rotation Time Series (Euler Angles)

In [6]:
# World frame rotation time series
quat_world = extract_quaternions(state_in_world)
euler_world = quaternion_to_euler(quat_world)

fig = make_subplots(rows=3, cols=1, shared_xaxes=True,
                    subplot_titles=('Roll', 'Pitch', 'Yaw'),
                    vertical_spacing=0.08)

fig.add_trace(go.Scatter(x=time_steps, y=euler_world[:, 0], mode='lines+markers', name='Roll', line=dict(color='red')), row=1, col=1)
fig.add_trace(go.Scatter(x=time_steps, y=euler_world[:, 1], mode='lines+markers', name='Pitch', line=dict(color='green')), row=2, col=1)
fig.add_trace(go.Scatter(x=time_steps, y=euler_world[:, 2], mode='lines+markers', name='Yaw', line=dict(color='blue')), row=3, col=1)

fig.update_xaxes(title_text='Keyframe', row=3, col=1)
fig.update_yaxes(title_text='degrees', row=1, col=1)
fig.update_yaxes(title_text='degrees', row=2, col=1)
fig.update_yaxes(title_text='degrees', row=3, col=1)

fig.update_layout(title='World Frame Rotation Over Time (Euler Angles)', height=600, showlegend=False)
fig.show()

## 2. Goal Pose Trajectory (Target Object)

### 3D Trajectory Plot

In [7]:
# Goal pose 3D trajectory
pos_goal = extract_positions(goal)

fig = go.Figure()

# Trajectory line
fig.add_trace(go.Scatter3d(
    x=pos_goal[:, 0],
    y=pos_goal[:, 1],
    z=pos_goal[:, 2],
    mode='lines+markers',
    marker=dict(size=4, color=np.arange(len(pos_goal)), colorscale='Plasma', showscale=True, colorbar=dict(title='Time')),
    line=dict(color='orange', width=2),
    name='Goal Trajectory'
))

# Start point
fig.add_trace(go.Scatter3d(
    x=[pos_goal[0, 0]],
    y=[pos_goal[0, 1]],
    z=[pos_goal[0, 2]],
    mode='markers',
    marker=dict(size=10, color='green', symbol='diamond'),
    name='Start'
))

# End point
fig.add_trace(go.Scatter3d(
    x=[pos_goal[-1, 0]],
    y=[pos_goal[-1, 1]],
    z=[pos_goal[-1, 2]],
    mode='markers',
    marker=dict(size=10, color='red', symbol='diamond'),
    name='End'
))

fig.update_layout(
    title='Goal Pose Trajectory (Target Object)',
    scene=dict(
        xaxis_title='X (m)',
        yaxis_title='Y (m)',
        zaxis_title='Z (m)',
        aspectmode='data'
    ),
    width=800,
    height=600
)

fig.show()

### Position Time Series

In [8]:
# Goal position time series
fig = make_subplots(rows=3, cols=1, shared_xaxes=True,
                    subplot_titles=('X Position', 'Y Position', 'Z Position'),
                    vertical_spacing=0.08)

fig.add_trace(go.Scatter(x=time_steps, y=pos_goal[:, 0], mode='lines+markers', name='X', line=dict(color='red')), row=1, col=1)
fig.add_trace(go.Scatter(x=time_steps, y=pos_goal[:, 1], mode='lines+markers', name='Y', line=dict(color='green')), row=2, col=1)
fig.add_trace(go.Scatter(x=time_steps, y=pos_goal[:, 2], mode='lines+markers', name='Z', line=dict(color='blue')), row=3, col=1)

fig.update_xaxes(title_text='Keyframe', row=3, col=1)
fig.update_yaxes(title_text='m', row=1, col=1)
fig.update_yaxes(title_text='m', row=2, col=1)
fig.update_yaxes(title_text='m', row=3, col=1)

fig.update_layout(title='Goal Position Over Time', height=600, showlegend=False)
fig.show()

### Rotation Time Series (Euler Angles)

In [9]:
# Goal rotation time series
quat_goal = extract_quaternions(goal)
euler_goal = quaternion_to_euler(quat_goal)

fig = make_subplots(rows=3, cols=1, shared_xaxes=True,
                    subplot_titles=('Roll', 'Pitch', 'Yaw'),
                    vertical_spacing=0.08)

fig.add_trace(go.Scatter(x=time_steps, y=euler_goal[:, 0], mode='lines+markers', name='Roll', line=dict(color='red')), row=1, col=1)
fig.add_trace(go.Scatter(x=time_steps, y=euler_goal[:, 1], mode='lines+markers', name='Pitch', line=dict(color='green')), row=2, col=1)
fig.add_trace(go.Scatter(x=time_steps, y=euler_goal[:, 2], mode='lines+markers', name='Yaw', line=dict(color='blue')), row=3, col=1)

fig.update_xaxes(title_text='Keyframe', row=3, col=1)
fig.update_yaxes(title_text='degrees', row=1, col=1)
fig.update_yaxes(title_text='degrees', row=2, col=1)
fig.update_yaxes(title_text='degrees', row=3, col=1)

fig.update_layout(title='Goal Rotation Over Time (Euler Angles)', height=600, showlegend=False)
fig.show()

## 3. Relative Trajectory (Grasped Object Relative to Target)

### 3D Trajectory Plot

In [10]:
# Relative pose 3D trajectory
pos_relative = extract_positions(state)

fig = go.Figure()

# Trajectory line
fig.add_trace(go.Scatter3d(
    x=pos_relative[:, 0],
    y=pos_relative[:, 1],
    z=pos_relative[:, 2],
    mode='lines+markers',
    marker=dict(size=4, color=np.arange(len(pos_relative)), colorscale='Cividis', showscale=True, colorbar=dict(title='Time')),
    line=dict(color='purple', width=2),
    name='Relative Trajectory'
))

# Start point
fig.add_trace(go.Scatter3d(
    x=[pos_relative[0, 0]],
    y=[pos_relative[0, 1]],
    z=[pos_relative[0, 2]],
    mode='markers',
    marker=dict(size=10, color='green', symbol='diamond'),
    name='Start'
))

# End point
fig.add_trace(go.Scatter3d(
    x=[pos_relative[-1, 0]],
    y=[pos_relative[-1, 1]],
    z=[pos_relative[-1, 2]],
    mode='markers',
    marker=dict(size=10, color='red', symbol='diamond'),
    name='End'
))

# Origin (target object location in its own frame)
fig.add_trace(go.Scatter3d(
    x=[0],
    y=[0],
    z=[0],
    mode='markers',
    marker=dict(size=12, color='yellow', symbol='x'),
    name='Target Origin'
))

fig.update_layout(
    title='Relative Trajectory (Grasped Object in Target Frame)',
    scene=dict(
        xaxis_title='X (m)',
        yaxis_title='Y (m)',
        zaxis_title='Z (m)',
        aspectmode='data'
    ),
    width=800,
    height=600
)

fig.show()

### Position Time Series

In [None]:
# Relative position time series
fig = make_subplots(rows=3, cols=1, shared_xaxes=True,
                    subplot_titles=('X Position', 'Y Position', 'Z Position'),
                    vertical_spacing=0.08)

fig.add_trace(go.Scatter(x=time_steps, y=pos_relative[:, 0], mode='lines+markers', name='X', line=dict(color='red')), row=1, col=1)
fig.add_trace(go.Scatter(x=time_steps, y=pos_relative[:, 1], mode='lines+markers', name='Y', line=dict(color='green')), row=2, col=1)
fig.add_trace(go.Scatter(x=time_steps, y=pos_relative[:, 2], mode='lines+markers', name='Z', line=dict(color='blue')), row=3, col=1)

fig.update_xaxes(title_text='Keyframe', row=3, col=1)
fig.update_yaxes(title_text='m', row=1, col=1)
fig.update_yaxes(title_text='m', row=2, col=1)
fig.update_yaxes(title_text='m', row=3, col=1)

fig.update_layout(title='Relative Position Over Time', height=600, showlegend=False)
fig.show()

### Rotation Time Series (Euler Angles)

In [None]:
# Relative rotation time series
quat_relative = extract_quaternions(state)
euler_relative = quaternion_to_euler(quat_relative)

fig = make_subplots(rows=3, cols=1, shared_xaxes=True,
                    subplot_titles=('Roll', 'Pitch', 'Yaw'),
                    vertical_spacing=0.08)

fig.add_trace(go.Scatter(x=time_steps, y=euler_relative[:, 0], mode='lines+markers', name='Roll', line=dict(color='red')), row=1, col=1)
fig.add_trace(go.Scatter(x=time_steps, y=euler_relative[:, 1], mode='lines+markers', name='Pitch', line=dict(color='green')), row=2, col=1)
fig.add_trace(go.Scatter(x=time_steps, y=euler_relative[:, 2], mode='lines+markers', name='Yaw', line=dict(color='blue')), row=3, col=1)

fig.update_xaxes(title_text='Keyframe', row=3, col=1)
fig.update_yaxes(title_text='degrees', row=1, col=1)
fig.update_yaxes(title_text='degrees', row=2, col=1)
fig.update_yaxes(title_text='degrees', row=3, col=1)

fig.update_layout(title='Relative Rotation Over Time (Euler Angles)', height=600, showlegend=False)
fig.show()

## Summary Statistics

In [None]:
print("=" * 50)
print("TRAJECTORY SUMMARY")
print("=" * 50)
print(f"\nNumber of keyframes: {len(state)}")
print(f"\nWorld Frame Position:")
print(f"  Start: [{pos_world[0, 0]:.3f}, {pos_world[0, 1]:.3f}, {pos_world[0, 2]:.3f}]")
print(f"  End:   [{pos_world[-1, 0]:.3f}, {pos_world[-1, 1]:.3f}, {pos_world[-1, 2]:.3f}]")
print(f"  Range: X=[{pos_world[:, 0].min():.3f}, {pos_world[:, 0].max():.3f}], "
      f"Y=[{pos_world[:, 1].min():.3f}, {pos_world[:, 1].max():.3f}], "
      f"Z=[{pos_world[:, 2].min():.3f}, {pos_world[:, 2].max():.3f}]")

print(f"\nGoal Position:")
print(f"  Position: [{pos_goal[0, 0]:.3f}, {pos_goal[0, 1]:.3f}, {pos_goal[0, 2]:.3f}]")
print(f"  (Static: {np.allclose(pos_goal, pos_goal[0])})")

print(f"\nRelative Position:")
print(f"  Start: [{pos_relative[0, 0]:.3f}, {pos_relative[0, 1]:.3f}, {pos_relative[0, 2]:.3f}]")
print(f"  End:   [{pos_relative[-1, 0]:.3f}, {pos_relative[-1, 1]:.3f}, {pos_relative[-1, 2]:.3f}]")
print(f"  Range: X=[{pos_relative[:, 0].min():.3f}, {pos_relative[:, 0].max():.3f}], "
      f"Y=[{pos_relative[:, 1].min():.3f}, {pos_relative[:, 1].max():.3f}], "
      f"Z=[{pos_relative[:, 2].min():.3f}, {pos_relative[:, 2].max():.3f}]")

# Distance traveled
distances = np.linalg.norm(np.diff(pos_world, axis=0), axis=1)
print(f"\nTotal distance traveled: {distances.sum():.3f} m")
print(f"Average step size: {distances.mean():.4f} m")