Skip to content

Flax Stable Diffusion 2.1 error #5224

@entrpn

Description

@entrpn

Describe the bug

When trying to load Stable Diffusion 2.1 using Flax, I am getting the following error:

Traceback (most recent call last):
  File "/home/jfacevedo/infer.py", line 120, in <module>
    run(opt)
  File "/home/jfacevedo/infer.py", line 30, in run
    pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
  File "/home/jfacevedo/.local/lib/python3.10/site-packages/diffusers/pipelines/pipeline_flax_utils.py", line 535, in from_pretrained
    raise ValueError(
ValueError: Pipeline <class 'diffusers.pipelines.stable_diffusion.pipeline_flax_stable_diffusion.FlaxStableDiffusionPipeline'> expected {'vae', 'scheduler', 'feature_extractor', 'text_encoder', 'safety_checker', 'tokenizer', 'unet'}, but only {'vae', 'scheduler', 'text_encoder', 'tokenizer', 'unet'} were passed.

Reproduction

Create a TPU VM and run the following installation:

git clone https://github.com/huggingface/diffusers.git
cd diffusers
pip install .
cd ..
pip install transformers flax
pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

The run the following as follows:

python infer.py --sd-version 2 --itters 3

import time
import argparse
import numpy as np
import jax
import jax.numpy as jnp
from diffusers import FlaxStableDiffusionPipeline
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image

def create_key(seed=0):
    return jax.random.PRNGKey(seed)

def image_grid(imgs, rows, cols):
    w, h = imgs[0].size
    grid = Image.new("RGB", size=(cols * w, rows * h))
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))
    return grid

def run(opt):
    if opt.sd_version == 1:
        pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
            "CompVis/stable-diffusion-v1-4",
            revision="bf16",
            dtype=jnp.bfloat16
        )
    else:
        pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
            "stabilityai/stable-diffusion-2-1",
            revision="bf16",
            dtype=jnp.bfloat16
        )
    
    p_params = replicate(params)
    rng = create_key(0)
    rng = jax.random.split(rng, jax.device_count())
    prompts = ["Labrador in the style of Hokusai"] * opt.batch_size
    print("prompts len:",len(prompts))

    prompt_ids = pipeline.prepare_inputs(prompts)
    prompt_ids = shard(prompt_ids)

    # Default values https://github.com/huggingface/diffusers/blob/v0.14.0/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py#L275
    num_inference_steps = 50
    height = opt.height 
    width = opt.width 
    guidance_scale = 7.5
    g = jnp.array([guidance_scale] * prompt_ids.shape[0], dtype=jnp.float32)
    g = g[:, None]  # shape: (prompt_ids.shape[0], 1)

    # num_inference_steps, height, width, and guidance_scale are static, so need to 
    # specify their positions in the _generate() function as an array to static_broadcasted_argnums
    p_generate = pmap(pipeline._generate, static_broadcasted_argnums=[3,4,5])

    print("Sharded prompt ids has shape:", prompt_ids.shape)
    print("Guidance shape:",g.shape)

    s = time.time()
    images = p_generate(prompt_ids, p_params, rng, num_inference_steps, height, width, g)
    images = images.block_until_ready()
    print("First inference time is:", time.time() - s)

    iters = opt.itters 
    s = time.time()
    for _ in range(iters):
        images = p_generate(prompt_ids, p_params, rng, num_inference_steps, height, width, g)
        images = images.block_until_ready()
    print("Second inference time is:", (time.time() - s)/iters)
    print("Shape of predictions is: ", images.shape)

    if opt.trace:
        trace_path = "/tmp/tensorboard"
        with jax.profiler.trace(trace_path):
            images = p_generate(prompt_ids, p_params, rng, num_inference_steps, height, width, g)
            images = images.block_until_ready()
            print(f"trace can be found: {trace_path}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--batch-size',
        type=int,
        default=4,
        help='Number of images to generate'
    )
    parser.add_argument(
        '--sd-version',
        type=int,
        default=1,
        help='Use 1 for SD 1.4, Use 2 for SD 2.1'
    )
    parser.add_argument(
        '--width',
        type=int,
        default=512,
        help='Width'
    )
    parser.add_argument(
        '--height',
        type=int,
        default=512,
        help='Height'
    )
    parser.add_argument(
        '--itters',
        type=int,
        default=15,
        help='Number of itterations to run the benchmark.'
    )
    parser.add_argument(
        '--trace',
        action="store_true", 
        default=False, 
        help="Run a trace"
    )

    opt = parser.parse_args()
    run(opt)

Logs

The config attributes {'act_fn': 'silu', 'center_input_sample': False, 'downsample_padding': 1, 'dual_cross_attention': False, 'mid_block_scale_factor': 1, 'norm_eps': 1e-05, 'norm_num_groups': 32, 'num_class_embeds': None} were passed to FlaxUNet2DConditionModel, but are not expected and will be ignored. Please verify your config.json configuration file.
Traceback (most recent call last):
  File "/home/jfacevedo/infer.py", line 120, in <module>
    run(opt)
  File "/home/jfacevedo/infer.py", line 30, in run
    pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
  File "/home/jfacevedo/.local/lib/python3.10/site-packages/diffusers/pipelines/pipeline_flax_utils.py", line 535, in from_pretrained
    raise ValueError(
ValueError: Pipeline <class 'diffusers.pipelines.stable_diffusion.pipeline_flax_stable_diffusion.FlaxStableDiffusionPipeline'> expected {'vae', 'scheduler', 'feature_extractor', 'text_encoder', 'safety_checker', 'tokenizer', 'unet'}, but only {'vae', 'scheduler', 'text_encoder', 'tokenizer', 'unet'} were passed.
I0000 00:00:1695928082.495501    4918 tfrt_cpu_pjrt_client.cc:352] TfrtCpuClient destroyed.

System Info

  • diffusers version: 0.22.0.dev0
  • Platform: Linux-5.19.0-1030-gcp-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • PyTorch version (GPU?): 2.0.1+cu117 (False)
  • Huggingface_hub version: 0.17.3
  • Transformers version: 4.33.3
  • Accelerate version: not installed
  • xFormers version: not installed
  • Using GPU in script?: No
  • using TPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Who can help?

No response

Metadata

Metadata

Labels

bugSomething isn't workingstaleIssues that haven't received updates

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions