# Arm Visualizer
Uses Plotly to visualize the configuration of an arm given a set of joint angles.

### Adjust Path
In order to use the RobotArm class located in `inverse_kinematics_prediction/utils/robot_arm.py` we need to add this project to our system path.

In [52]:
import sys
import os

# Add the parent directory (project root) relative to the current working directory.
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

print("Current working directory:", os.getcwd())
print("Project root added to sys.path:", project_root)

Current working directory: c:\Users\cjsta\git\inverse_kinematics_prediction\notebooks
Project root added to sys.path: c:\Users\cjsta\git\inverse_kinematics_prediction


### Create RobotArm Object

In [53]:
import numpy as np

from utils.robot_arm import RobotArm

# Create a 4-DOF robot arm with 5 links, each of length 1.0 unit.
num_joints = 4
joint_limits = [(-1*np.pi, 1*np.pi), (-0.8*np.pi, 0.8*np.pi), (-0.8*np.pi, 0.8*np.pi), (-0.8*np.pi, 0.8*np.pi)]
rotation_axes = ['z', 'y', 'y', 'y']
link_lengths = [0.2, 0.2, 0.5, 0.5, 0.2]
link_axes = ['z', 'z', 'z', 'z', 'z']

# Instantiate the RobotArm object
robot_arm = RobotArm(num_joints=num_joints, 
                     joint_limits=joint_limits, 
                     link_lengths=link_lengths, 
                     rotation_axes=rotation_axes,
                     link_axes=link_axes)

# Test forward kinematics
test_angles = robot_arm.sample_random_joint_angles()
# test_angles = np.array([0.0, 0.0, 0.0, 0.0])

pos, quat = robot_arm.forward_kinematics(test_angles)
print("Joint angles (radians): ", test_angles)
print("End Effector Position (point x,y,z): ", pos)
print("End Effector Orientation (quaternion x,y,z,w): ", quat)

Joint angles (radians):  [ 2.93147599 -1.47213581  2.25517577 -0.92163089]
End Effector Position (point x,y,z):  [ 0.16868597 -0.03597471  1.00171875]
End Effector Orientation (quaternion x,y,z,w):  (np.float64(0.0688582613585474), np.float64(-0.007260867165628025), np.float64(0.992099709607962), np.float64(0.10461350699827655))


### Visualization Helper Functions

In [54]:
import numpy as np
import plotly.graph_objects as go

# --- Helper Functions for Plotly Visualization ---

def get_rotation_matrix_for_axis(axis_str):
    """
    Return a 3x3 rotation matrix that rotates a vector from the default (z-axis)
    to be aligned with the specified axis ('x', 'y', or 'z').
    Here the default cylinder is built along the z-axis.
    """
    axis_str = axis_str.lower()
    if axis_str == 'x':
        # Rotate from z to x: rotate -90° about y.
        angle = -np.pi/2
        R = np.array([
            [ np.cos(angle), 0, np.sin(angle)],
            [ 0,             1, 0],
            [-np.sin(angle), 0, np.cos(angle)]
        ])
    elif axis_str == 'y':
        # Rotate from z to y: rotate +90° about x.
        angle = np.pi/2
        R = np.array([
            [1, 0,             0],
            [0, np.cos(angle), -np.sin(angle)],
            [0, np.sin(angle),  np.cos(angle)]
        ])
    elif axis_str == 'z':
        R = np.eye(3)
    else:
        raise ValueError("Invalid axis string. Must be 'x', 'y', or 'z'.")
    return R

