# üöÄ MP-100 CAPE Training on Google Colab

This notebook trains Category-Agnostic Pose Estimation (CAPE) on the full MP-100 dataset using Google Colab's GPU.

## üìã Setup Instructions
1. **Enable GPU**: Runtime ‚Üí Change runtime type ‚Üí GPU (A100 recommended, T4/V100 works)
2. **Run all cells in order** (takes ~8-12 hours)
3. The notebook will:
   - Clone code from GitHub (`pavlos-topic` branch with CUDA fix)
   - Install dependencies
   - Authenticate to GCP
   - Mount GCS bucket with MP-100 dataset (read-only)
   - Mount Google Drive for checkpoint storage (persistent)
   - Run full CAPE training (300 epochs with early stopping)

## üíæ Data Storage Strategy
- **Input Data**: GCS Bucket `dl-category-agnostic-pose-mp100-data` (read-only, ~100 categories)
- **Output Checkpoints**: Google Drive `MyDrive/cape_training_output/` (persistent, survives Colab session)
- **Why?**: GCS for large shared dataset, Google Drive for your personal model files

## üì¶ What Gets Saved
All files are saved to **Google Drive** (`MyDrive/cape_training_output/`):
- ‚úÖ `checkpoint_e###_*.pth` - Every epoch (for resume/analysis)
- ‚úÖ `checkpoint_best_pck_*.pth` - Best validation PCK model (for evaluation)
- ‚úÖ `training_logs.txt` - Complete training output
- ‚úÖ `TRAINING_SUMMARY.txt` - Quick summary with best model info

**‚Üí You can download these from Google Drive after training completes!**


## 1. Check GPU Availability


In [1]:
# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("‚ö†Ô∏è  No GPU detected! Please enable GPU in Runtime > Change runtime type > GPU")


CUDA available: True
GPU: NVIDIA A100-SXM4-80GB
CUDA version: 12.6
GPU Memory: 85.17 GB


In [2]:
# Clone repository from GitHub
import os
import subprocess
from getpass import getpass

REPO_URL = "https://github.com/nkkrnkl/category-agnostic-pose-estimation.git"
BRANCH = "main"
PROJECT_ROOT = "/content/category-agnostic-pose-estimation"

# Remove existing directory if it exists (SAFE method - won't hang)
if os.path.exists(PROJECT_ROOT):
    print(f"Directory exists: {PROJECT_ROOT}")
    print("Cleaning up safely...")

    # Step 1: Remove symlinks first (prevent rm -rf from following them into GCS mounts)
    try:
        result = subprocess.run(
            ['find', PROJECT_ROOT, '-maxdepth', '2', '-type', 'l', '-delete'],
            timeout=10,
            capture_output=True,
            text=True
        )
        print("  ‚úì Removed symlinks")
    except subprocess.TimeoutExpired:
        print("  ‚ö†Ô∏è  Symlink removal timed out (continuing anyway)")
    except Exception as e:
        print(f"  ‚ö†Ô∏è  Symlink removal warning: {e}")

    # Step 2: Now safely remove directory with timeout
    try:
        result = subprocess.run(
            ['rm', '-rf', PROJECT_ROOT],
            timeout=30,  # 30 second timeout prevents hanging
            capture_output=True,
            text=True,
            check=True
        )
        print(f"  ‚úì Removed {PROJECT_ROOT}")
    except subprocess.TimeoutExpired:
        print("  ‚ö†Ô∏è  rm -rf timed out, using Python fallback...")
        import shutil
        shutil.rmtree(PROJECT_ROOT, ignore_errors=True)
        print(f"  ‚úì Removed with shutil")
    except Exception as e:
        print(f"  ‚ö†Ô∏è  Removal error: {e}")
        print("  ‚Üí Attempting to continue anyway")

# For private repositories, you need to authenticate
# Option 1: Use Personal Access Token (recommended)
# Get token from: https://github.com/settings/tokens
# Create a token with 'repo' scope
print("For private repositories, you need to authenticate.")
print("Option 1: Enter your GitHub Personal Access Token")
print("  (Get one from: https://github.com/settings/tokens)")
print("Option 2: Press Enter to try without token (will fail if repo is private)")
print()

GITHUB_TOKEN = getpass("Enter GitHub Personal Access Token (or press Enter to skip): ")

if GITHUB_TOKEN.strip():
    # Use token in URL
    # Format: https://TOKEN@github.com/username/repo.git
    AUTH_REPO_URL = REPO_URL.replace("https://github.com/", f"https://{GITHUB_TOKEN}@github.com/")
    print(f"Cloning repository from {REPO_URL} (branch: {BRANCH})...")
    !git clone -b {BRANCH} {AUTH_REPO_URL} {PROJECT_ROOT}
    !git pull origin {BRANCH}
else:
    # Try without token (will work if repo is public)
    print(f"Cloning repository from {REPO_URL} (branch: {BRANCH})...")
    !git clone -b {BRANCH} {REPO_URL} {PROJECT_ROOT}

# Verify clone
if os.path.exists(PROJECT_ROOT) and os.path.exists(os.path.join(PROJECT_ROOT, ".git")):
    print(f"‚úÖ Repository cloned successfully to {PROJECT_ROOT}")
    !cd {PROJECT_ROOT} && git branch
else:
    print("‚ùå Failed to clone repository")
    print("\nIf the repository is private, you need to:")
    print("1. Create a Personal Access Token at: https://github.com/settings/tokens")
    print("2. Select 'repo' scope")
    print("3. Run this cell again and paste the token when prompted")


