In [None]:
# Install required packages for Google Colab
# Reinstalling mediapipe, tensorflow, and scipy with compatible versions to resolve numpy conflict
!pip install mediapipe==0.10.21 tensorflow==2.19.0 scipy==1.13.1 opencv-python numpy matplotlib scikit-learn
!pip install tensorflow-hub
!pip install torch torchvision
!pip install ultralytics
!pip install supervision

# For OpenPose (alternative approach)
!pip install opencv-contrib-python

# Additional packages for 3D visualization
!pip install plotly
!pip install trimesh

print("✅ All packages installed successfully!")

In [None]:
# Import all necessary libraries
import cv2
import numpy as np
import mediapipe as mp
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.animation import FuncAnimation
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import json
import time
from collections import deque
from scipy import signal
from scipy.spatial.distance import cdist
from sklearn.cluster import DBSCAN
import warnings
warnings.filterwarnings('ignore')

# TensorFlow and TensorFlow Hub for additional models
import tensorflow as tf
import tensorflow_hub as hub

# PyTorch for advanced models
import torch
import torch.nn as nn

print("✅ All libraries imported successfully!")
print(f"OpenCV version: {cv2.__version__}")
print(f"MediaPipe version: {mp.__version__}")
print(f"TensorFlow version: {tf.__version__}")
print(f"PyTorch version: {torch.__version__}")


In [None]:
# Pose estimation configuration and constants
class PoseConfig:
    # MediaPipe settings
    MEDIAPIPE_MODEL_COMPLEXITY = 2  # 0, 1, or 2 (higher = more accurate)
    MEDIAPIPE_SMOOTH_LANDMARKS = True
    MEDIAPIPE_ENABLE_SEGMENTATION = False
    MEDIAPIPE_SMOOTH_SEGMENTATION = True
    MEDIAPIPE_MIN_DETECTION_CONFIDENCE = 0.5
    MEDIAPIPE_MIN_TRACKING_CONFIDENCE = 0.5

    # Multi-person tracking settings
    MAX_PERSONS = 3
    PERSON_SIMILARITY_THRESHOLD = 0.7
    TRACKING_HISTORY_SIZE = 10

    # Temporal smoothing settings
    SMOOTHING_WINDOW_SIZE = 5
    SMOOTHING_SIGMA = 1.0

    # 3D pose estimation settings
    ENABLE_3D_ESTIMATION = True
    DEPTH_ESTIMATION_METHOD = "mediapipe"  # "mediapipe", "stereo", "monocular"

    # Visualization settings
    DRAW_CONNECTIONS = True
    DRAW_LANDMARKS = True
    LANDMARK_RADIUS = 3
    CONNECTION_THICKNESS = 2

# Pose landmark connections for MediaPipe
POSE_CONNECTIONS = mp.solutions.pose.POSE_CONNECTIONS

# Custom pose connections for martial arts and shooting analysis
MARTIAL_ARTS_CONNECTIONS = [
    # Core stability connections
    (11, 12),  # Shoulders
    (11, 23),  # Left shoulder to left hip
    (12, 24),  # Right shoulder to right hip
    (23, 24),  # Hips

    # Arm connections for shooting
    (11, 13),  # Left shoulder to left elbow
    (12, 14),  # Right shoulder to right elbow
    (13, 15),  # Left elbow to left wrist
    (14, 16),  # Right elbow to right wrist

    # Leg connections for stance
    (23, 25),  # Left hip to left knee
    (24, 26),  # Right hip to right knee
    (25, 27),  # Left knee to left ankle
    (26, 28),  # Right ankle to right ankle
]

SHOOTING_ANALYSIS_POINTS = {
    'stance': [23, 24, 25, 26, 27, 28],  # Hip and leg points
    'grip': [15, 16, 19, 20],  # Wrist and hand points
    'sight_alignment': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],  # Face points
    'trigger_control': [15, 16],  # Wrist points
    'follow_through': [11, 12, 13, 14, 15, 16]  # Arm points
}

print("✅ Configuration loaded successfully!")

# Advanced Multi-Person Pose Estimator Class
class MultiPersonPoseEstimator:
    def __init__(self, config=PoseConfig()):
        self.config = config
        self.mp_pose = mp.solutions.pose
        self.mp_drawing = mp.solutions.drawing_utils
        self.mp_drawing_styles = mp.solutions.drawing_styles

        # Initialize MediaPipe Pose
        self.pose = self.mp_pose.Pose(
            static_image_mode=False,
            model_complexity=config.MEDIAPIPE_MODEL_COMPLEXITY,
            smooth_landmarks=config.MEDIAPIPE_SMOOTH_LANDMARKS,
            enable_segmentation=config.MEDIAPIPE_ENABLE_SEGMENTATION,
            smooth_segmentation=config.MEDIAPIPE_SMOOTH_SEGMENTATION,
            min_detection_confidence=config.MEDIAPIPE_MIN_DETECTION_CONFIDENCE,
            min_tracking_confidence=config.MEDIAPIPE_MIN_TRACKING_CONFIDENCE
        )

        # Person tracking
        self.person_tracks = {}
        self.next_person_id = 0
        self.tracking_history = deque(maxlen=config.TRACKING_HISTORY_SIZE)

        # Temporal smoothing
        self.smoothing_buffer = {}

    def calculate_person_similarity(self, landmarks1, landmarks2):
        """Calculate similarity between two pose landmarks"""
        if not landmarks1 or not landmarks2:
            return 0.0

        # Extract key points for comparison
        key_points = [0, 11, 12, 13, 14, 15, 16, 23, 24, 25, 26, 27, 28]

        points1 = []
        points2 = []

        for i in key_points:
            if i < len(landmarks1) and i < len(landmarks2):
                if landmarks1[i].visibility > 0.5 and landmarks2[i].visibility > 0.5:
                    points1.append([landmarks1[i].x, landmarks1[i].y])
                    points2.append([landmarks2[i].x, landmarks2[i].y])

        if len(points1) < 3 or len(points2) < 3:
            return 0.0

        # Calculate average distance
        points1 = np.array(points1)
        points2 = np.array(points2)

        # Normalize by image size
        distances = np.linalg.norm(points1 - points2, axis=1)
        avg_distance = np.mean(distances)

        # Convert distance to similarity (0-1)
        similarity = max(0, 1 - avg_distance * 2)
        return similarity

    def assign_person_ids(self, detected_poses):
        """Assign consistent IDs to detected poses"""
        current_poses = []

        for pose in detected_poses:
            if pose is None:
                continue

            best_match_id = None
            best_similarity = 0

            # Find best matching existing person
            for person_id, track in self.person_tracks.items():
                if len(track['landmarks_history']) > 0:
                    last_landmarks = track['landmarks_history'][-1]
                    similarity = self.calculate_person_similarity(pose.landmark, last_landmarks)

                    if similarity > best_similarity and similarity > self.config.PERSON_SIMILARITY_THRESHOLD:
                        best_similarity = similarity
                        best_match_id = person_id

            # Assign ID
            if best_match_id is not None:
                person_id = best_match_id
            else:
                person_id = self.next_person_id
                self.next_person_id += 1

            current_poses.append({
                'id': person_id,
                'landmarks': pose.landmark,
                'world_landmarks': pose.world_landmark if hasattr(pose, 'world_landmark') else None,
                'similarity': best_similarity
            })

        return current_poses

    def update_tracking_history(self, poses):
        """Update tracking history for all persons"""
        # Update existing tracks
        for pose in poses:
            person_id = pose['id']

            if person_id not in self.person_tracks:
                self.person_tracks[person_id] = {
                    'landmarks_history': deque(maxlen=self.config.TRACKING_HISTORY_SIZE),
                    'world_landmarks_history': deque(maxlen=self.config.TRACKING_HISTORY_SIZE),
                    'last_seen': time.time(),
                    'total_detections': 0
                }

            # Add to history
            self.person_tracks[person_id]['landmarks_history'].append(pose['landmarks'])
            if pose['world_landmarks']:
                self.person_tracks[person_id]['world_landmarks_history'].append(pose['world_landmarks'])

            self.person_tracks[person_id]['last_seen'] = time.time()
            self.person_tracks[person_id]['total_detections'] += 1

        # Clean up old tracks
        current_time = time.time()
        to_remove = []
        for person_id, track in self.person_tracks.items():
            if current_time - track['last_seen'] > 5.0:  # Remove after 5 seconds
                to_remove.append(person_id)

        for person_id in to_remove:
            del self.person_tracks[person_id]

    def apply_temporal_smoothing(self, landmarks):
        """Apply temporal smoothing to landmarks"""
        if len(landmarks) == 0:
            return landmarks

        smoothed_landmarks = []

        for i, landmark in enumerate(landmarks):
            if i not in self.smoothing_buffer:
                self.smoothing_buffer[i] = deque(maxlen=self.config.SMOOTHING_WINDOW_SIZE)

            # Add current point
            self.smoothing_buffer[i].append([landmark.x, landmark.y, landmark.z])

            # Calculate smoothed position
            if len(self.smoothing_buffer[i]) > 1:
                points = np.array(list(self.smoothing_buffer[i]))

                # Apply Gaussian smoothing
                weights = np.exp(-0.5 * ((np.arange(len(points)) - len(points) + 1) / self.config.SMOOTHING_SIGMA) ** 2)
                weights = weights / np.sum(weights)

                smoothed_point = np.average(points, axis=0, weights=weights)

                # Create new landmark with smoothed position
                smoothed_landmark = type(landmark)()
                smoothed_landmark.x = smoothed_point[0]
                smoothed_landmark.y = smoothed_point[1]
                smoothed_landmark.z = smoothed_point[2]
                smoothed_landmark.visibility = landmark.visibility

                smoothed_landmarks.append(smoothed_landmark)
            else:
                smoothed_landmarks.append(landmark)

        return smoothed_landmarks

    def process_frame(self, image):
        """Process a single frame and return multi-person pose data"""
        # Convert BGR to RGB
        rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Process with MediaPipe
        results = self.pose.process(rgb_image)

        detected_poses = []
        if results.pose_landmarks:
            # For multi-person, we'll process each detection
            # Note: MediaPipe Pose processes one person at a time
            # For true multi-person, we'd need to use a different approach
            detected_poses.append(results.pose_landmarks)

        # Assign person IDs
        poses_with_ids = self.assign_person_ids(detected_poses)

        # Update tracking history
        self.update_tracking_history(poses_with_ids)

        # Apply temporal smoothing
        smoothed_poses = []
        for pose in poses_with_ids:
            smoothed_landmarks = self.apply_temporal_smoothing(pose['landmarks'])
            smoothed_poses.append({
                'id': pose['id'],
                'landmarks': smoothed_landmarks,
                'world_landmarks': pose['world_landmarks'],
                'similarity': pose['similarity']
            })

        return {
            'poses': smoothed_poses,
            'raw_results': results,
            'image_shape': image.shape
        }

