In [None]:
import jax
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind

import time
import traceback

print(f"Found {num_devices} JAX devices of type {device_type}.")
# assert "TPU" in device_type, "Available device is not a TPU, please select TPU from Edit > Notebook settings > Hardware accelerator"

In [None]:
import numpy as np
import jax
import jax.numpy as jnp

from pathlib import Path
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image

# from huggingface_hub import notebook_login
from diffusers import FlaxStableDiffusionPipeline
from diffusers import EulerAncestralDiscreteScheduler

In [None]:
# pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
#     "runwayml/stable-diffusion-v1-5",
#     # revision="fp16",
#     dtype=jnp.float16,
#     from_pt=True,
# )

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    revision="bf16",
    dtype=jnp.bfloat16
)

In [None]:
# pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config)

In [None]:
batch_size_list = [2 ** x for x in range(0, 8)]
steps = 50
cfg_scale = 15.0
prompt = "postapocalyptic steampunk city, exploration, cinematic, realistic, hyper detailed, photorealistic maximum detail, volumetric light, (((focus))), wide-angle, (((brightly lit))), (((vegetation))), lightning, vines, destruction, devastation, wartorn, ruins"
negative_prompt = "(((blurry))), ((foggy)), (((dark))), ((monochrome)), sun, (((depth of field)))"

In [None]:
def create_key(seed=0):
    return jax.random.PRNGKey(seed)

rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())

In [None]:
# pre_warm
batch_size = 1
_prompt = [prompt] * batch_size
prompt_ids = pipeline.prepare_inputs(_prompt)

_neg_prompt = [negative_prompt] * batch_size
neg_prompt_ids = pipeline.prepare_inputs(_neg_prompt)

p_params = replicate(params)
prompt_ids = shard(prompt_ids)

neg_prompt_ids = shard(neg_prompt_ids)

images = pipeline(
    prompt_ids=prompt_ids,
    neg_prompt_ids=neg_prompt_ids,
    num_inference_steps=20,
    guidance_scale=cfg_scale,
    height=512,
    width=512,
    jit=True,
    params=p_params,
    prng_seed=rng,
)[0]

In [None]:
def image_grid(imgs, rows, cols):
    w,h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

In [None]:
images = images.reshape((images.shape[0],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)
image_grid(images, 1, 1)

In [None]:
result = []
batch_size_list = [1,1,2,2,4,4,8,8]
for batch_size in batch_size_list:
    try:
        rng = jax.random.PRNGKey(0)
        rng = jax.random.split(rng, num_devices)
        
        _prompt = [prompt] * batch_size
        prompt_ids = pipeline.prepare_inputs(_prompt)

        _neg_prompt = [negative_prompt] * batch_size
        neg_prompt_ids = pipeline.prepare_inputs(_neg_prompt)

        p_params = replicate(params)
        prompt_ids = shard(prompt_ids)

        neg_prompt_ids = shard(neg_prompt_ids)

        # print(prompt_ids.shape)
        # print(neg_prompt_ids.shape)

        t0 = time.time()
        images = pipeline(
            prompt_ids=prompt_ids,
            neg_prompt_ids=neg_prompt_ids,
            num_inference_steps=50,
            guidance_scale=cfg_scale,
            height=512,
            width=512,
            jit=True,
            params=p_params,
            prng_seed=rng,
        )[0]
        t1 = time.time()
        its = steps * batch_size / (t1 - t0)
        print("batch_size {}, it/s: {}, time: {}".format(batch_size, round(its, 2), round((t1 - t0), 2)))
    except Exception:
        print(traceback.print_exc())
        print("batch_size {}, OOM".format(batch_size))
        its = 0
    result.append(round(its, 2))
result