# Text-to-Image AI Service in Google Colab

This notebook sets up and runs a text-to-image generation service using Stable Diffusion with TensorFlow and KerasCV.

## Requirements
- Enable GPU runtime: Runtime ‚Üí Change runtime type ‚Üí Hardware accelerator ‚Üí GPU
- High-RAM runtime recommended for better performance

## Features
- üé® Stable Diffusion image generation
- üöÄ FastAPI REST API
- üåê Public URL via ngrok
- üì± Interactive web interface

## Step 1: Check GPU and Setup Environment

In [None]:
# Check GPU availability
!nvidia-smi

# Check Python version
import sys
print(f"Python version: {sys.version}")

# Set up environment variables
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ['COLAB_GPU'] = '1'

print("‚úÖ Environment check complete")

## Step 2: Install Dependencies

In [None]:
# Install system dependencies
!apt-get update -qq
!apt-get install -y -qq libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev libgomp1

print("‚úÖ System dependencies installed")

In [None]:
# Install Python dependencies
!pip install -q tensorflow>=2.15.0
!pip install -q keras-cv>=0.6.0
!pip install -q fastapi>=0.104.0
!pip install -q uvicorn[standard]>=0.24.0
!pip install -q python-multipart>=0.0.6
!pip install -q pillow>=10.0.0
!pip install -q pydantic>=2.5.0
!pip install -q pydantic-settings>=2.1.0
!pip install -q python-dotenv>=1.0.0
!pip install -q structlog>=23.2.0
!pip install -q nest-asyncio
!pip install -q pyngrok

print("‚úÖ Python dependencies installed")

## Step 3: Setup ngrok Authentication (REQUIRED)

**IMPORTANT**: ngrok now requires authentication even for free usage.

1. **Sign up for free**: Go to https://dashboard.ngrok.com/signup
2. **Get your auth token**: Visit https://dashboard.ngrok.com/get-started/your-authtoken  
3. **Copy the token**: It looks like `2abc123_def456ghi789jkl...`
4. **Replace the token below**: Change `YOUR_NGROK_AUTH_TOKEN_HERE` to your actual token

In [None]:
# Set up ngrok authentication (REQUIRED)
# Get your auth token from: https://dashboard.ngrok.com/get-started/your-authtoken

from pyngrok import ngrok, conf

# REPLACE THIS with your actual ngrok auth token:
NGROK_AUTH_TOKEN = "YOUR_NGROK_AUTH_TOKEN_HERE"

if NGROK_AUTH_TOKEN == "YOUR_NGROK_AUTH_TOKEN_HERE":
    print("‚ùå ERROR: Please set your ngrok auth token!")
    print("1. Go to: https://dashboard.ngrok.com/signup (free signup)")
    print("2. Get token: https://dashboard.ngrok.com/get-started/your-authtoken")
    print("3. Replace YOUR_NGROK_AUTH_TOKEN_HERE above with your token")
    raise ValueError("ngrok auth token required")
else:
    ngrok.set_auth_token(NGROK_AUTH_TOKEN)
    print("‚úÖ ngrok authentication configured successfully!")

## Step 4: Create and Run the Text-to-Image Service

In [None]:
import os
import time
import asyncio
import nest_asyncio
from typing import Optional, List
import tensorflow as tf
import keras_cv
import numpy as np
from PIL import Image
import io
import base64
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
import uvicorn
from pyngrok import ngrok
import threading

# Enable nested asyncio for Colab
nest_asyncio.apply()

print("‚úÖ Imports complete")

In [None]:
# Configure TensorFlow for Colab
def setup_tensorflow():
    """Configure TensorFlow for Colab environment."""
    gpus = tf.config.experimental.list_physical_devices("GPU")
    if gpus:
        try:
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
            tf.keras.mixed_precision.set_global_policy("mixed_float16")
            print(f"‚úÖ Configured {len(gpus)} GPU(s) with mixed precision")
        except RuntimeError as e:
            print(f"‚ö†Ô∏è GPU setup error: {e}")
    else:
        print("‚ö†Ô∏è No GPU detected, using CPU (will be slow)")
    return len(gpus) > 0

gpu_available = setup_tensorflow()

In [None]:
# Define API models
class ImageRequest(BaseModel):
    prompt: str = Field(..., min_length=1, max_length=500, description="Text description of the image")
    num_steps: int = Field(25, ge=10, le=50, description="Number of diffusion steps (lower = faster)")
    guidance_scale: float = Field(7.5, ge=1.0, le=15.0, description="How closely to follow the prompt")
    seed: Optional[int] = Field(None, ge=0, description="Random seed for reproducible results")

class ImageResponse(BaseModel):
    image_base64: str
    prompt: str
    generation_time: float
    parameters: dict

print("‚úÖ API models defined")

In [None]:
# Create the Stable Diffusion model class
class ColabStableDiffusion:
    """Simplified Stable Diffusion model for Colab."""

    def __init__(self):
        self.model = None
        self.load_model()

    def load_model(self):
        """Load Stable Diffusion model."""
        print("üîÑ Loading Stable Diffusion model (this may take a few minutes)...")

        try:
            self.model = keras_cv.models.StableDiffusion(
                img_width=512,
                img_height=512,
                jit_compile=False,  # Disable for Colab compatibility
            )
            print("‚úÖ Stable Diffusion model loaded successfully!")
        except Exception as e:
            print(f"‚ùå Model loading failed: {e}")
            raise

    def generate_image(self, prompt: str, num_steps: int = 25,
                      guidance_scale: float = 7.5, seed: Optional[int] = None):
        """Generate image from text prompt."""
        if self.model is None:
            raise RuntimeError("Model not loaded")

        if seed is not None:
            tf.random.set_seed(seed)
            np.random.seed(seed)

        try:
            print(f"üé® Generating image for: {prompt[:50]}...")
            start_time = time.time()

            # Check KerasCV version and use appropriate parameters
            try:
                # Try with guidance_scale first (newer versions)
                generated_images = self.model.text_to_image(
                    prompt=prompt,
                    batch_size=1,
                    num_steps=num_steps,
                    guidance_scale=guidance_scale,
                )
            except TypeError as e:
                if "guidance_scale" in str(e):
                    print("‚ö†Ô∏è Using unconditional_guidance_scale parameter for older KerasCV version")
                    # Fallback for older versions that use unconditional_guidance_scale
                    try:
                        generated_images = self.model.text_to_image(
                            prompt=prompt,
                            batch_size=1,
                            num_steps=num_steps,
                            unconditional_guidance_scale=guidance_scale,
                        )
                    except TypeError:
                        # If still failing, try without guidance parameter
                        print("‚ö†Ô∏è Using basic parameters without guidance scale")
                        generated_images = self.model.text_to_image(
                            prompt=prompt,
                            batch_size=1,
                            num_steps=num_steps,
                        )
                else:
                    raise e

            # Convert to PIL Image
            img_array = generated_images[0]
            img_array = (img_array + 1.0) * 127.5
            img_array = np.clip(img_array, 0, 255).astype(np.uint8)
            pil_image = Image.fromarray(img_array)

            generation_time = time.time() - start_time
            print(f"‚úÖ Image generated in {generation_time:.1f} seconds")

            return pil_image

        except Exception as e:
            print(f"‚ùå Generation failed: {e}")
            raise

# Initialize the model (this will take a few minutes)
print("Initializing Stable Diffusion model...")
sd_model = ColabStableDiffusion()

In [None]:
# Create FastAPI application
app = FastAPI(
    title="Text-to-Image AI (Colab)",
    description="Stable Diffusion text-to-image generation running in Google Colab",
    version="1.0.0"
)

@app.post("/generate", response_model=ImageResponse)
async def generate_image_endpoint(request: ImageRequest):
    """Generate image from text prompt."""
    start_time = time.time()

    try:
        # Run generation in thread pool to avoid blocking
        loop = asyncio.get_event_loop()
        pil_image = await loop.run_in_executor(
            None,
            sd_model.generate_image,
            request.prompt,
            request.num_steps,
            request.guidance_scale,
            request.seed,
        )

        # Convert to base64
        buffer = io.BytesIO()
        pil_image.save(buffer, format="PNG")
        img_base64 = base64.b64encode(buffer.getvalue()).decode()

        generation_time = time.time() - start_time

        return ImageResponse(
            image_base64=img_base64,
            prompt=request.prompt,
            generation_time=generation_time,
            parameters={
                "num_steps": request.num_steps,
                "guidance_scale": request.guidance_scale,
                "seed": request.seed
            }
        )

    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")

