In [None]:
# Cell: Fixing Gradio Installation for Colab Compatibility

print("⚠️ Fixing Gradio installation...")

# Uninstall problematic package without touching numpy
!pip uninstall -y gradio

# Install specific Gradio version without dependencies
!pip install gradio==3.41.2 --no-deps

# Install some common dependencies that Gradio needs
!pip install pydub markdown-it-py mdit-py-plugins

# Force reload importlib to ensure we're using the new version
import importlib
import sys
if 'gradio' in sys.modules:
    del sys.modules['gradio']  # Force complete reload

# Try importing and check version
try:
    import gradio
    print(f"✅ Successfully installed Gradio version: {gradio.__version__}")
except Exception as e:
    print(f"❌ Error importing Gradio: {e}")
    print("   You may need to restart the runtime (Runtime > Restart runtime)")

# Verify other key dependencies are still working
print("\n--- Checking core dependencies ---")
libs_to_check = ["torch", "diffusers", "transformers", "numpy"]
for lib_name in libs_to_check:
    try:
        lib = importlib.import_module(lib_name)
        version = getattr(lib, '__version__', "unknown")
        print(f"✅ {lib_name}: {version}")
    except Exception as e:
        print(f"❌ {lib_name} error: {e}")

print("\nIf all dependencies show ✅, you're ready to proceed with the notebook.")
print("If there are still errors, try running: Runtime > Restart runtime, then run all cells from the beginning.")

In [None]:
# Cell 2: Environment Setup - Dependencies (Attempt 10 - Force Reinstall NumPy/Diffusers)

# --- Upgrade pip ---
print("Upgrading pip...")
!pip install --upgrade pip
print("Pip upgrade complete.")

# --- Install PyTorch ---
print("Installing PyTorch 2.1.0...")
# Sticking with 2.1.0 as it seemed okay before the numpy issue
!pip install -q torch==2.1.0+cu121 torchvision==0.16.0+cu121 torchaudio==2.1.0+cu121 --index-url https://download.pytorch.org/whl/cu121

# --- Install Core Libs ---
print("Installing core libraries (forcing reinstall for numpy and diffusers)...")

# Force reinstall numpy and diffusers, keep others as specified
# Using -q for quiet, but errors should still show. Remove -q if detailed logs needed.
!pip install -q \
    --force-reinstall numpy==1.24.3 \
    Pillow==10.4.0 \
    gradio==3.41.2 \
    huggingface_hub==0.23.3 \
    transformers==4.42.4 \
    accelerate==0.32.1 \
    --force-reinstall diffusers==0.25.1 \
    torchsde==0.2.6 \
    einops==0.8.0 \
    safetensors==0.4.3 \
    pyyaml==6.0.1 \
    scipy==1.14.0 \
    tqdm==4.66.4 \
    psutil==6.0.0 \
    pytorch_lightning==2.3.3 \
    omegaconf==2.3.0 \
    pygit2==1.15.1 \
    opencv-contrib-python-headless==4.10.0.84 \
    httpx==0.27.0 \
    onnxruntime==1.18.1 \
    timm==1.0.7 \
    tokenizers==0.19.1 \
    packaging==24.1 \
    piexif \
    sentencepiece \
    requests

# --- Excluded Dependencies ---
# insightface, onnxruntime-gpu (for Woop - deferred)
# segment-anything-hq, supervision (for SAM - removed)
# invisible_watermark, etc. (Keeping it minimal)

# --- Clear Output & Verify ---
import IPython
import importlib
import torch
import os
import sys
import numpy # Import numpy again after reinstall for verification

# Comment out clear_output to see full install log if needed
# IPython.display.clear_output()

print("\n--- Dependency Installation Summary (Attempt 10 - Forced Reinstall) ---")
# Check libraries needed for this simplified plan
libs_to_check = {
    "numpy": "numpy",
    "torch": "torch",
    "PIL": "PIL", # Pillow
    "huggingface_hub": "huggingface_hub",
    "diffusers": "diffusers",
    "transformers": "transformers",
    "accelerate": "accelerate",
    "gradio": "gradio",
    "safetensors": "safetensors",
    "cv2": "cv2",
    "onnxruntime": "onnxruntime",
    "requests": "requests",
}
error_found = False
successful_imports = []
failed_imports = []

# Check imports carefully
print("Python version:", sys.version)
print("Running import checks...")
for display_name, lib_name in libs_to_check.items():
    try:
        # Use reload if module already imported (like numpy above)
        if lib_name in sys.modules:
             mod = importlib.reload(sys.modules[lib_name])
        else:
             mod = importlib.import_module(lib_name)

        version = "N/A"
        if hasattr(mod, '__version__'):
            version = mod.__version__
        elif lib_name == "PIL":
             try:
                 import Pillow
                 version = Pillow.__version__
             except ImportError: pass
             except AttributeError: pass

        # --- Specific NumPy dtypes check ---
        if lib_name == "numpy":
             has_dtypes_check = hasattr(mod, 'dtypes')
             print(f"✅ {display_name}: {version} (dtypes exists: {has_dtypes_check})")
             if not has_dtypes_check:
                  print(f"   🚨 CRITICAL WARNING: NumPy {version} still missing 'dtypes' after reinstall!")
                  error_found = True # Treat this as an error
        else:
             print(f"✅ {display_name}: {version}")

        successful_imports.append(display_name)
    except ImportError as e:
        print(f"❌ {display_name} not found. Error: {e}")
        failed_imports.append(display_name)
        error_found = True
    except Exception as e:
        print(f"❌ Error importing/checking {display_name}. Error: {e}")
        failed_imports.append(display_name)
        error_found = True


print("\n--- GPU Check ---")
# Check torch import status from the loop above
torch_failed = any(item == "torch" for item in failed_imports)

if torch_failed:
     print("❌ Torch import failed, cannot check GPU.")
     error_found = True # Ensure error is flagged
elif torch.cuda.is_available():
    try:
        print(f"✅ GPU Found: {torch.cuda.get_device_name(0)}")
        print(f"   CUDA Version Used by PyTorch: {torch.version.cuda}")
        print(f"   PyTorch Version: {torch.__version__}")
    except Exception as e:
        print(f"❌ Error during GPU check: {e}")
        error_found = True
else:
    print("❌ No GPU detected. This notebook requires a GPU runtime.")
    error_found = True

print("\n--- Summary ---")
if error_found:
     print(f"⚠️ Errors encountered during setup or import checks. Failed/Problematic imports: {failed_imports}")
     print(f"   If NumPy still shows 'dtypes exists: False', the environment conflict persists.")
     print(f"   You might need to try different versions of libraries (diffusers, torch, numpy) or factory reset the runtime.")
else:
     print("✅ Core environment setup cell completed successfully (Attempt 10 - Forced Reinstall).")
     print(f"   Successfully imported: {successful_imports}")
     print(f"   NumPy dtypes check passed.")



In [None]:
# Cell 3: Google Drive Integration

from google.colab import drive
import os
import json
from datetime import datetime

print("🔄 Mounting Google Drive...")
try:
    drive.mount('/content/drive', force_remount=True) # Force remount can help avoid errors if already mounted
    print("✅ Google Drive mounted successfully.")
except Exception as e:
    print(f"❌ Error mounting Google Drive: {e}")
    raise Exception("Drive mounting failed. Please check permissions and try again.")

# --- Define Directory Structure ---
# Using a clear base path in MyDrive
BASE_DRIVE_PATH = '/content/drive/MyDrive/AI_Studio_Toolkit_v3' # Changed name for v3
MODELS_PATH = os.path.join(BASE_DRIVE_PATH, 'models')
SAM_MODELS_PATH = os.path.join(MODELS_PATH, 'sam_models') # Specific path for SAM models
WOOP_MODELS_PATH = os.path.join(MODELS_PATH, 'woop_models') # Specific path for Woop models
OUTPUT_PATH = os.path.join(BASE_DRIVE_PATH, 'outputs')
CONFIG_PATH = os.path.join(BASE_DRIVE_PATH, 'config') # For storing config files like tokens

# List of directories to create
paths_to_create = [
    BASE_DRIVE_PATH,
    MODELS_PATH,
    SAM_MODELS_PATH,
    WOOP_MODELS_PATH,
    OUTPUT_PATH,
    CONFIG_PATH
]

print("\n--- Creating Project Directories ---")
for path in paths_to_create:
    try:
        os.makedirs(path, exist_ok=True) # exist_ok=True prevents error if dir already exists
        print(f"✓ Directory ensured: {path}")
    except OSError as e:
        print(f"❌ Error creating directory {path}: {e}")
        # Decide if this error is critical. For now, we print and continue.
        # Depending on the error, you might want to raise it.

# --- Verify Write Permissions ---
print("\n--- Verifying Drive Write Access ---")
test_file_path = os.path.join(BASE_DRIVE_PATH, '.write_test')
try:
    with open(test_file_path, 'w') as f:
        f.write(datetime.now().isoformat())
    os.remove(test_file_path) # Clean up the test file
    print("✅ Google Drive is writable.")
except Exception as e:
    print(f"❌ Write access test failed in {BASE_DRIVE_PATH}: {e}")
    print("   Please ensure Colab has write permissions for your Google Drive.")
    # Consider raising an exception if write access is absolutely critical here
    # raise Exception("Write access to Google Drive failed.")

# --- Define Key File Paths (for later use) ---
TOKENS_FILE = os.path.join(CONFIG_PATH, 'tokens.json')

print(f"\n✅ Google Drive setup complete. Project base path: {BASE_DRIVE_PATH}")

In [None]:
# Cell 4: Configuration & Token Management (with Colab Forms)

import os
import json
from google.colab import output # Used for managing Colab Forms output

# --- Configuration ---
# Ensure BASE_DRIVE_PATH and CONFIG_PATH are defined from Cell 3
# If running this cell independently, uncomment and define them:
# BASE_DRIVE_PATH = '/content/drive/MyDrive/AI_Studio_Toolkit_v3'
# CONFIG_PATH = os.path.join(BASE_DRIVE_PATH, 'config')
TOKENS_FILE = os.path.join(CONFIG_PATH, 'tokens.json')

# --- TokenManager Class ---
class TokenManager:
    """Handles loading and saving API tokens securely."""

    def __init__(self, filepath):
        """
        Initializes the TokenManager.

        Args:
            filepath (str): The path to the JSON file where tokens are stored.
        """
        self.filepath = filepath
        # Ensure the config directory exists (it should from Cell 3)
        os.makedirs(os.path.dirname(filepath), exist_ok=True)

    def load_tokens(self):
        """Loads tokens from the JSON file."""
        if os.path.exists(self.filepath):
            try:
                with open(self.filepath, 'r') as f:
                    tokens = json.load(f)
                    print(f"🔑 Tokens loaded from: {self.filepath}")
                    return tokens
            except json.JSONDecodeError:
                print(f"⚠️ Warning: Token file '{self.filepath}' is corrupted. Starting fresh.")
                return {}
            except Exception as e:
                print(f"⚠️ Warning: Could not read token file '{self.filepath}'. Error: {e}")
                return {}
        else:
            print(f"ℹ️ Token file not found at: {self.filepath}. Will create if tokens are saved.")
            return {} # Return empty dict if file doesn't exist

    def save_tokens(self, tokens):
        """
        Saves the provided tokens to the JSON file.

        Args:
            tokens (dict): A dictionary containing the tokens to save.
        """
        try:
            with open(self.filepath, 'w') as f:
                json.dump(tokens, f, indent=4) # Use indent for readability
            print(f"💾 Tokens successfully saved to: {self.filepath}")
        except Exception as e:
            print(f"❌ Error saving tokens to {self.filepath}: {e}")

# --- Initialize Token Manager & Load Existing Tokens ---
token_manager = TokenManager(TOKENS_FILE)
existing_tokens = token_manager.load_tokens()

# --- Colab Forms for Token Input ---
# Clear previous form output to avoid clutter
output.clear()

print("--- API Token Configuration ---")
print("Enter your API tokens below. They are required for downloading models.")
print("Tokens are saved to your Google Drive and are *not* printed in the output.")
print("❗ Important: Do not share your notebook with tokens saved if you are concerned about security.")

# Prepare descriptions for the form, showing if a token is already saved
hf_token_status = "(Optional, Recommended) - Found in saved file." if existing_tokens.get('hf_token') else "(Optional, Recommended)"
civitai_token_status = "(Optional) - Found in saved file." if existing_tokens.get('civitai_token') else "(Optional)"

#@markdown ---
#@markdown ### Hugging Face Token
#@markdown Get yours from: https://huggingface.co/settings/tokens
hf_token_input = "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" #@param {type:"string"}
#@markdown **Save Hugging Face Token to Google Drive?** (Overwrites existing if checked)
save_hf_token = True #@param {type:"boolean"}

#@markdown ---
#@markdown ### Civitai API Key
#@markdown Get yours from: https://civitai.com/user/account (Create API Key section)
civitai_token_input = "8cd1ea2f1f643d6ce1e04c2d5dea119b" #@param {type:"string"}
#@markdown **Save Civitai API Key to Google Drive?** (Overwrites existing if checked)
save_civitai_token = True #@param {type:"boolean"}
#@markdown ---

# --- Process and Save Tokens ---
tokens_to_save = existing_tokens.copy() # Start with existing tokens

# Update Hugging Face token if provided and save is checked
if hf_token_input and save_hf_token:
    tokens_to_save['hf_token'] = hf_token_input.strip()
    print("ℹ️ Hugging Face token provided and marked for saving.")
elif hf_token_input and not save_hf_token:
     print("ℹ️ Hugging Face token provided but *not* marked for saving to Drive.")
     # Optionally, you could store it temporarily in a variable for the session
     # session_hf_token = hf_token_input.strip()
elif not hf_token_input and save_hf_token:
     print("⚠️ Warning: 'Save Hugging Face Token' checked, but no token was entered.")

# Update Civitai token if provided and save is checked
if civitai_token_input and save_civitai_token:
    tokens_to_save['civitai_token'] = civitai_token_input.strip()
    print("ℹ️ Civitai token provided and marked for saving.")
elif civitai_token_input and not save_civitai_token:
     print("ℹ️ Civitai token provided but *not* marked for saving to Drive.")
     # session_civitai_token = civitai_token_input.strip()
elif not civitai_token_input and save_civitai_token:
     print("⚠️ Warning: 'Save Civitai API Key' checked, but no key was entered.")


# Save if any changes were marked for saving
if save_hf_token or save_civitai_token:
    # Only save if at least one relevant token was actually provided
    if (save_hf_token and hf_token_input) or (save_civitai_token and civitai_token_input):
         token_manager.save_tokens(tokens_to_save)
    else:
         print("ℹ️ No new tokens were provided to save, even though a save box was checked.")
else:
    print("ℹ️ No tokens were marked for saving to Google Drive in this run.")


# Display final status (partially masked for security)
print("\n--- Current Token Status ---")
loaded_hf = tokens_to_save.get('hf_token')
loaded_civitai = tokens_to_save.get('civitai_token')

if loaded_hf:
    print(f"🔑 Hugging Face Token: Loaded (hf_...{loaded_hf[-4:]})")
else:
    print("❓ Hugging Face Token: Not configured.")

