# MP-100 CAPE Training on Google Colab

This notebook trains Category-Agnostic Pose Estimation (CAPE) on the MP-100 dataset using Google Colab's GPU.

## Setup Instructions
1. Enable GPU: Runtime → Change runtime type → GPU (T4 or better)
2. Run all cells in order
3. The notebook will:
   - Clone code from GitHub
   - Install dependencies
   - Authenticate to GCP
   - Mount GCS bucket with data
   - Run training with "tiny" mode


## 1. Check GPU Availability


In [None]:
# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("⚠️  No GPU detected! Please enable GPU in Runtime > Change runtime type > GPU")


In [None]:
# Clone repository from GitHub
import os
from getpass import getpass

REPO_URL = "https://github.com/nkkrnkl/category-agnostic-pose-estimation.git"
BRANCH = "teo-branch-copy"
PROJECT_ROOT = "/content/category-agnostic-pose-estimation"

# Remove existing directory if it exists
if os.path.exists(PROJECT_ROOT):
    print(f"Removing existing directory: {PROJECT_ROOT}")
    !rm -rf {PROJECT_ROOT}

# For private repositories, you need to authenticate
# Option 1: Use Personal Access Token (recommended)
# Get token from: https://github.com/settings/tokens
# Create a token with 'repo' scope
print("For private repositories, you need to authenticate.")
print("Option 1: Enter your GitHub Personal Access Token")
print("  (Get one from: https://github.com/settings/tokens)")
print("Option 2: Press Enter to try without token (will fail if repo is private)")
print()

GITHUB_TOKEN = getpass("Enter GitHub Personal Access Token (or press Enter to skip): ")

if GITHUB_TOKEN.strip():
    # Use token in URL
    # Format: https://TOKEN@github.com/username/repo.git
    AUTH_REPO_URL = REPO_URL.replace("https://github.com/", f"https://{GITHUB_TOKEN}@github.com/")
    print(f"Cloning repository from {REPO_URL} (branch: {BRANCH})...")
    !git clone -b {BRANCH} {AUTH_REPO_URL} {PROJECT_ROOT}
else:
    # Try without token (will work if repo is public)
    print(f"Cloning repository from {REPO_URL} (branch: {BRANCH})...")
    !git clone -b {BRANCH} {REPO_URL} {PROJECT_ROOT}

# Verify clone
if os.path.exists(PROJECT_ROOT) and os.path.exists(os.path.join(PROJECT_ROOT, ".git")):
    print(f"✅ Repository cloned successfully to {PROJECT_ROOT}")
    !cd {PROJECT_ROOT} && git branch
else:
    print("❌ Failed to clone repository")
    print("\nIf the repository is private, you need to:")
    print("1. Create a Personal Access Token at: https://github.com/settings/tokens")
    print("2. Select 'repo' scope")
    print("3. Run this cell again and paste the token when prompted")


## 3. Install Requirements


In [None]:
# Install additional dependencies needed for plot_utils and other utilities
# (descartes, shapely, etc. - these are in requirements.txt but not requirements_cape.txt)
print("Installing additional dependencies (descartes, shapely, etc.)...")
!pip install -q descartes shapely>=1.8.0
print("✅ Additional dependencies installed!")


In [None]:
# Install requirements
import os

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"
REQUIREMENTS_FILE = os.path.join(PROJECT_ROOT, "requirements_cape.txt")

print("Installing requirements from requirements_cape.txt...")
!cd {PROJECT_ROOT} && pip install -q -r {REQUIREMENTS_FILE}

# Install detectron2 for CUDA 11.8 (Colab typically has CUDA 11.8)
print("\nInstalling detectron2...")
!pip install -q 'git+https://github.com/facebookresearch/detectron2.git'

print("✅ All dependencies installed!")


## 4. Authenticate to GCP


In [None]:
# Authenticate to GCP
from google.colab import auth

print("Authenticating to GCP...")
auth.authenticate_user()