@app.get("/health")
async def health_check():
    """Health check endpoint."""
    return {
        "status": "healthy",
        "model_loaded": sd_model.model is not None,
        "gpu_available": gpu_available,
        "environment": "Google Colab"
    }

@app.get("/")
async def root():
    """Root endpoint with usage instructions."""
    return {
        "message": "Text-to-Image AI Service running in Google Colab",
        "docs_url": "/docs",
        "health_url": "/health",
        "generate_url": "/generate",
        "gpu_available": gpu_available
    }

print("‚úÖ FastAPI application created")

## Step 5: Start the Server with Public URL

In [None]:
# Start the server with ngrok tunnel
import threading
import time
import asyncio
from uvicorn import Config, Server

def start_server():
    """Start the FastAPI server with proper async handling."""
    print("üîÑ Starting FastAPI server...")
    try:
        # Create a new event loop for this thread
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        
        # Configure and start the server
        config = Config(app, host="0.0.0.0", port=8000, log_level="info")
        server = Server(config)
        
        # Run the server
        loop.run_until_complete(server.serve())
    except Exception as e:
        print(f"‚ùå Server startup failed: {e}")

# Start the server in a separate thread BEFORE creating ngrok tunnel
print("üöÄ Starting the FastAPI server...")
server_thread = threading.Thread(target=start_server)
server_thread.daemon = True
server_thread.start()

# Wait for server to start
print("‚è≥ Waiting for server to initialize...")
time.sleep(15)  # Increased wait time

# Test if server is running locally
import requests
server_ready = False
for attempt in range(5):  # Try 5 times
    try:
        test_response = requests.get("http://localhost:8000/health", timeout=5)
        if test_response.status_code == 200:
            print("‚úÖ Server is running locally!")
            server_ready = True
            break
        else:
            print(f"‚ö†Ô∏è Server responded with status: {test_response.status_code}")
    except Exception as e:
        print(f"üîÑ Attempt {attempt + 1}/5: Server not ready yet...")
        time.sleep(5)

if not server_ready:
    print("‚ùå Server failed to start properly. Please restart runtime and try again.")
else:
    # Now start ngrok tunnel (uses the auth token set earlier)
    print("üåê Creating ngrok tunnel...")
    try:
        public_tunnel = ngrok.connect(8000)
        public_url = str(public_tunnel)  # Convert to string URL
        
        print("\n" + "="*60)
        print("üöÄ TEXT-TO-IMAGE AI SERVICE IS LIVE!")
        print("="*60)
        print(f"üåê Public URL: {public_url}")
        print(f"üìñ API Documentation: {public_url}/docs")
        print(f"‚ù§Ô∏è Health Check: {public_url}/health")
        print(f"üé® Generate Images: {public_url}/generate")
        print("="*60)
        print("\nüì± How to use:")
        print("1. Click the API Documentation link above")
        print("2. Try the /generate endpoint with a text prompt")
        print("3. Or use curl/Python requests to call the API")
        print("\n‚ö†Ô∏è Note: Keep this cell running to maintain the service")
        print("="*60)
        
    except Exception as e:
        print(f"‚ùå Failed to create ngrok tunnel: {e}")

if server_ready:
    print("\n‚úÖ Setup complete! The service should now be accessible via the public URL above.")
else:
    print("\n‚ùå Setup failed. Please restart runtime and try again.")

## Step 6: Troubleshooting & Server Check

If you're getting ngrok connection errors, run this cell to diagnose the issue:

In [None]:
# Troubleshooting: Check if server is running properly
import requests
import time
import subprocess
import threading
import asyncio
from uvicorn import Config, Server

print("üîç Diagnosing server status...")

# Check if port 8000 is in use
try:
    result = subprocess.run(['netstat', '-tuln'], capture_output=True, text=True)
    if ':8000' in result.stdout:
        print("‚úÖ Port 8000 is in use (server might be running)")
    else:
        print("‚ùå Port 8000 is not in use (server not running)")
except:
    print("‚ö†Ô∏è Could not check port status")

# Test local server connection
print("\nüîç Testing local server connection...")
try:
    response = requests.get("http://localhost:8000/health", timeout=10)
    print(f"‚úÖ Local server is responding! Status: {response.status_code}")
    print(f"Response: {response.json()}")