For private repositories, you need to authenticate.
Option 1: Enter your GitHub Personal Access Token
  (Get one from: https://github.com/settings/tokens)
Option 2: Press Enter to try without token (will fail if repo is private)

Enter GitHub Personal Access Token (or press Enter to skip): ¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑¬∑
Cloning repository from https://github.com/nkkrnkl/category-agnostic-pose-estimation.git (branch: main)...
Cloning into '/content/category-agnostic-pose-estimation'...
remote: Enumerating objects: 1428, done.[K
remote: Counting objects: 100% (137/137), done.[K
remote: Compressing objects: 100% (99/99), done.[K
remote: Total 1428 (delta 66), reused 79 (delta 38), pack-reused 1291 (from 1)[K
Receiving objects: 100% (1428/1428), 73.80 MiB | 17.17 MiB/s, done.
Resolving deltas: 100% (610/610), done.
fatal: not a git repository (or any of the parent directories): .git
‚úÖ Repository cloned successfully to /content/category-agnostic-pose-estimation
* [32mmain[m


In [41]:
import os

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"
BRANCH = "main"

print(f"Pulling latest changes from branch {BRANCH}...")
!cd {PROJECT_ROOT} && git pull origin {BRANCH}

print("‚úÖ Git pull complete!")

Pulling latest changes from branch main...
remote: Enumerating objects: 5, done.[K
remote: Counting objects: 100% (5/5), done.[K
remote: Compressing objects: 100% (1/1), done.[K
remote: Total 3 (delta 2), reused 3 (delta 2), pack-reused 0 (from 0)[K
Unpacking objects: 100% (3/3), 1.26 KiB | 117.00 KiB/s, done.
From https://github.com/nkkrnkl/category-agnostic-pose-estimation
 * branch            main       -> FETCH_HEAD
   7272279..a0a9d49  main       -> origin/main
Updating d33076e..a0a9d49
error: Your local changes to the following files would be overwritten by merge:
	datasets/episodic_sampler.py
Please commit your changes or stash them before you merge.
Aborting
‚úÖ Git pull complete!


## 2.5 Apply CUDA Fixes (Important!)

The GitHub repo may not have the latest CUDA-specific fixes. This cell patches two critical files:

1. **`geometric_support_encoder.py`**: Adds safety check for all-masked batches (prevents `to_padded_tensor` crash)
2. **`episodic_sampler.py`**: Fixes mask convention (True=ignore, False=use)

**Why is this needed?**
- CUDA uses nested tensor optimization that crashes when all keypoints are masked
- MPS (Mac) doesn't use this optimization, so it works locally but crashes on Colab
- Early EOS prediction can cause temporary all-masked batches in epoch 1


In [4]:
# ============================================================================
# üîß CRITICAL: Apply CUDA Nested Tensor Fixes
# ============================================================================
# These patches fix PyTorch nested tensor issues that crash on CUDA when
# all keypoints in a batch are masked (happens with early EOS prediction).
#
# Fixes applied:
# 1. geometric_support_encoder.py: Add safety check for all-masked batches
# 2. episodic_sampler.py: Correct mask convention (True=ignore, False=use)
# ============================================================================

import os

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"

# ============================================================================
# Fix 1: Add safety check to geometric_support_encoder.py
# ============================================================================
GEOMETRIC_ENCODER_FILE = os.path.join(PROJECT_ROOT, "models/geometric_support_encoder.py")

# Read the file
with open(GEOMETRIC_ENCODER_FILE, 'r') as f:
    content = f.read()

# Check if the fix is already applied
if "CRITICAL SAFETY CHECK" not in content:
    print("Applying Fix 1: Adding safety check to geometric_support_encoder.py...")

    # Find the location to insert the safety check
    # We need to modify the forward method where self.transformer_encoder is called
    old_code = '''        # 6. Transformer self-attention
        support_features = self.transformer_encoder(
            embeddings,
            src_key_padding_mask=support_mask
        )'''

    new_code = '''        # 6. Transformer self-attention
        # support_mask: True = positions to ignore (mask out)
        # PyTorch convention: True = ignore

        # CRITICAL SAFETY CHECK: Handle edge case where ALL keypoints are masked
        # This can happen with invalid data where all keypoints have visibility==0
        # PyTorch's nested tensor conversion fails when all elements are masked
        # Check if any batch element has all keypoints masked (all True)
        all_masked_per_batch = support_mask.all(dim=1)  # [bs]

        if all_masked_per_batch.any():
            # At least one batch element has all keypoints masked (invalid data)
            # This causes nested tensor conversion to fail
            # Workaround: temporarily unmask the first keypoint for those batches
            temp_mask = support_mask.clone()
            for b in range(support_mask.shape[0]):
                if all_masked_per_batch[b]:
                    # Unmask the first keypoint to avoid nested tensor error
                    # (This shouldn't happen with proper data validation, but we handle it gracefully)
                    temp_mask[b, 0] = False

            support_features = self.transformer_encoder(
                embeddings,
                src_key_padding_mask=temp_mask
            )

            # Zero out features for fully-masked batches (they contain invalid data)
            support_features[all_masked_per_batch] = 0.0
        else:
            # Normal case - process as usual
            support_features = self.transformer_encoder(
                embeddings,
                src_key_padding_mask=support_mask
            )'''

    if old_code in content:
        content = content.replace(old_code, new_code)
        with open(GEOMETRIC_ENCODER_FILE, 'w') as f:
            f.write(content)
        print("  ‚úÖ Fix 1 applied successfully!")
    else:
        # Try alternative pattern (might have slight differences)
        print("  ‚ö†Ô∏è  Could not find exact pattern. Checking if fix is already present...")
        if "all_masked_per_batch" in content:
            print("  ‚úÖ Fix appears to be already applied!")
        else:
            print("  ‚ùå Could not apply fix - pattern not found. Manual intervention needed.")
else:
    print("Fix 1: Safety check already present in geometric_support_encoder.py ‚úÖ")

# ============================================================================
# Fix 2: Correct mask convention in episodic_sampler.py
# ============================================================================
EPISODIC_SAMPLER_FILE = os.path.join(PROJECT_ROOT, "datasets/episodic_sampler.py")

with open(EPISODIC_SAMPLER_FILE, 'r') as f:
    sampler_content = f.read()

# Check if the WRONG mask convention is present
if "[v > 0 for v in support_visibility]" in sampler_content:
    print("Applying Fix 2: Correcting mask convention in episodic_sampler.py...")

    # Fix the mask convention: True should mean "ignore", not "visible"
    old_mask_code = "[v > 0 for v in support_visibility]"
    new_mask_code = "[v == 0 for v in support_visibility]"

    sampler_content = sampler_content.replace(old_mask_code, new_mask_code)

    # Also add a comment explaining the convention
    if "# CRITICAL: Mask should be True=ignore" not in sampler_content:
        sampler_content = sampler_content.replace(
            "support_mask = torch.tensor(",
            "# CRITICAL: Mask should be True=ignore, False=use\n        # So we want True when visibility == 0 (not labeled)\n        support_mask = torch.tensor("
        )

    with open(EPISODIC_SAMPLER_FILE, 'w') as f:
        f.write(sampler_content)
    print("  ‚úÖ Fix 2 applied successfully!")
elif "[v == 0 for v in support_visibility]" in sampler_content:
    print("Fix 2: Correct mask convention already in episodic_sampler.py ‚úÖ")
else:
    print("  ‚ö†Ô∏è  Could not find mask pattern in episodic_sampler.py - manual check needed")

# ============================================================================
# Verification: Check both fixes are present
# ============================================================================
print("\n" + "="*60)
print("VERIFICATION:")
print("="*60)

# Re-read files and verify
with open(GEOMETRIC_ENCODER_FILE, 'r') as f:
    ge_content = f.read()

with open(EPISODIC_SAMPLER_FILE, 'r') as f:
    es_content = f.read()

fix1_ok = "CRITICAL SAFETY CHECK" in ge_content or "all_masked_per_batch" in ge_content
fix2_ok = "[v == 0 for v in support_visibility]" in es_content

print(f"  Fix 1 (Safety check in geometric_support_encoder.py): {'‚úÖ OK' if fix1_ok else '‚ùå MISSING'}")
print(f"  Fix 2 (Mask convention in episodic_sampler.py):       {'‚úÖ OK' if fix2_ok else '‚ùå MISSING'}")

if fix1_ok and fix2_ok:
    print("\n‚úÖ All CUDA fixes applied successfully!")
    print("   Training should now work on Colab with A100/V100/T4 GPUs.")
else:
    print("\n‚ö†Ô∏è  Some fixes may be missing. Check the output above.")
    print("   If training fails with 'to_padded_tensor' error, manually apply the fixes.")

print("="*60)


Fix 1: Safety check already present in geometric_support_encoder.py ‚úÖ
Applying Fix 2: Correcting mask convention in episodic_sampler.py...
  ‚úÖ Fix 2 applied successfully!

VERIFICATION:
  Fix 1 (Safety check in geometric_support_encoder.py): ‚úÖ OK
  Fix 2 (Mask convention in episodic_sampler.py):       ‚úÖ OK

‚úÖ All CUDA fixes applied successfully!
   Training should now work on Colab with A100/V100/T4 GPUs.


## 3. Install Requirements


In [5]:
# Install additional dependencies needed for plot_utils and other utilities
# (descartes, shapely, etc. - these are in requirements.txt but not requirements_cape.txt)
print("Installing additional dependencies (descartes, shapely, etc.)...")
!pip install -q descartes shapely>=1.8.0
print("‚úÖ Additional dependencies installed!")


Installing additional dependencies (descartes, shapely, etc.)...
‚úÖ Additional dependencies installed!


In [6]:
# Install requirements
import os

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"
REQUIREMENTS_FILE = os.path.join(PROJECT_ROOT, "requirements_cape.txt")

print("Installing requirements from requirements_cape.txt...")
!cd {PROJECT_ROOT} && pip install -q -r {REQUIREMENTS_FILE}

# Install detectron2 for CUDA 11.8 (Colab typically has CUDA 11.8)
print("\nInstalling detectron2...")
!pip install -q 'git+https://github.com/facebookresearch/detectron2.git'

print("‚úÖ All dependencies installed!")


Installing requirements from requirements_cape.txt...
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m85.2/85.2 kB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m1.7/1.7 MB[0m [31m96.4 MB/s[0m eta [36m0:00:00[0m
[?25h
Installing detectron2...
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m50.2/50.2 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m154.5/154.5 kB[0m [31m16.3 MB/s[0m eta [36m0:00:00[0m
[?2

## 4. Authenticate to GCP


In [7]:
# Authenticate to GCP
from google.colab import auth

print("Authenticating to GCP...")
auth.authenticate_user()

# Set GCP project
GCP_PROJECT = "dl-category-agnostic-pose-est"
!gcloud config set project {GCP_PROJECT}

print(f"‚úÖ Authenticated to GCP project: {GCP_PROJECT}")


Authenticating to GCP...
Are you sure you wish to set property [core/project] to 
dl-category-agnostic-pose-est?

Do you want to continue (Y/n)?  Y

Updated property [core/project].
‚úÖ Authenticated to GCP project: dl-category-agnostic-pose-est


## 5. Mount GCS Bucket


In [16]:
# Mount GCS bucket using gcsfuse
import os
import subprocess
import time

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"
BUCKET_NAME = "dl-category-agnostic-pose-mp100-data"
MOUNT_POINT = os.path.join(PROJECT_ROOT, "Raster2Seq_internal-main", "data")

# Install gcsfuse from Google's official repository
print("Installing gcsfuse...")
# Add Google's gcsfuse repository (updated method for newer Ubuntu versions)
!export GCSFUSE_REPO=gcsfuse-`lsb_release -c -s` && \
echo "deb http://packages.cloud.google.com/apt $GCSFUSE_REPO main" | sudo tee /etc/apt/sources.list.d/gcsfuse.list && \
curl -fsSL https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg && \
echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] http://packages.cloud.google.com/apt $GCSFUSE_REPO main" | sudo tee /etc/apt/sources.list.d/gcsfuse.list && \
sudo apt-get update && \
sudo apt-get install -y gcsfuse

# Verify installation
!which gcsfuse
print("‚úÖ gcsfuse installed")

# Create mount point directory and parent directories
print(f"Creating mount point: {MOUNT_POINT}")
os.makedirs(os.path.dirname(MOUNT_POINT), exist_ok=True)
os.makedirs(MOUNT_POINT, exist_ok=True)

# Check if already mounted
try:
    result = subprocess.run(['mountpoint', '-q', MOUNT_POINT], capture_output=True)
    if result.returncode == 0:
        print(f"‚úÖ Already mounted at {MOUNT_POINT}")
    else:
        # Try to unmount if exists but not properly mounted
        try:
            subprocess.run(['fusermount', '-u', MOUNT_POINT], capture_output=True, timeout=5)
        except:
            try:
                subprocess.run(['umount', MOUNT_POINT], capture_output=True, timeout=5)
            except:
                pass
except:
    pass

# Mount the bucket
print(f"Mounting gs://{BUCKET_NAME} to {MOUNT_POINT}...")
print("This may take a moment...")

# Run gcsfuse in background
# Note: In Colab, we need to run gcsfuse in background using shell &
print(f"Running: gcsfuse --implicit-dirs {BUCKET_NAME} {MOUNT_POINT}")
!nohup gcsfuse --implicit-dirs {BUCKET_NAME} {MOUNT_POINT} > /tmp/gcsfuse.log 2>&1 &

# Wait a moment for mount to initialize
print("Waiting for mount to initialize...")
time.sleep(8)  # Give it more time to mount

# Check mount status
print("\nChecking mount status...")
# Check mount log for errors
if os.path.exists("/tmp/gcsfuse.log"):
    with open("/tmp/gcsfuse.log", "r") as f:
        log_content = f.read()
        if log_content:
            print("Mount log:")
            print(log_content[-500:])  # Last 500 chars
        else:
            print("Mount log is empty (mount might still be initializing)")

# Also verify we can access the bucket directly with gsutil
print("\nVerifying bucket access with gsutil...")
!gsutil ls gs://{BUCKET_NAME}/ | head -10

# Verify mount
print(f"\nVerifying mount at: {MOUNT_POINT}")
print(f"Path exists: {os.path.exists(MOUNT_POINT)}")

# Check if actually mounted using mountpoint command
try:
    result = subprocess.run(['mountpoint', '-q', MOUNT_POINT], capture_output=True)
    is_mounted = (result.returncode == 0)
    print(f"Is mounted: {is_mounted}")
except:
    # Fallback: check mount table
    result = subprocess.run(['mount'], capture_output=True, text=True)
    is_mounted = MOUNT_POINT in result.stdout
    print(f"Is mounted (from mount table): {is_mounted}")

if os.path.exists(MOUNT_POINT) and is_mounted:
    try:
        # Try to list contents
        items = os.listdir(MOUNT_POINT)
        if len(items) > 0:
            print(f"‚úÖ GCS bucket mounted successfully!")
            print(f"Mount point: {MOUNT_POINT}")
            print(f"Found {len(items)} items in bucket")
            # List a few items to verify
            for item in items[:10]:
                item_path = os.path.join(MOUNT_POINT, item)
                item_type = "directory" if os.path.isdir(item_path) else "file"
                print(f"   - {item} ({item_type})")
        else:
            print(f"‚ö†Ô∏è  Mount point exists but is empty (0 items)")
            print(f"   This might indicate:")
            print(f"   1. Bucket is empty")
            print(f"   2. Mount didn't work correctly")
            print(f"   3. Permission issues")
    except PermissionError as e:
        print(f"‚ö†Ô∏è  Permission error accessing mount: {e}")
        print("   Mount might still be initializing, wait a moment and try again")
    except Exception as e:
        print(f"‚ö†Ô∏è  Mount point exists but cannot list contents: {e}")
        print("   This might indicate a mount issue")
        import traceback
        traceback.print_exc()
elif os.path.exists(MOUNT_POINT) and not is_mounted:
    print(f"‚ö†Ô∏è  Directory exists but is not mounted")
    print(f"   The directory exists but gcsfuse mount is not active")
    print(f"   Trying to mount again...")
    # Try mounting again
    !nohup gcsfuse --implicit-dirs {BUCKET_NAME} {MOUNT_POINT} > /tmp/gcsfuse.log 2>&1 &
    time.sleep(5)
    # Re-check
    items = os.listdir(MOUNT_POINT) if os.path.exists(MOUNT_POINT) else []
    if len(items) > 0:
        print(f"‚úÖ Mount successful after retry! Found {len(items)} items")
    else:
        print(f"‚ùå Mount still not working")
else:
    print("‚ùå Failed to mount GCS bucket")
    print(f"   Mount point: {MOUNT_POINT}")
    print(f"   Check:")
    print(f"   1. GCP authentication (run the GCP auth cell)")
    print(f"   2. Bucket name is correct: {BUCKET_NAME}")
    print(f"   3. You have read access to the bucket")
    print(f"   4. Check mount log: /tmp/gcsfuse.log")


Installing gcsfuse...
deb http://packages.cloud.google.com/apt gcsfuse-jammy main
gpg: cannot open '/dev/tty': No such device or address
curl: (23) Failed writing body
/usr/bin/gcsfuse
‚úÖ gcsfuse installed
Creating mount point: /content/category-agnostic-pose-estimation/Raster2Seq_internal-main/data
‚úÖ Already mounted at /content/category-agnostic-pose-estimation/Raster2Seq_internal-main/data
Mounting gs://dl-category-agnostic-pose-mp100-data to /content/category-agnostic-pose-estimation/Raster2Seq_internal-main/data...
This may take a moment...
Running: gcsfuse --implicit-dirs dl-category-agnostic-pose-mp100-data /content/category-agnostic-pose-estimation/Raster2Seq_internal-main/data
Waiting for mount to initialize...

Checking mount status...
Mount log:
":4,"RandomSeekThreshold":3,"StartBlocksPerHandle":1},"WorkloadInsight":{"ForwardMergeThresholdMb":0,"OutputFile":"","Visualize":false},"Write":{"BlockSizeMb":32,"CreateEmptyFile":false,"EnableRapidAppends":true,"EnableStreamingWri

## 6. Create Data Symlink


In [29]:
# Create symlinks from data and annotations to mounted GCS bucket (as expected by START_TRAINING.sh)
import os

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"
MOUNTED_DATA = os.path.join(PROJECT_ROOT, "Raster2Seq_internal-main", "data")
MOUNTED_ANNOTATIONS = os.path.join(PROJECT_ROOT, "Raster2Seq_internal-main", "annotations")
DATA_SYMLINK = os.path.join(PROJECT_ROOT, "data")
ANNOTATIONS_SYMLINK = os.path.join(PROJECT_ROOT, "annotations")

print(f"Checking mount point: {MOUNTED_DATA}")
print(f"  Exists: {os.path.exists(MOUNTED_DATA)}")
if os.path.exists(MOUNTED_DATA):
    print(f"  Is directory: {os.path.isdir(MOUNTED_DATA)}")
    try:
        items = os.listdir(MOUNTED_DATA)
        print(f"  Can list contents: Yes ({len(items)} items)")
    except Exception as e:
        print(f"  Can list contents: No ({e})")

# Remove existing symlink or directory if it exists
if os.path.exists(DATA_SYMLINK):
    if os.path.islink(DATA_SYMLINK):
        print(f"Removing existing symlink: {DATA_SYMLINK}")
        os.unlink(DATA_SYMLINK)
    elif os.path.isdir(DATA_SYMLINK):
        print(f"Warning: {DATA_SYMLINK} exists as a directory (not a symlink)")
        print("   Removing it to create symlink...")
        import shutil
        shutil.rmtree(DATA_SYMLINK)
    else:
        print(f"Warning: {DATA_SYMLINK} exists and is not a symlink or directory")
        os.remove(DATA_SYMLINK)

# Create symlink
if os.path.exists(MOUNTED_DATA) and os.path.isdir(MOUNTED_DATA):
    try:
        # Use absolute path for symlink target
        MOUNTED_DATA_ABS = os.path.abspath(MOUNTED_DATA)
        print(f"\nCreating symlink:")
        print(f"  From: {DATA_SYMLINK}")
        print(f"  To: {MOUNTED_DATA_ABS}")
        os.symlink(MOUNTED_DATA_ABS, DATA_SYMLINK)
        print(f"‚úÖ Created symlink: {DATA_SYMLINK} -> {MOUNTED_DATA_ABS}")

        # Verify symlink
        if os.path.exists(DATA_SYMLINK):
            print(f"‚úÖ Symlink verified: {DATA_SYMLINK}")
            print(f"  Is symlink: {os.path.islink(DATA_SYMLINK)}")
            # Try to list contents through symlink
            try:
                items = os.listdir(DATA_SYMLINK)
                print(f"‚úÖ Can access {len(items)} items through symlink")
                print(f"   First 5 items: {items[:5]}")
            except Exception as e:
                print(f"‚ö†Ô∏è  Symlink exists but cannot access contents: {e}")
        else:
            print(f"‚ùå Symlink creation failed - path does not exist after creation")
    except Exception as e:
        print(f"‚ùå Error creating symlink: {e}")
        print(f"   Source: {MOUNTED_DATA}")
        print(f"   Target: {DATA_SYMLINK}")
        import traceback
        traceback.print_exc()
else:
    print(f"‚ùå Mounted data not found at {MOUNTED_DATA}")
    print(f"   Please check that GCS bucket is mounted correctly")
    print(f"   Run the mount cell above and check for errors")
    print(f"   Mount point should exist and be accessible")

# ============================================================================
# Create annotations symlink
# ============================================================================
print("\n" + "=" * 80)
print("Creating Annotations Symlink")
print("=" * 80)

print(f"Checking mounted annotations: {MOUNTED_ANNOTATIONS}")
print(f"  Exists: {os.path.exists(MOUNTED_ANNOTATIONS)}")
if os.path.exists(MOUNTED_ANNOTATIONS):
    print(f"  Is directory: {os.path.isdir(MOUNTED_ANNOTATIONS)}")
    try:
        items = os.listdir(MOUNTED_ANNOTATIONS)
        print(f"  Can list contents: Yes ({len(items)} items)")
        if items:
            print(f"   First 5 items: {items[:5]}")
    except Exception as e:
        print(f"  Can list contents: No ({e})")

# Remove existing symlink or directory if it exists
if os.path.exists(ANNOTATIONS_SYMLINK):
    if os.path.islink(ANNOTATIONS_SYMLINK):
        print(f"Removing existing symlink: {ANNOTATIONS_SYMLINK}")
        os.unlink(ANNOTATIONS_SYMLINK)
    elif os.path.isdir(ANNOTATIONS_SYMLINK):
        print(f"Warning: {ANNOTATIONS_SYMLINK} exists as a directory (not a symlink)")
        print("   Removing it to create symlink...")
        import shutil
        shutil.rmtree(ANNOTATIONS_SYMLINK)
    else:
        print(f"Warning: {ANNOTATIONS_SYMLINK} exists and is not a symlink or directory")
        os.remove(ANNOTATIONS_SYMLINK)

# Create symlink
if os.path.exists(MOUNTED_ANNOTATIONS) and os.path.isdir(MOUNTED_ANNOTATIONS):
    try:
        # Use absolute path for symlink target
        MOUNTED_ANNOTATIONS_ABS = os.path.abspath(MOUNTED_ANNOTATIONS)
        print(f"\nCreating annotations symlink:")
        print(f"  From: {ANNOTATIONS_SYMLINK}")
        print(f"  To: {MOUNTED_ANNOTATIONS_ABS}")
        os.symlink(MOUNTED_ANNOTATIONS_ABS, ANNOTATIONS_SYMLINK)
        print(f"‚úÖ Created symlink: {ANNOTATIONS_SYMLINK} -> {MOUNTED_ANNOTATIONS_ABS}")

        # Verify symlink
        if os.path.exists(ANNOTATIONS_SYMLINK):
            print(f"‚úÖ Annotations symlink verified: {ANNOTATIONS_SYMLINK}")
            print(f"  Is symlink: {os.path.islink(ANNOTATIONS_SYMLINK)}")
            # Try to list contents through symlink
            try:
                items = os.listdir(ANNOTATIONS_SYMLINK)
                print(f"‚úÖ Can access {len(items)} annotation files through symlink")
                # Check for expected files
                expected_files = ["mp100_split1_train.json", "mp100_split1_test.json"]
                for exp_file in expected_files:
                    if exp_file in items:
                        print(f"   ‚úì Found: {exp_file}")
                    else:
                        print(f"   ‚ö†Ô∏è  Missing: {exp_file}")
            except Exception as e:
                print(f"‚ö†Ô∏è  Symlink exists but cannot access contents: {e}")
        else:
            print(f"‚ùå Annotations symlink creation failed - path does not exist after creation")
    except Exception as e:
        print(f"‚ùå Error creating annotations symlink: {e}")
        print(f"   Source: {MOUNTED_ANNOTATIONS}")
        print(f"   Target: {ANNOTATIONS_SYMLINK}")
        import traceback
        traceback.print_exc()
else:
    print(f"‚ö†Ô∏è  Mounted annotations not found at {MOUNTED_ANNOTATIONS}")
    print(f"   This might be OK if annotations are in a different location")
    print(f"   Training will look for annotations in: {ANNOTATIONS_SYMLINK}")


Checking mount point: /content/category-agnostic-pose-estimation/Raster2Seq_internal-main/data
  Exists: True
  Is directory: True
  Can list contents: Yes (93 items)
Removing existing symlink: /content/category-agnostic-pose-estimation/data

Creating symlink:
  From: /content/category-agnostic-pose-estimation/data
  To: /content/category-agnostic-pose-estimation/Raster2Seq_internal-main/data
‚úÖ Created symlink: /content/category-agnostic-pose-estimation/data -> /content/category-agnostic-pose-estimation/Raster2Seq_internal-main/data
‚úÖ Symlink verified: /content/category-agnostic-pose-estimation/data
  Is symlink: True
‚úÖ Can access 93 items through symlink
   First 5 items: ['beaver_body', 'bed', 'bighornsheep_face', 'bison_body', 'blackbuck_face']

Creating Annotations Symlink
Checking mounted annotations: /content/category-agnostic-pose-estimation/Raster2Seq_internal-main/annotations
  Exists: False
‚ö†Ô∏è  Mounted annotations not found at /content/category-agnostic-pose-estimat

In [None]:
## 7. Run Fast Training (1 Epoch, Split 1)

# This section trains the CAPE model for a quick test run.

# **Training Configuration:**
# - Full episodic training mode
# - Epochs: 1 (fast test run)
# - Split: 1 (default)
# - Episodes per epoch: 500 (reduced for speed)
# - Batch size: 10
# - Queries per episode: 2
# - All logs will be saved to `training_logs.txt`


# Configure full training parameters


In [46]:

# Configure output directories
import os
from datetime import datetime
from google.colab import drive

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"

# Mount Google Drive for persistent checkpoint storage
print("Mounting Google Drive for checkpoint persistence...")
drive.mount('/content/drive')

# Output directory - save to Google Drive
OUTPUT_DIR = "/content/drive/MyDrive/cape_training_output"
LOG_FILE = os.path.join(OUTPUT_DIR, "training_logs.txt")

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

print("\n" + "=" * 80)
print("Fast CAPE Training Configuration (1 Epoch, Split 1)")
print("=" * 80)
print(f"Mode:              Fast test run (1 epoch)")
print(f"\nüìÅ STORAGE (Google Drive - Persistent):")
print(f"  Output directory:  {OUTPUT_DIR}")
print(f"  Log file:          {LOG_FILE}")
print(f"\nüí° All files will be saved to Google Drive")
print(f"   You can access them after the Colab session ends")
print(f"\nTraining configuration:")
print(f"  - Epochs: 1 (fast test run)")
print(f"  - Split: 1 (default)")
print(f"  - Batch size: 10")
print(f"  - Episodes per epoch: 500 (reduced for speed)")
print(f"  - Learning rate: 1e-4 (backbone: 1e-5)")
print(f"\nNon-default settings:")
print(f"  - Geometric support encoder: ENABLED")
print(f"  - GCN pre-encoding: ENABLED")
print(f"  - Classification loss weight: 2.0 (default: 1.0)")
print(f"  - Fixed validation episodes: ENABLED (stable curves)")
print(f"  - GPU optimizations: Mixed precision + cuDNN tuning")
print("=" * 80)
print()




Mounting Google Drive for checkpoint persistence...
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).

Fast CAPE Training Configuration (1 Epoch, Split 1)
Mode:              Fast test run (1 epoch)

üìÅ STORAGE (Google Drive - Persistent):
  Output directory:  /content/drive/MyDrive/cape_training_output
  Log file:          /content/drive/MyDrive/cape_training_output/training_logs.txt

üí° All files will be saved to Google Drive
   You can access them after the Colab session ends

Training configuration:
  - Epochs: 1 (fast test run)
  - Split: 1 (default)
  - Batch size: 10
  - Episodes per epoch: 500 (reduced for speed)
  - Learning rate: 1e-4 (backbone: 1e-5)

Non-default settings:
  - Geometric support encoder: ENABLED
  - GCN pre-encoding: ENABLED
  - Classification loss weight: 2.0 (default: 1.0)
  - Fixed validation episodes: ENABLED (stable curves)
  - GPU optimizations: Mixed precision + cuDNN tun

In [None]:
# Ensure we have the latest code before training
# This cell stashes local changes, pulls latest code, then reapplies CUDA fixes
import os
import subprocess

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"
BRANCH = "main"  # Use main branch

print("=" * 80)
print("Pulling Latest Code from GitHub")
print("=" * 80)
print(f"Repository: {PROJECT_ROOT}")
print(f"Branch: {BRANCH}")
print()

# Change to project directory
os.chdir(PROJECT_ROOT)

# Check for local changes
result = subprocess.run(
    ['git', 'status', '--porcelain'],
    capture_output=True,
    text=True,
    timeout=10
)

if result.stdout.strip():
    print("‚ö†Ô∏è  Local changes detected. Stashing them before pull...")
    print("   (These will be reapplied after pull if needed)")
    stash_result = subprocess.run(
        ['git', 'stash'],
        capture_output=True,
        text=True,
        timeout=10
    )
    if stash_result.returncode == 0:
        print("‚úÖ Local changes stashed")
    else:
        print(f"‚ö†Ô∏è  Stash had issues: {stash_result.stderr}")

# Pull latest changes
try:
    result = subprocess.run(
        ['git', 'pull', 'origin', BRANCH],
        capture_output=True,
        text=True,
        timeout=30
    )
    
    if result.returncode == 0:
        print("‚úÖ Successfully pulled latest code!")
        if result.stdout.strip():
            print("\nChanges:")
            print(result.stdout)
    else:
        print(f"‚ö†Ô∏è  Git pull had issues (return code: {result.returncode})")
        if result.stderr:
            print(f"Error: {result.stderr}")
        print("\nContinuing anyway - using existing code...")
except subprocess.TimeoutExpired:
    print("‚ö†Ô∏è  Git pull timed out - continuing with existing code...")
except Exception as e:
    print(f"‚ö†Ô∏è  Error during git pull: {e}")
    print("Continuing with existing code...")

print("\n" + "=" * 80)
print("Ready to proceed with training")
print("=" * 80)
print("\nüí° Note: CUDA fixes will be reapplied in the next cell if needed")


Pulling Latest Code from GitHub
Repository: /content/category-agnostic-pose-estimation
Branch: main

‚ö†Ô∏è  Git pull had issues (return code: 1)
Error: From https://github.com/nkkrnkl/category-agnostic-pose-estimation
 * branch            main       -> FETCH_HEAD
error: Your local changes to the following files would be overwritten by merge:
	datasets/episodic_sampler.py
Please commit your changes or stash them before you merge.
Aborting


Continuing anyway - using existing code...

Ready to proceed with training


# Run Training with Logging


In [50]:
# Verify data is accessible before training
import os
import json

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"
DATA_DIR = os.path.join(PROJECT_ROOT, "data")
ANNOTATION_FILE = '/content/category-agnostic-pose-estimation/data/cleaned_annotations/mp100_split1_train.json'

print("Verifying dataset for full training...")
print(f"Data directory: {DATA_DIR}")
print(f"Annotation file: {ANNOTATION_FILE}")
print()

# Verify annotations exist
if os.path.exists(ANNOTATION_FILE):
    with open(ANNOTATION_FILE, 'r') as f:
        coco_data = json.load(f)

    num_images = len(coco_data.get('images', []))
    num_annotations = len(coco_data.get('annotations', []))
    num_categories = len(coco_data.get('categories', []))

    print(f"‚úÖ Training annotations loaded:")
    print(f"   Images: {num_images}")
    print(f"   Annotations: {num_annotations}")
    print(f"   Categories: {num_categories}")
else:
    print(f"‚ùå Annotation file not found: {ANNOTATION_FILE}")
    print("Training will fail!")

# Verify data directory is accessible
if os.path.exists(DATA_DIR) and os.path.islink(DATA_DIR):
    items = os.listdir(DATA_DIR)
    print(f"‚úÖ Data directory accessible: {len(items)} items found")
else:
    print(f"‚ùå Data directory not accessible: {DATA_DIR}")
    print("Training will fail!")

print("\n‚úÖ Dataset verification complete. Ready to train on full MP-100 dataset.")


Verifying dataset for full training...
Data directory: /content/category-agnostic-pose-estimation/data
Annotation file: /content/category-agnostic-pose-estimation/data/cleaned_annotations/mp100_split1_train.json

‚úÖ Training annotations loaded:
   Images: 12816
   Annotations: 13712
   Categories: 70
‚úÖ Data directory accessible: 93 items found

‚úÖ Dataset verification complete. Ready to train on full MP-100 dataset.


In [None]:
## 8. Check Training Results


In [51]:
# Run full CAPE training with logging
import subprocess
import sys
from datetime import datetime

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"

# Build training command - only non-default flags
cmd = [
    sys.executable, "-m", "models.train_cape_episodic",

    # Required paths (Colab-specific)
    "--dataset_root", PROJECT_ROOT,
    "--output_dir", OUTPUT_DIR,

    # Architecture improvements (better performance)
    "--use_geometric_encoder",  # Use CapeX-inspired geometric encoder
    "--use_gcn_preenc",          # Add GCN pre-processing for better keypoint relationships

    # Loss tuning (faster convergence)
    "--cls_loss_coef", "2.0",    # Weight token classification loss 2x (helps sequence learning)

    # Validation stability (reproducible metrics across epochs)
    "--fixed_val_episodes",      # Cache validation episodes for stable curves

    # GPU optimizations (2x faster on Colab)
    "--use_amp",                 # Mixed precision training (FP16/FP32)
    "--cudnn_benchmark",         # Auto-tune cuDNN convolutions

    # Logging
    "--print_freq", "10",        # Print stats every 10 batches
    "--early_stopping_patience", "300",
    "--accumulation_steps", "4",
    "--episodes_per_epoch", "500",
    "--batch_size", "10",
    "--epochs", "1",             # Fast test: 1 epoch
    "--mp100_split", "1",        # Use split 1
    "--val_seed", "42"
]

print("=" * 80)
print("Starting CAPE Full Training")
print("=" * 80)
print(f"Output directory: {OUTPUT_DIR}")
print(f"Logging to: {LOG_FILE}")
print(f"\nTraining parameters:")
print(f"  - Epochs: 1 (fast test run)")
print(f"  - Split: 1")
print(f"  - Batch size: 10 (with accumulation_steps=4 ‚Üí effective=40)")
print(f"  - Episodes per epoch: 500 (train), 200 (val)")
print(f"  - Learning rate: 1e-4 (backbone: 1e-5)")
print(f"\nValidation stability:")
print(f"  - Fixed validation episodes: YES")
print(f"  - Same 200 episodes reused each epoch for reproducible curves")
print(f"\nGPU optimizations enabled:")
print(f"  - Mixed precision (AMP)")
print(f"  - cuDNN auto-tuning")
print("=" * 80)
print()

# Change to project directory
os.chdir(PROJECT_ROOT)

# Run training with logging to both stdout and file
with open(LOG_FILE, 'w') as log_file:
    # Write header to log file
    log_file.write("=" * 80 + "\n")
    log_file.write(f"CAPE Full Training Log\n")
    log_file.write(f"Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
    log_file.write(f"Using default training parameters (see train_cape_episodic.py)\n")
    log_file.write(f"Non-default flags: geometric_encoder, gcn_preenc, cls_loss_coef=2.0, fixed_val_episodes\n")
    log_file.write(f"Validation: 200 fixed episodes per epoch (stable curves)\n")
    log_file.write(f"GPU optimizations: AMP, cuDNN benchmark\n")
    log_file.write("=" * 80 + "\n\n")
    log_file.flush()

    # Run process and stream output to both stdout and file
    process = subprocess.Popen(
        cmd,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        universal_newlines=True,
        bufsize=1
    )

    # Stream output in real-time
    for line in process.stdout:
        print(line, end='')  # Print to notebook
        log_file.write(line)  # Write to log file
        log_file.flush()  # Ensure immediate write

    # Wait for process to complete
    return_code = process.wait()

    # Write footer to log file
    log_file.write("\n" + "=" * 80 + "\n")
    log_file.write(f"Training completed: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
    log_file.write(f"Return code: {return_code}\n")
    log_file.write("=" * 80 + "\n")

if return_code == 0:
    print("\n" + "=" * 80)
    print("‚úÖ Training completed successfully!")
    print(f"Checkpoints saved to: {OUTPUT_DIR}")
    print(f"Full logs saved to: {LOG_FILE}")

    # Create training summary file for easy recovery
    import glob
    summary_file = os.path.join(OUTPUT_DIR, "TRAINING_SUMMARY.txt")
    with open(summary_file, 'w') as f:
        f.write("=" * 80 + "\n")
        f.write("CAPE FULL TRAINING - RESULTS SUMMARY\n")
        f.write("=" * 80 + "\n\n")
        f.write(f"Training completed: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"Location: Google Drive/cape_training_output/\n\n")

        # Count checkpoints
        all_ckpts = glob.glob(os.path.join(OUTPUT_DIR, "checkpoint_e*.pth"))
        best_ckpts = glob.glob(os.path.join(OUTPUT_DIR, "checkpoint_best_pck*.pth"))

        f.write(f"CHECKPOINTS:\n")
        f.write(f"  Total epoch checkpoints: {len(all_ckpts)}\n")
        f.write(f"  Best PCK checkpoints: {len(best_ckpts)}\n\n")

        if best_ckpts:
            # Find the actual best checkpoint (highest PCK)
            best_ckpt = sorted(best_ckpts)[-1]  # Last one alphabetically has highest PCK
            f.write(f"BEST MODEL:\n")
            f.write(f"  {os.path.basename(best_ckpt)}\n\n")

        f.write(f"FILES IN THIS DIRECTORY:\n")
        f.write(f"  - checkpoint_e***.pth         : Per-epoch checkpoints\n")
        f.write(f"  - checkpoint_best_pck_*.pth   : Best validation PCK checkpoints\n")
        f.write(f"  - training_logs.txt           : Complete training logs\n")
        f.write(f"  - TRAINING_SUMMARY.txt        : This file\n\n")

        f.write(f"TO EVALUATE A CHECKPOINT:\n")
        f.write(f"  1. Download checkpoint from Google Drive\n")
        f.write(f"  2. Run: python scripts/eval_cape_checkpoint.py \\\n")
        f.write(f"             --checkpoint <path_to_checkpoint.pth> \\\n")
        f.write(f"             --dataset_root . \\\n")
        f.write(f"             --split test\n\n")

        f.write(f"GOOGLE DRIVE PATH:\n")
        f.write(f"  {OUTPUT_DIR}\n\n")
        f.write("=" * 80 + "\n")

    print(f"Training summary saved to: {summary_file}")
    print("=" * 80)
else:
    print("\n" + "=" * 80)
    print(f"‚ùå Training failed with return code: {return_code}")
    print(f"Check logs at: {LOG_FILE}")
    print("=" * 80)


Starting CAPE Full Training
Output directory: /content/drive/MyDrive/cape_training_output
Logging to: /content/drive/MyDrive/cape_training_output/training_logs.txt

Training parameters:
  - Epochs: 1 (fast test run)
  - Split: 1
  - Batch size: 10 (with accumulation_steps=4 ‚Üí effective=40)
  - Episodes per epoch: 500 (train), 200 (val)
  - Learning rate: 1e-4 (backbone: 1e-5)

Validation stability:
  - Fixed validation episodes: YES
  - Same 200 episodes reused each epoch for reproducible curves

GPU optimizations enabled:
  - Mixed precision (AMP)
  - cuDNN auto-tuning

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/content/category-agnostic-pose-estimation/models/train_cape_episodic.py", line 1002, in <module>
    main(args)
  File "/content/category-agnostic-pose-estimation/models/train_cape_episodic.py", line 397, in main
    from datasets.episodic_sampler import build_episodic_d

## üìÅ Accessing Your Trained Models

All training results are **automatically saved to Google Drive** and will persist after the Colab session ends.

### üîç What's Saved:

1. **Checkpoints** (`*.pth` files):
   - `checkpoint_e###_lr*_bs*_acc*_qpe*.pth` - Every epoch
   - `checkpoint_best_pck_*.pth` - Best validation PCK models
   
2. **Logs**:
   - `training_logs.txt` - Complete training output
   - `TRAINING_SUMMARY.txt` - Quick summary with best model info

### üìÇ Location:
**Google Drive ‚Üí My Drive ‚Üí cape_training_output/**

### üíæ To Download for Local Evaluation:

1. **On Google Drive Web:**
   - Go to: https://drive.google.com/drive/my-drive
   - Navigate to `cape_training_output/`
   - Download the `checkpoint_best_pck_*.pth` file

2. **On Your Computer:**
   ```bash
   # Evaluate on test set (unseen categories)
   python scripts/eval_cape_checkpoint.py \
       --checkpoint path/to/checkpoint_best_pck_*.pth \
       --dataset_root . \
       --split test \
       --device cuda
   ```

### ‚ö†Ô∏è Important Notes:
- Checkpoints are **~580MB each** - only download what you need
- The **best PCK checkpoint** is usually what you want for evaluation
- All results persist in Google Drive even after Colab session ends


In [36]:
# Check training results and find best checkpoint
import os
import glob

OUTPUT_DIR = "/content/drive/MyDrive/cape_training_output"

print("=" * 80)
print("TRAINING RESULTS - Saved to Google Drive")
print("=" * 80)
print(f"üìÅ Location: {OUTPUT_DIR}")
print()

# Check if training summary exists
summary_file = os.path.join(OUTPUT_DIR, "TRAINING_SUMMARY.txt")
if os.path.exists(summary_file):
    print("‚úÖ Training Summary:")
    print("‚îÄ" * 80)
    with open(summary_file, 'r') as f:
        print(f.read())
    print("‚îÄ" * 80)
    print()
else:
    print("‚ö†Ô∏è  No TRAINING_SUMMARY.txt found - training may not be complete yet")
    print()

# List all checkpoints
checkpoints = glob.glob(os.path.join(OUTPUT_DIR, "*.pth"))
if checkpoints:
    print(f"üì¶ Found {len(checkpoints)} checkpoint file(s):")

    # Separate into epoch and best checkpoints
    epoch_ckpts = [c for c in checkpoints if "checkpoint_e" in os.path.basename(c) and "best" not in os.path.basename(c)]
    best_ckpts = [c for c in checkpoints if "best_pck" in os.path.basename(c)]

    if best_ckpts:
        print(f"\n  üèÜ Best PCK Checkpoints ({len(best_ckpts)}):")
        for ckpt in sorted(best_ckpts):
            size_mb = os.path.getsize(ckpt) / (1024 * 1024)
            print(f"    - {os.path.basename(ckpt)} ({size_mb:.1f} MB)")

    if epoch_ckpts:
        print(f"\n  üìä Epoch Checkpoints ({len(epoch_ckpts)}):")
        # Show first 3 and last 3
        for ckpt in sorted(epoch_ckpts)[:3]:
            size_mb = os.path.getsize(ckpt) / (1024 * 1024)
            print(f"    - {os.path.basename(ckpt)} ({size_mb:.1f} MB)")
        if len(epoch_ckpts) > 6:
            print(f"    ... ({len(epoch_ckpts) - 6} more) ...")
        for ckpt in sorted(epoch_ckpts)[-3:]:
            size_mb = os.path.getsize(ckpt) / (1024 * 1024)
            print(f"    - {os.path.basename(ckpt)} ({size_mb:.1f} MB)")

    # Calculate total size
    total_size = sum(os.path.getsize(c) for c in checkpoints) / (1024 * 1024 * 1024)
    print(f"\n  üíæ Total checkpoint size: {total_size:.2f} GB")
    print()

    # Find best checkpoint
    best_ckpts = glob.glob(os.path.join(OUTPUT_DIR, "checkpoint_best_pck*.pth"))
    if best_ckpts:
        best_ckpt = sorted(best_ckpts)[-1]  # Get most recent
        print(f"\n‚úÖ Best checkpoint: {os.path.basename(best_ckpt)}")
        BEST_CHECKPOINT = best_ckpt
    else:
        # Use most recent checkpoint
        best_ckpt = sorted(checkpoints, key=os.path.getmtime)[-1]
        print(f"\n‚ö†Ô∏è  No 'best' checkpoint found, using most recent: {os.path.basename(best_ckpt)}")
        BEST_CHECKPOINT = best_ckpt
else:
    print("‚ùå No checkpoints found!")
    BEST_CHECKPOINT = None

# Show log file info
LOG_FILE = os.path.join(OUTPUT_DIR, "training_logs.txt")
if os.path.exists(LOG_FILE):
    size_mb = os.path.getsize(LOG_FILE) / (1024 * 1024)
    print(f"\nüìÑ Training log: {os.path.basename(LOG_FILE)} ({size_mb:.2f} MB)")

    # Show last few lines of log
    print("\nLast 20 lines of training log:")
    print("-" * 80)
    with open(LOG_FILE, 'r') as f:
        lines = f.readlines()
        for line in lines[-20:]:
            print(line.rstrip())
else:
    print("\n‚ö†Ô∏è  Training log not found")

print("\n" + "=" * 80)


TRAINING RESULTS - Saved to Google Drive
üìÅ Location: /content/drive/MyDrive/cape_training_output

‚ö†Ô∏è  No TRAINING_SUMMARY.txt found - training may not be complete yet

‚ùå No checkpoints found!

üìÑ Training log: training_logs.txt (0.00 MB)

Last 20 lines of training log:
--------------------------------------------------------------------------------
  - GCN layers: 2
Support encoder layers: 3
Fusion method: cross_attention
Queries per episode: 2
Train episodes per epoch: 500
Val episodes per epoch: 200
Fixed validation episodes: YES (seed=42) - stable curves

Using device: cuda:0
  GPU: NVIDIA A100-SXM4-80GB
  CUDA Version: 12.6
  GPU Memory: 79.32 GB
  cuDNN benchmark: Enabled (auto-tuning convolution algorithms)
  Mixed Precision (AMP): Enabled (FP16/FP32 training)


Training completed: 2025-11-30 22:21:50
Return code: 1



## 9. Evaluate Best Checkpoint on Test Set

This section evaluates the best trained checkpoint on the **test set** (unseen categories) using the curated `eval_cape_checkpoint.py` script.

The script will:
- Run autoregressive inference on test episodes
- Compute PCK@0.2 metrics (overall + per-category)
- Generate visualizations (support + GT + predicted keypoints)
- Save all results to Google Drive for later recovery

Results are saved to: `Google Drive/cape_training_output/evaluate_checkpoint_<epoch>_<date>/`

In [None]:
# Evaluate best checkpoint using curated eval_cape_checkpoint.py script
# Run both 1-shot and 5-shot evaluations
import os
import sys
import subprocess
from datetime import datetime
import glob
import re

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"
TRAINING_OUTPUT_DIR = "/content/drive/MyDrive/cape_training_output"

print("=" * 80)
print("CHECKPOINT EVALUATION ON TEST SET (1-Shot and 5-Shot)")
print("=" * 80)
print()

# Find best PCK checkpoint
print("Finding best checkpoint...")
best_ckpts = glob.glob(os.path.join(TRAINING_OUTPUT_DIR, "checkpoint_best_pck*.pth"))

if not best_ckpts:
    # Fallback: use most recent epoch checkpoint
    epoch_ckpts = glob.glob(os.path.join(TRAINING_OUTPUT_DIR, "checkpoint_e*.pth"))
    if epoch_ckpts:
        best_checkpoint = sorted(epoch_ckpts, key=os.path.getmtime)[-1]
        print(f"‚ö†Ô∏è  No best checkpoint found, using most recent: {os.path.basename(best_checkpoint)}")
    else:
        print("‚ùå No checkpoint found!")
        print(f"   Looking in: {TRAINING_OUTPUT_DIR}")
        print("   Please run training first (Cell 21)")
        sys.exit(1)
else:
    # Use the most recent best checkpoint (highest PCK if multiple)
    best_checkpoint = sorted(best_ckpts)[-1]
    print(f"‚úÖ Found best checkpoint: {os.path.basename(best_checkpoint)}")

# Extract epoch number from checkpoint filename
# Format: checkpoint_best_pck_e###_pck0.####_meanpck0.####.pth
match = re.search(r'_e(\d+)_', os.path.basename(best_checkpoint))
if match:
    epoch_num = int(match.group(1))
else:
    # Fallback: use 'unknown' if can't parse
    epoch_num = 'unknown'

print(f"Epoch number: {epoch_num}")
print()

# Change to project root
os.chdir(PROJECT_ROOT)

# Run evaluations for both 1-shot and 5-shot
for num_support in [1, 5]:
    shot_name = f"{num_support}-shot"
    print("\n" + "=" * 80)
    print(f"EVALUATION: {shot_name.upper()}")
    print("=" * 80)

    # Create evaluation output directory with shot info
    date_str = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    eval_folder_name = f"evaluate_checkpoint_e{epoch_num}_{shot_name}_{date_str}"
    EVAL_OUTPUT_DIR = os.path.join(TRAINING_OUTPUT_DIR, eval_folder_name)
    os.makedirs(EVAL_OUTPUT_DIR, exist_ok=True)

    # Build evaluation command
    cmd = [
        sys.executable,
        "scripts/eval_cape_checkpoint.py",

        # Required: checkpoint to evaluate
        "--checkpoint", best_checkpoint,

        # Evaluate on test set (unseen categories)
        "--split", "test",

        # Dataset location
        "--dataset-root", PROJECT_ROOT,

        # Output directory (in Google Drive)
        "--output-dir", EVAL_OUTPUT_DIR,

        # Evaluation parameters
        "--num-episodes", "200",         # 200 test episodes
        "--num-support-per-episode", str(num_support),  # 1-shot or 5-shot
        "--num-visualizations", "50",   # Save 50 example visualizations
        "--pck-threshold", "0.2",       # PCK@0.2 (standard metric)
        "--draw-skeleton",               # Draw skeleton edges
        "--eval_seed", "42",            # Reproducible evaluation

        # System
        "--device", "cuda",
        "--num-workers", "2",
    ]

    print(f"Checkpoint:  {os.path.basename(best_checkpoint)}")
    print(f"Split:       test (unseen categories)")
    print(f"Shot:        {shot_name} ({num_support} support image(s) per episode)")
    print(f"Episodes:    200")
    print(f"Output:      {eval_folder_name}")
    print()

    # Run evaluation
    try:
        result = subprocess.run(cmd, check=True, capture_output=False, text=True)

        print(f"\n‚úÖ {shot_name.upper()} EVALUATION COMPLETE")
        print(f"Results saved to: {EVAL_OUTPUT_DIR}")

    except subprocess.CalledProcessError as e:
        print(f"\n‚ùå {shot_name} evaluation failed with return code: {e.returncode}")
        print(f"Check output above for error details")
    except Exception as e:
        print(f"\n‚ùå Error during {shot_name} evaluation: {e}")
        import traceback
        traceback.print_exc()

print("\n" + "=" * 80)
print("‚úÖ ALL EVALUATIONS COMPLETE")
print("=" * 80)
print(f"Results saved to Google Drive:")
print(f"  MyDrive/cape_training_output/")
print()
print("Look for folders:")
print(f"  - evaluate_checkpoint_e{epoch_num}_1-shot_*")
print(f"  - evaluate_checkpoint_e{epoch_num}_5-shot_*")
print()
print("Each folder contains:")
print(f"  - metrics_test.json       : PCK metrics (overall + per-category)")
print(f"  - visualizations/*.png    : 50 example visualizations")
print("=" * 80)



‚úÖ Using last checkpoint: checkpoint_best_pck_e010_pck0.8889_meanpck0.8889.pth

Available checkpoints (56 total):
  - checkpoint_e045_lr1e-04_bs1_acc4_qpe1.pth (epoch 45)
  - checkpoint_e046_lr1e-04_bs1_acc4_qpe1.pth (epoch 46)
  - checkpoint_e047_lr1e-04_bs1_acc4_qpe1.pth (epoch 47)
  - checkpoint_e048_lr1e-04_bs1_acc4_qpe1.pth (epoch 48)
  - checkpoint_e049_lr1e-04_bs1_acc4_qpe1.pth (epoch 49)
  ... and 51 more

CAPE Prediction Visualization
Checkpoint: checkpoint_best_pck_e010_pck0.8889_meanpck0.8889.pth
Image:      /content/category-agnostic-pose-estimation/data/camel_face/camel_16.jpg
Output:     /content/category-agnostic-pose-estimation/output/single_image_colab/visualizations

Running visualization...
Command: /usr/bin/python3 -m models.visualize_cape_predictions --checkpoint /content/category-agnostic-pose-estimation/output/single_image_colab/checkpoint_best_pck_e010_pck0.8889_meanpck0.8889.pth --dataset_root /content/category-agnostic-pose-estimation --device cuda --single_i

In [None]:
# View evaluation results summary
import os
import json
import glob

# Evaluation output directory (from previous cell)
if 'EVAL_OUTPUT_DIR' not in globals():
    # Fallback: find most recent evaluation folder
    TRAINING_OUTPUT_DIR = "/content/drive/MyDrive/cape_training_output"
    eval_folders = glob.glob(os.path.join(TRAINING_OUTPUT_DIR, "evaluate_checkpoint_*"))
    if eval_folders:
        EVAL_OUTPUT_DIR = sorted(eval_folders)[-1]  # Most recent
        print(f"Using most recent evaluation: {os.path.basename(EVAL_OUTPUT_DIR)}")
    else:
        print("‚ùå No evaluation results found. Run Cell 24 first.")
        import sys
        sys.exit(1)

print("=" * 80)
print("EVALUATION RESULTS SUMMARY")
print("=" * 80)
print(f"üìÅ Location: {EVAL_OUTPUT_DIR}")
print()

# Load and display evaluation metrics
metrics_file = os.path.join(EVAL_OUTPUT_DIR, "metrics_test.json")

if os.path.exists(metrics_file):
    print("üìä Test Set Evaluation Metrics:")
    print("‚îÄ" * 80)

    with open(metrics_file, 'r') as f:
        metrics = json.load(f)

    # Display overall metrics
    pck_overall = metrics.get('pck_overall', 0.0)
    mean_pck = metrics.get('mean_pck_categories', 0.0)
    total_correct = metrics.get('total_correct', 0)
    total_visible = metrics.get('total_visible', 0)

    print(f"  Overall PCK@0.2:          {pck_overall:.4f} ({pck_overall*100:.2f}%)")
    print(f"  Mean PCK (categories):    {mean_pck:.4f} ({mean_pck*100:.2f}%)")
    print(f"  Correct keypoints:        {total_correct} / {total_visible}")
    print()

    # Display per-category results (top 5 and bottom 5)
    if 'per_category' in metrics and len(metrics['per_category']) > 0:
        cat_pcks = [(cat_id, cat_data['pck']) for cat_id, cat_data in metrics['per_category'].items()]
        cat_pcks_sorted = sorted(cat_pcks, key=lambda x: x[1], reverse=True)

        print(f"  Top 5 Categories:")
        for cat_id, pck in cat_pcks_sorted[:5]:
            print(f"    Category {cat_id}: {pck:.2%}")

        print()
        print(f"  Bottom 5 Categories:")
        for cat_id, pck in cat_pcks_sorted[-5:]:
            print(f"    Category {cat_id}: {pck:.2%}")

    print("‚îÄ" * 80)
    print()
else:
    print("‚ö†Ô∏è  Metrics file not found - evaluation may have failed")
    print(f"   Expected: {metrics_file}")
    print()

# Count visualization files
vis_dir = os.path.join(EVAL_OUTPUT_DIR, "visualizations")
if os.path.exists(vis_dir):
    vis_files = glob.glob(os.path.join(vis_dir, "*.png"))
    print(f"üñºÔ∏è  Visualizations: {len(vis_files)} files generated")
    print(f"   Location: {vis_dir}")
    print()

# Final summary
print("=" * 80)
print("üì• ACCESS YOUR RESULTS")
print("=" * 80)
print()
print("All results are saved in Google Drive and persist after Colab session ends!")
print()
print("Google Drive path:")
print(f"  MyDrive/cape_training_output/{os.path.basename(EVAL_OUTPUT_DIR)}/")
print()
print("Contains:")
print(f"  - metrics_test.json       : Full PCK metrics + per-category breakdown")
print(f"  - visualizations/*.png    : Example predictions with GT comparison")
print()
print("To download:")
print(f"  1. Go to https://drive.google.com/drive/my-drive")
print(f"  2. Navigate to cape_training_output/{os.path.basename(EVAL_OUTPUT_DIR)}/")
print(f"  3. Download the entire folder or specific files")
print()
print("=" * 80)
