# VitPose: Vision Transformer for Human Pose Estimation

This notebook demonstrates how to use VitPose (Vision Transformer for Pose Estimation) following the official Hugging Face documentation. VitPose is a **top-down** pose estimation model that requires two stages:

1. **Object Detection**: First detect people in the image using RT-DETR
2. **Pose Estimation**: Extract keypoints from detected person regions using VitPose

## Key Features:
- State-of-the-art pose estimation using Vision Transformers
- Support for multiple model sizes (base, large, huge)
- COCO keypoint format with 17 body landmarks
- Batch processing and memory optimization

In [None]:
# Install Required Dependencies
!pip install transformers torch torchvision
!pip install Pillow matplotlib requests numpy
!pip install supervision  # Optional: for advanced visualization
!pip install opencv-python  # For alternative visualization options

In [None]:
import torch
import requests
import numpy as np
import matplotlib.pyplot as plt
import json
from PIL import Image, ImageDraw
from typing import List, Dict, Tuple, Any

# Import Transformers components for VitPose and RT-DETR
from transformers import (
    AutoProcessor, 
    RTDetrForObjectDetection, 
    VitPoseForPoseEstimation
)

# Set device for computation
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

In [None]:
# Load VitPose Model and Processor
# Available models: usyd-community/vitpose-base-simple, vitpose-large-simple, vitpose-huge-simple

model_name = "usyd-community/vitpose-base-simple"
print(f"Loading VitPose model: {model_name}")

# Load the VitPose processor and model
vitpose_processor = AutoProcessor.from_pretrained(model_name)
vitpose_model = VitPoseForPoseEstimation.from_pretrained(model_name, device_map=device)

print(f"✓ VitPose model loaded successfully!")
print(f"Model config - Hidden size: {vitpose_model.config.backbone_config.hidden_size}")
if hasattr(vitpose_model.config, 'num_keypoints'):
    print(f"Number of keypoints: {vitpose_model.config.num_keypoints}")

# For ViTPose++ models with Mixture of Experts, you can specify dataset_index
# Available dataset indices:
# 0: COCO validation 2017 dataset
# 1: AiC dataset  
# 2: MPII dataset
# 3: AP-10K dataset
# 4: APT-36K dataset
# 5: COCO-WholeBody dataset

In [None]:
# Load and Preprocess Input Image
# Using the same image as in the official documentation

url = "http://images.cocodataset.org/val2017/000000000139.jpg"
print(f"Loading image from: {url}")

# Download and load the image
image = Image.open(requests.get(url, stream=True).raw)
print(f"✓ Image loaded successfully! Size: {image.size}")

# Display the original image
plt.figure(figsize=(10, 8))
plt.imshow(image)
plt.title("Original Image")
plt.axis('off')
plt.show()

# Alternative: Load local image
# image = Image.open("path/to/your/image.jpg")

In [None]:
# Stage 1: Person Detection using RT-DETR
# VitPose is a top-down model that requires person bounding boxes

print("Stage 1: Detecting people in the image...")

# Load RT-DETR model for object detection
person_detector_name = "PekingU/rtdetr_r50vd_coco_o365"
person_processor = AutoProcessor.from_pretrained(person_detector_name)
person_model = RTDetrForObjectDetection.from_pretrained(person_detector_name, device_map=device)

print(f"✓ RT-DETR model loaded: {person_detector_name}")

# Process image for object detection
detection_inputs = person_processor(images=image, return_tensors="pt").to(device)

# Run object detection
with torch.no_grad():
    detection_outputs = person_model(**detection_inputs)

# Post-process detection results
detection_results = person_processor.post_process_object_detection(
    detection_outputs, 
    target_sizes=torch.tensor([(image.height, image.width)]), 
    threshold=0.3
)
result = detection_results[0]  # Results for first (and only) image

# Filter for person detections (label 0 in COCO dataset)
person_boxes = result["boxes"][result["labels"] == 0]
person_scores = result["scores"][result["labels"] == 0]

print(f"✓ Detected {len(person_boxes)} person(s) in the image")

