# 🔬 OAI X-ray Inpainting - Google Colab

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/johnreynolds3d/OAI-inpainting/blob/master/notebooks/OAI_Inpainting_Colab.ipynb)

**Comprehensive testing and training of 9+ inpainting model variants on OAI knee X-ray data**

---

## 🎯 What This Notebook Does

This notebook tests and trains state-of-the-art image inpainting models on OAI (Osteoarthritis Initiative) knee X-ray data:

### 📊 Model Variants Available:
- **AOT-GAN**: CelebA-HQ, Places2, OAI-trained
- **ICT**: FFHQ, ImageNet, Places2_Nature, OAI-trained  
- **RePaint**: CelebA-HQ, ImageNet, Places2

**Total: 9 different model variants** tested on your OAI data!

---

## 🚀 Quick Start (5 Simple Steps)

### 1️⃣ **Setup Environment**
- Run: "Check GPU availability"
- Run: "Alternative setup method (if Git fails)"
→ Clones repo, creates symlinks, installs dependencies

### 2️⃣ **Mount Google Drive**
- Run: "Mount Google Drive and setup data"
→ Makes your OAI_untracked folder accessible

### 3️⃣ **Verify Setup**
- Run: "Verify installation and setup"
→ Quick check that imports work

### 4️⃣ **Generate Dataset Splits** (First time only)
- Run: "GENERATE DATASET SPLITS"
→ Creates **perfectly balanced** train/valid/test splits using **ALL 539 images**
→ 80/10/10 split with equal low/high BMD representation
→ Creates subset_4 (4 test images) for quick testing

### 5️⃣ **Test All Models**
- Run: "PIPELINE RUNNER SETUP" (imports functions)
- Run: "QUICK START - Test ALL Model Variants"
→ Tests all 9 models in ~30-60 minutes

**Done!** Results saved to Google Drive automatically.

---

## 📁 Data Organization

**Recommended setup:**
- ✅ **Code**: Cloned from GitHub (always up-to-date)
- ✅ **Data**: Stored in Google Drive (persistent across sessions)
- ✅ **Results**: Generated in Colab (downloadable as ZIP)

**Required data structure:**
```
OAI_untracked/
├── data/
│   ├── oai/
│   │   ├── img/          # 539 PNG X-ray images
│   │   ├── data.csv      # BMD values for each image
│   │   └── split.py      # Balanced split generator
│   └── pretrained/       # Model checkpoint files
│       ├── aot-gan/
│       ├── ict/
│       └── repaint/
```

**📊 New Balanced Split (Oct 2025):**
- ✅ Uses **ALL 539 images** (previously only ~268)
- ✅ **80% train** (431 images), **10% val** (53 images), **10% test** (55 images)
- ✅ **Perfectly balanced**: Equal low/high BMD in each split
- ✅ **Mutually exclusive**: No image overlap between splits

**💾 Results Persistence:**
Both `data/` and `results/` are symlinked to `OAI_untracked/` in Google Drive, ensuring:
- ✅ Results survive Colab session restarts
- ✅ Identical behavior on local machine and Colab
- ✅ All outputs automatically saved to Google Drive

---

## 🎓 Usage Modes

### Mode 1: **Quick Testing** (Recommended for first time)
- ⏱️ Time: ~30-60 minutes
- 🎯 Purpose: Test all pretrained models on 4 sample images
- 📊 Output: Comparison of all 9 model variants
- 💻 Run: Cell 11 (Comprehensive Test)

### Mode 2: **Full Training** (Advanced)
- ⏱️ Time: 6-8 hours
- 🎯 Purpose: Train new models on full OAI dataset
- 📊 Output: New trained models + evaluation
- 💻 Run: Cell 14 (Training Pipeline)

---


## 🔧 Environment Setup


In [None]:
# Check GPU availability
import torch

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")

In [None]:
# Mount Google Drive and setup data
from pathlib import Path

try:
    from google.colab import drive

    drive.mount("/content/drive")
    print("✅ Google Drive mounted successfully")

    # Check for OAI data in Google Drive
    oai_drive_path = Path("/content/drive/MyDrive/Colab Notebooks/OAI_untracked")
    if oai_drive_path.exists():
        print(f"✅ OAI data found in Google Drive: {oai_drive_path}")

        # Create symlinks to Google Drive data
        data_dir = Path("data")
        if not data_dir.exists():
            data_dir.mkdir()

        # Link to Google Drive data
        oai_link = data_dir / "oai"
        if not oai_link.exists():
            oai_link.symlink_to(oai_drive_path)
            print("🔗 Created symlink to Google Drive OAI data")
    else:
        print("❌ OAI data not found in Google Drive")
        print(
            "💡 Please ensure your data is in: /content/drive/MyDrive/Colab Notebooks/OAI_untracked"
        )

except ImportError:
    print("⚠️ Google Colab not detected - skipping Drive mount")
except Exception as e:
    print(f"❌ Error mounting Google Drive: {e}")

In [None]:
# Alternative setup method (if Git fails)
import os
import subprocess
import urllib.request
import zipfile
from pathlib import Path


