# MP-100 CAPE Training on Google Colab

This notebook trains Category-Agnostic Pose Estimation (CAPE) on the MP-100 dataset using Google Colab's GPU.

## Setup Instructions
1. Enable GPU: Runtime ‚Üí Change runtime type ‚Üí GPU (T4 or better)
2. Run all cells in order
3. The notebook will:
   - Clone code from GitHub
   - Install dependencies
   - Authenticate to GCP
   - Mount GCS bucket with data
   - Run training with "tiny" mode


## 1. Check GPU Availability


In [17]:
# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("‚ö†Ô∏è  No GPU detected! Please enable GPU in Runtime > Change runtime type > GPU")


CUDA available: True
GPU: NVIDIA A100-SXM4-40GB
CUDA version: 12.6
GPU Memory: 42.47 GB


In [19]:
# Clone repository from GitHub
import os
from getpass import getpass

REPO_URL = "https://github.com/nkkrnkl/category-agnostic-pose-estimation.git"
BRANCH = "pavlos-topic-copy"
PROJECT_ROOT = "/content/category-agnostic-pose-estimation"

# Remove existing directory if it exists
if os.path.exists(PROJECT_ROOT):
    print(f"Removing existing directory: {PROJECT_ROOT}")
    !rm -rf {PROJECT_ROOT}

# For private repositories, you need to authenticate
# Option 1: Use Personal Access Token (recommended)
# Get token from: https://github.com/settings/tokens
# Create a token with 'repo' scope
print("For private repositories, you need to authenticate.")
print("Option 1: Enter your GitHub Personal Access Token")
print("  (Get one from: https://github.com/settings/tokens)")
print("Option 2: Press Enter to try without token (will fail if repo is private)")
print()

GITHUB_TOKEN = getpass("Enter GitHub Personal Access Token (or press Enter to skip): ")

if GITHUB_TOKEN.strip():
    # Use token in URL
    # Format: https://TOKEN@github.com/username/repo.git
    AUTH_REPO_URL = REPO_URL.replace("https://github.com/", f"https://{GITHUB_TOKEN}@github.com/")
    print(f"Cloning repository from {REPO_URL} (branch: {BRANCH})...")
    !git clone -b {BRANCH} {AUTH_REPO_URL} {PROJECT_ROOT}
    !git pull origin {BRANCH}
else:
    # Try without token (will work if repo is public)
    print(f"Cloning repository from {REPO_URL} (branch: {BRANCH})...")
    !git clone -b {BRANCH} {REPO_URL} {PROJECT_ROOT}

# Verify clone
if os.path.exists(PROJECT_ROOT) and os.path.exists(os.path.join(PROJECT_ROOT, ".git")):
    print(f"‚úÖ Repository cloned successfully to {PROJECT_ROOT}")
    !cd {PROJECT_ROOT} && git branch
else:
    print("‚ùå Failed to clone repository")
    print("\nIf the repository is private, you need to:")
    print("1. Create a Personal Access Token at: https://github.com/settings/tokens")
    print("2. Select 'repo' scope")
    print("3. Run this cell again and paste the token when prompted")


