# FLUX Krea API Server for Image Generation

This notebook sets up a FastAPI server that exposes the FLUX Krea model for image generation.

## Setup Instructions:
1. Run all cells in order
2. The last cell will provide you with a public URL (via ngrok)
3. Use this URL in your frontend application

In [None]:
# Install required packages
!pip install -q fastapi uvicorn pyngrok python-multipart pillow torch diffusers transformers accelerate

In [None]:
# Import required libraries
import os
import base64
import io
import time
import asyncio
from typing import Optional
from datetime import datetime

import torch
from PIL import Image
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from diffusers import FluxPipeline
import uvicorn
from pyngrok import ngrok
import nest_asyncio

# Allow nested event loops in Jupyter
nest_asyncio.apply()

In [None]:
# Set up device and model
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Model configuration
MODEL_ID = "black-forest-labs/FLUX.1-schnell"  # Fast version of FLUX
# For higher quality, use: "black-forest-labs/FLUX.1-dev" (requires auth)

In [None]:
# Initialize the model
print("Loading FLUX model... This may take a few minutes on first run.")

try:
    pipe = FluxPipeline.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.float16 if device == "cuda" else torch.float32,
    )
    pipe = pipe.to(device)
    
    # Enable memory efficient attention if available
    if hasattr(pipe, "enable_attention_slicing"):
        pipe.enable_attention_slicing()
    
    print("Model loaded successfully!")
except Exception as e:
    print(f"Error loading model: {e}")
    print("Note: You may need to authenticate with Hugging Face for some models.")
    raise

In [None]:
# Create FastAPI app
app = FastAPI(title="FLUX Krea API", version="1.0.0")

# Configure CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # In production, replace with your frontend URL
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Request/Response models
class GenerateRequest(BaseModel):
    prompt: str = Field(..., description="Text prompt for image generation")
    steps: int = Field(30, ge=20, le=50, description="Number of inference steps")
    cfg_scale: float = Field(4.0, ge=1.0, le=10.0, description="CFG guidance scale")
    seed: int = Field(-1, description="Random seed (-1 for random)")
    width: int = Field(1024, ge=512, le=2048, description="Image width")
    height: int = Field(1024, ge=512, le=2048, description="Image height")

class GenerateResponse(BaseModel):
    success: bool
    image: Optional[str] = None
    error: Optional[str] = None
    duration: Optional[float] = None

class HealthResponse(BaseModel):
    status: str
    model: str
    device: str
    version: str

In [None]:
# API endpoints
@app.get("/health", response_model=HealthResponse)
async def health_check():
    """Check if the API is healthy and model is loaded."""
    return HealthResponse(
        status="healthy",
        model=MODEL_ID,
        device=device,
        version="1.0.0"
    )

@app.post("/generate", response_model=GenerateResponse)
async def generate_image(request: GenerateRequest):
    """Generate an image from a text prompt."""
    start_time = time.time()
    
    try:
        # Set random seed if requested
        generator = None
        if request.seed != -1:
            generator = torch.Generator(device=device).manual_seed(request.seed)
        
        # Generate image
        print(f"Generating image with prompt: {request.prompt[:50]}...")
        
        with torch.no_grad():
            image = pipe(
                prompt=request.prompt,
                num_inference_steps=request.steps,
                guidance_scale=request.cfg_scale,
                generator=generator,
                width=request.width,
                height=request.height,
            ).images[0]
        
        # Convert to base64
        buffered = io.BytesIO()
        image.save(buffered, format="PNG")
        img_base64 = base64.b64encode(buffered.getvalue()).decode()
        
        duration = time.time() - start_time
        print(f"Image generated successfully in {duration:.2f}s")
        
        return GenerateResponse(
            success=True,
            image=f"data:image/png;base64,{img_base64}",
            duration=duration
        )
        
    except Exception as e:
        print(f"Error generating image: {str(e)}")
        return GenerateResponse(
            success=False,
            error=str(e),
            duration=time.time() - start_time
        )

@app.get("/models")
async def get_models():
    """Get available models."""
    return {"models": [MODEL_ID]}

In [None]:
# Get ngrok auth token (optional but recommended for stable URLs)
# You can get a free auth token from https://dashboard.ngrok.com/signup
# Uncomment and add your token:
# ngrok.set_auth_token("YOUR_NGROK_AUTH_TOKEN")

In [None]:
# Start the server with ngrok
import threading

# Start ngrok tunnel
public_url = ngrok.connect(7860)
print(f"\n🚀 API is running!")
print(f"\n📱 Public URL: {public_url}")
print(f"\n🔧 Use this URL in your frontend application's VITE_API_ENDPOINT")
print(f"\n📝 API Documentation: {public_url}/docs")
print(f"\n⚡ Health check: {public_url}/health")

# Run the server in a separate thread
config = uvicorn.Config(app, host="0.0.0.0", port=7860, log_level="info")
server = uvicorn.Server(config)

thread = threading.Thread(target=server.run)
thread.start()

print("\n✅ Server is running! Keep this cell running to maintain the connection.")
print("\n⚠️  To stop the server, interrupt the kernel.")

## Testing the API

You can test the API using the cell below:

In [None]:
# Test the API (optional)
import requests

# Use localhost for testing within Colab
test_url = "http://localhost:7860"

# Test health endpoint
response = requests.get(f"{test_url}/health")
print("Health check:", response.json())

# Test image generation
test_prompt = "A serene mountain landscape at sunset, photorealistic"
response = requests.post(
    f"{test_url}/generate",
    json={
        "prompt": test_prompt,
        "steps": 20,
        "cfg_scale": 4.0,
        "seed": 42,
        "width": 1024,
        "height": 1024
    }
)

result = response.json()
if result["success"]:
    print(f"\n✅ Image generated successfully in {result['duration']:.2f}s")
    # Display the image
    from IPython.display import Image as IPImage, display
    import base64
    img_data = base64.b64decode(result["image"].split(",")[1])
    display(IPImage(img_data))
else:
    print(f"\n❌ Error: {result['error']}")

## Important Notes:

1. **GPU Runtime**: For best performance, use GPU runtime in Colab (Runtime → Change runtime type → GPU)

2. **Model Options**:
   - `FLUX.1-schnell`: Fast generation (4-8 steps), good quality
   - `FLUX.1-dev`: Higher quality, requires more steps and HuggingFace authentication
   - `FLUX.1-pro`: Best quality, requires API access

3. **Memory Management**:
   - The model uses ~6-8GB of VRAM
   - Free Colab might disconnect after some time
   - Consider Colab Pro for longer sessions

4. **ngrok Limitations**:
   - Free tier has request limits
   - URLs change on restart
   - Consider getting an auth token for stable URLs

5. **Security**:
   - The current CORS settings allow all origins
   - In production, restrict to your frontend domain only