[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/notebooks/blob/main/camenduru's_flax_to_pt_converter.ipynb)

In [None]:
!pip install -q torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 torchtext==0.14.1 torchdata==0.5.1 --extra-index-url https://download.pytorch.org/whl/cu116 -U
!pip -q install transformers accelerate flax jax
!pip -q install git+https://github.com/camenduru/diffusers@from_flax_v2 -U

In [None]:
!git lfs install

In [None]:
!git clone -b flax https://huggingface.co/runwayml/stable-diffusion-v1-5 /content/flax

In [None]:
import torch
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("camenduru/plushies", revision="main", safety_checker=None, low_cpu_mem_usage=False, from_flax=True).to("cpu")

In [None]:
pipe.save_pretrained("/content/plushies-pt")

In [None]:
image = pipe("cat", num_inference_steps=20, height=512, width=512).images[0]
display(image)

In [None]:
!rm -rf /root/.cache/huggingface

In [None]:
import jax
from diffusers import FlaxStableDiffusionPipeline
pipe2, params = FlaxStableDiffusionPipeline.from_pretrained("camenduru/plushies", revision="main", dtype=jax.numpy.bfloat16, safety_checker=None)

In [None]:
from huggingface_hub import create_repo, upload_folder
create_repo("camenduru/plushies-pt", private=True, token="")
upload_folder(folder_path="plushies-pt", path_in_repo="", repo_id="camenduru/plushies-pt", commit_message=f"plushies flax to pt", token="")

In [None]:
!pip -q install flax

In [None]:
!wget https://huggingface.co/camenduru/plushies/resolve/main/vae/diffusion_flax_model.msgpack -P vae
!wget https://huggingface.co/camenduru/plushies/resolve/main/unet/diffusion_flax_model.msgpack -P unet
!wget https://huggingface.co/camenduru/plushies/resolve/main/text_encoder/flax_model.msgpack -P clip

In [None]:
!wget https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/flax/vae/diffusion_flax_model.msgpack -P vae
!wget https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/flax/unet/diffusion_flax_model.msgpack -P unet
!wget https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/flax/text_encoder/flax_model.msgpack -P clip

In [None]:
!wget https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/vae/diffusion_pytorch_model.bin -P vae
!wget https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/unet/diffusion_pytorch_model.bin -P unet
!wget https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/text_encoder/pytorch_model.bin -P clip

In [None]:
import torch

pytorch_vae_nested_state_dict = torch.load("/content/vae/diffusion_pytorch_model.bin")
pytorch_unet_nested_state_dict = torch.load("/content/unet/diffusion_pytorch_model.bin")
pytorch_clip_nested_state_dict = torch.load("/content/clip/pytorch_model.bin")

def print_state_dict_keys(state_dict, indent=0, file=None):
    for key, value in sorted(state_dict.items()):
        file.write(" " * indent + key + "\n")
        if isinstance(value, dict):
            print_state_dict_keys(value, indent + 2, file)

# with open("/content/pytorch_vae_nested_state_dict_keys.txt", "w") as f:
#     print_state_dict_keys(pytorch_vae_nested_state_dict, file=f)

# with open("/content/pytorch_unet_nested_state_dict_keys.txt", "w") as f:
#     print_state_dict_keys(pytorch_unet_nested_state_dict, file=f)

# with open("/content/pytorch_clip_nested_state_dict_keys.txt", "w") as f:
#     print_state_dict_keys(pytorch_clip_nested_state_dict, file=f)

import pprint
with open("vae-original.txt", "w") as f:
    pprint.pprint(pytorch_vae_nested_state_dict, f)
with open("unet-original.txt", "w") as f:
    pprint.pprint(pytorch_unet_nested_state_dict, f)
with open("clip-original.txt", "w") as f:
    pprint.pprint(pytorch_clip_nested_state_dict, f)

In [None]:
import numpy as np
import flax
import jax.numpy as jnp
import torch

# with open("/content/vae/diffusion_flax_model.msgpack", "rb") as state_vae_f:
#     try:
#         flax_vae_nested_state_dict = flax.serialization.from_bytes(None, state_vae_f.read())
#     except e:
#         raise print(e)

with open("/content/unet/diffusion_flax_model.msgpack", "rb") as state_unet_f:
    try:
        flax_unet_nested_state_dict = flax.serialization.from_bytes(None, state_unet_f.read())
    except e:
        raise print(e)

# with open("/content/clip/flax_model.msgpack", "rb") as state_clip_f:
#     try:
#         flax_clip_nested_state_dict = flax.serialization.from_bytes(None, state_clip_f.read())
#     except e:
#         raise print(e)

def print_state_dict_keys(state_dict, indent=0, file=None):
    for key, value in sorted(state_dict.items()):
        file.write(" " * indent + key + "\n")
        if isinstance(value, dict):
            print_state_dict_keys(value, indent + 2, file)

# with open("vae/flax_vae_nested_state_dict_keys.txt", "w") as f:
#     print_state_dict_keys(flax_vae_nested_state_dict, file=f)

# with open("unet/flax_unet_nested_state_dict_keys.txt", "w") as f:
#     print_state_dict_keys(flax_unet_nested_state_dict, file=f)

# with open("clip/flax_clip_nested_state_dict_keys.txt", "w") as f:
#     print_state_dict_keys(flax_clip_nested_state_dict, file=f)

In [None]:
 for flax_key_tuple, flax_tensor in sorted(flax_vae_nested_state_dict.items()):
   print(flax_tensor)

In [None]:
pytorch_vae_nested_state_dict = torch.load("/content/vae/diffusion_pytorch_model.bin")
# pytorch_unet_nested_state_dict = torch.load("/content/unet/diffusion_pytorch_model.bin")
# pytorch_clip_nested_state_dict = torch.load("/content/clip/pytorch_model.bin")

def print_state_dict_keys(state_dict, indent=0, file=None):
    for key, value in sorted(state_dict.items()):
        file.write(" " * indent + key + "\n")
        if isinstance(value, dict):
            print_state_dict_keys(value, indent + 2, file)

with open("vae/pytorch_vae_nested_state_dict_keys.txt", "w") as f:
    print_state_dict_keys(pytorch_vae_nested_state_dict, file=f)

# with open("unet/pytorch_unet_nested_state_dict_keys.txt", "w") as f:
#     print_state_dict_keys(pytorch_unet_nested_state_dict, file=f)

# with open("clip/pytorch_clip_nested_state_dict_keys.txt", "w") as f:
#     print_state_dict_keys(pytorch_clip_nested_state_dict, file=f)

In [None]:
import numpy as np
import flax
import jax.numpy as jnp
import torch

with open("/content/vae/diffusion_flax_model.msgpack", "rb") as state_vae_f:
    try:
        flax_vae_nested_state_dict = flax.serialization.from_bytes(None, state_vae_f.read())
    except e:
        raise print(e)

# with open("/content/unet/diffusion_flax_model.msgpack", "rb") as state_unet_f:
#     try:
#         flax_unet_nested_state_dict = flax.serialization.from_bytes(None, state_unet_f.read())
#     except e:
#         raise print(e)

# with open("/content/clip/flax_model.msgpack", "rb") as state_clip_f:
#     try:
#         flax_clip_nested_state_dict = flax.serialization.from_bytes(None, state_clip_f.read())
#     except e:
#         raise print(e)

# flax_vae_nested_state_dict_flattened = flax.traverse_util.flatten_dict(flax_vae_nested_state_dict, sep=".")
# flax_unet_nested_state_dict_flattened = flax.traverse_util.flatten_dict(flax_unet_nested_state_dict, sep=".")
# flax_clip_nested_state_dict_flattened = flax.traverse_util.flatten_dict(flax_clip_nested_state_dict, sep=".")

# with open("vae/flax_vae_nested_state_dict_keys_flattened.txt", "w") as f:
#     for key in sorted(flax_vae_nested_state_dict_flattened.keys()):
#         f.write(key + "\n")
# with open("unet/flax_unet_nested_state_dict_keys_flattened.txt", "w") as f:
#     for key in sorted(flax_unet_nested_state_dict_flattened.keys()):
#         f.write(key + "\n")
# with open("clip/flax_clip_nested_state_dict_keys_flattened.txt", "w") as f:
#     for key in sorted(flax_clip_nested_state_dict_flattened.keys()):
#         f.write(key + "\n")

In [None]:
flax_vae_nested_state_dict_flattened = flax.traverse_util.flatten_dict(flax_vae_nested_state_dict)

import jax.numpy as jnp

for flax_key_tuple, flax_tensor in sorted(flax_vae_nested_state_dict_flattened.items()):
  flax_key_tuple_array = flax_key_tuple
  # flax_key_tuple_array = flax_key_tuple.split('.')
  if flax_key_tuple_array[-1] == "kernel" and flax_tensor.ndim == 4:
    flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
    flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1))
  elif flax_key_tuple_array[-1] == "kernel":
    flax_key_tuple_array = flax_key_tuple_array[:-1] +  ["weight"]
    flax_tensor = flax_tensor.T
    print(flax_key_tuple)
  elif flax_key_tuple_array[-1] == "scale":
    flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]

  for i, flax_key_tuple_string in enumerate(flax_key_tuple_array):
    flax_key_tuple_array[i] = flax_key_tuple_string.replace('_0', '.0').replace('_1', '.1').replace('_2', '.2').replace('_3', '.3')

  flax_key_tuple = ".".join(flax_key_tuple_array)
  # print(flax_key_tuple)

