In [None]:
import os
import logging
import sys
import torch
import matplotlib.pyplot as plt
from PIL import Image
import gc
from pathlib import Path
from diffusers import DiffusionPipeline
import textwrap
import requests
import shutil
import subprocess
from tqdm import tqdm

In [None]:
# Configure memory optimization
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def check_dependencies():
    """Verify all required Python packages are installed"""
    required = ['torch', 'diffusers', 'transformers', 'matplotlib', 'imageio', 'xformers']
    for module in required:
        try:
            __import__(module)
        except ImportError:
            raise ImportError(f"Required module {module} is not installed. Please install it.")

def clean_memory():
    """Clear GPU memory and perform garbage collection"""
    torch.cuda.empty_cache()
    gc.collect()

def check_model_availability(url: str) -> bool:
    """Verify if a model URL is accessible"""
    try:
        response = requests.head(url, timeout=5)
        return response.status_code == 200
    except requests.RequestException:
        return False

In [None]:
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
logger.info(f"Using device: {device} ({dtype})")

class FaceGenerator:
    def __init__(self, low_memory_mode: bool = True):
        """Initialize the face generator with memory optimization options"""
        check_dependencies()
        self.models = {}
        self.low_memory_mode = low_memory_mode
        self._setup_stylegan_paths()

    def _setup_stylegan_paths(self):
        """Configure StyleGAN2 paths with error handling"""
        stylegan2_path = Path.cwd() / "stylegan2-ada-pytorch"
        if not stylegan2_path.exists():
            raise FileNotFoundError("StyleGAN2 repository not found. Ensure it is cloned.")
        sys.path.insert(0, str(stylegan2_path))
        global dnnlib, legacy
        try:
            import dnnlib
            import legacy
        except ImportError as e:
            logger.error(f"StyleGAN setup error: {str(e)}")
            raise

    def _load_diffusion_model(self, model_name: str):
        """Load diffusion models with memory optimization"""
        if model_name in self.models and self.models[model_name] is not None:
            return self.models[model_name]

        model_config = {
            "realvisxl": {
                "repo": "SG161222/RealVisXL_V5.0",
                "variant": "fp16",
                "enable_xformers": True
            },
            "sdxl": {
                "repo": "stabilityai/stable-diffusion-xl-base-1.0",
                "variant": "fp16",
                "enable_xformers": True,
                "low_memory": True
            }
        }

        clean_memory()
        logger.info(f"Loading {model_name}...")

        try:
            # Special handling for SDXL in low memory mode
            if model_name == "sdxl" and self.low_memory_mode:
                pipe = DiffusionPipeline.from_pretrained(
                    model_config[model_name]["repo"],
                    torch_dtype=dtype,
                    variant=model_config[model_name]["variant"],
                    use_safetensors=True
                )

                # Enable sequential CPU offload for memory optimization
                pipe.enable_model_cpu_offload()
                pipe.enable_sequential_cpu_offload()
            else:
                pipe = DiffusionPipeline.from_pretrained(
                    model_config[model_name]["repo"],
                    torch_dtype=dtype,
                    variant=model_config[model_name]["variant"],
                    use_safetensors=True
                ).to(device)

                if model_config[model_name]["enable_xformers"] and device.type == 'cuda':
                    pipe.enable_xformers_memory_efficient_attention()

            self.models[model_name] = pipe
            return pipe
        except Exception as e:
            logger.error(f"Failed to load {model_name}: {str(e)}")
            return None

    def _load_gan_model(self):
        """Load StyleGAN2 model with error handling"""
        model_name = "stylegan2"
        if model_name in self.models and self.models[model_name] is not None:
            return self.models[model_name]

        model_url = "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl"

        if not check_model_availability(model_url):
            logger.error("StyleGAN2 model URL is not accessible.")
            return None

        clean_memory()
        logger.info("Loading StyleGAN2...")

        try:
            with dnnlib.util.open_url(model_url) as f:
                model = legacy.load_network_pkl(f)["G_ema"].to(device)
            self.models[model_name] = model
            return model
        except Exception as e:
            logger.error(f"Failed to load StyleGAN2: {str(e)}")
            return None

    def generate_images(self,
                       prompt: str = "Professional portrait photo, detailed facial features, 8k",
                       width: int = 768,
                       height: int = 768,
                       num_inference_steps: int = 25,
                       guidance_scale: float = 7.5) -> dict:
        """Generate and compare images from three models with memory optimization"""
        results = {}

        # Generate with RealVisXL first (most memory efficient)
        if "realvisxl" not in self.models:
            self.models["realvisxl"] = self._load_diffusion_model("realvisxl")

        if self.models["realvisxl"]:
            logger.info("Generating image with RealVisXL...")
            results["realvisxl"] = self._generate_diffusion_image(
                "realvisxl", prompt, width, height, num_inference_steps, guidance_scale)
            clean_memory()

        # Generate with SDXL (memory intensive)
        if "sdxl" not in self.models:
            self.models["sdxl"] = self._load_diffusion_model("sdxl")

        if self.models["sdxl"]:
            logger.info("Generating image with SDXL (this may take more memory)...")
            results["sdxl"] = self._generate_diffusion_image(
                "sdxl", prompt, width, height, num_inference_steps, guidance_scale)
            clean_memory()

        # Generate with StyleGAN2 last (least memory intensive)
        if "stylegan2" not in self.models:
            self.models["stylegan2"] = self._load_gan_model()

        if self.models["stylegan2"]:
            logger.info("Generating image with StyleGAN2...")
            results["stylegan2"] = self._generate_gan_image()
            clean_memory()

        return results

    def _generate_diffusion_image(self,
                                model_name: str,
                                prompt: str,
                                width: int,
                                height: int,
                                num_inference_steps: int,
                                guidance_scale: float) -> Image.Image:
        """Generate image from diffusion model with memory protection"""
        try:
            # Reduce memory footprint for SDXL
            if model_name == "sdxl" and self.low_memory_mode:
                return self.models[model_name](
                    prompt=prompt,
                    width=width,
                    height=height,
                    num_inference_steps=num_inference_steps,
                    guidance_scale=guidance_scale,
                    generator=torch.Generator(device="cpu")  # Reduce GPU memory usage
                ).images[0]
            else:
                return self.models[model_name](
                    prompt=prompt,
                    width=width,
                    height=height,
                    num_inference_steps=num_inference_steps,
                    guidance_scale=guidance_scale
                ).images[0]
        except torch.cuda.OutOfMemoryError:
            logger.warning(f"Out of memory during {model_name} generation. Retrying with lower resolution...")
            clean_memory()
            try:
                return self.models[model_name](
                    prompt=prompt,
                    width=512,
                    height=512,
                    num_inference_steps=num_inference_steps,
                    guidance_scale=guidance_scale
                ).images[0]
            except Exception as e:
                logger.error(f"Generation failed for {model_name}: {str(e)}")
                return None
        except Exception as e:
            logger.error(f"Generation failed for {model_name}: {str(e)}")
            return None

    def _generate_gan_image(self) -> Image.Image:
        """Generate image from StyleGAN2 model"""
        try:
            model = self.models["stylegan2"]
            z = torch.randn([1, model.z_dim]).to(device)
            with torch.no_grad():
                img = model(z, None, truncation_psi=0.7)
                img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
            return Image.fromarray(img[0].cpu().numpy())
        except Exception as e:
            logger.error(f"Generation failed for StyleGAN2: {str(e)}")
            return None

    def visualize_results(self, results: dict, prompt: str):
        """Display comparison results for three models"""
        valid_results = {k: v for k, v in results.items() if v is not None}
        if not valid_results:
            logger.warning("No images generated. Skipping visualization.")
            return

        model_names = {
            "realvisxl": "RealVisXL V5.0",
            "sdxl": "Stable Diffusion XL",
            "stylegan2": "StyleGAN2-ADA"
        }

        plt.figure(figsize=(15, 5))

        for idx, (model_key, model_name) in enumerate(model_names.items(), 1):
            if model_key in valid_results:
                plt.subplot(1, 3, idx)
                plt.imshow(valid_results[model_key])
                plt.title(model_name, fontsize=12)
                plt.axis('off')
                valid_results[model_key].save(f"{model_key}_output.png")

        plt.suptitle(textwrap.shorten(prompt, width=100, placeholder="..."), y=1.05)
        plt.tight_layout()
        plt.show()

In [None]:
if __name__ == "__main__":
    try:
        # Initialize with low memory mode enabled
        generator = FaceGenerator(low_memory_mode=True)

        prompt = "High-quality portrait of man photo with detailed facial features, professional photography, 8k"

        # Generate images with automatic memory management
        results = generator.generate_images(
            prompt=prompt,
            width=768,  # Will automatically reduce if OOM occurs
            height=768,
            num_inference_steps=25,
            guidance_scale=7.5
        )

        # Visualize results
        generator.visualize_results(results, prompt)

    except Exception as e:
        logger.error(f"Fatal error: {str(e)}")
        logger.info("Please check your setup and try again")