# Stable Diffusion XL on TPU

This notebook runs SDXL image generation directly on Google Colab's TPU using JAX/Flax.

## Instructions

1. **Set Runtime to TPU**: Go to `Runtime` > `Change runtime type` > Select `TPU` > `Save`
2. **Run all cells**: Click `Runtime` > `Run all`
3. **Wait for warmup**: The first compilation takes ~3 minutes
4. **Generate images**: Use the Gradio interface that appears

The model generates 8 images in parallel across the 8 TPU cores.

In [None]:
# Verify TPU is available
import os

if 'COLAB_TPU_ADDR' in os.environ:
    print(f"TPU available at: {os.environ['COLAB_TPU_ADDR']}")
else:
    raise RuntimeError(
        "TPU not found! Please go to Runtime > Change runtime type > "
        "Select 'TPU' and restart the notebook."
    )

In [None]:
# Install dependencies
!pip install -q jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install -q diffusers transformers flax gradio

In [None]:
# Imports
import jax
import jax.numpy as jnp
from jax import pmap
from flax.jax_utils import replicate, unreplicate
from diffusers import FlaxStableDiffusionXLPipeline
import gradio as gr
from typing import Tuple
import time

# Verify JAX sees TPU devices
devices = jax.devices()
print(f"JAX devices: {len(devices)} TPU cores")
for d in devices:
    print(f"  - {d}")

In [None]:
# Style system
style_list = [
    {
        "name": "(No style)",
        "prompt": "{prompt}",
        "negative_prompt": "",
    },
    {
        "name": "Cinematic",
        "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
        "negative_prompt": "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured",
    },
    {
        "name": "Photographic",
        "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
        "negative_prompt": "drawing, painting, crayon, sketch, graphite, impressionist, noisy, blurry, soft, deformed, ugly",
    },
    {
        "name": "Anime",
        "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
        "negative_prompt": "photo, deformed, black and white, realism, disfigured, low contrast",
    },
    {
        "name": "Manga",
        "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
        "negative_prompt": "ugly, deformed, noisy, blurry, low contrast, realism, photorealistic, Western comic style",
    },
    {
        "name": "Digital Art",
        "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
        "negative_prompt": "photo, photorealistic, realism, ugly",
    },
    {
        "name": "Pixel art",
        "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
        "negative_prompt": "sloppy, messy, blurry, noisy, highly detailed, ultra textured, photo, realistic",
    },
    {
        "name": "Fantasy art",
        "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
        "negative_prompt": "photographic, realistic, realism, 35mm film, dslr, cropped, frame, text, deformed, glitch, noise, noisy, off-center, deformed, cross-eyed, closed eyes, bad anatomy, ugly, disfigured, sloppy, duplicate, mutated, black and white",
    },
    {
        "name": "Neonpunk",
        "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
        "negative_prompt": "painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured",
    },
    {
        "name": "3D Model",
        "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
        "negative_prompt": "ugly, deformed, noisy, low poly, blurry, painting",
    },
]

styles = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in style_list}
STYLE_NAMES = list(styles.keys())
DEFAULT_STYLE_NAME = "(No style)"


def apply_style(style_name: str, positive: str, negative: str = "") -> Tuple[str, str]:
    p, n = styles.get(style_name, styles[DEFAULT_STYLE_NAME])
    return p.replace("{prompt}", positive), n + negative

print(f"Loaded {len(style_list)} styles")

In [None]:
# Load model with bfloat16 for TPU efficiency
print("Loading SDXL model (this may take a few minutes)...")

pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    split_head_dim=True,  # TPU optimization
    dtype=jnp.bfloat16,
)

# Keep scheduler params in float32 for numerical stability
scheduler_state = params.pop("scheduler")
params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
params["scheduler"] = scheduler_state

print("Model loaded successfully!")

In [None]:
# Replicate params across all TPU cores for parallel generation
num_devices = jax.local_device_count()
print(f"Replicating model across {num_devices} TPU cores...")

p_params = replicate(params)

print("Model replicated!")

In [None]:
# Inference function
def tokenize_prompt(prompt: str, negative_prompt: str):
    """Tokenize prompts and replicate across devices."""
    prompt_ids = pipeline.prepare_inputs(prompt)
    neg_prompt_ids = pipeline.prepare_inputs(negative_prompt)
    return prompt_ids, neg_prompt_ids


