# Stable Diffusion XL on Colab

This notebook runs SDXL image generation on Google Colab using either **TPU** (JAX/Flax) or **GPU** (PyTorch).

## Instructions

1. **Choose your runtime**: Go to `Runtime` > `Change runtime type`
   - **TPU**: Generates 8 images in parallel across TPU cores (JAX/Flax)
   - **GPU (T4)**: Generates 1 image at a time (PyTorch, fp16)
2. **Configure storage** (optional): Toggle `use_google_drive` in the next cell
   - **On**: Caches model to Google Drive (~6.5GB), faster on future runs
   - **Off**: Downloads model each session, no Drive access needed
3. **Run all cells**: Click `Runtime` > `Run all`
4. **Generate images**: Use the Gradio interface that appears

In [None]:
# @title Configuration { display-mode: "form" }
# @markdown ### Storage Settings
# @markdown Toggle Google Drive for model caching and image saving:

use_google_drive = True  # @param {type:"boolean"}

# @markdown ---
# @markdown **With Google Drive:**
# @markdown - Model cached for faster future runs
# @markdown - Can save generated images to Drive
# @markdown 
# @markdown **Without Google Drive:**
# @markdown - Model downloaded each session
# @markdown - No authorization required

import os

USE_GOOGLE_DRIVE = use_google_drive

if USE_GOOGLE_DRIVE:
    from google.colab import drive
    print("Mounting Google Drive...")
    drive.mount('/content/drive')
    
    # Set up cache directory on Google Drive
    CACHE_DIR = '/content/drive/MyDrive/.cache/huggingface'
    OUTPUT_DIR = '/content/drive/MyDrive/sdxl_outputs'
    os.makedirs(CACHE_DIR, exist_ok=True)
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    
    # Set environment variables for Hugging Face cache
    os.environ['HF_HOME'] = CACHE_DIR
    os.environ['TRANSFORMERS_CACHE'] = CACHE_DIR
    os.environ['HF_DATASETS_CACHE'] = CACHE_DIR
    
    print(f"\nCache directory: {CACHE_DIR}")
    print(f"Output directory: {OUTPUT_DIR}")
    
    # Check if model is already cached
    model_cache_path = os.path.join(CACHE_DIR, 'hub', 'models--stabilityai--stable-diffusion-xl-base-1.0')
    if os.path.exists(model_cache_path):
        print("\nSDXL model found in cache - loading will be fast!")
    else:
        print("\nSDXL model not cached yet - first run will download ~6.5GB")
else:
    # Use default local cache (lost when session ends)
    CACHE_DIR = '/root/.cache/huggingface'
    OUTPUT_DIR = '/content/outputs'
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    
    print("Google Drive disabled.")
    print("Model will be downloaded each session (~6.5GB).")
    print(f"\nOutput directory: {OUTPUT_DIR} (local, lost when session ends)")

In [None]:
# Detect runtime type (TPU vs GPU)
import subprocess

USE_TPU = 'COLAB_TPU_ADDR' in os.environ

if USE_TPU:
    print(f"TPU detected at: {os.environ['COLAB_TPU_ADDR']}")
    print("Using JAX/Flax backend for parallel generation across 8 TPU cores.")
    RUNTIME = "TPU"
else:
    # Check for GPU
    try:
        result = subprocess.run(['nvidia-smi', '--query-gpu=name', '--format=csv,noheader'],
                                capture_output=True, text=True)
        gpu_name = result.stdout.strip()
        if gpu_name:
            print(f"GPU detected: {gpu_name}")
            print("Using PyTorch backend with fp16.")
            RUNTIME = "GPU"
        else:
            raise RuntimeError("No GPU found")
    except Exception:
        raise RuntimeError(
            "No TPU or GPU found! Please go to Runtime > Change runtime type > "
            "Select 'TPU' or 'GPU' and restart the notebook."
        )

print(f"\nRuntime: {RUNTIME}")

In [None]:
# Install dependencies based on runtime
if USE_TPU:
    !pip install -q jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    !pip install -q diffusers transformers flax gradio
else:
    !pip install -q torch torchvision --index-url https://download.pytorch.org/whl/cu118
    !pip install -q diffusers transformers accelerate gradio

In [None]:
# Imports
import gradio as gr
from typing import Tuple
import time

if USE_TPU:
    import jax
    import jax.numpy as jnp
    from flax.jax_utils import replicate
    from diffusers import FlaxStableDiffusionXLPipeline
    
    devices = jax.devices()
    NUM_DEVICES = len(devices)
    print(f"JAX devices: {NUM_DEVICES} TPU cores")
    for d in devices:
        print(f"  - {d}")
