Describe the bug
I am trying to convert pytorch weights to jax as follows:
from diffusers import FlaxStableDiffusionPipeline
model_name='riffusion/riffusion-model-v1'
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(model_name, from_pt=True)
pipeline.save_pretrained('riffusion_jax', params=params)
This works when using model_name='CompVis/stable-diffusion-v1-4' and model_name='runwayml/stable-diffusion-v1-5' but when I try model_name='riffusion/riffusion-model-v1', I get an error: raise ValueError('FrozenDict is immutable.')
Additionally, I modified configuration_utils.py's to_json_string to add the following code:
def to_json_string(self) -> str:
"""
Serializes this instance to a JSON string.
Returns:
`str`: String containing all the attributes that make up this configuration instance in JSON format.
"""
config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
print(config_dict)
######## Modified
try:
config_dict["_class_name"] = self.__class__.__name__
config_dict["_diffusers_version"] = __version__
except:
from flax.core.frozen_dict import unfreeze
config_dict = unfreeze(config_dict)
config_dict["_class_name"] = self.__class__.__name__
config_dict["_diffusers_version"] = __version__
######## Modified
def to_json_saveable(value):
if isinstance(value, np.ndarray):
value = value.tolist()
return value
config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
Which allowed me convert the weights, however, when I try using them, I get another error:
Traceback (most recent call last):
File "/Users/hidden/apps.noindex/stable-diffusion-experiments/riffusion-weights-to-jax/infer.py", line 11, in <module>
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained('riffusion_jax',dtype=jnp.bfloat16, ignore_mismatched_sizes=True)
File "/Users/hidden/.local/share/virtualenvs/stable-diffusion-experiments-IsQz2u7U/lib/python3.10/site-packages/diffusers/pipeline_flax_utils.py", line 466, in from_pretrained
loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False)
File "/Users/hidden/.local/share/virtualenvs/stable-diffusion-experiments-IsQz2u7U/lib/python3.10/site-packages/transformers/modeling_flax_utils.py", line 864, in from_pretrained
raise ValueError(
ValueError: Trying to load the pretrained weight for ('text_model', 'embeddings', 'position_embedding', 'embedding') failed: checkpoint has shape (768,) which is incompatible with the model shape (77, 768). Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this model.
I tried setting ignore_mismatched_sizes=True and I'm still getting this error.
Inference code:
from diffusers import FlaxStableDiffusionPipeline
from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker
import jax
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
import jax.numpy as jnp
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained('riffusion_jax',dtype=jnp.bfloat16, ignore_mismatched_sizes=True)
prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"
prompt = [prompt] * jax.device_count()
prompt_ids = pipeline.prepare_inputs(prompt)
params = replicate(params)
prng_seed = jax.random.split(prng_seed, 8)
prompt_ids = shard(prompt_ids)
def create_key(seed=0):
return jax.random.PRNGKey(seed)
rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())
images = pipeline(prompt_ids, params, rng,height=512,width=512, jit=True)[0]
Thank you!
Reproduction
from diffusers import FlaxStableDiffusionPipeline
model_name='riffusion/riffusion-model-v1'
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(model_name, from_pt=True)
pipeline.save_pretrained('riffusion_jax', params=params)
Logs
Fetching 15 files: 100%|███████████████████████████████████████████████████████████████████| 15/15 [00:00<00:00, 8840.04it/s]
Some weights of the model checkpoint at /Users/hidden/.cache/huggingface/diffusers/models--riffusion--riffusion-model-v1/snapshots/ee6dd541d8d283ed37b4f61ee948f1d58b8687ae/safety_checker were not used when initializing FlaxStableDiffusionSafetyChecker: {('vision_model', 'vision_model', 'embeddings', 'position_ids')}
- This IS expected if you are initializing FlaxStableDiffusionSafetyChecker from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxStableDiffusionSafetyChecker from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
/Users/hidden/.local/share/virtualenvs/stable-diffusion-experiments-IsQz2u7U/lib/python3.10/site-packages/diffusers/configuration_utils.py:195: FutureWarning: It is deprecated to pass a pretrained model name or path to `from_config`.If you were trying to load a model, please use <class 'diffusers.models.unet_2d_condition_flax.FlaxUNet2DConditionModel'>.load_config(...) followed by <class 'diffusers.models.unet_2d_condition_flax.FlaxUNet2DConditionModel'>.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary instead. This functionality will be removed in v1.0.0.
deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
Some weights of the model checkpoint at /Users/hidden/.cache/huggingface/diffusers/models--riffusion--riffusion-model-v1/snapshots/ee6dd541d8d283ed37b4f61ee948f1d58b8687ae/text_encoder were not used when initializing FlaxCLIPTextModel: {('text_model', 'embeddings', 'position_ids')}
- This IS expected if you are initializing FlaxCLIPTextModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxCLIPTextModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
/Users/hidden/.local/share/virtualenvs/stable-diffusion-experiments-IsQz2u7U/lib/python3.10/site-packages/diffusers/configuration_utils.py:195: FutureWarning: It is deprecated to pass a pretrained model name or path to `from_config`.
deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
/Users/hidden/.local/share/virtualenvs/stable-diffusion-experiments-IsQz2u7U/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py:120: FutureWarning: The configuration file of the unet has set the default `sample_size` to smaller than 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the following:
- CompVis/stable-diffusion-v1-4
- CompVis/stable-diffusion-v1-3
- CompVis/stable-diffusion-v1-2
- CompVis/stable-diffusion-v1-1
- runwayml/stable-diffusion-v1-5
- runwayml/stable-diffusion-inpainting
you should change 'sample_size' to 64 in the configuration file. Please make sure to update the config accordingly as leaving `sample_size=32` in the config might lead to incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for the `unet/config.json` file
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
Some weights of the model checkpoint at CompVis/stable-diffusion-safety-checker were not used when initializing FlaxStableDiffusionSafetyChecker: {('vision_model', 'vision_model', 'embeddings', 'position_ids')}
- This IS expected if you are initializing FlaxStableDiffusionSafetyChecker from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxStableDiffusionSafetyChecker from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Traceback (most recent call last):
File "/Users/hidden/apps.noindex/stable-diffusion-experiments/riffusion-weights-to-jax/main.py", line 72, in <module>
new_pipeline.save_pretrained(
File "/Users/hidden/.local/share/virtualenvs/stable-diffusion-experiments-IsQz2u7U/lib/python3.10/site-packages/diffusers/pipeline_flax_utils.py", line 189, in save_pretrained
save_method(
File "/Users/hidden/.local/share/virtualenvs/stable-diffusion-experiments-IsQz2u7U/lib/python3.10/site-packages/diffusers/modeling_flax_utils.py", line 518, in save_pretrained
model_to_save.save_config(save_directory)
File "/Users/hidden/.local/share/virtualenvs/stable-diffusion-experiments-IsQz2u7U/lib/python3.10/site-packages/diffusers/configuration_utils.py", line 137, in save_config
self.to_json_file(output_config_file)
File "/Users/hidden/.local/share/virtualenvs/stable-diffusion-experiments-IsQz2u7U/lib/python3.10/site-packages/diffusers/configuration_utils.py", line 527, in to_json_file
writer.write(self.to_json_string())
File "/Users/hidden/.local/share/virtualenvs/stable-diffusion-experiments-IsQz2u7U/lib/python3.10/site-packages/diffusers/configuration_utils.py", line 507, in to_json_string
config_dict["_class_name"] = self.__class__.__name__
File "/Users/hidden/.local/share/virtualenvs/stable-diffusion-experiments-IsQz2u7U/lib/python3.10/site-packages/flax/core/frozen_dict.py", line 72, in __setitem__
raise ValueError('FrozenDict is immutable.')
System Info
diffusers version: 0.11.1
- Platform: macOS-13.1-x86_64-i386-64bit
- Python version: 3.10.1
- PyTorch version (GPU?): 1.13.1 (False)
- Huggingface_hub version: 0.11.1
- Transformers version: 4.25.1
- Using GPU in script?: no
- Using distributed or parallel set-up in script?: no
Describe the bug
I am trying to convert pytorch weights to jax as follows:
This works when using
model_name='CompVis/stable-diffusion-v1-4'andmodel_name='runwayml/stable-diffusion-v1-5'but when I trymodel_name='riffusion/riffusion-model-v1', I get an error:raise ValueError('FrozenDict is immutable.')Additionally, I modified configuration_utils.py's
to_json_stringto add the following code:Which allowed me convert the weights, however, when I try using them, I get another error:
I tried setting
ignore_mismatched_sizes=Trueand I'm still getting this error.Inference code:
Thank you!
Reproduction
from diffusers import FlaxStableDiffusionPipeline
model_name='riffusion/riffusion-model-v1'
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(model_name, from_pt=True)
pipeline.save_pretrained('riffusion_jax', params=params)
Logs
System Info
diffusersversion: 0.11.1