except requests.exceptions.ConnectionError:
    print("‚ùå Cannot connect to local server on port 8000")
    print("üîÑ Trying to restart the server...")
    
    # Try to restart server with proper async handling
    def restart_server():
        try:
            # Create a new event loop for this thread
            loop = asyncio.new_event_loop()
            asyncio.set_event_loop(loop)
            
            # Configure and start the server
            config = Config(app, host="0.0.0.0", port=8000, log_level="info")
            server = Server(config)
            
            # Run the server
            loop.run_until_complete(server.serve())
        except Exception as e:
            print(f"‚ùå Server restart failed: {e}")
    
    server_thread = threading.Thread(target=restart_server)
    server_thread.daemon = True
    server_thread.start()
    
    print("‚è≥ Waiting 20 seconds for server restart...")
    time.sleep(20)
    
    # Test again
    for attempt in range(3):
        try:
            response = requests.get("http://localhost:8000/health", timeout=5)
            print(f"‚úÖ Server restarted successfully! Status: {response.status_code}")
            break
        except:
            print(f"üîÑ Restart attempt {attempt + 1}/3...")
            time.sleep(5)
    else:
        print("‚ùå Server restart failed")

except Exception as e:
    print(f"‚ùå Error testing server: {e}")

# Test ngrok tunnel if it exists
if 'public_url' in globals():
    print(f"\nüîç Testing ngrok tunnel: {public_url}")
    try:
        response = requests.get(f"{public_url}/health", timeout=30)
        print(f"‚úÖ Ngrok tunnel is working! Status: {response.status_code}")
    except Exception as e:
        print(f"‚ùå Ngrok tunnel test failed: {e}")
        print("üí° Try running the server setup cell again")
else:
    print("\n‚ö†Ô∏è No ngrok tunnel found. Run the server setup cell first.")

print("\n" + "="*50)
print("üí° TROUBLESHOOTING TIPS:")
print("="*50)
print("1. If local server fails: Restart runtime and run all cells again")
print("2. If ngrok fails: Check your auth token is set correctly")  
print("3. If still failing: Try running cells one by one with delays")
print("4. For memory issues: Enable High-RAM runtime")
print("5. For async errors: Make sure nest_asyncio is installed")
print("="*50)

## Step 7: Test the API (Optional)

In [None]:
# Test the API directly from the notebook
import requests
import json
from IPython.display import Image as IPImage
from io import BytesIO

def test_generation(prompt, num_steps=25):
    """Test image generation directly."""
    print(f"Testing generation with prompt: {prompt}")
    
    # Check if we have the public URL
    if 'public_url' not in globals():
        print("‚ùå Error: No public URL found. Run the server setup cell first.")
        return None
    
    # Ensure we have a proper URL string
    test_url = public_url if isinstance(public_url, str) else str(public_url)
    if not test_url.startswith(('http://', 'https://')):
        print("‚ùå Error: Invalid URL format")
        return None

    try:
        print(f"üåê Making request to: {test_url}")
        
        # Make API request
        response = requests.post(
            f"{test_url}/generate",
            json={
                "prompt": prompt,
                "num_steps": num_steps,
                "guidance_scale": 7.5
            },
            timeout=300  # 5 minute timeout for generation
        )

        if response.status_code == 200:
            result = response.json()
            print(f"‚úÖ Generation completed in {result['generation_time']:.1f} seconds")

            # Decode and display image
            img_data = base64.b64decode(result['image_base64'])
            return IPImage(img_data)
        else:
            print(f"‚ùå Error: {response.status_code} - {response.text}")
            return None
            
    except requests.exceptions.RequestException as e:
        print(f"‚ùå Request failed: {e}")
        return None
    except Exception as e:
        print(f"‚ùå Unexpected error: {e}")
        return None

# Wait a moment for everything to be ready
print("‚è≥ Waiting for services to be ready...")
time.sleep(5)

# Test with a simple prompt
print("üé® Testing image generation...")
test_image = test_generation("a cute cat sitting in a garden", num_steps=20)
if test_image:
    display(test_image)
    print("üéâ Success! Your text-to-image service is working!")
else:
    print("üîÑ If the test failed, try running the troubleshooting cell above.")
    print("üí° Common solutions:")
    print("   - Wait a bit longer (model might still be loading)")
    print("   - Run the troubleshooting cell to check server status")
    print("   - Restart runtime and run all cells again")

## Step 8: Keep the Service Running