# Set GCP project
GCP_PROJECT = "dl-category-agnostic-pose-est"
!gcloud config set project {GCP_PROJECT}

print(f"✅ Authenticated to GCP project: {GCP_PROJECT}")


## 5. Mount GCS Bucket


In [None]:
# Verify data access before training
import os
from pathlib import Path

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"
DATA_DIR = os.path.join(PROJECT_ROOT, "data")

print("Verifying data access...")
print(f"Data directory: {DATA_DIR}")
print(f"Exists: {os.path.exists(DATA_DIR)}")
print(f"Is symlink: {os.path.islink(DATA_DIR)}")

if os.path.exists(DATA_DIR):
    # Check if we can list directories
    try:
        categories = [d for d in os.listdir(DATA_DIR) if os.path.isdir(os.path.join(DATA_DIR, d))]
        print(f"✅ Found {len(categories)} category directories")
        if len(categories) > 0:
            print(f"   First 5 categories: {categories[:5]}")
            
            # Try to access a file in the first category
            first_cat = categories[0]
            cat_dir = os.path.join(DATA_DIR, first_cat)
            files = [f for f in os.listdir(cat_dir) if f.endswith(('.jpg', '.png', '.jpeg'))]
            if len(files) > 0:
                test_file = os.path.join(cat_dir, files[0])
                print(f"   Test file exists: {os.path.exists(test_file)}")
                print(f"   Test file: {test_file}")
            else:
                print(f"   ⚠️  No image files found in {first_cat}")
    except Exception as e:
        print(f"❌ Error accessing data directory: {e}")
        print("   This might indicate the GCS mount is not working properly")
else:
    print(f"❌ Data directory does not exist: {DATA_DIR}")
    print("   Please check:")
    print("   1. GCS bucket is mounted")
    print("   2. Data symlink is created")


In [None]:
# Mount GCS bucket using gcsfuse
import os
import subprocess
import time

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"
BUCKET_NAME = "dl-category-agnostic-pose-mp100-data"
MOUNT_POINT = os.path.join(PROJECT_ROOT, "Raster2Seq_internal-main", "data")

# Install gcsfuse from Google's official repository
print("Installing gcsfuse...")
# Add Google's gcsfuse repository (updated method for newer Ubuntu versions)
!export GCSFUSE_REPO=gcsfuse-`lsb_release -c -s` && \
echo "deb http://packages.cloud.google.com/apt $GCSFUSE_REPO main" | sudo tee /etc/apt/sources.list.d/gcsfuse.list && \
curl -fsSL https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg && \
echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] http://packages.cloud.google.com/apt $GCSFUSE_REPO main" | sudo tee /etc/apt/sources.list.d/gcsfuse.list && \
sudo apt-get update && \
sudo apt-get install -y gcsfuse

# Verify installation
!which gcsfuse
print("✅ gcsfuse installed")

# Create mount point directory and parent directories
print(f"Creating mount point: {MOUNT_POINT}")
os.makedirs(os.path.dirname(MOUNT_POINT), exist_ok=True)
os.makedirs(MOUNT_POINT, exist_ok=True)

# Check if already mounted
try:
    result = subprocess.run(['mountpoint', '-q', MOUNT_POINT], capture_output=True)
    if result.returncode == 0:
        print(f"✅ Already mounted at {MOUNT_POINT}")
    else:
        # Try to unmount if exists but not properly mounted
        try:
            subprocess.run(['fusermount', '-u', MOUNT_POINT], capture_output=True, timeout=5)
        except:
            try:
                subprocess.run(['umount', MOUNT_POINT], capture_output=True, timeout=5)
            except:
                pass
except:
    pass

# Mount the bucket
print(f"Mounting gs://{BUCKET_NAME} to {MOUNT_POINT}...")
print("This may take a moment...")