def cylinder_mesh(center, axis_str, radius=0.05, height=0.15, resolution=20):
    """
    Compute mesh grids (X, Y, Z) for a cylinder built along the default z-axis and then
    rotated so its long axis aligns with the specified axis (axis_str). The mesh is then translated 
    to the provided center.
    """
    z = np.linspace(-height/2, height/2, resolution)
    theta = np.linspace(0, 2*np.pi, resolution)
    theta_grid, z_grid = np.meshgrid(theta, z)
    x_grid = radius * np.cos(theta_grid)
    y_grid = radius * np.sin(theta_grid)
    # Flatten the grid for transformation.
    points = np.vstack((x_grid.ravel(), y_grid.ravel(), z_grid.ravel()))
    
    # Rotate from default z-axis to desired axis.
    R_align = get_rotation_matrix_for_axis(axis_str)
    rotated_points = R_align.dot(points)
    
    # Reshape and translate.
    X = rotated_points[0, :].reshape(theta_grid.shape) + center[0]
    Y = rotated_points[1, :].reshape(theta_grid.shape) + center[1]
    Z = rotated_points[2, :].reshape(theta_grid.shape) + center[2]
    
    return X, Y, Z

def quaternion_to_rotation_matrix(quat):
    """
    Convert a quaternion (in ROS format: (x, y, z, w)) to a 3x3 rotation matrix.
    """
    x, y, z, w = quat
    R = np.array([
        [1 - 2*(y**2 + z**2),   2*(x*y - z*w),       2*(x*z + y*w)],
        [2*(x*y + z*w),         1 - 2*(x**2 + z**2), 2*(y*z - x*w)],
        [2*(x*z - y*w),         2*(y*z + x*w),       1 - 2*(x**2 + y**2)]
    ])
    return R

def create_circle_arc_points(center, rotation_axis, current_link_dir, next_link_dir, joint_limit, radius=0.05, n_points=60):
    """
    Create points forming an arc in 3D space within the joint limits.
    
    Parameters:
    - center: The center point of the circle (3D coordinates)
    - rotation_axis: A vector perpendicular to the circle's plane (rotation axis)
    - current_link_dir: Direction vector of the current link
    - next_link_dir: Direction vector of the next link
    - joint_limit: Tuple (min_angle, max_angle) in radians
    - radius: The radius of the circle
    - n_points: Number of points to generate along the arc
    
    Returns:
    - A numpy array containing the 3D coordinates of the arc points
    - The theta values used to generate the arc points
    """
    # Normalize vectors
    rotation_axis = rotation_axis / np.linalg.norm(rotation_axis)
    current_link_dir = current_link_dir / np.linalg.norm(current_link_dir)
    next_link_dir = next_link_dir / np.linalg.norm(next_link_dir)
    
    # For the reference angle, we need to use the next link direction
    # because that's what actually gets rotated by this joint
    
    # Project next_link_dir onto the plane perpendicular to rotation_axis
    next_link_dir_proj = next_link_dir - np.dot(next_link_dir, rotation_axis) * rotation_axis
    
    # If next_link_dir is parallel to rotation_axis, we need a different approach
    if np.linalg.norm(next_link_dir_proj) < 1e-5:
        # Use current_link_dir as a fallback
        current_link_dir_proj = current_link_dir - np.dot(current_link_dir, rotation_axis) * rotation_axis
        
        if np.linalg.norm(current_link_dir_proj) < 1e-5:
            # If both are parallel to rotation_axis, use a default vector
            if np.abs(rotation_axis[2]) < 0.9:
                v1 = np.array([-rotation_axis[1], rotation_axis[0], 0])
            else:
                v1 = np.array([1, 0, 0])
        else:
            v1 = current_link_dir_proj
    else:
        v1 = next_link_dir_proj
    
    # Normalize v1
    v1 = v1 / np.linalg.norm(v1)
    
    # v2 is perpendicular to both rotation_axis and v1
    v2 = np.cross(rotation_axis, v1)
    v2 = v2 / np.linalg.norm(v2)
    
    # Get the min and max angles
    min_angle, max_angle = joint_limit
    
    # Generate arc points within joint limits
    if max_angle - min_angle >= 2 * np.pi:
        # Full circle case
        theta = np.linspace(0, 2 * np.pi, n_points)
    else:
        # Arc case
        theta = np.linspace(min_angle, max_angle, n_points)
    
    # Generate arc points
    arc_points = center.reshape(1, 3) + radius * (np.outer(np.cos(theta), v1) + np.outer(np.sin(theta), v2))
    
    # Store v1 and v2 for use in calculating current position
    basis_vectors = (v1, v2)
    
    return arc_points, theta, basis_vectors