In [None]:
import jax.numpy as jnp
import flax

with open("/content/unet/diffusion_flax_model.msgpack", "rb") as state_unet_f:
  flax_unet_nested_state_dict = flax.serialization.from_bytes(None, state_unet_f.read())

flax_unet_nested_state_dict_flattened = flax.traverse_util.flatten_dict(flax_unet_nested_state_dict, sep=".")

In [None]:
for flax_key_tuple, flax_tensor in sorted(flax_unet_nested_state_dict_flattened.items()):
  flax_key_tuple_array = flax_key_tuple.split('.')
  if flax_key_tuple_array[-1] == "kernel" and flax_tensor.ndim == 4:
    flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
    flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1))
  elif flax_key_tuple_array[-1] == "kernel":
    flax_key_tuple_array = flax_key_tuple_array[:-1] +  ["weight"]
    flax_tensor = flax_tensor.T
  elif flax_key_tuple_array[-1] == "scale":
    flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]

  if not "time_embedding" in flax_key_tuple_array:
    for i, flax_key_tuple_string in enumerate(flax_key_tuple_array):
      flax_key_tuple_array[i] = flax_key_tuple_string.replace('_0', '.0').replace('_1', '.1').replace('_2', '.2').replace('_3', '.3')
  
  flax_key_tuple = ".".join(flax_key_tuple_array)
  # if "time_embedding" in flax_key_tuple:
  #    print(flax_key_tuple)
  print(flax_key_tuple)

In [None]:
import jax.numpy as jnp
flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1))