def setup_repository_alternative():
    """Alternative method to setup repository without Git"""
    print("🔄 Using alternative setup method...")

    # IMPORTANT: Navigate to Google Drive to maintain sibling directory structure
    # This ensures relative symlinks work (data -> ../OAI_untracked/data/)
    # Making it portable between local and Colab environments
    try:
        from google.colab import drive

        colab_notebooks = Path("/content/drive/MyDrive/Colab Notebooks")
        if colab_notebooks.exists():
            os.chdir(colab_notebooks)
            print(f"✅ Working directory: {Path.cwd()}")
            print("💡 Using relative paths for portability")
        else:
            print("⚠️ Google Drive not mounted or path not found")
            print(f"📁 Current directory: {Path.cwd()}")
    except ImportError:
        print("⚠️ Not in Colab environment - using current directory")
        print(f"📁 Current directory: {Path.cwd()}")

    if not Path("OAI-inpainting").exists():
        print("📥 Downloading repository as ZIP...")
        try:
            # Download the repository as ZIP
            zip_url = "https://github.com/johnreynolds3d/OAI-inpainting/archive/refs/heads/master.zip"
            zip_path = "OAI-inpainting-master.zip"

            print(f"📥 Downloading from: {zip_url}")
            urllib.request.urlretrieve(zip_url, zip_path)

            # Extract the ZIP file
            print("📦 Extracting repository...")
            with zipfile.ZipFile(zip_path, "r") as zip_ref:
                zip_ref.extractall(".")

            # Rename the extracted folder
            if Path("OAI-inpainting-master").exists():
                if Path("OAI-inpainting").exists():
                    import shutil

                    shutil.rmtree("OAI-inpainting")
                Path("OAI-inpainting-master").rename("OAI-inpainting")

            # Clean up ZIP file
            Path(zip_path).unlink()
            print("✅ Repository downloaded and extracted successfully")

        except Exception as e:
            print(f"❌ Download failed: {e}")
            print("💡 Please check your internet connection and try again")
            return False

    # Change to project directory
    os.chdir("OAI-inpainting")
    print(f"📂 Current directory: {Path.cwd()}")

    # Create relative symlinks (portable between local and Colab)
    print("\n🔗 Setting up portable symlinks...")

    # Remove any broken files that might have been cloned from Git
    for name in ["data", "results"]:
        path = Path(name)
        if path.is_file():  # Broken symlink tracked as file
            path.unlink()
            print(f"🗑️  Removed broken {name} file")

    # Create data and results symlinks using relative paths (portable!)
    data_dir = Path("data")
    results_dir = Path("results")
    relative_data_path = Path("../OAI_untracked/data")
    relative_results_path = Path("../OAI_untracked/results")

    # Setup data symlink
    if not data_dir.exists():
        if relative_data_path.exists():
            data_dir.symlink_to(relative_data_path)
            print("✅ Created data symlink -> ../OAI_untracked/data/ (relative)")
        else:
            print("⚠️ OAI_untracked/data not found as sibling directory")
            print(
                "💡 Expected structure: parent_dir/OAI-inpainting/ and parent_dir/OAI_untracked/"
            )
    elif data_dir.is_symlink():
        print("✅ Data symlink already exists (relative)")

    # Setup results symlink (same as local - results persist in Google Drive!)
    if not results_dir.exists():
        # Ensure parent results directory exists
        if not relative_results_path.exists():
            try:
                relative_results_path.mkdir(parents=True, exist_ok=True)
                print("📁 Created results directory in OAI_untracked")
            except:
                pass  # May not have permissions, will fallback
        
        # Try to create symlink
        if relative_results_path.exists():
            results_dir.symlink_to(relative_results_path)
            print("✅ Created results symlink -> ../OAI_untracked/results/ (relative)")
            print("💾 Results will persist in Google Drive!")
        else:
            # Fallback: create local directory
            results_dir.mkdir(parents=True, exist_ok=True)
            print("⚠️  Results created as local directory (won't persist in Drive)")
    elif results_dir.is_symlink():
        print("✅ Results symlink already exists (relative)")
    else:
        print("⚠️  Results exists as real directory (won't persist in Drive)")
        print("💡 Delete it to create symlink for Google Drive persistence")

    # Install core dependencies manually
    print("\n📦 Installing core dependencies...")
    core_deps = [
        "torch",
        "torchvision",
        "numpy",
        "opencv-python",
        "pillow",
        "scikit-image",
        "scipy",
        "pandas",
        "matplotlib",
        "seaborn",
        "pyyaml",
        "tqdm",
        "tensorboard",
        "wandb",
        "scikit-learn",
    ]

    for dep in core_deps:
        try:
            subprocess.run(["pip", "install", dep], check=True, capture_output=True)
            print(f"✅ Installed {dep}")
        except subprocess.CalledProcessError:
            print(f"⚠️ Failed to install {dep}")

    print("\n✅ Alternative setup complete!")
    print(f"📁 Project location: {Path.cwd()}")
    print(
        f"📁 Data location: {Path('data').resolve() if Path('data').exists() else 'Not linked'}"
    )
    print(
        f"📁 Results location: {Path('results').resolve() if Path('results').exists() else 'Not linked'}"
    )
    return True


