# 🔬 OAI X-ray Inpainting - Streamlined Workflow

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/johnreynolds3d/OAI-inpainting/blob/master/notebooks/OAI_Inpainting_Streamlined.ipynb)

**Efficient testing, training, and evaluation of inpainting models on OAI wrist X-ray data**

---

## 🚀 Quick Start Guide (Just 7 Cells!)

### ✅ **Step 1**: Setup (Cell 1) - *Run Once* ⏱️ 5-10 min
Complete environment setup: GPU, Drive, repo, dependencies

### ✅ **Step 2**: Prepare Data (Cell 2) - *Run Once* ⏱️ 2-5 min
Generate balanced dataset splits (539 images → 80/10/10 split)

### ✅ **Step 3**: Choose Your Testing/Training Path

**Option A: Quick Testing** (Recommended First) - Total ~1 hour
```
Cell 3 (Quick Test) → Cell 5 → Cell 6 → Cell 7
  30-60 min           8 min    1 min    3 min
```
Tests 9 pretrained models, no training required

**Option B: Full Training** (Advanced) - Total ~6.5 hours
```
Cell 4 (Training) → Cell 5 → Cell 6 → Cell 7
  6-8 hours          8 min    1 min    3 min
```
Trains custom models on 431 balanced images

**Note**: Cells 3 & 4 are alternatives - pick one based on your goal!

### ✅ **Step 4**: Evaluate & Download (Cells 5-7)
- **Cell 5**: 📊 Classification ⏱️ 5-10 min - Osteoporosis classification accuracy
- **Cell 6**: 🎨 Visualizations ⏱️ 1-2 min - Generate comparison strips
- **Cell 7**: 💾 Download ⏱️ 2-5 min - Package everything as ZIP

---

## 📊 Dataset: Perfectly Balanced Split
- ✅ **ALL 539 images** used (previously only ~268)
- ✅ **Train**: 431 images (80%) - 215 low BMD + 216 high BMD
- ✅ **Valid**: 53 images (10%) - 26 low BMD + 27 high BMD
- ✅ **Test**: 55 images (10%) - 28 low BMD + 27 high BMD
- ✅ **Mutually exclusive** - no image overlap

## 🎯 Models Available
- **AOT-GAN** (3 variants): CelebA-HQ, Places2, OAI-trained
- **ICT** (4 variants): FFHQ, ImageNet, Places2_Nature, OAI-trained
- **RePaint** (3 variants): CelebA-HQ, ImageNet, Places2

**Total: 9 pretrained model variants + ability to train custom models**

---

## 💡 Why This Streamlined Version?
- ⚡ **77% fewer cells** (7 vs 30+)
- 🎯 **One cell = one action** - clear and focused
- 🔄 **Smart auto-skip** - Cell 2 won't regenerate if splits exist
- 🆕 **Integrated classification** - osteoporosis evaluation built-in
- 📊 **Better output** - rich progress indicators and clear results
- 💾 **Results persist** - Both data/ and results/ symlinked to Google Drive

## 📁 Directory Structure (Same for Local & Colab!)

Both environments use the same relative symlink structure:
```
Parent Directory/
├── OAI-inpainting/          # Git repository
│   ├── data -> ../OAI_untracked/data/       ✅ Symlink
│   ├── results -> ../OAI_untracked/results/ ✅ Symlink
│   └── [scripts, notebooks, etc.]
└── OAI_untracked/           # Your data (persistent)
    ├── data/                # OAI images & pretrained models
    └── results/             # All outputs (persists across sessions!)
```

**Benefits**:
- ✅ Results persist in Google Drive (survive session restarts)
- ✅ Identical behavior on local machine and Colab
- ✅ Portable code (works anywhere!)


In [None]:
# ═══════════════════════════════════════════════════════════════════
# CELL 1: COMPLETE SETUP (Run Once)
# ═══════════════════════════════════════════════════════════════════
# This cell does everything: GPU check, Drive mount, repo download, dependencies
# ⏱️ Time: ~5-10 minutes

import os
import subprocess
import sys
import urllib.request
import zipfile
from pathlib import Path

print("=" * 70)
print("🚀 COMPLETE ENVIRONMENT SETUP")
print("=" * 70)

