# OAI X-ray Inpainting - Google Colab

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/johnreynolds3d/OAI-inpainting/blob/master/notebooks/OAI_Inpainting_Colab.ipynb)

This notebook provides a complete environment for training and testing inpainting models on the OAI dataset.

## 🚀 Quick Start

1. **Setup Environment** - Run the setup cell below
2. **Connect Google Drive** - Mount your Google Drive with OAI data
3. **Train Models** - Run training for AOT-GAN, ICT, or RePaint
4. **Test Models** - Evaluate model performance
5. **Download Results** - Save results to your local machine

## 📁 Data Setup

**Recommended approach:**
- **Code**: Cloned from GitHub (always up-to-date)
- **Data**: Stored in Google Drive at `/content/drive/MyDrive/Colab Notebooks/OAI_untracked`
- **Results**: Generated in Colab, downloadable when complete

This hybrid approach gives you the best of both worlds: always-current code with persistent data storage.

---


## 🔧 Environment Setup


In [None]:
# Check GPU availability
import torch

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")

In [None]:
# Alternative setup method (if Git fails)
import os
import subprocess
import urllib.request
import zipfile
from pathlib import Path


def setup_repository_alternative():
    """Alternative method to setup repository without Git"""
    print("🔄 Using alternative setup method...")

    if not Path("OAI-inpainting").exists():
        print("📥 Downloading repository as ZIP...")
        try:
            # Download the repository as ZIP
            zip_url = "https://github.com/johnreynolds3d/OAI-inpainting/archive/refs/heads/master.zip"
            zip_path = "OAI-inpainting-master.zip"

            print(f"📥 Downloading from: {zip_url}")
            urllib.request.urlretrieve(zip_url, zip_path)

            # Extract the ZIP file
            print("📦 Extracting repository...")
            with zipfile.ZipFile(zip_path, "r") as zip_ref:
                zip_ref.extractall(".")

            # Rename the extracted folder
            if Path("OAI-inpainting-master").exists():
                if Path("OAI-inpainting").exists():
                    import shutil

                    shutil.rmtree("OAI-inpainting")
                Path("OAI-inpainting-master").rename("OAI-inpainting")

            # Clean up ZIP file
            Path(zip_path).unlink()
            print("✅ Repository downloaded and extracted successfully")

        except Exception as e:
            print(f"❌ Download failed: {e}")
            print("💡 Please check your internet connection and try again")
            return False

    # Change to project directory
    os.chdir("OAI-inpainting")
    print(f"📂 Current directory: {Path.cwd()}")

    # Install core dependencies manually
    print("📦 Installing core dependencies...")
    core_deps = [
        "torch",
        "torchvision",
        "numpy",
        "opencv-python",
        "pillow",
        "scikit-image",
        "scipy",
        "pandas",
        "matplotlib",
        "seaborn",
        "pyyaml",
        "tqdm",
        "tensorboard",
        "wandb",
        "scikit-learn",
    ]

    for dep in core_deps:
        try:
            subprocess.run(["pip", "install", dep], check=True)
            print(f"✅ Installed {dep}")
        except subprocess.CalledProcessError:
            print(f"⚠️ Failed to install {dep}")

    print("✅ Alternative setup complete!")
    return True


# Run alternative setup
setup_repository_alternative()

## 🔧 Troubleshooting

If you encounter issues with the setup:

1. **Git not available**: The notebook will automatically fall back to downloading the repository as a ZIP file
2. **Dependency installation fails**: Core dependencies will be installed individually
3. **Permission errors**: Restart the runtime and try again
4. **Import errors**: Make sure you're in the correct directory (`OAI-inpainting`)

**Quick fixes:**
- Restart runtime: `Runtime → Restart runtime`
- Check directory: `!pwd`
- List files: `!ls -la`


In [None]:
# Clone repository and install dependencies

# Clone the repository
if not Path("OAI-inpainting").exists():
    print("📥 Cloning repository...")
    subprocess.run(
        ["git", "clone", "https://github.com/johnreynolds3d/OAI-inpainting.git"],
        check=True,
    )
else:
    print("📁 Repository already exists, updating...")
    os.chdir("OAI-inpainting")
    subprocess.run(["git", "pull"], check=True)
    os.chdir("..")

# Change to project directory
os.chdir("OAI-inpainting")
print(f"📂 Current directory: {Path.cwd()}")

# Install dependencies
print("📦 Installing dependencies...")
subprocess.run(["pip", "install", "-e", ".[dev,ml]"], check=True)

print("✅ Setup complete!")

In [None]:
# Verify installation
try:
    from src.paths import get_project_root

    print("✅ Core modules imported successfully")
    print(f"📁 Project root: {get_project_root()}")