else:
    import torch
    from diffusers import StableDiffusionXLPipeline
    
    NUM_DEVICES = 1
    print(f"PyTorch device: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

In [None]:
# Style system
style_list = [
    {
        "name": "(No style)",
        "prompt": "{prompt}",
        "negative_prompt": "",
    },
    {
        "name": "Cinematic",
        "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
        "negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
    },
    {
        "name": "Photographic",
        "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
        "negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
    },
    {
        "name": "Anime",
        "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
        "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
    },
    {
        "name": "Manga",
        "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
        "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
    },
    {
        "name": "Digital Art",
        "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
        "negative_prompt": "photo, photorealistic, realism, ugly",
    },
    {
        "name": "Pixel art",
        "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
        "negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
    },
    {
        "name": "Fantasy art",
        "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
        "negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
    },
    {
        "name": "Neonpunk",
        "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
        "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
    },
    {
        "name": "3D Model",
        "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
        "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
    },
]

styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
STYLE_NAMES = list(styles.keys())
DEFAULT_STYLE_NAME = "(No style)"


def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str, str]:
    p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
    return p.replace("{prompt}", positive), n + negative

print(f"Loaded {len(style_list)} styles")

In [None]:
# Load model
print("Loading SDXL model...")
print(f"Cache directory: {CACHE_DIR}")

load_start = time.time()

if USE_TPU:
    # JAX/Flax pipeline for TPU
    pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        split_head_dim=True,  # TPU optimization
        dtype=jnp.bfloat16,
        cache_dir=CACHE_DIR,
    )
    
    # Keep scheduler params in float32 for numerical stability
    scheduler_state = params.pop("scheduler")
    params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
    params["scheduler"] = scheduler_state
    
    # Replicate across TPU cores
    print(f"Replicating model across {NUM_DEVICES} TPU cores...")
    p_params = replicate(params)
    
else:
    # PyTorch pipeline for GPU
    pipeline = StableDiffusionXLPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        torch_dtype=torch.float16,
        variant="fp16",
        use_safetensors=True,
        cache_dir=CACHE_DIR,
    )
    pipeline = pipeline.to("cuda")
    
    # Enable memory optimizations for T4
    pipeline.enable_attention_slicing()

load_elapsed = time.time() - load_start
print(f"\nModel loaded in {load_elapsed:.1f}s")

In [None]:
# Inference function
if USE_TPU:
    def generate_images(
        prompt: str,
        negative_prompt: str = "low quality",
        guidance_scale: float = 7.5,
        style_name: str = "(No style)",
        num_steps: int = 30,
        seed: int = None,
        num_images: int = 4,
    ):
        """Generate images using SDXL on TPU."""
        styled_prompt, styled_negative = apply_style(style_name, prompt, negative_prompt)
        
        prompt_ids = pipeline.prepare_inputs(styled_prompt)
        neg_prompt_ids = pipeline.prepare_inputs(styled_negative)
        
        prompt_ids = replicate(prompt_ids)
        neg_prompt_ids = replicate(neg_prompt_ids)
        
        if seed is None:
            seed = int(time.time())
        
        rng = jax.random.PRNGKey(seed)
        rngs = jax.random.split(rng, NUM_DEVICES)
        
        start_time = time.time()
        images = pipeline(
            prompt_ids=prompt_ids,
            params=p_params,
            prng_seed=rngs,
            num_inference_steps=num_steps,
            guidance_scale=guidance_scale,
            neg_prompt_ids=neg_prompt_ids,
            jit=True,
        ).images
        
        images = images.block_until_ready()
        elapsed = time.time() - start_time
        print(f"Generated {NUM_DEVICES} images in {elapsed:.2f}s")
        
        images = images.reshape((NUM_DEVICES, images.shape[-3], images.shape[-2], images.shape[-1]))
        pil_images = pipeline.numpy_to_pil(jax.device_get(images))
        
        return pil_images[:num_images]

else:
    def generate_images(
        prompt: str,
        negative_prompt: str = "low quality",
        guidance_scale: float = 7.5,
        style_name: str = "(No style)",
        num_steps: int = 30,
        seed: int = None,
        num_images: int = 1,
    ):
        """Generate images using SDXL on GPU."""
        styled_prompt, styled_negative = apply_style(style_name, prompt, negative_prompt)
        
        if seed is None:
            seed = int(time.time())
        
        generator = torch.Generator(device="cuda").manual_seed(seed)
        
        pil_images = []
        start_time = time.time()
        
        for i in range(num_images):
            image = pipeline(
                prompt=styled_prompt,
                negative_prompt=styled_negative,
                guidance_scale=guidance_scale,
                num_inference_steps=num_steps,
                generator=generator,
            ).images[0]
            pil_images.append(image)
            
            # Different seed for next image
            generator = torch.Generator(device="cuda").manual_seed(seed + i + 1)
        
        elapsed = time.time() - start_time
        print(f"Generated {num_images} image(s) in {elapsed:.2f}s")
        
        return pil_images

print("Inference function defined.")

In [None]:
# Warmup - trigger compilation/loading before UI
print("Running warmup...")
if USE_TPU:
    print("(JIT compilation on TPU takes ~3 minutes, subsequent runs are fast)")

warmup_start = time.time()
_ = generate_images(
    prompt="a photo of a cat",
    num_images=1,
)
warmup_elapsed = time.time() - warmup_start
print(f"\nWarmup complete in {warmup_elapsed:.1f}s. Model is ready!")

In [None]:
# Gradio interface
MAX_IMAGES = 8 if USE_TPU else 4  # GPU generates sequentially, so limit to 4

def infer(prompt, negative_prompt, guidance_scale, style_name, num_images, save_images):
    """Gradio inference wrapper."""
    if not prompt:
        raise gr.Error("Please enter a prompt")
    
    images = generate_images(
        prompt=prompt,
        negative_prompt=negative_prompt or "low quality",
        guidance_scale=guidance_scale,
        style_name=style_name,
        num_images=int(num_images),
    )
    
    # Save images if requested
    if save_images:
        timestamp = time.strftime('%Y%m%d_%H%M%S')
        for i, img in enumerate(images):
            safe_prompt = prompt[:50].replace(' ', '_').replace('/', '_')
            filename = f"{timestamp}_{safe_prompt}_{i}.png"
            filepath = os.path.join(OUTPUT_DIR, filename)
            img.save(filepath)
        print(f"Saved {len(images)} image(s) to {OUTPUT_DIR}")
    
    return images


css = """
.gradio-container {
    font-family: 'IBM Plex Sans', sans-serif;
    max-width: 730px !important;
    margin: auto;
    padding-top: 1.5rem;
}
.gr-button {
    color: white;
    border-color: black;
    background: black;
}
#gallery {
    min-height: 22rem;
    margin-bottom: 15px;
}
"""

examples = [
    ["A serious capybara at work, wearing a suit"],
    ["A Squirtle fine dining with a view to the London Eye"],
    ["A tamale food cart in front of a Japanese Castle"],
    ["a graffiti of a robot serving meals to people"],
    ["a beautiful cabin in Attersee, Austria, 3d animation style"],
]

runtime_info = f"Running on **{RUNTIME}** ({'JAX/Flax, 8 parallel images' if USE_TPU else 'PyTorch fp16'})"
storage_info = "Google Drive" if USE_GOOGLE_DRIVE else "Local (session only)"

with gr.Blocks(css=css) as demo:
    gr.HTML(
        f"""
        <div style="text-align: center; margin: 0 auto;">
            <h1 style="font-weight: 900; margin-bottom: 7px; margin-top: 5px">
                Stable Diffusion XL on Colab
            </h1>
            <p style="margin-bottom: 10px; font-size: 94%; line-height: 23px;">
                {runtime_info} | Storage: {storage_info}
            </p>
        </div>
        """
    )
    
    with gr.Row():
        prompt = gr.Textbox(
            label="Prompt",
            placeholder="Enter your prompt",
            max_lines=1,
        )
        btn = gr.Button("Generate", scale=0)
    
    gallery = gr.Gallery(
        label="Generated images",
        elem_id="gallery",
        columns=2,
        height="auto",
    )
    
    with gr.Accordion("Advanced settings", open=False):
        style_selection = gr.Radio(
            choices=STYLE_NAMES,
            value=DEFAULT_STYLE_NAME,
            label="Image Style",
        )
        negative = gr.Textbox(
            label="Negative prompt",
            placeholder="Enter a negative prompt",
            max_lines=1,
        )
        guidance_scale = gr.Slider(
            label="Guidance Scale",
            minimum=0,
            maximum=50,
            value=7.5,
            step=0.1,
        )
        num_images = gr.Slider(
            label="Number of images",
            minimum=1,
            maximum=MAX_IMAGES,
            value=min(4, MAX_IMAGES),
            step=1,
        )
        save_images = gr.Checkbox(
            label=f"Save images to {'Google Drive' if USE_GOOGLE_DRIVE else 'local storage'}",
            value=False,
            info=f"Saves to: {OUTPUT_DIR}",
        )
    
    gr.Examples(
        examples=examples,
        inputs=[prompt],
    )
    
    gr.HTML(
        """
        <div style="text-align: center; margin-top: 20px; font-size: 0.8rem; color: #666;">
            <p>Model: <a href="https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0" target="_blank">stabilityai/stable-diffusion-xl-base-1.0</a></p>
        </div>
        """
    )
    
    btn.click(
        infer,
        inputs=[prompt, negative, guidance_scale, style_selection, num_images, save_images],
        outputs=[gallery],
    )
    prompt.submit(
        infer,
        inputs=[prompt, negative, guidance_scale, style_selection, num_images, save_images],
        outputs=[gallery],
    )

print("Gradio interface built.")
print(f"Output directory: {OUTPUT_DIR}")

In [None]:
# Launch the app with a public URL
demo.launch(share=True, debug=True)