In [16]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import ipywidgets as widgets
from IPython.display import display, clear_output
import re

class BVHJoint:
    def __init__(self, name, parent=None):
        self.name = name
        self.parent = parent
        self.children = []
        self.offset = np.zeros(3)
        self.channels = []
        self.end_site = False
        
    def add_child(self, child):
        self.children.append(child)
        
class BVHMotion:
    def __init__(self):
        self.joints = {}  # name -> BVHJoint
        self.root = None
        self.frames = 0
        self.frame_time = 0
        self.motion_data = []
        
    def parse_file(self, file_path):
        """Parse a BVH file."""
        with open(file_path, 'r') as file:
            content = file.read()
            
        # Split into hierarchy and motion parts
        if 'MOTION' in content:
            hierarchy, motion = content.split('MOTION', 1)
        else:
            hierarchy = content
            motion = ""
            
        # Parse hierarchy
        self._parse_hierarchy(hierarchy)
        
        # Parse motion
        if motion:
            self._parse_motion(motion)
            
    def _parse_hierarchy(self, text):
        """Parse the HIERARCHY section of a BVH file."""
        lines = [line.strip() for line in text.splitlines() if line.strip()]
        
        joint_stack = []
        joint = None
        
        for line in lines:
            if 'HIERARCHY' in line:
                continue
                
            if 'ROOT' in line:
                name = line.split('ROOT', 1)[1].strip()
                joint = BVHJoint(name)
                self.root = joint
                self.joints[name] = joint
                joint_stack.append(joint)
                
            elif 'JOINT' in line:
                if not joint_stack:
                    continue
                    
                name = line.split('JOINT', 1)[1].strip()
                parent = joint_stack[-1]
                joint = BVHJoint(name, parent)
                parent.add_child(joint)
                self.joints[name] = joint
                joint_stack.append(joint)
                
            elif 'End Site' in line:
                if not joint_stack:
                    continue
                    
                # End sites are usually unnamed in BVH, we'll use parent name + "_end"
                name = f"{joint_stack[-1].name}_end"
                parent = joint_stack[-1]
                
                end_joint = BVHJoint(name, parent)
                end_joint.end_site = True
                parent.add_child(end_joint)
                self.joints[name] = end_joint
                joint_stack.append(end_joint)
                
            elif '{' in line:
                continue
                
            elif '}' in line:
                if joint_stack:
                    joint_stack.pop()
                    
            elif 'OFFSET' in line:
                if not joint_stack:
                    continue
                    
                offset = [float(x) for x in line.split('OFFSET')[1].strip().split()]
                joint_stack[-1].offset = np.array(offset)
                
            elif 'CHANNELS' in line:
                if not joint_stack:
                    continue
                    
                parts = line.split()
                num_channels = int(parts[1])
                joint_stack[-1].channels = parts[2:2+num_channels]
                
    def _parse_motion(self, text):
        """Parse the MOTION section of a BVH file."""
        lines = [line.strip() for line in text.splitlines() if line.strip()]
        
        for i, line in enumerate(lines):
            if 'Frames:' in line:
                self.frames = int(line.split(':', 1)[1].strip())
            elif 'Frame Time:' in line:
                self.frame_time = float(line.split(':', 1)[1].strip())
            else:
                # This is motion data
                try:
                    values = [float(x) for x in line.split()]
                    self.motion_data.append(values)
                except ValueError:
                    pass  # Skip lines that can't be parsed as floats
                
    def extract_frame_data(self, frame_idx):
        """
        Extract joint positions for a specific frame.
        
        Args:
            frame_idx: Frame index to extract
            
        Returns:
            Dict mapping joint names to 3D positions
        """
        if frame_idx >= self.frames:
            raise ValueError(f"Frame index {frame_idx} out of range (0-{self.frames-1})")
            
        if not self.motion_data:
            return {}
            
        frame_data = self.motion_data[frame_idx]
        positions = {}
        
        # Start with the root
        if self.root:
            self._update_joint_position(self.root, frame_data, 0, np.zeros(3), np.identity(3), positions)
            
        return positions
        
    def _update_joint_position(self, joint, frame_data, channel_offset, parent_pos, parent_rotation, positions):
        """
        Recursively update joint positions.
        
        Args:
            joint: The joint to update
            frame_data: Motion data for the current frame
            channel_offset: Current channel offset in the frame data
            parent_pos: Parent joint position
            parent_rotation: Parent joint rotation matrix
            positions: Dictionary to store joint positions
        
        Returns:
            Next channel offset
        """
        # Apply parent position and rotation to get this joint's position
        joint_pos = parent_pos.copy()
        
        # If this is not an end site and has channels
        if not joint.end_site and joint.channels:
            # Handle root position if present
            if joint.parent is None:  # Root joint
                position_indices = [i for i, ch in enumerate(joint.channels) if 'position' in ch.lower()]
                
                for idx in position_indices:
                    if 'Xposition' in joint.channels[idx]:
                        joint_pos[0] = frame_data[channel_offset + idx]
                    elif 'Yposition' in joint.channels[idx]:
                        joint_pos[1] = frame_data[channel_offset + idx]
                    elif 'Zposition' in joint.channels[idx]:
                        joint_pos[2] = frame_data[channel_offset + idx]
            
            # Handle rotation
            rotation_indices = [i for i, ch in enumerate(joint.channels) if 'rotation' in ch.lower()]
            local_rotation = np.identity(3)
            
            for idx in rotation_indices:
                angle_deg = frame_data[channel_offset + idx]
                angle_rad = np.radians(angle_deg)
                
                if 'Xrotation' in joint.channels[idx]:
                    rot_mat = np.array([
                        [1, 0, 0],
                        [0, np.cos(angle_rad), -np.sin(angle_rad)],
                        [0, np.sin(angle_rad), np.cos(angle_rad)]
                    ])
                elif 'Yrotation' in joint.channels[idx]:
                    rot_mat = np.array([
                        [np.cos(angle_rad), 0, np.sin(angle_rad)],
                        [0, 1, 0],
                        [-np.sin(angle_rad), 0, np.cos(angle_rad)]
                    ])
                elif 'Zrotation' in joint.channels[idx]:
                    rot_mat = np.array([
                        [np.cos(angle_rad), -np.sin(angle_rad), 0],
                        [np.sin(angle_rad), np.cos(angle_rad), 0],
                        [0, 0, 1]
                    ])
                else:
                    rot_mat = np.identity(3)
                    
                local_rotation = np.dot(local_rotation, rot_mat)
                
            # Update channel offset
            channel_offset += len(joint.channels)
        else:
            # End sites don't have channels, use identity rotation
            local_rotation = np.identity(3)
        
        # Calculate rotation
        rotation = np.dot(parent_rotation, local_rotation)
        
        # Add offset to position (rotated by parent's rotation)
        offset_rotated = np.dot(parent_rotation, joint.offset)
        joint_pos += offset_rotated
        
        # Store the position
        positions[joint.name] = joint_pos
        
        # Process children
        for child in joint.children:
            channel_offset = self._update_joint_position(
                child, frame_data, channel_offset, joint_pos, rotation, positions
            )
            
        return channel_offset

def preview_bvh(file_path, show_labels=False, use_plotly=False):
    """
    Preview a BVH animation file in a Jupyter notebook.
    
    Args:
        file_path: Path to the BVH file
        show_labels: Whether to show joint labels
        use_plotly: Whether to use Plotly instead of Matplotlib (if available)
    
    Returns:
        BVHMotion object
    """
    # Load the BVH file
    bvh = BVHMotion()
    bvh.parse_file(file_path)
    
    # Check if we should use Plotly
    if use_plotly:
        try:
            import plotly.graph_objects as go
            return _preview_with_plotly(bvh, show_labels)
        except ImportError:
            print("Plotly not available, falling back to Matplotlib")
    
    # Default to Matplotlib
    return _preview_with_matplotlib(bvh, show_labels)

def _preview_with_matplotlib(bvh, show_labels=False):
    """Use Matplotlib for visualization."""
    # Create the figure for the animation
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')
    plt.close(fig)  # Close the figure so it doesn't display immediately
    
    # Get positions for all frames (pre-calculate for animation performance)
    all_positions = []
    for i in range(bvh.frames):
        all_positions.append(bvh.extract_frame_data(i))
    
    # Calculate global limits for consistent view
    min_vals = np.array([float('inf'), float('inf'), float('inf')])
    max_vals = np.array([float('-inf'), float('-inf'), float('-inf')])
    
    for positions in all_positions:
        if positions:
            pos_array = np.array([pos for pos in positions.values()])
            min_vals = np.minimum(min_vals, np.min(pos_array, axis=0))
            max_vals = np.maximum(max_vals, np.max(pos_array, axis=0))
    
    center = (min_vals + max_vals) / 2
    max_range = np.max(max_vals - min_vals) / 2
    
    # Store scatter and line objects for updating
    scatter_points = None
    line_collections = []
    text_objects = []
    
    # Function to initialize or update the plot
    def update_plot(frame_idx, show_joint_labels):
        nonlocal scatter_points, line_collections, text_objects
        
        # Clear previous frame elements
        ax.clear()
        line_collections = []
        text_objects = []
        
        # Get joint positions for the current frame
        positions = all_positions[frame_idx]
        
        # Prepare data for plotting
        points = []
        lines = []
        labels = []
        
        for joint_name, pos in positions.items():
            # Swap Y and Z for better visualization
            plot_pos = [pos[0], pos[2], pos[1]]
            points.append(plot_pos)
            
            if show_joint_labels:
                text_obj = ax.text(plot_pos[0], plot_pos[1], plot_pos[2], joint_name, size=8)
                text_objects.append(text_obj)
            
            # If this joint has a parent, prepare line to it
            joint = bvh.joints[joint_name]
            if joint.parent:
                parent_pos = positions[joint.parent.name]
                parent_plot_pos = [parent_pos[0], parent_pos[2], parent_pos[1]]
                
                lines.append([plot_pos, parent_plot_pos])
        
        # Plot joints
        xs, ys, zs = zip(*points) if points else ([], [], [])
        scatter_points = ax.scatter(xs, ys, zs, color='blue', s=20)
        
        # Plot lines between joints
        for line in lines:
            line_obj = ax.plot(
                [line[0][0], line[1][0]], 
                [line[0][1], line[1][1]], 
                [line[0][2], line[1][2]], 
                color='red'
            )[0]
            line_collections.append(line_obj)
        
        # Set axis properties
        ax.set_xlabel('X')
        ax.set_ylabel('Z')
        ax.set_zlabel('Y')
        
        # Set consistent viewing angle and limits
        if max_range > 0:
            ax.set_xlim(center[0] - max_range, center[0] + max_range)
            ax.set_ylim(center[2] - max_range, center[2] + max_range)
            ax.set_zlim(center[1] - max_range, center[1] + max_range)
        
        ax.set_title(f"Frame {frame_idx} / {bvh.frames-1}")
        
        # Return updated objects
        return [scatter_points] + line_collections + text_objects
    
    # Create output widget
    output_widget = widgets.Output()
    
    # Create animation control widgets
    frame_slider = widgets.IntSlider(
        value=0,
        min=0,
        max=bvh.frames-1,
        step=1,
        description='Frame:',
        continuous_update=True
    )
    
    play_button = widgets.Play(
        value=0,
        min=0,
        max=bvh.frames-1,
        step=1,
        interval=int(bvh.frame_time * 1000),  # Convert to milliseconds
        description="Play",
        disabled=False
    )
    
    speed_slider = widgets.FloatSlider(
        value=1.0,
        min=0.1,
        max=5.0,
        step=0.1,
        description='Speed:',
        continuous_update=True
    )
    
    labels_checkbox = widgets.Checkbox(
        value=show_labels,
        description='Show labels',
        disabled=False
    )
    
    # Function to update playback speed
    def update_speed(change):
        play_button.interval = int((bvh.frame_time * 1000) / speed_slider.value)
    
    speed_slider.observe(update_speed, names='value')
    
    # Link play button to frame slider
    widgets.jslink((play_button, 'value'), (frame_slider, 'value'))
    
    # Animation update function
    def update_animation(change):
        with output_widget:
            clear_output(wait=True)
            update_plot(frame_slider.value, labels_checkbox.value)
            plt.draw()
            display(fig)
    
    # Register callbacks
    frame_slider.observe(update_animation, names='value')
    labels_checkbox.observe(update_animation, names='value')
    
    # Layout the widgets
    controls = widgets.VBox([
        widgets.HBox([play_button, frame_slider]),
        widgets.HBox([speed_slider, labels_checkbox])
    ])
    
    # Display controls
    display(controls)
    
    # Initialize animation
    with output_widget:
        update_plot(0, show_labels)
        display(fig)
    
    display(output_widget)
    
    return bvh