# Run alternative setup
setup_repository_alternative()

## 🔧 Troubleshooting

If you encounter issues with the setup:

1. **Git not available**: The notebook will automatically fall back to downloading the repository as a ZIP file
2. **Dependency installation fails**: Core dependencies will be installed individually
3. **Permission errors**: Restart the runtime and try again
4. **Import errors**: Make sure you're in the correct directory (`OAI-inpainting`)
5. **NotADirectoryError**: Run the "FIX: Results Directory Issue" cell below

**Quick fixes:**
- **Fix results directory**: Run the next cell (Results Directory Fix)
- Restart runtime: `Runtime → Restart runtime`
- Check directory: `!pwd`
- List files: `!ls -la`


In [None]:
# 🔧 FIX: Results Directory Issue
# Run this cell if you get "NotADirectoryError" when running tests
# This fixes the results directory by converting it from a symlink to a real directory

from pathlib import Path

print("🔧 Fixing results directory...")

# Try to navigate to project directory
try:
    project_dirs = [
        Path("/content/drive/MyDrive/Colab Notebooks/OAI-inpainting"),
        Path("/content/OAI-inpainting"),
        Path("OAI-inpainting"),
        Path.cwd(),
    ]

    for project_dir in project_dirs:
        if project_dir.exists() and (project_dir / "models").exists():
            os.chdir(project_dir)
            print(f"📁 Working directory: {os.getcwd()}")
            break
    else:
        print("⚠️ Could not find OAI-inpainting directory")
        print(f"📁 Current directory: {os.getcwd()}")

except Exception as e:
    print(f"⚠️ Error navigating: {e}")
    print(f"📁 Current directory: {os.getcwd()}")

# Check and fix results directory
results = Path("results")

if results.is_symlink():
    target = os.readlink(results)
    print(f"⚠️ Found symlink: results -> {target}")
    results.unlink()
    print("🗑️ Removed symlink")
elif results.is_file():
    print("⚠️ Found file instead of directory")
    results.unlink()
    print("🗑️ Removed file")
elif results.is_dir():
    print("✅ Results is already a real directory")
else:
    print("ℹ️ Results doesn't exist yet")

# Create as real directory
try:
    results.mkdir(parents=True, exist_ok=True)
    print(f"✅ Created real directory: {results.resolve()}")

    # Verify
    if results.is_dir() and not results.is_symlink():
        print("\n🎉 SUCCESS! Results is now a proper directory")
        print("📂 You can now run the comprehensive test!")
    else:
        print("\n⚠️ Results exists but may have issues")

except Exception as e:
    print(f"\n❌ Error creating directory: {e}")
    print("💡 Try: Runtime → Restart runtime, then re-run Cell 4")

In [None]:
# Verify installation and setup
try:
    from src.paths import get_project_root

    print("✅ Core modules imported successfully")
    print(f"📁 Project root: {get_project_root()}")
except ImportError as e:
    print(f"❌ Import error: {e}")

# Check available models
models_dir = Path("models")
if models_dir.exists():
    print("\n📋 Available models:")
    for model in models_dir.iterdir():
        if model.is_dir():
            print(f"  - {model.name}")
else:
    print("❌ Models directory not found")

In [None]:
# Setup data from Google Drive
import shutil
from pathlib import Path

print("📊 Setting up data from Google Drive...")

# Define paths
drive_data_path = Path("/content/drive/MyDrive/Colab Notebooks/OAI_untracked")
local_data_path = Path("data")

