In [1]:
import os
import json
import shutil
# import random  <-- Removed random as we want deterministic sequential order
from tqdm import tqdm

def convert_to_yolo_pose(annotations_root, images_root, output_root, action_names, train_split=0.8):
    """
    Converts a COCO-style pose dataset into the YOLOv8-pose format.
    Splits data SEQUENTIALLY and STRATIFIED by class.
    """
    
    # --- 1. Create Output Directories ---
    yolo_images_train = os.path.join(output_root, 'images', 'train')
    yolo_images_val = os.path.join(output_root, 'images', 'val')
    yolo_labels_train = os.path.join(output_root, 'labels', 'train')
    yolo_labels_val = os.path.join(output_root, 'labels', 'val')
    
    os.makedirs(yolo_images_train, exist_ok=True)
    os.makedirs(yolo_images_val, exist_ok=True)
    os.makedirs(yolo_labels_train, exist_ok=True)
    os.makedirs(yolo_labels_val, exist_ok=True)
    
    # --- 2. Create Class Mapping ---
    action_to_label = {name: i for i, name in enumerate(action_names)}
    print(f"Class mapping: {action_to_label}")
    
    all_samples = []

    # --- 3. Gather Samples Per Action ---
    print("Gathering and splitting samples...")
    
    for action_name in action_names:
        action_label = action_to_label[action_name]
        json_path = os.path.join(annotations_root, f"{action_name}.json")
        img_folder = os.path.join(images_root, action_name)
        
        if not os.path.exists(json_path):
            print(f"Warning: JSON file not found at {json_path}. Skipping.")
            continue
            
        with open(json_path, 'r') as f:
            data = json.load(f)

        # Map image_id -> annotations
        image_id_to_anns = {}
        for ann in data.get('annotations', []):
            img_id = ann['image_id']
            if img_id not in image_id_to_anns:
                image_id_to_anns[img_id] = []
            image_id_to_anns[img_id].append(ann)

        # Map image_id -> image_info
        image_id_to_info = {img['id']: img for img in data.get('images', [])}

        # Temporary list to hold samples JUST for this action
        current_action_samples = []

        for img_id, annotations in image_id_to_anns.items():
            if img_id not in image_id_to_info:
                continue
                
            img_info = image_id_to_info[img_id]
            file_name = img_info['file_name']
            img_path = os.path.join(img_folder, file_name)
            
            if not os.path.exists(img_path):
                continue
                
            current_action_samples.append({
                'img_path': img_path,
                'file_name': file_name,
                'annotations': annotations,
                'action_label': action_label,
                'width': img_info['width'],
                'height': img_info['height']
            })

        # --- KEY CHANGE: Sequential & Stratified Logic ---
        
        # 1. Sort by filename to ensure temporal order (B_001, B_002...)
        current_action_samples.sort(key=lambda x: x['file_name'])
        
        # 2. Calculate split index for THIS specific action
        num_train = int(len(current_action_samples) * train_split)
        
        # 3. Assign split and add to main list
        for i, sample in enumerate(current_action_samples):
            if i < num_train:
                sample['split'] = 'train'
            else:
                sample['split'] = 'val'
            
            all_samples.append(sample)

    print(f"Found {len(all_samples)} total samples.")
    
    # --- 4. Process and Write Files ---
    print("Processing files...")
    # NOTE: random.shuffle is REMOVED to keep the order clean (though it doesn't matter for writing)
    
    for sample in tqdm(all_samples):
        # Retrieve the split we assigned earlier
        split = sample['split']
        
        img_path = sample['img_path']
        img_w = sample['width']
        img_h = sample['height']
        
        # Define output paths based on the split
        yolo_img_path = os.path.join(output_root, 'images', split, sample['file_name'])
        
        txt_file_name = os.path.splitext(sample['file_name'])[0] + '.txt'
        yolo_label_path = os.path.join(output_root, 'labels', split, txt_file_name)
        
        # Copy the image
        shutil.copy2(img_path, yolo_img_path)
        
        # Create the YOLO label string
        yolo_lines = []
        for ann in sample['annotations']:
            class_id = sample['action_label']
            
            # --- Bounding Box ---
            bbox_xywh = ann['bbox'] # [x_min, y_min, w, h]
            x_center = bbox_xywh[0] + bbox_xywh[2] / 2
            y_center = bbox_xywh[1] + bbox_xywh[3] / 2
            
            # Normalize
            x_center_norm = x_center / img_w
            y_center_norm = y_center / img_h
            w_norm = bbox_xywh[2] / img_w
            h_norm = bbox_xywh[3] / img_h
            
            bbox_str = f"{class_id} {x_center_norm:.6f} {y_center_norm:.6f} {w_norm:.6f} {h_norm:.6f}"
            
            # # --- Keypoints ---
            # keypoints = ann['keypoints']
            # kpts_str_list = []
            
            # for j in range(0, len(keypoints), 3):
            #     kpt_x = keypoints[j]
            #     kpt_y = keypoints[j+1]
            #     kpt_vis = keypoints[j+2]
                
            #     kpt_x_norm = kpt_x / img_w
            #     kpt_y_norm = kpt_y / img_h
                
            #     kpts_str_list.extend([f"{kpt_x_norm:.6f}", f"{kpt_y_norm:.6f}", f"{kpt_vis:.0f}"])
            
            # kpts_str = " ".join(kpts_str_list)
            # yolo_lines.append(f"{bbox_str} {kpts_str}")
            yolo_lines.append(bbox_str)
            
        with open(yolo_label_path, 'w') as f:
            f.write("\n".join(yolo_lines))

    print("\nDone!")
    print(f"YOLOv8-pose dataset created at: {output_root}")
    # Quick count check
    train_count = len([f for f in os.listdir(yolo_images_train)])
    val_count = len([f for f in os.listdir(yolo_images_val)])
    print(f"Train images: {train_count}")
    print(f"Val images: {val_count}")

# --- Constants & Execution ---
ACTIONS = ['backhand', 'forehand']
ANNOTATIONS_DIR = './dataset/annotations'
IMAGES_DIR = './dataset/images'
OUTPUT_DIR = './yolo_tennis_dataset_seq_strat_corrected_v3_only_bbox_2_classes'

convert_to_yolo_pose(
    annotations_root=ANNOTATIONS_DIR,
    images_root=IMAGES_DIR,
    output_root=OUTPUT_DIR,
    action_names=ACTIONS,
    train_split=0.8
)

Class mapping: {'backhand': 0, 'forehand': 1}
Gathering and splitting samples...
Found 1000 total samples.
Processing files...


100%|██████████| 1000/1000 [04:05<00:00,  4.07it/s]



Done!
YOLOv8-pose dataset created at: ./yolo_tennis_dataset_seq_strat_corrected_v3_only_bbox_2_classes
Train images: 800
Val images: 200
