# Tendon Gliding Hand Action Recognition

# Description

- The goal is to perform Tendon Gliding Hand Action Recognition by classifying hand postures into five distinct classes using real-time video. Additionally, the task involves calculating the accuracy of hand skeleton occurrences compared to the ground truth.
- The classes are as follows: **Hand Open**, **Intrinsic Plan**, **Straight Fist**, **Hand Close**, and **Hook Hand**.  

<p align="center">
    <img src='../../IMAGES/Classes.jpg' width="500px" />
</p>

- The base workflow is as follows: 

<div align="center">

```mermaid
graph LR
    A[Import RAW Data]
    A --> B[Separate into 5 different classes]
    B --> C[Generate 2D keypoints of X and Y coordinates]
    C --> D[Train LSTM model]
    D --> E[Evaluate Performance]
    E --> F[Metrics: Accuracy, Specificity, Sensitivity, F1-Score, Confusion Matrix]
```

</div>

## 0. Import Library

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import rerun as rr
import rerun.blueprint as rrb
import seedir as sd
import os
import albumentations as A
from pyorbbecsdk import *
import cv2
from utils import frame_to_bgr_image
from pathlib import Path
import shutil
from collections import defaultdict

os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"

  check_for_updates()


## 1. Utility Function

In [2]:
def list_bag_files(root_dir='data'):
    """
    List all .bag files in the directory structure, excluding those in 'open-hand' folders.
    
    Args:
        root_dir (str): Root directory to start the search from.
        
    Returns:
        list: List of paths to .bag files.
    """
    bag_files = []
    
    for root, dirs, files in os.walk(root_dir):
        # Skip 'open-hand' directories
        if 'open-fist' in os.path.basename(root):
            continue
            
        # Add .bag files to the list
        for file in files:
            if file.endswith('.bag'):
                bag_files.append(os.path.join(root, file))
                
    return bag_files

def playback_state_callback(state):
    """Callback function to handle playback state transitions."""
    global playback_finished
    if state == OBMediaState.OB_MEDIA_BEGIN:
        print("Bag player begin")
    elif state == OBMediaState.OB_MEDIA_END:
        print("Bag player end")
        playback_finished = True  # Signal that playback has finished
    elif state == OBMediaState.OB_MEDIA_PAUSED:
        print("Bag player paused")

def process_frames(bag_file):
    """
    Process the .bag file and return lists of processed images.
    
    Returns:
        depth_image_list: List of raw depth data (converted to float and scaled).
        color_image_list: List of processed color images.
        overlaid_image_list: List of images with overlay (color blended with depth colormap).
    """
    global playback_finished
    playback_finished = False  # Reset flag

    pipeline = Pipeline(bag_file)
    playback = pipeline.get_playback()
    playback.set_playback_state_callback(playback_state_callback)

    # Start the pipeline
    pipeline.start()

    depth_image_list = []
    color_image_list = []

    while not playback_finished:
        frames = pipeline.wait_for_frames(100)
        if frames is None:
            if playback_finished:
                print("All frames have been processed and converted successfully.")
                break
            continue

        # Retrieve frames once per iteration
        color_frame = frames.get_color_frame()
        depth_frame = frames.get_depth_frame()

        if depth_frame is not None:
            width = depth_frame.get_width()
            height = depth_frame.get_height()
            scale = depth_frame.get_depth_scale()

            # Process raw depth data
            depth_data = np.frombuffer(depth_frame.get_data(), dtype=np.uint16)
            depth_data = depth_data.reshape((height, width))
            depth_data = depth_data.astype(np.float32) * scale
            depth_image_list.append(depth_data)

            # Normalize and invert to obtain desired mapping (farthest = red, closest = blue)
            depth_norm = cv2.normalize(depth_data, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
            inverted_depth = 255 - depth_norm
            depth_image = cv2.applyColorMap(inverted_depth, cv2.COLORMAP_JET)
        else:
            depth_image = None

        if color_frame is not None:
            width = color_frame.get_width()
            height = color_frame.get_height()

            color_data = frame_to_bgr_image(color_frame)
            color_image = cv2.resize(color_data, (width, height))
            # Convert to BGR if necessary; adjust if frame_to_bgr_image already outputs BGR
            color_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR)
            color_image_list.append(color_image)
        else:
            color_image = None

    return depth_image_list, color_image_list

def rerun_visualization(*image_lists):
    """
    Create a rerun visualization with multiple image lists displayed in a grid.
    
    Args:
        *image_lists: Variable number of image lists to display
    """
    stream = rr.new_recording("spawn", spawn=True)
    
    # Dynamically create spatial views for each image list
    spatial_views = []
    for i in range(len(image_lists)):
        spatial_views.append(rrb.Spatial2DView(origin=f'/color_image_{i}'))
    
    # Calculate a reasonable number of columns for the grid
    # You can adjust this logic based on your preference
    num_columns = min(3, len(image_lists))  # Max 3 columns
    
    # Setup the blueprint with dynamic grid configuration
    blueprint = rrb.Blueprint(
        rrb.Grid(*spatial_views, grid_columns=num_columns),
        collapse_panels=True
    )
    
    # Calculate the maximum length across all image lists
    max_length = max(len(image_list) for image_list in image_lists)
    
    # Log all images with proper time sequencing
    for idx in range(max_length):
        stream.set_time_sequence("frame", idx)
        
        # Log each image list at the current index if available
        for list_idx, image_list in enumerate(image_lists):
            if idx < len(image_list):
                stream.log(f"color_image_{list_idx}", rr.Image(image_list[idx]))
    
    stream.send_blueprint(blueprint)

def sliding_window_sample(images, window_size=16, stride=8):
    total_frames = len(images)
    windows = []
    
    for start_idx in range(0, total_frames - window_size + 1, stride):
        end_idx = start_idx + window_size
        window = np.stack(images[start_idx:end_idx])
        windows.append(window)
    
    return windows

def visualize_windows(windows):
    """
    Visualize multiple windows of image sequences in Rerun.
    
    Args:
        windows: List of image windows, where each window is a sequence of frames
        max_windows_to_display: Maximum number of windows to display in the grid
    """
    stream = rr.new_recording("spawn", spawn=True)
    
    # Limit the number of windows to display to avoid overcrowding
    num_windows = len(windows)
    
    # Create spatial views for each window
    spatial_views = []
    for i in range(num_windows):
        spatial_views.append(rrb.Spatial2DView(origin=f'/window_{i}'))
    
    # Calculate grid layout (max 3 columns)
    num_columns = min(3, num_windows)
    
    # Setup the blueprint with dynamic grid configuration
    blueprint = rrb.Blueprint(
        rrb.Grid(*spatial_views, grid_columns=num_columns),
        collapse_panels=True
    )
    
    # Log the frames of each window
    window_size = windows[0].shape[0]  # Get the size of each window
    
    # Log each frame in each window
    for frame_idx in range(window_size):
        stream.set_time_sequence("frame", frame_idx)
        
        # Log the current frame from each window
        for window_idx in range(num_windows):
            if window_idx < len(windows):
                stream.log(f"window_{window_idx}", 
                          rr.Image(windows[window_idx][frame_idx]))
    
    # Display info about the windows
    print(f"Visualizing {num_windows} windows out of {len(windows)} total")
    print(f"Each window contains {window_size} frames")
    print(f"Window shape: {windows[0].shape}")
    
    stream.send_blueprint(blueprint)

def save_windowed_data(recordings, bag_files, window_size=16, stride=8, windowed_data_dir='windowed_data'):
    """
    Save windowed data for RGB and depth images for each recording to its respective folder.
    
    Args:
        recordings: Dictionary containing recordings with color_images and depth_images
        bag_files: List of bag file paths sorted to match recording indices
        window_size: Size of each window
        stride: Stride between consecutive windows
    """
    print(f"Saving windowed data for {len(recordings)} recordings...")
    
    for recording_idx, bag_file_path in enumerate(bag_files):
        recording_key = f"recording_{recording_idx}"
        
        # Check if the recording exists
        if recording_key in recordings:
            recording_data = recordings[recording_key]
            # Create base output directory for this recording
            bag_path = Path(bag_file_path)
            base_output_dir = bag_path.parent / windowed_data_dir
            
            # Process each image type (RGB and depth)
            image_types = {
                "color_images": "rgb",
                "depth_images": "depth"
            }
            
            for image_key, folder_name in image_types.items():
                if image_key in recording_data and recording_data[image_key] is not None:
                    images = recording_data[image_key]
                    
                    # Create windows for this image type
                    windows = sliding_window_sample(images, window_size=window_size, stride=stride)
                    
                    # Create output directory for this image type
                    output_dir = base_output_dir / folder_name
                    os.makedirs(output_dir, exist_ok=True)
                    
                    # Create a base filename from the bag file
                    base_filename = bag_path.stem
                    
                    # Save each window as a separate .npy file
                    for window_idx, window in enumerate(windows):
                        # Create filename: RecordXXX_window_NNN.npy
                        window_filename = f"{base_filename}_window_{window_idx:03d}.npy"
                        output_path = output_dir / window_filename
                        
                        # Save the window data
                        np.save(output_path, window)
                    
                    print(f"Saved {len(windows)} {folder_name} windows for {recording_key} to {output_dir}")
                else:
                    print(f"Skipping {image_key} for {recording_key}: Data not available")
        else:
            print(f"Skipping recording {recording_idx}: Recording not found")
    
    print("Windowed data saving complete!")

def prepare_recordings_dict(processed_data):
    recordings = {}
    for i in range(6):  # Assuming there are 6 recordings (0-5)
        recording_key = f"recording_{i}"
        if recording_key in processed_data:
            recordings[recording_key] = processed_data[recording_key]
    return recordings

## 2. Data Preparation

- In total there are six different recordings of tendon gliding task as follows: 
    - From 📁 `20250402`
        - `Record_20250402151124.bag`
        - `Record_20250402151609.bag`
        - `Record_20250402152331.bag`
    - From 📁 `20250506`
        - `Record_20250506145704.bag`
        - `Record_20250506152951.bag`
        - `Record_20250506162630.bag`
- Each data consist of RGB frames and Depth frames
- Workflow: 

<div align="center">

```mermaid
graph LR
    A[Data Extraction] --> B[Data Visualization]
    B --> C[Windowing Process]
    C --> D[Save Windowed Data]
    D --> E[Data Class Generation]
```

</div>

### 2.1. Data Extraction

In [None]:
bag_files = list_bag_files()
data_20250402 = bag_files[:3]      # First three items
data_20250506 = bag_files[-3:]     # Last three items
processed_data = {}

# Process the 20250402 files
for idx, bag_file in enumerate(data_20250402):
    print(f"Processing bag file: {bag_file}")
    # Process the frames in the bag file
    depth_images, color_images = process_frames(bag_file)
    # Store the lists in the dictionary with dataset identifier
    processed_data[f"recording_{idx}"] = {
        "dataset": "20250402",
        "depth_images": depth_images,
        "color_images": color_images
    }

# Process the 20250506 files
for idx, bag_file in enumerate(data_20250506):
    print(f"Processing bag file: {bag_file}")
    # Process the frames in the bag file
    depth_images, color_images = process_frames(bag_file)
    # Store the lists in the dictionary with dataset identifier
    processed_data[f"recording_{idx+3}"] = {  # Add offset to avoid key collision
        "dataset": "20250506",
        "depth_images": depth_images,
        "color_images": color_images
    }

### 2.2. Data Visualization

In [None]:
recording_0 = processed_data.get("recording_0")
recording_1 = processed_data.get("recording_1")
recording_2 = processed_data.get("recording_2")
recording_3 = processed_data.get("recording_3")
recording_4 = processed_data.get("recording_4")
recording_5 = processed_data.get("recording_5")

image_lists = []
for recording_idx in range(6):  # Assuming you have 6 recordings
    recording_key = f"recording_{recording_idx}"
    if recording_key in processed_data and "color_images" in processed_data[recording_key]:
        image_lists.append(processed_data[recording_key]["color_images"])

# Visualize all valid image lists together
if image_lists:
    rerun_visualization(*image_lists)

### 2.3. Detect Hand

In [3]:
import os
import numpy as np
import mmcv
from mmdet.apis import inference_detector, init_detector
from mmpose.utils import adapt_mmdet_pipeline
from mmpose.evaluation.functional import nms
import warnings
import logging
from typing import Tuple, List, Optional, Union, Dict, Any

# Global model cache to avoid reloading models
MODEL_CACHE = {}

def detect_hands(img: Union[str, np.ndarray, List[Union[str, np.ndarray]]], 
                det_config: str = 'configs/rtmdet_nano_320-8xb32_hand.py',
                det_checkpoint: str = 'https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmdet_nano_8xb32-300e_hand-267f9c8f.pth',
                device: str = 'cuda:0',
                det_cat_id: int = 0,
                bbox_thr: float = 0.3,
                nms_thr: float = 0.3,
                suppress_warnings: bool = True,
                input_size: Tuple[int, int] = (300, 300),
                crop_size: Tuple[int, int] = (600, 600),
                return_coordinates: str = 'original') -> Union[
                    Tuple[np.ndarray, np.ndarray, Dict[str, Any]],  # For 'original' mode
                    Tuple[np.ndarray, np.ndarray, Dict[str, Any], np.ndarray],  # For 'crop' or 'model' modes
                    List[Union[
                        Tuple[np.ndarray, np.ndarray, Dict[str, Any]],  # List of results in 'original' mode
                        Tuple[np.ndarray, np.ndarray, Dict[str, Any], np.ndarray]  # List of results in 'crop' or 'model' modes
                    ]]
                ]:
    """
    Detect hand bounding boxes in a single image or a list of images.
    
    Args:
        img: Image path, numpy array (RGB), or list of image paths/arrays.
        det_config: Path to detection config file.
        det_checkpoint: Path or URL to detection checkpoint.
        device: Device to run inference on.
        det_cat_id: Category ID for hands.
        bbox_thr: Threshold for bounding box confidence.
        nms_thr: Threshold for non-maximum suppression.
        suppress_warnings: Whether to suppress warning messages.
        input_size: Model input size (width, height), defaults to (300, 300)
        crop_size: Size for center crop (width, height), defaults to (600, 600)
        return_coordinates: Coordinate system for returned bounding boxes:
                          - 'original': In original input image coordinates (default)
                          - 'crop': In center-cropped image coordinates
                          - 'model': In model input size coordinates (300x300 by default)
        
    Returns:
        For a single image:
            bboxes: Numpy array of hand bounding boxes [x1, y1, x2, y2] in specified coordinate system
            scores: Numpy array of confidence scores
            crop_info: Dictionary with information about the crop (for transforming coordinates)
            processed_img: The cropped or model-sized image based on return_coordinates value
            
        For multiple images:
            List of tuples, where each tuple contains the above results for one image
    """
    global MODEL_CACHE
    
    # Validate return_coordinates parameter
    valid_return_options = ['original', 'crop', 'model']
    if return_coordinates not in valid_return_options:
        raise ValueError(f"return_coordinates must be one of {valid_return_options}, got {return_coordinates}")
    
    # Set up warning suppression if requested
    if suppress_warnings:
        # Filter warnings
        warnings.filterwarnings("ignore", category=UserWarning)
        warnings.filterwarnings("ignore", category=FutureWarning)
        
        # Filter mmengine and other library logs
        logging.getLogger('mmengine').setLevel(logging.ERROR)
        logging.getLogger('mmdet').setLevel(logging.ERROR)
        logging.getLogger('mmpose').setLevel(logging.ERROR)
        
        # Disable PyTorch CUDA warnings
        os.environ['PYTHONWARNINGS'] = 'ignore::FutureWarning'
    
    # Check if model is already in cache
    model_key = f"{det_config}_{det_checkpoint}_{device}"
    if model_key not in MODEL_CACHE:
        # Build detector
        detector = init_detector(det_config, det_checkpoint, device=device)
        detector.cfg = adapt_mmdet_pipeline(detector.cfg)
        MODEL_CACHE[model_key] = detector
    else:
        detector = MODEL_CACHE[model_key]
    
    # Define the processing function inside to maintain access to all variables
    def _process_single_image(current_img):
        # Convert image to numpy array if it's a path
        if isinstance(current_img, str):
            current_img = mmcv.imread(current_img, channel_order='rgb')
        elif isinstance(current_img, np.ndarray) and current_img.shape[-1] == 3:
            # Convert BGR to RGB if necessary
            if not isinstance(current_img, np.ndarray):
                current_img = np.array(current_img)
            current_img = mmcv.bgr2rgb(current_img)
        
        # Store original image dimensions
        original_height, original_width = current_img.shape[:2]
        
        # Apply center crop first to get the cropped image
        center_crop = A.CenterCrop(height=crop_size[1], width=crop_size[0])
        crop_result = center_crop(image=current_img)
        cropped_img = crop_result['image']
        
        # Calculate actual crop offsets based on original image dimensions
        crop_x_offset = max(0, (original_width - crop_size[0]) // 2)
        crop_y_offset = max(0, (original_height - crop_size[1]) // 2)
        
        # Now resize the cropped image to input size for the model
        resize_transform = A.Resize(height=input_size[1], width=input_size[0])
        resize_result = resize_transform(image=cropped_img)
        processed_img = resize_result['image']
        
        # Store crop information for coordinate mapping
        crop_info = {
            'original_size': (original_width, original_height),
            'crop_size': crop_size,
            'input_size': input_size,
            'crop_offset': (crop_x_offset, crop_y_offset)
        }
        
        # Detect hands (bounding boxes)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            # Suppress torch warnings
            torch._C._jit_set_profiling_executor(False)
            torch._C._jit_set_profiling_mode(False)
            
            det_result = inference_detector(detector, processed_img)
        
        pred_instance = det_result.pred_instances.cpu().numpy()
        
        # Extract bounding boxes with scores
        if len(pred_instance.bboxes) > 0:
            bboxes = np.concatenate(
                (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1)
            
            # Filter bboxes by category and score
            mask = np.logical_and(pred_instance.labels == det_cat_id,
                                pred_instance.scores > bbox_thr)
            bboxes = bboxes[mask]
            
            # Apply NMS
            if len(bboxes) > 0:
                keep_indices = nms(bboxes, nms_thr)
                filtered_bboxes = bboxes[keep_indices, :4]  # Just the coordinates
                filtered_scores = bboxes[keep_indices, 4]   # Just the scores
                
                # Map bounding boxes to appropriate coordinates based on user choice
                if len(filtered_bboxes) > 0:
                    # For 'model' coordinates, we keep the bounding boxes as is
                    if return_coordinates == 'model':
                        # Clip to model input boundaries (usually redundant but good practice)
                        filtered_bboxes[:, 0] = np.clip(filtered_bboxes[:, 0], 0, input_size[0])
                        filtered_bboxes[:, 1] = np.clip(filtered_bboxes[:, 1], 0, input_size[1])
                        filtered_bboxes[:, 2] = np.clip(filtered_bboxes[:, 2], 0, input_size[0])
                        filtered_bboxes[:, 3] = np.clip(filtered_bboxes[:, 3], 0, input_size[1])
                        
                    else:  # 'crop' or 'original'
                        # First scale from input_size to crop_size
                        scale_x = crop_size[0] / input_size[0]
                        scale_y = crop_size[1] / input_size[1]
                        
                        filtered_bboxes[:, 0] *= scale_x  # x1
                        filtered_bboxes[:, 1] *= scale_y  # y1
                        filtered_bboxes[:, 2] *= scale_x  # x2
                        filtered_bboxes[:, 3] *= scale_y  # y2
                        
                        # For 'crop' coordinates, clip to crop boundaries
                        if return_coordinates == 'crop':
                            filtered_bboxes[:, 0] = np.clip(filtered_bboxes[:, 0], 0, crop_size[0])
                            filtered_bboxes[:, 1] = np.clip(filtered_bboxes[:, 1], 0, crop_size[1])
                            filtered_bboxes[:, 2] = np.clip(filtered_bboxes[:, 2], 0, crop_size[0])
                            filtered_bboxes[:, 3] = np.clip(filtered_bboxes[:, 3], 0, crop_size[1])
                        
                        # For 'original' coordinates, add crop offsets and clip to original boundaries
                        elif return_coordinates == 'original':
                            filtered_bboxes[:, 0] += crop_x_offset  # x1
                            filtered_bboxes[:, 1] += crop_y_offset  # y1
                            filtered_bboxes[:, 2] += crop_x_offset  # x2
                            filtered_bboxes[:, 3] += crop_y_offset  # y3
                            
                            # Clip to original image boundaries
                            filtered_bboxes[:, 0] = np.clip(filtered_bboxes[:, 0], 0, original_width)
                            filtered_bboxes[:, 1] = np.clip(filtered_bboxes[:, 1], 0, original_height)
                            filtered_bboxes[:, 2] = np.clip(filtered_bboxes[:, 2], 0, original_width)
                            filtered_bboxes[:, 3] = np.clip(filtered_bboxes[:, 3], 0, original_height)
                
                # Return appropriate image along with results based on return_coordinates
                if return_coordinates == 'model':
                    return filtered_bboxes, filtered_scores, crop_info, processed_img
                elif return_coordinates == 'crop':
                    return filtered_bboxes, filtered_scores, crop_info, cropped_img
                else:  # 'original'
                    return filtered_bboxes, filtered_scores, crop_info
        
        # Return empty arrays if no detections, with appropriate image
        if return_coordinates == 'model':
            return np.empty((0, 4), dtype=np.float32), np.empty((0,), dtype=np.float32), crop_info, processed_img
        elif return_coordinates == 'crop':
            return np.empty((0, 4), dtype=np.float32), np.empty((0,), dtype=np.float32), crop_info, cropped_img
        else:  # 'original'
            return np.empty((0, 4), dtype=np.float32), np.empty((0,), dtype=np.float32), crop_info
    
    # Check if input is a list of images
    if isinstance(img, list):
        return [_process_single_image(single_img) for single_img in img]
    else:
        return _process_single_image(img)

def visualize_hand_detections(image, bboxes, scores=None, thickness=2, color=(0, 255, 0), 
                             display_scores=True, score_threshold=0.0, 
                             figsize=(12, 12), save_path=None, display=True):
    """
    Visualize hand bounding boxes on an image.
    
    Args:
        image: Original image as numpy array (RGB)
        bboxes: Numpy array of bounding boxes in [x1, y1, x2, y2] format
        scores: Optional numpy array of confidence scores for each box
        thickness: Thickness of bounding box lines
        color: Color of bounding box lines in BGR format (default: green)
        display_scores: Whether to display confidence scores
        score_threshold: Only display boxes with scores above this threshold
        figsize: Size of the displayed figure
        save_path: Optional path to save the visualization
        display: Whether to display the plot (set to False to avoid showing plots)
        
    Returns:
        vis_image: Numpy array of the visualization image with bounding boxes
    """
    # Make a copy of the image to avoid modifying the original
    vis_image = image.copy()
    
    # Convert to BGR for OpenCV if it's RGB
    if vis_image.shape[2] == 3:
        vis_image_bgr = cv2.cvtColor(vis_image, cv2.COLOR_RGB2BGR)
    else:
        vis_image_bgr = vis_image.copy()
    
    # Draw each bounding box
    for i, bbox in enumerate(bboxes):
        # Convert to integers
        x1, y1, x2, y2 = bbox.astype(np.int32)
        
        # Get score if available
        score = scores[i] if scores is not None else None
        
        # Skip boxes with scores below threshold
        if score is not None and score < score_threshold:
            continue
        
        # Draw rectangle
        cv2.rectangle(vis_image_bgr, (x1, y1), (x2, y2), color, thickness)
        
        # Display score if requested
        if display_scores and score is not None:
            score_text = f"{score:.2f}"
            text_position = (x1, y1 - 10)
            cv2.putText(vis_image_bgr, score_text, text_position,
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, thickness)
    
    # Convert back to RGB for matplotlib display
    vis_image = cv2.cvtColor(vis_image_bgr, cv2.COLOR_BGR2RGB)
    
    # Display the image if requested
    if display:
        plt.figure(figsize=figsize)
        plt.imshow(vis_image)
        plt.axis('off')
        
        # Save the image if requested
        if save_path:
            plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
            print(f"Saved visualization to {save_path}")
        
        plt.show()
    elif save_path:
        # If not displaying but still need to save
        plt.figure(figsize=figsize)
        plt.imshow(vis_image)
        plt.axis('off')
        plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
        plt.close()  # Close the figure to avoid memory leaks
        print(f"Saved visualization to {save_path}")
    
    return vis_image

def detect_and_visualize_hands(images, 
                              det_config='configs/rtmdet_nano_320-8xb32_hand.py',
                              det_checkpoint='https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmdet_nano_8xb32-300e_hand-267f9c8f.pth',
                              device='cuda:0',
                              bbox_thr=0.3,
                              nms_thr=0.3,
                              visualize=True,  # Now controls whether to create visualizations
                              display_plot=True,  # New parameter to control showing the plot
                              save_path=None,
                              grid_display=False,
                              max_images_per_row=3,
                              return_coordinates='original',
                              crop_size=(600, 600)):
    """
    Detect and visualize hands in a single image or a list of images.
    
    Args:
        images: Path to image, single numpy array, or list of numpy arrays
        det_config: Path to detection config file
        det_checkpoint: Path or URL to detection checkpoint
        device: Device to run inference on
        bbox_thr: Threshold for bounding box confidence
        nms_thr: Threshold for non-maximum suppression
        visualize: Whether to create visualizations with bounding boxes (always returns vis_images)
        display_plot: Whether to display the visualization plots
        save_path: Optional path to save the visualization
        grid_display: When processing multiple images, display them in a grid
        max_images_per_row: Maximum number of images per row in the grid
        return_coordinates: Coordinate system for returned bounding boxes:
                          - 'original': In original input image coordinates (default)
                          - 'crop': In center-cropped image coordinates
                          - 'model': In model input size coordinates
        crop_size: Size for center crop (width, height), defaults to (600, 600)
        
    Returns:
        bboxes_list: List of numpy arrays of hand bounding boxes
        scores_list: List of numpy arrays of confidence scores
        crop_info_list: List of crop information dictionaries
        vis_images: Visualization images with bounding boxes (if visualize=True)
        processed_images: List of processed images without bounding boxes (if return_coordinates is 'crop' or 'model')
    """
    # Check if the input is a list or a single image
    if isinstance(images, list) or (isinstance(images, np.ndarray) and len(images.shape) == 4):
        # We have a list of images
        is_list = True
        if isinstance(images, np.ndarray):
            # Convert 4D array to list of 3D arrays
            images = [images[i] for i in range(images.shape[0])]
    else:
        # We have a single image
        is_list = False
        images = [images]  # Convert to list for uniform processing
    
    bboxes_list = []
    scores_list = []
    vis_images = []
    crop_info_list = []
    processed_images = []
    
    # Process each image
    for i, image in enumerate(images):
        # Load image if path is provided
        if isinstance(image, str):
            img = cv2.imread(image)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        else:
            img = image.copy()
        
        # Detect hands - handle different return values based on return_coordinates
        detect_result = detect_hands(
            img, 
            det_config=det_config,
            det_checkpoint=det_checkpoint,
            device=device,
            bbox_thr=bbox_thr,
            nms_thr=nms_thr,
            crop_size=crop_size,
            return_coordinates=return_coordinates
        )
        
        # Unpack the result based on return_coordinates
        if return_coordinates == 'original':
            bboxes, scores, crop_info = detect_result
            vis_base_img = img  # Use original image for visualization
        else:  # 'crop' or 'model'
            bboxes, scores, crop_info, processed_img = detect_result
            vis_base_img = processed_img  # Use processed image for visualization
            processed_images.append(processed_img)
        
        bboxes_list.append(bboxes)
        scores_list.append(scores)
        crop_info_list.append(crop_info)
        
        # Create visualization (always, regardless of display_plot parameter)
        if visualize:
            # Create individual save path if provided
            individual_save_path = None
            if save_path and len(images) > 1:
                base, ext = os.path.splitext(save_path) if '.' in os.path.basename(save_path) else (save_path, '.jpg')
                individual_save_path = f"{base}_{i}{ext}"
            elif save_path:
                individual_save_path = save_path
            
            # Only display during loop if requested and not using grid display
            should_display = display_plot and not (grid_display and len(images) > 1)
            
            if len(bboxes) > 0:
                # Create visualization with bounding boxes
                vis_result = visualize_hand_detections(
                    vis_base_img,  # Use the appropriate image based on return_coordinates
                    bboxes, 
                    scores, 
                    save_path=individual_save_path if not grid_display else None,
                    display=should_display  # Control whether to show the plot
                )
                vis_images.append(vis_result)
                
                if should_display:
                    print(f"Image {i+1}: Detected {len(bboxes)} hands")
            else:
                # No hands detected, just use the base image
                vis_images.append(vis_base_img)
                if should_display:
                    print(f"Image {i+1}: No hands detected")
        else:
            # If not visualizing, just add the base image without bounding boxes
            vis_images.append(vis_base_img)
    
    # Create a grid visualization for multiple images
    if visualize and grid_display and len(images) > 1:
        # Calculate grid dimensions
        num_images = len(vis_images)
        num_cols = min(max_images_per_row, num_images)
        num_rows = (num_images + num_cols - 1) // num_cols  # Ceiling division
        
        # Create figure and plot images
        fig, axes = plt.subplots(num_rows, num_cols, figsize=(5*num_cols, 5*num_rows))
        
        # Make axes accessible in a uniform way
        if num_rows == 1 and num_cols == 1:
            axes = np.array([[axes]])
        elif num_rows == 1 or num_cols == 1:
            axes = axes.reshape(num_rows, num_cols)
        
        # Plot each image
        for i in range(num_rows):
            for j in range(num_cols):
                idx = i * num_cols + j
                if idx < num_images:
                    axes[i, j].imshow(vis_images[idx])
                    axes[i, j].set_title(f"Image {idx+1}: {len(bboxes_list[idx])} hands")
                    axes[i, j].axis('off')
                else:
                    axes[i, j].axis('off')  # Hide unused subplots
        
        plt.tight_layout()
        
        # Save grid if requested
        if save_path:
            plt.savefig(save_path, bbox_inches='tight')
            print(f"Saved grid visualization to {save_path}")
        
        # Only display the grid if requested
        if display_plot:
            plt.show()
        else:
            plt.close()  # Close the figure to avoid memory leaks
    
    # Return appropriate values based on input type and return_coordinates
    if not is_list:
        # Single image was provided
        if return_coordinates == 'original':
            return bboxes_list[0], scores_list[0], crop_info_list[0], vis_images[0]
        else:  # 'crop' or 'model'
            return bboxes_list[0], scores_list[0], crop_info_list[0], vis_images[0], processed_images[0]
    else:
        # List of images was provided
        if return_coordinates == 'original':
            return bboxes_list, scores_list, crop_info_list, vis_images
        else:  # 'crop' or 'model'
            return bboxes_list, scores_list, crop_info_list, vis_images, processed_images

# recording = processed_data.get("recording_0")
# rgb_image = recording["color_images"]

# boxes_list, scores_list, crop_info_list, vis_images, processed_imgs = detect_and_visualize_hands(
#     rgb_image,
#     return_coordinates='crop',
#     crop_size=(800, 800),
#     visualize=True,
#     display_plot=False
# )

  from torch.distributed.optim import \


### 2.4. Cropping to Target Hand

In [4]:
from typing import Tuple, List, Optional, Union

def crop_hands_sequence(images: Union[List[np.ndarray], np.ndarray],
                      output_size: Tuple[int, int] = (300, 300),
                      margin_percent: float = 0.2,
                      det_config: str = 'configs/rtmdet_nano_320-8xb32_hand.py',
                      det_checkpoint: str = 'https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmdet_nano_8xb32-300e_hand-267f9c8f.pth',
                      device: str = 'cuda:0',
                      bbox_thr: float = 0.3,
                      nms_thr: float = 0.3,
                      suppress_warnings: bool = True,
                      save: bool = False,
                      save_path: Optional[str] = None) -> Tuple[List[np.ndarray], Tuple[int, int, int, int]]:
    """
    Crop hands from a sequence of images using the same crop coordinates for all frames.
    
    Args:
        images: List of RGB images or a 4D array (frames, height, width, channels)
        output_size: Size of the output images (width, height), defaults to (300, 300)
        margin_percent: Extra margin to add around detected hand (as percentage of bbox dims)
        det_config: Path to detection config file
        det_checkpoint: Path or URL to detection checkpoint
        device: Device to run inference on
        bbox_thr: Threshold for bounding box confidence
        nms_thr: Threshold for non-maximum suppression
        suppress_warnings: Whether to suppress warning messages
        save: Whether to save the cropped sequence as a .npy file
        save_path: Path where to save the .npy file (required if save=True)
        
    Returns:
        Tuple containing:
            - List of cropped hand images of size output_size
            - Crop coordinates (x_min, y_min, x_max, y_max) used for all images
    """
    # Convert input to list of images if needed
    if isinstance(images, np.ndarray) and len(images.shape) == 4:
        images_list = [images[i] for i in range(images.shape[0])]
    elif not isinstance(images, list):
        images_list = [images]  # Single image case
    else:
        images_list = images
    
    # Get the first image for hand detection
    first_image = images_list[0]
    
    # 1. Detect hand in the first frame
    bboxes, scores, _ = detect_hands(
        first_image,
        det_config=det_config,
        det_checkpoint=det_checkpoint,
        device=device,
        bbox_thr=bbox_thr,
        nms_thr=nms_thr,
        suppress_warnings=suppress_warnings,
        crop_size=(800, 800),
        return_coordinates='original'
    )
    
    # Check if any hands were detected
    if len(bboxes) == 0:
        # No hands detected, return the original images resized
        crop_transform = A.Resize(height=output_size[1], width=output_size[0])
        resized_imgs = [crop_transform(image=img)['image'] for img in images_list]
        
        # Save the resized images as .npy if requested
        if save:
            if save_path is None:
                raise ValueError("save_path must be provided when save=True")
            
            # Create directory if it doesn't exist
            os.makedirs(os.path.dirname(os.path.abspath(save_path)), exist_ok=True)
            
            # Convert list to numpy array and save
            np_sequence = np.array(resized_imgs)
            np.save(save_path, np_sequence)
            print(f"Saved resized sequence to {save_path}, shape: {np_sequence.shape}")
            
        return resized_imgs, (0, 0, first_image.shape[1], first_image.shape[0])
    
    # Use the first detected hand (highest confidence)
    x1, y1, x2, y2 = bboxes[0]
    width = x2 - x1
    height = y2 - y1
    
    # 2. Add margin to the bounding box
    margin_w = width * margin_percent
    margin_h = height * margin_percent
    
    crop_x1 = max(0, int(x1 - margin_w))
    crop_y1 = max(0, int(y1 - margin_h))
    crop_x2 = min(first_image.shape[1], int(x2 + margin_w))
    crop_y2 = min(first_image.shape[0], int(y2 + margin_h))
    
    # 3. Create crop transform
    crop_transform = A.Compose([
        A.Crop(x_min=crop_x1, y_min=crop_y1, x_max=crop_x2, y_max=crop_y2),
        A.Resize(height=output_size[1], width=output_size[0])
    ])
    
    # Apply the same crop transform to all images
    cropped_imgs = []
    for img in images_list:
        # Handle images of different sizes
        if img.shape[:2] != first_image.shape[:2]:
            # Resize to match first image
            img = cv2.resize(img, (first_image.shape[1], first_image.shape[0]))
        
        # Apply crop transform
        cropped_img = crop_transform(image=img)['image']
        cropped_imgs.append(cropped_img)
    
    # Save the cropped sequence as .npy if requested
    if save:
        if save_path is None:
            raise ValueError("save_path must be provided when save=True")
        
        # Create directory if it doesn't exist
        os.makedirs(os.path.dirname(os.path.abspath(save_path)), exist_ok=True)
        
        # Convert list to numpy array and save
        np_sequence = np.array(cropped_imgs)
        np.save(save_path, np_sequence)
        print(f"Saved cropped sequence to {save_path}, shape: {np_sequence.shape}")
    
    # 4. Return the cropped images and crop coordinates
    return cropped_imgs, (crop_x1, crop_y1, crop_x2, crop_y2)

# recording = processed_data.get("recording_0")
# rgb_images = recording["color_images"]

# cropped_frames, _ = crop_hands_sequence(
#     rgb_images,
#     output_size=(300, 300),
#     margin_percent=0.5,
#     save=True,
#     save_path="cropped_hands_sequence.npy"
# )

### 2.5. Windowing Process

In [None]:
# Get the color images from recording_0
rgb_images = recording_0["color_images"]

cropped_frames, _ = crop_hands_sequence(
    rgb_images,
    output_size=(300, 300),
    margin_percent=0.5
)

# Get multiple windows of 16 frames each
rgb_windows = sliding_window_sample(cropped_frames, window_size=16, stride=8)

# Visualize the windows
visualize_windows(rgb_windows)

### 2.6. Save Windowed Data into 📁 `rgb` and 📁 `depth`

In [None]:
# Prepare the recordings dictionary
recordings = prepare_recordings_dict(processed_data)

# Create a new dictionary to store the hand-cropped recordings
cropped_recordings = {}
output_size = (300, 300)  # Output size for cropped images
margin_percent = 0.5      # Extra margin around detected hands

# Insert loop to process each recording with `crop_hands_sequence`
for recording_key, recording_data in recordings.items():
    print(f"Processing {recording_key} for hand cropping...")
    
    # Create a copy of the recording data to modify
    cropped_recording_data = {
        "dataset": recording_data["dataset"]  # Preserve dataset identifier
    }
    
    # Process color images (RGB) first
    if "color_images" in recording_data and recording_data["color_images"] is not None:
        color_images = recording_data["color_images"]
        print(f"  Found {len(color_images)} color frames")
        
        # Apply hand detection and cropping to color images
        cropped_color_images, crop_coords = crop_hands_sequence(
            color_images,
            output_size=output_size,
            margin_percent=margin_percent,
            device='cuda:0',  # Adjust based on your setup
            bbox_thr=0.3,
            nms_thr=0.3
        )
        
        # Store the cropped color images
        cropped_recording_data["color_images"] = np.array(cropped_color_images)
        cropped_recording_data["crop_info"] = {
            "coords": crop_coords,
            "output_size": output_size
        }
        
        print(f"  Cropped {len(cropped_color_images)} color frames to size {output_size}")
    else:
        print(f"  No color images found for {recording_key}")
        cropped_recording_data["color_images"] = None
    
    # Process depth images using the same crop coordinates
    if "depth_images" in recording_data and recording_data["depth_images"] is not None:
        depth_images = recording_data["depth_images"]
        print(f"  Found {len(depth_images)} depth frames")
        
        # Only proceed if we have crop coordinates from color images
        if "crop_info" in cropped_recording_data:
            crop_coords = cropped_recording_data["crop_info"]["coords"]
            
            # Create crop transform
            x_min, y_min, x_max, y_max = crop_coords
            crop_transform = A.Compose([
                A.Crop(x_min=x_min, y_min=y_min, x_max=x_max, y_max=y_max),
                A.Resize(height=output_size[1], width=output_size[0])
            ])
            
            # Apply crop transform to depth images
            cropped_depth_images = []
            for depth_frame in depth_images:
                # Check if depth is 2D (single channel) or 3D
                if len(depth_frame.shape) == 2:
                    # Convert to 3-channel for albumentations
                    depth_3ch = np.stack([depth_frame] * 3, axis=-1)
                    transformed = crop_transform(image=depth_3ch)['image']
                    # Extract first channel to return to single-channel
                    cropped_depth_images.append(transformed[:, :, 0])
                else:
                    transformed = crop_transform(image=depth_frame)['image']
                    cropped_depth_images.append(transformed)
            
            # Store the cropped depth images
            cropped_recording_data["depth_images"] = np.array(cropped_depth_images)
            print(f"  Cropped {len(cropped_depth_images)} depth frames to size {output_size}")
        else:
            # No crop coordinates available (likely no color images), just resize
            print(f"  No crop coordinates available for {recording_key}, resizing depth images")
            resize_transform = A.Resize(height=output_size[1], width=output_size[0])
            resized_depth_images = [
                resize_transform(image=frame)['image'] for frame in depth_images
            ]
            cropped_recording_data["depth_images"] = np.array(resized_depth_images)
    else:
        print(f"  No depth images found for {recording_key}")
        cropped_recording_data["depth_images"] = None
    
    # Store the cropped recording data
    cropped_recordings[recording_key] = cropped_recording_data
    print(f"Completed processing {recording_key}")

save_windowed_data(cropped_recordings, bag_files, windowed_data_dir='windowed_hand_focused_data')

### 2.7. Data Class Generation

There are 5 class to be pick up from the windowed data: `Hand Open (HO)`, `Intrinsic Hand (IH)`, `Straight Fist (SF)`, `Hand Close (HC)`, and `Hook Hand (HH)`.

In [None]:
def create_dataset(recordings, bag_files, class_info_dict, output_dir="video", max_samples_per_recording=2, windowed_dir='windowed_data', modality_option="both"):
    """
    Create a balanced dataset with RGB and depth samples for each class from each recording.
    
    Args:
        recordings: Dictionary containing recordings with processed data
        bag_files: List of bag file paths
        class_info_dict: Dictionary mapping recording names to class window indices
        output_dir: Output directory name
        max_samples_per_recording: Maximum number of samples to include per class per recording
    """
    # Define class acronyms and their full names
    class_acronyms = ["OH", "IH", "SF", "HC", "HH"]
    class_full_names = {
        "OH": "Open Hand",
        "IH": "Intrinsic Hand",
        "SF": "Straight Fist",
        "HC": "Hand Close",
        "HH": "Hook Hand"
    }
    
    # Define modalities based on the modality_option parameter
    if modality_option.lower() == "rgb":
        modalities = ["rgb"]
    elif modality_option.lower() == "depth":
        modalities = ["depth"]
    elif modality_option.lower() == "both":
        modalities = ["rgb", "depth"]
    else:
        raise ValueError("Invalid modality_option. Choose 'rgb', 'depth', or 'both'.")
    
    # Create main output directory
    output_path = Path(output_dir)
    
    # Create class and modality folders
    for acronym in class_acronyms:
        for modality in modalities:
            class_modality_folder = output_path / modality / acronym
            os.makedirs(class_modality_folder, exist_ok=True)
    
    # Track which samples we've collected for each class
    collected_samples = defaultdict(list)
    
    # Process each recording
    for recording_idx, bag_file_path in enumerate(bag_files):
        bag_path = Path(bag_file_path)
        recording_name = bag_path.stem
        recording_key = f"recording_{recording_idx}"
        
        # Skip if recording doesn't exist in our data
        if recording_key not in recordings:
            print(f"Warning: {recording_key} not found in recordings, skipping...")
            continue
        
        # Skip if recording doesn't have class information
        if recording_name not in class_info_dict:
            print(f"Warning: No class information for {recording_name}, skipping...")
            continue
        
        # Get class information for this recording
        recording_class_info = class_info_dict[recording_name]
        
        # Process each class in this recording
        for acronym in class_acronyms:
            # Get available window indices for this class
            window_indices = recording_class_info.get(acronym, [])
            
            # Limit to the maximum number of samples per recording
            selected_indices = window_indices[:max_samples_per_recording]
            
            # For each selected window index
            for window_idx in selected_indices:
                # Process each modality (rgb and depth)
                for modality in modalities:
                    # Get windowed data directory for this recording and modality
                    data_dir = bag_path.parent / windowed_dir / modality
                    
                    # Build the source filename
                    window_filename = f"{recording_name}_window_{window_idx:03d}.npy"
                    source_path = data_dir / window_filename
                    
                    # Skip if source file doesn't exist
                    if not source_path.exists():
                        print(f"Warning: {modality} file {source_path} not found, skipping...")
                        continue
                    
                    # Create destination filename
                    dest_filename = f"{recording_name}_w{window_idx:03d}.npy"
                    dest_path = output_path / modality / acronym / dest_filename
                    
                    # Copy the file
                    shutil.copy2(source_path, dest_path)
                
                # Track which samples we've collected (only track once per sample, not per modality)
                collected_samples[acronym].append((recording_name, window_idx))
                
                print(f"Copied {class_full_names[acronym]} ({acronym}) sample from {recording_name}, window {window_idx} (both rgb and depth)")
    
    # Check if we have enough samples for each class (12 samples per class)
    target_count = len(bag_files) * max_samples_per_recording
    for acronym in class_acronyms:
        samples_count = len(collected_samples[acronym])
        
        print(f"Class {class_full_names[acronym]} ({acronym}): {samples_count}/{target_count} samples collected")
        
        # If we don't have enough samples, try to get more from recordings with extra samples
        if samples_count < target_count:
            additional_needed = target_count - samples_count
            print(f"Need {additional_needed} more samples for {acronym}")
            
            # Find all available samples for this class across all recordings
            all_available = []
            for recording_idx, bag_file_path in enumerate(bag_files):
                bag_path = Path(bag_file_path)
                recording_name = bag_path.stem
                
                if recording_name in class_info_dict and acronym in class_info_dict[recording_name]:
                    # Get indices that haven't been used yet
                    used_indices = [idx for name, idx in collected_samples[acronym] if name == recording_name]
                    available_indices = [idx for idx in class_info_dict[recording_name][acronym] if idx not in used_indices]
                    
                    for idx in available_indices:
                        all_available.append((recording_name, idx))
            
            # Use extra samples to fill in
            for i, (recording_name, window_idx) in enumerate(all_available):
                if i >= additional_needed:
                    break
                    
                # Find the bag file path for this recording
                recording_bag_path = None
                for bag_file_path in bag_files:
                    if Path(bag_file_path).stem == recording_name:
                        recording_bag_path = Path(bag_file_path)
                        break
                
                if recording_bag_path is None:
                    continue
                
                # Process each modality for this additional sample
                for modality in modalities:
                    # Get the source path
                    data_dir = recording_bag_path.parent / windowed_dir / modality
                    window_filename = f"{recording_name}_window_{window_idx:03d}.npy"
                    source_path = data_dir / window_filename
                    
                    # Skip if source file doesn't exist
                    if not source_path.exists():
                        print(f"Warning: Additional {modality} file {source_path} not found, skipping...")
                        continue
                        
                    # Create destination filename
                    dest_filename = f"{recording_name}_w{window_idx:03d}.npy"
                    dest_path = output_path / modality / acronym / dest_filename
                    
                    # Copy the file
                    shutil.copy2(source_path, dest_path)
                
                # Track which samples we've collected (only once per sample)
                collected_samples[acronym].append((recording_name, window_idx))
                
                print(f"Added extra {class_full_names[acronym]} ({acronym}) sample from {recording_name}, window {window_idx} (both rgb and depth)")
    
    # Final count and validation
    for acronym in class_acronyms:
        rgb_count = len(list((output_path / "rgb" / acronym).glob("*.npy")))
        depth_count = len(list((output_path / "depth" / acronym).glob("*.npy")))
        
        print(f"Final count for {class_full_names[acronym]} ({acronym}):")
        print(f"  - RGB: {rgb_count} samples")
        print(f"  - Depth: {depth_count} samples")
        
        # Check if counts match
        if rgb_count != depth_count:
            print(f"  - WARNING: RGB and depth counts don't match for {acronym}!")

# Use the provided dictionary directly
info_dict = {
    "Record_20250402151124": {
        "OH": [0, 10, 11],
        "IH": [1, 2],
        "SF": [3, 4],
        "HC": [5, 6],
        "HH": [7, 8]
    },
    "Record_20250402151609": {
        "OH": [0, 9],
        "IH": [1, 2],
        "SF": [3, 4],
        "HC": [5],
        "HH": [6, 7]
    },
    "Record_20250402152331": {
        "OH": [0, 1, 2, 12],
        "IH": [3, 4],
        "SF": [5, 6],
        "HC": [7],
        "HH": [8, 9]
    },
    "Record_20250506145704": {
        "OH": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
        "IH": [11, 12, 13, 14, 15],
        "SF": [16, 17, 18, 19, 20, 21],
        "HC": [22, 23, 24, 25, 26],
        "HH": [27, 28, 29, 30, 31]
    },
    "Record_20250506152951": {
        "OH": [0, 1, 2, 3, 4],
        "IH": [5, 6, 7, 8],
        "SF": [10, 11, 12, 13],
        "HC": [14, 15, 16, 17],
        "HH": [18, 19, 20, 21]
    },
    "Record_20250506162630": {
        "OH": [0, 1, 2, 3, 4, 5],
        "IH": [6, 7],
        "SF": [8],
        "HC": [9, 10, 11, 12],
        "HH": [13, 14]
    }
}

# Create the dataset
create_dataset(processed_data, bag_files, info_dict, output_dir=r"./data/video_hand_focused_data", windowed_dir='windowed_hand_focused_data', modality_option="rgb")

### 2.8. Synthetic Data Generation

#### Splitting into Training and Validation

In [None]:
import glob
from sklearn.model_selection import train_test_split

video_dir = r'D:\RESEARCH ASSISTANT\6. Depth Camera\CODE\Orbbec Gemini 2XL\REMOTE\DEVELOPMENT\notebook\DATA\video_hand_focused_data\rgb'
save_dir = os.path.join(os.path.dirname(video_dir), 'split_rgb')

# Define subfolders
train_dir = os.path.join(save_dir, 'train')
val_dir = os.path.join(save_dir, 'val')

# Create directory structure
for target_dir in [train_dir, val_dir]:
    os.makedirs(target_dir, exist_ok=True)

# Loop over each class folder
for class_name in os.listdir(video_dir):
    class_path = os.path.join(video_dir, class_name)
    if not os.path.isdir(class_path):
        continue

    # List all .npy files in the class
    npy_files = glob.glob(os.path.join(class_path, '*.npy'))

    # Split using sklearn
    train_files, val_files = train_test_split(npy_files, test_size=0.25, random_state=42)

    # Destination subfolders
    train_class_dir = os.path.join(train_dir, class_name)
    val_class_dir = os.path.join(val_dir, class_name)
    os.makedirs(train_class_dir, exist_ok=True)
    os.makedirs(val_class_dir, exist_ok=True)

    # Copy files to train
    for f in train_files:
        shutil.copy2(f, os.path.join(train_class_dir, os.path.basename(f)))

    # Copy files to val
    for f in val_files:
        shutil.copy2(f, os.path.join(val_class_dir, os.path.basename(f)))

print(f"Dataset split complete. Saved to: {save_dir}")

#### Data Generation

In [None]:
import os
import numpy as np
import shutil
import albumentations as A
from tqdm import tqdm

video_dir       = r'D:\RESEARCH ASSISTANT\6. Depth Camera\CODE\Orbbec Gemini 2XL\REMOTE\DEVELOPMENT\notebook\DATA\video_hand_focused_data'
train_dir       = os.path.join(video_dir, "split_rgb", "train")
aug_dir         = os.path.join(os.path.dirname(train_dir), "train_aug")
n_applications  = 6   # how many aug variants per original

transform = A.ReplayCompose([
    A.ElasticTransform(alpha=0.5, p=0.5),
    A.ShiftScaleRotate(scale_limit=0.05, rotate_limit=10, p=0.5),
    A.RGBShift(r_shift_limit=50, g_shift_limit=50, b_shift_limit=50, p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
    A.CLAHE(p=0.5),
    A.PixelDropout(drop_value=0, dropout_prob=0.01, p=0.5),
    A.PixelDropout(drop_value=255, dropout_prob=0.01, p=0.5),
    A.Blur(blur_limit=(2, 4), p=0.5)
])

def augment_numpy_video(arr: np.ndarray):
    # Original data is in T, H, W, C format
    T, H, W, C = arr.shape
    out_frames, replay = [], None

    for t in range(T):
        # No need to transpose the frame since it's already in H, W, C format
        frame = arr[t].astype("uint8")
        if t == 0:
            data   = transform(image=frame)
            new    = data["image"]
            replay = data["replay"]
        else:
            new = A.ReplayCompose.replay(replay, image=frame)["image"]
        out_frames.append(new)

    # Stack without transposing to keep H, W, C format
    return np.stack(out_frames, axis=0)

# Walk each class folder
for cls in tqdm(os.listdir(train_dir), desc="Processing classes"):
    src_cls_dir = os.path.join(train_dir, cls)
    dst_cls_dir = os.path.join(aug_dir, cls)

    # 1) Clear out any previous augmented files
    if os.path.isdir(dst_cls_dir):
        for old in os.listdir(dst_cls_dir):
            if "_aug" in old or old.endswith("_orig.npy"):
                os.remove(os.path.join(dst_cls_dir, old))
    else:
        os.makedirs(dst_cls_dir, exist_ok=True)

    # 2) Generate fresh augmentations
    file_list = [f for f in os.listdir(src_cls_dir) if f.endswith(".npy")]
    for fname in tqdm(file_list, desc=f"Augmenting {cls}", leave=False):
        src_path = os.path.join(src_cls_dir, fname)
        base, _  = os.path.splitext(fname)

        arr = np.load(src_path)

        # Save a base copy
        orig_path = os.path.join(dst_cls_dir, f"{base}_orig.npy")
        np.save(orig_path, arr)

        # Generate N augmentations
        for i in range(1, n_applications + 1):
            aug_arr = augment_numpy_video(arr)
            out_path = os.path.join(dst_cls_dir, f"{base}_aug{i}.npy")
            np.save(out_path, aug_arr)

    print(f"[{cls}] now has {len(os.listdir(dst_cls_dir))} files in {dst_cls_dir}")

print("✅ Done creating fresh synthetic training videos.")

#### Visualize Augmentation

In [None]:
import rerun.blueprint as rrb
import rerun as rr

def rerun_visualization_from_npy(npy_path: str):
    # Load the .npy video: shape (T, C, H, W)
    video_array = np.load(npy_path)
    print(f"Loaded {npy_path}, shape: {video_array.shape}")

    stream = rr.new_recording("rerun_augmented_video", spawn=True)

    # Configure layout
    blueprint = rrb.Blueprint(
        rrb.Grid(
            rrb.Vertical(
                rrb.Spatial2DView(origin="/color_image"),
            ),
        ),
        collapse_panels=True,
    )

    # Log each frame
    for idx, frame in enumerate(video_array):
        stream.set_time_sequence("frame", idx)
        stream.log("color_image", rr.Image(frame))

    stream.send_blueprint(blueprint)

classes = ['HC', 'HH', 'IH', 'OH', 'SF']
target_dir = os.path.join(aug_dir, classes[0])

# Files with _aug{i} suffix
augmented_files = glob.glob(os.path.join(target_dir, '*_aug*.npy'))

# Visualize first one
rerun_visualization_from_npy(augmented_files[-1])

### 2.9. Extract Hand Landmarks `X` and `Y` Coordinates

#### Hand Landmark Detection

In [5]:
from typing import Union, List, Tuple, Dict, Any, Optional
import numpy as np
import torch
from mmpose.apis import init_model as init_pose_estimator
from mmpose.apis import inference_topdown
from mmpose.structures import merge_data_samples

def pose(
    images: Union[np.ndarray, List[np.ndarray]],
    boxes: Union[np.ndarray, List[np.ndarray]],
    scores: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
    pose_config: str = 'configs/rtmpose-m_8xb256-210e_hand5-256x256.py',
    pose_checkpoint: str = 'https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-m_simcc-hand5_pt-aic-coco_210e-256x256-74fb594_20230320.pth',
    device: str = 'cuda:0',
    normalize_keypoints: bool = False,
    return_full_samples: bool = False
) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
    """
    Estimate pose keypoints for hand(s) in a single image or a list of images.
    
    Args:
        images: A single image or list of images (numpy arrays in RGB format)
        boxes: Bounding box(es) for hand detection either for a single image or matching the image list
        scores: Optional confidence scores for the bounding boxes
        pose_config: Path to pose estimation config file
        pose_checkpoint: Path or URL to pose estimation checkpoint
        device: Device to run inference on ('cuda:0', 'cpu', etc.)
        normalize_keypoints: Whether to normalize keypoint coordinates to 0-1 range
        return_full_samples: Whether to return the full data_samples or just the keypoints and scores
        
    Returns:
        For a single image:
            A dictionary containing:
                - 'keypoints': numpy array of shape (num_hands, num_keypoints, 2) with x,y coordinates
                - 'scores': numpy array of shape (num_hands, num_keypoints) with confidence scores
                - 'boxes': numpy array of shape (num_hands, 4) with bounding boxes
                - 'box_scores': numpy array of shape (num_hands) with bounding box confidence scores
                - (Optional) 'data_samples': full pose estimation results
                
        For multiple images:
            A list of dictionaries as described above, one for each image
    """
    global MODEL_CACHE
    
    # Initialize or get cached pose estimator
    model_key = f"{pose_config}_{pose_checkpoint}_{device}"
    if not hasattr(pose, 'MODEL_CACHE'):
        pose.MODEL_CACHE = {}
        
    if model_key not in pose.MODEL_CACHE:
        pose_estimator = init_pose_estimator(
            pose_config,
            pose_checkpoint,
            device=device
        )
        pose.MODEL_CACHE[model_key] = pose_estimator
    else:
        pose_estimator = pose.MODEL_CACHE[model_key]
    
    # Handle single image vs list of images
    is_single_image = not isinstance(images, list)
    
    if is_single_image:
        images = [images]
        boxes = [boxes]
        if scores is not None:
            scores = [scores]
    
    # Make sure each image has corresponding boxes
    assert len(images) == len(boxes), f"Number of images ({len(images)}) must match number of box sets ({len(boxes)})"
    
    # Process each image
    all_results = []
    
    for i, (image, image_boxes) in enumerate(zip(images, boxes)):
        # Skip processing if no hands detected
        if len(image_boxes) == 0:
            result = {
                'keypoints': np.empty((0, 21, 2), dtype=np.float32),
                'scores': np.empty((0, 21), dtype=np.float32),
                'boxes': np.empty((0, 4), dtype=np.float32),
                'box_scores': np.empty((0,), dtype=np.float32),
            }
            if return_full_samples:
                result['data_samples'] = None
            all_results.append(result)
            continue
            
        # Run pose estimation
        pose_results = inference_topdown(pose_estimator, image, image_boxes)
        data_samples = merge_data_samples(pose_results)
        
        # Extract keypoints and scores
        result = {}
        
        if hasattr(data_samples, 'pred_instances'):
            # Extract keypoints from pred_instances
            if hasattr(data_samples.pred_instances, 'keypoints'):
                keypoints = data_samples.pred_instances.keypoints
                keypoint_scores = data_samples.pred_instances.keypoint_scores
                
                # Normalize keypoint coordinates if requested
                if normalize_keypoints:
                    img_height, img_width = image.shape[:2]
                    keypoints[:, :, 0] /= img_width
                    keypoints[:, :, 1] /= img_height
                
                result['keypoints'] = keypoints
                result['scores'] = keypoint_scores
                result['boxes'] = image_boxes
                
                # Add box scores if available
                if scores is not None:
                    result['box_scores'] = scores[i]
                else:
                    # Default box scores to 1.0 if not provided
                    result['box_scores'] = np.ones(len(image_boxes), dtype=np.float32)
                
                if return_full_samples:
                    result['data_samples'] = data_samples
            else:
                # No keypoints detected
                result = {
                    'keypoints': np.empty((0, 21, 2), dtype=np.float32),
                    'scores': np.empty((0, 21), dtype=np.float32),
                    'boxes': image_boxes,
                    'box_scores': scores[i] if scores is not None else np.ones(len(image_boxes), dtype=np.float32),
                }
                if return_full_samples:
                    result['data_samples'] = data_samples
        else:
            # No instances detected
            result = {
                'keypoints': np.empty((0, 21, 2), dtype=np.float32),
                'scores': np.empty((0, 21), dtype=np.float32),
                'boxes': np.empty((0, 4), dtype=np.float32),
                'box_scores': np.empty((0,), dtype=np.float32),
            }
            if return_full_samples:
                result['data_samples'] = None
        
        all_results.append(result)
    
    # Return single result for single image input
    if is_single_image:
        return all_results[0]
    else:
        return all_results

#### Visualize Hand Landmark Detection

In [None]:
from typing import Union, List, Dict, Any, Optional
import numpy as np
import cv2

def visualize_hand_landmarks(images: Union[np.ndarray, List[np.ndarray]], 
                           pose_results: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
                           keypoint_threshold: float = 0.3,
                           bbox_thickness: int = 2,
                           keypoint_radius: int = 4,
                           skeleton_thickness: int = 2,
                           show_bbox: bool = True,
                           show_keypoints: bool = True,
                           show_skeleton: bool = True,
                           show_keypoint_labels: bool = False,
                           show_confidence: bool = True,
                           show_handedness: bool = True,
                           backend: str = 'opencv') -> Union[np.ndarray, List[np.ndarray]]:
    """
    Visualize detected hand landmarks and bounding boxes on a single image or multiple images.
    
    Args:
        images: Input image(s) (numpy array or list of arrays in RGB format)
        pose_results: Results from the pose() function
                    For single image: Dictionary with 'keypoints', 'scores', 'boxes', 'box_scores'
                    For multiple images: List of dictionaries, one for each image
                    If None, no landmarks or boxes will be drawn
        keypoint_threshold: Minimum confidence threshold for displaying keypoints
        bbox_thickness: Thickness of bounding box lines
        keypoint_radius: Radius of keypoint circles
        skeleton_thickness: Thickness of skeleton lines
        show_bbox: Whether to show bounding boxes
        show_keypoints: Whether to show keypoints
        show_skeleton: Whether to show skeleton connections
        show_keypoint_labels: Whether to show keypoint labels
        show_confidence: Whether to show confidence scores
        show_handedness: Whether to show hand type (left/right prediction)
        backend: Backend to use for visualization ('opencv' or 'matplotlib')
        
    Returns:
        For single image: Visualized image with landmarks and bounding boxes
        For multiple images: List of visualized images
    """
    # Define hand keypoint connections for skeleton based on onehand10k dataset
    skeleton_connections = [
        # Thumb connections
        (0, 1), (1, 2), (2, 3), (3, 4),  
        # Index finger (forefinger) connections
        (0, 5), (5, 6), (6, 7), (7, 8),  
        # Middle finger connections
        (0, 9), (9, 10), (10, 11), (11, 12),  
        # Ring finger connections
        (0, 13), (13, 14), (14, 15), (15, 16),  
        # Pinky connections
        (0, 17), (17, 18), (18, 19), (19, 20)  
    ]
    
    # Keypoint names from onehand10k dataset
    keypoint_names = [
        "wrist", 
        "thumb1", "thumb2", "thumb3", "thumb4",
        "forefinger1", "forefinger2", "forefinger3", "forefinger4",
        "middle_finger1", "middle_finger2", "middle_finger3", "middle_finger4",
        "ring_finger1", "ring_finger2", "ring_finger3", "ring_finger4",
        "pinky_finger1", "pinky_finger2", "pinky_finger3", "pinky_finger4"
    ]
    
    # Keypoint colors from onehand10k dataset (BGR format for OpenCV)
    keypoint_colors = [
        (255, 255, 255),       # wrist - white
        (0, 128, 255),         # thumb1 - orange
        (0, 128, 255),         # thumb2 - orange
        (0, 128, 255),         # thumb3 - orange
        (0, 128, 255),         # thumb4 - orange
        (255, 153, 255),       # forefinger1 - pink
        (255, 153, 255),       # forefinger2 - pink
        (255, 153, 255),       # forefinger3 - pink
        (255, 153, 255),       # forefinger4 - pink
        (255, 178, 102),       # middle_finger1 - light blue
        (255, 178, 102),       # middle_finger2 - light blue
        (255, 178, 102),       # middle_finger3 - light blue
        (255, 178, 102),       # middle_finger4 - light blue
        (51, 51, 255),         # ring_finger1 - red
        (51, 51, 255),         # ring_finger2 - red
        (51, 51, 255),         # ring_finger3 - red
        (51, 51, 255),         # ring_finger4 - red
        (0, 255, 0),           # pinky_finger1 - green
        (0, 255, 0),           # pinky_finger2 - green
        (0, 255, 0),           # pinky_finger3 - green
        (0, 255, 0)            # pinky_finger4 - green
    ]
    
    # Skeleton colors from onehand10k dataset (BGR format for OpenCV)
    skeleton_colors = [
        # Thumb connections
        (0, 128, 255), (0, 128, 255), (0, 128, 255), (0, 128, 255),
        # Index (forefinger) connections
        (255, 153, 255), (255, 153, 255), (255, 153, 255), (255, 153, 255),
        # Middle finger connections
        (255, 178, 102), (255, 178, 102), (255, 178, 102), (255, 178, 102),
        # Ring finger connections
        (51, 51, 255), (51, 51, 255), (51, 51, 255), (51, 51, 255),
        # Pinky connections
        (0, 255, 0), (0, 255, 0), (0, 255, 0), (0, 255, 0)
    ]
    
    # Default bbox color in BGR
    bbox_color = (0, 255, 0)  # Green
    
    # Check if input is a single image or a list of images
    is_single_image = not isinstance(images, list)
    
    # If single image, convert to list format to unify processing
    if is_single_image:
        images_list = [images]
        # If pose_results is provided and is a dict, convert to list
        if pose_results is not None and isinstance(pose_results, dict):
            pose_results_list = [pose_results]
        else:
            pose_results_list = None
    else:
        # For multiple images
        images_list = images
        # If pose_results is provided, use it
        pose_results_list = pose_results if isinstance(pose_results, list) else None
    
    # Process each image
    result_images = []
    
    for i, image in enumerate(images_list):
        # Make a copy of the image to avoid modifying the original
        vis_image = image.copy()
        
        # Convert RGB to BGR for OpenCV
        if vis_image.shape[2] == 3 and backend == 'opencv':
            vis_image = cv2.cvtColor(vis_image, cv2.COLOR_RGB2BGR)
        
        # Skip processing if no pose results
        if pose_results_list is None or i >= len(pose_results_list) or pose_results_list[i] is None:
            # Just add the original image to the result and continue
            if vis_image.shape[2] == 3 and backend == 'opencv':
                vis_image = cv2.cvtColor(vis_image, cv2.COLOR_BGR2RGB)
            result_images.append(vis_image)
            continue
        
        # Get pose results for this image
        result = pose_results_list[i]
        
        # Extract keypoints, bboxes and scores
        keypoints_array = result.get('keypoints', None)
        keypoint_scores = result.get('scores', None)
        current_bboxes = result.get('boxes', None)
        current_scores = result.get('box_scores', None)
        
        # Skip if no keypoints or bboxes
        if (keypoints_array is None or len(keypoints_array) == 0) and (current_bboxes is None or len(current_bboxes) == 0):
            if vis_image.shape[2] == 3 and backend == 'opencv':
                vis_image = cv2.cvtColor(vis_image, cv2.COLOR_BGR2RGB)
            result_images.append(vis_image)
            continue
        
        # Draw bounding boxes if available and requested
        if show_bbox and current_bboxes is not None and len(current_bboxes) > 0:
            for idx, box in enumerate(current_bboxes):
                x1, y1, x2, y2 = box.astype(int)
                
                # Draw the bounding box
                cv2.rectangle(vis_image, (x1, y1), (x2, y2), bbox_color, bbox_thickness)
                
                # Add confidence score if requested
                if show_confidence and current_scores is not None and idx < len(current_scores):
                    score = current_scores[idx]
                    conf_text = f"Conf: {score:.2f}"
                    text_size = cv2.getTextSize(conf_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0]
                    cv2.rectangle(vis_image, 
                                (x1, y1 - text_size[1] - 5), 
                                (x1 + text_size[0], y1), 
                                bbox_color, -1)
                    cv2.putText(vis_image, conf_text, (x1, y1 - 5), 
                            cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
        
        # Draw keypoints and skeleton if available and requested
        if (show_keypoints or show_skeleton) and keypoints_array is not None and len(keypoints_array) > 0:
            for hand_idx, hand_keypoints in enumerate(keypoints_array):
                # Skip hands with no valid keypoints
                if len(hand_keypoints) == 0:
                    continue
                
                # Get scores for this hand
                hand_scores = keypoint_scores[hand_idx] if keypoint_scores is not None else np.ones(hand_keypoints.shape[0])
                
                # Create list of valid keypoints
                valid_keypoints = []
                for j, (kp, score) in enumerate(zip(hand_keypoints, hand_scores)):
                    if score >= keypoint_threshold:
                        x, y = int(kp[0]), int(kp[1])
                        valid_keypoints.append((j, x, y, score))
                
                # Draw skeleton if requested
                if show_skeleton:
                    # Create a dictionary for quick access to keypoint coordinates
                    kp_dict = {idx: (x, y) for idx, x, y, _ in valid_keypoints}
                    
                    # Draw connections
                    for j, connection in enumerate(skeleton_connections):
                        idx1, idx2 = connection
                        if idx1 in kp_dict and idx2 in kp_dict:
                            pt1 = kp_dict[idx1]
                            pt2 = kp_dict[idx2]
                            # Use the color specific to this connection
                            color = skeleton_colors[j] if j < len(skeleton_colors) else (255, 255, 255)
                            cv2.line(vis_image, pt1, pt2, color, skeleton_thickness)
                
                # Draw keypoints if requested
                if show_keypoints:
                    for idx, x, y, conf in valid_keypoints:
                        # Use the color specific to this keypoint
                        color = keypoint_colors[idx] if idx < len(keypoint_colors) else (255, 255, 255)
                        
                        # Draw keypoint circle with filled center
                        cv2.circle(vis_image, (x, y), keypoint_radius, color, -1)
                        
                        # Add a small border to make keypoints more visible
                        cv2.circle(vis_image, (x, y), keypoint_radius, (0, 0, 0), 1)
                        
                        # Add keypoint labels if requested
                        if show_keypoint_labels and idx < len(keypoint_names):
                            label = f"{keypoint_names[idx]}"
                            if show_confidence:
                                label += f" ({conf:.2f})"
                            
                            # Place text near the keypoint
                            text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1)[0]
                            
                            # Create a dark background for text
                            cv2.rectangle(vis_image, 
                                        (x + 5, y - text_size[1] - 5), 
                                        (x + 5 + text_size[0], y), 
                                        (0, 0, 0), -1)
                            
                            # Add text label
                            cv2.putText(vis_image, label, (x + 5, y - 5), 
                                    cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
                
                # Determine handedness if requested
                if show_handedness and len(valid_keypoints) > 0:
                    # Simplified handedness estimation using keypoint positions
                    wrist_idx = 0
                    thumb_tip_idx = 4
                    index_tip_idx = 8
                    
                    # Create a dictionary of keypoint positions
                    kp_positions = {idx: (x, y) for idx, x, y, _ in valid_keypoints}
                    
                    hand_type = "Unknown"
                    
                    # Check if we have the necessary keypoints
                    if wrist_idx in kp_positions and thumb_tip_idx in kp_positions and index_tip_idx in kp_positions:
                        wrist = kp_positions[wrist_idx]
                        thumb_tip = kp_positions[thumb_tip_idx]
                        index_tip = kp_positions[index_tip_idx]
                        
                        # Calculate vectors from wrist to fingertips
                        wrist_to_thumb = (thumb_tip[0] - wrist[0], thumb_tip[1] - wrist[1])
                        wrist_to_index = (index_tip[0] - wrist[0], index_tip[1] - wrist[1])
                        
                        # Cross product to determine if thumb is to the left or right of index finger
                        cross_product = wrist_to_thumb[0] * wrist_to_index[1] - wrist_to_thumb[1] * wrist_to_index[0]
                        
                        # If cross product is positive, thumb is to the right of index finger (likely left hand)
                        # If negative, thumb is to the left of index finger (likely right hand)
                        if cross_product > 0:
                            hand_type = "Left"
                        else:
                            hand_type = "Right"
                    
                    # Add hand type text
                    hand_text = f"{hand_type} Hand"
                    
                    if current_bboxes is not None and hand_idx < len(current_bboxes):
                        x1, y1, x2, y2 = current_bboxes[hand_idx].astype(int)
                        
                        text_size = cv2.getTextSize(hand_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0]
                        
                        # Place text above the bounding box
                        text_y = y1 - text_size[1] - 15 if show_confidence else y1 - text_size[1] - 5
                        text_x = x1
                        
                        cv2.rectangle(vis_image, 
                                    (text_x, text_y - text_size[1]), 
                                    (text_x + text_size[0], text_y + 5), 
                                    (0, 0, 255), -1)  # Blue background
                        cv2.putText(vis_image, hand_text, (text_x, text_y), 
                                cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
                    else:
                        # If no bounding box, place text at the wrist position
                        if wrist_idx in kp_positions:
                            wrist_x, wrist_y = kp_positions[wrist_idx]
                            text_size = cv2.getTextSize(hand_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0]
                            cv2.rectangle(vis_image, 
                                        (wrist_x, wrist_y - text_size[1] - 15), 
                                        (wrist_x + text_size[0], wrist_y - 10), 
                                        (0, 0, 255), -1)  # Blue background
                            cv2.putText(vis_image, hand_text, (wrist_x, wrist_y - 15), 
                                    cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
        
        # Convert back to RGB if needed
        if vis_image.shape[2] == 3 and backend == 'opencv':
            vis_image = cv2.cvtColor(vis_image, cv2.COLOR_BGR2RGB)
        
        # Add the processed image to the result list
        result_images.append(vis_image)
    
    # Return a single image if input was a single image, otherwise return the list
    if is_single_image:
        return result_images[0]
    else:
        return result_images

recording = processed_data.get("recording_0")
rgb_images = recording["color_images"]

cropped_frames, _ = crop_hands_sequence(
    rgb_images,
    output_size=(300, 300),
    margin_percent=0.5,
)

boxes_list = []
scores_list = []
crop_info_list = []

for frame in cropped_frames:
    boxes, scores, crop_info = detect_hands(
        frame,
        return_coordinates='original',
        crop_size=(300, 300),
    )
    boxes_list.append(boxes)
    scores_list.append(scores)
    crop_info_list.append(crop_info)

pose_results = pose(
    images=cropped_frames,
    boxes=boxes_list,
    scores=scores_list
)

visualized_frames = visualize_hand_landmarks(
    images=cropped_frames,
    pose_results=pose_results,
    keypoint_threshold=0
)

#### Create Keypoint Dataset

##### **Class Defenition**

In [6]:
import numpy as np
import os
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import json
from tqdm import tqdm
from datetime import datetime

class HandPoseLSTMDatasetCreator:
    """
    Create LSTM dataset from windowed hand video data
    
    Class mapping:
    - OH: Open Hand (0)
    - IH: Intrinsic Plus (1) 
    - SF: Straight Fist (2)
    - HH: Hook Hand (3)
    - HC: Hand Close (4)
    """
    
    def __init__(self, data_root: str, window_size: int = 16, num_keypoints: int = 21,
                 flatten_keypoints: bool = False):
        self.data_root = Path(data_root)
        self.window_size = window_size
        self.num_keypoints = num_keypoints
        self.flatten_keypoints = flatten_keypoints
        
        # Class mapping
        self.class_mapping = {
            'OH': 0,  # Open Hand
            'IH': 1,  # Intrinsic Plus
            'SF': 2,  # Straight Fist
            'HH': 3,  # Hook Hand
            'HC': 4   # Hand Close
        }
        
        self.class_names = {v: k for k, v in self.class_mapping.items()}
        
        # Data containers
        self.sequences = []
        self.labels = []
        self.metadata = []
        self.failed_samples = []
        
    def load_window_data(self, file_path: str) -> np.ndarray:
        """Load windowed data from .npy file"""
        try:
            data = np.load(file_path)
            print(f"Loaded {file_path}: shape {data.shape}")
            return data
        except Exception as e:
            print(f"Error loading {file_path}: {e}")
            return None
    
    def process_single_window(self, rgb_frames: np.ndarray, class_name: str, 
                            file_name: str) -> Tuple[Optional[np.ndarray], Dict]:
        """
        Process a single window of 16 frames
        
        Args:
            rgb_frames: numpy array of shape (16, H, W, 3)
            class_name: class abbreviation (OH, IH, SF, HH, HC)
            file_name: original file name for metadata
            
        Returns:
            pose_sequence: numpy array of shape (16, 21, 2) or None if failed
            metadata: dictionary with processing information
        """
        metadata = {
            'file_name': file_name,
            'class_name': class_name,
            'class_id': self.class_mapping[class_name],
            'timestamp': datetime.now().isoformat(),
            'num_frames': len(rgb_frames),
            'processing_status': 'started'
        }
        
        try:
            # Detect hands for all frames
            boxes_list = []
            scores_list = []
            crop_info_list = []
            
            for i, frame in enumerate(rgb_frames):
                boxes, scores, crop_info = detect_hands(
                    frame,
                    return_coordinates='original',
                    crop_size=(300, 300),
                )
                boxes_list.append(boxes)
                scores_list.append(scores)
                crop_info_list.append(crop_info)
            
            # Convert numpy array to list for pose function
            rgb_frames_list = [frame for frame in rgb_frames]
            
            # Get pose results for all frames
            pose_results = pose(
                images=rgb_frames_list,
                boxes=boxes_list,
                scores=scores_list
            )
            
            # Extract pose coordinates
            pose_sequence = []
            valid_frames = 0
            
            for frame_idx, result in enumerate(pose_results):
                # Initialize with zeros
                frame_keypoints = np.zeros((self.num_keypoints, 2))
                
                if result is not None:
                    # Handle the specific dictionary format from pose()
                    if isinstance(result, dict) and 'keypoints' in result:
                        keypoints = result['keypoints']
                        if isinstance(keypoints, np.ndarray) and keypoints.size > 0:
                            if keypoints.shape == (1, self.num_keypoints, 2):
                                frame_keypoints = keypoints[0].copy()
                                valid_frames += 1
                            elif keypoints.shape == (self.num_keypoints, 2):
                                frame_keypoints = keypoints.copy()
                                valid_frames += 1
                            else:
                                print(f"Frame {frame_idx}: Unexpected keypoints shape: {keypoints.shape}")
                    elif isinstance(result, list) and len(result) > 0:
                        # If result is a list of detections
                        first_detection = result[0]
                        if isinstance(first_detection, dict) and 'keypoints' in first_detection:
                            keypoints = first_detection['keypoints']
                            if isinstance(keypoints, np.ndarray) and keypoints.size > 0:
                                if keypoints.shape == (1, self.num_keypoints, 2):
                                    frame_keypoints = keypoints[0].copy()
                                    valid_frames += 1
                                elif keypoints.shape == (self.num_keypoints, 2):
                                    frame_keypoints = keypoints.copy()
                                    valid_frames += 1
                
                pose_sequence.append(frame_keypoints)
            
            pose_sequence = np.array(pose_sequence)
            
            # Update metadata
            metadata['valid_frames'] = valid_frames
            metadata['detection_rate'] = valid_frames / len(rgb_frames)
            metadata['boxes_detected'] = [len(b) if b is not None else 0 for b in boxes_list]
            
            # Check if we have enough valid detections
            min_required_frames = int(self.window_size * 0.3)  # Lowered to 30% for debugging
            if valid_frames < min_required_frames:
                metadata['processing_status'] = 'insufficient_detections'
                print(f"Insufficient detections for {file_name}: {valid_frames}/{self.window_size} frames (need at least {min_required_frames})")
                return None, metadata
            
            metadata['processing_status'] = 'success'
            return pose_sequence, metadata
            
        except Exception as e:
            metadata['processing_status'] = 'failed'
            metadata['error'] = str(e)
            print(f"Error processing {file_name}: {e}")
            import traceback
            traceback.print_exc()
            return None, metadata
    
    def normalize_sequence(self, pose_sequence: np.ndarray, method: str = 'wrist') -> np.ndarray:
        """Normalize pose sequence"""
        normalized = pose_sequence.copy()
        
        if method == 'wrist':
            # Normalize relative to wrist position
            for i in range(len(pose_sequence)):
                # Check if frame has valid detections (not all zeros)
                if np.any(pose_sequence[i] != 0):
                    wrist_pos = pose_sequence[i, 0, :]
                    normalized[i] = pose_sequence[i] - wrist_pos
                    
        elif method == 'bbox':
            # Normalize to bounding box
            for i in range(len(pose_sequence)):
                # Check if frame has valid detections
                if np.any(pose_sequence[i] != 0):
                    # Get non-zero points only
                    valid_mask = ~np.all(pose_sequence[i] == 0, axis=1)
                    if np.any(valid_mask):
                        valid_points = pose_sequence[i][valid_mask]
                        min_vals = valid_points.min(axis=0)
                        max_vals = valid_points.max(axis=0)
                        range_vals = max_vals - min_vals
                        range_vals[range_vals == 0] = 1  # Avoid division by zero
                        
                        # Normalize only valid points
                        for j in range(len(pose_sequence[i])):
                            if valid_mask[j]:
                                normalized[i, j] = (pose_sequence[i, j] - min_vals) / range_vals
        
        return normalized
    
    def process_dataset(self, normalize: str = 'wrist', skip_failed: bool = True):
        """Process entire dataset"""
        rgb_root = self.data_root / 'rgb'
        
        if not rgb_root.exists():
            raise ValueError(f"RGB data directory not found: {rgb_root}")
        
        # Process each class
        for class_name in self.class_mapping.keys():
            class_dir = rgb_root / class_name
            
            if not class_dir.exists():
                print(f"Warning: Class directory not found: {class_dir}")
                continue
            
            # Get all .npy files in class directory
            npy_files = list(class_dir.glob("*.npy"))
            print(f"\nProcessing class {class_name}: {len(npy_files)} files")
            
            for npy_file in tqdm(npy_files, desc=f"Processing {class_name}"):
                # Load RGB frames
                rgb_frames = self.load_window_data(str(npy_file))
                
                if rgb_frames is None:
                    continue
                
                # Process window
                pose_sequence, metadata = self.process_single_window(
                    rgb_frames, 
                    class_name, 
                    npy_file.name
                )
                
                if pose_sequence is not None:
                    # Normalize if requested
                    if normalize:
                        pose_sequence = self.normalize_sequence(pose_sequence, normalize)
                    
                    # Flatten keypoints if requested  # ADD THIS BLOCK
                    if self.flatten_keypoints:
                        # Reshape from (16, 21, 2) to (16, 42)
                        pose_sequence = pose_sequence.reshape(self.window_size, -1)
                    
                    # Add to dataset
                    self.sequences.append(pose_sequence)
                    self.labels.append(self.class_mapping[class_name])
                    self.metadata.append(metadata)
                else:
                    self.failed_samples.append(metadata)
        
        print(f"\nProcessing complete:")
        print(f"Total samples: {len(self.sequences)}")
        print(f"Failed samples: {len(self.failed_samples)}")
        
        # Print class distribution
        if self.labels:
            unique, counts = np.unique(self.labels, return_counts=True)
            print("\nClass distribution:")
            for cls, count in zip(unique, counts):
                print(f"  {self.class_names[cls]}: {count} samples")
    
    def prepare_lstm_format(self) -> Tuple[np.ndarray, np.ndarray]:
        """Convert to LSTM input format"""
        if not self.sequences:
            raise ValueError("No sequences to prepare. Run process_dataset first.")
        
        # Stack sequences
        if self.flatten_keypoints:
            # Already flattened during processing
            X = np.stack(self.sequences)  # (num_samples, 16, 42)
        else:
            # Need to flatten now
            X_3d = np.stack(self.sequences)  # (num_samples, 16, 21, 2)
            num_samples = X_3d.shape[0]
            X = X_3d.reshape(num_samples, self.window_size, -1)  # (num_samples, 16, 42)
        
        y = np.array(self.labels)
        
        return X, y
    
    def save_by_folder_structure(self, output_base_path: str, save_individual_files: bool = True):
        """
        Save processed landmarks maintaining the original folder structure
        
        Args:
            output_base_path: Base path for saving (hand_landmark folder will be created here)
            save_individual_files: If True, saves each sample as individual .npy file
                                 If False, saves all samples per class in one file
        """
        # Create hand_landmark directory structure
        output_base = Path(output_base_path).parent if Path(output_base_path).suffix else Path(output_base_path)
        landmark_root = output_base
        
        # Organize sequences by class
        class_sequences = {class_name: [] for class_name in self.class_mapping.keys()}
        class_metadata = {class_name: [] for class_name in self.class_mapping.keys()}
        
        # Group sequences by class
        for seq, label, meta in zip(self.sequences, self.labels, self.metadata):
            class_name = self.class_names[label]
            class_sequences[class_name].append(seq)
            class_metadata[class_name].append(meta)
        
        # Save by class
        for class_name, sequences in class_sequences.items():
            if not sequences:
                continue
                
            class_dir = landmark_root / class_name
            class_dir.mkdir(parents=True, exist_ok=True)
            
            if save_individual_files:
                # Save each sample individually with original filename
                for seq, meta in zip(sequences, class_metadata[class_name]):
                    original_name = meta['file_name']
                    # Change extension from .npy to _landmarks.npy
                    landmark_name = original_name.replace('.npy', '_landmarks.npy')
                    save_path = class_dir / landmark_name
                    
                    # Save landmark data
                    np.save(save_path, seq)
                    
                    # Save metadata as json
                    meta_path = save_path.with_suffix('.json')
                    with open(meta_path, 'w') as f:
                        json.dump(meta, f, indent=2)
                        
                print(f"Saved {len(sequences)} individual files for class {class_name}")
            else:
                # Save all samples of this class in one file
                class_data = {
                    'sequences': np.array(sequences),  # (n_samples, 16, 21, 2)
                    'metadata': class_metadata[class_name],
                    'class_name': class_name,
                    'class_id': self.class_mapping[class_name]
                }
                
                save_path = class_dir / f'{class_name}_all_landmarks.npz'
                np.savez_compressed(save_path, **class_data)
                print(f"Saved {len(sequences)} samples for class {class_name} in {save_path}")
        
        # Save summary file
        summary = {
            'total_samples': len(self.sequences),
            'failed_samples': len(self.failed_samples),
            'class_distribution': {cls: len(seqs) for cls, seqs in class_sequences.items()},
            'window_size': self.window_size,
            'num_keypoints': self.num_keypoints,
            'processing_date': datetime.now().isoformat()
        }
        
        summary_path = landmark_root / 'dataset_summary.json'
        with open(summary_path, 'w') as f:
            json.dump(summary, f, indent=2)
        
        print(f"\nDataset saved to: {landmark_root}")
        print(f"Total processed samples: {len(self.sequences)}")
        
    def save_dataset(self, output_path: str, save_format: str = 'npz', 
                 maintain_structure: bool = False, save_individual: bool = True):
        """
        Save processed dataset
        
        Args:
            output_path: Path for saving the dataset
            save_format: 'npz' or 'separate' (only used when maintain_structure=False)
            maintain_structure: If True, saves in folder structure; if False, saves as single file
            save_individual: If True and maintain_structure=True, saves each sample separately
        """
        
        # Check if we have any data to save
        if not self.sequences:
            print("\n!!! WARNING: No sequences were successfully processed !!!")
            print(f"Total failed samples: {len(self.failed_samples)}")
            if self.failed_samples:
                print("\nDetailed failure analysis:")
                failure_reasons = {}
                for failed in self.failed_samples:
                    reason = failed.get('processing_status', 'unknown')
                    failure_reasons[reason] = failure_reasons.get(reason, 0) + 1
                
                for reason, count in failure_reasons.items():
                    print(f"  - {reason}: {count} samples")
                
                print("\nFirst 5 failed samples:")
                for failed in self.failed_samples[:5]:
                    print(f"\n  File: {failed['file_name']}")
                    print(f"  Status: {failed['processing_status']}")
                    if 'error' in failed:
                        print(f"  Error: {failed['error']}")
                    if 'detection_rate' in failed:
                        print(f"  Detection rate: {failed['detection_rate']:.2%}")
                    if 'valid_frames' in failed:
                        print(f"  Valid frames: {failed['valid_frames']}/{self.window_size}")
            
            print("\nNo dataset file was created due to lack of valid samples.")
            return
        
        # If maintain_structure is True, use folder structure saving
        if maintain_structure:
            self.save_by_folder_structure(output_base_path=output_path, save_individual_files=save_individual)
            return
        
        # Otherwise, use the original saving method (single file)
        output_path = Path(output_path)
        output_path.parent.mkdir(parents=True, exist_ok=True)
        
        # Prepare data
        X, y = self.prepare_lstm_format()
        
        if save_format == 'npz':
            # Save as compressed numpy archive
            save_dict = {
                'X': X,  # LSTM format
                'y': y,  # Labels
                'sequences_3d': np.stack(self.sequences),  # Original 3D format
                'class_mapping': self.class_mapping,
                'class_names': self.class_names,
                'metadata': self.metadata,
                'failed_samples': self.failed_samples,
                'window_size': self.window_size,
                'num_keypoints': self.num_keypoints,
            }
            
            np.savez_compressed(str(output_path), **save_dict)
            print(f"\nDataset saved to: {output_path}")
            
        elif save_format == 'separate':
            # Save as separate files
            base_path = output_path.with_suffix('')
            np.save(f"{base_path}_X.npy", X)
            np.save(f"{base_path}_y.npy", y)
            np.save(f"{base_path}_sequences_3d.npy", np.stack(self.sequences))
            
            # Save metadata as JSON
            meta_dict = {
                'class_mapping': self.class_mapping,
                'class_names': self.class_names,
                'window_size': self.window_size,
                'num_keypoints': self.num_keypoints,
                'total_samples': len(X),
                'failed_samples': len(self.failed_samples),
                'metadata': self.metadata,
                'failed_samples_info': self.failed_samples
            }
            
            with open(f"{base_path}_metadata.json", 'w') as f:
                json.dump(meta_dict, f, indent=2)
            
            print(f"\nDataset saved to: {base_path}_*.npy/json")
        
        print(f"Dataset shape - X: {X.shape}, y: {y.shape}")
        
    def split_dataset(self, train_ratio: float = 0.8, random_seed: int = 42):
        """Split dataset into train and validation sets"""
        X, y = self.prepare_lstm_format()
        
        np.random.seed(random_seed)
        n_samples = len(y)
        indices = np.random.permutation(n_samples)
        
        train_size = int(n_samples * train_ratio)
        train_indices = indices[:train_size]
        val_indices = indices[train_size:]
        
        return {
            'X_train': X[train_indices],
            'X_val': X[val_indices],
            'y_train': y[train_indices],
            'y_val': y[val_indices],
            'train_indices': train_indices,
            'val_indices': val_indices
        }

##### **Execute Class**

In [7]:
data_root = r"data/video_hand_focused_data"
output_path = r"data/video_hand_focused_data/hand_landmark_flatten"

# Create dataset processor
dataset_creator = HandPoseLSTMDatasetCreator(
    data_root=data_root,
    window_size=16,
    num_keypoints=21, 
    flatten_keypoints=True
)

# Process all data
print("Starting dataset processing...")
dataset_creator.process_dataset(normalize='wrist', skip_failed=True)

# Save with folder structure - individual files for each sample
dataset_creator.save_dataset(
    output_path=output_path,  
    maintain_structure=True,
    save_individual=True  
)

Starting dataset processing...

Processing class OH: 12 files


Processing OH:   0%|          | 0/12 [00:00<?, ?it/s]

Loaded data\video_hand_focused_data\rgb\OH\Record_20250402151124_w000.npy: shape (16, 300, 300, 3)
Loads checkpoint by http backend from path: https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmdet_nano_8xb32-300e_hand-267f9c8f.pth
Loads checkpoint by http backend from path: https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-m_simcc-hand5_pt-aic-coco_210e-256x256-74fb594_20230320.pth


Processing OH:   8%|▊         | 1/12 [00:08<01:34,  8.55s/it]

Loaded data\video_hand_focused_data\rgb\OH\Record_20250402151124_w010.npy: shape (16, 300, 300, 3)


Processing OH:  17%|█▋        | 2/12 [00:11<00:50,  5.02s/it]

Loaded data\video_hand_focused_data\rgb\OH\Record_20250402151609_w000.npy: shape (16, 300, 300, 3)


Processing OH:  25%|██▌       | 3/12 [00:13<00:34,  3.86s/it]

Loaded data\video_hand_focused_data\rgb\OH\Record_20250402151609_w009.npy: shape (16, 300, 300, 3)


Processing OH:  33%|███▎      | 4/12 [00:15<00:25,  3.16s/it]

Loaded data\video_hand_focused_data\rgb\OH\Record_20250402152331_w000.npy: shape (16, 300, 300, 3)


Processing OH:  42%|████▏     | 5/12 [00:17<00:19,  2.85s/it]

Loaded data\video_hand_focused_data\rgb\OH\Record_20250402152331_w001.npy: shape (16, 300, 300, 3)


Processing OH:  50%|█████     | 6/12 [00:20<00:15,  2.64s/it]

Loaded data\video_hand_focused_data\rgb\OH\Record_20250506145704_w000.npy: shape (16, 300, 300, 3)


Processing OH:  58%|█████▊    | 7/12 [00:24<00:16,  3.22s/it]

Loaded data\video_hand_focused_data\rgb\OH\Record_20250506145704_w001.npy: shape (16, 300, 300, 3)


Processing OH:  67%|██████▋   | 8/12 [00:29<00:15,  3.82s/it]

Loaded data\video_hand_focused_data\rgb\OH\Record_20250506152951_w000.npy: shape (16, 300, 300, 3)


Processing OH:  75%|███████▌  | 9/12 [00:34<00:12,  4.24s/it]

Loaded data\video_hand_focused_data\rgb\OH\Record_20250506152951_w001.npy: shape (16, 300, 300, 3)


Processing OH:  83%|████████▎ | 10/12 [00:39<00:08,  4.40s/it]

Loaded data\video_hand_focused_data\rgb\OH\Record_20250506162630_w000.npy: shape (16, 300, 300, 3)


Processing OH:  92%|█████████▏| 11/12 [00:43<00:04,  4.37s/it]

Loaded data\video_hand_focused_data\rgb\OH\Record_20250506162630_w001.npy: shape (16, 300, 300, 3)


Processing OH: 100%|██████████| 12/12 [00:49<00:00,  4.11s/it]



Processing class IH: 12 files


Processing IH:   0%|          | 0/12 [00:00<?, ?it/s]

Loaded data\video_hand_focused_data\rgb\IH\Record_20250402151124_w001.npy: shape (16, 300, 300, 3)


Processing IH:   8%|▊         | 1/12 [00:03<00:35,  3.21s/it]

Loaded data\video_hand_focused_data\rgb\IH\Record_20250402151124_w002.npy: shape (16, 300, 300, 3)


Processing IH:  17%|█▋        | 2/12 [00:06<00:31,  3.17s/it]

Loaded data\video_hand_focused_data\rgb\IH\Record_20250402151609_w001.npy: shape (16, 300, 300, 3)


Processing IH:  25%|██▌       | 3/12 [00:09<00:27,  3.03s/it]

Loaded data\video_hand_focused_data\rgb\IH\Record_20250402151609_w002.npy: shape (16, 300, 300, 3)


Processing IH:  33%|███▎      | 4/12 [00:12<00:24,  3.06s/it]

Loaded data\video_hand_focused_data\rgb\IH\Record_20250402152331_w003.npy: shape (16, 300, 300, 3)


Processing IH:  42%|████▏     | 5/12 [00:15<00:20,  2.97s/it]

Loaded data\video_hand_focused_data\rgb\IH\Record_20250402152331_w004.npy: shape (16, 300, 300, 3)


Processing IH:  50%|█████     | 6/12 [00:18<00:17,  2.95s/it]

Loaded data\video_hand_focused_data\rgb\IH\Record_20250506145704_w011.npy: shape (16, 300, 300, 3)


Processing IH:  58%|█████▊    | 7/12 [00:21<00:15,  3.08s/it]

Loaded data\video_hand_focused_data\rgb\IH\Record_20250506145704_w012.npy: shape (16, 300, 300, 3)


Processing IH:  67%|██████▋   | 8/12 [00:24<00:12,  3.12s/it]

Loaded data\video_hand_focused_data\rgb\IH\Record_20250506152951_w005.npy: shape (16, 300, 300, 3)


Processing IH:  75%|███████▌  | 9/12 [00:27<00:09,  3.05s/it]

Loaded data\video_hand_focused_data\rgb\IH\Record_20250506152951_w006.npy: shape (16, 300, 300, 3)


Processing IH:  83%|████████▎ | 10/12 [00:31<00:06,  3.26s/it]

Loaded data\video_hand_focused_data\rgb\IH\Record_20250506162630_w006.npy: shape (16, 300, 300, 3)


Processing IH:  92%|█████████▏| 11/12 [00:34<00:03,  3.20s/it]

Loaded data\video_hand_focused_data\rgb\IH\Record_20250506162630_w007.npy: shape (16, 300, 300, 3)


Processing IH: 100%|██████████| 12/12 [00:36<00:00,  3.08s/it]



Processing class SF: 12 files


Processing SF:   0%|          | 0/12 [00:00<?, ?it/s]

Loaded data\video_hand_focused_data\rgb\SF\Record_20250402151124_w003.npy: shape (16, 300, 300, 3)


Processing SF:   8%|▊         | 1/12 [00:03<00:33,  3.00s/it]

Loaded data\video_hand_focused_data\rgb\SF\Record_20250402151124_w004.npy: shape (16, 300, 300, 3)


Processing SF:  17%|█▋        | 2/12 [00:06<00:33,  3.40s/it]

Loaded data\video_hand_focused_data\rgb\SF\Record_20250402151609_w003.npy: shape (16, 300, 300, 3)


Processing SF:  25%|██▌       | 3/12 [00:10<00:31,  3.45s/it]

Loaded data\video_hand_focused_data\rgb\SF\Record_20250402151609_w004.npy: shape (16, 300, 300, 3)


Processing SF:  33%|███▎      | 4/12 [00:13<00:27,  3.38s/it]

Loaded data\video_hand_focused_data\rgb\SF\Record_20250402152331_w005.npy: shape (16, 300, 300, 3)


Processing SF:  42%|████▏     | 5/12 [00:17<00:24,  3.46s/it]

Loaded data\video_hand_focused_data\rgb\SF\Record_20250402152331_w006.npy: shape (16, 300, 300, 3)


Processing SF:  50%|█████     | 6/12 [00:20<00:20,  3.47s/it]

Loaded data\video_hand_focused_data\rgb\SF\Record_20250506145704_w016.npy: shape (16, 300, 300, 3)


Processing SF:  58%|█████▊    | 7/12 [00:23<00:16,  3.33s/it]

Loaded data\video_hand_focused_data\rgb\SF\Record_20250506145704_w017.npy: shape (16, 300, 300, 3)


Processing SF:  67%|██████▋   | 8/12 [00:26<00:12,  3.19s/it]

Loaded data\video_hand_focused_data\rgb\SF\Record_20250506145704_w018.npy: shape (16, 300, 300, 3)


Processing SF:  75%|███████▌  | 9/12 [00:29<00:09,  3.12s/it]

Loaded data\video_hand_focused_data\rgb\SF\Record_20250506152951_w010.npy: shape (16, 300, 300, 3)


Processing SF:  83%|████████▎ | 10/12 [00:32<00:06,  3.10s/it]

Loaded data\video_hand_focused_data\rgb\SF\Record_20250506152951_w011.npy: shape (16, 300, 300, 3)


Processing SF:  92%|█████████▏| 11/12 [00:34<00:02,  2.89s/it]

Loaded data\video_hand_focused_data\rgb\SF\Record_20250506162630_w008.npy: shape (16, 300, 300, 3)


Processing SF: 100%|██████████| 12/12 [00:37<00:00,  3.11s/it]



Processing class HH: 12 files


Processing HH:   0%|          | 0/12 [00:00<?, ?it/s]

Loaded data\video_hand_focused_data\rgb\HH\Record_20250402151124_w007.npy: shape (16, 300, 300, 3)


Processing HH:   8%|▊         | 1/12 [00:02<00:26,  2.45s/it]

Loaded data\video_hand_focused_data\rgb\HH\Record_20250402151124_w008.npy: shape (16, 300, 300, 3)


Processing HH:  17%|█▋        | 2/12 [00:05<00:25,  2.51s/it]

Loaded data\video_hand_focused_data\rgb\HH\Record_20250402151609_w006.npy: shape (16, 300, 300, 3)


Processing HH:  25%|██▌       | 3/12 [00:07<00:22,  2.49s/it]

Loaded data\video_hand_focused_data\rgb\HH\Record_20250402151609_w007.npy: shape (16, 300, 300, 3)


Processing HH:  33%|███▎      | 4/12 [00:10<00:20,  2.56s/it]

Loaded data\video_hand_focused_data\rgb\HH\Record_20250402152331_w008.npy: shape (16, 300, 300, 3)


Processing HH:  42%|████▏     | 5/12 [00:12<00:17,  2.48s/it]

Loaded data\video_hand_focused_data\rgb\HH\Record_20250402152331_w009.npy: shape (16, 300, 300, 3)


Processing HH:  50%|█████     | 6/12 [00:14<00:14,  2.41s/it]

Loaded data\video_hand_focused_data\rgb\HH\Record_20250506145704_w027.npy: shape (16, 300, 300, 3)


Processing HH:  58%|█████▊    | 7/12 [00:16<00:11,  2.35s/it]

Loaded data\video_hand_focused_data\rgb\HH\Record_20250506145704_w028.npy: shape (16, 300, 300, 3)


Processing HH:  67%|██████▋   | 8/12 [00:19<00:09,  2.26s/it]

Loaded data\video_hand_focused_data\rgb\HH\Record_20250506152951_w018.npy: shape (16, 300, 300, 3)


Processing HH:  75%|███████▌  | 9/12 [00:21<00:06,  2.20s/it]

Loaded data\video_hand_focused_data\rgb\HH\Record_20250506152951_w019.npy: shape (16, 300, 300, 3)


Processing HH:  83%|████████▎ | 10/12 [00:23<00:04,  2.24s/it]

Loaded data\video_hand_focused_data\rgb\HH\Record_20250506162630_w013.npy: shape (16, 300, 300, 3)


Processing HH:  92%|█████████▏| 11/12 [00:25<00:02,  2.32s/it]

Loaded data\video_hand_focused_data\rgb\HH\Record_20250506162630_w014.npy: shape (16, 300, 300, 3)


Processing HH: 100%|██████████| 12/12 [00:28<00:00,  2.36s/it]



Processing class HC: 12 files


Processing HC:   0%|          | 0/12 [00:00<?, ?it/s]

Loaded data\video_hand_focused_data\rgb\HC\Record_20250402151124_w005.npy: shape (16, 300, 300, 3)


Processing HC:   8%|▊         | 1/12 [00:02<00:26,  2.40s/it]

Loaded data\video_hand_focused_data\rgb\HC\Record_20250402151124_w006.npy: shape (16, 300, 300, 3)


Processing HC:  17%|█▋        | 2/12 [00:04<00:23,  2.33s/it]

Loaded data\video_hand_focused_data\rgb\HC\Record_20250402151609_w005.npy: shape (16, 300, 300, 3)


Processing HC:  25%|██▌       | 3/12 [00:07<00:22,  2.52s/it]

Loaded data\video_hand_focused_data\rgb\HC\Record_20250402152331_w007.npy: shape (16, 300, 300, 3)


Processing HC:  33%|███▎      | 4/12 [00:09<00:20,  2.52s/it]

Loaded data\video_hand_focused_data\rgb\HC\Record_20250506145704_w022.npy: shape (16, 300, 300, 3)


Processing HC:  42%|████▏     | 5/12 [00:12<00:17,  2.45s/it]

Loaded data\video_hand_focused_data\rgb\HC\Record_20250506145704_w023.npy: shape (16, 300, 300, 3)


Processing HC:  50%|█████     | 6/12 [00:14<00:14,  2.45s/it]

Loaded data\video_hand_focused_data\rgb\HC\Record_20250506145704_w024.npy: shape (16, 300, 300, 3)


Processing HC:  58%|█████▊    | 7/12 [00:17<00:12,  2.47s/it]

Loaded data\video_hand_focused_data\rgb\HC\Record_20250506145704_w025.npy: shape (16, 300, 300, 3)


Processing HC:  67%|██████▋   | 8/12 [00:19<00:09,  2.36s/it]

Loaded data\video_hand_focused_data\rgb\HC\Record_20250506152951_w014.npy: shape (16, 300, 300, 3)


Processing HC:  75%|███████▌  | 9/12 [00:21<00:06,  2.32s/it]

Loaded data\video_hand_focused_data\rgb\HC\Record_20250506152951_w015.npy: shape (16, 300, 300, 3)


Processing HC:  83%|████████▎ | 10/12 [00:23<00:04,  2.25s/it]

Loaded data\video_hand_focused_data\rgb\HC\Record_20250506162630_w009.npy: shape (16, 300, 300, 3)


Processing HC:  92%|█████████▏| 11/12 [00:25<00:02,  2.21s/it]

Loaded data\video_hand_focused_data\rgb\HC\Record_20250506162630_w010.npy: shape (16, 300, 300, 3)


Processing HC: 100%|██████████| 12/12 [00:27<00:00,  2.33s/it]



Processing complete:
Total samples: 60
Failed samples: 0

Class distribution:
  OH: 12 samples
  IH: 12 samples
  SF: 12 samples
  HH: 12 samples
  HC: 12 samples
Saved 12 individual files for class OH
Saved 12 individual files for class IH
Saved 12 individual files for class SF
Saved 12 individual files for class HH
Saved 12 individual files for class HC

Dataset saved to: data\video_hand_focused_data\hand_landmark_flatten
Total processed samples: 60


## 3. Model Development

### 3.1. Create Dataset Object

#### Using Non-Augmented Data

In [None]:
import os
import glob
import numpy as np
from datasets import Dataset, Features, Value, ClassLabel

# 1. Create lists of file paths and labels
image_rgb_dir = r'./data/video/rgb'
filepaths, labels = [], []
for class_name in os.listdir(image_rgb_dir):
    class_dir = os.path.join(image_rgb_dir, class_name)
    for npy_path in glob.glob(os.path.join(class_dir, '*.npy')):
        filepaths.append(npy_path)
        labels.append(class_name)

# 2. Create dataset with file paths
unique_labels = sorted(set(labels))
pr_ds = Dataset.from_dict({
    "image_path": filepaths,
    "label": labels
})

# 3. Encode the label column
pr_ds = pr_ds.class_encode_column("label")

# 4. Create a function to load images when needed
def load_image(example):
    example["image"] = np.load(example["image_path"])
    return example

# 5. Split the dataset
split1 = pr_ds.train_test_split(test_size=0.25, shuffle=True, seed=42, stratify_by_column="label")
train_ds = split1["train"]
val_ds = split1["test"]

# 6. Build label2id / id2label
label2id = {lab: idx for idx, lab in enumerate(unique_labels)}
id2label = {idx: lab for lab, idx in label2id.items()}

#### Using Augmented Data

In [None]:
import os
import glob
import numpy as np
from datasets import Dataset, Features, Value, ClassLabel

# 1) Base directories
base_dir     = os.path.join(
    r"D:\RESEARCH ASSISTANT\6. Depth Camera\CODE\Orbbec Gemini 2XL\REMOTE\DEVELOPMENT\notebook\DATA\video_hand_focused_data",
    "split_rgb"
)
train_dirs   = [os.path.join(base_dir, "train_aug")]
val_dir      = os.path.join(base_dir, "val")

# 2) Collect train file paths & labels
train_paths, train_labels = [], []
for d in train_dirs:
    for class_name in os.listdir(d):
        class_dir = os.path.join(d, class_name)
        if not os.path.isdir(class_dir):
            continue
        for npy_path in glob.glob(os.path.join(class_dir, "*.npy")):
            train_paths.append(npy_path)
            train_labels.append(class_name)

# 3) Collect val file paths & labels
val_paths, val_labels = [], []
for class_name in os.listdir(val_dir):
    class_dir = os.path.join(val_dir, class_name)
    if not os.path.isdir(class_dir):
        continue
    for npy_path in glob.glob(os.path.join(class_dir, "*.npy")):
        val_paths.append(npy_path)
        val_labels.append(class_name)

# 4) Define the ordered class list instead of using sorted set
ordered_labels = ["OH", "IH", "SF", "HC", "HH"]

# 5) Create custom class label feature with specified order
class_feature = ClassLabel(names=ordered_labels)

# 6) Create train dataset with custom class feature
train_ds = Dataset.from_dict({
    "image_path": train_paths,
    "label": train_labels
}, features=Features({
    "image_path": Value("string"),
    "label": class_feature
}))

# 7) Create validation dataset with custom class feature
val_ds = Dataset.from_dict({
    "image_path": val_paths,
    "label": val_labels
}, features=Features({
    "image_path": Value("string"),
    "label": class_feature
}))

# 8) Create a function to load images when needed
def load_image(example):
    example["image"] = np.load(example["image_path"])
    return example

# 9) Build label2id / id2label with the specified order
label2id = {label: idx for idx, label in enumerate(ordered_labels)}
id2label = {idx: label for idx, label in enumerate(ordered_labels)}

# 10) Verify the label mappings
print(f"label2id: {label2id}")
print(f"id2label: {id2label}")

#### Ensure Correct Labelling

In [None]:
import matplotlib.pyplot as plt
import random
from collections import Counter

# Function to visualize sample images for each class
def visualize_class_samples(dataset, id2label):
    plt.figure(figsize=(18, 10))  # Wider figure to accommodate all samples
    
    for class_id in range(len(id2label)):
        # Get examples for this class
        class_examples = [i for i, label in enumerate(dataset["label"]) if label == class_id]
        
        if not class_examples:
            continue
            
        # Select a random example
        sample_idx = random.choice(class_examples)
        
        # Load the sample image
        sample_path = dataset["image_path"][sample_idx]
        sample = np.load(sample_path)
        
        # For 4D arrays (video data), take the middle frame
        if len(sample.shape) == 4:
            middle_frame = sample.shape[0] // 2
            sample = sample[middle_frame]
        
        # Plot the sample
        plt.subplot(1, len(id2label), class_id + 1)
        
        # Handle different image formats (RGB, grayscale, etc.)
        if len(sample.shape) == 3 and sample.shape[2] == 3:
            plt.imshow(sample)
        elif len(sample.shape) == 2 or (len(sample.shape) == 3 and sample.shape[2] == 1):
            plt.imshow(sample, cmap='gray')
        
        plt.title(f"{id2label[class_id]} (ID: {class_id})")
        plt.axis('off')
    
    plt.subplots_adjust(wspace=0.3)  # Add more space between subplots
    plt.show()

# Visualize a sample from each class
print("\nSample Images from Each Class:")
visualize_class_samples(train_ds, id2label)

# Print explicit mapping for verification
print("\nClass Mapping Verification:")
for i in range(len(id2label)):
    class_name = id2label[i]
    print(f"ID {i}: {class_name}")
    
    # Count samples for this class in training set
    train_count = sum(1 for label in train_ds["label"] if label == i)
    print(f"  - Training samples: {train_count}")
    
    # Count samples for this class in validation set
    val_count = sum(1 for label in val_ds["label"] if label == i)
    print(f"  - Validation samples: {val_count}")

### 3.2. Model Initialization

#### Model Loading

In [None]:
from transformers import VideoMAEForVideoClassification, AutoImageProcessor
from torchinfo import summary
import torch

model_ckpt = 'MCG-NJU/videomae-base'
sw_processor = AutoImageProcessor.from_pretrained(model_ckpt, use_fast=True)

# Use the length of ordered_labels instead of unique_labels
sw_model = VideoMAEForVideoClassification.from_pretrained(
    model_ckpt,
    num_labels=len(ordered_labels),  # Changed from unique_labels to ordered_labels
    id2label=id2label,
    label2id=label2id, 
)

print(sw_model.config)
print(summary(sw_model,
        input_size=(1, 16, 3, 224, 224),
        col_names=["input_size", "output_size", "num_params"],
        row_settings=["depth"],
        device='cpu'))

#### Set Up Transformation Pipeline

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
import torchvision.transforms.functional as tvf
import einops

class Normalize(nn.Module):
    """
    Normalize a (T, C, H, W) tensor by per-channel mean/std,
    treating T as the batch dimension.
    """
    def __init__(self,
                 mean: Tuple[float, float, float],
                 std:  Tuple[float, float, float],
                 inplace: bool = False):
        super().__init__()
        self.mean    = mean
        self.std     = std
        self.inplace = inplace

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x is (T, C, H, W); tvf.normalize expects (N, C, H, W)
        return tvf.normalize(x, self.mean, self.std, self.inplace)
    
mean = sw_processor.image_mean
std = sw_processor.image_std
num_frames_to_samples = sw_model.config.num_frames
height = sw_processor.size.get("shortest_edge", sw_processor.size.get("height"))
width = height
resize_to = (height, width)

from torchvision.transforms import Compose, Lambda, RandomCrop, RandomHorizontalFlip

transformation_pipeline = Compose([
    Lambda(lambda x: x / 255.0),                          # now x is float [0,1]
    Normalize(mean, std),                                 # per‑channel norm
    Lambda(lambda x: F.interpolate(                             
        x, size=resize_to, mode="bilinear", align_corners=False 
    )),                                                          # resize to (H,W) = resize_to
])


#### Apply Transformation Pipeline to Dataset

In [None]:
def preprocess_train(batch):
    pixel_values = []
    for path in batch["image_path"]:
        # Load the numpy array from file - shape (T, H, W, C) = (16, 800, 1280, 3)
        arr = np.load(path)
        
        # Convert numpy array to tensor
        vid = torch.as_tensor(arr, dtype=torch.float32)
        
        # Rearrange from (T, H, W, C) to (T, C, H, W) using einops
        vid = einops.rearrange(vid, 't h w c -> t c h w')
        
        # Apply transformations
        vid_t = transformation_pipeline(vid)
        pixel_values.append(vid_t)
    
    batch["pixel_values"] = pixel_values
    batch["labels"] = batch["label"]
    return batch

def preprocess_val(batch):
    pixel_values = []
    for path in batch["image_path"]:
        # Load the numpy array from file - shape (T, H, W, C) = (16, 800, 1280, 3)
        arr = np.load(path)
        
        # Convert numpy array to tensor
        vid = torch.as_tensor(arr, dtype=torch.float32)
        
        # Rearrange from (T, H, W, C) to (T, C, H, W) using einops
        vid = einops.rearrange(vid, 't h w c -> t c h w')
        
        # Apply transformations
        vid_t = transformation_pipeline(vid)
        pixel_values.append(vid_t)
    
    batch["pixel_values"] = pixel_values
    batch["labels"] = batch["label"]
    return batch

# Apply preprocessing to training dataset
train_ds = train_ds.map(
    preprocess_train, 
    batched=True, 
    batch_size=4,
    remove_columns=["image_path", "label"]
)
train_ds.set_format(type="torch", columns=["pixel_values", "labels"])

# Apply preprocessing to validation dataset
val_ds = val_ds.map(
    preprocess_val, 
    batched=True, 
    batch_size=4,
    remove_columns=["image_path", "label"]
)
val_ds.set_format(type="torch", columns=["pixel_values", "labels"])

## 4. Model Fine-Tuning

### 4.1. Setup `TrainingArguments`, `Trainer`, and `evaluate`

In [None]:
from transformers import Trainer, TrainerCallback, TrainingArguments
import evaluate
from datetime import datetime
import pickle
import copy

model_name = model_ckpt.split("/")[-1]
new_model_name = f"{model_name}-finetuned-without-aug"
num_train_epochs = 50
EXPERIMENT_DATE = datetime.now().strftime("%Y%m%d")
SAVE_DIR = f"experiments/video/{EXPERIMENT_DATE}/{new_model_name}"
batch_size = 8

args_pr = TrainingArguments(
    output_dir=SAVE_DIR,
    remove_unused_columns=False, 
    eval_strategy="epoch",
    save_strategy="best",
    learning_rate=5e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=num_train_epochs,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
)

accuracy  = evaluate.load("accuracy")
precision = evaluate.load("precision")
recall    = evaluate.load("recall")
f1        = evaluate.load("f1")
confusion = evaluate.load("confusion_matrix")


def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)

    # scalar metrics as before…
    acc   = accuracy.compute(predictions=preds, references=labels)["accuracy"]
    prec  = precision.compute(predictions=preds, references=labels, average="macro")["precision"]
    rec   = recall.compute(predictions=preds, references=labels, average="macro")["recall"]
    f1sc  = f1.compute(predictions=preds, references=labels, average="macro")["f1"]

    # get confusion matrix and turn it into a nested Python list
    cm = confusion.compute(predictions=preds, references=labels)["confusion_matrix"]
    cm_list = cm.tolist()

    return {
        "accuracy":          acc,
        "precision":         prec,
        "recall":            rec,
        "f1":                f1sc,
        "confusion_matrix":  cm_list,    # now JSON‑serializable
    }

class CustomTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # these will accumulate _all_ batches in the current epoch
        self.epoch_losses     = []
        self.epoch_preds      = []
        self.epoch_labels     = []

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        """
        Overrides Trainer.compute_loss to store batch‐level loss & preds/labels.
        """
        labels = inputs.get("labels", None)
        outputs = model(**inputs)
        loss   = outputs.loss
        logits = outputs.logits

        if labels is not None:
            # 1) store the loss
            self.epoch_losses.append(loss.item())
            # 2) store predictions + labels as 1D arrays
            preds = logits.argmax(dim=-1).detach().cpu().numpy()
            labs  = labels.detach().cpu().numpy()
            self.epoch_preds .extend(preds.tolist())
            self.epoch_labels.extend(labs.tolist())

        return (loss, outputs) if return_outputs else loss

class MetricsCallback(TrainerCallback):
    def __init__(self, trainer):
        super().__init__()
        self.trainer = trainer

        # lists to hold epoch‐by‐epoch values
        self.train_losses     = []
        self.train_accuracies = []
        self.eval_losses      = []
        self.eval_accuracies  = []
        self.eval_confusion_matrices = [] 

    def on_epoch_end(self, args, state, control, **kwargs):
        # Compute average training loss & accuracy for the epoch
        t = self.trainer
        avg_loss = float(np.mean(t.epoch_losses))
        acc      = np.mean(
            np.array(t.epoch_preds) == np.array(t.epoch_labels)
        )

        # Store & clear for next epoch
        self.train_losses    .append(avg_loss)
        self.train_accuracies.append(acc)
        t.epoch_losses .clear()
        t.epoch_preds  .clear()
        t.epoch_labels .clear()

    def on_evaluate(self, args, state, control, metrics, **kwargs):
        # metrics come prefixed with "eval_"
        self.eval_losses   .append(metrics["eval_loss"])
        self.eval_accuracies.append(metrics["eval_accuracy"])
        self.eval_confusion_matrices.append(metrics["eval_confusion_matrix"])

trainer = CustomTrainer(
    model=copy.deepcopy(sw_model),                 
    args=args_pr,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=sw_processor,          
    compute_metrics=compute_metrics, 
)

metrics_cb = MetricsCallback(trainer)
trainer.add_callback(metrics_cb)

### 4.2. Training

In [None]:
train_results = trainer.train()

### 4.3. Training, Validation History Saving

In [None]:
history_val_ds = {
    "train_loss":     metrics_cb.train_losses,
    "train_accuracy": metrics_cb.train_accuracies,
    "eval_loss":      metrics_cb.eval_losses,
    "eval_accuracy":  metrics_cb.eval_accuracies,
    "eval_confusion_matrix": metrics_cb.eval_confusion_matrices,
}
history_train_val_ds_path = os.path.join(SAVE_DIR, "history_trainval.pkl")
with open(history_train_val_ds_path, "wb") as f:
    pickle.dump(history_val_ds, f)

### 4.4. Plotting Training, Val Losses and Confusion Matrix

In [None]:
def visualize_train_val_losses(save_dir: str) -> None:
    import pickle
    import matplotlib.pyplot as plt
    import glob
    import os

    # Set global font properties
    plt.rcParams['font.family'] = 'Times New Roman'
    plt.rcParams['font.size'] = 15

    # Search for history pickle files
    history_files = glob.glob(os.path.join(save_dir, "*.pkl"))

    for history_path in history_files:
        filename = os.path.basename(history_path)

        if "trainval" not in filename:
            continue  # skip non-trainval files

        print(f"Visualizing: {filename}")
        with open(history_path, "rb") as f:
            hist = pickle.load(f)

        fig, axs = plt.subplots(2, 1, figsize=(15, 15))

        # Plot loss
        axs[0].plot(hist["train_loss"], label="Train Loss", color='red')
        axs[0].plot(hist["eval_loss"], label="Validation Loss", color='blue')
        axs[0].set_title('Posture Recognition Training Loss', fontweight='bold')
        axs[0].set_ylabel('Loss')
        axs[0].legend()

        # Plot accuracy
        axs[1].plot(hist["train_accuracy"], label="Train Accuracy", color='orange')
        axs[1].plot(hist["eval_accuracy"], label="Validation Accuracy", color='green')
        axs[1].set_title('Posture Recognition Training Accuracy', fontweight='bold')
        axs[1].set_xlabel('Epoch')
        axs[1].set_ylabel('Accuracy')
        axs[1].legend()

        plt.tight_layout()
        plt.show()

def visualize_confusion_matrix(save_dir: str, suffix: str, mode: str = "trainval") -> None:
    import pickle
    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn.metrics import ConfusionMatrixDisplay
    import glob
    import os
    # Look for the correct .pkl file based on mode
    history_files = glob.glob(os.path.join(save_dir, "*.pkl"))
    
    # If only one history file exists, use it regardless of name
    if len(history_files) == 1:
        matched_file = history_files[0]
    else:
        matched_file = None
        for file in history_files:
            if mode.lower() in os.path.basename(file).lower():
                matched_file = file
                break
                
    if not matched_file:
        print(f"[ERROR] No history file found in '{save_dir}' for mode '{mode}'")
        return
    
    print(f"[INFO] Loading confusion matrix from: {matched_file}")
    with open(matched_file, "rb") as f:
        hist = pickle.load(f)
    
    # Decide whether to use the last epoch or a single matrix
    if "eval_confusion_matrix" in hist:
        if isinstance(hist["eval_confusion_matrix"], list):
            cm_data = hist["eval_confusion_matrix"][-1]
            title = f'Confusion Matrix (Last Epoch - Train/Val) - {suffix}'
        else:
            cm_data = hist["eval_confusion_matrix"]
            title = f"Confusion Matrix ({mode.capitalize()})"
    else:
        print(f"[ERROR] No confusion matrix data found in history file")
        return
    
    cm = np.array(cm_data)
    class_names = [
        "Hand Open",
        "Intrinsic Plus",
        "Straight Fist",  # Fixed missing comma here
        "Hook Hand",
        "Hand Close",
    ]
    
    fig, ax = plt.subplots(figsize=(8, 8))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
    disp.plot(cmap=plt.cm.Blues, ax=ax, colorbar=False)
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
    ax.set_title(title, fontweight="bold")
    plt.tight_layout()
    
    # Save the figure before showing it
    output_path = os.path.join(save_dir, f"confusion_matrix_{mode}.png")
    plt.savefig(output_path)
    print(f"[INFO] Confusion matrix saved to: {output_path}")
    
    plt.show()

visualize_train_val_losses(r'D:\RESEARCH ASSISTANT\6. Depth Camera\CODE\Orbbec Gemini 2XL\REMOTE\DEVELOPMENT\notebook\experiments\video\20250516\20250516-aug-3')
visualize_confusion_matrix(r'D:\RESEARCH ASSISTANT\6. Depth Camera\CODE\Orbbec Gemini 2XL\REMOTE\DEVELOPMENT\notebook\experiments\video\20250515\videomae-base-finetuned-without-aug-hand-focused', mode="trainval", suffix='Original')

## 5. Model Prediction

### 5.1. Video Classification

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import cv2
from transformers import VideoMAEForVideoClassification, AutoImageProcessor
from torchvision.transforms import Compose, Lambda
import torchvision.transforms.functional as tvf
from typing import Tuple, List

class Normalize(nn.Module):
    """
    Normalize a (T, C, H, W) tensor by per-channel mean/std,
    treating T as the batch dimension.
    """
    def __init__(self,
                 mean: Tuple[float, float, float],
                 std:  Tuple[float, float, float],
                 inplace: bool = False):
        super().__init__()
        self.mean    = mean
        self.std     = std
        self.inplace = inplace

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x is (T, C, H, W); tvf.normalize expects (N, C, H, W)
        return tvf.normalize(x, self.mean, self.std, self.inplace)

def sliding_window_sample(images, window_size=16, stride=8):
    """
    Create sliding windows from a sequence of frames.
    
    Args:
        images: List of frames or numpy array of shape [T, H, W, C]
        window_size: Number of frames in each window
        stride: Step size between windows
        
    Returns:
        List of numpy arrays, each with shape [window_size, H, W, C]
    """
    total_frames = len(images)
    windows = []
    
    # Handle case where there aren't enough frames for a single window
    if total_frames < window_size:
        # Duplicate the last frame to reach window_size
        padding_needed = window_size - total_frames
        padded_images = list(images) + [images[-1]] * padding_needed
        windows.append(np.stack(padded_images))
        return windows
    
    # Create windows with stride
    for start_idx in range(0, total_frames - window_size + 1, stride):
        end_idx = start_idx + window_size
        # Make a copy to ensure contiguous memory
        window = np.ascontiguousarray(np.stack(images[start_idx:end_idx]))
        windows.append(window)
    
    # If the last window doesn't align with the stride, add one more window
    # that takes the last window_size frames
    if (total_frames - window_size) % stride != 0:
        last_start = total_frames - window_size
        if last_start > (start_idx + stride):  # Only if it's different enough from the last added window
            window = np.ascontiguousarray(np.stack(images[last_start:]))
            windows.append(window)
    
    return windows

def load_model(model_path, device='cuda' if torch.cuda.is_available() else 'cpu'):
    """
    Load the fine-tuned VideoMAE model and processor.
    
    Args:
        model_path: Path to the fine-tuned model directory
        device: Device to load the model on
        
    Returns:
        model: Loaded model
        processor: Image processor
        id2label: Dictionary mapping from ID to class label
    """
    print(f"Loading model from {model_path}")
    
    # Load the processor from the base model
    processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base", use_fast=True)
    
    # Load the model config to get id2label mapping
    model = VideoMAEForVideoClassification.from_pretrained(model_path)
    id2label = model.config.id2label
    
    model.to(device)
    model.eval()
    
    return model, processor, id2label

def prepare_transforms(processor, model):
    """
    Build transform pipeline for video processing.
    
    Args:
        processor: Image processor
        model: VideoMAE model
        
    Returns:
        transform: Composition of transforms
    """
    mean, std = processor.image_mean, processor.image_std
    num_frames = model.config.num_frames
    size = processor.size.get("shortest_edge", processor.size.get("height"))
    resize_to = (size, size)
    
    transform = Compose([
        Lambda(lambda x: x / 255.0),
        Normalize(mean, std),
        Lambda(lambda x: F.interpolate(
            x, size=resize_to, mode="bilinear", align_corners=False
        )),
    ])
    
    return transform

def process_window(window, transform, device):
    """
    Process a window of frames for model input.
    
    Args:
        window: Numpy array of shape [T, H, W, C]
        transform: Transform pipeline
        device: Device to run inference on
        
    Returns:
        Tensor ready for model input
    """
    # Force a copy with positive strides
    window = window.copy()
    
    # Convert BGR to RGB if needed (OpenCV loads as BGR)
    if window.shape[-1] == 3:  # If channels-last format
        window = window[..., ::-1].copy()  # BGR to RGB with explicit copy
    
    # Convert to [T, C, H, W] format (channels first)
    if window.shape[-1] == 3:  # If channels-last format
        window = np.transpose(window, (0, 3, 1, 2)).copy()  # Another explicit copy
    
    # Convert to tensor - try with numpy array interface
    try:
        window_tensor = torch.from_numpy(window).float()
    except:
        # If that fails, try with a list conversion (less efficient but more robust)
        window_tensor = torch.tensor(window.tolist(), dtype=torch.float32)
    
    # Apply transforms
    window_tensor = transform(window_tensor)
    
    # Add batch dimension
    window_tensor = window_tensor.unsqueeze(0).to(device)
    
    return window_tensor

def predict_windows(windows, model, transform, device):
    """
    Run prediction on a list of windows.
    
    Args:
        windows: List of numpy arrays, each with shape [T, H, W, C]
        model: VideoMAE model
        transform: Transform pipeline
        device: Device to run inference on
        
    Returns:
        List of (class_id, confidence) tuples for each window
    """
    results = []
    
    with torch.no_grad():
        for window in windows:
            # Process window
            inputs = process_window(window, transform, device)
            
            # Run inference
            outputs = model(pixel_values=inputs)
            logits = outputs.logits
            
            # Get predictions
            probs = torch.softmax(logits, dim=-1)
            pred_class = torch.argmax(logits, dim=-1).item()
            confidence = probs[0, pred_class].item()
            
            results.append((pred_class, confidence))
    
    return results

def predict_video(video_path, model_path, output_path=None, window_size=16, stride=8):
    """
    Predict activity in a video using the fine-tuned VideoMAE model.
    Draws predictions on the video frames and saves output video.
    
    Args:
        video_path: Path to input video file
        model_path: Path to the fine-tuned model directory
        output_path: Path to save the output video (None for display only)
        window_size: Number of frames in each window
        stride: Step size between windows
        
    Returns:
        Dictionary with prediction results
    """
    # Load the model
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model, processor, id2label = load_model(model_path, device)
    transform = prepare_transforms(processor, model)
    
    # Open the video
    cap = cv2.VideoCapture(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    
    # Setup video writer if output path is specified
    if output_path:
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # Use appropriate codec
        out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
    
    # Read all frames
    frames = []
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        frames.append(frame)
    
    cap.release()
    print(f"Video loaded: {len(frames)} frames")
    
    # Create windows
    windows = sliding_window_sample(frames, window_size, stride)
    print(f"Created {len(windows)} windows")
    
    # Run predictions
    results = predict_windows(windows, model, transform, device)
    
    # Process results and create output video
    window_duration = window_size / fps
    predictions = []
    
    # Map window predictions to frames
    frame_predictions = [None] * len(frames)
    for i, (class_id, confidence) in enumerate(results):
        start_frame = i * stride
        end_frame = min(start_frame + window_size, len(frames))
        
        class_name = id2label[class_id]
        timestamp = start_frame / fps
        
        prediction = {
            'timestamp': timestamp,
            'class_id': class_id,
            'class_name': class_name,
            'confidence': confidence,
            'start_frame': start_frame,
            'end_frame': end_frame
        }
        predictions.append(prediction)
        
        # Assign this prediction to all frames in the window
        for frame_idx in range(start_frame, end_frame):
            if frame_idx < len(frame_predictions):
                frame_predictions[frame_idx] = prediction
    
    # Add prediction overlay to frames and write/display video
    for i, frame in enumerate(frames):
        # Add black rectangle on top
        overlay = frame.copy()
        cv2.rectangle(overlay, (0, 0), (width, 60), (0, 0, 0), -1)
        
        # Add text
        if frame_predictions[i]:
            pred = frame_predictions[i]
            text = f"{pred['class_name']} ({pred['confidence']:.2f})"
            cv2.putText(overlay, text, (10, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
        else:
            cv2.putText(overlay, "Processing...", (10, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
        
        # Write frame
        if output_path:
            out.write(overlay)
    
    # Release resources
    if output_path:
        out.release()
        print(f"Output video saved to {output_path}")
    
    return {
        'predictions': predictions,
        'id2label': id2label,
        'num_frames': len(frames),
        'fps': fps
    }

def predict_from_array(video_array, model_path, window_size=16, stride=8):
    """
    Predict activity from a numpy array of video frames.
    
    Args:
        video_array: Numpy array of shape [T, H, W, C]
        model_path: Path to the fine-tuned model directory
        window_size: Number of frames in each window
        stride: Step size between windows
        
    Returns:
        Dictionary with prediction results and processed frames with overlays
    """
    # Ensure contiguous array to avoid stride issues
    video_array = np.ascontiguousarray(video_array)
    
    # Load the model
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model, processor, id2label = load_model(model_path, device)
    transform = prepare_transforms(processor, model)
    
    # Get video dimensions
    total_frames = len(video_array)
    height, width = video_array[0].shape[:2]
    
    # Create windows
    windows = sliding_window_sample(video_array, window_size, stride)
    print(f"Created {len(windows)} windows from {total_frames} frames")
    
    # Run predictions
    results = predict_windows(windows, model, transform, device)
    
    # Process results
    predictions = []
    frame_predictions = [None] * total_frames
    
    for i, (class_id, confidence) in enumerate(results):
        start_frame = i * stride
        end_frame = min(start_frame + window_size, total_frames)
        
        class_name = id2label[class_id]
        
        prediction = {
            'window_idx': i,
            'class_id': class_id,
            'class_name': class_name,
            'confidence': confidence,
            'start_frame': start_frame,
            'end_frame': end_frame
        }
        predictions.append(prediction)
        
        # Assign this prediction to all frames in the window
        for frame_idx in range(start_frame, end_frame):
            if frame_idx < len(frame_predictions):
                frame_predictions[frame_idx] = prediction
    
    # Add prediction overlay to frames
    processed_frames = []
    for i, frame in enumerate(video_array):
        # Make a copy of the frame
        overlay = frame.copy()
        
        # Add black rectangle on top
        cv2.rectangle(overlay, (0, 0), (width, 60), (0, 0, 0), -1)
        
        # Add text
        if i < len(frame_predictions) and frame_predictions[i]:
            pred = frame_predictions[i]
            text = f"{pred['class_name']} ({pred['confidence']:.2f})"
            cv2.putText(overlay, text, (10, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
        else:
            cv2.putText(overlay, "No prediction", (10, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
        
        processed_frames.append(overlay)
    
    return {
        'predictions': predictions,
        'id2label': id2label,
        'processed_frames': processed_frames
    }

def main():
    """Demo function to run prediction on a video file"""
    try:
        # Load the numpy array first
        print("Loading numpy array...")
        video_array = np.load('./cropped_hands_sequence.npy')
        print(f"Loaded array with shape: {video_array.shape}")
        
        # Ensure the array is C-contiguous
        if not video_array.flags.c_contiguous:
            print("Converting to C-contiguous array...")
            video_array = np.ascontiguousarray(video_array)
        
        result = predict_from_array(
            video_array, 
            r'D:\RESEARCH ASSISTANT\6. Depth Camera\CODE\Orbbec Gemini 2XL\REMOTE\DEVELOPMENT\notebook\experiments\video\20250516\20250516-aug-6\checkpoint-240',
            window_size=16,
            stride=16
        )

        vis = []
        
        # Save the processed frames if needed
        if 'processed_frames' in result:
            output_dir = './cropped_hands_sequence_output'
            os.makedirs(output_dir, exist_ok=True)
            
            # Save each processed frame
            for i, frame in enumerate(result['processed_frames']):
                cv2.imwrite(f"{output_dir}/frame_{i:04d}.png", frame)
                vis.append(frame)
            
            print(f"Saved {len(result['processed_frames'])} processed frames to {output_dir}")
        
        print(f"Found {len(result['predictions'])} activity segments")
        for i, pred in enumerate(result['predictions']):
            print(f"Segment {i+1}: {pred['class_name']} (confidence: {pred['confidence']:.2f})")

        return vis

    except Exception as e:
        print(f"Error occurred: {str(e)}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    vis = main()