Skip to content

Converting pytorch stable diffusion weights to jax doesn't always work #2081

@entrpn

Description

@entrpn

Describe the bug

I am trying to convert pytorch weights to jax as follows:

from diffusers import FlaxStableDiffusionPipeline
model_name='riffusion/riffusion-model-v1'
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(model_name, from_pt=True)
pipeline.save_pretrained('riffusion_jax', params=params)

This works when using model_name='CompVis/stable-diffusion-v1-4' and model_name='runwayml/stable-diffusion-v1-5' but when I try model_name='riffusion/riffusion-model-v1', I get an error: raise ValueError('FrozenDict is immutable.')

Additionally, I modified configuration_utils.py's to_json_string to add the following code:

    def to_json_string(self) -> str:
        """
        Serializes this instance to a JSON string.

        Returns:
            `str`: String containing all the attributes that make up this configuration instance in JSON format.
        """
        config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
        print(config_dict)
        
        ######## Modified
        try:
            config_dict["_class_name"] = self.__class__.__name__
            config_dict["_diffusers_version"] = __version__
        except:
            from flax.core.frozen_dict import unfreeze
            config_dict = unfreeze(config_dict)
            config_dict["_class_name"] = self.__class__.__name__
            config_dict["_diffusers_version"] = __version__
        ######## Modified

        def to_json_saveable(value):
            if isinstance(value, np.ndarray):
                value = value.tolist()
            return value

        config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
        return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"

Which allowed me convert the weights, however, when I try using them, I get another error:

Traceback (most recent call last):
  File "/Users/hidden/apps.noindex/stable-diffusion-experiments/riffusion-weights-to-jax/infer.py", line 11, in <module>
    pipeline, params = FlaxStableDiffusionPipeline.from_pretrained('riffusion_jax',dtype=jnp.bfloat16, ignore_mismatched_sizes=True)
  File "/Users/hidden/.local/share/virtualenvs/stable-diffusion-experiments-IsQz2u7U/lib/python3.10/site-packages/diffusers/pipeline_flax_utils.py", line 466, in from_pretrained
    loaded_sub_model, loaded_params = load_method(loadable_folder, _do_init=False)
  File "/Users/hidden/.local/share/virtualenvs/stable-diffusion-experiments-IsQz2u7U/lib/python3.10/site-packages/transformers/modeling_flax_utils.py", line 864, in from_pretrained
    raise ValueError(
ValueError: Trying to load the pretrained weight for ('text_model', 'embeddings', 'position_embedding', 'embedding') failed: checkpoint has shape (768,) which is incompatible with the model shape (77, 768). Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this model.

I tried setting ignore_mismatched_sizes=True and I'm still getting this error.

Inference code:

from diffusers import FlaxStableDiffusionPipeline
from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker
import jax
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
import jax.numpy as jnp

prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained('riffusion_jax',dtype=jnp.bfloat16, ignore_mismatched_sizes=True)

prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"
prompt = [prompt] * jax.device_count()
prompt_ids = pipeline.prepare_inputs(prompt)

params = replicate(params)
prng_seed = jax.random.split(prng_seed, 8)
prompt_ids = shard(prompt_ids)

def create_key(seed=0):
    return jax.random.PRNGKey(seed)
rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())

images = pipeline(prompt_ids, params, rng,height=512,width=512, jit=True)[0]

Thank you!

Reproduction

from diffusers import FlaxStableDiffusionPipeline
model_name='riffusion/riffusion-model-v1'
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(model_name, from_pt=True)
pipeline.save_pretrained('riffusion_jax', params=params)

Logs

Fetching 15 files: 100%|███████████████████████████████████████████████████████████████████| 15/15 [00:00<00:00, 8840.04it/s]
Some weights of the model checkpoint at /Users/hidden/.cache/huggingface/diffusers/models--riffusion--riffusion-model-v1/snapshots/ee6dd541d8d283ed37b4f61ee948f1d58b8687ae/safety_checker were not used when initializing FlaxStableDiffusionSafetyChecker: {('vision_model', 'vision_model', 'embeddings', 'position_ids')}
- This IS expected if you are initializing FlaxStableDiffusionSafetyChecker from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxStableDiffusionSafetyChecker from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
/Users/hidden/.local/share/virtualenvs/stable-diffusion-experiments-IsQz2u7U/lib/python3.10/site-packages/diffusers/configuration_utils.py:195: FutureWarning: It is deprecated to pass a pretrained model name or path to `from_config`.If you were trying to load a model, please use <class 'diffusers.models.unet_2d_condition_flax.FlaxUNet2DConditionModel'>.load_config(...) followed by <class 'diffusers.models.unet_2d_condition_flax.FlaxUNet2DConditionModel'>.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary instead. This functionality will be removed in v1.0.0.
  deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