def _preview_with_plotly(bvh, show_labels=False):
    """Use Plotly for interactive 3D visualization."""
    import plotly.graph_objects as go
    from IPython.display import display
    
    # Pre-compute all frames data for better performance
    all_positions = []
    for i in range(bvh.frames):
        all_positions.append(bvh.extract_frame_data(i))
    
    # Calculate global limits for consistent view
    min_vals = np.array([float('inf'), float('inf'), float('inf')])
    max_vals = np.array([float('-inf'), float('-inf'), float('-inf')])
    
    for positions in all_positions:
        if positions:
            pos_array = np.array([pos for pos in positions.values()])
            min_vals = np.minimum(min_vals, np.min(pos_array, axis=0))
            max_vals = np.maximum(max_vals, np.max(pos_array, axis=0))
    
    center = (min_vals + max_vals) / 2
    max_range = np.max(max_vals - min_vals) / 2 * 1.2  # Add 20% padding
    
    # Create figure
    fig = go.Figure()
    
    # Create frames for animation
    frames = []
    
    for frame_idx, positions in enumerate(all_positions):
        if not positions:
            continue
            
        # Extract data for this frame
        points_x, points_y, points_z = [], [], []
        lines_x, lines_y, lines_z = [], [], []
        labels = []
        
        # First pass: collect all points
        for joint_name, pos in positions.items():
            # For plotly, keep original coordinates
            points_x.append(pos[0])
            points_y.append(pos[1])
            points_z.append(pos[2])
            labels.append(joint_name)
            
            # Add connections to parent
            joint = bvh.joints[joint_name]
            if joint.parent:
                parent_pos = positions[joint.parent.name]
                
                # Add line segments (each point needs to be separated by None)
                lines_x.extend([pos[0], parent_pos[0], None])
                lines_y.extend([pos[1], parent_pos[1], None])
                lines_z.extend([pos[2], parent_pos[2], None])
        
        # Create a frame
        frame_data = [
            # Points (joints)
            go.Scatter3d(
                x=points_x, y=points_y, z=points_z,
                mode='markers',
                marker=dict(size=6, color='blue'),
                text=labels if show_labels else None,
                hoverinfo='text' if show_labels else 'none',
                showlegend=False
            ),
            # Lines (skeleton)
            go.Scatter3d(
                x=lines_x, y=lines_y, z=lines_z,
                mode='lines',
                line=dict(color='red', width=4),
                hoverinfo='none',
                showlegend=False
            )
        ]
        
        frames.append(go.Frame(data=frame_data, name=f"frame_{frame_idx}"))
    
    # Add initial data (first frame)
    if frames:
        fig.add_traces(frames[0].data)
    
    # Add all frames to the figure
    fig.frames = frames
    
    # Set up animation settings
    animation_settings = dict(
        frame=dict(duration=int(bvh.frame_time * 1000), redraw=True),
        fromcurrent=True,
        mode="immediate"
    )
    
    # Add animation buttons
    fig.update_layout(
        updatemenus=[
            dict(
                type="buttons",
                showactive=False,
                buttons=[
                    dict(
                        label="Play",
                        method="animate",
                        args=[None, animation_settings]
                    ),
                    dict(
                        label="Pause",
                        method="animate",
                        args=[[None], dict(frame=dict(duration=0, redraw=False), mode="immediate")]
                    )
                ],
                x=0.1,
                y=0,
                xanchor="right",
                yanchor="top"
            )
        ],
        sliders=[
            dict(
                active=0,
                steps=[
                    dict(
                        method="animate",
                        args=[
                            [f"frame_{k}"],
                            dict(mode="immediate", frame=dict(duration=0, redraw=True))
                        ],
                        label=f"{k}"
                    )
                    for k in range(0, bvh.frames, max(1, bvh.frames // 100))  # Create slider steps
                ],
                x=0.1,
                y=0,
                currentvalue=dict(
                    prefix="Frame: ",
                    visible=True,
                    xanchor="right"
                ),
                len=0.9,
                pad=dict(b=10, t=50),
                xanchor="left",
                yanchor="top"
            )
        ]
    )
    
    # Configure 3D layout
    fig.update_layout(
        scene=dict(
            xaxis=dict(range=[center[0] - max_range, center[0] + max_range]),
            yaxis=dict(range=[center[1] - max_range, center[1] + max_range]),
            zaxis=dict(range=[center[2] - max_range, center[2] + max_range]),
            aspectmode='cube'
        ),
        margin=dict(l=0, r=0, b=0, t=30),
        height=700,
        title=dict(
            text="BVH Animation",
            x=0.5,
            xanchor="center"
        )
    )
    
    # Display the figure
    display(fig)
    
    return bvh

In [18]:
viewer = preview_bvh('bvh/fight2.bvh', True, True)

VBox(children=(HBox(children=(Play(value=0, description='Play', interval=66, max=21), IntSlider(value=0, descr…

Output()