In [None]:
# -----------------------------
# Colab setup
# -----------------------------
!pip install -q pycocotools opencv-python wget matplotlib

import os, wget, zipfile
import torch
from torch.utils.data import Dataset, DataLoader
import cv2
import numpy as np
from pycocotools.coco import COCO
import matplotlib.pyplot as plt
from google.colab import drive

In [None]:
# -----------------------------
# 1. Define paths and URLs
# -----------------------------
drive.mount('/content/drive')

# Set data directory in your Google Drive
data_dir = "/content/drive/MyDrive/ColabDataset/coco_full"
os.makedirs(data_dir, exist_ok=True)

urls = {
    "train2017": "http://images.cocodataset.org/zips/train2017.zip",
    "val2017": "http://images.cocodataset.org/zips/val2017.zip",
    "annotations": "http://images.cocodataset.org/annotations/annotations_trainval2017.zip"
}

In [None]:
# -----------------------------
# 2. Download and extract function
# -----------------------------
def download_and_extract(name, url):
    zip_path = os.path.join(data_dir, f"{name}.zip")
    extract_path = os.path.join(data_dir, name)

    if os.path.exists(extract_path):
        print(f"{name} already exists, skipping download/extraction.")
        return

    # Download if zip doesn't exist
    if not os.path.exists(zip_path):
        print(f"Downloading {name}...")
        try:
            wget.download(url, zip_path)
        except Exception as e:
            raise RuntimeError(f"Failed to download {name}: {e}")

    # Extract
    print(f"\nExtracting {name}...")
    try:
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(data_dir)
    except zipfile.BadZipFile:
        raise RuntimeError(f"Bad zip file for {name}. Delete {zip_path} and retry.")
    print(f"{name} ready.")

# Download all datasets safely
for key, url in urls.items():
    download_and_extract(key, url)

In [None]:
# -----------------------------
# 3. Paths
# -----------------------------
train_images_dir = os.path.join(data_dir, "train2017")
val_images_dir   = os.path.join(data_dir, "val2017")
annotations_dir  = os.path.join(data_dir, "annotations")
train_ann_file = os.path.join(annotations_dir, "person_keypoints_train2017.json")
val_ann_file   = os.path.join(annotations_dir, "person_keypoints_val2017.json")

In [None]:
# -----------------------------
# 4. COCO Dataset class (human-only + existing images)
# -----------------------------
class COCODataset(Dataset):
    def __init__(self, annotation_file, image_dir, transform=None):
        if not os.path.exists(annotation_file):
            raise FileNotFoundError(f"Annotation file {annotation_file} not found.")
        if not os.path.exists(image_dir):
            raise FileNotFoundError(f"Image directory {image_dir} not found.")

        self.coco = COCO(annotation_file)
        self.image_dir = image_dir
        self.transform = transform

        # Filter: only human images that exist
        all_img_ids = self.coco.getImgIds(catIds=[1])
        self.img_ids = []
        for img_id in all_img_ids:
            img_info = self.coco.loadImgs(img_id)[0]
            img_path = os.path.join(self.image_dir, img_info['file_name'])
            if os.path.exists(img_path):
                self.img_ids.append(img_id)

        print(f"Loaded {len(self.img_ids)} human images from {image_dir}")

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        img_info = self.coco.loadImgs(img_id)[0]
        img_path = os.path.join(self.image_dir, img_info['file_name'])
        img = cv2.imread(img_path)
        img = img[:, :, ::-1]  # BGR -> RGB

        orig_h, orig_w = img.shape[:2]
        img_resized = cv2.resize(img, (256, 256))

        ann_ids = self.coco.getAnnIds(imgIds=img_id, catIds=[1])
        anns = self.coco.loadAnns(ann_ids)

        keypoints_list = []
        for ann in anns:
            if 'keypoints' in ann and np.sum(ann['keypoints']) > 0:
                kps = np.array(ann['keypoints']).reshape(-1, 3)[:, :2]
                keypoints_list.append(kps)
        keypoints = keypoints_list[0] if keypoints_list else np.zeros((17, 2))

        # Scale keypoints to 256x256
        keypoints[:, 0] = keypoints[:, 0] * 256 / orig_w
        keypoints[:, 1] = keypoints[:, 1] * 256 / orig_h

        img_tensor = torch.tensor(img_resized).permute(2, 0, 1).float() / 255.0

        # Apply transform if provided
        if self.transform is not None:
            img_tensor, keypoints = self.transform(img_tensor, keypoints)
        
        return img_tensor, keypoints