# Check if Google Drive data exists
if drive_data_path.exists():
    print(f"✅ Found OAI data in Google Drive: {drive_data_path}")

    # Check data structure
    oai_img_path = drive_data_path / "data" / "oai" / "img"
    pretrained_path = drive_data_path / "data" / "pretrained"

    if oai_img_path.exists():
        img_count = len(list(oai_img_path.glob("*.png")))
        print(f"📸 Found {img_count} OAI X-ray images")
    else:
        print("⚠️ OAI images not found in expected location")

    if pretrained_path.exists():
        model_files = list(pretrained_path.rglob("*.pt")) + list(
            pretrained_path.rglob("*.pth")
        )
        print(f"🤖 Found {len(model_files)} pretrained model files")
    else:
        print("⚠️ Pretrained models not found")

    # Check if data is already linked or setup
    if local_data_path.is_symlink():
        print("\n✅ Data already linked via symlink (from Cell 3)")
        print(f"   {local_data_path} → {os.readlink(local_data_path)}")
    elif (local_data_path / "oai").is_symlink():
        print("\n✅ Data already linked via symlink (from Cell 3)")
        print(f"   {local_data_path / 'oai'} → {os.readlink(local_data_path / 'oai')}")
    else:
        # Try to copy data for better performance
        print("\n📋 Attempting to copy data from Google Drive to local storage...")
        print("   (This improves performance but requires disk space)")

        # Copy only if not already copied (check for a marker file)
        marker_file = local_data_path / ".data_copied"
        if not marker_file.exists():
            try:
                # Create local data directory
                local_data_path.mkdir(parents=True, exist_ok=True)

                # Copy the entire data directory structure
                source_data_path = drive_data_path / "data"
                if source_data_path.exists():
                    # Use rsync-like approach with os.walk for better control
                    print("   Copying files... (this may take a few minutes)")
                    file_count = 0
                    for root, _dirs, files in os.walk(source_data_path):
                        rel_path = Path(root).relative_to(source_data_path)
                        dest_dir = local_data_path / rel_path
                        dest_dir.mkdir(parents=True, exist_ok=True)

                        for file in files:
                            src_file = Path(root) / file
                            dst_file = dest_dir / file
                            if not dst_file.exists():
                                shutil.copy2(src_file, dst_file)
                                file_count += 1
                                if file_count % 100 == 0:
                                    print(f"   Copied {file_count} files...")

                    print(f"✅ Data copied successfully ({file_count} files)")
                    marker_file.touch()
                else:
                    print("❌ Data directory not found in expected structure")
                    print("💡 Please check your Google Drive structure")

            except Exception as e:
                error_msg = str(e)[:200]  # Truncate to prevent massive output
                print(f"❌ Error copying data: {error_msg}...")
                print("💡 Falling back to symlink approach...")
                # Fallback to symlink
                try:
                    if local_data_path.exists() and not local_data_path.is_symlink():
                        shutil.rmtree(local_data_path)
                    if not local_data_path.exists():
                        local_data_path.symlink_to(drive_data_path / "data")
                        print("✅ Created symlink instead")
                except Exception as link_error:
                    print(f"❌ Symlink also failed: {link_error}")
        else:
            print("✅ Data already copied (skipping)")

    # Verify data structure
    print("\n📋 Verifying data structure...")
    oai_local_path = local_data_path / "oai"
    pretrained_local_path = local_data_path / "pretrained"

    # Check OAI data
    if oai_local_path.exists():
        img_path = oai_local_path / "img"
        if img_path.exists():
            img_count = len(list(img_path.glob("*.png")))
            print(f"  ✅ oai/img/ ({img_count} PNG files)")
        else:
            print("  ❌ oai/img/ (missing)")
    else:
        print("  ❌ oai/ directory (missing)")

    # Check pretrained models
    if pretrained_local_path.exists():
        model_files = list(pretrained_local_path.rglob("*.pt")) + list(
            pretrained_local_path.rglob("*.pth")
        )
        print(f"  ✅ pretrained/ ({len(model_files)} model files)")

        # Check specific model directories
        for model_dir in ["aot-gan", "ict", "repaint"]:
            model_path = pretrained_local_path / model_dir
            if model_path.exists():
                model_count = len(
                    list(model_path.rglob("*.pt")) + list(model_path.rglob("*.pth"))
                )
                print(f"    ✅ {model_dir}/ ({model_count} files)")
            else:
                print(f"    ❌ {model_dir}/ (missing)")
    else:
        print("  ❌ pretrained/ directory (missing)")

    # Check for generated directories (will be created by split.py)
    generated_dirs = ["train", "valid", "test"]
    for dir_name in generated_dirs:
        dir_path = oai_local_path / dir_name
        if dir_path.exists():
            file_count = len(list(dir_path.rglob("*")))
            print(f"  ✅ oai/{dir_name}/ ({file_count} files) - Generated")
        else:
            print(f"  ⏳ oai/{dir_name}/ (will be generated by split.py)")

else:
    print("❌ OAI data not found in Google Drive")
    print("💡 Please ensure your data is uploaded to:")
    print("   /content/drive/MyDrive/Colab Notebooks/OAI_untracked")
    print("\n📤 Expected structure:")
    print("   OAI_untracked/")
    print("   ├── data/")
    print("   │   ├── oai/")
    print("   │   │   └── img/          # 539 PNG files")
    print("   │   └── pretrained/       # Model files")
    print("   └── README.md")
    print("\n📤 To upload data:")
    print("1. Go to Google Drive")
    print("2. Navigate to 'Colab Notebooks' folder")
    print("3. Create 'OAI_untracked' folder")
    print("4. Upload your OAI dataset files there")

## 🔄 Update Repository


In [None]:
# 🔄 GENERATE DATASET SPLITS (Required before testing!)
# This creates PERFECTLY BALANCED train/valid/test splits using ALL 539 images
# • 80% train (431 images), 10% validation (53 images), 10% test (55 images)
# • Equal low/high BMD representation in each split
# • Mutually exclusive (no image overlap between splits)
# • Also creates subset_4 (4 test images) for quick testing

from pathlib import Path

print("=" * 60)
print("🔄 GENERATING PERFECTLY BALANCED DATASET SPLITS")
print("Using ALL 539 images with 80/10/10 split")
print("=" * 60)

# Check if splits already exist
subset_4_path = Path("data/oai/test/img/subset_4")
if subset_4_path.exists() and any(subset_4_path.glob("*.png")):
    file_count = len(list(subset_4_path.glob("*.png")))
    print("✅ Splits already exist!")
    print(f"✅ Found {file_count} files in subset_4")
    print("\n💡 Skipping split generation - already done")