except ImportError as e:
    print(f"❌ Import error: {e}")

# Check available models
models_dir = Path("models")
if models_dir.exists():
    print("\n📋 Available models:")
    for model in models_dir.iterdir():
        if model.is_dir():
            print(f"  - {model.name}")
else:
    print("❌ Models directory not found")

# Check Google Drive connection
try:
    from google.colab import drive

    drive.mount("/content/drive")
    print("\n✅ Google Drive mounted successfully")

    # Check for OAI data in Google Drive
    oai_drive_path = Path("/content/drive/MyDrive/Colab Notebooks/OAI_untracked")
    if oai_drive_path.exists():
        print(f"✅ OAI data found in Google Drive: {oai_drive_path}")

        # Create symlinks to Google Drive data
        data_dir = Path("data")
        if not data_dir.exists():
            data_dir.mkdir()

        # Link to Google Drive data
        oai_link = data_dir / "oai"
        if not oai_link.exists():
            oai_link.symlink_to(oai_drive_path)
            print("🔗 Created symlink to Google Drive OAI data")
    else:
        print("❌ OAI data not found in Google Drive")
        print(
            "💡 Please ensure your data is in: /content/drive/MyDrive/Colab Notebooks/OAI_untracked"
        )

except ImportError:
    print("⚠️ Google Colab not detected - skipping Drive mount")
except Exception as e:
    print(f"❌ Error mounting Google Drive: {e}")

## 🔄 Update Repository


## 📊 Data Management


In [None]:
# Setup data from Google Drive
import shutil
from pathlib import Path

print("📊 Setting up data from Google Drive...")

# Define paths
drive_data_path = Path("/content/drive/MyDrive/Colab Notebooks/OAI_untracked")
local_data_path = Path("data/oai")

# Check if Google Drive data exists
if drive_data_path.exists():
    print(f"✅ Found OAI data in Google Drive: {drive_data_path}")

    # Create local data directory
    local_data_path.mkdir(parents=True, exist_ok=True)

    # Copy data from Google Drive to local (for better performance)
    print("📋 Copying data from Google Drive to local storage...")

    # Copy only if not already copied (check for a marker file)
    marker_file = local_data_path / ".data_copied"
    if not marker_file.exists():
        try:
            # Copy the entire OAI data directory
            if (drive_data_path / "data" / "oai").exists():
                source_path = drive_data_path / "data" / "oai"
            else:
                source_path = drive_data_path

            shutil.copytree(source_path, local_data_path, dirs_exist_ok=True)

            # Create marker file
            marker_file.touch()
            print("✅ Data copied successfully")
        except Exception as e:
            print(f"❌ Error copying data: {e}")
            print("💡 Falling back to symlink approach...")
            # Fallback to symlink
            if local_data_path.exists():
                shutil.rmtree(local_data_path)
            local_data_path.symlink_to(drive_data_path)
    else:
        print("✅ Data already copied (skipping)")

    # Verify data structure
    print("\n📋 Verifying data structure...")
    required_dirs = ["img", "train", "valid", "test"]
    for dir_name in required_dirs:
        dir_path = local_data_path / dir_name
        if dir_path.exists():
            file_count = len(list(dir_path.rglob("*")))
            print(f"  ✅ {dir_name}/ ({file_count} files)")
        else:
            print(f"  ❌ {dir_name}/ (missing)")

else:
    print("❌ OAI data not found in Google Drive")
    print("💡 Please ensure your data is uploaded to:")
    print("   /content/drive/MyDrive/Colab Notebooks/OAI_untracked")
    print("\n📤 To upload data:")
    print("1. Go to Google Drive")
    print("2. Navigate to 'Colab Notebooks' folder")
    print("3. Create 'OAI_untracked' folder")
    print("4. Upload your OAI dataset files there")

In [None]:
# Enhanced data setup from Google Drive
print("📊 Setting up data from Google Drive...")

# Define paths
drive_data_path = Path("/content/drive/MyDrive/Colab Notebooks/OAI_untracked")
local_data_path = Path("data")

# Check if Google Drive data exists
if drive_data_path.exists():
    print(f"✅ Found OAI data in Google Drive: {drive_data_path}")

    # Check data structure
    oai_img_path = drive_data_path / "data" / "oai" / "img"
    pretrained_path = drive_data_path / "data" / "pretrained"

    if oai_img_path.exists():
        img_count = len(list(oai_img_path.glob("*.png")))
        print(f"📸 Found {img_count} OAI X-ray images")
    else:
        print("⚠️ OAI images not found in expected location")

    if pretrained_path.exists():
        model_files = list(pretrained_path.rglob("*.pt")) + list(
            pretrained_path.rglob("*.pth")
        )
        print(f"🤖 Found {len(model_files)} pretrained model files")
    else:
        print("⚠️ Pretrained models not found")

    # Create local data directory
    local_data_path.mkdir(parents=True, exist_ok=True)

    # Copy data from Google Drive to local (for better performance)
    print("\n📋 Copying data from Google Drive to local storage...")

    # Copy only if not already copied (check for a marker file)
    marker_file = local_data_path / ".data_copied"
    if not marker_file.exists():
        try:
            # Copy the entire data directory structure
            source_data_path = drive_data_path / "data"
            if source_data_path.exists():
                shutil.copytree(source_data_path, local_data_path, dirs_exist_ok=True)
                print("✅ Data copied successfully")
            else:
                print("❌ Data directory not found in expected structure")
                print("💡 Please check your Google Drive structure")

            # Create marker file
            marker_file.touch()

        except Exception as e:
            print(f"❌ Error copying data: {e}")
            print("💡 Falling back to symlink approach...")
            # Fallback to symlink
            if local_data_path.exists():
                shutil.rmtree(local_data_path)
            local_data_path.symlink_to(drive_data_path / "data")
    else:
        print("✅ Data already copied (skipping)")

    # Verify data structure
    print("\n📋 Verifying data structure...")
    oai_local_path = local_data_path / "oai"
    pretrained_local_path = local_data_path / "pretrained"

    # Check OAI data
    if oai_local_path.exists():
        img_path = oai_local_path / "img"
        if img_path.exists():
            img_count = len(list(img_path.glob("*.png")))
            print(f"  ✅ oai/img/ ({img_count} PNG files)")
        else:
            print("  ❌ oai/img/ (missing)")
    else:
        print("  ❌ oai/ directory (missing)")

    # Check pretrained models
    if pretrained_local_path.exists():
        model_files = list(pretrained_local_path.rglob("*.pt")) + list(
            pretrained_local_path.rglob("*.pth")
        )
        print(f"  ✅ pretrained/ ({len(model_files)} model files)")

        # Check specific model directories
        for model_dir in ["aot-gan", "ict", "repaint"]:
            model_path = pretrained_local_path / model_dir
            if model_path.exists():
                model_count = len(
                    list(model_path.rglob("*.pt")) + list(model_path.rglob("*.pth"))
                )
                print(f"    ✅ {model_dir}/ ({model_count} files)")
            else:
                print(f"    ❌ {model_dir}/ (missing)")
    else:
        print("  ❌ pretrained/ directory (missing)")

    # Check for generated directories (will be created by split.py)
    generated_dirs = ["train", "valid", "test"]
    for dir_name in generated_dirs:
        dir_path = oai_local_path / dir_name
        if dir_path.exists():
            file_count = len(list(dir_path.rglob("*")))
            print(f"  ✅ oai/{dir_name}/ ({file_count} files) - Generated")
        else:
            print(f"  ⏳ oai/{dir_name}/ (will be generated by split.py)")

else:
    print("❌ OAI data not found in Google Drive")
    print("💡 Please ensure your data is uploaded to:")
    print("   /content/drive/MyDrive/Colab Notebooks/OAI_untracked")
    print("\n📤 Expected structure:")
    print("   OAI_untracked/")
    print("   ├── data/")
    print("   │   ├── oai/")
    print("   │   │   └── img/          # 539 PNG files")
    print("   │   └── pretrained/       # Model files")
    print("   └── README.md")
    print("\n📤 To upload data:")
    print("1. Go to Google Drive")
    print("2. Navigate to 'Colab Notebooks' folder")
    print("3. Create 'OAI_untracked' folder")
    print("4. Upload your OAI dataset files there")

In [None]:
# Update repository to latest version
print("🔄 Updating repository...")

try:
    # Fetch latest changes
    subprocess.run(["git", "fetch"], check=True)

    # Check for updates
    result = subprocess.run(
        ["git", "status", "-uno"], check=False, capture_output=True, text=True
    )
    print("📋 Repository status:")
    print(result.stdout)

    # Pull latest changes
    subprocess.run(["git", "pull"], check=True)
    print("✅ Repository updated to latest version")

    # Reinstall dependencies if needed
    print("📦 Reinstalling dependencies...")
    subprocess.run(["pip", "install", "-e", ".[dev,ml]"], check=True)
    print("✅ Dependencies updated")

except Exception as e:
    print(f"❌ Update failed: {e}")
    print("💡 You may need to restart the runtime and re-run the setup cell.")