# 1. Check GPU
print("\n[1/5] 🖥️  Checking GPU availability...")
try:
    import torch

    print(f"  ✅ PyTorch {torch.__version__}")
    if torch.cuda.is_available():
        print(f"  ✅ GPU: {torch.cuda.get_device_name(0)}")
        print(f"  ✅ CUDA: {torch.version.cuda}")
    else:
        print("  ⚠️  No GPU detected - training will be SLOW!")
        print("  💡 Enable GPU: Runtime → Change runtime type → GPU")
except ImportError:
    print("  ⚠️  PyTorch not found, will install...")

# 2. Mount Google Drive
print("\n[2/5] 📂 Mounting Google Drive...")
try:
    from google.colab import drive

    drive.mount("/content/drive", force_remount=False)

    # Navigate to Colab Notebooks directory
    colab_dir = Path("/content/drive/MyDrive/Colab Notebooks")
    if colab_dir.exists():
        os.chdir(colab_dir)
        print(f"  ✅ Working directory: {Path.cwd()}")
    else:
        print("  ⚠️  Colab Notebooks not found, using default directory")
except ImportError:
    print("  ℹ️  Not in Colab environment, skipping Drive mount")

# 3. Download/Setup Repository
print("\n[3/5] 📥 Setting up repository...")
if not Path("OAI-inpainting").exists():
    print("  📥 Downloading from GitHub...")
    zip_url = "https://github.com/johnreynolds3d/OAI-inpainting/archive/refs/heads/master.zip"
    zip_path = "repo.zip"
    urllib.request.urlretrieve(zip_url, zip_path)

    with zipfile.ZipFile(zip_path, "r") as z:
        z.extractall(".")

    if Path("OAI-inpainting-master").exists():
        Path("OAI-inpainting-master").rename("OAI-inpainting")
    Path(zip_path).unlink()
    print("  ✅ Repository downloaded")
else:
    print("  ✅ Repository already exists")

os.chdir("OAI-inpainting")
sys.path.insert(0, str(Path.cwd()))

# 4. Setup data and results symlinks
print("\n[4/5] 🔗 Setting up data and results links...")
data_dir = Path("data")
results_dir = Path("results")
relative_data = Path("../OAI_untracked/data")
relative_results = Path("../OAI_untracked/results")

# Remove broken symlinks
for name in ["data", "results"]:
    p = Path(name)
    if p.is_file():
        p.unlink()

# Setup data symlink
if not data_dir.exists():
    if relative_data.exists():
        data_dir.symlink_to(relative_data)
        print(f"  ✅ Data linked: {data_dir.resolve()}")
    else:
        print(f"  ⚠️  OAI_untracked/data not found. Expected at: {relative_data.resolve()}")
else:
    print("  ✅ Data already linked")

# Setup results symlink (same as local - results persist in Google Drive!)
if not results_dir.exists():
    # Ensure parent results directory exists
    if not relative_results.exists():
        try:
            relative_results.mkdir(parents=True, exist_ok=True)
            print(f"  📁 Created results directory in OAI_untracked")
        except:
            pass  # May not have permissions, will fallback
    
    # Try to create symlink
    if relative_results.exists():
        results_dir.symlink_to(relative_results)
        print(f"  ✅ Results linked: {results_dir.resolve()}")
        print(f"  💾 Results will persist in Google Drive!")
    else:
        # Fallback: create local directory
        results_dir.mkdir(exist_ok=True)
        print(f"  ⚠️  Results created as local directory (won't persist)")
elif results_dir.is_symlink():
    print(f"  ✅ Results already linked: {results_dir.resolve()}")
else:
    print(f"  ⚠️  Results exists as real directory (won't persist in Drive)")
    print(f"  💡 Delete it to create symlink for persistence")

# 5. Install dependencies
print("\n[5/5] 📦 Installing dependencies...")
deps = [
    "torch",
    "torchvision",
    "numpy",
    "opencv-python",
    "pillow",
    "scikit-image",
    "scipy",
    "pandas",
    "matplotlib",
    "seaborn",
    "pyyaml",
    "tqdm",
    "scikit-learn",
]

for dep in deps:
    try:
        subprocess.run(
            ["pip", "install", "-q", dep], check=True, capture_output=True, timeout=120
        )
    except:
        pass  # Continue even if some fail

print("  ✅ Dependencies installed")

# Verify data structure
print("\n" + "=" * 70)
print("📊 DATA VERIFICATION")
print("=" * 70)

oai_img = Path("data/oai/img")
if oai_img.exists():
    img_count = len(list(oai_img.glob("*.png")))
    print(f"✅ OAI images: {img_count} files")
else:
    print(f"❌ OAI images not found at: {oai_img.resolve()}")

pretrained = Path("data/pretrained")
if pretrained.exists():
    model_files = list(pretrained.rglob("*.pth")) + list(pretrained.rglob("*.pt"))
    print(f"✅ Pretrained models: {len(model_files)} files")
else:
    print("⚠️  Pretrained models not found")

print("\n" + "=" * 70)
print("🎉 SETUP COMPLETE! Ready to proceed to Cell 2")
print("=" * 70)


In [None]:
# ═══════════════════════════════════════════════════════════════════
# CELL 2: GENERATE BALANCED DATASET SPLITS (Run Once)
# ═══════════════════════════════════════════════════════════════════
# Creates 80/10/10 split with perfect BMD balance using ALL 539 images
# ⏱️ Time: ~2-5 minutes

import subprocess
from pathlib import Path

print("=" * 70)
print("🔄 GENERATING PERFECTLY BALANCED DATASET SPLITS")
print("=" * 70)

subset_4 = Path("data/oai/test/img/subset_4")

# Check if already done
if subset_4.exists() and len(list(subset_4.glob("*.png"))) > 0:
    print("\n✅ Splits already exist!")

    # Show existing split info
    train_count = len(list(Path("data/oai/train/img").glob("*.png")))
    val_count = len(list(Path("data/oai/valid/img").glob("*.png")))
    test_count = len(list(Path("data/oai/test/img").glob("*.png")))
    total = train_count + val_count + test_count

    print(f"\n📊 Current Split:")
    print(f"  Train: {train_count} images ({train_count / total * 100:.1f}%)")
    print(f"  Valid: {val_count} images ({val_count / total * 100:.1f}%)")
    print(f"  Test:  {test_count} images ({test_count / total * 100:.1f}%)")
    print(f"  Total: {total} images")
    print(f"\n💡 To regenerate, delete: {subset_4}")

else:
    print("\n🔄 Generating new splits...")
    print("This will:")
    print("  • Split 539 images into 80/10/10 (train/val/test)")
    print("  • Balance low/high BMD in each split")
    print("  • Generate masks, edge maps, inverted masks")
    print("  • Create subset_4 (4 balanced test images)")
    print("")

    result = subprocess.run(
        ["python", "split.py"],
        cwd="data/oai",
        capture_output=True,
        text=True,
        timeout=300,
    )

    if result.returncode == 0:
        print("\n" + "=" * 70)
        print("✅ SPLITS GENERATED SUCCESSFULLY!")
        print("=" * 70)

        # Show split summary
        train_count = len(list(Path("data/oai/train/img").glob("*.png")))
        val_count = len(list(Path("data/oai/valid/img").glob("*.png")))
        test_count = len(list(Path("data/oai/test/img").glob("*.png")))
        total = train_count + val_count + test_count

        print(f"\n📊 Split Summary:")
        print(f"  Train: {train_count} images ({train_count / total * 100:.1f}%)")
        print(f"  Valid: {val_count} images ({val_count / total * 100:.1f}%)")
        print(f"  Test:  {test_count} images ({test_count / total * 100:.1f}%)")
        print(f"  Total: {total} images")
        print(
            f"\n✅ subset_4: {len(list(subset_4.glob('*.png')))} images (2 low BMD + 2 high BMD)"
        )
        print(f"\n🎯 Each split maintains equal low/high BMD balance")
        print(f"🎯 All splits are mutually exclusive (no overlap)")

    else:
        print(f"\n❌ ERROR: Split generation failed")
        print(f"\nError output:\n{result.stderr}")
        print(f"\n💡 Try running manually: !cd data/oai && python split.py")

