# 🎨 FLUX Image Generator - Fixed Version

This notebook runs a FLUX model API server that works with George's Dream Factory frontend.

**Important**: Make sure to use GPU runtime (Runtime → Change runtime type → T4 GPU)

In [None]:
# Install all required dependencies
!pip install torch diffusers transformers accelerate fastapi uvicorn pyngrok pillow -q
!pip install pydantic --upgrade -q

print("✅ Dependencies installed successfully!")

In [None]:
# Import all required libraries and create the complete API server
import os
import base64
import io
import time
import asyncio
import logging
from typing import Optional
from datetime import datetime
from threading import Thread

import torch
from PIL import Image
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field, validator
from diffusers import FluxPipeline
import uvicorn
from pyngrok import ngrok

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Model configuration
MODEL_ID = "black-forest-labs/FLUX.1-schnell"  # Fast version

# Global variables
pipe = None
device = None

# Request/Response models with all fixes
class GenerateRequest(BaseModel):
    prompt: str = Field(..., min_length=1, max_length=1000, description="Text prompt for image generation")
    steps: int = Field(30, ge=20, le=50, description="Number of inference steps")
    cfg_scale: Optional[float] = Field(None, ge=1.0, le=10.0, description="CFG guidance scale (alias for cfg_guidance)")
    cfg_guidance: Optional[float] = Field(None, ge=1.0, le=10.0, description="CFG guidance scale")
    seed: int = Field(-1, ge=-1, le=999999999, description="Random seed (-1 for random)")
    width: int = Field(1024, ge=512, le=2048, description="Image width")
    height: int = Field(1024, ge=512, le=2048, description="Image height")
    
    def __init__(self, **data):
        # Handle both cfg_scale and cfg_guidance for backward compatibility
        if 'cfg_scale' in data and 'cfg_guidance' not in data:
            data['cfg_guidance'] = data['cfg_scale']
        elif 'cfg_guidance' in data and 'cfg_scale' not in data:
            data['cfg_scale'] = data['cfg_guidance']
        elif 'cfg_scale' not in data and 'cfg_guidance' not in data:
            data['cfg_scale'] = 4.0
            data['cfg_guidance'] = 4.0
        super().__init__(**data)
        
    @validator('prompt')
    def validate_prompt(cls, v):
        """Validate prompt is not empty and doesn't contain harmful content."""
        if not v or not v.strip():
            raise ValueError("Prompt cannot be empty")
        return v.strip()
        
    @validator('width', 'height')
    def validate_dimensions(cls, v):
        """Ensure dimensions are multiples of 8 for better generation."""
        if v % 8 != 0:
            # Round to nearest multiple of 8
            v = round(v / 8) * 8
            v = max(512, min(2048, v))  # Ensure within bounds
        return v

class GenerateResponse(BaseModel):
    success: bool
    image: Optional[str] = None
    error: Optional[str] = None
    duration: Optional[float] = None

class HealthResponse(BaseModel):
    status: str
    model: str
    device: str
    version: str