# Run gcsfuse in background
# Note: In Colab, we need to run gcsfuse in background using shell &
print(f"Running: gcsfuse --implicit-dirs {BUCKET_NAME} {MOUNT_POINT}")
!nohup gcsfuse --implicit-dirs {BUCKET_NAME} {MOUNT_POINT} > /tmp/gcsfuse.log 2>&1 &

# Wait a moment for mount to initialize
print("Waiting for mount to initialize...")
time.sleep(8)  # Give it more time to mount

# Check mount status
print("\nChecking mount status...")
# Check mount log for errors
if os.path.exists("/tmp/gcsfuse.log"):
    with open("/tmp/gcsfuse.log", "r") as f:
        log_content = f.read()
        if log_content:
            print("Mount log:")
            print(log_content[-500:])  # Last 500 chars
        else:
            print("Mount log is empty (mount might still be initializing)")

# Also verify we can access the bucket directly with gsutil
print("\nVerifying bucket access with gsutil...")
!gsutil ls gs://{BUCKET_NAME}/ | head -10

# Verify mount
print(f"\nVerifying mount at: {MOUNT_POINT}")
print(f"Path exists: {os.path.exists(MOUNT_POINT)}")

# Check if actually mounted using mountpoint command
try:
    result = subprocess.run(['mountpoint', '-q', MOUNT_POINT], capture_output=True)
    is_mounted = (result.returncode == 0)
    print(f"Is mounted: {is_mounted}")
except:
    # Fallback: check mount table
    result = subprocess.run(['mount'], capture_output=True, text=True)
    is_mounted = MOUNT_POINT in result.stdout
    print(f"Is mounted (from mount table): {is_mounted}")

if os.path.exists(MOUNT_POINT) and is_mounted:
    try:
        # Try to list contents
        items = os.listdir(MOUNT_POINT)
        if len(items) > 0:
            print(f"✅ GCS bucket mounted successfully!")
            print(f"Mount point: {MOUNT_POINT}")
            print(f"Found {len(items)} items in bucket")
            # List a few items to verify
            for item in items[:10]:
                item_path = os.path.join(MOUNT_POINT, item)
                item_type = "directory" if os.path.isdir(item_path) else "file"
                print(f"   - {item} ({item_type})")
        else:
            print(f"⚠️  Mount point exists but is empty (0 items)")
            print(f"   This might indicate:")
            print(f"   1. Bucket is empty")
            print(f"   2. Mount didn't work correctly")
            print(f"   3. Permission issues")
    except PermissionError as e:
        print(f"⚠️  Permission error accessing mount: {e}")
        print("   Mount might still be initializing, wait a moment and try again")
    except Exception as e:
        print(f"⚠️  Mount point exists but cannot list contents: {e}")
        print("   This might indicate a mount issue")
        import traceback
        traceback.print_exc()
elif os.path.exists(MOUNT_POINT) and not is_mounted:
    print(f"⚠️  Directory exists but is not mounted")
    print(f"   The directory exists but gcsfuse mount is not active")
    print(f"   Trying to mount again...")
    # Try mounting again
    !nohup gcsfuse --implicit-dirs {BUCKET_NAME} {MOUNT_POINT} > /tmp/gcsfuse.log 2>&1 &
    time.sleep(5)
    # Re-check
    items = os.listdir(MOUNT_POINT) if os.path.exists(MOUNT_POINT) else []
    if len(items) > 0:
        print(f"✅ Mount successful after retry! Found {len(items)} items")
    else:
        print(f"❌ Mount still not working")
else:
    print("❌ Failed to mount GCS bucket")
    print(f"   Mount point: {MOUNT_POINT}")
    print(f"   Check:")
    print(f"   1. GCP authentication (run the GCP auth cell)")
    print(f"   2. Bucket name is correct: {BUCKET_NAME}")
    print(f"   3. You have read access to the bucket")
    print(f"   4. Check mount log: /tmp/gcsfuse.log")


## 6. Create Data Symlink