if loaded_civitai:
     # Civitai keys can be shorter, adjust masking if needed
     mask_len = min(4, len(loaded_civitai) // 2)
     print(f"🔑 Civitai API Key: Loaded (...{loaded_civitai[-mask_len:]})")
else:
    print("❓ Civitai API Key: Not configured.")

print("\n✅ Token configuration cell complete.")


In [None]:
# Cell 5: Model Management (with Pre-Download Form)

import os
import requests
import shutil
import json
from tqdm.notebook import tqdm # For progress bars
from huggingface_hub import snapshot_download, HfApi, hf_hub_url, hf_hub_download
from urllib.parse import urlparse, parse_qs
import time
from google.colab import output # To potentially clear form output if needed

# --- Configuration Flags ---
# Set these flags to True to enable the corresponding functionality
# NOTE: Enabling these also requires installing the necessary dependencies (e.g., in Cell 2)
ENABLE_SAM = False
ENABLE_WOOP = False

# --- Configuration & Paths ---
# Assuming variables from Cell 3 & 4 are available:
# BASE_DRIVE_PATH, MODELS_PATH, SAM_MODELS_PATH, WOOP_MODELS_PATH, CONFIG_PATH, TOKENS_FILE
# Also assuming 'tokens_to_save' dictionary from Cell 4 is available globally.

# --- Standalone Setup (for independent execution, keep commented otherwise) ---
# BASE_DRIVE_PATH = '/content/drive/MyDrive/AI_Studio_Toolkit_v3'
# MODELS_PATH = os.path.join(BASE_DRIVE_PATH, 'models')
# SAM_MODELS_PATH = os.path.join(MODELS_PATH, 'sam_models') # Path defined even if disabled
# WOOP_MODELS_PATH = os.path.join(MODELS_PATH, 'woop_models') # Path defined even if disabled
# CONFIG_PATH = os.path.join(BASE_DRIVE_PATH, 'config')
# TOKENS_FILE = os.path.join(CONFIG_PATH, 'tokens.json')
# class StandaloneTokenManager:
#     def __init__(self, filepath): self.filepath = filepath
#     def load_tokens(self):
#         if os.path.exists(self.filepath):
#             try:
#                 with open(self.filepath, 'r') as f: return json.load(f)
#             except: return {}
#         return {}
# token_manager_standalone = StandaloneTokenManager(TOKENS_FILE)
# tokens_to_save = token_manager_standalone.load_tokens()
# print(f"[Standalone] Loaded tokens: {tokens_to_save}")
# --- End Standalone Setup ---


print("--- Initializing Model Manager ---")
print(f"SAM Functionality Enabled: {ENABLE_SAM}")
print(f"Woop Functionality Enabled: {ENABLE_WOOP}")


# --- Predefined Model List ---
# Format: 'Display Name': ('type', 'identifier', 'optional_subfolder', 'optional_filename_filter')
PREDEFINED_MODELS = {
    # --- Hugging Face Models ---
    "Deliberate V3 (HF)": ('hf', 'stablediffusionapi/deliberate-v3', 'deliberate-v3', None),
    "Anything V5 (HF)": ('hf', 'stablediffusionapi/anything-v5', 'anything-v5', None),
    "Realistic Vision V5.1 (HF)": ('hf', 'SG161222/Realistic_Vision_V5.1_noVAE', 'Realistic_Vision_V5.1_noVAE', None),
    "DreamShaper 8 (HF)": ('hf', 'Lykon/dreamshaper-8', 'dreamshaper-8', None),
    "Absolute Reality V1.8.1 (HF)": ('hf', 'stablediffusionapi/absolute-reality-v1.8.1', 'absolute-reality-v1.8.1', None),
    # --- Civitai Models ---
    "ChilloutMix (Civitai)": ('civitai', '11745', 'chilloutmix', '.safetensors'),
    "OrangeMixs (Civitai)": ('civitai', '49829', 'OrangeMixs', '.safetensors'),
    "Perfect Deliberate (Civitai)": ('civitai', '128456', 'perfectdeliberate-v5', '.safetensors'),
    # --- SAM Models (Only active if ENABLE_SAM=True) ---
    "SAM ViT-H (Default)": ('sam', 'sam_vit_h_4b8939.pth', 'sam_vit_h', None), # Added subfolder
    # --- Woop Models (Only active if ENABLE_WOOP=True) ---
    "Woop Insighter FaceAnalysis": ('woop', 'inswapper_128.onnx', 'inswapper', None), # Added subfolder
}

# --- ModelManager Class ---
class ModelManager:
    """Handles downloading, organizing, and listing models, with conditional SAM/Woop."""

    def __init__(self, models_base_path, sam_models_path, woop_models_path, tokens):
        self.models_base_path = models_base_path
        self.sam_models_path = sam_models_path # Store path even if disabled
        self.woop_models_path = woop_models_path # Store path even if disabled
        self.tokens = tokens
        self.hf_token = self.tokens.get('hf_token')
        self.civitai_token = self.tokens.get('civitai_token')
        self.hf_api = HfApi(token=self.hf_token) if self.hf_token else HfApi()

        # Ensure base diffusion models directory exists
        os.makedirs(self.models_base_path, exist_ok=True)
        print(f"✓ Diffusion Model Base Path: {self.models_base_path}")

        # Conditionally ensure SAM/Woop directories exist
        if ENABLE_SAM:
            os.makedirs(self.sam_models_path, exist_ok=True)
            print(f"✓ SAM Models Path (Enabled): {self.sam_models_path}")
        else:
            print(f"ℹ️ SAM Models Path (Disabled): {self.sam_models_path}")

        if ENABLE_WOOP:
            os.makedirs(self.woop_models_path, exist_ok=True)
            print(f"✓ Woop Models Path (Enabled): {self.woop_models_path}")
        else:
            print(f"ℹ️ Woop Models Path (Disabled): {self.woop_models_path}")


    def _get_target_dir(self, model_type, identifier, subfolder_name=None):
        """Determines the correct target directory based on model type."""
        # Helper to create safe directory names
        def safe_name(name):
            # Basic sanitization, replace slashes and backslashes
            return name.replace('/', '_').replace('\\', '_')

        if model_type in ['hf', 'civitai']:
            base = self.models_base_path
            # Use provided subfolder or generate one from identifier
            effective_subfolder = subfolder_name if subfolder_name else safe_name(identifier)
            return os.path.join(base, effective_subfolder)

        elif model_type == 'sam':
            if not ENABLE_SAM: return None # Return None if SAM is disabled
            base = self.sam_models_path
            effective_subfolder = subfolder_name if subfolder_name else safe_name(identifier)
            return os.path.join(base, effective_subfolder)

        elif model_type == 'woop':
            if not ENABLE_WOOP: return None # Return None if Woop is disabled
            base = self.woop_models_path
            effective_subfolder = subfolder_name if subfolder_name else safe_name(identifier)
            return os.path.join(base, effective_subfolder)
        else:
            print(f"⚠️ Unknown model type '{model_type}'. Cannot determine target directory.")
            return None


    def list_local_models(self, model_type='all'):
        """Lists downloaded models in the specified directories, respecting enabled features."""
        print(f"\n--- Listing Local Models (Type: {model_type}, SAM: {ENABLE_SAM}, Woop: {ENABLE_WOOP}) ---")
        found_models = {} # Dictionary to store 'DisplayName': 'path'

        def find_models(path, type_label):
            # Check if the path exists before trying to listdir
            if not os.path.exists(path):
                return {}
            models = {}
            try:
                # Iterate through items (files or directories) in the path
                for item in os.listdir(path):
                    item_path = os.path.join(path, item)
                    is_dir = os.path.isdir(item_path)
                    # Check common model file extensions
                    is_model_file = item.endswith(('.safetensors', '.ckpt', '.pth', '.onnx', '.bin'))

                    if is_dir or is_model_file:
                        found_name = item # Default to folder/file name
                        # Try to map back to the predefined display name
                        for name, details in PREDEFINED_MODELS.items():
                            p_type, p_id, p_subfolder, p_filter = details
                            # Match if the type matches AND...
                            if p_type == type_label:
                                # 1. Subfolder name matches the item name (common case)
                                match_subfolder = p_subfolder == item
                                # 2. Identifier matches the item name (e.g., for .pth, .onnx files)
                                match_id_as_item = p_id == item
                                # 3. Special check for HF repos where item is subfolder derived from repo_id
                                match_hf_repo_subfolder = (p_type == 'hf' and item == p_id.replace('/', '_'))
                                # 4. Check if item_path (for files) matches the expected file path based on get_model_path logic
                                expected_file_path = None
                                if is_model_file and p_type in ['civitai', 'sam', 'woop']:
                                     # Reconstruct potential expected path (simplified check)
                                     temp_target_dir = self._get_target_dir(p_type, p_id, p_subfolder)
                                     if temp_target_dir:
                                          # We don't know the exact filename downloaded by civitai easily here,
                                          # so this check is less reliable for civitai without more state.
                                          # Focus on matching subfolder or id for mapping back.
                                          pass


                                if match_subfolder or match_id_as_item or match_hf_repo_subfolder:
                                    found_name = name
                                    break
                        # Store the display name -> path mapping
                        # If a display name maps to multiple items (unlikely with good subfolders), last one wins
                        models[found_name] = item_path
            except Exception as e:
                print(f"Error listing models in {path}: {e}")
            return models

        # List Diffusion models (HF/Civitai) - they share the base path
        if model_type in ['hf', 'civitai', 'all']:
            found_models.update(find_models(self.models_base_path, 'hf'))
            found_models.update(find_models(self.models_base_path, 'civitai'))

        # Conditionally list SAM models
        if ENABLE_SAM and model_type in ['sam', 'all']:
            found_models.update(find_models(self.sam_models_path, 'sam'))

        # Conditionally list Woop models
        if ENABLE_WOOP and model_type in ['woop', 'all']:
            found_models.update(find_models(self.woop_models_path, 'woop'))

        if not found_models:
            print("No local models found matching the criteria and enabled features.")
        else:
            print("Found:")
            # Sort by name for consistent output
            for name in sorted(found_models.keys()):
                 path = found_models[name]
                 type_indicator = "[Dir]" if os.path.isdir(path) else "[File]"
                 print(f"  - {name} ({path}) {type_indicator}")
        return found_models


    def get_model_path(self, model_name):
        """
        Gets the local path of a model using its predefined name.
        Returns the directory path for HF models, or the specific file path for others.
        Returns None if not found or feature is disabled.
        """
        if model_name not in PREDEFINED_MODELS:
            return None

        model_type, identifier, subfolder, filename_filter = PREDEFINED_MODELS[model_name]

        # Check if the feature is enabled for SAM/Woop
        if model_type == 'sam' and not ENABLE_SAM: return None
        if model_type == 'woop' and not ENABLE_WOOP: return None

        target_dir = self._get_target_dir(model_type, identifier, subfolder)
        if not target_dir:
             return None

        # --- Hugging Face Model Check ---
        if model_type == 'hf':
            if os.path.isdir(target_dir):
                # Check for indicator files
                has_config = os.path.exists(os.path.join(target_dir, "config.json"))
                has_model_index = os.path.exists(os.path.join(target_dir, "model_index.json"))
                has_safetensors = any(f.endswith(".safetensors") for f in os.listdir(target_dir) if os.path.isfile(os.path.join(target_dir, f)))
                if has_config or has_model_index or has_safetensors:
                    return target_dir # Return directory path
            return None # Directory doesn't exist or lacks key files

        # --- Civitai, SAM, Woop Model Check (Expect single file within the subfolder) ---
        elif model_type in ['civitai', 'sam', 'woop']:
            if os.path.isdir(target_dir):
                 expected_file = None
                 try:
                     files_in_dir = [f for f in os.listdir(target_dir) if os.path.isfile(os.path.join(target_dir, f))]
                     # Try to find file matching the filter (Civitai) or identifier (SAM/Woop)
                     for f in files_in_dir:
                          matches_filter = filename_filter and filename_filter.lower() in f.lower()
                          matches_id = identifier == f # SAM/Woop identifier is often the filename
                          if matches_filter or (model_type in ['sam', 'woop'] and matches_id):
                               expected_file = os.path.join(target_dir, f)
                               break
                     # Fallback: if only one relevant model file exists, assume it's the one
                     if not expected_file:
                          relevant_files = [f for f in files_in_dir if f.endswith(('.safetensors', '.ckpt', '.pth', '.onnx', '.bin'))]
                          if len(relevant_files) == 1:
                              expected_file = os.path.join(target_dir, relevant_files[0])

                 except FileNotFoundError:
                     return None # Directory doesn't exist

                 if expected_file and os.path.exists(expected_file):
                     return expected_file # Return specific file path
            return None # Directory doesn't exist or expected file not found
        else:
            print(f"❌ Unknown model type '{model_type}' for model '{model_name}'.")
            return None


    def download_model(self, model_name):
        """Downloads a model based on its predefined name, respecting enabled features."""
        if model_name not in PREDEFINED_MODELS:
            print(f"❌ Model '{model_name}' not found in predefined list.")
            return None

        model_type, identifier, subfolder, filename_filter = PREDEFINED_MODELS[model_name]

        # --- Feature Enablement Checks ---
        if model_type == 'sam' and not ENABLE_SAM:
            print(f"❌ Cannot download '{model_name}': SAM functionality is disabled.")
            return None
        if model_type == 'woop' and not ENABLE_WOOP:
            print(f"❌ Cannot download '{model_name}': Woop functionality is disabled.")
            return None

        target_dir = self._get_target_dir(model_type, identifier, subfolder)
        if not target_dir:
             print(f"❌ Could not determine target directory for '{model_name}'.")
             return None

        # --- Check if Already Exists (before printing download message) ---
        existing_path = self.get_model_path(model_name)
        if existing_path:
             print(f"✅ Model '{model_name}' already downloaded. Path: {existing_path}")
             return existing_path # Return existing path

        # --- Proceed with Download ---
        print(f"\n--- Downloading Model: {model_name} ---")
        print(f"Type: {model_type}, Identifier: {identifier}, Target Dir: {target_dir}")
        os.makedirs(target_dir, exist_ok=True) # Ensure target dir exists

        downloaded_path = None # Variable to store the final path of the downloaded item

        try:
            # --- Hugging Face Download ---
            if model_type == 'hf':
                print(f"Downloading Hugging Face model repo '{identifier}'...")
                # snapshot_download returns the path to the downloaded directory
                downloaded_path = snapshot_download(
                    repo_id=identifier,
                    local_dir=target_dir,
                    token=self.hf_token,
                    local_dir_use_symlinks=False,
                    resume_download=True,
                )
                print(f"✅ Successfully downloaded HF model '{model_name}' to {downloaded_path}")

            # --- Civitai Download ---
            elif model_type == 'civitai':
                print(f"Attempting to download Civitai model ID '{identifier}'...")
                if not self.civitai_token:
                    print("❌ Civitai API Key required. Please configure in Cell 4.")
                    return None

                # Fetch model version details
                api_url = f"https://civitai.com/api/v1/models/{identifier}"
                headers = {"Authorization": f"Bearer {self.civitai_token}"}
                response = requests.get(api_url, headers=headers, timeout=30)
                response.raise_for_status()
                model_data = response.json()

                download_url = None
                selected_filename = None
                # Find suitable file URL (logic unchanged from previous version)
                if 'modelVersions' in model_data and model_data['modelVersions']:
                    for version in model_data['modelVersions']:
                         if 'files' in version and version['files']:
                             primary_file = next((f for f in version['files'] if f.get('primary')), None)
                             if primary_file and (not filename_filter or filename_filter.lower() in primary_file['name'].lower()):
                                 download_url = primary_file['downloadUrl']
                                 selected_filename = primary_file['name']
                                 print(f"Found primary file: {selected_filename} in version {version.get('name', 'N/A')}")
                                 break
                             if not download_url:
                                 for file_info in version['files']:
                                     if filename_filter and filename_filter.lower() in file_info['name'].lower():
                                         download_url = file_info['downloadUrl']
                                         selected_filename = file_info['name']
                                         print(f"Found matching file: {selected_filename} in version {version.get('name', 'N/A')}")
                                         break
                         if download_url: break

                if not download_url or not selected_filename:
                    print(f"❌ Could not find suitable download URL for Civitai model ID {identifier} matching filter '{filename_filter}'.")
                    return None

                # Download the file
                final_download_url = f"{download_url}?token={self.civitai_token}"
                filepath = os.path.join(target_dir, selected_filename)
                print(f"Downloading Civitai file: {selected_filename}")
                with requests.Session() as session:
                     session.headers.update(headers)
                     response = session.get(final_download_url, stream=True, timeout=120) # Longer timeout for large files
                     response.raise_for_status()
                     total_size = int(response.headers.get('content-length', 0))
                     block_size = 1024 * 1024
                     with open(filepath, 'wb') as f, tqdm(
                         desc=selected_filename, total=total_size, unit='iB',
                         unit_scale=True, unit_divisor=1024, leave=False # leave=False hides bar on completion
                     ) as bar:
                         for chunk in response.iter_content(chunk_size=block_size):
                             if chunk:
                                 size = f.write(chunk)
                                 bar.update(size)
                downloaded_path = filepath # Store the file path
                print(f"✅ Successfully downloaded Civitai model '{model_name}' to {downloaded_path}")


            # --- SAM Download (Placeholder) ---
            elif model_type == 'sam':
                print("ℹ️ SAM model download - Attempting generic download...")
                SAM_BASE_URL = "https://dl.fbaipublicfiles.com/segment_anything/"
                sam_url = f"{SAM_BASE_URL}{identifier}"
                filepath = os.path.join(target_dir, identifier)
                print(f"Attempting download from: {sam_url}")
                try:
                    response = requests.get(sam_url, stream=True, timeout=120)
                    response.raise_for_status()
                    total_size = int(response.headers.get('content-length', 0))
                    block_size = 1024 * 1024
                    with open(filepath, 'wb') as f, tqdm(
                        desc=identifier, total=total_size, unit='iB',
                        unit_scale=True, unit_divisor=1024, leave=False
                    ) as bar:
                        for chunk in response.iter_content(chunk_size=block_size):
                            if chunk:
                                size = f.write(chunk)
                                bar.update(size)
                    downloaded_path = filepath
                    print(f"✅ Successfully downloaded SAM model '{model_name}' to {downloaded_path}")
                except requests.exceptions.RequestException as e:
                     print(f"❌ Failed to download SAM model from {sam_url}. Error: {e}")
                     if os.path.exists(filepath): os.remove(filepath)
                     return None # Explicitly return None on failure

            # --- Woop Download (Placeholder) ---
            elif model_type == 'woop':
                print("⚠️ Woop (insightface) model download - Placeholder attempt...")
                woop_url = identifier # Assume identifier is URL
                filepath = os.path.join(target_dir, os.path.basename(identifier))
                print(f"Attempting download from: {woop_url}")
                try:
                    response = requests.get(woop_url, stream=True, timeout=60)
                    response.raise_for_status()
                    # Simplified download without progress for placeholder
                    with open(filepath, 'wb') as f:
                         for chunk in response.iter_content(chunk_size=8192):
                             if chunk: f.write(chunk)
                    downloaded_path = filepath
                    print(f"✅ Placeholder download complete for Woop model '{model_name}' to {downloaded_path}")
                except Exception as e:
                     print(f"❌ Failed placeholder download for Woop model. Error: {e}")
                     return None # Explicitly return None on failure

            else:
                print(f"❌ Unknown model type '{model_type}' for download.")
                return None

            # Return the path of the downloaded item (directory for HF, file for others)
            return downloaded_path

        except requests.exceptions.RequestException as e:
            print(f"❌ Download failed for '{model_name}': Network error {e}")
            return None
        except Exception as e:
            print(f"❌ An error occurred during download for '{model_name}': {e}")
            return None


# --- Instantiate and Run ---
print("\nInstantiating ModelManager...")
model_manager = None # Initialize as None
if 'tokens_to_save' not in globals():
     print("❌ ERROR: 'tokens_to_save' dictionary not found. Please run Cell 4 first.")
else:
     try:
         # Pass the specific paths from Cell 3 context
         model_manager = ModelManager(MODELS_PATH, SAM_MODELS_PATH, WOOP_MODELS_PATH, tokens_to_save)
         print("\n--- Initial Model Scan ---")
         model_manager.list_local_models() # List all models respecting enabled features
     except Exception as e:
         print(f"❌ Failed to initialize ModelManager or scan models: {e}")

# --- Colab Form for Pre-Download ---
if model_manager: # Only show form if manager initialized successfully
    # Filter models for the dropdown based on enabled features
    available_models_for_download = ["(None - Skip Pre-Download)"]
    for name, details in PREDEFINED_MODELS.items():
        model_type = details[0]
        if model_type in ['hf', 'civitai']:
            available_models_for_download.append(name)
        elif model_type == 'sam' and ENABLE_SAM:
            available_models_for_download.append(name)
        elif model_type == 'woop' and ENABLE_WOOP:
            available_models_for_download.append(name)

    # Sort the list alphabetically, keeping "(None...)" at the top
    available_models_for_download = [available_models_for_download[0]] + sorted(available_models_for_download[1:])

    print("\n--- Optional: Pre-Download a Model ---")
    print("Select a model to download now, ensuring it's ready when Gradio starts.")
    # Clear previous form output to prevent duplicates if re-run
    output.clear(wait=True) # wait=True prevents flickering

    #@markdown Select a model to download to Google Drive:
    selected_model_to_download = "Anything V5 (HF)" #@param {type:"string"} ["(None - Skip Pre-Download)", "Deliberate V3 (HF)", "Anything V5 (HF)", "Realistic Vision V5.1 (HF)", "DreamShaper 8 (HF)", "Absolute Reality V1.8.1 (HF)", "ChilloutMix (Civitai)", "OrangeMixs (Civitai)", "Perfect Deliberate (Civitai)"]
    # Note: Manually update the list in the @param line if PREDEFINED_MODELS changes significantly
    # or regenerate this cell. For dynamic updates based on flags, manual update is easiest in Colab.

    # --- Trigger Download Based on Form Selection ---
    if selected_model_to_download != "(None - Skip Pre-Download)":
        print(f"\nUser selected '{selected_model_to_download}' for pre-download.")
        if selected_model_to_download in PREDEFINED_MODELS:
             # Call the download function - it handles checks for existing files
             model_manager.download_model(selected_model_to_download)
             print("\n--- Model Scan After Pre-Download ---") # Optional: Rescan after download
             model_manager.list_local_models()
        else:
             print(f"⚠️ Warning: Selected model '{selected_model_to_download}' not found in PREDEFINED_MODELS dictionary.")
    else:
        print("\nℹ️ No model selected for pre-download.")


# --- Final Status ---
if model_manager:
    print("\n✅ Model Management cell setup complete.")
    print(f"   SAM Enabled: {ENABLE_SAM}, Woop Enabled: {ENABLE_WOOP}")
else:
    print("\n⚠️ Model Management cell finished, but ModelManager could not be initialized (likely missing tokens or other error).")



In [None]:
# Cell 6: Image Generation & Segmentation Pipelines (with NumPy Check)

# --- NumPy Version Check ---
# Add this block at the very beginning to diagnose the issue
import sys
print(f"--- Checking NumPy Version ---")
try:
    import numpy
    print(f"✅ Found NumPy version: {numpy.__version__}")
    # Explicitly check for the problematic attribute
    has_dtypes = hasattr(numpy, 'dtypes')
    print(f"   numpy.dtypes attribute exists: {has_dtypes}")
    if not has_dtypes:
        print(f"   ⚠️ The loaded NumPy ({numpy.__version__}) is missing the 'dtypes' attribute!")
        print(f"      This is likely the cause of the error.")
        print(f"      Expected version from Cell 2: 1.24.3") # Adjust if Cell 2 changes
        print(f"      RECOMMENDATION: Restart Runtime and run all cells from the beginning.")

except ImportError:
    print("❌ NumPy not found. Please ensure it's installed correctly in Cell 2.")
    # Raise an error or exit if numpy is critical and missing
    raise ImportError("NumPy is required but could not be imported.")
except Exception as e:
    print(f"❌ An unexpected error occurred while checking NumPy: {e}")
print(f"--- End NumPy Check ---")
# --- End NumPy Version Check ---


import torch
# import numpy as np # Already imported above for check
from PIL import Image, ImageOps
import os
import time
import random
from diffusers import (
    DiffusionPipeline, StableDiffusionPipeline, StableDiffusionInpaintPipeline,
    StableDiffusionImg2ImgPipeline, # Potentially useful for variations
    DPMSolverMultistepScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler,
    LMSDiscreteScheduler, DDIMScheduler, UniPCMultistepScheduler
)
import gc # Garbage collector for VRAM management
from datetime import datetime
import traceback # For detailed error printing if needed


# --- Conditional SAM Imports ---
# These imports will only succeed if SAM dependencies were installed (e.g., in Cell 2)
# and ENABLE_SAM was set to True in Cell 5.
if 'ENABLE_SAM' in globals() and ENABLE_SAM:
    try:
        # Use segment_anything_hq if available, otherwise fall back
        try:
            from segment_anything_hq import SamPredictor, sam_model_registry
            print("✅ Imported Segment Anything HQ")
        except ImportError:
            try:
                from segment_anything import SamPredictor, sam_model_registry
                print("✅ Imported Segment Anything (standard)")
            except ImportError:
                print("⚠️ WARNING: ENABLE_SAM is True, but 'segment_anything' or 'segment_anything_hq' library not found.")
                print("   SAM functionality will be unavailable. Please install dependencies.")
                ENABLE_SAM = False # Disable SAM if import fails

        # Supervision is often used for mask processing/visualization with SAM
        try:
            import supervision as sv
            print("✅ Imported Supervision")
        except ImportError:
            print("⚠️ Warning: 'supervision' library not found. Mask processing helpers might be limited.")

        # OpenCV is usually required by SAM/Supervision
        try:
            import cv2
            print("✅ Imported OpenCV (cv2)")
        except ImportError:
            print("⚠️ WARNING: ENABLE_SAM is True, but 'opencv-python' library not found.")
            print("   SAM functionality might be impaired. Please install dependencies.")
            # Consider disabling SAM if cv2 is strictly required by the implementation
            # ENABLE_SAM = False

    except Exception as e:
        print(f"❌ Error during conditional SAM/related imports: {e}")
        ENABLE_SAM = False # Disable SAM on unexpected import errors
else:
    # Ensure ENABLE_SAM is False if it wasn't defined earlier (e.g., running cell standalone)
    if 'ENABLE_SAM' not in globals():
        ENABLE_SAM = False
    print("ℹ️ SAM functionality is disabled (ENABLE_SAM=False or flag not found).")


# --- Constants and Configuration ---
DEFAULT_SCHEDULER = "DPMSolverMultistepScheduler" # Default sampler
AVAILABLE_SCHEDULERS = {
    "DPMSolverMultistepScheduler": DPMSolverMultistepScheduler,
    "EulerDiscreteScheduler": EulerDiscreteScheduler,
    "EulerAncestralDiscreteScheduler": EulerAncestralDiscreteScheduler,
    "UniPCMultistepScheduler": UniPCMultistepScheduler,
    "LMSDiscreteScheduler": LMSDiscreteScheduler,
    "DDIMScheduler": DDIMScheduler,
}
# Default path for SAM model if enabled (ensure this matches PREDEFINED_MODELS in Cell 5)
DEFAULT_SAM_MODEL_NAME = "SAM ViT-H (Default)"


# --- Image Generation Class ---
class ImageGenerator:
    """Handles loading models and generating images using various pipelines."""

    def __init__(self, model_manager, output_path):
        """
        Initializes the ImageGenerator.

        Args:
            model_manager (ModelManager): Instance of the ModelManager from Cell 5.
            output_path (str): Path to the directory where generated images will be saved.
        """
        self.model_manager = model_manager
        self.output_path = output_path
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"ImageGenerator initialized on device: {self.device}")
        if self.device == "cpu":
            print("⚠️ Warning: No GPU detected. Image generation will be very slow.")

        self.current_pipeline = None
        self.current_model_name = None # Store the display name of the loaded model
        self.current_pipeline_type = None # 'txt2img', 'inpaint', etc.

        # --- Conditional SAM Initialization ---
        self.sam_predictor = None
        self.sam_model_name = None
        # Use globals().get() for safer check in case ENABLE_SAM wasn't defined somehow
        if globals().get('ENABLE_SAM', False):
            self._initialize_sam()

    def _initialize_sam(self):
        """Loads the SAM model if enabled and available."""
        # Double-check flag just in case
        if not globals().get('ENABLE_SAM', False):
            print("ℹ️ SAM is disabled, skipping initialization.")
            return
        if self.sam_predictor:
             print("ℹ️ SAM predictor already initialized.")
             return

        print("--- Initializing SAM Predictor ---")
        # Ensure model_manager is valid
        if not hasattr(self, 'model_manager') or self.model_manager is None:
             print("❌ Cannot initialize SAM: ModelManager is not available.")
             globals()['ENABLE_SAM'] = False
             print("⚠️ SAM functionality has been disabled.")
             return

        sam_model_path = self.model_manager.get_model_path(DEFAULT_SAM_MODEL_NAME)

        if not sam_model_path or not os.path.exists(sam_model_path):
            print(f"❌ SAM model file not found at expected path derived from '{DEFAULT_SAM_MODEL_NAME}'. Looked for: {sam_model_path}")
            print(f"   Please download it using the Model Manager (Cell 5).")
            globals()['ENABLE_SAM'] = False # Update global flag
            print("⚠️ SAM functionality has been disabled due to missing model.")
            return

        # Determine SAM model type from filename (common convention)
        sam_filename = os.path.basename(sam_model_path)
        if "vit_h" in sam_filename: model_type = "vit_h"
        elif "vit_l" in sam_filename: model_type = "vit_l"
        elif "vit_b" in sam_filename: model_type = "vit_b"
        else:
            print(f"⚠️ Could not determine SAM model type from filename: {sam_filename}. Assuming 'vit_h'.")
            model_type = "vit_h" # Default guess

        try:
            print(f"Loading SAM model (type: {model_type}) from: {sam_model_path}")
            # Ensure sam_model_registry is available
            if 'sam_model_registry' not in globals():
                 raise NameError("sam_model_registry not found. SAM library import likely failed.")

            sam = sam_model_registry[model_type](checkpoint=sam_model_path)
            sam.to(device=self.device)
            self.sam_predictor = SamPredictor(sam)
            self.sam_model_name = DEFAULT_SAM_MODEL_NAME # Store name of loaded SAM model
            print(f"✅ SAM predictor initialized successfully with '{self.sam_model_name}'.")

        except NameError as ne:
             print(f"❌ Error initializing SAM: {ne}. Library might not be imported correctly.")
             globals()['ENABLE_SAM'] = False
             print("⚠️ SAM functionality has been disabled.")
        except Exception as e:
            print(f"❌ Error loading SAM model: {e}")
            print(f"   Ensure the model file at {sam_model_path} is valid and dependencies are installed.")
            traceback.print_exc() # Print detailed traceback for debugging
            self.sam_predictor = None
            self.sam_model_name = None
            globals()['ENABLE_SAM'] = False # Disable SAM if model loading fails
            print("⚠️ SAM functionality has been disabled due to model loading error.")


    def _unload_pipeline(self):
        """Moves the current pipeline to CPU and clears memory."""
        if self.current_pipeline is not None:
            print(f"Unloading pipeline: {self.current_model_name} ({self.current_pipeline_type})...")
            try:
                # Check if pipeline has 'to' method before calling
                if hasattr(self.current_pipeline, 'to'):
                    self.current_pipeline.to("cpu")
                else:
                    print("Warning: Pipeline object doesn't have 'to' method for CPU transfer.")
            except Exception as e:
                print(f"Warning: Error moving pipeline to CPU: {e}")

            # Delete references
            pipeline_ref = self.current_pipeline
            self.current_pipeline = None
            self.current_model_name = None
            self.current_pipeline_type = None
            del pipeline_ref # Explicitly delete the reference

            # Force garbage collection and clear CUDA cache
            gc.collect()
            if self.device == "cuda":
                torch.cuda.empty_cache()
            print("Pipeline unloaded and memory cleared.")
        else:
             # If no pipeline loaded, still try to clear cache just in case
             gc.collect()
             if self.device == "cuda":
                 torch.cuda.empty_cache()


    def _load_pipeline(self, model_name, task_type):
        """
        Loads the appropriate diffusion pipeline for the given model and task.
        Manages VRAM by unloading the previous pipeline.

        Args:
            model_name (str): The display name of the model from PREDEFINED_MODELS.
            task_type (str): The type of task ('txt2img', 'inpaint', 'img2img').

        Returns:
            bool: True if the pipeline was loaded successfully, False otherwise.
        """
        # Ensure model_manager is valid
        if not hasattr(self, 'model_manager') or self.model_manager is None:
             print("❌ Cannot load pipeline: ModelManager is not available.")
             return False

        # Check if requested pipeline is already loaded
        if model_name == self.current_model_name and task_type == self.current_pipeline_type and self.current_pipeline:
            print(f"Pipeline '{model_name}' for task '{task_type}' already loaded.")
            # Ensure it's on the correct device (might have been moved to CPU)
            if hasattr(self.current_pipeline, 'device') and str(self.current_pipeline.device) != self.device:
                 print(f"Moving existing pipeline back to {self.device}...")
                 try:
                     self.current_pipeline.to(self.device)
                     print("✅ Pipeline moved back to active device.")
                 except Exception as e:
                     print(f"❌ Failed to move existing pipeline to {self.device}: {e}")
                     # Force unload/reload if moving fails
                     self._unload_pipeline()
                     # Continue to reload logic below...
                 else:
                     return True # Already loaded and on correct device
            else:
                 return True # Already loaded and on correct device (or device check not possible)


        # Get model path using model manager
        model_path = self.model_manager.get_model_path(model_name)
        if not model_path:
            print(f"❌ Model '{model_name}' not found locally via ModelManager. Please download it first.")
            return False
        # Further check if path actually exists (get_model_path might return theoretical path)
        if not os.path.exists(model_path):
             print(f"❌ Model path found by manager but does not exist on disk: {model_path}")
             return False


        # Unload previous pipeline before loading new one
        self._unload_pipeline()

        print(f"\n--- Loading Pipeline ---")
        print(f"Model: {model_name}")
        print(f"Task: {task_type}")
        print(f"Path: {model_path}")

        pipeline_class = None
        load_method_name = None # Store name for logging
        load_args = {}

        # Determine pipeline class based on task
        if task_type == 'txt2img': pipeline_class = StableDiffusionPipeline
        elif task_type == 'inpaint': pipeline_class = StableDiffusionInpaintPipeline
        elif task_type == 'img2img': pipeline_class = StableDiffusionImg2ImgPipeline
        else:
            print(f"❌ Unsupported task type: {task_type}")
            return False

        # Determine load method based on path type (directory vs file)
        if os.path.isdir(model_path):
            load_method = pipeline_class.from_pretrained
            load_method_name = "from_pretrained"
            load_args['pretrained_model_name_or_path'] = model_path
        elif os.path.isfile(model_path) and (model_path.endswith(".safetensors") or model_path.endswith(".ckpt")):
            load_method = pipeline_class.from_single_file
            load_method_name = "from_single_file"
            load_args['pretrained_model_link_or_path'] = model_path
            # Add safety_checker=None for single file loads if needed, depends on diffusers version
            # load_args['safety_checker'] = None
        else:
            print(f"❌ Cannot determine load method for path type: {model_path}")
            return False

        try:
            print(f"Loading pipeline using: {load_method_name}...")
            start_time = time.time()
            # Common arguments for loading
            common_load_args = {
                "torch_dtype": torch.float16,
                # Add safety checker args if needed, often disabled for custom models
                # "safety_checker": None,
                # "requires_safety_checker": False,
            }
            # Merge specific args with common args
            final_load_args = {**load_args, **common_load_args}

            pipeline = load_method(**final_load_args)
            pipeline.to(self.device)

            # Optional: Enable optimizations if available and desired
            # try:
            #     pipeline.enable_xformers_memory_efficient_attention()
            #     print("Enabled xformers memory efficient attention.")
            # except Exception:
            #     try:
            #          # Fallback for newer diffusers/torch versions
            #          import torch.nn.functional as F
            #          pipeline.enable_attention_slicing()
            #          print("Enabled attention slicing (fallback).")
            #     except Exception as e_opt:
            #          print(f"Could not enable memory optimizations: {e_opt}")


            self.current_pipeline = pipeline
            self.current_model_name = model_name
            self.current_pipeline_type = task_type
            end_time = time.time()
            print(f"✅ Pipeline loaded successfully in {end_time - start_time:.2f} seconds.")
            return True

        except Exception as e:
            print(f"❌ Failed to load pipeline '{model_name}' (Task: {task_type}, Method: {load_method_name}). Error:")
            traceback.print_exc() # Print detailed traceback
            self.current_pipeline = None
            self.current_model_name = None
            self.current_pipeline_type = None
            # Attempt cleanup again
            gc.collect()
            if self.device == "cuda":
                torch.cuda.empty_cache()
            return False

    def _get_scheduler(self, scheduler_name):
        """Gets and configures a scheduler instance."""
        if not self.current_pipeline:
            print("❌ Cannot get scheduler, no pipeline loaded.")
            return None
        # Ensure pipeline has scheduler attribute
        if not hasattr(self.current_pipeline, 'scheduler'):
             print("❌ Current pipeline does not have a scheduler attribute.")
             return None


        scheduler_class = AVAILABLE_SCHEDULERS.get(scheduler_name)
        if not scheduler_class:
            print(f"⚠️ Scheduler '{scheduler_name}' not found in AVAILABLE_SCHEDULERS. Using default {DEFAULT_SCHEDULER}.")
            scheduler_class = AVAILABLE_SCHEDULERS.get(DEFAULT_SCHEDULER)
            if not scheduler_class: # Should not happen if default is in dict
                 print(f"❌ Default scheduler {DEFAULT_SCHEDULER} also not found!")
                 return None

        try:
            # Load scheduler config from the pipeline's current scheduler
            # This preserves settings like beta schedules etc.
            scheduler = scheduler_class.from_config(self.current_pipeline.scheduler.config)
            # Assign the new scheduler instance to the pipeline
            self.current_pipeline.scheduler = scheduler
            print(f"Using scheduler: {scheduler_name}")
            return scheduler
        except Exception as e:
            print(f"❌ Error setting scheduler '{scheduler_name}': {e}")
            traceback.print_exc()
            return None

    def _preprocess_image(self, image, target_size=512, divisible_by=8):
        """Converts and resizes PIL image for Stable Diffusion, ensuring divisibility."""
        if not isinstance(image, Image.Image):
             # Try to load if it's a path? No, expect PIL image.
            raise ValueError("Input 'image' must be a PIL Image object.")

        try:
            # Ensure RGB format
            if image.mode != "RGB":
                 print(f"Converting image from mode {image.mode} to RGB.")
                 image = image.convert("RGB")

            # Ensure dimensions are divisible by required number
            width, height = image.size
            new_width = width - (width % divisible_by)
            new_height = height - (height % divisible_by)

            # Handle cases where rounding down makes dimensions zero
            if new_width == 0: new_width = divisible_by
            if new_height == 0: new_height = divisible_by

            if new_width != width or new_height != height:
                print(f"Resizing image from ({width}, {height}) to ({new_width}, {new_height}) to be divisible by {divisible_by}.")
                image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) # High quality downsampling

            return image
        except Exception as e:
             print(f"❌ Error during image preprocessing: {e}")
             raise # Re-raise the exception


    def _preprocess_mask(self, mask, target_image):
        """Converts, resizes mask for inpainting. Expects white=inpaint area."""
        if not isinstance(mask, Image.Image):
            raise ValueError("Input 'mask' must be a PIL Image object.")
        if not isinstance(target_image, Image.Image):
             raise ValueError("Input 'target_image' must be a PIL Image object.")

        try:
            target_width, target_height = target_image.size

            # Ensure mask matches target image size
            if mask.size != target_image.size:
                print(f"Resizing mask from {mask.size} to {target_image.size} to match image.")
                mask = mask.resize((target_width, target_height), Image.Resampling.NEAREST) # Use NEAREST for sharp edges

            # Convert mask to grayscale ('L') for processing
            if mask.mode != 'L':
                 mask = mask.convert('L')

            # Binarize the mask: Ensure it's only 0 (black) and 255 (white)
            # Pixels > 127 become 255 (white), others become 0 (black)
            # This assumes white is the area to inpaint.
            threshold = 127
            mask = mask.point(lambda p: 255 if p > threshold else 0)


            # Convert final mask to RGB for the pipeline (most expect RGB mask)
            mask = mask.convert('RGB')

            return mask
        except Exception as e:
             print(f"❌ Error during mask preprocessing: {e}")
             raise # Re-raise the exception

    def _save_image(self, image, prompt=""):
        """Saves the generated PIL image to the output directory."""
        # Ensure output path exists
        try:
             os.makedirs(self.output_path, exist_ok=True)
        except OSError as e:
             print(f"❌ Error creating output directory {self.output_path}: {e}")
             return None # Cannot save if directory fails

        try:
            # Create a filename
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            # Sanitize prompt for filename
            safe_prompt = "".join(c if c.isalnum() or c in (' ', '_', '-') else '_' for c in prompt).strip()
            safe_prompt = safe_prompt[:50] # Limit length
            if not safe_prompt: safe_prompt = "generated_image" # Fallback if prompt is empty/unusable

            # Add random element to prevent collisions on identical prompts/timestamps
            rand_id = random.randint(1000, 9999)
            filename = f"{timestamp}_{safe_prompt}_{rand_id}.png"
            filepath = os.path.join(self.output_path, filename)

            # Save the image
            image.save(filepath, "PNG")
            print(f"✅ Image saved to: {filepath}")
            return filepath
        except Exception as e:
            print(f"❌ Error saving image to {self.output_path}: {e}")
            traceback.print_exc()
            return None

    # --- Core Generation Methods ---

    def text_to_image(self, model_name, prompt, negative_prompt="", guidance_scale=7.5,
                      num_inference_steps=30, seed=None, width=512, height=512,
                      scheduler_name=DEFAULT_SCHEDULER):
        """Generates an image from text prompts."""
        print(f"\n--- Task: Text-to-Image ---")
        if not self._load_pipeline(model_name, 'txt2img'):
            return None, None, seed # Return consistent tuple on failure

        if not self._get_scheduler(scheduler_name):
             print("⚠️ Failed to set scheduler. Proceeding with pipeline's default.")
             # Decide if this is critical - for now, we proceed

        # Validate dimensions (must be divisible by 8)
        proc_width = (width // 8) * 8
        proc_height = (height // 8) * 8
        if proc_width == 0 or proc_height == 0:
             print(f"❌ Invalid dimensions after ensuring divisibility by 8: {width}x{height} -> {proc_width}x{proc_height}")
             return None, None, seed
        if proc_width != width or proc_height != height:
             print(f"Adjusting dimensions to be divisible by 8: {width}x{height} -> {proc_width}x{proc_height}")
             width, height = proc_width, proc_height


        # Set seed for reproducibility
        generator = torch.Generator(device=self.device)
        if seed is None or seed == -1 or not isinstance(seed, int):
            seed = random.randint(0, 2**32 - 1) # Generate a valid seed
        generator.manual_seed(seed)
        print(f"Using Seed: {seed}")
        print(f"Parameters: Steps={num_inference_steps}, CFG={guidance_scale}, Size={width}x{height}, Scheduler={scheduler_name}")

        output_image = None
        saved_path = None
        start_time = time.time()
        try:
            print("Generating text-to-image...")
            # Ensure pipeline is on the correct device before inference
            self.current_pipeline.to(self.device)
            with torch.inference_mode(): # Use inference mode for efficiency
                result = self.current_pipeline(
                    prompt=prompt,
                    negative_prompt=negative_prompt,
                    guidance_scale=guidance_scale,
                    num_inference_steps=num_inference_steps,
                    generator=generator,
                    width=width,
                    height=height,
                ) # Add error callback? progress callback?
            output_image = result.images[0]
            end_time = time.time()
            print(f"Image generated in {end_time - start_time:.2f} seconds.")

            # Save the image
            saved_path = self._save_image(output_image, prompt)

        except Exception as e:
            print(f"❌ Error during text-to-image generation: {e}")
            traceback.print_exc()
            # Attempt to unload pipeline to free memory after error
            self._unload_pipeline()

        # Return image, path, and seed used (even on failure, return seed)
        return output_image, saved_path, seed


    def inpaint(self, model_name, prompt, negative_prompt="", base_image=None, mask_image=None,
                guidance_scale=7.5, num_inference_steps=50, seed=None,
                scheduler_name=DEFAULT_SCHEDULER, strength=0.8):
        """Fills masked areas of an image based on prompts."""
        print(f"\n--- Task: Inpainting ---")
        if base_image is None or mask_image is None:
            print("❌ Base image and mask image are required for inpainting.")
            return None, None, seed

        if not self._load_pipeline(model_name, 'inpaint'):
            return None, None, seed

        if not self._get_scheduler(scheduler_name):
             print("⚠️ Failed to set scheduler. Proceeding with pipeline's default.")

        # Preprocess images
        processed_image = None
        processed_mask = None
        try:
            print("Preprocessing images for inpainting...")
            # Ensure base_image is divisible by 8
            processed_image = self._preprocess_image(base_image, divisible_by=8)
            # Ensure mask matches processed image size and format (white=inpaint)
            processed_mask = self._preprocess_mask(mask_image, processed_image)
            print(f"Processed image size: {processed_image.size}, Mask size: {processed_mask.size}")
        except Exception as e:
             print(f"❌ Error during image/mask preprocessing for inpainting: {e}")
             traceback.print_exc()
             return None, None, seed

        # Set seed
        generator = torch.Generator(device=self.device)
        if seed is None or seed == -1 or not isinstance(seed, int):
            seed = random.randint(0, 2**32 - 1)
        generator.manual_seed(seed)
        print(f"Using Seed: {seed}")
        print(f"Parameters: Steps={num_inference_steps}, CFG={guidance_scale}, Strength={strength}, Scheduler={scheduler_name}")


        output_image = None
        saved_path = None
        start_time = time.time()
        try:
            print("Generating inpainting...")
            # Ensure pipeline is on the correct device
            self.current_pipeline.to(self.device)
            with torch.inference_mode():
                result = self.current_pipeline(
                    prompt=prompt,
                    negative_prompt=negative_prompt,
                    image=processed_image, # Use preprocessed image
                    mask_image=processed_mask, # Use preprocessed mask
                    guidance_scale=guidance_scale,
                    num_inference_steps=num_inference_steps,
                    generator=generator,
                    strength=strength,
                    # width=processed_image.width, # Usually inferred
                    # height=processed_image.height,
                )
            output_image = result.images[0]
            end_time = time.time()
            print(f"Inpainting generated in {end_time - start_time:.2f} seconds.")

            # Save the image
            saved_path = self._save_image(output_image, prompt + "_inpainted")

        except Exception as e:
            print(f"❌ Error during inpainting generation: {e}")
            traceback.print_exc()
            self._unload_pipeline() # Attempt to free memory

        return output_image, saved_path, seed


    def outpaint(self, model_name, prompt, negative_prompt="", base_image=None,
                 pixels_to_expand=128, expand_direction="all", # all, top, bottom, left, right
                 guidance_scale=7.5, num_inference_steps=50, seed=None,
                 scheduler_name=DEFAULT_SCHEDULER, strength=0.95): # Often need high strength for outpaint
        """Expands an image outwards using the inpainting pipeline."""
        print(f"\n--- Task: Outpainting ---")
        if base_image is None:
            print("❌ Base image is required for outpainting.")
            return None, None, seed

        print(f"Direction: {expand_direction}, Pixels: {pixels_to_expand}")
        expanded_image = None
        mask_rgb = None
        orig_width, orig_height = base_image.size

        try:
             print("Preparing canvas and mask for outpainting...")
             # 1. Ensure base image is RGB
             image = base_image.convert("RGB")

             # 2. Calculate padding and new dimensions
             pad_left, pad_right, pad_top, pad_bottom = 0, 0, 0, 0
             if expand_direction in ["all", "left"]: pad_left = pixels_to_expand
             if expand_direction in ["all", "right"]: pad_right = pixels_to_expand
             if expand_direction in ["all", "top"]: pad_top = pixels_to_expand
             if expand_direction in ["all", "bottom"]: pad_bottom = pixels_to_expand

             new_width = orig_width + pad_left + pad_right
             new_height = orig_height + pad_top + pad_bottom

             # Ensure new dimensions are divisible by 8
             final_width = max( ((new_width + 7) // 8) * 8, 8) # Ensure at least 8x8
             final_height = max( ((new_height + 7) // 8) * 8, 8)

             # Adjust padding to match final divisible dimensions
             width_diff = final_width - new_width
             height_diff = final_height - new_height
             # Distribute extra padding (e.g., add to right/bottom)
             pad_right += width_diff
             pad_bottom += height_diff
             new_width, new_height = final_width, final_height

             if new_width <= orig_width and new_height <= orig_height:
                  print("⚠️ Calculated expansion results in image size not increasing. Check parameters.")
                  # Fallback: Add minimum padding if direction was specified
                  if pad_left > 0: pad_left = max(pad_left, 8)
                  if pad_right > 0: pad_right = max(pad_right, 8)
                  if pad_top > 0: pad_top = max(pad_top, 8)
                  if pad_bottom > 0: pad_bottom = max(pad_bottom, 8)
                  # Recalculate based on minimum padding
                  new_width = orig_width + pad_left + pad_right
                  new_height = orig_height + pad_top + pad_bottom
                  new_width = max( ((new_width + 7) // 8) * 8, 8)
                  new_height = max( ((new_height + 7) // 8) * 8, 8)


             print(f"Expanding to: {new_width}x{new_height}")

             # 3. Create expanded canvas (fill with average color or noise?)
             # Simple gray fill for now
             expanded_image = Image.new("RGB", (new_width, new_height), (127, 127, 127))
             expanded_image.paste(image, (pad_left, pad_top))

             # 4. Create the mask (white in the expanded areas, black in original area)
             mask = Image.new("L", (new_width, new_height), 255) # Start with white
             mask_paste_black = Image.new("L", (orig_width, orig_height), 0) # Black rectangle
             mask.paste(mask_paste_black, (pad_left, pad_top))

             # Convert mask to RGB for pipeline
             mask_rgb = mask.convert("RGB")
             print("Canvas and mask prepared.")

        except Exception as e:
             print(f"❌ Error during outpainting preparation: {e}")
             traceback.print_exc()
             return None, None, seed

        # --- Call Inpainting Pipeline ---
        print("Calling inpaint pipeline for outpainting task...")
        # Pass the prepared expanded image and mask
        return self.inpaint(
            model_name=model_name,
            prompt=prompt,
            negative_prompt=negative_prompt,
            base_image=expanded_image, # Pass the expanded canvas
            mask_image=mask_rgb,       # Pass the mask covering new areas
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            seed=seed,
            scheduler_name=scheduler_name,
            strength=strength # Strength is important for outpainting
        )


    # --- SAM Mask Generation Method ---

    def predict_sam_mask(self, input_image, input_points):
        """
        Generates a segmentation mask using SAM based on input points.

        Args:
            input_image (PIL.Image): The image to segment.
            input_points (list): A list of (x, y) tuples representing click points.

        Returns:
            PIL.Image: A binary mask (mode 'L', 0=background, 255=mask) or None if failed.
        """
        print(f"\n--- Task: SAM Mask Prediction ---")
        # Check if SAM is enabled *and* predictor is initialized
        if not globals().get('ENABLE_SAM', False) or not self.sam_predictor:
            print("❌ SAM is not enabled or initialized. Cannot generate mask.")
            # Try to initialize if enabled but not initialized yet
            if globals().get('ENABLE_SAM', False) and not self.sam_predictor:
                 print("Attempting to initialize SAM predictor now...")
                 self._initialize_sam()
                 # Check again
                 if not self.sam_predictor:
                      print("❌ SAM predictor initialization failed. Cannot proceed.")
                      return None
            else:
                 return None # SAM not enabled

        if not input_points:
            print("ℹ️ No points provided for SAM prediction.")
            # Return an empty mask matching input image size
            try:
                 empty_mask = Image.new('L', input_image.size, 0)
                 return empty_mask
            except Exception:
                 return None # Cannot even create empty mask

        print(f"Generating SAM Mask ({len(input_points)} points)")
        final_mask_pil = None
        try:
            # Convert PIL Image to OpenCV format (BGR uint8)
            image_rgb = input_image.convert("RGB")
            image_np = np.array(image_rgb)
            image_cv2 = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)

            # Set the image in the SAM predictor
            print("Setting image in SAM predictor...")
            start_time = time.time()
            self.sam_predictor.set_image(image_cv2)
            print(f"Image set in {time.time() - start_time:.2f}s")

            # Format points and labels for SAM predictor
            points_np = np.array(input_points)
            # Labels: 1 = foreground point (target object), 0 = background point
            labels_np = np.ones(len(input_points), dtype=int) # Assume all clicks are targets

            print(f"Predicting mask with points: {input_points}")
            start_time = time.time()
            # Predict masks using the points
            masks, scores, logits = self.sam_predictor.predict(
                point_coords=points_np,
                point_labels=labels_np,
                multimask_output=True, # Get multiple masks per point set (usually 3)
            )
            # masks shape: (num_masks, H, W), boolean numpy array
            # scores shape: (num_masks,), float, IoU prediction score
            print(f"Prediction done in {time.time() - start_time:.2f}s.")
            print(f"Found {len(masks)} masks with scores: {[f'{s:.2f}' for s in scores]}")

            # --- Mask Selection/Merging Logic ---
            if len(masks) == 0:
                 print("⚠️ No masks found by SAM for the given points.")
                 # Return an empty (all black) mask
                 final_mask_np = np.zeros(image_cv2.shape[:2], dtype=np.uint8) # Match input image dims H, W
            else:
                 # Option 1: Take the mask with the highest score
                 # best_mask_idx = np.argmax(scores)
                 # final_mask_np = masks[best_mask_idx]
                 # print(f"Selected best mask (index {best_mask_idx}) with score {scores[best_mask_idx]:.2f}")

                 # Option 2: Combine all masks using logical OR (more inclusive)
                 print("Combining all predicted masks using logical OR.")
                 final_mask_np = np.logical_or.reduce(masks, axis=0)

                 # Convert boolean mask to uint8 (0 or 255)
                 final_mask_np = final_mask_np.astype(np.uint8) * 255

            # Convert final numpy mask back to PIL Image ('L' mode)
            final_mask_pil = Image.fromarray(final_mask_np, mode='L')
            print("✅ SAM mask generated successfully.")

        except NameError as ne:
             print(f"❌ SAM prediction failed: {ne}. Required libraries (cv2, numpy, segment_anything) might be missing or failed to import.")
             traceback.print_exc()
        except Exception as e:
            print(f"❌ Error during SAM mask prediction: {e}")
            traceback.print_exc() # Print detailed traceback for debugging

        return final_mask_pil # Return PIL mask or None

# --- Instantiate Generator ---
print("\nInstantiating ImageGenerator...")
image_generator = None
# Check if prerequisite variables exist
if 'model_manager' in globals() and model_manager is not None:
    if 'OUTPUT_PATH' in globals() and isinstance(OUTPUT_PATH, str):
        try:
            image_generator = ImageGenerator(model_manager, OUTPUT_PATH)
            print("✅ ImageGenerator instantiated successfully.")
        except Exception as e:
            print(f"❌ Failed to instantiate ImageGenerator: {e}")
            traceback.print_exc()
    else:
        print("❌ ERROR: 'OUTPUT_PATH' not found or invalid (expected string path from Cell 3). Cannot instantiate ImageGenerator.")
else:
    print("❌ ERROR: 'model_manager' not found or not initialized (expected from Cell 5). Cannot instantiate ImageGenerator.")

# --- Optional: Quick Test ---
# (Keep commented out unless specifically testing this cell)
# print("\n--- Optional Quick Test ---")
# if image_generator:
#      if image_generator.current_pipeline and image_generator.current_model_name:
#           print(f"\n--- Quick Test: Text-to-Image using pre-loaded model '{image_generator.current_model_name}' ---")
#           test_prompt = "A photo of an astronaut riding a horse on the moon"
#           img, path, seed = image_generator.text_to_image(
#                model_name=image_generator.current_model_name,
#                prompt=test_prompt, num_inference_steps=15, width=512, height=512
#           )
#           if img: print(f"Quick test successful. Seed: {seed}, Path: {path}")
#           else: print("Quick test failed.")
#      else:
#           print("\n--- Quick Test: Text-to-Image (requires model download if not pre-loaded) ---")
#           # Select a model known to be defined in PREDEFINED_MODELS
#           test_model = "Deliberate V3 (HF)" # Or another model name
#           print(f"Attempting test with model: {test_model}")
#           # Ensure model is downloaded first (optional, download_model handles check)
#           # model_manager.download_model(test_model)
#           test_prompt = "A watercolor painting of a cozy cabin in the woods, autumn"
#           img, path, seed = image_generator.text_to_image(
#                model_name=test_model, prompt=test_prompt, num_inference_steps=15, width=512, height=512
#           )
#           if img: print(f"Quick test successful. Seed: {seed}, Path: {path}")
#           else: print(f"Quick test failed for model {test_model}.")
# else:
#      print("\nℹ️ ImageGenerator not instantiated, skipping quick test.")


print("\n✅ Image Generation & Segmentation cell setup complete.")
if image_generator:
     print("   'image_generator' instance is ready to use.")
     # Report SAM status based on the flag AND predictor state
     sam_status = "Enabled and Initialized" if globals().get('ENABLE_SAM', False) and image_generator.sam_predictor else \
                  "Enabled but NOT Initialized (Check Logs)" if globals().get('ENABLE_SAM', False) else \
                  "Disabled"
     print(f"   SAM Status: {sam_status}")
else:
     print("   ⚠️ ImageGenerator instance could not be created.")