# Initialize model
def initialize_model():
    global pipe, device
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    logger.info(f"Using device: {device}")
    
    if device == "cuda":
        logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
        logger.info(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
    
    logger.info("Loading FLUX model... This may take a few minutes on first run.")
    
    try:
        pipe = FluxPipeline.from_pretrained(
            MODEL_ID,
            torch_dtype=torch.float16 if device == "cuda" else torch.float32,
        )
        pipe = pipe.to(device)
        
        # Enable memory efficient attention if available
        if hasattr(pipe, "enable_attention_slicing"):
            pipe.enable_attention_slicing()
        
        # Enable CPU offload for low memory systems
        if device == "cuda" and torch.cuda.get_device_properties(0).total_memory < 8 * 1024**3:
            logger.warning("Low GPU memory detected, enabling CPU offload...")
            pipe.enable_model_cpu_offload()
        
        logger.info("Model loaded successfully!")
    except Exception as e:
        logger.error(f"Error loading model: {e}")
        logger.error("Note: You may need to authenticate with Hugging Face for some models.")
        raise

# Create FastAPI app
app = FastAPI(title="FLUX Krea API", version="1.0.0")

# Middleware for request logging
@app.middleware("http")
async def log_requests(request: Request, call_next):
    start_time = time.time()
    
    # Log request
    logger.info(f"Incoming {request.method} request to {request.url.path}")
    
    # Process request
    response = await call_next(request)
    
    # Log response
    duration = time.time() - start_time
    logger.info(f"Request to {request.url.path} completed in {duration:.2f}s with status {response.status_code}")
    
    return response

# Configure CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Configure this for your frontend URL in production
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.on_event("startup")
async def startup_event():
    """Initialize model on startup."""
    initialize_model()

@app.get("/")
async def root():
    """Root endpoint."""
    return {
        "message": "FLUX Krea API Server",
        "endpoints": {
            "health": "/health",
            "status": "/status",
            "generate": "/generate",
            "models": "/models",
            "docs": "/docs"
        }
    }

@app.get("/health", response_model=HealthResponse)
async def health_check():
    """Check if the API is healthy and model is loaded."""
    return HealthResponse(
        status="healthy" if pipe is not None else "unhealthy",
        model=MODEL_ID,
        device=str(device),
        version="1.0.0"
    )

@app.get("/status")
async def status_check():
    """Check model loading status and server health."""
    return {
        "status": "ok" if pipe is not None else "loading",
        "model_loaded": pipe is not None,
        "models_loaded": pipe is not None,  # For backward compatibility
        "models_loading": pipe is None,
        "message": "Model is ready" if pipe is not None else "Model is loading, please wait...",
        "device": str(device) if device else "unknown",
        "model": MODEL_ID
    }

@app.post("/generate", response_model=GenerateResponse)
async def generate_image(request: GenerateRequest):
    """Generate an image from a text prompt."""
    if pipe is None:
        raise HTTPException(
            status_code=503,
            detail="Model is still loading. Please wait a moment and try again."
        )
    
    start_time = time.time()
    
    try:
        # Set random seed if requested
        generator = None
        if request.seed != -1:
            generator = torch.Generator(device=device).manual_seed(request.seed)
        
        # Generate image
        logger.info(f"Starting image generation")
        logger.info(f"Prompt: {request.prompt[:100]}...")
        logger.info(f"Parameters: steps={request.steps}, cfg_guidance={request.cfg_guidance}, size={request.width}x{request.height}, seed={request.seed}")
        
        with torch.no_grad():
            # Run generation in thread pool to avoid blocking
            loop = asyncio.get_event_loop()
            image = await loop.run_in_executor(
                None,
                lambda: pipe(
                    prompt=request.prompt,
                    num_inference_steps=request.steps,
                    guidance_scale=request.cfg_guidance,  # Use cfg_guidance which is normalized in __init__
                    generator=generator,
                    width=request.width,
                    height=request.height,
                ).images[0]
            )
        
        # Convert to base64
        buffered = io.BytesIO()
        image.save(buffered, format="PNG", optimize=True)
        img_base64 = base64.b64encode(buffered.getvalue()).decode()
        
        duration = time.time() - start_time
        logger.info(f"Image generated successfully in {duration:.2f}s")
        logger.info(f"Image size: {len(img_base64)} bytes (base64)")
        
        return GenerateResponse(
            success=True,
            image=f"data:image/png;base64,{img_base64}",
            duration=duration
        )
        
    except torch.cuda.OutOfMemoryError:
        logger.error("GPU out of memory error")
        return GenerateResponse(
            success=False,
            error="GPU out of memory. Try reducing image size or restarting the backend.",
            duration=time.time() - start_time
        )
    except Exception as e:
        logger.error(f"Error generating image: {type(e).__name__}: {str(e)}")
        error_msg = str(e)
        
        # Provide more helpful error messages
        if "CUDA" in error_msg:
            error_msg = "GPU error. Please check CUDA availability and try again."
        elif "dimension" in error_msg.lower() or "size" in error_msg.lower():
            error_msg = f"Invalid image dimensions. Please use sizes between 512 and 2048. Error: {error_msg}"
        elif "prompt" in error_msg.lower():
            error_msg = "Invalid prompt. Please check your input text."
            
        return GenerateResponse(
            success=False,
            error=error_msg,
            duration=time.time() - start_time
        )

@app.get("/models")
async def get_models():
    """Get available models."""
    return {"models": [MODEL_ID]}

print("\n🎨 FLUX Image Generator API - Fixed Version")
print("=" * 50)
print("✅ All fixes applied:")
print("   - Added /status endpoint")
print("   - Fixed parameter mapping (cfg_scale/cfg_guidance)")
print("   - Enhanced error handling")
print("   - Added request/response logging")
print("   - Improved input validation")
print("=" * 50)


In [None]:
# Setup ngrok for public access
from pyngrok import ngrok
import nest_asyncio

# Allow nested event loops in Jupyter
nest_asyncio.apply()

# Kill any existing ngrok processes
ngrok.kill()

# Start ngrok tunnel
public_url = ngrok.connect(7860)
print(f"\n🚀 Public URL: {public_url}")
print(f"\n📋 Copy this URL to your frontend settings!")
print(f"\n📝 API Documentation: {public_url}/docs")
print("\n" + "=" * 50)

In [None]:
# Run the server
print("\n🌐 Starting FLUX API Server...")
print("\n⏳ First image generation will take 2-3 minutes while model loads.")
print("\n✨ Server is running! Use the ngrok URL above in your frontend.")
print("\nPress 'Stop' button to shutdown the server.")

# Run the FastAPI server
uvicorn.run(app, host="0.0.0.0", port=7860)

## 📌 How to Use:

1. **Run all cells** in order (Runtime → Run all)
2. **Copy the ngrok URL** from the output (looks like `https://xxxx-xxxx.ngrok.io`)
3. **Open your frontend** at http://localhost:5173
4. **Go to Settings** and paste the ngrok URL
5. **Click Test Connection** - it should show "Dreams Ready!"
6. **Start generating images!**

## 🛠️ Troubleshooting:

- **Connection refused**: Make sure all cells have run successfully
- **Model loading**: First generation takes 2-3 minutes
- **GPU memory**: If you get memory errors, restart runtime and try smaller images
- **Ngrok expired**: Re-run the ngrok cell to get a new URL