# Convert boxes from VOC format (x1, y1, x2, y2) to COCO format (x, y, w, h)
person_boxes_coco = person_boxes.cpu().numpy().copy()
person_boxes_coco[:, 2] = person_boxes_coco[:, 2] - person_boxes_coco[:, 0]  # width
person_boxes_coco[:, 3] = person_boxes_coco[:, 3] - person_boxes_coco[:, 1]  # height

print(f"Person boxes (COCO format): {person_boxes_coco}")
print(f"Detection scores: {person_scores.cpu().numpy()}")

In [None]:
# Stage 2: Pose Estimation using VitPose
# Process detected person regions through VitPose

if len(person_boxes_coco) == 0:
    print("❌ No people detected. Cannot proceed with pose estimation.")
else:
    print(f"Stage 2: Running pose estimation on {len(person_boxes_coco)} detected person(s)...")
    
    # Prepare inputs for VitPose - IMPORTANT: boxes parameter is required!
    pose_inputs = vitpose_processor(
        image, 
        boxes=[person_boxes_coco],  # Pass bounding boxes as required
        return_tensors="pt"
    ).to(device)
    
    print(f"Pose input shape: {pose_inputs.pixel_values.shape}")
    
    # Run pose estimation
    with torch.no_grad():
        pose_outputs = vitpose_model(**pose_inputs)
    
    print(f"Raw heatmaps shape: {pose_outputs.heatmaps.shape}")
    
    # Post-process pose estimation results
    pose_results = vitpose_processor.post_process_pose_estimation(
        pose_outputs, 
        boxes=[person_boxes_coco]
    )
    
    # Extract results for the first image
    image_pose_results = pose_results[0]
    
    print(f"✓ Pose estimation completed!")
    print(f"✓ Detected poses for {len(image_pose_results)} person(s)")
    
    # Display pose information
    for i, pose_result in enumerate(image_pose_results):
        keypoints = pose_result['keypoints']
        scores = pose_result['scores']
        print(f"\nPerson {i+1}:")
        print(f"  Keypoints shape: {keypoints.shape}")
        print(f"  Scores shape: {scores.shape}")
        print(f"  Average confidence: {scores.mean():.3f}")
        print(f"  High confidence keypoints: {(scores > 0.5).sum()}/{len(scores)}")

In [None]:
# Visualize Keypoints and Skeleton
# Following the official documentation visualization approach

if len(person_boxes_coco) > 0:
    # COCO keypoint names for reference
    COCO_KEYPOINT_NAMES = [
        'nose', 'left_eye', 'right_eye', 'left_ear', 'right_ear',
        'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow', 
        'left_wrist', 'right_wrist', 'left_hip', 'right_hip',
        'left_knee', 'right_knee', 'left_ankle', 'right_ankle'
    ]
    
    # Color palette from the official documentation
    palette = np.array([
        [255, 128, 0], [255, 153, 51], [255, 178, 102], [230, 230, 0],
        [255, 153, 255], [153, 204, 255], [255, 102, 255], [255, 51, 255],
        [102, 178, 255], [51, 153, 255], [255, 153, 153], [255, 102, 102],
        [255, 51, 51], [153, 255, 153], [102, 255, 102], [51, 255, 51],
        [0, 255, 0], [0, 0, 255], [255, 0, 0], [255, 255, 255]
    ])
    
    # Link colors and keypoint colors from official docs
    link_colors = palette[[0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16]]
    keypoint_colors = palette[[16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0]]
    
    # Get keypoint edges from model config
    keypoint_edges = vitpose_model.config.edges
    
    # Create visualization
    plt.figure(figsize=(15, 10))
    
    # Original image
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title("Original Image")
    plt.axis('off')
    
    # Image with pose overlay
    plt.subplot(1, 2, 2)
    numpy_image = np.array(image.copy())
    
    # Draw poses for each detected person
    for pose_result in image_pose_results:
        keypoints = np.array(pose_result["keypoints"])
        scores = np.array(pose_result["scores"])
        
        # Draw keypoint connections (skeleton)
        if keypoint_edges is not None:
            for sk_id, (joint1, joint2) in enumerate(keypoint_edges):
                if joint1 < len(keypoints) and joint2 < len(keypoints):
                    x1, y1, score1 = keypoints[joint1][0], keypoints[joint1][1], scores[joint1]
                    x2, y2, score2 = keypoints[joint2][0], keypoints[joint2][1], scores[joint2]
                    
                    # Only draw if both keypoints are confident enough
                    if score1 > 0.3 and score2 > 0.3:
                        x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
                        if (0 <= x1 < numpy_image.shape[1] and 0 <= y1 < numpy_image.shape[0] and
                            0 <= x2 < numpy_image.shape[1] and 0 <= y2 < numpy_image.shape[0]):
                            
                            color = tuple(int(c) for c in link_colors[sk_id % len(link_colors)])
                            cv2.line(numpy_image, (x1, y1), (x2, y2), color, thickness=2)
        
        # Draw keypoints
        for kid, (kpt, kpt_score) in enumerate(zip(keypoints, scores)):
            x_coord, y_coord = int(kpt[0]), int(kpt[1])
            if kpt_score > 0.3:  # Only draw confident keypoints
                if (0 <= x_coord < numpy_image.shape[1] and 0 <= y_coord < numpy_image.shape[0]):
                    color = tuple(int(c) for c in keypoint_colors[kid % len(keypoint_colors)])
                    cv2.circle(numpy_image, (x_coord, y_coord), 4, color, -1)
                    # Add white border
                    cv2.circle(numpy_image, (x_coord, y_coord), 4, (255, 255, 255), 1)
    
    plt.imshow(numpy_image)
    plt.title(f"Detected Poses ({len(image_pose_results)} person(s))")
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    print("✓ Pose visualization completed!")

In [None]:
# Import cv2 for drawing functions (needed for visualization)
import cv2

In [None]:
# Save Results to JSON Format

if len(person_boxes_coco) > 0:
    # Prepare data for JSON export
    results_data = {
        "image_info": {
            "width": image.width,
            "height": image.height,
            "url": url
        },
        "detection_info": {
            "num_persons_detected": len(person_boxes_coco),
            "detection_threshold": 0.3,
            "person_boxes": person_boxes_coco.tolist(),
            "detection_scores": person_scores.cpu().numpy().tolist()
        },
        "pose_results": []
    }
    
    # Add pose data for each detected person
    for i, pose_result in enumerate(image_pose_results):
        keypoints = pose_result['keypoints'].numpy()
        scores = pose_result['scores'].numpy()
        
        pose_data = {
            "person_id": i,
            "bbox": person_boxes_coco[i].tolist(),
            "detection_score": float(person_scores[i].cpu().numpy()),
            "keypoints": [],
            "average_keypoint_confidence": float(scores.mean()),
            "high_confidence_keypoints": int((scores > 0.5).sum())
        }
        
        # Add individual keypoint data
        for j, (kpt, score) in enumerate(zip(keypoints, scores)):
            keypoint_data = {
                "id": j,
                "name": COCO_KEYPOINT_NAMES[j] if j < len(COCO_KEYPOINT_NAMES) else f"keypoint_{j}",
                "x": float(kpt[0]),
                "y": float(kpt[1]),
                "confidence": float(score),
                "visible": bool(score > 0.3)
            }
            pose_data["keypoints"].append(keypoint_data)
        
        results_data["pose_results"].append(pose_data)
    
    # Save to JSON file
    output_file = "vitpose_results.json"
    with open(output_file, 'w') as f:
        json.dump(results_data, f, indent=2)
    
    print(f"✓ Results saved to {output_file}")
    print(f"✓ Summary: {len(image_pose_results)} poses detected")
    print(f"✓ Total keypoints: {sum(len(pose['keypoints']) for pose in results_data['pose_results'])}")
    
    # Display a sample of the JSON structure
    print("\nSample JSON structure:")
    sample_person = results_data["pose_results"][0] if results_data["pose_results"] else {}
    if sample_person:
        print(f"Person 1 - Avg confidence: {sample_person['average_keypoint_confidence']:.3f}")
        print(f"Sample keypoints (first 3):")
        for kpt in sample_person["keypoints"][:3]:
            print(f"  {kpt['name']}: ({kpt['x']:.1f}, {kpt['y']:.1f}) conf={kpt['confidence']:.3f}")
else:
    print("❌ No poses to save - no people were detected in the image")

In [None]:
# Batch Processing Multiple Images

def process_multiple_images(image_urls: List[str], confidence_threshold: float = 0.3) -> List[Dict]:
    """
    Process multiple images for pose estimation in batch.
    
    Args:
        image_urls: List of image URLs or file paths
        confidence_threshold: Threshold for person detection
    
    Returns:
        List of results for each image
    """
    all_results = []
    
    for i, img_url in enumerate(image_urls):
        print(f"\n--- Processing Image {i+1}/{len(image_urls)} ---")
        
        try:
            # Load image
            if img_url.startswith('http'):
                img = Image.open(requests.get(img_url, stream=True).raw)
            else:
                img = Image.open(img_url)
            
            print(f"Loaded image: {img.size}")
            
            # Stage 1: Person Detection
            detection_inputs = person_processor(images=img, return_tensors="pt").to(device)
            
            with torch.no_grad():
                detection_outputs = person_model(**detection_inputs)
            
            detection_results = person_processor.post_process_object_detection(
                detection_outputs, 
                target_sizes=torch.tensor([(img.height, img.width)]), 
                threshold=confidence_threshold
            )
            
            result = detection_results[0]
            person_boxes = result["boxes"][result["labels"] == 0]
            person_scores = result["scores"][result["labels"] == 0]
            
            if len(person_boxes) == 0:
                print("No people detected in this image")
                all_results.append({
                    "image_index": i,
                    "image_url": img_url,
                    "persons_detected": 0,
                    "poses": []
                })
                continue
                
            # Convert to COCO format
            person_boxes_coco = person_boxes.cpu().numpy().copy()
            person_boxes_coco[:, 2] = person_boxes_coco[:, 2] - person_boxes_coco[:, 0]
            person_boxes_coco[:, 3] = person_boxes_coco[:, 3] - person_boxes_coco[:, 1]
            
            # Stage 2: Pose Estimation
            pose_inputs = vitpose_processor(
                img, 
                boxes=[person_boxes_coco], 
                return_tensors="pt"
            ).to(device)
            
            with torch.no_grad():
                pose_outputs = vitpose_model(**pose_inputs)
            
            pose_results = vitpose_processor.post_process_pose_estimation(
                pose_outputs, 
                boxes=[person_boxes_coco]
            )
            
            # Store results
            image_results = {
                "image_index": i,
                "image_url": img_url,
                "image_size": img.size,
                "persons_detected": len(person_boxes_coco),
                "poses": []
            }
            
            for j, pose_result in enumerate(pose_results[0]):
                keypoints = pose_result['keypoints'].numpy()
                scores = pose_result['scores'].numpy()
                
                image_results["poses"].append({
                    "person_id": j,
                    "bbox": person_boxes_coco[j].tolist(),
                    "detection_confidence": float(person_scores[j].cpu().numpy()),
                    "avg_keypoint_confidence": float(scores.mean()),
                    "visible_keypoints": int((scores > 0.3).sum()),
                    "keypoints": keypoints.tolist(),
                    "scores": scores.tolist()
                })
            
            all_results.append(image_results)
            print(f"✓ Processed: {len(person_boxes_coco)} person(s) detected")
            
        except Exception as e:
            print(f"❌ Error processing image {i+1}: {e}")
            all_results.append({
                "image_index": i,
                "image_url": img_url,
                "error": str(e),
                "persons_detected": 0,
                "poses": []
            })
            
        # Clear memory
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    return all_results

# Example: Process multiple images
sample_urls = [
    "http://images.cocodataset.org/val2017/000000000139.jpg",
    "http://images.cocodataset.org/val2017/000000039769.jpg",
    # Add more URLs as needed
]