In [None]:
# Create symlink from data to mounted GCS bucket (as expected by START_TRAINING.sh)
import os

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"
MOUNTED_DATA = os.path.join(PROJECT_ROOT, "Raster2Seq_internal-main", "data")
DATA_SYMLINK = os.path.join(PROJECT_ROOT, "data")

print(f"Checking mount point: {MOUNTED_DATA}")
print(f"  Exists: {os.path.exists(MOUNTED_DATA)}")
if os.path.exists(MOUNTED_DATA):
    print(f"  Is directory: {os.path.isdir(MOUNTED_DATA)}")
    try:
        items = os.listdir(MOUNTED_DATA)
        print(f"  Can list contents: Yes ({len(items)} items)")
    except Exception as e:
        print(f"  Can list contents: No ({e})")

# Remove existing symlink or directory if it exists
if os.path.exists(DATA_SYMLINK):
    if os.path.islink(DATA_SYMLINK):
        print(f"Removing existing symlink: {DATA_SYMLINK}")
        os.unlink(DATA_SYMLINK)
    elif os.path.isdir(DATA_SYMLINK):
        print(f"Warning: {DATA_SYMLINK} exists as a directory (not a symlink)")
        print("   Removing it to create symlink...")
        import shutil
        shutil.rmtree(DATA_SYMLINK)
    else:
        print(f"Warning: {DATA_SYMLINK} exists and is not a symlink or directory")
        os.remove(DATA_SYMLINK)

# Create symlink
if os.path.exists(MOUNTED_DATA) and os.path.isdir(MOUNTED_DATA):
    try:
        # Use absolute path for symlink target
        MOUNTED_DATA_ABS = os.path.abspath(MOUNTED_DATA)
        print(f"\nCreating symlink:")
        print(f"  From: {DATA_SYMLINK}")
        print(f"  To: {MOUNTED_DATA_ABS}")
        os.symlink(MOUNTED_DATA_ABS, DATA_SYMLINK)
        print(f"✅ Created symlink: {DATA_SYMLINK} -> {MOUNTED_DATA_ABS}")
        
        # Verify symlink
        if os.path.exists(DATA_SYMLINK):
            print(f"✅ Symlink verified: {DATA_SYMLINK}")
            print(f"  Is symlink: {os.path.islink(DATA_SYMLINK)}")
            # Try to list contents through symlink
            try:
                items = os.listdir(DATA_SYMLINK)
                print(f"✅ Can access {len(items)} items through symlink")
                print(f"   First 5 items: {items[:5]}")
            except Exception as e:
                print(f"⚠️  Symlink exists but cannot access contents: {e}")
        else:
            print(f"❌ Symlink creation failed - path does not exist after creation")
    except Exception as e:
        print(f"❌ Error creating symlink: {e}")
        print(f"   Source: {MOUNTED_DATA}")
        print(f"   Target: {DATA_SYMLINK}")
        import traceback
        traceback.print_exc()
else:
    print(f"❌ Mounted data not found at {MOUNTED_DATA}")
    print(f"   Please check that GCS bucket is mounted correctly")
    print(f"   Run the mount cell above and check for errors")
    print(f"   Mount point should exist and be accessible")


## 6.1. Create Cleaned Annotations Symlink

Create a symlink to access cleaned_annotations from the GCS bucket.


In [None]:
# Create symlink for cleaned_annotations from GCS bucket
import os

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"
MOUNTED_DATA = os.path.join(PROJECT_ROOT, "Raster2Seq_internal-main", "data")
CLEANED_ANNOTATIONS_SOURCE = os.path.join(MOUNTED_DATA, "cleaned_annotations")
CLEANED_ANNOTATIONS_SYMLINK = os.path.join(PROJECT_ROOT, "cleaned_annotations")

print("=" * 80)
print("CREATING CLEANED_ANNOTATIONS SYMLINK")
print("=" * 80)

