Describe the bug
When trying to load Stable Diffusion 2.1 using Flax, I am getting the following error:
Traceback (most recent call last):
File "/home/jfacevedo/infer.py", line 120, in <module>
run(opt)
File "/home/jfacevedo/infer.py", line 30, in run
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
File "/home/jfacevedo/.local/lib/python3.10/site-packages/diffusers/pipelines/pipeline_flax_utils.py", line 535, in from_pretrained
raise ValueError(
ValueError: Pipeline <class 'diffusers.pipelines.stable_diffusion.pipeline_flax_stable_diffusion.FlaxStableDiffusionPipeline'> expected {'vae', 'scheduler', 'feature_extractor', 'text_encoder', 'safety_checker', 'tokenizer', 'unet'}, but only {'vae', 'scheduler', 'text_encoder', 'tokenizer', 'unet'} were passed.
Reproduction
Create a TPU VM and run the following installation:
git clone https://github.com/huggingface/diffusers.git
cd diffusers
pip install .
cd ..
pip install transformers flax
pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
The run the following as follows:
python infer.py --sd-version 2 --itters 3
import time
import argparse
import numpy as np
import jax
import jax.numpy as jnp
from diffusers import FlaxStableDiffusionPipeline
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
def create_key(seed=0):
return jax.random.PRNGKey(seed)
def image_grid(imgs, rows, cols):
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
def run(opt):
if opt.sd_version == 1:
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="bf16",
dtype=jnp.bfloat16
)
else:
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1",
revision="bf16",
dtype=jnp.bfloat16
)
p_params = replicate(params)
rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())
prompts = ["Labrador in the style of Hokusai"] * opt.batch_size
print("prompts len:",len(prompts))
prompt_ids = pipeline.prepare_inputs(prompts)
prompt_ids = shard(prompt_ids)
# Default values https://github.com/huggingface/diffusers/blob/v0.14.0/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py#L275
num_inference_steps = 50
height = opt.height
width = opt.width
guidance_scale = 7.5
g = jnp.array([guidance_scale] * prompt_ids.shape[0], dtype=jnp.float32)
g = g[:, None] # shape: (prompt_ids.shape[0], 1)
# num_inference_steps, height, width, and guidance_scale are static, so need to
# specify their positions in the _generate() function as an array to static_broadcasted_argnums
p_generate = pmap(pipeline._generate, static_broadcasted_argnums=[3,4,5])
print("Sharded prompt ids has shape:", prompt_ids.shape)
print("Guidance shape:",g.shape)
s = time.time()
images = p_generate(prompt_ids, p_params, rng, num_inference_steps, height, width, g)
images = images.block_until_ready()
print("First inference time is:", time.time() - s)
iters = opt.itters
s = time.time()
for _ in range(iters):
images = p_generate(prompt_ids, p_params, rng, num_inference_steps, height, width, g)
images = images.block_until_ready()
print("Second inference time is:", (time.time() - s)/iters)
print("Shape of predictions is: ", images.shape)
if opt.trace:
trace_path = "/tmp/tensorboard"
with jax.profiler.trace(trace_path):
images = p_generate(prompt_ids, p_params, rng, num_inference_steps, height, width, g)
images = images.block_until_ready()
print(f"trace can be found: {trace_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--batch-size',
type=int,
default=4,
help='Number of images to generate'
)
parser.add_argument(
'--sd-version',
type=int,
default=1,
help='Use 1 for SD 1.4, Use 2 for SD 2.1'
)
parser.add_argument(
'--width',
type=int,
default=512,
help='Width'
)
parser.add_argument(
'--height',
type=int,
default=512,
help='Height'
)
parser.add_argument(
'--itters',
type=int,
default=15,
help='Number of itterations to run the benchmark.'
)
parser.add_argument(
'--trace',
action="store_true",
default=False,
help="Run a trace"
)
opt = parser.parse_args()
run(opt)
Logs
The config attributes {'act_fn': 'silu', 'center_input_sample': False, 'downsample_padding': 1, 'dual_cross_attention': False, 'mid_block_scale_factor': 1, 'norm_eps': 1e-05, 'norm_num_groups': 32, 'num_class_embeds': None} were passed to FlaxUNet2DConditionModel, but are not expected and will be ignored. Please verify your config.json configuration file.
Traceback (most recent call last):
File "/home/jfacevedo/infer.py", line 120, in <module>
run(opt)
File "/home/jfacevedo/infer.py", line 30, in run
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
File "/home/jfacevedo/.local/lib/python3.10/site-packages/diffusers/pipelines/pipeline_flax_utils.py", line 535, in from_pretrained
raise ValueError(
ValueError: Pipeline <class 'diffusers.pipelines.stable_diffusion.pipeline_flax_stable_diffusion.FlaxStableDiffusionPipeline'> expected {'vae', 'scheduler', 'feature_extractor', 'text_encoder', 'safety_checker', 'tokenizer', 'unet'}, but only {'vae', 'scheduler', 'text_encoder', 'tokenizer', 'unet'} were passed.
I0000 00:00:1695928082.495501 4918 tfrt_cpu_pjrt_client.cc:352] TfrtCpuClient destroyed.
System Info
diffusers version: 0.22.0.dev0
- Platform: Linux-5.19.0-1030-gcp-x86_64-with-glibc2.35
- Python version: 3.10.12
- PyTorch version (GPU?): 2.0.1+cu117 (False)
- Huggingface_hub version: 0.17.3
- Transformers version: 4.33.3
- Accelerate version: not installed
- xFormers version: not installed
- Using GPU in script?: No
- using TPU in script?: Yes
- Using distributed or parallel set-up in script?: No
Who can help?
No response
Describe the bug
When trying to load Stable Diffusion 2.1 using Flax, I am getting the following error:
Reproduction
Create a TPU VM and run the following installation:
The run the following as follows:
python infer.py --sd-version 2 --itters 3Logs
The config attributes {'act_fn': 'silu', 'center_input_sample': False, 'downsample_padding': 1, 'dual_cross_attention': False, 'mid_block_scale_factor': 1, 'norm_eps': 1e-05, 'norm_num_groups': 32, 'num_class_embeds': None} were passed to FlaxUNet2DConditionModel, but are not expected and will be ignored. Please verify your config.json configuration file. Traceback (most recent call last): File "/home/jfacevedo/infer.py", line 120, in <module> run(opt) File "/home/jfacevedo/infer.py", line 30, in run pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( File "/home/jfacevedo/.local/lib/python3.10/site-packages/diffusers/pipelines/pipeline_flax_utils.py", line 535, in from_pretrained raise ValueError( ValueError: Pipeline <class 'diffusers.pipelines.stable_diffusion.pipeline_flax_stable_diffusion.FlaxStableDiffusionPipeline'> expected {'vae', 'scheduler', 'feature_extractor', 'text_encoder', 'safety_checker', 'tokenizer', 'unet'}, but only {'vae', 'scheduler', 'text_encoder', 'tokenizer', 'unet'} were passed. I0000 00:00:1695928082.495501 4918 tfrt_cpu_pjrt_client.cc:352] TfrtCpuClient destroyed.System Info
diffusersversion: 0.22.0.dev0Who can help?
No response