else:
    print("⚠️  Splits not found - generating now...")
    print("This will create:")
    print("  • Train/valid/test splits")
    print("  • Random masks")
    print("  • Edge maps")
    print("  • Inverted masks")
    print("  • subset_4 evaluation set (4 images)")

    # Run split.py
    result = subprocess.run(
        ["python", "split.py"],
        check=False,
        cwd="data/oai",
        capture_output=True,
        text=True,
        timeout=300,
    )

    if result.returncode == 0:
        print("\n✅ PERFECTLY BALANCED dataset splits generated successfully!")
        print("\n📊 Split Summary:")

        # Count actual files in each split
        train_count = (
            len(list(Path("data/oai/train/img").glob("*.png")))
            if Path("data/oai/train/img").exists()
            else 0
        )
        val_count = (
            len(list(Path("data/oai/valid/img").glob("*.png")))
            if Path("data/oai/valid/img").exists()
            else 0
        )
        test_count = (
            len(list(Path("data/oai/test/img").glob("*.png")))
            if Path("data/oai/test/img").exists()
            else 0
        )
        total_count = train_count + val_count + test_count

        print(
            f"  📁 Train: {train_count} images ({train_count / total_count * 100:.1f}%)"
        )
        print(f"  📁 Valid: {val_count} images ({val_count / total_count * 100:.1f}%)")
        print(
            f"  📁 Test:  {test_count} images ({test_count / total_count * 100:.1f}%)"
        )
        print(f"  📁 TOTAL: {total_count} images")

        # Verify subset_4 was created
        if subset_4_path.exists():
            img_count = len(list(subset_4_path.glob("*.png")))
            print(
                f"\n✅ subset_4 created with {img_count} images (2 low BMD + 2 high BMD)"
            )

            # Verify all required directories
            required_dirs = [
                "data/oai/test/img/subset_4",
                "data/oai/test/mask/subset_4",
                "data/oai/test/edge/subset_4",
                "data/oai/test/mask_inv/subset_4",
            ]

            print("\n📋 Verification:")
            for dir_path in required_dirs:
                p = Path(dir_path)
                if p.exists():
                    count = len(list(p.glob("*.png")))
                    print(
                        f"  ✅ {dir_path.split('/')[-2]}/{dir_path.split('/')[-1]}/: {count} files"
                    )
                else:
                    print(f"  ❌ {dir_path}: missing")

            print("\n🎉 Ready for training, testing, and evaluation!")
        else:
            print("⚠️ subset_4 not created - check output above")
            if result.stdout:
                print(f"Output: {result.stdout}")
    else:
        print(f"❌ Split failed with exit code {result.returncode}")
        print(f"Error: {result.stderr}")
        print("\n💡 Try running manually:")
        print("  !cd data/oai && python split.py")

## 🚀 Training, Testing & Evaluation

Now that you have the perfectly balanced dataset splits, you can:
- **Test** pretrained models on the test set
- **Train** new models on the full 431-image training set
- **Evaluate** model performance with comprehensive metrics


In [None]:
# 🚀 PIPELINE RUNNER SETUP
# Import the pipeline runner functions

import sys

sys.path.append("scripts")

try:
    from colab_comprehensive_test import ModelTester
    from colab_pipeline import run_full_pipeline, run_phase

    print("✅ Pipeline runner imported successfully!")
    print("")
    print("🎯 Available functions:")
    print("")
    print("📊 COMPREHENSIVE TESTING (Recommended):")
    print("  - ModelTester().run_comprehensive_test() - Test ALL 9 model variants")
    print("     • AOT-GAN: CelebA-HQ, Places2, OAI")
    print("     • ICT: FFHQ, ImageNet, Places2_Nature, OAI")
    print("     • RePaint: CelebA-HQ, ImageNet, Places2")
    print("")
    print("🔄 PHASED PIPELINE (For training):")
    print("  - run_full_pipeline() - Run all 5 phases")
    print("  - run_phase(1) - Quick verification")
    print("  - run_phase(2) - AOT-GAN training")
    print("  - run_phase(3) - ICT training")
    print("  - run_phase(4) - RePaint inference")
    print("  - run_phase(5) - Evaluation")
    print("")
    print("💡 Ready to use! Go to the next cell to run commands.")
except ImportError as e:
    print(f"❌ Import error: {e}")
    print("💡 Make sure you've run the repository setup cell first")
    print("💡 The scripts should be in the scripts/ directory")

## 🎯 Quick Workflow: Train, Test & Evaluate on Balanced Dataset

The cells below provide different workflows based on your needs:


In [None]:
# 🚀 OPTION 1: Quick Test on Balanced Dataset (Recommended First Step)
# Tests all 9 pretrained model variants on the subset_4 (4 balanced test images)
# ⏱️ Time: ~30-60 minutes

print("=" * 60)
print("🎯 QUICK TEST: All Models on Balanced subset_4")
print("=" * 60)
print("Testing 9 model variants on 4 balanced images:")
print("  • 2 low BMD images")
print("  • 2 high BMD images")
print("")