# Check if source exists
if not os.path.exists(MOUNTED_DATA):
    print(f"❌ Mount point not found: {MOUNTED_DATA}")
    print("   Please run the 'Mount GCS Bucket' cell first.")
elif not os.path.exists(CLEANED_ANNOTATIONS_SOURCE):
    print(f"⚠️  cleaned_annotations not found at: {CLEANED_ANNOTATIONS_SOURCE}")
    print("   Checking what's available in the mount...")
    try:
        items = os.listdir(MOUNTED_DATA)
        print(f"   Found {len(items)} items in mount:")
        for item in sorted(items)[:10]:
            item_path = os.path.join(MOUNTED_DATA, item)
            item_type = "directory" if os.path.isdir(item_path) else "file"
            print(f"     - {item} ({item_type})")
        if len(items) > 10:
            print(f"     ... and {len(items) - 10} more")
    except Exception as e:
        print(f"   Error listing mount contents: {e}")
    print("\n   Please ensure cleaned_annotations exists in the GCS bucket.")
else:
    # Remove existing symlink or directory if it exists
    if os.path.exists(CLEANED_ANNOTATIONS_SYMLINK):
        if os.path.islink(CLEANED_ANNOTATIONS_SYMLINK):
            print(f"Removing existing symlink: {CLEANED_ANNOTATIONS_SYMLINK}")
            os.unlink(CLEANED_ANNOTATIONS_SYMLINK)
        elif os.path.isdir(CLEANED_ANNOTATIONS_SYMLINK):
            print(f"Warning: {CLEANED_ANNOTATIONS_SYMLINK} exists as a directory (not a symlink)")
            print("   Removing it to create symlink...")
            import shutil
            shutil.rmtree(CLEANED_ANNOTATIONS_SYMLINK)
        else:
            print(f"Warning: {CLEANED_ANNOTATIONS_SYMLINK} exists and is not a symlink or directory")
            os.remove(CLEANED_ANNOTATIONS_SYMLINK)
    
    # Create symlink
    try:
        CLEANED_ANNOTATIONS_SOURCE_ABS = os.path.abspath(CLEANED_ANNOTATIONS_SOURCE)
        print(f"\nCreating symlink:")
        print(f"  From: {CLEANED_ANNOTATIONS_SYMLINK}")
        print(f"  To: {CLEANED_ANNOTATIONS_SOURCE_ABS}")
        os.symlink(CLEANED_ANNOTATIONS_SOURCE_ABS, CLEANED_ANNOTATIONS_SYMLINK)
        print(f"✅ Created symlink: {CLEANED_ANNOTATIONS_SYMLINK} -> {CLEANED_ANNOTATIONS_SOURCE_ABS}")
        
        # Verify symlink
        if os.path.exists(CLEANED_ANNOTATIONS_SYMLINK):
            print(f"✅ Symlink verified: {CLEANED_ANNOTATIONS_SYMLINK}")
            print(f"  Is symlink: {os.path.islink(CLEANED_ANNOTATIONS_SYMLINK)}")
            # Try to list contents through symlink
            try:
                items = os.listdir(CLEANED_ANNOTATIONS_SYMLINK)
                print(f"✅ Can access {len(items)} annotation files through symlink")
                print(f"   Files: {', '.join(sorted(items)[:5])}")
                if len(items) > 5:
                    print(f"   ... and {len(items) - 5} more files")
            except Exception as e:
                print(f"⚠️  Symlink exists but cannot access contents: {e}")
        else:
            print(f"❌ Symlink creation failed - path does not exist after creation")
    except Exception as e:
        print(f"❌ Error creating symlink: {e}")
        print(f"   Source: {CLEANED_ANNOTATIONS_SOURCE}")
        print(f"   Target: {CLEANED_ANNOTATIONS_SYMLINK}")
        import traceback
        traceback.print_exc()

print("\n" + "=" * 80)


## 6.5. Clean Annotations (Optional but Recommended)

This step removes entries for non-existent images from annotation files. 
**⚠️ WARNING: This can take 1-3 hours on mounted GCS data due to network latency.**