For private repositories, you need to authenticate.
Option 1: Enter your GitHub Personal Access Token
  (Get one from: https://github.com/settings/tokens)
Option 2: Press Enter to try without token (will fail if repo is private)

Enter GitHub Personal Access Token (or press Enter to skip): ¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑
Cloning repository from https://github.com/nkkrnkl/category-agnostic-pose-estimation.git (branch: pavlos-topic-copy)...
Cloning into '/content/category-agnostic-pose-estimation'...
remote: Enumerating objects: 1015, done.[K
remote: Counting objects: 100% (241/241), done.[K
remote: Compressing objects: 100% (187/187), done.[K
remote: Total 1015 (delta 75), reused 208 (delta 51), pack-reused 774 (from 2)[K
Receiving objects: 100% (1015/1015), 73.26 MiB | 15.66 MiB/s, done.
Resolving deltas: 100% (325/325), done.
Updating files: 100% (215/215), done.
remote: Enumerating objects: 13, done.[K
remote: Counting objects: 100% (13/13), done.[K
remote: Compressing objects: 100% (1/1)

In [20]:
import os

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"
BRANCH = "pavlos-topic-copy"

print(f"Pulling latest changes from branch {BRANCH}...")
!cd {PROJECT_ROOT} && git pull origin {BRANCH}

print("‚úÖ Git pull complete!")

Pulling latest changes from branch pavlos-topic-copy...
From https://github.com/nkkrnkl/category-agnostic-pose-estimation
 * branch            pavlos-topic-copy -> FETCH_HEAD
Already up to date.
‚úÖ Git pull complete!


## 3. Install Requirements


In [21]:
# Install additional dependencies needed for plot_utils and other utilities
# (descartes, shapely, etc. - these are in requirements.txt but not requirements_cape.txt)
print("Installing additional dependencies (descartes, shapely, etc.)...")
!pip install -q descartes shapely>=1.8.0
print("‚úÖ Additional dependencies installed!")


Installing additional dependencies (descartes, shapely, etc.)...
‚úÖ Additional dependencies installed!


In [22]:
# Install requirements
import os

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"
REQUIREMENTS_FILE = os.path.join(PROJECT_ROOT, "requirements_cape.txt")

print("Installing requirements from requirements_cape.txt...")
!cd {PROJECT_ROOT} && pip install -q -r {REQUIREMENTS_FILE}

# Install detectron2 for CUDA 11.8 (Colab typically has CUDA 11.8)
print("\nInstalling detectron2...")
!pip install -q 'git+https://github.com/facebookresearch/detectron2.git'

print("‚úÖ All dependencies installed!")


Installing requirements from requirements_cape.txt...

Installing detectron2...
  Preparing metadata (setup.py) ... [?25l[?25hdone
‚úÖ All dependencies installed!


## 4. Authenticate to GCP


In [23]:
# Authenticate to GCP
from google.colab import auth

print("Authenticating to GCP...")
auth.authenticate_user()

# Set GCP project
GCP_PROJECT = "dl-category-agnostic-pose-est"
!gcloud config set project {GCP_PROJECT}

print(f"‚úÖ Authenticated to GCP project: {GCP_PROJECT}")


Authenticating to GCP...
Are you sure you wish to set property [core/project] to 
dl-category-agnostic-pose-est?

Do you want to continue (Y/n)?  Y

Updated property [core/project].
‚úÖ Authenticated to GCP project: dl-category-agnostic-pose-est


## 5. Mount GCS Bucket


In [24]:
# Mount GCS bucket using gcsfuse
import os
import subprocess
import time

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"
BUCKET_NAME = "dl-category-agnostic-pose-mp100-data"
MOUNT_POINT = os.path.join(PROJECT_ROOT, "Raster2Seq_internal-main", "data")

# Install gcsfuse from Google's official repository
print("Installing gcsfuse...")
# Add Google's gcsfuse repository (updated method for newer Ubuntu versions)
!export GCSFUSE_REPO=gcsfuse-`lsb_release -c -s` && \
echo "deb http://packages.cloud.google.com/apt $GCSFUSE_REPO main" | sudo tee /etc/apt/sources.list.d/gcsfuse.list && \
curl -fsSL https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg && \
echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] http://packages.cloud.google.com/apt $GCSFUSE_REPO main" | sudo tee /etc/apt/sources.list.d/gcsfuse.list && \
sudo apt-get update && \
sudo apt-get install -y gcsfuse

# Verify installation
!which gcsfuse
print("‚úÖ gcsfuse installed")

# Create mount point directory and parent directories
print(f"Creating mount point: {MOUNT_POINT}")
os.makedirs(os.path.dirname(MOUNT_POINT), exist_ok=True)
os.makedirs(MOUNT_POINT, exist_ok=True)

# Check if already mounted
try:
    result = subprocess.run(['mountpoint', '-q', MOUNT_POINT], capture_output=True)
    if result.returncode == 0:
        print(f"‚úÖ Already mounted at {MOUNT_POINT}")
    else:
        # Try to unmount if exists but not properly mounted
        try:
            subprocess.run(['fusermount', '-u', MOUNT_POINT], capture_output=True, timeout=5)
        except:
            try:
                subprocess.run(['umount', MOUNT_POINT], capture_output=True, timeout=5)
            except:
                pass
except:
    pass

# Mount the bucket
print(f"Mounting gs://{BUCKET_NAME} to {MOUNT_POINT}...")
print("This may take a moment...")

# Run gcsfuse in background
# Note: In Colab, we need to run gcsfuse in background using shell &
print(f"Running: gcsfuse --implicit-dirs {BUCKET_NAME} {MOUNT_POINT}")
!nohup gcsfuse --implicit-dirs {BUCKET_NAME} {MOUNT_POINT} > /tmp/gcsfuse.log 2>&1 &

# Wait a moment for mount to initialize
print("Waiting for mount to initialize...")
time.sleep(8)  # Give it more time to mount

# Check mount status
print("\nChecking mount status...")
# Check mount log for errors
if os.path.exists("/tmp/gcsfuse.log"):
    with open("/tmp/gcsfuse.log", "r") as f:
        log_content = f.read()
        if log_content:
            print("Mount log:")
            print(log_content[-500:])  # Last 500 chars
        else:
            print("Mount log is empty (mount might still be initializing)")

# Also verify we can access the bucket directly with gsutil
print("\nVerifying bucket access with gsutil...")
!gsutil ls gs://{BUCKET_NAME}/ | head -10

# Verify mount
print(f"\nVerifying mount at: {MOUNT_POINT}")
print(f"Path exists: {os.path.exists(MOUNT_POINT)}")

# Check if actually mounted using mountpoint command
try:
    result = subprocess.run(['mountpoint', '-q', MOUNT_POINT], capture_output=True)
    is_mounted = (result.returncode == 0)
    print(f"Is mounted: {is_mounted}")
except:
    # Fallback: check mount table
    result = subprocess.run(['mount'], capture_output=True, text=True)
    is_mounted = MOUNT_POINT in result.stdout
    print(f"Is mounted (from mount table): {is_mounted}")

if os.path.exists(MOUNT_POINT) and is_mounted:
    try:
        # Try to list contents
        items = os.listdir(MOUNT_POINT)
        if len(items) > 0:
            print(f"‚úÖ GCS bucket mounted successfully!")
            print(f"Mount point: {MOUNT_POINT}")
            print(f"Found {len(items)} items in bucket")
            # List a few items to verify
            for item in items[:10]:
                item_path = os.path.join(MOUNT_POINT, item)
                item_type = "directory" if os.path.isdir(item_path) else "file"
                print(f"   - {item} ({item_type})")
        else:
            print(f"‚ö†Ô∏è  Mount point exists but is empty (0 items)")
            print(f"   This might indicate:")
            print(f"   1. Bucket is empty")
            print(f"   2. Mount didn't work correctly")
            print(f"   3. Permission issues")
    except PermissionError as e:
        print(f"‚ö†Ô∏è  Permission error accessing mount: {e}")
        print("   Mount might still be initializing, wait a moment and try again")
    except Exception as e:
        print(f"‚ö†Ô∏è  Mount point exists but cannot list contents: {e}")
        print("   This might indicate a mount issue")
        import traceback
        traceback.print_exc()
elif os.path.exists(MOUNT_POINT) and not is_mounted:
    print(f"‚ö†Ô∏è  Directory exists but is not mounted")
    print(f"   The directory exists but gcsfuse mount is not active")
    print(f"   Trying to mount again...")
    # Try mounting again
    !nohup gcsfuse --implicit-dirs {BUCKET_NAME} {MOUNT_POINT} > /tmp/gcsfuse.log 2>&1 &
    time.sleep(5)
    # Re-check
    items = os.listdir(MOUNT_POINT) if os.path.exists(MOUNT_POINT) else []
    if len(items) > 0:
        print(f"‚úÖ Mount successful after retry! Found {len(items)} items")
    else:
        print(f"‚ùå Mount still not working")
else:
    print("‚ùå Failed to mount GCS bucket")
    print(f"   Mount point: {MOUNT_POINT}")
    print(f"   Check:")
    print(f"   1. GCP authentication (run the GCP auth cell)")
    print(f"   2. Bucket name is correct: {BUCKET_NAME}")
    print(f"   3. You have read access to the bucket")
    print(f"   4. Check mount log: /tmp/gcsfuse.log")


Installing gcsfuse...
deb http://packages.cloud.google.com/apt gcsfuse-jammy main
gpg: cannot open '/dev/tty': No such device or address
curl: (23) Failed writing body
/usr/bin/gcsfuse
‚úÖ gcsfuse installed
Creating mount point: /content/category-agnostic-pose-estimation/Raster2Seq_internal-main/data
Mounting gs://dl-category-agnostic-pose-mp100-data to /content/category-agnostic-pose-estimation/Raster2Seq_internal-main/data...
This may take a moment...
Running: gcsfuse --implicit-dirs dl-category-agnostic-pose-mp100-data /content/category-agnostic-pose-estimation/Raster2Seq_internal-main/data
Waiting for mount to initialize...

Checking mount status...
Mount log:
:4,"RandomSeekThreshold":3,"StartBlocksPerHandle":1},"WorkloadInsight":{"ForwardMergeThresholdMb":0,"OutputFile":"","Visualize":false},"Write":{"BlockSizeMb":32,"CreateEmptyFile":false,"EnableRapidAppends":true,"EnableStreamingWrites":true,"FinalizeFileForRapid":false,"GlobalMaxBlocks":4,"MaxBlocksPerFile":1}}}
{"timestamp":{

## 6. Create Data Symlink


In [25]:
# Create symlink from data to mounted GCS bucket (as expected by START_TRAINING.sh)
import os

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"
MOUNTED_DATA = os.path.join(PROJECT_ROOT, "Raster2Seq_internal-main", "data")
DATA_SYMLINK = os.path.join(PROJECT_ROOT, "data")

print(f"Checking mount point: {MOUNTED_DATA}")
print(f"  Exists: {os.path.exists(MOUNTED_DATA)}")
if os.path.exists(MOUNTED_DATA):
    print(f"  Is directory: {os.path.isdir(MOUNTED_DATA)}")
    try:
        items = os.listdir(MOUNTED_DATA)
        print(f"  Can list contents: Yes ({len(items)} items)")
    except Exception as e:
        print(f"  Can list contents: No ({e})")

# Remove existing symlink or directory if it exists
if os.path.exists(DATA_SYMLINK):
    if os.path.islink(DATA_SYMLINK):
        print(f"Removing existing symlink: {DATA_SYMLINK}")
        os.unlink(DATA_SYMLINK)
    elif os.path.isdir(DATA_SYMLINK):
        print(f"Warning: {DATA_SYMLINK} exists as a directory (not a symlink)")
        print("   Removing it to create symlink...")
        import shutil
        shutil.rmtree(DATA_SYMLINK)
    else:
        print(f"Warning: {DATA_SYMLINK} exists and is not a symlink or directory")
        os.remove(DATA_SYMLINK)

# Create symlink
if os.path.exists(MOUNTED_DATA) and os.path.isdir(MOUNTED_DATA):
    try:
        # Use absolute path for symlink target
        MOUNTED_DATA_ABS = os.path.abspath(MOUNTED_DATA)
        print(f"\nCreating symlink:")
        print(f"  From: {DATA_SYMLINK}")
        print(f"  To: {MOUNTED_DATA_ABS}")
        os.symlink(MOUNTED_DATA_ABS, DATA_SYMLINK)
        print(f"‚úÖ Created symlink: {DATA_SYMLINK} -> {MOUNTED_DATA_ABS}")

        # Verify symlink
        if os.path.exists(DATA_SYMLINK):
            print(f"‚úÖ Symlink verified: {DATA_SYMLINK}")
            print(f"  Is symlink: {os.path.islink(DATA_SYMLINK)}")
            # Try to list contents through symlink
            try:
                items = os.listdir(DATA_SYMLINK)
                print(f"‚úÖ Can access {len(items)} items through symlink")
                print(f"   First 5 items: {items[:5]}")
            except Exception as e:
                print(f"‚ö†Ô∏è  Symlink exists but cannot access contents: {e}")
        else:
            print(f"‚ùå Symlink creation failed - path does not exist after creation")
    except Exception as e:
        print(f"‚ùå Error creating symlink: {e}")
        print(f"   Source: {MOUNTED_DATA}")
        print(f"   Target: {DATA_SYMLINK}")
        import traceback
        traceback.print_exc()
else:
    print(f"‚ùå Mounted data not found at {MOUNTED_DATA}")
    print(f"   Please check that GCS bucket is mounted correctly")
    print(f"   Run the mount cell above and check for errors")
    print(f"   Mount point should exist and be accessible")


Checking mount point: /content/category-agnostic-pose-estimation/Raster2Seq_internal-main/data
  Exists: True
  Is directory: True
  Can list contents: Yes (99 items)

Creating symlink:
  From: /content/category-agnostic-pose-estimation/data
  To: /content/category-agnostic-pose-estimation/Raster2Seq_internal-main/data
‚úÖ Created symlink: /content/category-agnostic-pose-estimation/data -> /content/category-agnostic-pose-estimation/Raster2Seq_internal-main/data
‚úÖ Symlink verified: /content/category-agnostic-pose-estimation/data
  Is symlink: True
‚úÖ Can access 99 items through symlink
   First 5 items: ['amur_tiger_body', 'annotations', 'antelope_body', 'arcticwolf_face', 'beaver_body']


In [26]:
## 7. Run Single Image Training

# This section trains the model on a **single image** from a specific category for 20 epochs.
# This is useful for:
# - Quick overfitting test to verify the model can learn
# - Debugging the training pipeline
# - Testing on Colab GPU

# **Training Configuration:**
# - Single image mode: Enabled
# - Category: 40 (zebra) - you can change this
# - Epochs: 20
# - All logs will be saved to `training_logs.txt`


# Configure single image training parameters


In [27]:
# Configure single image training
import os
from datetime import datetime

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"

# Training configuration
SINGLE_IMAGE_CATEGORY = 40  # Category ID (40 = zebra, change if needed)
EPOCHS = 20
BATCH_SIZE = 1
NUM_QUERIES_PER_EPISODE = 1
EPISODES_PER_EPOCH = 20

# Output directories
OUTPUT_DIR = os.path.join(PROJECT_ROOT, "output", "single_image_colab")
LOG_FILE = os.path.join(OUTPUT_DIR, "training_logs.txt")

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

print("=" * 80)
print("Single Image Training Configuration")
print("=" * 80)
print(f"Category ID:        {SINGLE_IMAGE_CATEGORY}")
print(f"Epochs:            {EPOCHS}")
print(f"Batch size:        {BATCH_SIZE}")
print(f"Queries/episode:   {NUM_QUERIES_PER_EPISODE}")
print(f"Episodes/epoch:    {EPISODES_PER_EPOCH}")
print(f"Output directory:  {OUTPUT_DIR}")
print(f"Log file:          {LOG_FILE}")
print("=" * 80)
print()



Single Image Training Configuration
Category ID:        40
Epochs:            20
Batch size:        1
Queries/episode:   1
Episodes/epoch:    20
Output directory:  /content/category-agnostic-pose-estimation/output/single_image_colab
Log file:          /content/category-agnostic-pose-estimation/output/single_image_colab/training_logs.txt



# Run Training with Logging


In [None]:
# Find first image from cleaned train annotations that exists in mounted data
import os
import json

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"
DATA_DIR = os.path.join(PROJECT_ROOT, "data")
ANNOTATION_FILE = os.path.join(PROJECT_ROOT, "annotations", "mp100_split1_train.json")

print("Finding first valid image from cleaned train annotations...")
print(f"Data directory: {DATA_DIR}")
print(f"Annotation file: {ANNOTATION_FILE}")
print()

# Load train annotations
if os.path.exists(ANNOTATION_FILE):
    with open(ANNOTATION_FILE, 'r') as f:
        coco_data = json.load(f)
    
    images = coco_data.get('images', [])
    print(f"Found {len(images)} images in train annotations")
    
    # Find first image that exists
    found_image = None
    for img_info in images:
        file_name = img_info['file_name']  # e.g., "camel_face/camel_133.jpg"
        full_path = os.path.join(DATA_DIR, file_name)
        
        if os.path.exists(full_path):
            found_image = full_path
            print(f"‚úÖ Found first valid image:")
            print(f"   File name: {file_name}")
            print(f"   Full path: {full_path}")
            print(f"   Image ID: {img_info.get('id', 'N/A')}")
            break
    
    if found_image:
        SINGLE_IMAGE_PATH = found_image
        print(f"\n‚úÖ Set SINGLE_IMAGE_PATH = {SINGLE_IMAGE_PATH}")
    else:
        print("‚ùå No images found in mounted data!")
        print("   Check that:")
        print("   1. GCS bucket is mounted correctly")
        print("   2. Data symlink is created")
        print("   3. Images exist in the data directory")
        SINGLE_IMAGE_PATH = None
else:
    print(f"‚ùå Annotation file not found: {ANNOTATION_FILE}")
    SINGLE_IMAGE_PATH = None


In [None]:
# SINGLE_IMAGE_PATH is now set automatically by the previous cell
# It finds the first image from cleaned train annotations that exists in mounted data
# If you want to use a specific image instead, uncomment and set it here:
# SINGLE_IMAGE_PATH = '/content/category-agnostic-pose-estimation/data/bison_body/000000001113.jpg'

# Verify SINGLE_IMAGE_PATH is set
if 'SINGLE_IMAGE_PATH' not in globals() or SINGLE_IMAGE_PATH is None:
    print("‚ö†Ô∏è  SINGLE_IMAGE_PATH not set! Please run the previous cell to find a valid image.")
else:
    print(f"‚úÖ Using image: {SINGLE_IMAGE_PATH}")

In [31]:
# Run training on single image with full logging
import subprocess
import sys
from datetime import datetime

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"
OUTPUT_DIR = os.path.join(PROJECT_ROOT, "output", "single_image_colab")
LOG_FILE = os.path.join(OUTPUT_DIR, "training_logs.txt")

# Build training command
cmd = [
    sys.executable, "-m", "models.train_cape_episodic",
    "--dataset_root", PROJECT_ROOT,
    "--category_split_file", os.path.join(PROJECT_ROOT, "category_splits.json"),
    "--output_dir", OUTPUT_DIR,
    "--device", "cuda:0",
]

# Add single image mode argument (path takes precedence over category)
if SINGLE_IMAGE_PATH:
    cmd.extend(["--debug_single_image_path", SINGLE_IMAGE_PATH])
elif 'SINGLE_IMAGE_CATEGORY' in globals() and SINGLE_IMAGE_CATEGORY is not None:
    cmd.extend(["--debug_single_image", str(SINGLE_IMAGE_CATEGORY)])

# Add remaining training arguments
cmd.extend([
    "--epochs", str(EPOCHS),
    "--batch_size", str(BATCH_SIZE),
    "--num_queries_per_episode", str(NUM_QUERIES_PER_EPISODE),
    "--episodes_per_epoch", str(EPISODES_PER_EPOCH),
    "--lr", "1e-4",
    "--lr_backbone", "1e-5",
    "--weight_decay", "1e-4",
    "--clip_max_norm", "0.1",
    "--support_encoder_layers", "3",
    "--support_fusion_method", "cross_attention",
    "--backbone", "resnet50",
    "--hidden_dim", "256",
    "--nheads", "8",
    "--enc_layers", "6",
    "--dec_layers", "6",
    "--dim_feedforward", "1024",
    "--dropout", "0.1",
    "--image_size", "256",
    "--vocab_size", "2000",
    "--seq_len", "200",
    "--num_queries", "200",
    "--num_polys", "1",
    "--cls_loss_coef", "2.0",
    "--coords_loss_coef", "5.0",
    "--room_cls_loss_coef", "0.0",
    "--semantic_classes", "70",
    "--num_feature_levels", "4",
    "--dec_n_points", "4",
    "--enc_n_points", "4",
    "--aux_loss",
    "--with_poly_refine",
    "--num_workers", "2",
    "--seed", "42",
    "--print_freq", "5",
    "--use_amp",  # Enable mixed precision for faster training
    "--cudnn_benchmark",  # Enable cuDNN benchmark
    "--job_name", f"single_image_colab_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
]
)

print("Starting training...")
print(f"Command: {' '.join(cmd)}")
print(f"Logging to: {LOG_FILE}")
print("=" * 80)
print()

# Change to project directory
os.chdir(PROJECT_ROOT)

# Run training with logging to both stdout and file
with open(LOG_FILE, 'w') as log_file:
    # Write header to log file
    log_file.write("=" * 80 + "\n")
    log_file.write(f"Single Image Training Log\n")
    log_file.write(f"Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
    if 'SINGLE_IMAGE_PATH' in globals() and SINGLE_IMAGE_PATH:
        log_file.write(f"Image path: {SINGLE_IMAGE_PATH}\n")
    elif 'SINGLE_IMAGE_CATEGORY' in globals() and SINGLE_IMAGE_CATEGORY is not None:
        log_file.write(f"Category: {SINGLE_IMAGE_CATEGORY}\n")
    log_file.write(f"Epochs: {EPOCHS}\n")
    log_file.write("=" * 80 + "\n\n")
    log_file.flush()

    # Run process and stream output to both stdout and file
    process = subprocess.Popen(
        cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        universal_newlines=True,
        bufsize=1
    )

    # Stream output in real-time
    for line in process.stdout:
        print(line, end='')  # Print to notebook
        log_file.write(line)  # Write to log file
        log_file.flush()  # Ensure immediate write

    # Wait for process to complete
    return_code = process.wait()

    # Write footer to log file
    log_file.write("\n" + "=" * 80 + "\n")
    log_file.write(f"Training completed: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
    log_file.write(f"Return code: {return_code}\n")
    log_file.write("=" * 80 + "\n")

if return_code == 0:
    print("\n" + "=" * 80)
    print("‚úÖ Training completed successfully!")
    print(f"Checkpoints saved to: {OUTPUT_DIR}")
    print(f"Full logs saved to: {LOG_FILE}")
    print("=" * 80)
else:
    print("\n" + "=" * 80)
    print(f"‚ùå Training failed with return code: {return_code}")
    print(f"Check logs at: {LOG_FILE}")
    print("=" * 80)


Starting training...
Command: /usr/bin/python3 -m models.train_cape_episodic --dataset_root /content/category-agnostic-pose-estimation --category_split_file /content/category-agnostic-pose-estimation/category_splits.json --output_dir /content/category-agnostic-pose-estimation/output/single_image_colab --device cuda:0 --debug_single_image_path /content/category-agnostic-pose-estimation/data/bison_body/000000001113.jpg --epochs 20 --batch_size 1 --num_queries_per_episode 1 --episodes_per_epoch 20 --lr 1e-4 --lr_backbone 1e-5 --weight_decay 1e-4 --clip_max_norm 0.1 --support_encoder_layers 3 --support_fusion_method cross_attention --backbone resnet50 --hidden_dim 256 --nheads 8 --enc_layers 6 --dec_layers 6 --dim_feedforward 1024 --dropout 0.1 --image_size 256 --vocab_size 2000 --seq_len 200 --num_queries 200 --num_polys 1 --cls_loss_coef 2.0 --coords_loss_coef 5.0 --room_cls_loss_coef 0.0 --semantic_classes 70 --num_feature_levels 4 --dec_n_points 4 --enc_n_points 4 --aux_loss --with_pol

In [None]:
## 8. Check Training Results


In [None]:
# Check training results and find best checkpoint
import os
import glob

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"
OUTPUT_DIR = os.path.join(PROJECT_ROOT, "output", "single_image_colab")

print("Checking training results...")
print(f"Output directory: {OUTPUT_DIR}")
print()

# List all checkpoints
checkpoints = glob.glob(os.path.join(OUTPUT_DIR, "*.pth"))
if checkpoints:
    print(f"Found {len(checkpoints)} checkpoint(s):")
    for ckpt in sorted(checkpoints):
        size_mb = os.path.getsize(ckpt) / (1024 * 1024)
        print(f"  - {os.path.basename(ckpt)} ({size_mb:.1f} MB)")

    # Find best checkpoint
    best_ckpts = glob.glob(os.path.join(OUTPUT_DIR, "checkpoint_best_pck*.pth"))
    if best_ckpts:
        best_ckpt = sorted(best_ckpts)[-1]  # Get most recent
        print(f"\n‚úÖ Best checkpoint: {os.path.basename(best_ckpt)}")
        BEST_CHECKPOINT = best_ckpt
    else:
        # Use most recent checkpoint
        best_ckpt = sorted(checkpoints, key=os.path.getmtime)[-1]
        print(f"\n‚ö†Ô∏è  No 'best' checkpoint found, using most recent: {os.path.basename(best_ckpt)}")
        BEST_CHECKPOINT = best_ckpt
else:
    print("‚ùå No checkpoints found!")
    BEST_CHECKPOINT = None

# Show log file info
LOG_FILE = os.path.join(OUTPUT_DIR, "training_logs.txt")
if os.path.exists(LOG_FILE):
    size_mb = os.path.getsize(LOG_FILE) / (1024 * 1024)
    print(f"\nüìÑ Training log: {os.path.basename(LOG_FILE)} ({size_mb:.2f} MB)")

    # Show last few lines of log
    print("\nLast 20 lines of training log:")
    print("-" * 80)
    with open(LOG_FILE, 'r') as f:
        lines = f.readlines()
        for line in lines[-20:]:
            print(line.rstrip())
else:
    print("\n‚ö†Ô∏è  Training log not found")

print("\n" + "=" * 80)


## 9. Visualize Predictions

This section visualizes the model's predictions on the single training image.
We'll use the trained checkpoint to generate predictions and visualize them.

In [None]:
# Visualize predictions using the trained model
import os
import sys
import subprocess
from datetime import datetime

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"
OUTPUT_DIR = os.path.join(PROJECT_ROOT, "output", "single_image_colab")
VISUALIZATION_DIR = os.path.join(OUTPUT_DIR, "visualizations")

# Find best checkpoint
import glob
best_ckpts = glob.glob(os.path.join(OUTPUT_DIR, "checkpoint_best_pck*.pth"))
if not best_ckpts:
    # Try any checkpoint
    all_ckpts = glob.glob(os.path.join(OUTPUT_DIR, "*.pth"))
    if all_ckpts:
        best_ckpts = [sorted(all_ckpts, key=os.path.getmtime)[-1]]

if not best_ckpts:
    print("‚ùå No checkpoint found! Please run training first.")
    sys.exit(1)

CHECKPOINT = sorted(best_ckpts)[-1]

# Single image to visualize (the overfitted training image)
SINGLE_IMAGE_PATH = "/content/category-agnostic-pose-estimation/data/camel_face/camel_133.jpg"

print("=" * 80)
print("CAPE Prediction Visualization")
print("=" * 80)
print(f"Checkpoint: {os.path.basename(CHECKPOINT)}")
print(f"Image:      {SINGLE_IMAGE_PATH}")
print(f"Output:     {VISUALIZATION_DIR}")
print("=" * 80)
print()

# Create visualization directory
os.makedirs(VISUALIZATION_DIR, exist_ok=True)

# Build visualization command
cmd = [
    sys.executable, "-m", "models.visualize_cape_predictions",
    "--checkpoint", CHECKPOINT,
    "--dataset_root", PROJECT_ROOT,
    "--device", "cuda",  # Use "cuda" not "cuda:0" - script only accepts cpu, cuda, or mps
    "--single_image_path", SINGLE_IMAGE_PATH,  # Visualize specific overfitted image
    "--output_dir", VISUALIZATION_DIR
]

print("Running visualization...")
print(f"Command: {' '.join(cmd)}")
print()

# Change to project directory
os.chdir(PROJECT_ROOT)

# Run visualization
try:
    result = subprocess.run(cmd, check=True, capture_output=True, text=True)
    print(result.stdout)
    if result.stderr:
        print("Warnings/Errors:")
        print(result.stderr)

    print("\n" + "=" * 80)
    print("‚úÖ Visualization complete!")
    print(f"Results saved to: {VISUALIZATION_DIR}")
    print("=" * 80)

    # List generated visualizations
    vis_files = glob.glob(os.path.join(VISUALIZATION_DIR, "*.png"))
    if vis_files:
        print(f"\nGenerated {len(vis_files)} visualization(s):")
        for vis_file in sorted(vis_files):
            print(f"  - {os.path.basename(vis_file)}")

except subprocess.CalledProcessError as e:
    print(f"\n‚ùå Visualization failed with return code: {e.returncode}")
    print("STDOUT:")
    print(e.stdout)
    print("\nSTDERR:")
    print(e.stderr)
except Exception as e:
    print(f"\n‚ùå Error during visualization: {e}")
    import traceback
    traceback.print_exc()


In [None]:
# Download results from Colab
from google.colab import files
import os
import shutil
import zipfile
import glob

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"
OUTPUT_DIR = os.path.join(PROJECT_ROOT, "output", "single_image_colab")

print("Preparing results for download...")
print(f"Output directory: {OUTPUT_DIR}")
print()

# Create a zip file with all results
ZIP_FILE = "/content/single_image_training_results.zip"

with zipfile.ZipFile(ZIP_FILE, 'w', zipfile.ZIP_DEFLATED) as zipf:
    # Add training logs
    log_file = os.path.join(OUTPUT_DIR, "training_logs.txt")
    if os.path.exists(log_file):
        zipf.write(log_file, "training_logs.txt")
        print(f"‚úì Added training logs")

    # Add checkpoints (only best one to save space)
    best_ckpts = glob.glob(os.path.join(OUTPUT_DIR, "checkpoint_best_pck*.pth"))
    if best_ckpts:
        best_ckpt = sorted(best_ckpts)[-1]
        zipf.write(best_ckpt, f"checkpoints/{os.path.basename(best_ckpt)}")
        print(f"‚úì Added best checkpoint: {os.path.basename(best_ckpt)}")

    # Add visualizations
    vis_dir = os.path.join(OUTPUT_DIR, "visualizations")
    if os.path.exists(vis_dir):
        for root, dirs, files in os.walk(vis_dir):
            for file in files:
                file_path = os.path.join(root, file)
                arcname = os.path.relpath(file_path, OUTPUT_DIR)
                zipf.write(file_path, arcname)
        print(f"‚úì Added visualizations")

print(f"\n‚úÖ Created zip file: {ZIP_FILE}")
print(f"Size: {os.path.getsize(ZIP_FILE) / (1024*1024):.2f} MB")
print("\nDownloading...")
files.download(ZIP_FILE)

print("\n‚úÖ Download complete!")
print("\nThe zip file contains:")
print("  - training_logs.txt (full training output)")
print("  - checkpoints/ (best model checkpoint)")
print("  - visualizations/ (prediction visualizations)")