# Get joint rotation axes in world frame
def get_joint_rotation_axes(robot_arm, joint_angles):
    """
    Get the rotation axes for each joint in world coordinates along with joint limits and link directions.
    
    Parameters:
    - robot_arm: The RobotArm instance
    - joint_angles: Joint angles
    
    Returns:
    - List of (joint_position, rotation_axis, link_dir, next_link_dir, joint_limit) tuples in world frame
    """
    poses = robot_arm.get_joint_poses(joint_angles)
    rotation_axes_world = []
    
    for i in range(1, len(poses)-1):  # Skip base and end-effector
        joint_idx = i-1  # Joint index is offset by 1 from pose index
        joint_pos = poses[i][0]
        joint_orientation = poses[i][1]
        
        # Get rotation axis in local frame ('x', 'y', or 'z')
        local_rot_axis_str = robot_arm.rotation_axes[joint_idx]
        
        # Get the current and next link axis
        current_link_axis_str = robot_arm.link_axes[joint_idx]
        next_link_axis_str = robot_arm.link_axes[joint_idx + 1]
        
        # Convert rotation axis to vector
        if local_rot_axis_str == 'x':
            local_rot_axis = np.array([1, 0, 0])
        elif local_rot_axis_str == 'y':
            local_rot_axis = np.array([0, 1, 0])
        elif local_rot_axis_str == 'z':
            local_rot_axis = np.array([0, 0, 1])
            
        # Convert current link axis to vector
        if current_link_axis_str == 'x':
            current_link_dir = np.array([1, 0, 0])
        elif current_link_axis_str == 'y':
            current_link_dir = np.array([0, 1, 0])
        elif current_link_axis_str == 'z':
            current_link_dir = np.array([0, 0, 1])
            
        # Convert next link axis to vector
        if next_link_axis_str == 'x':
            next_link_dir = np.array([1, 0, 0])
        elif next_link_axis_str == 'y':
            next_link_dir = np.array([0, 1, 0])
        elif next_link_axis_str == 'z':
            next_link_dir = np.array([0, 0, 1])
        
        # Transform to world frame using joint orientation (quaternion)
        R = robot_arm._quaternion_to_rotation_matrix(joint_orientation)
        world_rot_axis = R @ local_rot_axis
        world_current_link_dir = R @ current_link_dir
        world_next_link_dir = R @ next_link_dir
        
        # Get joint limits
        joint_limit = robot_arm.joint_limits[joint_idx]
        
        rotation_axes_world.append((joint_pos, world_rot_axis, world_current_link_dir, world_next_link_dir, joint_limit))
    
    return rotation_axes_world

### Visualize
Let's visualize a sample configuration in 3D.

