---
title: "Keypoint Visualizer"
author: "Ali Zaidi"
date: "2025-11-19"
categories: [Data Visualization]
description: "Once we have the resources to extract keypoints and normalize them, lets see how they look side by side"
format:
  html:
    code-fold: true
jupyter: python3
---

In [56]:
#| include: false
from fastai.vision.all import *
from swing_class import *
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
from utils import *

In [48]:
#| include: false
base_path = '../../../data/full_videos/ymirza'
swing_days = ['jun8', 'aug9', 'sep14']
parent_dir = f'{base_path}/{swing_days[-1]}'
files = [file for file in get_files(parent_dir, extensions='.pkl') if file.name[:3] == 'IMG']

In [49]:
#| include: false
scores = [1, 3, 1]
names = ['_'.join(x.name.split('.')[0].split('_')[:-2]) for x in files[4:7]]
lbls = [f'{names[x]}_score_{scores[x]}' for x in range(len(scores))]
kps_holder = [KpExtractor(files[x]).kps for x in range(4,7,1)]
len(kps_holder), [x.shape for x in kps_holder]

(3, [(180, 17, 3), (180, 17, 3), (180, 17, 3)])

In [50]:
#| code-fold: true
def animate_keypoints(keypoint_sequences, dark_mode=False, 
                      labels=None, vertical=False, fps=60,):
    """
    Animates 1 to 3 keypoint sequences.
    
    Args:
        keypoint_sequences: List of arrays (Frames, KPs, 2).
        dark_mode (bool): Black background with lime elements if True.
        vertical (bool): If True, stacks plots vertically. If False, places them side-by-side.
        fps (int): Frames per second.
        skeleton (list): Connection indices.
    """
    skeleton = [
        (0, 1), (0, 2), (1, 3), (2, 4),  # Face
        (5, 6), (5, 7), (7, 9),          # Left Arm
        (6, 8), (8, 10),                 # Right Arm
        (5, 11), (6, 12),                # Torso
        (11, 12),                        # Hips
        (11, 13), (13, 15),              # Left Leg
        (12, 14), (14, 16)               # Right Leg
    ]    # 1. Input Normalization
    if not isinstance(keypoint_sequences, list):
        keypoint_sequences = [keypoint_sequences]
    
    if len(keypoint_sequences) > 3:
        print("Warning: Limiting visualization to first 3 sequences.")
        keypoint_sequences = keypoint_sequences[:3]
        
    num_plots = len(keypoint_sequences)
    
    # Validate labels
    if labels and len(labels) != num_plots:
        print(f"Warning: Provided {len(labels)} labels for {num_plots} plots. Labels may not match.")
    
    # 2. Style Configuration
    if dark_mode:
        bg_color = 'black'
        line_color = 'lime'
        joint_color = 'white'
        text_color = 'white'
    else:
        bg_color = 'white'
        line_color = 'black'
        joint_color = 'red'
        text_color = 'black'

    # 3. Figure & Subplot Setup
    if vertical:
        nrows, ncols = num_plots, 1
        figsize = (2, 2 * num_plots)
    else:
        nrows, ncols = 1, num_plots
        figsize = (2 * num_plots, 2)

    fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
    fig.patch.set_facecolor(bg_color)
    
    if num_plots == 1:
        axes = [axes]
        
    plot_objects = []

    # Loop through axes, data, and optionally labels
    for i, (ax, data) in enumerate(zip(axes, keypoint_sequences)):
        if data.shape[-1] == 3:
            data = data[..., :2]
            
        ax.set_facecolor(bg_color)
        ax.axis('off')
        ax.set_aspect('equal')

        # --- Add Label if provided ---
        if labels and i < len(labels):
            ax.set_title(labels[i], color=text_color, fontsize=8, pad=8)
        
        # --- LIMITS CONFIGURATION ---
        all_x = data[..., 0].flatten()
        all_y = data[..., 1].flatten()
        valid_mask = (all_x > 0.1) & (all_y > 0.1)

        if valid_mask.any():
            vx, vy = all_x[valid_mask], all_y[valid_mask]
            pad = 50
            ax.set_xlim(vx.min() - pad, vx.max() + pad)
            # Pass (MAX, MIN) to put max values (feet) at the bottom
            ax.set_ylim(vy.max() + pad, vy.min() - pad)
        else:
            ax.set_xlim(0, 640)
            ax.set_ylim(480, 0)

        # Create graphics objects
        scat = ax.scatter([], [], s=30, c=joint_color, zorder=2)
        lines = [ax.plot([], [], color=line_color, lw=2)[0] for _ in skeleton]
        
        plot_objects.append({'scat': scat, 'lines': lines, 'data': data, 'ax': ax})

    # 4. Animation Logic
    def init():
        all_artists = []
        for obj in plot_objects:
            obj['scat'].set_offsets(np.empty((0, 2)))
            for line in obj['lines']:
                line.set_data([], [])
            all_artists.append(obj['scat'])
            all_artists.extend(obj['lines'])
        return all_artists

    def update(frame_idx):
        all_artists = []
        for obj in plot_objects:
            data = obj['data']
            idx = min(frame_idx, len(data) - 1)
            current_frame = data[idx]
            
            mask = (current_frame[:, 0] > 0.1) & (current_frame[:, 1] > 0.1) & ~np.isnan(current_frame[:, 0])
            
            obj['scat'].set_offsets(current_frame[mask])
            all_artists.append(obj['scat'])
            
            for line, (start, end) in zip(obj['lines'], skeleton):
                if mask[start] and mask[end]:
                    line.set_data(
                        [current_frame[start, 0], current_frame[end, 0]],
                        [current_frame[start, 1], current_frame[end, 1]]
                    )
                else:
                    line.set_data([], [])
                all_artists.append(line)
                
        return all_artists

    max_frames = max(len(d) for d in keypoint_sequences)

    anim = FuncAnimation(
        fig, 
        update, 
        frames=max_frames, 
        init_func=init, 
        blit=True, 
        interval=1000/fps
    )
    
    plt.close()
    return anim

In [51]:
#| echo: false
top_idxs = [get_frame_plot(x[30:])[0] + 30 for x in kps_holder]
lowest_frame_count = np.array([kp.shape[0] for kp in kps_holder]).min()
highest_peak_frame = np.array(top_idxs).max()
diff = lowest_frame_count - highest_peak_frame
start_idx = highest_peak_frame - diff
end_idx = highest_peak_frame + diff
print(f'The frame index where the straight arm is found \
in the backswing are: {top_idxs}')
print(f'The highest peak frame where this is found is: {highest_peak_frame}')
print(f'The lowest frame count in any of our clips is: {lowest_frame_count}')
print(f'We have a difference of {diff} frames between where this happens')
print(f'The indexes we will use to index all these videos are:\n \
start:{start_idx} and end:{end_idx}')

The frame index where the straight arm is found in the backswing are: [136, 141, 136]
The highest peak frame where this is found is: 141
The lowest frame count in any of our clips is: 180
We have a difference of 39 frames between where this happens
The indexes we will use to index all these videos are:
 start:102 and end:180


In [52]:
#| include: false
idx_bounds = [(top_idxs[x]-70, top_idxs[x] + 45)for x in range(len(kps_holder))]
test_kps = [kps_holder[x][idx_bounds[x][0]:idx_bounds[x][1]] \
            for x in range(len(idx_bounds))]
[x.shape for x in test_kps]

[(114, 17, 3), (109, 17, 3), (114, 17, 3)]

In [55]:
animator = animate_keypoints(test_kps, vertical=True,
                              dark_mode=True, labels=lbls)
display(HTML(animator.to_jshtml()))

Now lets see how these keypoints look once normalized....