In [None]:
# -----------------------------
# 5. DataLoader helper (human-only)
# -----------------------------
def get_dataloader(annotation_file, image_dir, batch_size=4, transform=None):
    dataset = COCODataset(annotation_file, image_dir, transform=transform)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Example usage
train_loader = get_dataloader(train_ann_file, train_images_dir, batch_size=4)
val_loader   = get_dataloader(val_ann_file, val_images_dir, batch_size=4)


In [None]:
# -----------------------------
# 4.5. Custom Transform Class for Keypoints
# -----------------------------
import random
import torchvision.transforms.functional as TF

class KeypointTransform:
    def __init__(self, flip_prob=0.5, rotation_degrees=10, brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1):
        self.flip_prob = flip_prob
        self.rotation_degrees = rotation_degrees
        self.brightness = brightness
        self.contrast = contrast
        self.saturation = saturation
        self.hue = hue
    
    def __call__(self, image, keypoints):
        # Convert to PIL if it's a tensor
        if torch.is_tensor(image):
            image = TF.to_pil_image(image)
        
        # Get image dimensions
        img_width, img_height = image.size  # PIL format: (width, height)
        
        # Track transformations
        was_flipped = False
        rotation_angle = 0
        
        # Random horizontal flip
        if random.random() < self.flip_prob:
            image = TF.hflip(image)
            was_flipped = True
        
        # Random rotation
        if self.rotation_degrees > 0:
            rotation_angle = random.uniform(-self.rotation_degrees, self.rotation_degrees)
            image = TF.rotate(image, rotation_angle)
        
        # Color jitter
        if self.brightness > 0 or self.contrast > 0 or self.saturation > 0 or self.hue > 0:
            image = TF.adjust_brightness(image, 1 + random.uniform(-self.brightness, self.brightness))
            image = TF.adjust_contrast(image, 1 + random.uniform(-self.contrast, self.contrast))
            image = TF.adjust_saturation(image, 1 + random.uniform(-self.saturation, self.saturation))
            image = TF.adjust_hue(image, random.uniform(-self.hue, self.hue))
        
        # Convert back to tensor
        image = TF.to_tensor(image)
        
        # Transform keypoints
        transformed_keypoints = keypoints.copy()
        
        # Apply horizontal flip to keypoints
        if was_flipped:
            # Flip x-coordinates: new_x = width - old_x
            transformed_keypoints[:, 0] = img_width - 1 - transformed_keypoints[:, 0]
            
            # For human pose, we also need to swap left/right keypoints
            # COCO keypoint order: 0=nose, 1=left_eye, 2=right_eye, 3=left_ear, 4=right_ear,
            # 5=left_shoulder, 6=right_shoulder, 7=left_elbow, 8=right_elbow,
            # 9=left_wrist, 10=right_wrist, 11=left_hip, 12=right_hip,
            # 13=left_knee, 14=right_knee, 15=left_ankle, 16=right_ankle
            swap_pairs = [(1, 2), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12), (13, 14), (15, 16)]
            for left_idx, right_idx in swap_pairs:
                if left_idx < len(transformed_keypoints) and right_idx < len(transformed_keypoints):
                    transformed_keypoints[[left_idx, right_idx]] = transformed_keypoints[[right_idx, left_idx]]
        
        # Apply rotation to keypoints (around center)
        if abs(rotation_angle) > 0:
            center_x, center_y = img_width / 2, img_height / 2
            angle_rad = np.radians(-rotation_angle)  # Negative because image rotation is clockwise
            cos_angle, sin_angle = np.cos(angle_rad), np.sin(angle_rad)
            
            # Translate to origin, rotate, translate back
            x_centered = transformed_keypoints[:, 0] - center_x
            y_centered = transformed_keypoints[:, 1] - center_y
            
            x_rotated = x_centered * cos_angle - y_centered * sin_angle
            y_rotated = x_centered * sin_angle + y_centered * cos_angle
            
            transformed_keypoints[:, 0] = x_rotated + center_x
            transformed_keypoints[:, 1] = y_rotated + center_y
            
            # Clip to image bounds
            transformed_keypoints[:, 0] = np.clip(transformed_keypoints[:, 0], 0, img_width - 1)
            transformed_keypoints[:, 1] = np.clip(transformed_keypoints[:, 1], 0, img_height - 1)
        
        return image, transformed_keypoints