**Options:**
- Clean all annotation files (default)
- Clean only a specific split (e.g., split 2) - much faster!
  - Set `CLEAN_SPLIT = 2` in the code cell to clean only split 2 files

You can skip this if you prefer to use `--skip_missing_at_runtime` during training.


In [None]:
# Clean annotation files by removing entries for non-existent images
# This will permanently modify the annotation JSON files in the annotations/ folder
# Backups are created automatically as *.json.backup

import os
import sys
import time
from pathlib import Path

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"
CLEAN_SCRIPT = os.path.join(PROJECT_ROOT, "clean_annotations.py")

print("=" * 80)
print("CLEANING ANNOTATION FILES")
print("=" * 80)
print("\n⚠️  WARNING: This process can take 1-3 hours on mounted GCS data.")
print("   The script needs to:")
print("   1. Scan all image files in the data directory (slow on GCS mount)")
print("   2. Check each annotation entry against existing files")
print("\n   Progress will be shown below. You can monitor the output.")
print("   The script creates backups (*.json.backup) before modifying files.")
print("\n" + "=" * 80 + "\n")

# Verify data symlink exists
DATA_SYMLINK = os.path.join(PROJECT_ROOT, "data")
if not os.path.exists(DATA_SYMLINK):
    print(f"❌ Data symlink not found: {DATA_SYMLINK}")
    print("   Please run the 'Create Data Symlink' cell above first.")
    sys.exit(1)

# Verify clean_annotations.py exists
if not os.path.exists(CLEAN_SCRIPT):
    print(f"❌ clean_annotations.py not found: {CLEAN_SCRIPT}")
    sys.exit(1)

# Check if annotations directory exists
ANNOTATIONS_DIR = os.path.join(PROJECT_ROOT, "annotations")
if not os.path.exists(ANNOTATIONS_DIR):
    print(f"❌ Annotations directory not found: {ANNOTATIONS_DIR}")
    sys.exit(1)

# Count annotation files
annotation_files = list(Path(ANNOTATIONS_DIR).glob("*.json"))
annotation_files = [f for f in annotation_files if not f.name.endswith(".backup")]
print(f"Found {len(annotation_files)} annotation file(s) in total:")
for f in annotation_files:
    print(f"  - {f.name}")

# Option to clean specific split (set to None to clean all files)
# Examples: CLEAN_SPLIT = 2  (cleans only split 2 files)
#           CLEAN_SPLIT = None  (cleans all files - default)
CLEAN_SPLIT = None  # Change this to clean only a specific split (e.g., 2 for split2)

# Build command arguments
clean_args = [CLEAN_SCRIPT]
if CLEAN_SPLIT is not None:
    clean_args.extend(["--split", str(CLEAN_SPLIT)])
    print(f"\n⚠️  Will clean only split {CLEAN_SPLIT} files")
    print("   (Change CLEAN_SPLIT variable above to clean a different split or set to None for all)")
else:
    print(f"\n⚠️  Will clean ALL annotation files")
    print("   (Set CLEAN_SPLIT variable above to clean only a specific split)")

# Ask for confirmation (optional - you can comment this out to auto-run)
print("\n" + "=" * 80)
print("Starting annotation cleanup...")
print("=" * 80 + "\n")

# Record start time
start_time = time.time()

# Run the cleanup script
try:
    import subprocess
    
    # Change to project directory and run the script
    print(f"Running: python {' '.join(clean_args)}\n")
    process = subprocess.Popen(
        ["python"] + clean_args,
        cwd=PROJECT_ROOT,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        universal_newlines=True,
        bufsize=1
    )
    
    # Print output in real-time
    for line in process.stdout:
        print(line, end='')
    
    # Wait for process to complete
    process.wait()
    
    if process.returncode != 0:
        raise subprocess.CalledProcessError(process.returncode, CLEAN_SCRIPT)
    
    # Calculate elapsed time
    elapsed_time = time.time() - start_time
    hours = int(elapsed_time // 3600)
    minutes = int((elapsed_time % 3600) // 60)
    seconds = int(elapsed_time % 60)
    
    print("\n" + "=" * 80)
    print(f"✅ Annotation cleanup completed!")
    print(f"   Time taken: {hours}h {minutes}m {seconds}s")
    print("=" * 80)
    print("\nNext steps:")
    print("  1. Check the cleanup report: annotation_cleanup_report.txt")
    print("  2. Backups are saved as *.json.backup in the annotations/ folder")
    print("  3. You can now run training without --skip_missing_at_runtime")
    print("     (or keep using it if you prefer)")
    
except Exception as e:
    print(f"\n❌ Error running clean_annotations.py: {e}")
    import traceback
    traceback.print_exc()
    print("\nYou can continue with training using --skip_missing_at_runtime instead.")


In [None]:
## 7. Run Training


# Run training using START_TRAINING.sh with "tiny" mode


In [None]:
# Run START_TRAINING.sh with "tiny" mode
import os

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"
TRAINING_SCRIPT = os.path.join(PROJECT_ROOT, "START_TRAINING.sh")

# Make script executable
!chmod +x {TRAINING_SCRIPT}

# Change to project directory and run training
print("Starting training with 'tiny' mode...")
print("This will run 5 epochs with batch_size 8 (~30-60 min)")
print("=" * 80)

!cd {PROJECT_ROOT} && bash {TRAINING_SCRIPT} tiny


In [None]:
## 8. Monitor Training (Optional)


In [None]:
# Check training logs
import json
import os
from pathlib import Path

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"
OUTPUT_DIR = os.path.join(PROJECT_ROOT, "output", "tiny_test", "tiny_test")
LOG_FILE = os.path.join(OUTPUT_DIR, "log.txt")

if os.path.exists(LOG_FILE):
    print(f"Reading log file: {LOG_FILE}")
    with open(LOG_FILE, 'r') as f:
        lines = f.readlines()
        print(f"Total log entries: {len(lines)}")
        if lines:
            print("\nLast 3 entries:")
            for line in lines[-3:]:
                try:
                    stats = json.loads(line.strip())
                    print(f"  Epoch {stats.get('epoch', 'N/A')}: ")
                    print(f"    Train Loss: {stats.get('train_loss', stats.get('loss', 'N/A'))}")
                    print(f"    Val Loss: {stats.get('test_loss', 'N/A')}")
                except:
                    pass
else:
    print(f"Log file not found: {LOG_FILE}")
    print("\nAvailable output directories:")
    output_base = os.path.join(PROJECT_ROOT, "output")
    if os.path.exists(output_base):
        for d in os.listdir(output_base):
            print(f"  - {os.path.join(output_base, d)}")


In [None]:
## 9. Download Results (Optional)


In [None]:
# Download checkpoints and logs
from google.colab import files
from pathlib import Path
import zipfile
import os

PROJECT_ROOT = "/content/category-agnostic-pose-estimation"
OUTPUT_BASE = os.path.join(PROJECT_ROOT, "output")

# Find all checkpoints
checkpoints = list(Path(OUTPUT_BASE).rglob("checkpoint*.pth"))

if checkpoints:
    print(f"Found {len(checkpoints)} checkpoint(s)")
    
    # Create zip with all checkpoints
    zip_path = "/content/checkpoints.zip"
    with zipfile.ZipFile(zip_path, 'w') as zipf:
        for cp in checkpoints:
            # Preserve relative path structure
            rel_path = os.path.relpath(cp, PROJECT_ROOT)
            zipf.write(cp, rel_path)
    
    print(f"\nDownloading {zip_path}...")
    files.download(zip_path)
    print("✅ Download complete!")
else:
    print("No checkpoints found yet.")
    print(f"Output directory: {OUTPUT_BASE}")