print("✅ MultiPersonPoseEstimator class created successfully!")

# Advanced Multi-Person Detection with Improved Tracking
class AdvancedMultiPersonPoseEstimator:
    def __init__(self, config=PoseConfig()):
        self.config = config
        self.mp_pose = mp.solutions.pose
        self.mp_drawing = mp.solutions.drawing_utils

        # Initialize MediaPipe Pose with higher accuracy settings
        # Disable segmentation smoothing to prevent dimension mismatch errors
        self.pose = self.mp_pose.Pose(
            static_image_mode=False,
            model_complexity=2,  # Maximum complexity for better accuracy
            smooth_landmarks=True,
            enable_segmentation=False,  # Disable segmentation to prevent smoothing errors
            smooth_segmentation=False,  # Disable segmentation smoothing
            min_detection_confidence=0.6,  # Higher confidence threshold
            min_tracking_confidence=0.6
        )

        # Load YOLO for person detection
        try:
            from ultralytics import YOLO
            # Use YOLOv8m for better accuracy with multiple people
            self.yolo_model = YOLO('yolov8m.pt')  # Medium version for better accuracy
            self.use_yolo = True
            print("✅ YOLO model loaded successfully")
        except Exception as e:
            print(f"⚠️ YOLO not available, using fallback method: {e}")
            self.use_yolo = False

        # Advanced person tracking
        self.person_tracks = {}
        self.next_person_id = 0
        self.smoothing_buffer = {}
        self.tracking_history = deque(maxlen=30)  # Longer history for better tracking

        # Kalman filter for each person (simple implementation)
        self.kalman_filters = {}

        # Overlap detection
        self.overlap_threshold = 0.3  # IoU threshold for overlap detection

        # Performance optimization
        self.frame_skip = 2  # Process every 2nd frame for YOLO detection
        self.frame_count = 0

    def detect_persons_yolo(self, image):
        """Detect persons using YOLO with improved filtering"""
        if not self.use_yolo:
            return []

        # Only run YOLO detection every few frames for performance
        if self.frame_count % self.frame_skip != 0:
            return self.get_cached_detections()

        results = self.yolo_model(image, verbose=False)
        person_boxes = []

        for result in results:
            boxes = result.boxes
            if boxes is not None:
                for box in boxes:
                    # Check if it's a person (class 0 in COCO) with higher confidence
                    if int(box.cls[0]) == 0 and float(box.conf[0]) > 0.6:  # Higher confidence threshold
                        x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()

                        # Filter out very small detections
                        width = x2 - x1
                        height = y2 - y1
                        if width > 50 and height > 100:  # Minimum size filter
                            person_boxes.append({
                                'bbox': [int(x1), int(y1), int(x2), int(y2)],
                                'confidence': float(box.conf[0]),
                                'area': width * height
                            })

        # Sort by confidence and area
        person_boxes.sort(key=lambda x: (x['confidence'], x['area']), reverse=True)

        # Cache detections for frame skipping
        self.cached_detections = person_boxes
        return person_boxes

    def get_cached_detections(self):
        """Get cached detections when skipping YOLO processing"""
        return getattr(self, 'cached_detections', [])

    def calculate_iou(self, box1, box2):
        """Calculate Intersection over Union (IoU) of two bounding boxes"""
        x1_1, y1_1, x2_1, y2_1 = box1
        x1_2, y1_2, x2_2, y2_2 = box2

        # Calculate intersection
        x1_i = max(x1_1, x1_2)
        y1_i = max(y1_1, y1_2)
        x2_i = min(x2_1, x2_2)
        y2_i = min(y2_1, y2_2)

        if x2_i <= x1_i or y2_i <= y1_i:
            return 0.0

        intersection = (x2_i - x1_i) * (y2_i - y1_i)
        area1 = (x2_1 - x1_1) * (y2_1 - y1_1)
        area2 = (x2_2 - x1_2) * (y2_2 - y1_2)
        union = area1 + area2 - intersection

        return intersection / union if union > 0 else 0.0

    def detect_overlaps(self, person_boxes):
        """Detect overlapping persons and handle them"""
        overlaps = []
        processed = set()

        for i, box1_data in enumerate(person_boxes):
            if i in processed:
                continue

            box1 = box1_data['bbox']

            overlap_group = [i]
            for j, box2_data in enumerate(person_boxes[i+1:], i+1):
                if j in processed:
                    continue

                box2 = box2_data['bbox']

                iou = self.calculate_iou(box1, box2)
                if iou > self.overlap_threshold:
                    overlap_group.append(j)
                    processed.add(j)

            if len(overlap_group) > 1:
                overlaps.append(overlap_group)
                processed.update(overlap_group)

        return overlaps

    def extract_person_roi(self, image, bbox, padding=20):
        """Extract region of interest for a person with configurable padding using letterbox scaling"""
        x1, y1, x2, y2 = bbox
        h, w = image.shape[:2]

        x1 = max(0, x1 - padding)
        y1 = max(0, y1 - padding)
        x2 = min(w, x2 + padding)
        y2 = min(h, y2 + padding)

        # Extract ROI
        roi_image = image[y1:y2, x1:x2]
        original_roi_h, original_roi_w = roi_image.shape[:2]

        # Target dimensions for consistent processing
        target_height = 480
        target_width = 320

        # Calculate scale factor to fit ROI into target dimensions while preserving aspect ratio
        scale_x = target_width / original_roi_w
        scale_y = target_height / original_roi_h
        scale = min(scale_x, scale_y)  # Use smaller scale to fit entirely

        # Calculate new dimensions after scaling
        new_w = int(original_roi_w * scale)
        new_h = int(original_roi_h * scale)

        # Resize ROI with preserved aspect ratio
        if new_w != original_roi_w or new_h != original_roi_h:
            roi_image = cv2.resize(roi_image, (new_w, new_h))

        # Create letterbox (add padding to reach target dimensions)
        letterbox_image = np.zeros((target_height, target_width, 3), dtype=np.uint8)

        # Calculate padding offsets to center the image
        pad_x = (target_width - new_w) // 2
        pad_y = (target_height - new_h) // 2

        # Place the scaled ROI in the center of the letterbox
        letterbox_image[pad_y:pad_y + new_h, pad_x:pad_x + new_w] = roi_image

        # Store transformation parameters for proper landmark transformation
        roi_coords = (x1, y1, x2, y2, original_roi_w, original_roi_h, scale, pad_x, pad_y)

        return letterbox_image, roi_coords

    def apply_advanced_smoothing(self, poses):
        """Apply advanced temporal smoothing to poses"""
        smoothed_poses = []

        for pose_data in poses:
            person_id = pose_data['id']
            landmarks = pose_data['landmarks']

            if landmarks:
                # Apply smoothing to landmarks
                smoothed_landmarks = self.smooth_landmarks_temporally(person_id, landmarks)

                # Apply smoothing to bounding box
                smoothed_bbox = self.smooth_bbox_temporally(person_id, pose_data.get('bbox'))

                smoothed_poses.append({
                    'id': person_id,
                    'landmarks': smoothed_landmarks,
                    'world_landmarks': pose_data.get('world_landmarks'),
                    'bbox': smoothed_bbox,
                    'confidence': pose_data.get('confidence', 1.0),
                    'in_overlap': pose_data.get('in_overlap', False),
                    'similarity': pose_data.get('similarity', 0.0)
                })

        return smoothed_poses

    def smooth_landmarks_temporally(self, person_id, landmarks):
        """Apply temporal smoothing to landmarks using exponential moving average"""
        if person_id not in self.smoothing_buffer:
            self.smoothing_buffer[person_id] = {
                'landmarks_history': deque(maxlen=10),
                'alpha': 0.7  # Smoothing factor
            }

        buffer = self.smoothing_buffer[person_id]
        buffer['landmarks_history'].append(landmarks)

        if len(buffer['landmarks_history']) == 1:
            return landmarks

        # Apply exponential moving average
        alpha = buffer['alpha']
        smoothed_landmarks_list = []

        for i, landmark in enumerate(landmarks.landmark):
            # Get previous smoothed landmark from history
            if len(buffer['landmarks_history']) > 1:
                prev_landmarks = buffer['landmarks_history'][-2]
                if i < len(prev_landmarks.landmark):
                    prev_landmark = prev_landmarks.landmark[i]
                    if hasattr(prev_landmark, 'x'):
                        prev_x, prev_y, prev_z = prev_landmark.x, prev_landmark.y, prev_landmark.z
                    else:
                        prev_x = prev_landmark.get('x', 0)
                        prev_y = prev_landmark.get('y', 0)
                        prev_z = prev_landmark.get('z', 0)
                else:
                    prev_x = prev_y = prev_z = 0
            else:
                prev_x = prev_y = prev_z = 0


            # Handle both MediaPipe and simple landmarks
            if hasattr(landmark, 'x'):
                new_x = alpha * landmark.x + (1 - alpha) * prev_x
                new_y = alpha * landmark.y + (1 - alpha) * prev_y
                new_z = alpha * landmark.z + (1 - alpha) * prev_z
                new_visibility = landmark.visibility
            else:
                new_x = alpha * landmark.get('x', 0) + (1 - alpha) * prev_x
                new_y = alpha * landmark.get('y', 0) + (1 - alpha) * prev_y
                new_z = alpha * landmark.get('z', 0) + (1 - alpha) * prev_z
                new_visibility = landmark.get('visibility', 0)


            # Create smoothed landmark dictionary
            smoothed_landmark_dict = {
                'x': new_x,
                'y': new_y,
                'z': new_z,
                'visibility': new_visibility
            }
            smoothed_landmarks_list.append(smoothed_landmark_dict)

        # Create new landmarks object (MediaPipe or simple)
        if hasattr(landmarks, 'landmark'):
            from mediapipe.framework.formats import landmark_pb2
            new_landmarks = landmark_pb2.NormalizedLandmarkList()
            for smoothed_landmark_dict in smoothed_landmarks_list:
                new_landmark = new_landmarks.landmark.add()
                new_landmark.x = smoothed_landmark_dict['x']
                new_landmark.y = smoothed_landmark_dict['y']
                new_landmark.z = smoothed_landmark_dict['z']
                new_landmark.visibility = smoothed_landmark_dict['visibility']

            # Copy world landmarks if available
            if hasattr(landmarks, 'world_landmark') and landmarks.world_landmark:
                 for world_lm in landmarks.world_landmark:
                      new_world_landmark = new_landmarks.world_landmark.add()
                      new_world_landmark.x = world_lm.x
                      new_world_landmark.y = world_lm.y
                      new_world_landmark.z = world_lm.z
                      new_world_landmark.visibility = world_lm.visibility


            return new_landmarks
        else:
            # Simple landmarks object
            class SimpleLandmarks:
                def __init__(self, landmark_list):
                    self.landmark = landmark_list
                    self.world_landmark = None
            return SimpleLandmarks(smoothed_landmarks_list)

    def smooth_bbox_temporally(self, person_id, bbox):
        """Apply temporal smoothing to bounding box"""
        if not bbox:
            return bbox

        if person_id not in self.smoothing_buffer:
            self.smoothing_buffer[person_id] = {
                'bbox_history': deque(maxlen=5),
                'alpha': 0.8
            }

        buffer = self.smoothing_buffer[person_id]
        if 'bbox_history' not in buffer:
            buffer['bbox_history'] = deque(maxlen=5)

        buffer['bbox_history'].append(bbox)

        if len(buffer['bbox_history']) == 1:
            return bbox

        # Apply exponential moving average to bbox
        alpha = buffer['alpha']
        prev_bbox = buffer['bbox_history'][-2]

        smoothed_bbox = [
            alpha * bbox[0] + (1 - alpha) * prev_bbox[0],
            alpha * bbox[1] + (1 - alpha) * prev_bbox[1],
            alpha * bbox[2] + (1 - alpha) * prev_bbox[2],
            alpha * bbox[3] + (1 - alpha) * prev_bbox[3]
        ]

        return smoothed_bbox

    def process_person_roi(self, roi_image):
        """Process pose estimation for a single person ROI"""
        rgb_roi = cv2.cvtColor(roi_image, cv2.COLOR_BGR2RGB)
        results = self.pose.process(rgb_roi)

        if results.pose_landmarks:
            return results.pose_landmarks
        return None

    def transform_landmarks_to_full_image(self, landmarks, roi_coords, full_image_shape):
        """Transform landmarks from ROI coordinates to full image coordinates with letterbox support"""
        if not landmarks:
            return None

        full_h, full_w = full_image_shape[:2]
        target_height = 480
        target_width = 320

        # Handle different ROI coordinate formats
        if len(roi_coords) == 9:  # New format with letterbox parameters
            x1, y1, x2, y2, original_roi_w, original_roi_h, scale, pad_x, pad_y = roi_coords
        elif len(roi_coords) == 6:  # Old format without letterbox
            x1, y1, x2, y2, original_roi_w, original_roi_h = roi_coords
            scale = 1.0
            pad_x = pad_y = 0
        else:  # Very old format
            x1, y1, x2, y2 = roi_coords
            original_roi_w, original_roi_h = x2 - x1, y2 - y1
            scale = 1.0
            pad_x = pad_y = 0

        try:
            # Try using MediaPipe's proper method
            from mediapipe.framework.formats import landmark_pb2

            new_landmarks = landmark_pb2.NormalizedLandmarkList()

            for i, landmark in enumerate(landmarks.landmark):
                # Transform coordinates from letterbox ROI back to original ROI, then to full image

                # Step 1: Remove letterbox padding (convert from letterbox coordinates to scaled ROI coordinates)
                letterbox_x = landmark.x * target_width
                letterbox_y = landmark.y * target_height

                # Remove padding offset
                scaled_x = letterbox_x - pad_x
                scaled_y = letterbox_y - pad_y

                # Step 2: Convert from scaled ROI coordinates back to original ROI coordinates
                original_roi_x = scaled_x / scale
                original_roi_y = scaled_y / scale

                # Step 3: Transform to full image coordinates
                new_x = (original_roi_x + x1) / full_w
                new_y = (original_roi_y + y1) / full_h
                new_z = landmark.z  # Z coordinate doesn't need transformation

                # Add new landmark to the list
                new_landmark = new_landmarks.landmark.add()
                new_landmark.x = new_x
                new_landmark.y = new_y
                new_landmark.z = new_z
                new_landmark.visibility = landmark.visibility

            # Copy world landmarks if available
            if hasattr(landmarks, 'world_landmark') and landmarks.world_landmark:
                for world_lm in landmarks.world_landmark:
                    new_world_landmark = new_landmarks.world_landmark.add()
                    new_world_landmark.x = world_lm.x
                    new_world_landmark.y = world_lm.y
                    new_world_landmark.z = world_lm.z
                    new_world_landmark.visibility = world_lm.visibility

            return new_landmarks

        except Exception as e:
            print(f"⚠️ MediaPipe landmark creation failed: {e}")
            # Fallback: return a simple dictionary representation
            transformed_landmarks = []
            for i, landmark in enumerate(landmarks.landmark):
                # Transform coordinates from letterbox ROI back to original ROI, then to full image

                # Step 1: Remove letterbox padding
                letterbox_x = landmark.x * target_width
                letterbox_y = landmark.y * target_height

                # Remove padding offset
                scaled_x = letterbox_x - pad_x
                scaled_y = letterbox_y - pad_y

                # Step 2: Convert from scaled ROI coordinates back to original ROI coordinates
                original_roi_x = scaled_x / scale
                original_roi_y = scaled_y / scale

                # Step 3: Transform to full image coordinates
                new_x = (original_roi_x + x1) / full_w
                new_y = (original_roi_y + y1) / full_h
                new_z = landmark.z

                transformed_landmarks.append({
                    'x': new_x,
                    'y': new_y,
                    'z': new_z,
                    'visibility': landmark.visibility
                })

            # Create a simple object that mimics MediaPipe landmarks
            class SimpleLandmarks:
                def __init__(self, landmark_list):
                    self.landmark = landmark_list
                    self.world_landmark = None

            return SimpleLandmarks(transformed_landmarks)

    def calculate_person_similarity(self, landmarks1, landmarks2, bbox1=None, bbox2=None):
        """Calculate similarity between two pose landmarks with multiple features"""
        if not landmarks1 or not landmarks2:
            return 0.0

        similarity_score = 0.0
        weights = []

        # 1. Pose landmark similarity (weight: 0.6)
        pose_sim = self.calculate_pose_similarity(landmarks1, landmarks2)
        similarity_score += pose_sim * 0.6
        weights.append(0.6)

        # 2. Bounding box similarity (weight: 0.3)
        if bbox1 and bbox2:
            bbox_sim = self.calculate_bbox_similarity(bbox1, bbox2)
            similarity_score += bbox_sim * 0.3
            weights.append(0.3)

        # 3. Size similarity (weight: 0.1)
        if bbox1 and bbox2:
            size_sim = self.calculate_size_similarity(bbox1, bbox2)
            similarity_score += size_sim * 0.1
            weights.append(0.1)

        # Normalize by total weight
        total_weight = sum(weights)
        return similarity_score / total_weight if total_weight > 0 else 0.0

    def calculate_pose_similarity(self, landmarks1, landmarks2):
        """Calculate pose landmark similarity"""
        # Extract key points for comparison
        key_points = [0, 11, 12, 13, 14, 15, 16, 23, 24, 25, 26, 27, 28]

        points1 = []
        points2 = []

        for i in key_points:
            if i < len(landmarks1.landmark) and i < len(landmarks2.landmark):
                # Handle both MediaPipe and simple landmarks
                if hasattr(landmarks1.landmark[i], 'visibility'):
                    vis1 = landmarks1.landmark[i].visibility
                    x1 = landmarks1.landmark[i].x
                    y1 = landmarks1.landmark[i].y
                else:
                    vis1 = landmarks1.landmark[i].get('visibility', 0)
                    x1 = landmarks1.landmark[i].get('x', 0)
                    y1 = landmarks1.landmark[i].get('y', 0)

                if hasattr(landmarks2.landmark[i], 'visibility'):
                    vis2 = landmarks2.landmark[i].visibility
                    x2 = landmarks2.landmark[i].x
                    y2 = landmarks2.landmark[i].y
                else:
                    vis2 = landmarks2.landmark[i].get('visibility', 0)
                    x2 = landmarks2.landmark[i].get('x', 0)
                    y2 = landmarks2.landmark[i].get('y', 0)

                if vis1 > 0.5 and vis2 > 0.5:
                    points1.append([x1, y1])
                    points2.append([x2, y2])

        if len(points1) < 3 or len(points2) < 3:
            return 0.0

        # Calculate average distance
        points1 = np.array(points1)
        points2 = np.array(points2)

        distances = np.linalg.norm(points1 - points2, axis=1)
        avg_distance = np.mean(distances)

        # Convert distance to similarity (0-1)
        similarity = max(0, 1 - avg_distance * 3)  # More sensitive to distance
        return similarity

    def calculate_bbox_similarity(self, bbox1, bbox2):
        """Calculate bounding box similarity using IoU"""
        iou = self.calculate_iou(bbox1, bbox2)
        return iou

    def calculate_size_similarity(self, bbox1, bbox2):
        """Calculate size similarity between bounding boxes"""
        area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
        area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])

        if area1 == 0 or area2 == 0:
            return 0.0

        ratio = min(area1, area2) / max(area1, area2)
        return ratio

    def predict_next_position(self, person_id, current_bbox):
        """Predict next position using simple motion model"""
        if person_id not in self.kalman_filters:
            # Initialize simple motion model
            self.kalman_filters[person_id] = {
                'velocity': [0, 0],
                'last_position': current_bbox[:2],
                'last_time': time.time()
            }
            return current_bbox

        # Simple velocity-based prediction
        kalman = self.kalman_filters[person_id]
        current_time = time.time()
        dt = current_time - kalman['last_time']

        if dt > 0:
            # Update velocity
            current_center = [(current_bbox[0] + current_bbox[2]) / 2,
                            (current_bbox[1] + current_bbox[3]) / 2]
            last_center = kalman['last_position']

            velocity = [(current_center[0] - last_center[0]) / dt,
                       (current_center[1] - last_center[1]) / dt]

            # Smooth velocity
            kalman['velocity'] = [0.7 * kalman['velocity'][0] + 0.3 * velocity[0],
                                0.7 * kalman['velocity'][1] + 0.3 * velocity[1]]

            # Predict next position
            predicted_center = [current_center[0] + kalman['velocity'][0] * dt,
                              current_center[1] + kalman['velocity'][1] * dt]

            # Convert back to bbox
            width = current_bbox[2] - current_bbox[0]
            height = current_bbox[3] - current_bbox[1]

            predicted_bbox = [
                predicted_center[0] - width / 2,
                predicted_center[1] - height / 2,
                predicted_center[0] + width / 2,
                predicted_center[1] + height / 2
            ]

            # Update state
            kalman['last_position'] = current_center
            kalman['last_time'] = current_time

            return predicted_bbox

        return current_bbox

    def assign_person_ids(self, detected_poses):
        """Assign consistent IDs to detected poses using Hungarian algorithm"""
        if not detected_poses:
            return []

        # Create cost matrix for Hungarian algorithm
        existing_ids = list(self.person_tracks.keys())
        new_poses = [pose for pose in detected_poses if pose['landmarks'] is not None]

        if not new_poses:
            return []

        # Initialize cost matrix
        cost_matrix = np.full((len(existing_ids), len(new_poses)), float('inf'))

        # Calculate costs
        for i, person_id in enumerate(existing_ids):
            track = self.person_tracks[person_id]
            if len(track['landmarks_history']) > 0:
                last_landmarks = track['landmarks_history'][-1]
                last_bbox = track['bbox_history'][-1] if track['bbox_history'] else None

                for j, pose_data in enumerate(new_poses):
                    # Calculate similarity (lower is better for Hungarian)
                    similarity = self.calculate_person_similarity(
                        pose_data['landmarks'],
                        last_landmarks,
                        pose_data.get('bbox'),
                        last_bbox
                    )

                    # Convert similarity to cost (1 - similarity)
                    cost = 1.0 - similarity
                    cost_matrix[i, j] = cost

        # Use Hungarian algorithm to find optimal assignment
        try:
            from scipy.optimize import linear_sum_assignment
            row_indices, col_indices = linear_sum_assignment(cost_matrix)

            # Create assignments
            assignments = {}
            for i, j in zip(row_indices, col_indices):
                if cost_matrix[i, j] < 0.5:  # Only assign if cost is low enough
                    person_id = existing_ids[i]
                    assignments[j] = person_id
        except ImportError:
            # Fallback to greedy assignment
            assignments = {}
            used_poses = set()
            for i, person_id in enumerate(existing_ids):
                best_pose_idx = None
                best_cost = float('inf')

                for j, pose_data in enumerate(new_poses):
                    if j in used_poses:
                        continue

                    if cost_matrix[i, j] < best_cost:
                        best_cost = cost_matrix[i, j]
                        best_pose_idx = j

                if best_pose_idx is not None and best_cost < 0.5:
                    assignments[best_pose_idx] = person_id
                    used_poses.add(best_pose_idx)

        # Assign IDs
        current_poses = []
        for j, pose_data in enumerate(new_poses):
            if j in assignments:
                person_id = assignments[j]
                similarity = 1.0 - cost_matrix[existing_ids.index(person_id), j]
            else:
                # Create new person ID
                person_id = self.next_person_id
                self.next_person_id += 1
                similarity = 0.0

            current_poses.append({
                'id': person_id,
                'landmarks': pose_data['landmarks'],
                'world_landmarks': pose_data.get('world_landmarks'),
                'bbox': pose_data.get('bbox'),
                'confidence': pose_data.get('confidence', 1.0),
                'in_overlap': pose_data.get('in_overlap', False),
                'similarity': similarity
            })

        return current_poses

    def update_tracking_history(self, poses):
        """Update tracking history for all persons"""
        current_time = time.time()

        for pose in poses:
            person_id = pose['id']

            if person_id not in self.person_tracks:
                self.person_tracks[person_id] = {
                    'landmarks_history': deque(maxlen=self.config.TRACKING_HISTORY_SIZE),
                    'world_landmarks_history': deque(maxlen=self.config.TRACKING_HISTORY_SIZE),
                    'bbox_history': deque(maxlen=self.config.TRACKING_HISTORY_SIZE),
                    'last_seen': current_time,
                    'total_detections': 0
                }

            # Add to history
            self.person_tracks[person_id]['landmarks_history'].append(pose['landmarks'])
            if pose.get('world_landmarks'):
                self.person_tracks[person_id]['world_landmarks_history'].append(pose['world_landmarks'])
            if pose.get('bbox'):
                self.person_tracks[person_id]['bbox_history'].append(pose['bbox'])

            self.person_tracks[person_id]['last_seen'] = current_time
            self.person_tracks[person_id]['total_detections'] += 1

        # Clean up old tracks
        to_remove = []
        for person_id, track in self.person_tracks.items():
            if current_time - track['last_seen'] > 5.0:  # Remove after 5 seconds
                to_remove.append(person_id)

        for person_id in to_remove:
            del self.person_tracks[person_id]

    def process_frame(self, image):
        """Process a single frame and return multi-person pose data with improved tracking"""
        self.frame_count += 1
        detected_poses = []

        if self.use_yolo:
            # Use YOLO for person detection
            person_boxes = self.detect_persons_yolo(image)

            # Detect overlaps
            # Pass the list of dictionaries directly to detect_overlaps
            overlaps = self.detect_overlaps(person_boxes)

            # Process each person
            for i, person_box in enumerate(person_boxes):
                bbox = person_box['bbox']

                # Check if this person is in an overlap group
                in_overlap = any(i in group for group in overlaps)

                if in_overlap:
                    # Use larger ROI for overlapping persons
                    roi_image, roi_coords = self.extract_person_roi(image, bbox, padding=40)
                else:
                    roi_image, roi_coords = self.extract_person_roi(image, bbox)

                # Process pose for this person
                landmarks = self.process_person_roi(roi_image)

                if landmarks:
                    try:
                        # Transform landmarks back to full image coordinates
                        full_landmarks = self.transform_landmarks_to_full_image(
                            landmarks, roi_coords, image.shape
                        )

                        detected_poses.append({
                            'landmarks': full_landmarks,
                            'world_landmarks': landmarks.world_landmark if hasattr(landmarks, 'world_landmark') else None,
                            'bbox': bbox,
                            'confidence': person_box['confidence'],
                            'in_overlap': in_overlap
                        })
                    except Exception as e:
                        print(f"⚠️ Landmark transformation failed: {e}")
                        # Fallback: use original landmarks with bbox info
                        detected_poses.append({
                            'landmarks': landmarks,
                            'world_landmarks': landmarks.world_landmark if hasattr(landmarks, 'world_landmark') else None,
                            'bbox': bbox,
                            'confidence': person_box['confidence'],
                            'in_overlap': in_overlap
                        })
        else:
            # Fallback to single person detection
            rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            results = self.pose.process(rgb_image)

            if results.pose_landmarks:
                detected_poses.append({
                    'landmarks': results.pose_landmarks,
                    'world_landmarks': results.pose_world_landmarks if hasattr(results, 'pose_world_landmarks') else None,
                    'bbox': None,
                    'confidence': 1.0,
                    'in_overlap': False
                })

        # Assign person IDs using Hungarian algorithm
        poses_with_ids = self.assign_person_ids(detected_poses)

        # Apply temporal smoothing
        poses_with_ids = self.apply_advanced_smoothing(poses_with_ids)

        # Update tracking history
        self.update_tracking_history(poses_with_ids)

        # Create a list of dictionaries with 'bbox' keys for detect_overlaps
        bboxes_for_overlap = [{'bbox': pose.get('bbox', [0, 0, 1, 1])} for pose in detected_poses if pose.get('bbox')]

        return {
            'poses': poses_with_ids,
            'image_shape': image.shape,
            'num_persons_detected': len(poses_with_ids),
            'overlaps_detected': len(self.detect_overlaps(bboxes_for_overlap))
        }

print("✅ AdvancedMultiPersonPoseEstimator class created successfully!")

# 3D Pose Estimation and Analysis
class Pose3DAnalyzer:
    def __init__(self):
        self.landmark_names = [
            'nose', 'left_eye_inner', 'left_eye', 'left_eye_outer',
            'right_eye_inner', 'right_eye', 'right_eye_outer', 'left_ear',
            'right_ear', 'mouth_left', 'mouth_right', 'left_shoulder',
            'right_shoulder', 'left_elbow', 'right_elbow', 'left_wrist',
            'right_wrist', 'left_pinky', 'right_pinky', 'left_index',
            'right_index', 'left_thumb', 'right_thumb', 'left_hip',
            'right_hip', 'left_knee', 'right_knee', 'left_ankle', 'right_ankle',
            'left_heel', 'right_heel', 'left_foot_index', 'right_foot_index'
        ]

    def extract_3d_coordinates(self, landmarks):
        """Extract 3D coordinates from MediaPipe landmarks"""
        if not landmarks:
            return None

        coords_3d = []
        for landmark in landmarks.landmark:
            coords_3d.append({
                'x': landmark.x,
                'y': landmark.y,
                'z': landmark.z,
                'visibility': landmark.visibility
            })
        return coords_3d

    def calculate_angles_3d(self, landmarks):
        """Calculate 3D angles between key body parts"""
        if not landmarks or len(landmarks.landmark) < 33:
            return {}

        angles = {}

        # Helper function to get 3D vector between two points
        def get_vector3d(p1, p2):
            return np.array([p2.x - p1.x, p2.y - p1.y, p2.z - p1.z])

        # Helper function to calculate angle between two 3D vectors
        def angle_between_vectors(v1, v2):
            cos_angle = np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
            cos_angle = np.clip(cos_angle, -1.0, 1.0)
            return np.degrees(np.arccos(cos_angle))

        lm = landmarks.landmark

        # Elbow angles
        if (lm[11].visibility > 0.5 and lm[13].visibility > 0.5 and lm[15].visibility > 0.5):
            left_arm_vec1 = get_vector3d(lm[11], lm[13])  # shoulder to elbow
            left_arm_vec2 = get_vector3d(lm[13], lm[15])  # elbow to wrist
            angles['left_elbow'] = angle_between_vectors(left_arm_vec1, left_arm_vec2)

        if (lm[12].visibility > 0.5 and lm[14].visibility > 0.5 and lm[16].visibility > 0.5):
            right_arm_vec1 = get_vector3d(lm[12], lm[14])  # shoulder to elbow
            right_arm_vec2 = get_vector3d(lm[14], lm[16])  # elbow to wrist
            angles['right_elbow'] = angle_between_vectors(right_arm_vec1, right_arm_vec2)

        # Knee angles
        if (lm[23].visibility > 0.5 and lm[25].visibility > 0.5 and lm[27].visibility > 0.5):
            left_leg_vec1 = get_vector3d(lm[23], lm[25])  # hip to knee
            left_leg_vec2 = get_vector3d(lm[25], lm[27])  # knee to ankle
            angles['left_knee'] = angle_between_vectors(left_leg_vec1, left_leg_vec2)

        if (lm[24].visibility > 0.5 and lm[26].visibility > 0.5 and lm[28].visibility > 0.5):
            right_leg_vec1 = get_vector3d(lm[24], lm[26])  # hip to knee
            right_leg_vec2 = get_vector3d(lm[26], lm[28])  # knee to ankle
            angles['right_knee'] = angle_between_vectors(right_leg_vec1, right_leg_vec2)

        # Shoulder angles (for shooting stance analysis)
        if (lm[11].visibility > 0.5 and lm[12].visibility > 0.5 and lm[23].visibility > 0.5 and lm[24].visibility > 0.5):
            shoulder_vec = get_vector3d(lm[11], lm[12])  # left to right shoulder
            hip_vec = get_vector3d(lm[23], lm[24])       # left to right hip
            angles['shoulder_hip_alignment'] = angle_between_vectors(shoulder_vec, hip_vec)

        return angles

    def analyze_shooting_stance(self, landmarks):
        """Analyze shooting stance quality"""
        if not landmarks or len(landmarks.landmark) < 33:
            return {}

        analysis = {}
        lm = landmarks.landmark

        # Stance width analysis
        if lm[23].visibility > 0.5 and lm[24].visibility > 0.5:
            hip_width = abs(lm[24].x - lm[23].x)
            analysis['stance_width'] = hip_width
            analysis['stance_width_rating'] = 'good' if 0.1 < hip_width < 0.3 else 'needs_adjustment'

        # Shoulder alignment
        if (lm[11].visibility > 0.5 and lm[12].visibility > 0.5 and
            lm[23].visibility > 0.5 and lm[24].visibility > 0.5):
            shoulder_center = (lm[11].x + lm[12].x) / 2
            hip_center = (lm[23].x + lm[24].x) / 2
            alignment_offset = abs(shoulder_center - hip_center)
            analysis['shoulder_hip_alignment'] = alignment_offset
            analysis['alignment_rating'] = 'good' if alignment_offset < 0.05 else 'needs_adjustment'

        # Arm extension for shooting
        if (lm[12].visibility > 0.5 and lm[14].visibility > 0.5 and lm[16].visibility > 0.5):
            # Calculate arm extension (shoulder to wrist distance)
            arm_length = np.sqrt(
                (lm[16].x - lm[12].x)**2 +
                (lm[16].y - lm[12].y)**2 +
                (lm[16].z - lm[12].z)**2
            )
            analysis['arm_extension'] = arm_length
            analysis['extension_rating'] = 'good' if arm_length > 0.3 else 'needs_extension'

        # Grip analysis (wrist to hand points)
        if (lm[16].visibility > 0.5 and lm[20].visibility > 0.5):
            grip_stability = abs(lm[20].x - lm[16].x) + abs(lm[20].y - lm[16].y)
            analysis['grip_stability'] = grip_stability
            analysis['grip_rating'] = 'stable' if grip_stability < 0.1 else 'unstable'

        return analysis

    def analyze_martial_arts_stance(self, landmarks):
        """Analyze martial arts stance quality"""
        if not landmarks or len(landmarks.landmark) < 33:
            return {}

        analysis = {}
        lm = landmarks.landmark

        # Center of gravity analysis
        if (lm[23].visibility > 0.5 and lm[24].visibility > 0.5 and
            lm[25].visibility > 0.5 and lm[26].visibility > 0.5):
            # Calculate center of gravity based on hip and knee positions
            cog_x = (lm[23].x + lm[24].x + lm[25].x + lm[26].x) / 4
            cog_y = (lm[23].y + lm[24].y + lm[25].y + lm[26].y) / 4
            analysis['center_of_gravity'] = {'x': cog_x, 'y': cog_y}

            # Check if COG is centered between feet
            if lm[27].visibility > 0.5 and lm[28].visibility > 0.5:
                foot_center_x = (lm[27].x + lm[28].x) / 2
                cog_offset = abs(cog_x - foot_center_x)
                analysis['cog_balance'] = cog_offset
                analysis['balance_rating'] = 'balanced' if cog_offset < 0.1 else 'unbalanced'

        # Stance depth (front to back)
        if lm[25].visibility > 0.5 and lm[26].visibility > 0.5:
            stance_depth = abs(lm[26].y - lm[25].y)
            analysis['stance_depth'] = stance_depth
            analysis['depth_rating'] = 'good' if 0.05 < stance_depth < 0.2 else 'needs_adjustment'

        # Knee bend analysis
        angles = self.calculate_angles_3d(landmarks)
        if 'left_knee' in angles and 'right_knee' in angles:
            knee_bend_avg = (angles['left_knee'] + angles['right_knee']) / 2
            analysis['knee_bend'] = knee_bend_avg
            analysis['knee_rating'] = 'good' if 120 < knee_bend_avg < 160 else 'needs_adjustment'

        return analysis

    def detect_hidden_body_parts(self, landmarks, previous_landmarks=None):
        """Detect and estimate hidden body parts using temporal analysis"""
        if not landmarks or len(landmarks.landmark) < 33:
            return {}

        hidden_parts = {}
        lm = landmarks.landmark

        # Analyze visibility of key points
        key_points = {
            'left_shoulder': 11, 'right_shoulder': 12,
            'left_elbow': 13, 'right_elbow': 14,
            'left_wrist': 15, 'right_wrist': 16,
            'left_hip': 23, 'right_hip': 24,
            'left_knee': 25, 'right_knee': 26,
            'left_ankle': 27, 'right_ankle': 28
        }

        for part_name, idx in key_points.items():
            if lm[idx].visibility < 0.5:  # Part is hidden or poorly visible
                hidden_parts[part_name] = {
                    'visibility': lm[idx].visibility,
                    'estimated_position': {
                        'x': lm[idx].x,
                        'y': lm[idx].y,
                        'z': lm[idx].z
                    },
                    'confidence': 'low'
                }

                # If we have previous landmarks, try to estimate position
                if previous_landmarks and len(previous_landmarks.landmark) > idx:
                    prev_lm = previous_landmarks.landmark[idx]
                    if prev_lm.visibility > 0.5:
                        # Use previous position as estimate
                        hidden_parts[part_name]['estimated_position'] = {
                            'x': prev_lm.x,
                            'y': prev_lm.y,
                            'z': prev_lm.z
                        }
                        hidden_parts[part_name]['confidence'] = 'medium'

        return hidden_parts