In [55]:
def visualize_robot_arm(robot_arm, joint_angles):
    # Get the full chain of poses
    poses = robot_arm.get_joint_poses(joint_angles)
    positions = np.array([pose[0] for pose in poses])
    
    # Get joint rotation axes and limits in world frame
    joint_axes = get_joint_rotation_axes(robot_arm, joint_angles)
    
    # Compute the overall data bounds
    x_min, x_max = positions[:,0].min(), positions[:,0].max()
    y_min, y_max = positions[:,1].min(), positions[:,1].max()
    z_min, z_max = positions[:,2].min(), positions[:,2].max()
    
    # Add some padding to ensure all markers are visible
    padding = 0.1  # Increased padding
    x_min -= padding
    x_max += padding
    y_min -= padding
    y_max += padding
    z_min -= padding
    z_max += padding
    
    x_center = (x_min + x_max) / 2
    y_center = (y_min + y_max) / 2
    z_center = (z_min + z_max) / 2
    max_range = max(x_max - x_min, y_max - y_min, z_max - z_min)
    half_range = max_range / 2
    
    # Set the axis range as a cube centered on the data
    x_range = [x_center - half_range, x_center + half_range]
    y_range = [y_center - half_range, y_center + half_range]
    z_range = [0, 2*half_range]  # Keep the Z-axis starting from 0
    
    # Create a Plotly figure
    fig = go.Figure()
    
    # Add a filled square base
    base_size = 0.05  # Reduced from 0.1 to 0.05
    
    # Create a surface for the base (filled square)
    x_base = np.array([[-base_size, base_size, base_size, -base_size],
                        [-base_size, base_size, base_size, -base_size]])
    y_base = np.array([[-base_size, -base_size, base_size, base_size],
                        [-base_size, -base_size, base_size, base_size]])
    z_base = np.array([[0, 0, 0, 0],
                        [0.005, 0.005, 0.005, 0.005]])  # Small thickness

    fig.add_trace(go.Surface(
        x=x_base,
        y=y_base,
        z=z_base,
        colorscale=[[0, 'darkblue'], [1, 'darkblue']],
        showscale=False,
        name='Base'
    ))
    
    # Add an outline for the base square
    base_corners = np.array([
        [-base_size, -base_size, 0],
        [base_size, -base_size, 0],
        [base_size, base_size, 0],
        [-base_size, base_size, 0],
        [-base_size, -base_size, 0]  # Close the square
    ])
    
    fig.add_trace(go.Scatter3d(
        x=base_corners[:,0],
        y=base_corners[:,1],
        z=base_corners[:,2],
        mode='lines',
        line=dict(color='darkblue', width=3),
        showlegend=False
    ))
    
    # Add a vertical line from the base square to the first joint
    base_center = np.array([0, 0, 0])
    first_joint = positions[1]  # Position of the first joint
    
    fig.add_trace(go.Scatter3d(
        x=[base_center[0], first_joint[0]],
        y=[base_center[1], first_joint[1]],
        z=[base_center[2], first_joint[2]],
        mode='lines',
        line=dict(color='darkblue', width=5),
        name='Base link'
    ))
    
    # Plot the links as a 3D line (starting from the first joint)
    fig.add_trace(go.Scatter3d(
        x=positions[1:,0],  # Skip the base position
        y=positions[1:,1],
        z=positions[1:,2],
        mode='lines+markers',
        line=dict(color='black', width=4),
        marker=dict(size=6, color='blue'),
        name='Links'
    ))
    
    # Plot rotation plane arcs for each joint
    for i, (joint_pos, rotation_axis, current_link_dir, next_link_dir, joint_limit) in enumerate(joint_axes):
        radius = 0.1  # Radius of the arc
        
        # Create a unique color for each joint
        joint_color = f'rgb({50+i*50}, {100+i*30}, {150-i*30})'

        # Create arc points
        arc_points, theta_values, basis_vectors = create_circle_arc_points(
            joint_pos, rotation_axis, current_link_dir, next_link_dir, joint_limit, radius=radius)
        
        v1, v2 = basis_vectors
        
        # Add the arc to the plot
        fig.add_trace(go.Scatter3d(
            x=arc_points[:,0],
            y=arc_points[:,1],
            z=arc_points[:,2],
            mode='lines',
            line=dict(color=joint_color, width=3),
            name=f'Joint {i+1} motion range'
        ))
        
        # Add markers for min and max positions if not a full circle
        min_angle, max_angle = joint_limit
        if max_angle - min_angle < 2 * np.pi:
            # Add the limits as dots with the same color as the motion range
            min_arc_pt = arc_points[0]
            max_arc_pt = arc_points[-1]
            
            fig.add_trace(go.Scatter3d(
                x=[min_arc_pt[0], max_arc_pt[0]],
                y=[min_arc_pt[1], max_arc_pt[1]],
                z=[min_arc_pt[2], max_arc_pt[2]],
                mode='markers',
                marker=dict(
                    size=8, 
                    color=joint_color,  # Use the same color as the arc
                    symbol='circle'
                ),
                name=f'Joint {i+1} limits'
            ))
        
        # Add a small arrow showing the rotation axis
        axis_length = 0.15
        axis_end = joint_pos + rotation_axis * axis_length
        
        fig.add_trace(go.Scatter3d(
            x=[joint_pos[0], axis_end[0]],
            y=[joint_pos[1], axis_end[1]],
            z=[joint_pos[2], axis_end[2]],
            mode='lines',
            line=dict(color=joint_color, width=4, dash='dot'),
            name=f'Joint {i+1} axis'
        ))
        
        # Calculate current position based on joint angle directly using the basis vectors
        current_angle = joint_angles[i]
        current_pt = joint_pos + radius * (np.cos(current_angle) * v1 + np.sin(current_angle) * v2)
        
        # Add current position marker
        fig.add_trace(go.Scatter3d(
            x=[current_pt[0]],
            y=[current_pt[1]],
            z=[current_pt[2]],
            mode='markers',
            marker=dict(
                size=8, 
                color='yellow',
                symbol='diamond'
            ),
            name=f'Joint {i+1} current position'
        ))
    
    # Visualize quaternions for all joints (including base and end-effector)
    # Loop through all poses
    frame_axis_length = 0.1  # Smaller length for joint frames than end-effector
    
    # Define colors for the coordinate frames
    axis_colors = ['red', 'green', 'blue']
    
    for i, (pos, quat) in enumerate(poses):
        # Convert quaternion to rotation matrix
        R = robot_arm._quaternion_to_rotation_matrix(quat)
        
        element_name = "Base" if i == 0 else "End-effector" if i == len(poses) - 1 else f"Joint {i}"
        
        # Scale down the axes for intermediate joints to avoid clutter
        actual_axis_length = frame_axis_length if i < len(poses) - 1 else 0.2  # Keep end-effector axes larger
        
        # For each axis (X, Y, Z)
        for j, axis_color in enumerate(axis_colors):
            axis_vec = R[:, j] * actual_axis_length
            
            # Only add the axis to the plot
            fig.add_trace(go.Scatter3d(
                x=[pos[0], pos[0] + axis_vec[0]],
                y=[pos[1], pos[1] + axis_vec[1]],
                z=[pos[2], pos[2] + axis_vec[2]],
                mode='lines',
                line=dict(color=axis_color, width=3 if i < len(poses) - 1 else 5),  # Thinner for joints, thicker for end-effector
                name=f'{element_name} {["X", "Y", "Z"][j]}',
                showlegend=True if j == 0 else False  # Only show one legend item per element for cleanliness
            ))
    
    # Update the layout with increased figure size
    fig.update_layout(
        width=1000,
        height=600,
        scene=dict(
            xaxis=dict(title="X", range=x_range),
            yaxis=dict(title="Y", range=y_range),
            zaxis=dict(title="Z", range=z_range),
            aspectmode='manual',
            aspectratio=dict(x=1, y=1, z=1),
            camera=dict(
                projection=dict(type="orthographic"),
                eye=dict(x=1.5, y=1.5, z=1.5) 
            )
        ),
        title="3D Visualization of Manipulator Configuration"
    )
    
    return fig