In [None]:
# -----------------------------
# 6. Visualize a batch
# -----------------------------
images, keypoints_batch = next(iter(val_loader))
fig, axes = plt.subplots(1, len(images), figsize=(16, 4))

for i, img in enumerate(images):
    img_np = img.permute(1, 2, 0).numpy()
    axes[i].imshow(img_np)
    axes[i].scatter(keypoints_batch[i][:, 0], keypoints_batch[i][:, 1], c='r', s=40)
    axes[i].axis('off')

plt.show()

print(f"Train images: {len(train_loader.dataset)}, Val images: {len(val_loader.dataset)}")

In [None]:
# -----------------------------
# 7. Transform images and visualize with proper keypoint handling
# -----------------------------

# Create custom transform that handles both images and keypoints
custom_transform = KeypointTransform(
    flip_prob=1.0,  # Always flip for demonstration
    rotation_degrees=20,  # Reduced rotation for better visualization
    brightness=0.2,  # Reduced brightness changes
    contrast=0.2,
    saturation=0.2,
    hue=0.05
)

# Also create a no-transform version for comparison
no_transform_loader = get_dataloader(val_ann_file, val_images_dir, batch_size=2, transform=None)
transform_loader = get_dataloader(val_ann_file, val_images_dir, batch_size=2, transform=custom_transform)

# Get the same images (set seed for consistency)
torch.manual_seed(42)
original_images, original_keypoints = next(iter(no_transform_loader))

torch.manual_seed(42)
transformed_images, transformed_keypoints = next(iter(transform_loader))

# Visualize original vs transformed
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Original images (top row)
for i in range(2):
    img_np = original_images[i].permute(1, 2, 0).numpy()
    axes[0, i].imshow(img_np)
    axes[0, i].scatter(original_keypoints[i][:, 0], original_keypoints[i][:, 1], c='red', s=50, alpha=0.8, edgecolors='white', linewidth=1)
    axes[0, i].set_title(f'Original {i+1}', fontsize=14)
    axes[0, i].axis('off')

# Transformed images (bottom row)
for i in range(2):
    img_np = transformed_images[i].permute(1, 2, 0).numpy()
    axes[1, i].imshow(img_np)
    axes[1, i].scatter(transformed_keypoints[i][:, 0], transformed_keypoints[i][:, 1], c='blue', s=50, alpha=0.8, edgecolors='white', linewidth=1)
    axes[1, i].set_title(f'Transformed {i+1}', fontsize=14)
    axes[1, i].axis('off')

plt.tight_layout()
plt.show()

print("🔴 Red dots: Original keypoints")
print("🔵 Blue dots: Transformed keypoints")
print("✅ The keypoints should now accurately follow the image transformations!")
print("📝 Note: Horizontal flip also swaps left/right body parts (e.g., left eye ↔ right eye)")

In [None]:
# -----------------------------
# 7.5. Debug keypoint transformation
# -----------------------------
# Let's create a simple test to verify keypoint transformations work correctly

# Create a deterministic transform for testing
class DebugTransform:
    def __init__(self, flip=False, rotation=0):
        self.flip = flip
        self.rotation = rotation
    
    def __call__(self, image, keypoints):
        if torch.is_tensor(image):
            image = TF.to_pil_image(image)
        
        img_width, img_height = image.size
        transformed_keypoints = keypoints.copy()
        
        if self.flip:
            image = TF.hflip(image)
            # Flip x-coordinates
            transformed_keypoints[:, 0] = img_width - 1 - transformed_keypoints[:, 0]
            # Swap left/right keypoints
            swap_pairs = [(1, 2), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12), (13, 14), (15, 16)]
            for left_idx, right_idx in swap_pairs:
                if left_idx < len(transformed_keypoints) and right_idx < len(transformed_keypoints):
                    transformed_keypoints[[left_idx, right_idx]] = transformed_keypoints[[right_idx, left_idx]]
        
        if self.rotation != 0:
            image = TF.rotate(image, self.rotation)
            center_x, center_y = img_width / 2, img_height / 2
            angle_rad = np.radians(-self.rotation)
            cos_angle, sin_angle = np.cos(angle_rad), np.sin(angle_rad)
            
            x_centered = transformed_keypoints[:, 0] - center_x
            y_centered = transformed_keypoints[:, 1] - center_y
            
            x_rotated = x_centered * cos_angle - y_centered * sin_angle
            y_rotated = x_centered * sin_angle + y_centered * cos_angle
            
            transformed_keypoints[:, 0] = x_rotated + center_x
            transformed_keypoints[:, 1] = y_rotated + center_y
            
            transformed_keypoints[:, 0] = np.clip(transformed_keypoints[:, 0], 0, img_width - 1)
            transformed_keypoints[:, 1] = np.clip(transformed_keypoints[:, 1], 0, img_height - 1)
        
        return TF.to_tensor(image), transformed_keypoints

# Test different transformations - 8 total (4 columns x 2 rows)
transforms_to_test = [
    ("Original", None),
    ("Horizontal Flip", DebugTransform(flip=True)),
    ("Rotation 15°", DebugTransform(rotation=15)),
    ("Rotation 30°", DebugTransform(rotation=30)),
    ("Rotation -15°", DebugTransform(rotation=-15)),
    ("Flip + Rotate 15°", DebugTransform(flip=True, rotation=15)),
    ("Flip + Rotate 30°", DebugTransform(flip=True, rotation=30)),
    ("Flip + Rotate -15°", DebugTransform(flip=True, rotation=-15))
]

# Get two different samples
sample_dataset = COCODataset(val_ann_file, val_images_dir, transform=None)
sample_img1, sample_kpts1 = sample_dataset[0]
sample_img2, sample_kpts2 = sample_dataset[1]

# Create 4 columns x 2 rows visualization
fig, axes = plt.subplots(2, 4, figsize=(20, 10))

# First row - Sample 1
for i in range(4):
    title, transform = transforms_to_test[i]
    if transform is None:
        img_show = sample_img1
        kpts_show = sample_kpts1
    else:
        img_show, kpts_show = transform(sample_img1, sample_kpts1)
    
    img_np = img_show.permute(1, 2, 0).numpy()
    axes[0, i].imshow(img_np)
    axes[0, i].scatter(kpts_show[:, 0], kpts_show[:, 1], c='red', s=50, alpha=0.8, edgecolors='white', linewidth=1)
    axes[0, i].set_title(f'{title} - Sample 1', fontsize=12)
    axes[0, i].axis('off')

# Second row - Sample 2
for i in range(4):
    title, transform = transforms_to_test[i + 4]
    if transform is None:
        img_show = sample_img2
        kpts_show = sample_kpts2
    else:
        img_show, kpts_show = transform(sample_img2, sample_kpts2)
    
    img_np = img_show.permute(1, 2, 0).numpy()
    axes[1, i].imshow(img_np)
    axes[1, i].scatter(kpts_show[:, 0], kpts_show[:, 1], c='blue', s=50, alpha=0.8, edgecolors='white', linewidth=1)
    axes[1, i].set_title(f'{title} - Sample 2', fontsize=12)
    axes[1, i].axis('off')

plt.tight_layout()
plt.show()

print("🔴 Red dots: Keypoints on Sample 1 (top row)")
print("🔵 Blue dots: Keypoints on Sample 2 (bottom row)")
print("✅ This shows how keypoints accurately follow different transformations")
print("📝 Notice how flips swap left/right body parts and rotations preserve relative positions")