[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/stable-diffusion-diffusers-colab/blob/main/flax_stable_diffusion_3.ipynb)

In [None]:
#@title Install - Step 1
!pip install -q jax==0.3.25 jaxlib==0.3.25 flax==0.6.2 transformers accelerate omegaconf
!pip install -q git+https://github.com/huggingface/diffusers

!apt -y install -qq aria2
!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M 'https://civitai.com/api/download/models/138176?type=Model&format=SafeTensor&size=pruned&fp=fp16' -d /content/models -o cyberrealistic_v33.safetensors
# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M 'https://huggingface.co/lllyasviel/ControlNet-v1-1/resolve/main/control_v11p_sd15_canny.pth' -d /content/models/controlnet -o control_v11p_sd15_canny.pth
# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M 'https://huggingface.co/ckpt/ControlNet-v1-1/resolve/main/control_v11p_sd15_canny_fp16.safetensors' -d /content/models/controlnet -o control_v11p_sd15_canny_fp16.safetensors

In [None]:
#@title Install - Step 2
import jax, random, gc, torch
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from diffusers import StableDiffusionPipeline

import numpy as np
from PIL import Image

import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu('tpu_driver_nightly')
jax.device_count()

def image_grid(imgs, rows, cols):
    w,h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

In [None]:
#@title Vanilla - Step 1
pipeline = StableDiffusionPipeline.from_single_file('/content/models/cyberrealistic_v33.safetensors', torch_dtype=torch.float16, safety_checker=None)
pipeline.save_pretrained('/content/models/torch/cyberrealistic', safe_serialization=False)

In [None]:
#@title Vanilla - Step 2
from diffusers import FlaxStableDiffusionPipeline

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained('/content/models/torch/cyberrealistic', from_pt=True, dtype=jax.numpy.bfloat16, safety_checker=None)
params = replicate(params)

In [None]:
#@title Vanilla - Step 3
gc.collect()

prompt = "cute duck"
negative_prompt = "blurry"

num_samples = jax.device_count()

prompt_n = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt_n)
prompt_ids = shard(prompt_ids)

negative_prompt_n = num_samples * [negative_prompt]
negative_prompt_ids = pipeline.prepare_inputs(negative_prompt_n)
negative_prompt_ids = shard(negative_prompt_ids)

real_seed = random.randint(0, 2147483647)
prng_seed = jax.random.PRNGKey(real_seed)
prng_seed = jax.random.split(prng_seed, jax.device_count())

images = pipeline(prompt_ids, params, prng_seed, neg_prompt_ids=negative_prompt_ids, num_inference_steps=50, height=512, width=512, guidance_scale=7, jit=True).images
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))

image = image_grid(images, 2, 4)
image

In [None]:
#@title ControlNet #Canny - Step 1
import jax, torch
from diffusers import FlaxStableDiffusionControlNetPipeline, ControlNetModel
# controlnet = ControlNetModel.from_single_file('/content/models/controlnet/control_v11p_sd15_canny.pth', dtype=torch.float16) # invalid load key, '<'.
controlnet = ControlNetModel.from_single_file('https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth', dtype=torch.float16)
pipeline, params = FlaxStableDiffusionControlNetPipeline.from_pretrained('/content/models/torch/cyberrealistic', controlnet=controlnet, dtype=jax.numpy.bfloat16, from_pt=True, safety_checker=None)
params = replicate(params)

In [None]:
#@title ControlNet #Canny - Step 2
from diffusers.utils import load_image
image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
image_raw = image

import cv2
from PIL import Image
import numpy as np
image = np.array(image)
low_threshold = 100
high_threshold = 200
image = cv2.Canny(image, low_threshold, high_threshold)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
canny_image = Image.fromarray(image)
image = image_grid([image_raw, canny_image], 1, 2)
image

In [None]:
#@title ControlNet #Canny - Step 3
gc.collect()

prompt = "cute duck"
negative_prompt = "blurry, ugly"

num_samples = jax.device_count()

prompt_n = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt_n)
prompt_ids = shard(prompt_ids)

negative_prompt_n = num_samples * [negative_prompt]
negative_prompt_ids = pipeline.prepare_inputs(negative_prompt_n)
negative_prompt_ids = shard(negative_prompt_ids)

real_seed = random.randint(0, 2147483647)
prng_seed = jax.random.PRNGKey(real_seed)
prng_seed = jax.random.split(prng_seed, jax.device_count())

images = pipeline(prompt_ids, canny_image, params, prng_seed, neg_prompt_ids=negative_prompt_ids, num_inference_steps=50, height=512, width=512, guidance_scale=7, jit=True).images
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))

image = image_grid(images, 2, 4)
image