def generate_images(
    prompt: str,
    negative_prompt: str = "low quality",
    guidance_scale: float = 7.5,
    style_name: str = "(No style)",
    num_steps: int = 30,
    seed: int = None,
    num_images: int = 4,
):
    """Generate images using SDXL on TPU."""
    # Apply style
    styled_prompt, styled_negative = apply_style(style_name, prompt, negative_prompt)
    
    # Tokenize
    prompt_ids, neg_prompt_ids = tokenize_prompt(styled_prompt, styled_negative)
    
    # Replicate inputs across devices
    prompt_ids = replicate(prompt_ids)
    neg_prompt_ids = replicate(neg_prompt_ids)
    
    # Create different random seeds for each device
    if seed is None:
        seed = int(time.time())
    
    # Create PRNGKeys - one per device for parallel generation
    rng = jax.random.PRNGKey(seed)
    rngs = jax.random.split(rng, num_devices)
    
    # Generate images
    start_time = time.time()
    images = pipeline(
        prompt_ids=prompt_ids,
        params=p_params,
        prng_seed=rngs,
        num_inference_steps=num_steps,
        guidance_scale=guidance_scale,
        neg_prompt_ids=neg_prompt_ids,
        jit=True,
    ).images
    
    # Block until computation is done
    images = images.block_until_ready()
    elapsed = time.time() - start_time
    print(f"Generated {num_devices} images in {elapsed:.2f}s")
    
    # Convert to PIL images
    images = images.reshape((num_devices, images.shape[-3], images.shape[-2], images.shape[-1]))
    pil_images = pipeline.numpy_to_pil(jax.device_get(images))
    
    # Return requested number of images
    return pil_images[:num_images]

print("Inference function defined.")

In [None]:
# Warmup - trigger JIT compilation before UI
print("Running warmup to compile model (this takes ~3 minutes)...")
print("Subsequent generations will be much faster.")

warmup_start = time.time()
_ = generate_images(
    prompt="a photo of a cat",
    num_images=1,
)
warmup_elapsed = time.time() - warmup_start
print(f"\nWarmup complete in {warmup_elapsed:.1f}s. Model is ready!")

In [None]:
# Gradio interface
def infer(prompt, negative_prompt, guidance_scale, style_name, num_images):
    """Gradio inference wrapper."""
    if not prompt:
        raise gr.Error("Please enter a prompt")
    
    images = generate_images(
        prompt=prompt,
        negative_prompt=negative_prompt or "low quality",
        guidance_scale=guidance_scale,
        style_name=style_name,
        num_images=int(num_images),
    )
    return images


# CSS for styling
css = """
.gradio-container {
    font-family: 'IBM Plex Sans', sans-serif;
    max-width: 730px !important;
    margin: auto;
    padding-top: 1.5rem;
}
.gr-button {
    color: white;
    border-color: black;
    background: black;
}
#gallery {
    min-height: 22rem;
    margin-bottom: 15px;
}
"""

# Example prompts
examples = [
    ["A serious capybara at work, wearing a suit"],
    ["A Squirtle fine dining with a view to the London Eye"],
    ["A tamale food cart in front of a Japanese Castle"],
    ["a graffiti of a robot serving meals to people"],
    ["a beautiful cabin in Attersee, Austria, 3d animation style"],
]

# Build interface
with gr.Blocks(css=css) as demo:
    gr.HTML(
        """
        <div style="text-align: center; margin: 0 auto;">
            <h1 style="font-weight: 900; margin-bottom: 7px; margin-top: 5px">
                Stable Diffusion XL on TPU
            </h1>
            <p style="margin-bottom: 10px; font-size: 94%; line-height: 23px;">
                SDXL running on Google Colab TPU with JAX/Flax.
                Generates up to 8 images in parallel across TPU cores.
            </p>
        </div>
        """
    )
    
    with gr.Row():
        prompt = gr.Textbox(
            label="Prompt",
            placeholder="Enter your prompt",
            max_lines=1,
        )
        btn = gr.Button("Generate", scale=0)
    
    gallery = gr.Gallery(
        label="Generated images",
        elem_id="gallery",
        columns=2,
        height="auto",
    )
    
    with gr.Accordion("Advanced settings", open=False):
        style_selection = gr.Radio(
            choices=STYLE_NAMES,
            value=DEFAULT_STYLE_NAME,
            label="Image Style",
        )
        negative = gr.Textbox(
            label="Negative prompt",
            placeholder="Enter a negative prompt",
            max_lines=1,
        )
        guidance_scale = gr.Slider(
            label="Guidance Scale",
            minimum=0,
            maximum=50,
            value=7.5,
            step=0.1,
        )
        num_images = gr.Slider(
            label="Number of images",
            minimum=1,
            maximum=8,
            value=4,
            step=1,
        )
    
    gr.Examples(
        examples=examples,
        inputs=[prompt],
    )
    
    gr.HTML(
        """
        <div style="text-align: center; margin-top: 20px; font-size: 0.8rem; color: #666;">
            <p>Model: <a href="https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0" target="_blank">stabilityai/stable-diffusion-xl-base-1.0</a></p>
        </div>
        """
    )
    
    # Event handlers
    btn.click(
        infer,
        inputs=[prompt, negative, guidance_scale, style_selection, num_images],
        outputs=[gallery],
    )
    prompt.submit(
        infer,
        inputs=[prompt, negative, guidance_scale, style_selection, num_images],
        outputs=[gallery],
    )

print("Gradio interface built.")

In [None]:
# Launch the app with a public URL
demo.launch(share=True, debug=True)