print("✅ Pose3DAnalyzer class created successfully!")

# Visualization and Analysis Tools
class PoseVisualizer:
    def __init__(self):
        self.colors = [
            (255, 0, 0),    # Red
            (0, 255, 0),    # Green
            (0, 0, 255),    # Blue
            (255, 255, 0),  # Yellow
            (255, 0, 255),  # Magenta
            (0, 255, 255),  # Cyan
        ]

    def draw_pose_2d(self, image, poses, draw_connections=True, draw_landmarks=True):
        """Draw 2D pose visualization on image"""
        annotated_image = image.copy()

        for i, pose_data in enumerate(poses):
            person_id = pose_data['id']
            landmarks = pose_data['landmarks']
            bbox = pose_data.get('bbox')

            # Choose color for this person
            color = self.colors[person_id % len(self.colors)]

            # Draw bounding box if available
            if bbox:
                x1, y1, x2, y2 = bbox
                cv2.rectangle(annotated_image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
                cv2.putText(annotated_image, f"Person {person_id}", (int(x1), int(y1-10)),
                           cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)

            if landmarks:
                h, w = image.shape[:2]

                # Draw landmarks
                if draw_landmarks:
                    for landmark in landmarks.landmark:
                        # Handle both MediaPipe landmarks and simple dict landmarks
                        if hasattr(landmark, 'visibility'):
                            visibility = landmark.visibility
                            x_coord = landmark.x
                            y_coord = landmark.y
                        else:
                            visibility = landmark.get('visibility', 1.0)
                            x_coord = landmark.get('x', 0)
                            y_coord = landmark.get('y', 0)

                        if visibility > 0.5:
                            x = int(x_coord * w)
                            y = int(y_coord * h)

                            cv2.circle(annotated_image, (x, y), 3, color, -1)

                # Draw connections
                if draw_connections:
                    for connection in MARTIAL_ARTS_CONNECTIONS:
                        start_idx, end_idx = connection
                        if (start_idx < len(landmarks.landmark) and
                            end_idx < len(landmarks.landmark)):

                            start_lm = landmarks.landmark[start_idx]
                            end_lm = landmarks.landmark[end_idx]

                            # Handle both MediaPipe landmarks and simple dict landmarks
                            if hasattr(start_lm, 'visibility'):
                                start_vis = start_lm.visibility
                                start_x = start_lm.x
                                start_y = start_lm.y
                            else:
                                start_vis = start_lm.get('visibility', 1.0)
                                start_x = start_lm.get('x', 0)
                                start_y = start_lm.get('y', 0)

                            if hasattr(end_lm, 'visibility'):
                                end_vis = end_lm.visibility
                                end_x = end_lm.x
                                end_y = end_lm.y
                            else:
                                end_vis = end_lm.get('visibility', 1.0)
                                end_x = end_lm.get('x', 0)
                                end_y = end_lm.get('y', 0)

                            if start_vis > 0.5 and end_vis > 0.5:
                                start_x_pixel = int(start_x * w)
                                start_y_pixel = int(start_y * h)
                                end_x_pixel = int(end_x * w)
                                end_y_pixel = int(end_y * h)
                                cv2.line(annotated_image, (start_x_pixel, start_y_pixel), (end_x_pixel, end_y_pixel), color, 2)

        return annotated_image

    def create_3d_plot(self, poses, title="3D Pose Visualization"):
        """Create 3D plot of pose landmarks"""
        fig = go.Figure()

        for i, pose_data in enumerate(poses):
            person_id = pose_data['id']
            landmarks = pose_data['landmarks']

            if landmarks:
                # Extract 3D coordinates
                x_coords = [lm.x for lm in landmarks.landmark]
                y_coords = [lm.y for lm in landmarks.landmark]
                z_coords = [lm.z for lm in landmarks.landmark]
                visibility = [lm.visibility for lm in landmarks.landmark]

                # Filter visible points
                visible_x = [x for x, v in zip(x_coords, visibility) if v > 0.5]
                visible_y = [y for y, v in zip(y_coords, visibility) if v > 0.5]
                visible_z = [z for z, v in zip(z_coords, visibility) if v > 0.5]

                # Add scatter plot for landmarks
                fig.add_trace(go.Scatter3d(
                    x=visible_x,
                    y=visible_y,
                    z=visible_z,
                    mode='markers',
                    marker=dict(
                        size=5,
                        color=self.colors[person_id % len(self.colors)],
                        opacity=0.8
                    ),
                    name=f'Person {person_id}',
                    text=[f'Landmark {j}' for j in range(len(visible_x))],
                    hovertemplate='<b>%{text}</b><br>X: %{x:.3f}<br>Y: %{y:.3f}<br>Z: %{z:.3f}<extra></extra>'
                ))

                # Add connections
                for connection in MARTIAL_ARTS_CONNECTIONS:
                    start_idx, end_idx = connection
                    if (start_idx < len(landmarks.landmark) and
                        end_idx < len(landmarks.landmark)):

                        start_lm = landmarks.landmark[start_idx]
                        end_lm = landmarks.landmark[end_idx]

                        if start_lm.visibility > 0.5 and end_lm.visibility > 0.5:
                            fig.add_trace(go.Scatter3d(
                                x=[start_lm.x, end_lm.x],
                                y=[start_lm.y, end_lm.y],
                                z=[start_lm.z, end_lm.z],
                                mode='lines',
                                line=dict(
                                    color=self.colors[person_id % len(self.colors)],
                                    width=3
                                ),
                                showlegend=False,
                                hoverinfo='skip'
                            ))

        fig.update_layout(
            title=title,
            scene=dict(
                xaxis_title='X',
                yaxis_title='Y',
                zaxis_title='Z',
                aspectmode='data'
            ),
            width=800,
            height=600
        )

        return fig

    def create_analysis_dashboard(self, poses, analysis_data):
        """Create analysis dashboard with multiple plots"""
        if not poses:
            return None

        # Create subplots
        fig = make_subplots(
            rows=2, cols=2,
            subplot_titles=('Pose Analysis', '3D Visualization', 'Angles', 'Stance Quality'),
            specs=[[{"type": "scatter"}, {"type": "scatter3d"}],
                   [{"type": "bar"}, {"type": "bar"}]]
        )

        # Add pose analysis (placeholder for now)
        fig.add_trace(
            go.Scatter(x=[0, 1], y=[0, 1], mode='markers', name='Pose Points'),
            row=1, col=1
        )

        # Add 3D visualization
        for i, pose_data in enumerate(poses):
            landmarks = pose_data['landmarks']
            if landmarks:
                x_coords = [lm.x for lm in landmarks.landmark if lm.visibility > 0.5]
                y_coords = [lm.y for lm in landmarks.landmark if lm.visibility > 0.5]
                z_coords = [lm.z for lm in landmarks.landmark if lm.visibility > 0.5]

                fig.add_trace(
                    go.Scatter3d(
                        x=x_coords, y=y_coords, z=z_coords,
                        mode='markers',
                        name=f'Person {pose_data["id"]}',
                        marker=dict(size=3)
                    ),
                    row=1, col=2
                )

        # Add angles analysis
        if 'angles' in analysis_data:
            angles = analysis_data['angles']
            angle_names = list(angles.keys())
            angle_values = list(angles.values())

            fig.add_trace(
                go.Bar(x=angle_names, y=angle_values, name='Joint Angles'),
                row=2, col=1
            )

        # Add stance quality
        if 'stance_analysis' in analysis_data:
            stance = analysis_data['stance_analysis']
            metrics = list(stance.keys())
            values = [stance[m] if isinstance(stance[m], (int, float)) else 0 for m in metrics]

            fig.add_trace(
                go.Bar(x=metrics, y=values, name='Stance Metrics'),
                row=2, col=2
            )

        fig.update_layout(
            title="Pose Analysis Dashboard",
            height=800,
            showlegend=True
        )

        return fig

print("✅ PoseVisualizer class created successfully!")

# Video Processing and Testing
class VideoProcessor:
    def __init__(self, pose_estimator, pose_analyzer, visualizer):
        self.pose_estimator = pose_estimator
        self.pose_analyzer = pose_analyzer
        self.visualizer = visualizer
        self.frame_count = 0
        self.results_history = []

    def process_video_file(self, video_path, output_path=None, max_frames=None):
        """Process a video file and return results"""
        cap = cv2.VideoCapture(video_path)

        if not cap.isOpened():
            print(f"Error: Could not open video file {video_path}")
            return None

        fps = int(cap.get(cv2.CAP_PROP_FPS))
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

        print(f"Video Info: {width}x{height}, {fps} FPS, {total_frames} frames")

        # Setup video writer if output path is provided
        out = None
        if output_path:
            fourcc = cv2.VideoWriter_fourcc(*'mp4v')
            out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

        frame_results = []
        frame_count = 0

        while True:
            ret, frame = cap.read()
            if not ret:
                break

            if max_frames and frame_count >= max_frames:
                break

            # Process frame
            start_time = time.time()
            result = self.pose_estimator.process_frame(frame)
            processing_time = time.time() - start_time

            # Analyze poses
            analysis_results = []
            for pose_data in result['poses']:
                landmarks = pose_data['landmarks']
                if landmarks:
                    # 3D analysis
                    coords_3d = self.pose_analyzer.extract_3d_coordinates(landmarks)
                    angles = self.pose_analyzer.calculate_angles_3d(landmarks)
                    shooting_analysis = self.pose_analyzer.analyze_shooting_stance(landmarks)
                    martial_arts_analysis = self.pose_analyzer.analyze_martial_arts_stance(landmarks)
                    hidden_parts = self.pose_analyzer.detect_hidden_body_parts(landmarks)

                    analysis_results.append({
                        'person_id': pose_data['id'],
                        'coords_3d': coords_3d,
                        'angles': angles,
                        'shooting_analysis': shooting_analysis,
                        'martial_arts_analysis': martial_arts_analysis,
                        'hidden_parts': hidden_parts
                    })

            # Draw visualization
            annotated_frame = self.visualizer.draw_pose_2d(frame, result['poses'])

            # Add processing info
            cv2.putText(annotated_frame, f"Frame: {frame_count}", (10, 30),
                       cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
            cv2.putText(annotated_frame, f"Persons: {len(result['poses'])}", (10, 70),
                       cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
            cv2.putText(annotated_frame, f"Time: {processing_time:.3f}s", (10, 110),
                       cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)

            # Write frame if output is specified
            if out:
                out.write(annotated_frame)

            # Store results
            frame_results.append({
                'frame_number': frame_count,
                'poses': result['poses'],
                'analysis': analysis_results,
                'processing_time': processing_time,
                'timestamp': frame_count / fps
            })

            frame_count += 1

            # Print progress
            if frame_count % 30 == 0:
                print(f"Processed {frame_count}/{total_frames} frames")

        cap.release()
        if out:
            out.release()

        print(f"Video processing complete. Processed {frame_count} frames.")
        return frame_results

    def process_webcam(self, duration=10):
        """Process webcam feed for real-time testing"""
        cap = cv2.VideoCapture(0)

        if not cap.isOpened():
            print("Error: Could not open webcam")
            return

        start_time = time.time()
        frame_count = 0

        print("Starting webcam processing. Press 'q' to quit.")

        while True:
            ret, frame = cap.read()
            if not ret:
                break

            # Check duration
            if time.time() - start_time > duration:
                break

            # Process frame
            result = self.pose_estimator.process_frame(frame)

            # Draw visualization
            annotated_frame = self.visualizer.draw_pose_2d(frame, result['poses'])

            # Add info
            cv2.putText(annotated_frame, f"Frame: {frame_count}", (10, 30),
                       cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
            cv2.putText(annotated_frame, f"Persons: {len(result['poses'])}", (10, 70),
                       cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)

            # Show frame
            cv2.imshow('Pose Estimation', annotated_frame)

            # Break on 'q' key
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break

            frame_count += 1

        cap.release()
        cv2.destroyAllWindows()
        print(f"Webcam processing complete. Processed {frame_count} frames.")

    def create_sample_video(self, output_path="sample_poses.mp4", duration=5):
        """Create a sample video with pose estimation for testing"""
        # Create a simple test video with moving rectangles (simulating people)
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(output_path, fourcc, 30, (640, 480))

        for frame_num in range(duration * 30):
            # Create a black frame
            frame = np.zeros((480, 640, 3), dtype=np.uint8)

            # Add some moving rectangles to simulate people
            x1 = int(100 + 50 * np.sin(frame_num * 0.1))
            y1 = int(200 + 30 * np.cos(frame_num * 0.1))
            x2 = int(400 + 40 * np.sin(frame_num * 0.15))
            y2 = int(150 + 25 * np.cos(frame_num * 0.12))

            cv2.rectangle(frame, (x1, y1), (x1+100, y1+200), (255, 255, 255), -1)
            cv2.rectangle(frame, (x2, y2), (x2+100, y2+200), (255, 255, 255), -1)

            out.write(frame)

        out.release()
        print(f"Sample video created: {output_path}")

print("✅ VideoProcessor class created successfully!")

In [None]:
# Initialize the complete system
print("🚀 Initializing Multi-Person 3D Pose Estimation System...")

# Create configuration
config = PoseConfig()

# Initialize components
pose_estimator = AdvancedMultiPersonPoseEstimator(config)
pose_analyzer = Pose3DAnalyzer()
visualizer = PoseVisualizer()
video_processor = VideoProcessor(pose_estimator, pose_analyzer, visualizer)

print("✅ System initialized successfully!")
print(f"📊 Configuration:")
print(f"   - Max persons: {config.MAX_PERSONS}")
print(f"   - Model complexity: {config.MEDIAPIPE_MODEL_COMPLEXITY}")
print(f"   - 3D estimation: {config.ENABLE_3D_ESTIMATION}")
print(f"   - YOLO enabled: {pose_estimator.use_yolo}")
print(f"   - Smoothing window: {config.SMOOTHING_WINDOW_SIZE}")

# Test with a sample image
print("\n🧪 Testing with sample image...")

# Create a test image
test_image = np.zeros((480, 640, 3), dtype=np.uint8)
cv2.rectangle(test_image, (100, 100), (300, 400), (255, 255, 255), -1)
cv2.rectangle(test_image, (400, 150), (600, 450), (255, 255, 255), -1)

# Process the test image
result = pose_estimator.process_frame(test_image)
print(f"✅ Test completed. Detected {len(result['poses'])} persons.")

# Display system status
print(f"\n📈 System Status:")
print(f"   - Pose estimator: ✅ Ready")
print(f"   - 3D analyzer: ✅ Ready") 
print(f"   - Visualizer: ✅ Ready")
print(f"   - Video processor: ✅ Ready")
print(f"   - Person tracks: {len(pose_estimator.person_tracks)}")