Some weights of the model checkpoint at /Users/hidden/.cache/huggingface/diffusers/models--riffusion--riffusion-model-v1/snapshots/ee6dd541d8d283ed37b4f61ee948f1d58b8687ae/text_encoder were not used when initializing FlaxCLIPTextModel: {('text_model', 'embeddings', 'position_ids')}
- This IS expected if you are initializing FlaxCLIPTextModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxCLIPTextModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
/Users/hidden/.local/share/virtualenvs/stable-diffusion-experiments-IsQz2u7U/lib/python3.10/site-packages/diffusers/configuration_utils.py:195: FutureWarning: It is deprecated to pass a pretrained model name or path to `from_config`.
  deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
/Users/hidden/.local/share/virtualenvs/stable-diffusion-experiments-IsQz2u7U/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py:120: FutureWarning: The configuration file of the unet has set the default `sample_size` to smaller than 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the following: 
- CompVis/stable-diffusion-v1-4 
- CompVis/stable-diffusion-v1-3 
- CompVis/stable-diffusion-v1-2 
- CompVis/stable-diffusion-v1-1 
- runwayml/stable-diffusion-v1-5 
- runwayml/stable-diffusion-inpainting 
 you should change 'sample_size' to 64 in the configuration file. Please make sure to update the config accordingly as leaving `sample_size=32` in the config might lead to incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for the `unet/config.json` file
  deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
Some weights of the model checkpoint at CompVis/stable-diffusion-safety-checker were not used when initializing FlaxStableDiffusionSafetyChecker: {('vision_model', 'vision_model', 'embeddings', 'position_ids')}
- This IS expected if you are initializing FlaxStableDiffusionSafetyChecker from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxStableDiffusionSafetyChecker from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Traceback (most recent call last):
  File "/Users/hidden/apps.noindex/stable-diffusion-experiments/riffusion-weights-to-jax/main.py", line 72, in <module>
    new_pipeline.save_pretrained(
  File "/Users/hidden/.local/share/virtualenvs/stable-diffusion-experiments-IsQz2u7U/lib/python3.10/site-packages/diffusers/pipeline_flax_utils.py", line 189, in save_pretrained
    save_method(
  File "/Users/hidden/.local/share/virtualenvs/stable-diffusion-experiments-IsQz2u7U/lib/python3.10/site-packages/diffusers/modeling_flax_utils.py", line 518, in save_pretrained
    model_to_save.save_config(save_directory)
  File "/Users/hidden/.local/share/virtualenvs/stable-diffusion-experiments-IsQz2u7U/lib/python3.10/site-packages/diffusers/configuration_utils.py", line 137, in save_config
    self.to_json_file(output_config_file)
  File "/Users/hidden/.local/share/virtualenvs/stable-diffusion-experiments-IsQz2u7U/lib/python3.10/site-packages/diffusers/configuration_utils.py", line 527, in to_json_file
    writer.write(self.to_json_string())
  File "/Users/hidden/.local/share/virtualenvs/stable-diffusion-experiments-IsQz2u7U/lib/python3.10/site-packages/diffusers/configuration_utils.py", line 507, in to_json_string
    config_dict["_class_name"] = self.__class__.__name__
  File "/Users/hidden/.local/share/virtualenvs/stable-diffusion-experiments-IsQz2u7U/lib/python3.10/site-packages/flax/core/frozen_dict.py", line 72, in __setitem__
    raise ValueError('FrozenDict is immutable.')

System Info

  • diffusers version: 0.11.1
  • Platform: macOS-13.1-x86_64-i386-64bit
  • Python version: 3.10.1
  • PyTorch version (GPU?): 1.13.1 (False)
  • Huggingface_hub version: 0.11.1
  • Transformers version: 4.25.1
  • Using GPU in script?: no
  • Using distributed or parallel set-up in script?: no

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions