# MP-100 CAPE Training on Google Colab

This notebook trains Category-Agnostic Pose Estimation (CAPE) on the MP-100 dataset using Google Colab's GPU.

## Setup Instructions
1. Enable GPU: Runtime ‚Üí Change runtime type ‚Üí GPU (T4 or better)
2. Run all cells in order
3. The notebook will:
   - Clone code from GitHub
   - Install dependencies
   - Authenticate to GCP
   - Mount GCS bucket with data
   - Run training with "tiny" mode


## 1. Check GPU Availability


In [1]:
# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("‚ö†Ô∏è  No GPU detected! Please enable GPU in Runtime > Change runtime type > GPU")


CUDA available: True
GPU: NVIDIA A100-SXM4-40GB
CUDA version: 12.6
GPU Memory: 42.47 GB


In [2]:
# Clone repository from GitHub
import os
from getpass import getpass

REPO_URL = "https://github.com/nkkrnkl/category-agnostic-pose-estimation.git"
BRANCH = "pavlos-topic-copy"
PROJECT_ROOT = "/content/category-agnostic-pose-estimation"

# Remove existing directory if it exists
if os.path.exists(PROJECT_ROOT):
    print(f"Removing existing directory: {PROJECT_ROOT}")
    !rm -rf {PROJECT_ROOT}

# For private repositories, you need to authenticate
# Option 1: Use Personal Access Token (recommended)
# Get token from: https://github.com/settings/tokens
# Create a token with 'repo' scope
print("For private repositories, you need to authenticate.")
print("Option 1: Enter your GitHub Personal Access Token")
print("  (Get one from: https://github.com/settings/tokens)")
print("Option 2: Press Enter to try without token (will fail if repo is private)")
print()

GITHUB_TOKEN = getpass("Enter GitHub Personal Access Token (or press Enter to skip): ")

if GITHUB_TOKEN.strip():
    # Use token in URL
    # Format: https://TOKEN@github.com/username/repo.git
    AUTH_REPO_URL = REPO_URL.replace("https://github.com/", f"https://{GITHUB_TOKEN}@github.com/")
    print(f"Cloning repository from {REPO_URL} (branch: {BRANCH})...")
    !git clone -b {BRANCH} {AUTH_REPO_URL} {PROJECT_ROOT}
    !git pull origin {BRANCH}
else:
    # Try without token (will work if repo is public)
    print(f"Cloning repository from {REPO_URL} (branch: {BRANCH})...")
    !git clone -b {BRANCH} {REPO_URL} {PROJECT_ROOT}

# Verify clone
if os.path.exists(PROJECT_ROOT) and os.path.exists(os.path.join(PROJECT_ROOT, ".git")):
    print(f"‚úÖ Repository cloned successfully to {PROJECT_ROOT}")
    !cd {PROJECT_ROOT} && git branch
else:
    print("‚ùå Failed to clone repository")
    print("\nIf the repository is private, you need to:")
    print("1. Create a Personal Access Token at: https://github.com/settings/tokens")
    print("2. Select 'repo' scope")
    print("3. Run this cell again and paste the token when prompted")


For private repositories, you need to authenticate.
Option 1: Enter your GitHub Personal Access Token
  (Get one from: https://github.com/settings/tokens)
Option 2: Press Enter to try without token (will fail if repo is private)

Enter GitHub Personal Access Token (or press Enter to skip): ¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑
Cloning repository from https://github.com/nkkrnkl/category-agnostic-pose-estimation.git (branch: pavlos-topic-copy)...
Cloning into '/content/category-agnostic-pose-estimation'...
remote: Enumerating objects: 1193, done.[K
remote: Counting objects: 100% (419/419), done.[K
remote: Compressing objects: 100% (305/305), done.[K
remote: Total 1193 (delta 193), reused 326 (delta 110), pack-reused 774 (from 2)[K
Receiving objects: 100% (1193/1193), 73.46 MiB | 39.61 MiB/s, done.
Resolving deltas: 100% (443/443), done.
Updating files: 100% (250/250), done.
fatal: not a git repository (or any of the parent directories): .git
‚úÖ Repository cloned successfully to /content/category-agn

In [3]:
import os

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"
BRANCH = "pavlos-topic-copy"

print(f"Pulling latest changes from branch {BRANCH}...")
!cd {PROJECT_ROOT} && git pull origin {BRANCH}

print("‚úÖ Git pull complete!")

Pulling latest changes from branch pavlos-topic-copy...
From https://github.com/nkkrnkl/category-agnostic-pose-estimation
 * branch            pavlos-topic-copy -> FETCH_HEAD
Already up to date.
‚úÖ Git pull complete!


## 3. Install Requirements


In [4]:
# Install additional dependencies needed for plot_utils and other utilities
# (descartes, shapely, etc. - these are in requirements.txt but not requirements_cape.txt)
print("Installing additional dependencies (descartes, shapely, etc.)...")
!pip install -q descartes shapely>=1.8.0
print("‚úÖ Additional dependencies installed!")


Installing additional dependencies (descartes, shapely, etc.)...
‚úÖ Additional dependencies installed!


In [5]:
# Install requirements
import os

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"
REQUIREMENTS_FILE = os.path.join(PROJECT_ROOT, "requirements_cape.txt")

print("Installing requirements from requirements_cape.txt...")
!cd {PROJECT_ROOT} && pip install -q -r {REQUIREMENTS_FILE}

# Install detectron2 for CUDA 11.8 (Colab typically has CUDA 11.8)
print("\nInstalling detectron2...")
!pip install -q 'git+https://github.com/facebookresearch/detectron2.git'

print("‚úÖ All dependencies installed!")


Installing requirements from requirements_cape.txt...
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m85.2/85.2 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m1.7/1.7 MB[0m [31m32.2 MB/s[0m eta [36m0:00:00[0m
[?25h
Installing detectron2...
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m50.2/50.2 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m154.5/154.5 kB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
[?25

## 4. Authenticate to GCP


In [6]:
# Authenticate to GCP
from google.colab import auth

print("Authenticating to GCP...")
auth.authenticate_user()

# Set GCP project
GCP_PROJECT = "dl-category-agnostic-pose-est"
!gcloud config set project {GCP_PROJECT}

print(f"‚úÖ Authenticated to GCP project: {GCP_PROJECT}")


Authenticating to GCP...
Are you sure you wish to set property [core/project] to 
dl-category-agnostic-pose-est?

Do you want to continue (Y/n)?  Y

Updated property [core/project].
‚úÖ Authenticated to GCP project: dl-category-agnostic-pose-est


## 5. Mount GCS Bucket


In [7]:
# Mount GCS bucket using gcsfuse
import os
import subprocess
import time

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"
BUCKET_NAME = "dl-category-agnostic-pose-mp100-data"
MOUNT_POINT = os.path.join(PROJECT_ROOT, "Raster2Seq_internal-main", "data")

# Install gcsfuse from Google's official repository
print("Installing gcsfuse...")
# Add Google's gcsfuse repository (updated method for newer Ubuntu versions)
!export GCSFUSE_REPO=gcsfuse-`lsb_release -c -s` && \
echo "deb http://packages.cloud.google.com/apt $GCSFUSE_REPO main" | sudo tee /etc/apt/sources.list.d/gcsfuse.list && \
curl -fsSL https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg && \
echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] http://packages.cloud.google.com/apt $GCSFUSE_REPO main" | sudo tee /etc/apt/sources.list.d/gcsfuse.list && \
sudo apt-get update && \
sudo apt-get install -y gcsfuse

# Verify installation
!which gcsfuse
print("‚úÖ gcsfuse installed")

# Create mount point directory and parent directories
print(f"Creating mount point: {MOUNT_POINT}")
os.makedirs(os.path.dirname(MOUNT_POINT), exist_ok=True)
os.makedirs(MOUNT_POINT, exist_ok=True)

# Check if already mounted
try:
    result = subprocess.run(['mountpoint', '-q', MOUNT_POINT], capture_output=True)
    if result.returncode == 0:
        print(f"‚úÖ Already mounted at {MOUNT_POINT}")
    else:
        # Try to unmount if exists but not properly mounted
        try:
            subprocess.run(['fusermount', '-u', MOUNT_POINT], capture_output=True, timeout=5)
        except:
            try:
                subprocess.run(['umount', MOUNT_POINT], capture_output=True, timeout=5)
            except:
                pass
except:
    pass

# Mount the bucket
print(f"Mounting gs://{BUCKET_NAME} to {MOUNT_POINT}...")
print("This may take a moment...")

# Run gcsfuse in background
# Note: In Colab, we need to run gcsfuse in background using shell &
print(f"Running: gcsfuse --implicit-dirs {BUCKET_NAME} {MOUNT_POINT}")
!nohup gcsfuse --implicit-dirs {BUCKET_NAME} {MOUNT_POINT} > /tmp/gcsfuse.log 2>&1 &

# Wait a moment for mount to initialize
print("Waiting for mount to initialize...")
time.sleep(8)  # Give it more time to mount

# Check mount status
print("\nChecking mount status...")
# Check mount log for errors
if os.path.exists("/tmp/gcsfuse.log"):
    with open("/tmp/gcsfuse.log", "r") as f:
        log_content = f.read()
        if log_content:
            print("Mount log:")
            print(log_content[-500:])  # Last 500 chars
        else:
            print("Mount log is empty (mount might still be initializing)")

# Also verify we can access the bucket directly with gsutil
print("\nVerifying bucket access with gsutil...")
!gsutil ls gs://{BUCKET_NAME}/ | head -10

# Verify mount
print(f"\nVerifying mount at: {MOUNT_POINT}")
print(f"Path exists: {os.path.exists(MOUNT_POINT)}")

# Check if actually mounted using mountpoint command
try:
    result = subprocess.run(['mountpoint', '-q', MOUNT_POINT], capture_output=True)
    is_mounted = (result.returncode == 0)
    print(f"Is mounted: {is_mounted}")
except:
    # Fallback: check mount table
    result = subprocess.run(['mount'], capture_output=True, text=True)
    is_mounted = MOUNT_POINT in result.stdout
    print(f"Is mounted (from mount table): {is_mounted}")

if os.path.exists(MOUNT_POINT) and is_mounted:
    try:
        # Try to list contents
        items = os.listdir(MOUNT_POINT)
        if len(items) > 0:
            print(f"‚úÖ GCS bucket mounted successfully!")
            print(f"Mount point: {MOUNT_POINT}")
            print(f"Found {len(items)} items in bucket")
            # List a few items to verify
            for item in items[:10]:
                item_path = os.path.join(MOUNT_POINT, item)
                item_type = "directory" if os.path.isdir(item_path) else "file"
                print(f"   - {item} ({item_type})")
        else:
            print(f"‚ö†Ô∏è  Mount point exists but is empty (0 items)")
            print(f"   This might indicate:")
            print(f"   1. Bucket is empty")
            print(f"   2. Mount didn't work correctly")
            print(f"   3. Permission issues")
    except PermissionError as e:
        print(f"‚ö†Ô∏è  Permission error accessing mount: {e}")
        print("   Mount might still be initializing, wait a moment and try again")
    except Exception as e:
        print(f"‚ö†Ô∏è  Mount point exists but cannot list contents: {e}")
        print("   This might indicate a mount issue")
        import traceback
        traceback.print_exc()
elif os.path.exists(MOUNT_POINT) and not is_mounted:
    print(f"‚ö†Ô∏è  Directory exists but is not mounted")
    print(f"   The directory exists but gcsfuse mount is not active")
    print(f"   Trying to mount again...")
    # Try mounting again
    !nohup gcsfuse --implicit-dirs {BUCKET_NAME} {MOUNT_POINT} > /tmp/gcsfuse.log 2>&1 &
    time.sleep(5)
    # Re-check
    items = os.listdir(MOUNT_POINT) if os.path.exists(MOUNT_POINT) else []
    if len(items) > 0:
        print(f"‚úÖ Mount successful after retry! Found {len(items)} items")
    else:
        print(f"‚ùå Mount still not working")
else:
    print("‚ùå Failed to mount GCS bucket")
    print(f"   Mount point: {MOUNT_POINT}")
    print(f"   Check:")
    print(f"   1. GCP authentication (run the GCP auth cell)")
    print(f"   2. Bucket name is correct: {BUCKET_NAME}")
    print(f"   3. You have read access to the bucket")
    print(f"   4. Check mount log: /tmp/gcsfuse.log")


Installing gcsfuse...
deb http://packages.cloud.google.com/apt gcsfuse-jammy main
deb [signed-by=/usr/share/keyrings/cloud.google.gpg] http://packages.cloud.google.com/apt gcsfuse-jammy main
Get:1 http://packages.cloud.google.com/apt gcsfuse-jammy InRelease [1,227 B]
Hit:2 https://cli.github.com/packages stable InRelease
Get:3 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease [3,632 B]
Get:4 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease [1,581 B]
Get:5 http://packages.cloud.google.com/apt gcsfuse-jammy/main amd64 Packages [49.8 kB]
Get:6 http://packages.cloud.google.com/apt gcsfuse-jammy/main all Packages [750 B]
Get:7 https://r2u.stat.illinois.edu/ubuntu jammy InRelease [6,555 B]
Get:8 http://security.ubuntu.com/ubuntu jammy-security InRelease [129 kB]
Hit:9 http://archive.ubuntu.com/ubuntu jammy InRelease
Get:10 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  Packages [2,153 kB]
Get:11 http://archi

## 6. Create Data Symlink


In [8]:
# Create symlink from data to mounted GCS bucket (as expected by START_TRAINING.sh)
import os

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"
MOUNTED_DATA = os.path.join(PROJECT_ROOT, "Raster2Seq_internal-main", "data")
DATA_SYMLINK = os.path.join(PROJECT_ROOT, "data")

print(f"Checking mount point: {MOUNTED_DATA}")
print(f"  Exists: {os.path.exists(MOUNTED_DATA)}")
if os.path.exists(MOUNTED_DATA):
    print(f"  Is directory: {os.path.isdir(MOUNTED_DATA)}")
    try:
        items = os.listdir(MOUNTED_DATA)
        print(f"  Can list contents: Yes ({len(items)} items)")
    except Exception as e:
        print(f"  Can list contents: No ({e})")

# Remove existing symlink or directory if it exists
if os.path.exists(DATA_SYMLINK):
    if os.path.islink(DATA_SYMLINK):
        print(f"Removing existing symlink: {DATA_SYMLINK}")
        os.unlink(DATA_SYMLINK)
    elif os.path.isdir(DATA_SYMLINK):
        print(f"Warning: {DATA_SYMLINK} exists as a directory (not a symlink)")
        print("   Removing it to create symlink...")
        import shutil
        shutil.rmtree(DATA_SYMLINK)
    else:
        print(f"Warning: {DATA_SYMLINK} exists and is not a symlink or directory")
        os.remove(DATA_SYMLINK)

# Create symlink
if os.path.exists(MOUNTED_DATA) and os.path.isdir(MOUNTED_DATA):
    try:
        # Use absolute path for symlink target
        MOUNTED_DATA_ABS = os.path.abspath(MOUNTED_DATA)
        print(f"\nCreating symlink:")
        print(f"  From: {DATA_SYMLINK}")
        print(f"  To: {MOUNTED_DATA_ABS}")
        os.symlink(MOUNTED_DATA_ABS, DATA_SYMLINK)
        print(f"‚úÖ Created symlink: {DATA_SYMLINK} -> {MOUNTED_DATA_ABS}")

        # Verify symlink
        if os.path.exists(DATA_SYMLINK):
            print(f"‚úÖ Symlink verified: {DATA_SYMLINK}")
            print(f"  Is symlink: {os.path.islink(DATA_SYMLINK)}")
            # Try to list contents through symlink
            try:
                items = os.listdir(DATA_SYMLINK)
                print(f"‚úÖ Can access {len(items)} items through symlink")
                print(f"   First 5 items: {items[:5]}")
            except Exception as e:
                print(f"‚ö†Ô∏è  Symlink exists but cannot access contents: {e}")
        else:
            print(f"‚ùå Symlink creation failed - path does not exist after creation")
    except Exception as e:
        print(f"‚ùå Error creating symlink: {e}")
        print(f"   Source: {MOUNTED_DATA}")
        print(f"   Target: {DATA_SYMLINK}")
        import traceback
        traceback.print_exc()
else:
    print(f"‚ùå Mounted data not found at {MOUNTED_DATA}")
    print(f"   Please check that GCS bucket is mounted correctly")
    print(f"   Run the mount cell above and check for errors")
    print(f"   Mount point should exist and be accessible")


Checking mount point: /content/category-agnostic-pose-estimation/Raster2Seq_internal-main/data
  Exists: True
  Is directory: True
  Can list contents: Yes (100 items)

Creating symlink:
  From: /content/category-agnostic-pose-estimation/data
  To: /content/category-agnostic-pose-estimation/Raster2Seq_internal-main/data
‚úÖ Created symlink: /content/category-agnostic-pose-estimation/data -> /content/category-agnostic-pose-estimation/Raster2Seq_internal-main/data
‚úÖ Symlink verified: /content/category-agnostic-pose-estimation/data
  Is symlink: True
‚úÖ Can access 100 items through symlink
   First 5 items: ['amur_tiger_body', 'annotations', 'antelope_body', 'arcticwolf_face', 'beaver_body']


In [None]:
## 7. Run Full Model Training (300 Epochs)

# This section trains the **full model** on all training categories for 300 epochs.
# This is the main training configuration for the complete CAPE model.

# **Training Configuration:**
# - Full training mode: All training categories
# - Epochs: 300
# - Geometric encoder: Enabled
# - GCN pre-encoding: Enabled
# - Batch size: 2 episodes
# - Accumulation steps: 4 (effective batch size = 8)
# - Episodes per epoch: 500
# - Early stopping patience: 20 epochs (stops if PCK doesn't improve)
# - All logs will be saved to output directory


# Configure full model training parameters


In [None]:

# Configure full model training
import os
from datetime import datetime

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"

# Training configuration
EPOCHS = 300
BATCH_SIZE = 2
NUM_QUERIES_PER_EPISODE = 2
NUM_SUPPORT_PER_EPISODE = 5  # 5-shot learning (change to 1 for 1-shot)
EPISODES_PER_EPOCH = 500
ACCUMULATION_STEPS = 4
EARLY_STOPPING_PATIENCE = 20  # Stop if PCK doesn't improve for 20 epochs

# Model configuration
USE_GEOMETRIC_ENCODER = True
USE_GCN_PREENC = True

# Output directories
OUTPUT_DIR = os.path.join(PROJECT_ROOT, "outputs", "test_geometric_1")
LOG_FILE = os.path.join(OUTPUT_DIR, f"output_log_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.log")

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

print("=" * 80)
print("Full Model Training Configuration (300 Epochs)")
print("=" * 80)
print(f"Epochs:                {EPOCHS}")
print(f"Batch size:            {BATCH_SIZE} episodes")
print(f"Accumulation steps:    {ACCUMULATION_STEPS}")
print(f"Effective batch size: {BATCH_SIZE * ACCUMULATION_STEPS} episodes")
print(f"Support/episode:       {NUM_SUPPORT_PER_EPISODE} ({NUM_SUPPORT_PER_EPISODE}-shot)")
print(f"Queries/episode:       {NUM_QUERIES_PER_EPISODE}")
print(f"Episodes/epoch:        {EPISODES_PER_EPOCH}")
print(f"Early stopping:        {EARLY_STOPPING_PATIENCE} epochs")
print(f"Geometric encoder:     {USE_GEOMETRIC_ENCODER}")
print(f"GCN pre-encoding:      {USE_GCN_PREENC}")
print(f"Output directory:      {OUTPUT_DIR}")
print(f"Log file:              {LOG_FILE}")
print("=" * 80)
print()




Single Image Training Configuration
Category ID:        40
Epochs:            50
Batch size:        1
Queries/episode:   1
Episodes/epoch:    20
Output directory:  /content/category-agnostic-pose-estimation/output/single_image_colab
Log file:          /content/category-agnostic-pose-estimation/output/single_image_colab/training_logs.txt



# Run Full Training with Logging


In [None]:
# Run full training on all categories
import subprocess
import sys
from datetime import datetime

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"
OUTPUT_DIR = os.path.join(PROJECT_ROOT, "outputs", "test_geometric_1")
LOG_FILE = os.path.join(OUTPUT_DIR, f"output_log_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.log")

# Build training command
cmd = [
    sys.executable, "-m", "models.train_cape_episodic",
    "--dataset_root", PROJECT_ROOT,
    "--category_split_file", os.path.join(PROJECT_ROOT, "category_splits.json"),
    "--output_dir", OUTPUT_DIR,
    "--device", "cuda:0",
    "--epochs", str(EPOCHS),
    "--batch_size", str(BATCH_SIZE),
    "--accumulation_steps", str(ACCUMULATION_STEPS),
    "--num_support_per_episode", str(NUM_SUPPORT_PER_EPISODE),
    "--num_queries_per_episode", str(NUM_QUERIES_PER_EPISODE),
    "--episodes_per_epoch", str(EPISODES_PER_EPOCH),
    "--early_stopping_patience", str(EARLY_STOPPING_PATIENCE),
    "--lr", "1e-4",
    "--lr_backbone", "1e-5",
    "--weight_decay", "1e-4",
    "--clip_max_norm", "0.1",
    "--support_encoder_layers", "3",
    "--support_fusion_method", "cross_attention",
    "--backbone", "resnet50",
    "--hidden_dim", "256",
    "--nheads", "8",
    "--enc_layers", "6",
    "--dec_layers", "6",
    "--dim_feedforward", "1024",
    "--dropout", "0.1",
    "--image_size", "512",  # Changed from 256 to 512 for better quality
    "--vocab_size", "2000",
    "--seq_len", "200",
    "--num_queries", "200",
    "--num_polys", "1",
    "--cls_loss_coef", "2.0",
    "--coords_loss_coef", "5.0",
    "--room_cls_loss_coef", "0.0",
    "--semantic_classes", "70",
    "--num_feature_levels", "4",
    "--dec_n_points", "4",
    "--enc_n_points", "4",
    "--aux_loss",
    "--with_poly_refine",
    "--num_workers", "2",
    "--seed", "42",
    "--print_freq", "10",
    "--use_amp",
    "--cudnn_benchmark",
    "--job_name", f"full_training_geometric_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
]

# Add geometric encoder flags
if USE_GEOMETRIC_ENCODER:
    cmd.append("--use_geometric_encoder")
if USE_GCN_PREENC:
    cmd.append("--use_gcn_preenc")
    cmd.extend(["--num_gcn_layers", "2"])

print("Starting full model training...")
print(f"Command: {' '.join(cmd)}")
print(f"Logging to: {LOG_FILE}")
print("=" * 80)
print()

# Change to project directory
os.chdir(PROJECT_ROOT)

# Run training with logging to both stdout and file
with open(LOG_FILE, 'w') as log_file:
    # Write header to log file
    log_file.write("=" * 80 + "\n")
    log_file.write(f"Full Model Training Log (300 Epochs)\n")
    log_file.write(f"Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
    log_file.write(f"Epochs: {EPOCHS}\n")
    log_file.write(f"Batch size: {BATCH_SIZE} (accumulation: {ACCUMULATION_STEPS})\n")
    log_file.write(f"Support per episode: {NUM_SUPPORT_PER_EPISODE} ({NUM_SUPPORT_PER_EPISODE}-shot)\n")
    log_file.write(f"Queries per episode: {NUM_QUERIES_PER_EPISODE}\n")
    log_file.write(f"Episodes per epoch: {EPISODES_PER_EPOCH}\n")
    log_file.write(f"Geometric encoder: {USE_GEOMETRIC_ENCODER}\n")
    log_file.write(f"GCN pre-encoding: {USE_GCN_PREENC}\n")
    log_file.write("=" * 80 + "\n\n")
    log_file.flush()

    # Run process and stream output to both stdout and file
    process = subprocess.Popen(
        cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        universal_newlines=True,
        bufsize=1
    )

    # Stream output in real-time
    for line in process.stdout:
        print(line, end='')  # Print to notebook
        log_file.write(line)  # Write to log file
        log_file.flush()  # Ensure immediate write

    # Wait for process to complete
    return_code = process.wait()

    # Write footer to log file
    log_file.write("\n" + "=" * 80 + "\n")
    log_file.write(f"Training completed: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
    log_file.write(f"Return code: {return_code}\n")
    log_file.write("=" * 80 + "\n")

if return_code == 0:
    print("\n" + "=" * 80)
    print("‚úÖ Training completed successfully!")
    print(f"Checkpoints saved to: {OUTPUT_DIR}")
    print(f"Full logs saved to: {LOG_FILE}")
    print("=" * 80)
else:
    print("\n" + "=" * 80)
    print(f"‚ùå Training failed with return code: {return_code}")
    print(f"Check logs at: {LOG_FILE}")
    print("=" * 80)


Finding second valid image from cleaned train annotations...
Data directory: /content/category-agnostic-pose-estimation/data
Annotation file: /content/category-agnostic-pose-estimation/annotations/mp100_split1_train.json

Found 12816 images in train annotations
‚úÖ Found second valid image:
   File name: camel_face/camel_16.jpg
   Full path: /content/category-agnostic-pose-estimation/data/camel_face/camel_16.jpg
   Image ID: 3700000000004990
   (Skipped 99 previous valid image(s))

‚úÖ Set SINGLE_IMAGE_PATH = /content/category-agnostic-pose-estimation/data/camel_face/camel_16.jpg


In [18]:
# Run training on single image with full logging
import subprocess
import sys
from datetime import datetime

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"
OUTPUT_DIR = os.path.join(PROJECT_ROOT, "output", "single_image_colab")
LOG_FILE = os.path.join(OUTPUT_DIR, "training_logs.txt")

# Build training command
cmd = [
    sys.executable, "-m", "models.train_cape_episodic",
    "--dataset_root", PROJECT_ROOT,
    "--category_split_file", os.path.join(PROJECT_ROOT, "category_splits.json"),
    "--output_dir", OUTPUT_DIR,
    "--device", "cuda:0",
]

# Add single image mode argument (path takes precedence over category)
if SINGLE_IMAGE_PATH:
    cmd.extend(["--debug_single_image_path", SINGLE_IMAGE_PATH])
elif 'SINGLE_IMAGE_CATEGORY' in globals() and SINGLE_IMAGE_CATEGORY is not None:
    cmd.extend(["--debug_single_image", str(SINGLE_IMAGE_CATEGORY)])

# Add remaining training arguments
cmd.extend([
    "--epochs", str(EPOCHS),
    "--batch_size", str(BATCH_SIZE),
    "--num_queries_per_episode", str(NUM_QUERIES_PER_EPISODE),
    "--episodes_per_epoch", str(EPISODES_PER_EPOCH),
    "--lr", "1e-4",
    "--lr_backbone", "1e-5",
    "--weight_decay", "1e-4",
    "--clip_max_norm", "0.1",
    "--support_encoder_layers", "3",
    "--support_fusion_method", "cross_attention",

    # >>> ADD THESE LINES <<<
    "--use_geometric_encoder",
    "--use_gcn_preenc",
    "--num_gcn_layers", "2",
    # <<< ADD THESE LINES <<<

    "--stop_when_loss_zero",
    "--early_stopping_patience", "0",
    "--loss_zero_threshold", "1e-5",
    "--backbone", "resnet50",
    "--hidden_dim", "256",
    "--nheads", "8",
    "--enc_layers", "6",
    "--dec_layers", "6",
    "--dim_feedforward", "1024",
    "--dropout", "0.1",
    "--image_size", "256",
    "--vocab_size", "2000",
    "--seq_len", "200",
    "--num_queries", "200",
    "--num_polys", "1",
    "--cls_loss_coef", "2.0",
    "--coords_loss_coef", "5.0",
    "--room_cls_loss_coef", "0.0",
    "--semantic_classes", "70",
    "--num_feature_levels", "4",
    "--dec_n_points", "4",
    "--enc_n_points", "4",
    "--aux_loss",
    "--with_poly_refine",
    "--num_workers", "2",
    "--seed", "42",
    "--print_freq", "5",
    "--use_amp",
    "--cudnn_benchmark",
    "--job_name", f"single_image_colab_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
])

print("Starting training...")
print(f"Command: {' '.join(cmd)}")
print(f"Logging to: {LOG_FILE}")
print("=" * 80)
print()

# Change to project directory
os.chdir(PROJECT_ROOT)

# Run training with logging to both stdout and file
with open(LOG_FILE, 'w') as log_file:
    # Write header to log file
    log_file.write("=" * 80 + "\n")
    log_file.write(f"Single Image Training Log\n")
    log_file.write(f"Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
    if 'SINGLE_IMAGE_PATH' in globals() and SINGLE_IMAGE_PATH:
        log_file.write(f"Image path: {SINGLE_IMAGE_PATH}\n")
    elif 'SINGLE_IMAGE_CATEGORY' in globals() and SINGLE_IMAGE_CATEGORY is not None:
        log_file.write(f"Category: {SINGLE_IMAGE_CATEGORY}\n")
    log_file.write(f"Epochs: {EPOCHS}\n")
    log_file.write("=" * 80 + "\n\n")
    log_file.flush()

    # Run process and stream output to both stdout and file
    process = subprocess.Popen(
        cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        universal_newlines=True,
        bufsize=1
    )

    # Stream output in real-time
    for line in process.stdout:
        print(line, end='')  # Print to notebook
        log_file.write(line)  # Write to log file
        log_file.flush()  # Ensure immediate write

    # Wait for process to complete
    return_code = process.wait()

    # Write footer to log file
    log_file.write("\n" + "=" * 80 + "\n")
    log_file.write(f"Training completed: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
    log_file.write(f"Return code: {return_code}\n")
    log_file.write("=" * 80 + "\n")

if return_code == 0:
    print("\n" + "=" * 80)
    print("‚úÖ Training completed successfully!")
    print(f"Checkpoints saved to: {OUTPUT_DIR}")
    print(f"Full logs saved to: {LOG_FILE}")
    print("=" * 80)
else:
    print("\n" + "=" * 80)
    print(f"‚ùå Training failed with return code: {return_code}")
    print(f"Check logs at: {LOG_FILE}")
    print("=" * 80)


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Epoch 11:   0%|          | 0/20 [00:03<?, ?it/s, loss=2.9580, loss_ce=0.0040, loss_coords=0.3169, lr=0.000100]
Epoch 11:   0%|          | 0/20 [00:03<?, ?it/s, loss=2.9843, loss_ce=0.0060, loss_coords=0.3283, lr=0.000100]
Epoch 11:   0%|          | 0/20 [00:03<?, ?it/s, loss=2.8205, loss_ce=0.0161, loss_coords=0.3096, lr=0.000100]
Epoch 11:   0%|          | 0/20 [00:03<?, ?it/s, loss=2.8344, loss_ce=0.0039, loss_coords=0.3135, lr=0.000100]
Epoch 11:   0%|          | 0/20 [00:03<?, ?it/s, loss=2.7039, loss_ce=0.0064, loss_coords=0.2781, lr=0.000100]
Epoch 11:   0%|          | 0/20 [00:03<?, ?it/s, loss=2.7839, loss_ce=0.0041, loss_coords=0.3030, lr=0.000100]
Epoch 11: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.12it/s, loss=2.7839, loss_ce=0.0041, loss_coords=0.3030, lr=0.000100]
Epoch 11: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 20/20 [00:03<00:00,  5.06it/s, loss=2.7839, loss_ce=0.0041, loss_coords=0.3030, lr

## 8. Evaluate Best Checkpoint on Test Data

This section evaluates the best checkpoint (highest PCK) on the **test set** to get final performance metrics.

The evaluation script:
- Loads the best checkpoint automatically
- Runs autoregressive inference (no teacher forcing)
- Computes PCK@0.2 metric
- Generates visualizations comparing GT vs predictions
- Saves metrics to JSON file


In [None]:
# Find best checkpoint and evaluate on test set
import os
import glob
import subprocess
import sys
import re
from pathlib import Path

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"
OUTPUT_DIR = os.path.join(PROJECT_ROOT, "outputs", "test_geometric_1")
EVAL_OUTPUT_DIR = os.path.join(OUTPUT_DIR, "test_evaluation")

# Find best PCK checkpoint
# Checkpoint names are like: checkpoint_best_pck_e035_pck0.8500_meanpck0.8200.pth
best_pck_ckpts = glob.glob(os.path.join(OUTPUT_DIR, "checkpoint_best_pck*.pth"))
if not best_pck_ckpts:
    print("‚ùå No best PCK checkpoint found! Please run training first.")
    print(f"   Looking in: {OUTPUT_DIR}")
    CHECKPOINT = None
else:
    def extract_pck(checkpoint_path):
        """Extract PCK value from checkpoint filename."""
        basename = os.path.basename(checkpoint_path)
        # Match pattern: checkpoint_best_pck_e035_pck0.8500_meanpck0.8200.pth
        match = re.search(r'_pck([\d.]+)_', basename)
        if match:
            return float(match.group(1))
        return 0.0
    
    # Sort checkpoints by PCK value (highest PCK = best)
    sorted_best_pck_ckpts = sorted(best_pck_ckpts, key=extract_pck, reverse=True)
    CHECKPOINT = sorted_best_pck_ckpts[0]  # Best PCK checkpoint
    best_pck_value = extract_pck(CHECKPOINT)
    
    print("=" * 80)
    print("FINDING BEST CHECKPOINT")
    print("=" * 80)
    print(f"‚úÖ Using best PCK checkpoint: {os.path.basename(CHECKPOINT)}")
    print(f"   PCK value: {best_pck_value:.4f}")
    print()
    
    # Show top 5 best PCK checkpoints for reference
    print(f"Top best PCK checkpoints ({len(best_pck_ckpts)} total):")
    for i, ckpt in enumerate(sorted_best_pck_ckpts[:5]):
        pck_val = extract_pck(ckpt)
        marker = " ‚Üê USING THIS" if ckpt == CHECKPOINT else ""
        print(f"  {i+1}. {os.path.basename(ckpt)} (PCK: {pck_val:.4f}){marker}")
    if len(best_pck_ckpts) > 5:
        print(f"  ... and {len(best_pck_ckpts) - 5} more")
    print("=" * 80)
    print()

if CHECKPOINT:
    # Create evaluation output directory
    os.makedirs(EVAL_OUTPUT_DIR, exist_ok=True)
    
    print("=" * 80)
    print("TEST SET EVALUATION CONFIGURATION")
    print("=" * 80)
    print(f"Checkpoint:     {os.path.basename(CHECKPOINT)}")
    print(f"Split:          test")
    print(f"Episodes:       200 (comprehensive test evaluation)")
    print(f"Visualizations: 100 examples")
    print(f"Output dir:     {EVAL_OUTPUT_DIR}")
    print("=" * 80)
    print()


In [None]:
# Run evaluation on test set
if CHECKPOINT:
    # Change to project directory
    os.chdir(PROJECT_ROOT)
    
    # Build evaluation command
    # Evaluate with both 1-shot and 5-shot for comprehensive evaluation
    EVAL_SHOTS = [1, 5]  # Evaluate with both 1-shot and 5-shot
    
    for num_support in EVAL_SHOTS:
        eval_output_dir = os.path.join(EVAL_OUTPUT_DIR, f"{num_support}shot")
        os.makedirs(eval_output_dir, exist_ok=True)
        
        cmd = [
            sys.executable, "scripts/eval_cape_checkpoint.py",
            "--checkpoint", CHECKPOINT,
            "--split", "test",
            "--num-episodes", "200",  # Comprehensive test evaluation
            "--num-support-per-episode", str(num_support),  # K-shot evaluation
            "--num-visualizations", "100",  # Generate 100 visualizations
            "--pck-threshold", "0.2",
            "--output-dir", eval_output_dir,
            "--device", "cuda:0",
            "--num-workers", "2",
            "--draw-skeleton",  # Draw skeleton edges in visualizations
            "--save-all-queries"  # Save all queries in each episode
        ]
        
        print(f"\n{'=' * 80}")
        print(f"Running {num_support}-shot evaluation on test set...")
        print(f"{'=' * 80}")
        
        try:
            result = subprocess.run(cmd, check=True, capture_output=True, text=True)
            print(result.stdout)
            if result.stderr:
                print("Warnings/Errors:")
                print(result.stderr)
            
            # Check for metrics file
            metrics_file = os.path.join(eval_output_dir, "metrics_test.json")
            if os.path.exists(metrics_file):
                import json
                with open(metrics_file, 'r') as f:
                    metrics = json.load(f)
                
                print(f"\nüìä {num_support}-shot Test Set Results:")
                print(f"  PCK@0.2 (Overall):     {metrics.get('pck_overall', 'N/A'):.4f}")
                print(f"  Mean PCK (Categories):  {metrics.get('mean_pck_categories', 'N/A'):.4f}")
                print(f"  Correct keypoints:      {metrics.get('total_correct', 'N/A')}")
                print(f"  Visible keypoints:     {metrics.get('total_visible', 'N/A')}")
                print(f"  Episodes evaluated:    {metrics.get('num_episodes', 'N/A')}")
                print(f"üìÑ Metrics saved to: {metrics_file}")
            
        except subprocess.CalledProcessError as e:
            print(f"\n‚ùå {num_support}-shot evaluation failed with return code: {e.returncode}")
            print("STDOUT:")
            print(e.stdout)
            print("\nSTDERR:")
            print(e.stderr)
    
    # Summary comparison
    print(f"\n{'=' * 80}")
    print("EVALUATION SUMMARY - Comparing 1-shot vs 5-shot")
    print("=" * 80)
    
    for num_support in EVAL_SHOTS:
        metrics_file = os.path.join(EVAL_OUTPUT_DIR, f"{num_support}shot", "metrics_test.json")
        if os.path.exists(metrics_file):
            import json
            with open(metrics_file, 'r') as f:
                metrics = json.load(f)
            pck = metrics.get('pck_overall', 0)
            print(f"  {num_support}-shot PCK@0.2: {pck:.4f}")
    
    print("=" * 80)
    
    print("\nüí° Evaluation Results:")
    print("   - 1-shot metrics: outputs/test_geometric_1/test_evaluation/1shot/metrics_test.json")
    print("   - 5-shot metrics: outputs/test_geometric_1/test_evaluation/5shot/metrics_test.json")
    print("   - Visualizations saved in respective subdirectories")
    print("\nüí° Next steps:")
    print("   1. Compare 1-shot vs 5-shot PCK scores")
    print("   2. Review per-category breakdowns in metrics JSON files")
    print("   3. Check visualizations to see qualitative differences")
    print("   4. Cross-category evaluation: Test set uses UNSEEN categories (20 categories)")
    print("      not seen during training (69 categories) or validation (10 categories)")
else:
    print("‚ö†Ô∏è  Cannot run evaluation: No checkpoint found")


In [19]:
## 8. Check Training Results


In [20]:
# Check training results and find best checkpoint
import os
import glob

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"
OUTPUT_DIR = os.path.join(PROJECT_ROOT, "output", "single_image_colab")

print("Checking training results...")
print(f"Output directory: {OUTPUT_DIR}")
print()

# List all checkpoints
checkpoints = glob.glob(os.path.join(OUTPUT_DIR, "*.pth"))
if checkpoints:
    print(f"Found {len(checkpoints)} checkpoint(s):")
    for ckpt in sorted(checkpoints):
        size_mb = os.path.getsize(ckpt) / (1024 * 1024)
        print(f"  - {os.path.basename(ckpt)} ({size_mb:.1f} MB)")

    # Find best checkpoint
    best_ckpts = glob.glob(os.path.join(OUTPUT_DIR, "checkpoint_best_pck*.pth"))
    if best_ckpts:
        best_ckpt = sorted(best_ckpts)[-1]  # Get most recent
        print(f"\n‚úÖ Best checkpoint: {os.path.basename(best_ckpt)}")
        BEST_CHECKPOINT = best_ckpt
    else:
        # Use most recent checkpoint
        best_ckpt = sorted(checkpoints, key=os.path.getmtime)[-1]
        print(f"\n‚ö†Ô∏è  No 'best' checkpoint found, using most recent: {os.path.basename(best_ckpt)}")
        BEST_CHECKPOINT = best_ckpt
else:
    print("‚ùå No checkpoints found!")
    BEST_CHECKPOINT = None

# Show log file info
LOG_FILE = os.path.join(OUTPUT_DIR, "training_logs.txt")
if os.path.exists(LOG_FILE):
    size_mb = os.path.getsize(LOG_FILE) / (1024 * 1024)
    print(f"\nüìÑ Training log: {os.path.basename(LOG_FILE)} ({size_mb:.2f} MB)")

    # Show last few lines of log
    print("\nLast 20 lines of training log:")
    print("-" * 80)
    with open(LOG_FILE, 'r') as f:
        lines = f.readlines()
        for line in lines[-20:]:
            print(line.rstrip())
else:
    print("\n‚ö†Ô∏è  Training log not found")

print("\n" + "=" * 80)


Checking training results...
Output directory: /content/category-agnostic-pose-estimation/output/single_image_colab

Found 56 checkpoint(s):
  - checkpoint_best_pck_e000_pck0.2222_meanpck0.2222.pth (580.8 MB)
  - checkpoint_best_pck_e000_pck0.3333_meanpck0.3333.pth (585.1 MB)
  - checkpoint_best_pck_e001_pck0.5556_meanpck0.5556.pth (580.8 MB)
  - checkpoint_best_pck_e009_pck0.7778_meanpck0.7778.pth (585.1 MB)
  - checkpoint_best_pck_e009_pck0.8889_meanpck0.8889.pth (580.8 MB)
  - checkpoint_best_pck_e010_pck0.8889_meanpck0.8889.pth (585.1 MB)
  - checkpoint_e000_lr1e-04_bs1_acc4_qpe1.pth (580.8 MB)
  - checkpoint_e001_lr1e-04_bs1_acc4_qpe1.pth (580.8 MB)
  - checkpoint_e002_lr1e-04_bs1_acc4_qpe1.pth (580.8 MB)
  - checkpoint_e003_lr1e-04_bs1_acc4_qpe1.pth (580.8 MB)
  - checkpoint_e004_lr1e-04_bs1_acc4_qpe1.pth (580.8 MB)
  - checkpoint_e005_lr1e-04_bs1_acc4_qpe1.pth (580.8 MB)
  - checkpoint_e006_lr1e-04_bs1_acc4_qpe1.pth (580.8 MB)
  - checkpoint_e007_lr1e-04_bs1_acc4_qpe1.pth (580.8

## 9. Visualize Predictions

This section visualizes the model's predictions on the single training image.
We'll use the trained checkpoint to generate predictions and visualize them.

In [None]:
# Visualize predictions using the trained model
import os
import sys
import subprocess
from datetime import datetime
import glob

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"
OUTPUT_DIR = os.path.join(PROJECT_ROOT, "output", "single_image_colab")
VISUALIZATION_DIR = os.path.join(OUTPUT_DIR, "visualizations")

# Find best PCK checkpoint (highest PCK value)
# Checkpoint names are like: checkpoint_best_pck_e035_pck1.0000_meanpck1.0000.pth
best_pck_ckpts = glob.glob(os.path.join(OUTPUT_DIR, "checkpoint_best_pck*.pth"))
if not best_pck_ckpts:
    print("‚ùå No best PCK checkpoint found! Please run training first.")
    sys.exit(1)

def extract_pck(checkpoint_path):
    """Extract PCK value from checkpoint filename."""
    import re
    basename = os.path.basename(checkpoint_path)
    # Match pattern: checkpoint_best_pck_e035_pck1.0000_meanpck1.0000.pth
    match = re.search(r'_pck([\d.]+)_', basename)
    if match:
        return float(match.group(1))
    return 0.0

# Sort checkpoints by PCK value (highest PCK = best)
sorted_best_pck_ckpts = sorted(best_pck_ckpts, key=extract_pck, reverse=True)
CHECKPOINT = sorted_best_pck_ckpts[0]  # Best PCK checkpoint
best_pck_value = extract_pck(CHECKPOINT)

print(f"‚úÖ Using best PCK checkpoint: {os.path.basename(CHECKPOINT)}")
print(f"   PCK value: {best_pck_value:.4f}")

# Show top 5 best PCK checkpoints for reference
print(f"\nTop best PCK checkpoints ({len(best_pck_ckpts)} total):")
for i, ckpt in enumerate(sorted_best_pck_ckpts[:5]):
    pck_val = extract_pck(ckpt)
    marker = " ‚Üê USING THIS" if ckpt == CHECKPOINT else ""
    print(f"  {i+1}. {os.path.basename(ckpt)} (PCK: {pck_val:.4f}){marker}")
if len(best_pck_ckpts) > 5:
    print(f"  ... and {len(best_pck_ckpts) - 5} more")
print()

# Use the SAME image that was used for training and validation
# Check if SINGLE_IMAGE_PATH was set in an earlier cell
DATA_DIR = os.path.join(PROJECT_ROOT, "data")

if 'SINGLE_IMAGE_PATH' not in globals() or SINGLE_IMAGE_PATH is None:
    print("‚ö†Ô∏è  SINGLE_IMAGE_PATH not found in global scope!")
    print("   Trying to find it from training logs...")

    # Try to extract from training logs
    LOG_FILE = os.path.join(OUTPUT_DIR, "training_logs.txt")
    if os.path.exists(LOG_FILE):
        with open(LOG_FILE, 'r') as f:
            log_content = f.read()
            # Look for image path in logs
            import re
            pattern = r'Training on SINGLE IMAGE with path: (.+)'
            match = re.search(pattern, log_content)
            if match:
                SINGLE_IMAGE_PATH = match.group(1).strip()
                print(f"   ‚úÖ Found in logs: {SINGLE_IMAGE_PATH}")
            else:
                # Try alternative patterns
                pattern2 = r'Image path: (.+)'
                match2 = re.search(pattern2, log_content)
                if match2:
                    SINGLE_IMAGE_PATH = match2.group(1).strip()
                    print(f"   ‚úÖ Found in logs (alt pattern): {SINGLE_IMAGE_PATH}")
                else:
                    print("   ‚ùå Could not find image path in logs.")
                    print("   Please run the 'Find Single Image' cell first.")
                    sys.exit(1)
    else:
        print("   ‚ùå Training logs not found.")
        print("   Please run the training cell first, or set SINGLE_IMAGE_PATH manually.")
        sys.exit(1)

# Handle relative image paths - try to make absolute if needed
if not os.path.isabs(SINGLE_IMAGE_PATH):
    # Try relative to data directory
    abs_path = os.path.join(DATA_DIR, SINGLE_IMAGE_PATH)
    if os.path.exists(abs_path):
        SINGLE_IMAGE_PATH = abs_path
        print(f"   ‚úÖ Converted relative path to absolute: {SINGLE_IMAGE_PATH}")
    else:
        # Try relative to project root
        abs_path = os.path.join(PROJECT_ROOT, SINGLE_IMAGE_PATH)
        if os.path.exists(abs_path):
            SINGLE_IMAGE_PATH = abs_path
            print(f"   ‚úÖ Converted relative path to absolute: {SINGLE_IMAGE_PATH}")

# Verify the image file exists
if not os.path.exists(SINGLE_IMAGE_PATH):
    print(f"‚ö†Ô∏è  Image file not found: {SINGLE_IMAGE_PATH}")
    print("   Trying to find the image in the data directory...")
    
    # Try to find the image by filename in data directory
    image_filename = os.path.basename(SINGLE_IMAGE_PATH)
    for root, dirs, files in os.walk(DATA_DIR):
        if image_filename in files:
            SINGLE_IMAGE_PATH = os.path.join(root, image_filename)
            print(f"   ‚úÖ Found image at: {SINGLE_IMAGE_PATH}")
            break
    else:
        print("   ‚ùå Could not find the image file.")
        print("   Please re-run the 'Find Single Image' cell to find a valid image.")
        sys.exit(1)

print("=" * 80)
print("CAPE Prediction Visualization")
print("=" * 80)
print(f"Checkpoint: {os.path.basename(CHECKPOINT)}")
print(f"Image:      {SINGLE_IMAGE_PATH}")
print(f"Output:     {VISUALIZATION_DIR}")
print("=" * 80)
print()

# Create visualization directory
os.makedirs(VISUALIZATION_DIR, exist_ok=True)

# Build visualization command
cmd = [
    sys.executable, "-m", "models.visualize_cape_predictions",
    "--checkpoint", CHECKPOINT,
    "--dataset_root", PROJECT_ROOT,
    "--device", "cuda",  # Use "cuda" not "cuda:0" - script only accepts cpu, cuda, or mps
    "--single_image_path", SINGLE_IMAGE_PATH,  # Visualize the same image used for training
    "--output_dir", VISUALIZATION_DIR
]

print("Running visualization...")
print(f"Command: {' '.join(cmd)}")
print()

# Change to project directory
os.chdir(PROJECT_ROOT)

# Run visualization
try:
    result = subprocess.run(cmd, check=True, capture_output=True, text=True)
    print(result.stdout)
    if result.stderr:
        print("Warnings/Errors:")
        print(result.stderr)

    print("\n" + "=" * 80)
    print("‚úÖ Visualization complete!")
    print(f"Results saved to: {VISUALIZATION_DIR}")
    print("=" * 80)

    # List generated visualizations
    vis_files = glob.glob(os.path.join(VISUALIZATION_DIR, "*.png"))
    if vis_files:
        print(f"\nGenerated {len(vis_files)} visualization(s):")
        for vis_file in sorted(vis_files):
            print(f"  - {os.path.basename(vis_file)}")
        print("\nüí° Visualizations show:")
        print("   - Query image (the single training image)")
        print("   - Ground truth keypoints (what the model should predict)")
        print("   - Predicted keypoints (what the model actually predicted)")
        print("   - PCK score (percentage of correctly predicted keypoints)")
    else:
        print("\n‚ö†Ô∏è  No visualization files found. Check the script output above for errors.")

except subprocess.CalledProcessError as e:
    print(f"\n‚ùå Visualization failed with return code: {e.returncode}")
    print("STDOUT:")
    print(e.stdout)
    print("\nSTDERR:")
    print(e.stderr)
except Exception as e:
    print(f"\n‚ùå Error during visualization: {e}")
    import traceback
    traceback.print_exc()



‚úÖ Using last checkpoint: checkpoint_best_pck_e010_pck0.8889_meanpck0.8889.pth

Available checkpoints (56 total):
  - checkpoint_e045_lr1e-04_bs1_acc4_qpe1.pth (epoch 45)
  - checkpoint_e046_lr1e-04_bs1_acc4_qpe1.pth (epoch 46)
  - checkpoint_e047_lr1e-04_bs1_acc4_qpe1.pth (epoch 47)
  - checkpoint_e048_lr1e-04_bs1_acc4_qpe1.pth (epoch 48)
  - checkpoint_e049_lr1e-04_bs1_acc4_qpe1.pth (epoch 49)
  ... and 51 more

CAPE Prediction Visualization
Checkpoint: checkpoint_best_pck_e010_pck0.8889_meanpck0.8889.pth
Image:      /content/category-agnostic-pose-estimation/data/camel_face/camel_16.jpg
Output:     /content/category-agnostic-pose-estimation/output/single_image_colab/visualizations

Running visualization...
Command: /usr/bin/python3 -m models.visualize_cape_predictions --checkpoint /content/category-agnostic-pose-estimation/output/single_image_colab/checkpoint_best_pck_e010_pck0.8889_meanpck0.8889.pth --dataset_root /content/category-agnostic-pose-estimation --device cuda --single_i

In [None]:
# Download results from Colab
from google.colab import files
import os
import shutil
import zipfile
import glob

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"
OUTPUT_DIR = os.path.join(PROJECT_ROOT, "output", "single_image_colab")

print("Preparing results for download...")
print(f"Output directory: {OUTPUT_DIR}")
print()

# Create a zip file with all results
ZIP_FILE = "/content/single_image_training_results.zip"

with zipfile.ZipFile(ZIP_FILE, 'w', zipfile.ZIP_DEFLATED) as zipf:
    # Add training logs
    log_file = os.path.join(OUTPUT_DIR, "training_logs.txt")
    if os.path.exists(log_file):
        zipf.write(log_file, "training_logs.txt")
        print(f"‚úì Added training logs")

    # Add checkpoints (only best one to save space)
    best_ckpts = glob.glob(os.path.join(OUTPUT_DIR, "checkpoint_best_pck*.pth"))
    if best_ckpts:
        best_ckpt = sorted(best_ckpts)[-1]
        zipf.write(best_ckpt, f"checkpoints/{os.path.basename(best_ckpt)}")
        print(f"‚úì Added best checkpoint: {os.path.basename(best_ckpt)}")

    # Add visualizations
    vis_dir = os.path.join(OUTPUT_DIR, "visualizations")
    if os.path.exists(vis_dir):
        for root, dirs, files in os.walk(vis_dir):
            for file in files:
                file_path = os.path.join(root, file)
                arcname = os.path.relpath(file_path, OUTPUT_DIR)
                zipf.write(file_path, arcname)
        print(f"‚úì Added visualizations")

print(f"\n‚úÖ Created zip file: {ZIP_FILE}")
print(f"Size: {os.path.getsize(ZIP_FILE) / (1024*1024):.2f} MB")
print("\nDownloading...")
files.download(ZIP_FILE)

print("\n‚úÖ Download complete!")
print("\nThe zip file contains:")
print("  - training_logs.txt (full training output)")
print("  - checkpoints/ (best model checkpoint)")
print("  - visualizations/ (prediction visualizations)")