tester = ModelTester(timeout_per_model=600, verbose=True)
results = tester.run_comprehensive_test(models=["all"])

print("\n" + "=" * 60)
print("🎉 QUICK TEST COMPLETE!")
print("=" * 60)
print(f"✅ Successful: {results['summary']['successful']}")
print(f"❌ Failed: {results['summary']['failed']}")
print(f"⏭️ Skipped: {results['summary']['skipped']}")
print("")
print("📁 Check results/ directory for output images")
print("📊 Check results/comprehensive_test_results.json for detailed results")

In [None]:
# 🎓 OPTION 2: Train on Full Balanced Dataset & Evaluate
# Trains models on the full 431-image training set with balanced BMD representation
# ⏱️ Time: 6-8 hours (runs all training phases)
#
# This will:
# 1. Train AOT-GAN on 431 balanced training images
# 2. Train ICT (Transformer) on 431 balanced training images
# 3. Run RePaint inference on test set
# 4. Evaluate all models on the 55-image balanced test set
# 5. Generate comprehensive metrics and visualizations

# Uncomment to run full training pipeline
# print("=" * 60)
# print("🎓 FULL TRAINING PIPELINE ON BALANCED DATASET")
# print("=" * 60)
# print("Training on:")
# print("  • 431 images (215 low BMD + 216 high BMD)")
# print("  • 53 validation images (26 low BMD + 27 high BMD)")
# print("  • 55 test images (28 low BMD + 27 high BMD)")
# print("")
#
# run_full_pipeline(timeout_hours=8)

print("💡 To run full training pipeline:")
print("   1. Uncomment the code above")
print("   2. Run this cell")
print("   3. Wait 6-8 hours for training to complete")
print("")
print("📊 Benefits of the new balanced split:")
print("   • 2x more training data (431 vs ~216)")
print("   • Equal representation of low/high BMD cases")
print("   • Better model generalization")
print("   • More reliable evaluation metrics")

In [None]:
# 🔍 OPTION 3: Test Specific Model on Full Balanced Test Set
# Test a single model on all 55 balanced test images (instead of just subset_4)
# ⏱️ Time: ~5-15 minutes per model

import subprocess


def test_model_on_full_test_set(model_name, config_path):
    """Test a specific model on the full 55-image balanced test set."""
    print(f"🧪 Testing {model_name} on full balanced test set (55 images)...")
    print(f"   Using config: {config_path}")

    # Run the appropriate test script based on model
    if "aot-gan" in model_name.lower():
        cmd = [
            "python",
            "scripts/test.py",
            "--model",
            "aot-gan",
            "--config",
            config_path,
        ]
    elif "ict" in model_name.lower():
        cmd = ["python", "scripts/test.py", "--model", "ict", "--config", config_path]
    elif "repaint" in model_name.lower():
        cmd = [
            "python",
            "scripts/test.py",
            "--model",
            "repaint",
            "--config",
            config_path,
        ]
    else:
        print(f"❌ Unknown model: {model_name}")
        return

    result = subprocess.run(
        cmd, check=False, capture_output=True, text=True, timeout=1800
    )

    if result.returncode == 0:
        print(f"✅ {model_name} testing completed!")
        print(f"📁 Results saved to results/{model_name}/")
    else:
        print(f"❌ {model_name} testing failed")
        print(f"Error: {result.stderr}")


# Example usage (uncomment to run):
# test_model_on_full_test_set("AOT-GAN-OAI", "configs/oai_config.yml")
# test_model_on_full_test_set("ICT-OAI", "models/ict/Guided_Upsample/subset_4_config.yml")

print("💡 To test a specific model on the full balanced test set:")
print("   1. Uncomment one of the examples above")
print("   2. Adjust the model name and config path as needed")
print("   3. Run this cell")
print("")
print("📊 Full test set: 55 images (28 low BMD + 27 high BMD)")
print("   vs. subset_4: 4 images (2 low BMD + 2 high BMD)")

In [None]:
# 🚀 QUICK START - Test ALL Model Variants (RECOMMENDED)
# This tests all 9 pretrained models on 4 OAI X-ray images
# Perfect for quick validation and comparison!

tester = ModelTester(timeout_per_model=600, verbose=True)
results = tester.run_comprehensive_test(models=["all"])

print("\n" + "=" * 60)
print("🎉 COMPREHENSIVE TEST COMPLETE!")
print("=" * 60)
print(f"✅ Successful: {results['summary']['successful']}")
print(f"❌ Failed: {results['summary']['failed']}")
print(f"⏭️ Skipped: {results['summary']['skipped']}")
print("")
print("📁 Check results/ directory for output images")
print("📊 Check results/comprehensive_test_results.json for detailed results")

In [None]:
# 🔄 ALTERNATIVE - Test Specific Model Types
# If you only want to test specific models:

# Test only AOT-GAN variants
# tester = ModelTester(timeout_per_model=600)
# results = tester.run_comprehensive_test(models=["aot-gan"])

# Test only ICT variants
# tester = ModelTester(timeout_per_model=600)
# results = tester.run_comprehensive_test(models=["ict"])

# Test only RePaint variants
# tester = ModelTester(timeout_per_model=600)
# results = tester.run_comprehensive_test(models=["repaint"])

# Test multiple specific models
# tester = ModelTester(timeout_per_model=600)
# results = tester.run_comprehensive_test(models=["aot-gan", "repaint"])

print("💡 Uncomment the code above to test specific model types")
print("💡 Or run the previous cell to test all models at once")

## 🎓 Training Pipeline (Optional)

If you want to train new models on OAI data instead of just testing pretrained models:


In [None]:
# 🎓 TRAINING PIPELINE - Train models on OAI data
# Uncomment to run the full training pipeline (6-8 hours)

# Run complete pipeline (all phases)
# run_full_pipeline(timeout_hours=8)

# Or run individual phases:
# run_phase(1)  # Quick verification (5 min)
# run_phase(2)  # AOT-GAN training (2-4 hours)
# run_phase(3)  # ICT training (1-3 hours)
# run_phase(4)  # RePaint inference (30 min)
# run_phase(5)  # Evaluation (15 min)

print("📋 Training Pipeline Options:")
print("")
print("⚡ Quick verification only:")
print("   run_phase(1)")
print("")
print("🔧 Train specific model:")
print("   run_phase(2)  # AOT-GAN only")
print("   run_phase(3)  # ICT only")
print("   run_phase(4)  # RePaint inference only")
print("")
print("🚀 Train all models:")
print("   run_full_pipeline(timeout_hours=8)")
print("")
print("⚠️  Warning: Training takes 6-8 hours total")
print("💡 Tip: Test with run_phase(1) first to verify setup")

## 🎨 Comparison Strips Generation


In [None]:
# 🎨 GENERATE COMPARISON STRIPS
# Creates horizontal comparison images: GT, GT+Mask, and all model outputs
# Perfect for visual assessment and thesis inclusion!

import sys

sys.path.append("scripts")

try:
    from generate_comparison_strips import main as generate_strips

    print("🎨 Generating visual comparison strips...")
    print("This creates horizontal strips showing all model outputs side-by-side")
    print("")

    strip_paths = generate_strips()

    if strip_paths:
        print(f"\n✅ Generated {len(strip_paths)} comparison strips!")
        print("📁 Location: results/comparison_strips/")
        print("")
        print("📸 Files created:")
        for path in strip_paths:
            print(f"  - {path.name}")
        print("  - all_comparisons_summary.png (all strips stacked)")
        print("")
        print("💡 Each strip shows:")
        print("   GT → GT+Mask → AOT-GAN variants → ICT variants → RePaint variants")
    else:
        print("⚠️  No strips generated - check that comprehensive test completed")

except ImportError as e:
    print(f"❌ Import error: {e}")
    print("💡 Make sure you've run the setup cells first")
except Exception as e:
    print(f"❌ Error generating strips: {e}")
    print("💡 Ensure comprehensive test completed successfully")

## 📊 Results Visualization and Download


In [None]:
# 📊 Visualize Results
import json
from pathlib import Path

import matplotlib.pyplot as plt
from PIL import Image

# Check if results exist
results_json = Path("results/comprehensive_test_results.json")
if results_json.exists():
    with results_json.open() as f:
        test_results = json.load(f)

    print("=" * 60)
    print("📊 TEST RESULTS SUMMARY")
    print("=" * 60)
    print(f"Timestamp: {test_results['timestamp']}")
    print(f"Duration: {test_results['duration']}")
    print(f"Total tests: {test_results['summary']['total']}")
    print(f"Successful: {test_results['summary']['successful']}")
    print(f"Failed: {test_results['summary']['failed']}")
    print(f"Skipped: {test_results['summary']['skipped']}")
    print("")

    # Show which models passed
    print("✅ SUCCESSFUL MODELS:")
    for result in test_results["results"]:
        if result["success"]:
            print(f"  • {result['model']} ({result['elapsed']:.1f}s)")

    # Show which models were skipped/failed
    if test_results["summary"]["skipped"] > 0:
        print("")
        print("⏭️ SKIPPED MODELS:")
        for result in test_results["results"]:
            if "reason" in result:
                print(f"  • {result['model']} ({result['reason']})")

    if test_results["summary"]["failed"] > 0:
        print("")
        print("❌ FAILED MODELS:")
        for result in test_results["results"]:
            if not result["success"] and "reason" not in result:
                print(f"  • {result['model']}")

else:
    print("⚠️  No test results found. Run the comprehensive test first!")
    print("💡 Execute the 'QUICK START' cell above to generate results")

In [None]:
# 📸 Display Sample Results (if available)
# Show inpainted images from successful tests

results_base = Path("results")

# Find the first successful result directory
sample_dirs = []
for model_type in ["AOT-GAN", "ICT", "RePaint"]:
    model_path = results_base / model_type
    if model_path.exists():
        for variant in model_path.iterdir():
            subset_4_path = variant / "subset_4"
            if subset_4_path.exists():
                # Check for images
                image_files = list(subset_4_path.rglob("*.png"))
                if image_files:
                    sample_dirs.append((f"{model_type} {variant.name}", subset_4_path))

if sample_dirs:
    print(f"📸 Found results from {len(sample_dirs)} model variants")
    print("")
    print("Available results:")
    for i, (name, path) in enumerate(sample_dirs):
        image_count = len(list(path.rglob("*.png")))
        print(f"  {i + 1}. {name} ({image_count} images)")

    # Display sample from first available model
    if len(sample_dirs) > 0:
        model_name, sample_path = sample_dirs[0]
        sample_images = sorted(sample_path.rglob("*.png"))[:4]

        if sample_images:
            print(f"\n🖼️  Displaying samples from: {model_name}")
            print(f"📁 Path: {sample_path}")

            # Create grid
            fig, axes = plt.subplots(1, min(4, len(sample_images)), figsize=(15, 4))
            if len(sample_images) == 1:
                axes = [axes]

            for idx, img_path in enumerate(sample_images[:4]):
                img = Image.open(img_path)
                if idx < len(axes):
                    axes[idx].imshow(img, cmap="gray")
                    axes[idx].set_title(img_path.name)
                    axes[idx].axis("off")

            plt.tight_layout()
            plt.show()

            print("")
            print("💡 To view more results, explore the results/ directory")
else:
    print("⚠️  No result images found yet")
    print("💡 Run the comprehensive test first to generate results")

In [None]:
# 💾 Download Results
# Package and download all results as a ZIP file

import shutil
from datetime import datetime

# Check if we're in Colab
try:
    from google.colab import files

    IN_COLAB = True
except ImportError:
    IN_COLAB = False

if IN_COLAB:
    results_dir = Path("results")

    if results_dir.exists() and any(results_dir.iterdir()):
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        zip_filename = f"oai_inpainting_results_{timestamp}.zip"

        print(f"📦 Creating results archive: {zip_filename}")
        print("This may take a few minutes...")

        # Create ZIP file
        with zipfile.ZipFile(zip_filename, "w", zipfile.ZIP_DEFLATED) as zipf:
            for file_path in results_dir.rglob("*"):
                if file_path.is_file():
                    arcname = file_path.relative_to(results_dir.parent)
                    zipf.write(file_path, arcname)
                    if file_path.suffix in [".png", ".jpg", ".json", ".txt"]:
                        print(f"  Added: {arcname}")

        print(f"✅ Archive created: {zip_filename}")
        print(f"📦 Size: {Path(zip_filename).stat().st_size / 1024 / 1024:.1f} MB")
        print("")
        print("⬇️  Downloading...")

        # Download the file
        files.download(zip_filename)

        print("✅ Download initiated!")
        print("💡 Check your browser's download folder")
    else:
        print("⚠️  No results found to download")
        print("💡 Run the comprehensive test first to generate results")
else:
    print("⚠️  Not running in Google Colab")
    print("💡 This cell is designed for Colab environment")
    print("📁 Results are available locally in the results/ directory")

In [None]:
# Update repository to latest version
print("🔄 Updating repository...")

try:
    # Fetch latest changes
    subprocess.run(["git", "fetch"], check=True)

    # Check for updates
    result = subprocess.run(
        ["git", "status", "-uno"], check=False, capture_output=True, text=True
    )
    print("📋 Repository status:")
    print(result.stdout)

    # Pull latest changes
    subprocess.run(["git", "pull"], check=True)
    print("✅ Repository updated to latest version")

    # Reinstall dependencies if needed
    print("📦 Reinstalling dependencies...")
    subprocess.run(["pip", "install", "-e", ".[dev,ml]"], check=True)
    print("✅ Dependencies updated")

except Exception as e:
    print(f"❌ Update failed: {e}")
    print("💡 You may need to restart the runtime and re-run the setup cell.")

In [None]:
# Troubleshooting and Tips
print("🔧 Troubleshooting and Tips")
print("=" * 50)

print("\n📋 Common Issues and Solutions:")
print("1. ❌ 'run_phase' is not defined")
print("   💡 Solution: Run the 'Pipeline Runner Setup' cell first")
print("")
print("2. ❌ Import errors")
print("   💡 Solution: Restart runtime and re-run all setup cells")
print("")
print("3. ❌ Data not found")
print("   💡 Solution: Check Google Drive mounting and data structure")
print("")
print("4. ❌ GPU not available")
print("   💡 Solution: Enable GPU in Runtime → Change runtime type")
print("")
print("5. ❌ Out of memory")
print("   💡 Solution: Reduce batch size or restart runtime")

print("\n🚀 Quick Commands:")
print("• Check GPU: !nvidia-smi")
print("• Check directory: !pwd")
print("• List files: !ls -la")
print("• Check Python: !python --version")
print("• Check PyTorch: !python -c 'import torch; print(torch.__version__)'")

print("\n📞 Getting Help:")
print("• Check the README.md for detailed instructions")
print("• Review error messages carefully")
print("• Restart runtime if issues persist")
print("• Ensure all setup cells have run successfully")