<a href="https://colab.research.google.com/github/g-lam/colab-samples/blob/main/Text-to-Image/Stable_Diffusion_TPU.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu('tpu_driver_20221011')
jax.devices()

%pip install --quiet --upgrade diffusers transformers ftfy mediapy

In [None]:
import numpy as np
import jax
import jax.numpy as jnp
import time
import torch

from pathlib import Path
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image

from diffusers import FlaxStableDiffusionPipeline

def image_grid(imgs, rows, cols):
    w,h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

In [None]:
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", revision="bf16", dtype=jnp.bfloat16)
p_params = replicate(params)

In [None]:
prompt = "beautiful mountain by the lake, panoramic, professional photography" #@param {type:"string"}
neg_prompt = "bad anatomy, ugly, missing arms, missing legs, extra limbs, extra fingers, mutation, mutilated, distorted face, deformed body features, poor quality" #@param {type:"string"}
num_inference_steps = 50 #@param {type:"integer"}
seed = -1 #@param {type:"integer"}
#@markdown `-1` will set a random seed. You can replace that to any integer for reproducible results

if(seed == -1):
  import random
  random_int = random.randint(0, 2147483647)
  real_seed = random_int
else:
  real_seed = seed
prng_seed = jax.random.PRNGKey(real_seed)
prng_seed = jax.random.split(prng_seed, jax.device_count())
num_samples = jax.device_count()
prompt = num_samples * [prompt]
neg_prompt = num_samples * [neg_prompt]
prompt_ids = pipeline.prepare_inputs(prompt)
prompt_ids = shard(prompt_ids)
neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)
neg_prompt_ids = shard(neg_prompt_ids)
start = time.time()
images = pipeline(prompt_ids, p_params, prng_seed, num_inference_steps, neg_prompt_ids=neg_prompt_ids, jit=True).images
images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
end = time.time()
print(f"Seed:\t{real_seed}\nTime:\t{(end - start)}s")
image_grid(images_pil, 2, 4)