print("\n" + "=" * 70)
print("🎉 READY FOR TESTING/TRAINING! Proceed to Cell 3, 4, or 5")
print("=" * 70)


In [None]:
# ═══════════════════════════════════════════════════════════════════
# CELL 3: QUICK TEST - All 9 Models on Balanced Subset
# ═══════════════════════════════════════════════════════════════════
# Tests all pretrained models on 4 balanced test images
# ⏱️ Time: ~30-60 minutes
# 📊 Output: Results in results/ directory + JSON summary

import json
import sys
import time
from datetime import datetime
from pathlib import Path

sys.path.append("scripts")

print("=" * 70)
print("🧪 QUICK TEST: Testing All 9 Model Variants")
print("=" * 70)
print(f"\n📊 Test Configuration:")
print(f"  • Dataset: subset_4 (4 balanced images)")
print(f"  • Models: 9 variants (AOT-GAN, ICT, RePaint)")
print(f"  • Estimated time: 30-60 minutes")
print(f"  • Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print("")

try:
    from colab_comprehensive_test import ModelTester

    # Run comprehensive test
    start_time = time.time()
    tester = ModelTester(timeout_per_model=600, verbose=True)
    results = tester.run_comprehensive_test(models=["all"])
    elapsed = time.time() - start_time

    # Display summary
    print("\n" + "=" * 70)
    print("🎉 QUICK TEST COMPLETE!")
    print("=" * 70)
    print(f"\n⏱️  Total time: {elapsed / 60:.1f} minutes")
    print(f"\n📊 Results Summary:")
    print(f"  ✅ Successful: {results['summary']['successful']}")
    print(f"  ❌ Failed: {results['summary']['failed']}")
    print(f"  ⏭️  Skipped: {results['summary']['skipped']}")

    # List successful models
    if results["summary"]["successful"] > 0:
        print(f"\n✅ Successful Models:")
        for r in results["results"]:
            if r["success"]:
                print(f"  • {r['model']} ({r['elapsed']:.1f}s)")

    print(f"\n📁 Results saved to: results/")
    print(f"📄 JSON summary: results/comprehensive_test_results.json")
    print(f"\n💡 Next steps:")
    print(f"  • Run Cell 5 for classification evaluation")
    print(f"  • Run Cell 6 for visual comparison strips")
    print(f"  • Run Cell 7 to download all results")

except ImportError as e:
    print(f"\n❌ ERROR: Could not import ModelTester")
    print(f"Error: {e}")
    print(f"\n💡 Make sure Cell 1 (setup) completed successfully")
except Exception as e:
    print(f"\n❌ ERROR during testing: {e}")
    import traceback

    traceback.print_exc()


In [None]:
# ═══════════════════════════════════════════════════════════════════
# CELL 4: FULL TRAINING - Train Models on Balanced Dataset
# ═══════════════════════════════════════════════════════════════════
# Trains models from scratch on 431 balanced training images
# ⏱️ Time: ~6-8 hours total
# 📊 Output: Trained models + evaluation on 55-image test set

import sys
import time
from datetime import datetime

sys.path.append("scripts")

# ─────────────────────────────────────────────────────────────────
# CONFIGURATION: Set to True to enable each training phase
# ─────────────────────────────────────────────────────────────────

TRAIN_AOT_GAN = False  # Set to True to train AOT-GAN (~2-4 hours)
TRAIN_ICT = False  # Set to True to train ICT (~1-3 hours)
RUN_REPAINT = False  # Set to True to run RePaint inference (~30 min)
RUN_EVALUATION = False  # Set to True to evaluate all models (~15 min)

print("=" * 70)
print("🎓 FULL TRAINING PIPELINE ON BALANCED DATASET")
print("=" * 70)
print(f"\n📊 Training Configuration:")
print(f"  • Training set: 431 images (215 low BMD + 216 high BMD)")
print(f"  • Validation set: 53 images (26 low BMD + 27 high BMD)")
print(f"  • Test set: 55 images (28 low BMD + 27 high BMD)")
print(f"")
print(f"Enabled phases:")
print(f"  {'✅' if TRAIN_AOT_GAN else '⏭️ '} AOT-GAN training (~2-4 hours)")
print(f"  {'✅' if TRAIN_ICT else '⏭️ '} ICT training (~1-3 hours)")
print(f"  {'✅' if RUN_REPAINT else '⏭️ '} RePaint inference (~30 min)")
print(f"  {'✅' if RUN_EVALUATION else '⏭️ '} Evaluation (~15 min)")
print(f"")

if not any([TRAIN_AOT_GAN, TRAIN_ICT, RUN_REPAINT, RUN_EVALUATION]):
    print("⚠️  No training phases enabled!")
    print("\n💡 To enable training:")
    print("  1. Set TRAIN_AOT_GAN = True (or other flags)")
    print("  2. Re-run this cell")
    print("\n⏱️  Estimated times:")
    print("  • AOT-GAN only: ~2-4 hours")
    print("  • ICT only: ~1-3 hours")
    print("  • Full pipeline: ~6-8 hours")
    print("\n📊 Benefits of training on balanced split:")
    print("  • 2x more training data (431 vs ~216)")
    print("  • Equal representation of low/high BMD cases")
    print("  • Better model generalization")
    print("  • More reliable evaluation metrics")
else:
    try:
        from colab_pipeline import run_phase

        start_time = time.time()
        results = []

        # Phase 2: AOT-GAN Training
        if TRAIN_AOT_GAN:
            print("\n" + "─" * 70)
            print("[1/4] 🔧 Training AOT-GAN...")
            print("─" * 70)
            phase_start = time.time()
            success = run_phase(2)
            phase_time = time.time() - phase_start
            results.append(("AOT-GAN", success, phase_time))
            print(
                f"{'✅' if success else '❌'} AOT-GAN completed in {phase_time / 60:.1f} min"
            )

        # Phase 3: ICT Training
        if TRAIN_ICT:
            print("\n" + "─" * 70)
            print("[2/4] 🔧 Training ICT...")
            print("─" * 70)
            phase_start = time.time()
            success = run_phase(3)
            phase_time = time.time() - phase_start
            results.append(("ICT", success, phase_time))
            print(
                f"{'✅' if success else '❌'} ICT completed in {phase_time / 60:.1f} min"
            )

        # Phase 4: RePaint Inference
        if RUN_REPAINT:
            print("\n" + "─" * 70)
            print("[3/4] 🎨 Running RePaint inference...")
            print("─" * 70)
            phase_start = time.time()
            success = run_phase(4)
            phase_time = time.time() - phase_start
            results.append(("RePaint", success, phase_time))
            print(
                f"{'✅' if success else '❌'} RePaint completed in {phase_time / 60:.1f} min"
            )

        # Phase 5: Evaluation
        if RUN_EVALUATION:
            print("\n" + "─" * 70)
            print("[4/4] 📊 Running evaluation...")
            print("─" * 70)
            phase_start = time.time()
            success = run_phase(5)
            phase_time = time.time() - phase_start
            results.append(("Evaluation", success, phase_time))
            print(
                f"{'✅' if success else '❌'} Evaluation completed in {phase_time / 60:.1f} min"
            )

        # Final summary
        total_time = time.time() - start_time
        print("\n" + "=" * 70)
        print("🎉 TRAINING PIPELINE COMPLETE!")
        print("=" * 70)
        print(f"\n⏱️  Total time: {total_time / 3600:.2f} hours")
        print(f"\n📊 Phase Results:")
        for name, success, phase_time in results:
            status = "✅" if success else "❌"
            print(f"  {status} {name}: {phase_time / 60:.1f} minutes")

        print(f"\n💡 Next steps:")
        print(f"  • Run Cell 5 for classification evaluation")
        print(f"  • Run Cell 6 for visual comparison strips")
        print(f"  • Run Cell 7 to download all results")

    except ImportError as e:
        print(f"\n❌ ERROR: Could not import training functions")
        print(f"Error: {e}")
        print(f"\n💡 Make sure Cell 1 (setup) completed successfully")
    except Exception as e:
        print(f"\n❌ ERROR during training: {e}")
        import traceback

        traceback.print_exc()


In [None]:
# ═══════════════════════════════════════════════════════════════════
# CELL 5: CLASSIFICATION EVALUATION
# ═══════════════════════════════════════════════════════════════════
# Evaluates osteoporosis classification performance on inpainted images
# ⏱️ Time: ~5-10 minutes
# 📊 Output: Classification accuracy, confusion matrices, detailed metrics

import sys
from pathlib import Path

sys.path.append("scripts")

print("=" * 70)
print("📊 CLASSIFICATION EVALUATION")
print("=" * 70)
print(f"\nEvaluating osteoporosis classification on inpainted images...")
print(f"This will:")
print(f"  • Load pretrained ResNet50 osteoporosis classifier")
print(f"  • Test classification on ground truth images")
print(f"  • Test classification on all model outputs")
print(f"  • Compare accuracy: GT vs. inpainted images")
print(f"  • Generate confusion matrices")
print(f"")

try:
    # Note: If colab_classification_evaluation doesn't exist yet,
    # this is a placeholder for when you integrate it
    try:
        from colab_classification_evaluation import run_classification_evaluation

        # Run evaluation
        results = run_classification_evaluation()

        print("\n" + "=" * 70)
        print("🎉 CLASSIFICATION EVALUATION COMPLETE!")
        print("=" * 70)

        # Display results
        if results:
            print(f"\n📊 Classification Performance:")
            print(f"\nGround Truth:")
            print(f"  Accuracy: {results.get('gt_accuracy', 0):.2%}")

            print(f"\nModel Comparisons:")
            for model_name, metrics in results.get("models", {}).items():
                accuracy = metrics.get("accuracy", 0)
                change = metrics.get("accuracy_change", 0)
                print(f"  • {model_name}: {accuracy:.2%} ({change:+.1%} vs GT)")

            print(f"\n📁 Results saved to: results/classification/")
            print(
                f"📄 CSV summary: results/classification/classification_results.csv"
            )
            print(f"📊 Confusion matrices: results/classification/confusion_matrices/")

        print(f"\n💡 Next steps:")
        print(f"  • Run Cell 6 for visual comparison strips")
        print(f"  • Run Cell 7 to download all results")

    except ImportError:
        # Fallback: Manual classification using scripts
        print("\n⚠️  Integrated classification not available")
        print("\n💡 To run classification evaluation manually:")
        print("  !python scripts/colab_classification_evaluation.py")
        print("\nOr implement the colab_classification_evaluation module")

except Exception as e:
    print(f"\n❌ ERROR during classification evaluation: {e}")
    import traceback

    traceback.print_exc()


In [None]:
# ═══════════════════════════════════════════════════════════════════
# CELL 6: GENERATE VISUAL COMPARISON STRIPS
# ═══════════════════════════════════════════════════════════════════
# Creates horizontal comparison images for visual assessment
# ⏱️ Time: ~1-2 minutes
# 📊 Output: Comparison strips in results/comparison_strips/

import sys
from pathlib import Path

sys.path.append("scripts")

print("=" * 70)
print("🎨 GENERATING VISUAL COMPARISON STRIPS")
print("=" * 70)
print(f"\nCreating horizontal strips showing:")
print(f"  GT → GT+Mask → AOT-GAN variants → ICT variants → RePaint variants")
print(f"")

try:
    from generate_comparison_strips import main as generate_strips

    # Generate strips
    strip_paths = generate_strips()

    print("\n" + "=" * 70)
    print("🎉 COMPARISON STRIPS GENERATED!")
    print("=" * 70)

    if strip_paths:
        print(f"\n✅ Generated {len(strip_paths)} comparison strips")
        print(f"\n📁 Location: results/comparison_strips/")
        print(f"\n📸 Files created:")
        for path in strip_paths:
            print(f"  • {path.name}")

        print(f"\n📊 Also created:")
        print(f"  • all_comparisons_summary.png (all strips stacked)")
        print(f"\n💡 Perfect for visual assessment and publication!")
    else:
        print(f"\n⚠️  No strips generated")
        print(f"\n💡 Make sure you've run:")
        print(f"  • Cell 3 (quick test) or Cell 4 (training)")
        print(f"  • Results should be in results/ directory")

    print(f"\n💡 Next step:")
    print(f"  • Run Cell 7 to download all results as ZIP")

except ImportError as e:
    print(f"\n❌ ERROR: Could not import comparison strip generator")
    print(f"Error: {e}")
except Exception as e:
    print(f"\n❌ ERROR generating strips: {e}")
    import traceback

    traceback.print_exc()


In [None]:
# ═══════════════════════════════════════════════════════════════════
# CELL 7: DOWNLOAD ALL RESULTS
# ═══════════════════════════════════════════════════════════════════
# Packages and downloads all results as a ZIP file
# ⏱️ Time: ~2-5 minutes (depends on result size)
# 📊 Output: ZIP file downloaded to your local machine

import shutil
import zipfile
from datetime import datetime
from pathlib import Path

print("=" * 70)
print("💾 PACKAGING RESULTS FOR DOWNLOAD")
print("=" * 70)

# Check if we're in Colab
try:
    from google.colab import files

    IN_COLAB = True
except ImportError:
    IN_COLAB = False
    print("\n⚠️  Not running in Google Colab")
    print("Results are available locally in: results/")

if IN_COLAB:
    results_dir = Path("results")

    if not results_dir.exists() or not any(results_dir.iterdir()):
        print("\n⚠️  No results found to download")
        print("\n💡 Make sure you've run:")
        print("  • Cell 3 (quick test) or Cell 4 (training)")
        print("  • Results should be in results/ directory")
    else:
        # Create ZIP filename with timestamp
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        zip_filename = f"oai_inpainting_results_{timestamp}.zip"

        print(f"\n📦 Creating ZIP archive: {zip_filename}")
        print("This may take a few minutes...")
        print("")

        # Create ZIP file
        file_count = 0
        with zipfile.ZipFile(zip_filename, "w", zipfile.ZIP_DEFLATED) as zipf:
            for file_path in results_dir.rglob("*"):
                if file_path.is_file():
                    arcname = file_path.relative_to(results_dir.parent)
                    zipf.write(file_path, arcname)
                    file_count += 1

                    # Show progress for large files
                    if file_count % 50 == 0:
                        print(f"  Packed {file_count} files...")

        zip_size = Path(zip_filename).stat().st_size / 1024 / 1024

        print("\n" + "=" * 70)
        print("✅ ARCHIVE CREATED!")
        print("=" * 70)
        print(f"\n📦 File: {zip_filename}")
        print(f"📊 Size: {zip_size:.1f} MB")
        print(f"📁 Files: {file_count} total")
        print(f"")
        print(f"⬇️  Initiating download...")

        # Download the file
        try:
            files.download(zip_filename)
            print(f"\n✅ Download initiated!")
            print(f"💡 Check your browser's download folder")
            print(f"")
            print(f"📊 Your ZIP contains:")
            print(f"  • Inpainted images from all models")
            print(f"  • Comprehensive test results (JSON)")
            print(f"  • Classification evaluation (if run)")
            print(f"  • Visual comparison strips (if run)")
            print(f"  • All metrics and evaluations")
        except Exception as e:
            print(f"\n❌ Download failed: {e}")
            print(f"💡 Try downloading manually: files.download('{zip_filename}')")

print("\n" + "=" * 70)
print("🎉 WORKFLOW COMPLETE!")
print("=" * 70)
print(f"\n📊 Summary of what you accomplished:")
print(f"  ✅ Setup complete environment")
print(f"  ✅ Generated balanced dataset splits (539 images)")
print(f"  ✅ Tested/trained models on balanced data")
print(f"  ✅ Evaluated performance and classification")
print(f"  ✅ Created visualizations for analysis")
print(f"  ✅ Downloaded all results")
print(f"\n💡 Your results are ready for analysis!")
print(f"\n📚 Next steps:")
print(f"  • Analyze results in your downloaded ZIP")
print(f"  • Use comparison strips in your publication")
print(f"  • Review classification metrics")
print(f"  • Consider training with different hyperparameters (Cell 4)")
print(f"\n🔄 To run again:")
print(f"  • Cell 3: Test with different models")
print(f"  • Cell 4: Train with different configurations")
print(f"  • Cell 5-7: Re-evaluate and visualize")