In [56]:
# Print joint angles, joint axis, and link axis
print("Joint angles (radians): ", test_angles)
print("Joint rotation axes (world frame): ", [robot_arm.rotation_axes[i] for i in range(num_joints)])
print("Link axes (world frame): ", [robot_arm.link_axes[i] for i in range(num_joints)])
print("End Effector Position (point x,y,z): ", pos)
print("End Effector Orientation (quaternion x,y,z,w): ", quat)

# Call the visualization function with your robot_arm and test_angles
fig = visualize_robot_arm(robot_arm, test_angles)
fig.show(config={
    'scrollZoom': True,  # Enable scroll zoom
    'modeBarButtonsToAdd': ['resetCameraDefault3d'],  # Add reset button
    'doubleClick': 'reset',  # Reset on double click
    'displayModeBar': True,  # Always display the mode bar
})

Joint angles (radians):  [ 2.93147599 -1.47213581  2.25517577 -0.92163089]
Joint rotation axes (world frame):  ['z', 'y', 'y', 'y']
Link axes (world frame):  ['z', 'z', 'z', 'z']
End Effector Position (point x,y,z):  [ 0.16868597 -0.03597471  1.00171875]
End Effector Orientation (quaternion x,y,z,w):  (np.float64(0.0688582613585474), np.float64(-0.007260867165628025), np.float64(0.992099709607962), np.float64(0.10461350699827655))