print("Demo: Batch processing multiple images...")
batch_results = process_multiple_images(sample_urls[:1])  # Process just 1 for demo

# Summary
print("\n=== Batch Processing Summary ===")
total_images = len(batch_results)
total_persons = sum(result.get("persons_detected", 0) for result in batch_results)
successful_images = len([r for r in batch_results if "error" not in r])

print(f"Images processed: {total_images}")
print(f"Successful: {successful_images}")
print(f"Total persons detected: {total_persons}")
print(f"Average persons per image: {total_persons/max(successful_images, 1):.1f}")

In [None]:
# Performance Optimization and Memory Management

# 1. Memory cleanup function
def cleanup_gpu_memory():
    """Clean up GPU memory to prevent memory leaks."""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        print("✓ GPU memory cache cleared")
    
    import gc
    gc.collect()
    print("✓ Garbage collection completed")

# 2. Quantization for memory efficiency (optional)
def load_quantized_models():
    """Load models with quantization for reduced memory usage."""
    try:
        from transformers import BitsAndBytesConfig
        
        # 4-bit quantization config
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4"
        )
        
        # Load quantized models
        quantized_vitpose = VitPoseForPoseEstimation.from_pretrained(
            "usyd-community/vitpose-base-simple",
            quantization_config=bnb_config,
            device_map="auto"
        )
        
        print("✓ Quantized VitPose model loaded")
        return quantized_vitpose
        
    except ImportError:
        print("⚠️ BitsAndBytesConfig not available. Install with: pip install bitsandbytes")
        return None

# 3. Performance monitoring
def monitor_performance():
    """Monitor GPU memory usage and model performance."""
    if torch.cuda.is_available():
        print(f"GPU Device: {torch.cuda.get_device_name()}")
        print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
        print(f"Allocated GPU memory: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
        print(f"Cached GPU memory: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
        print(f"Free GPU memory: {(torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_reserved()) / 1e9:.2f} GB")
    else:
        print("Using CPU - no GPU memory to monitor")

# 4. Model comparison function
def compare_model_sizes():
    """Compare different VitPose model sizes."""
    models = [
        "usyd-community/vitpose-base-simple",
        "usyd-community/vitpose-large-simple", 
        "usyd-community/vitpose-huge-simple"
    ]
    
    print("VitPose Model Comparison:")
    print("Model Size | Parameters | Performance | Memory Usage")
    print("-" * 50)
    print("Base       | ~100M      | Good        | Low")
    print("Large      | ~300M      | Better      | Medium") 
    print("Huge       | ~600M      | Best        | High")
    print("\nChoose based on your hardware capabilities and accuracy requirements.")

# Demo the optimization features
print("=== Performance Optimization Demo ===")

# Monitor current performance
print("\n1. Current Memory Usage:")
monitor_performance()

# Clean up memory
print("\n2. Cleaning up memory:")
cleanup_gpu_memory()
monitor_performance()

# Show model comparison
print("\n3. Model Size Comparison:")
compare_model_sizes()

# Tips for optimization
print("\n=== Optimization Tips ===")
print("🚀 Performance Tips:")
print("  • Use GPU when available (CUDA)")
print("  • Enable mixed precision with torch.bfloat16")  
print("  • Use quantization for memory-constrained environments")
print("  • Process images in batches when possible")
print("  • Clear GPU cache between processing sessions")
print("  • Choose appropriate model size for your hardware")

print("\n💾 Memory Management:")
print("  • Monitor GPU memory usage regularly")
print("  • Use torch.cuda.empty_cache() after processing")
print("  • Consider using CPU for very large batches")
print("  • Implement proper error handling and cleanup")

print("\n⚡ Speed Optimization:")
print("  • Use torch.no_grad() for inference")
print("  • Enable attention optimization with 'sdpa'")
print("  • Preload models to avoid repeated loading")
print("  • Use appropriate batch sizes for your GPU")

print("\n✓ Optimization setup completed!")