# Stable Diffusion v2 Demo with Torch Compile

## Prerequisites

install required packages

In [None]:
%pip install -q "diffusers>=0.14.0" openvino-nightly "datasets>=2.14.6" "transformers>=4.25.1" "gradio>=4.19" "torch>=2.1" Pillow opencv-python --extra-index-url https://download.pytorch.org/whl/cpu
%pip install -q git+https://github.com/openvinotoolkit/nncf.git
%pip install -q accelerate

## Stable Diffusion v2 for Text-to-Image Generation

To start, let's look on Text-to-Image process for Stable Diffusion v2. We will use [Stable Diffusion v2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1) model for these purposes. The main difference from Stable Diffusion v2 and Stable Diffusion v2.1 is usage of more data, more training, and less restrictive filtering of the dataset, that gives promising results for selecting wide range of input text prompts. More details about model can be found in [Stability AI blog post](https://stability.ai/blog/stablediffusion2-1-release7-dec-2022) and original model [repository](https://github.com/Stability-AI/stablediffusion).

### Stable Diffusion in Diffusers library
To work with Stable Diffusion v2, we will use Hugging Face [Diffusers](https://github.com/huggingface/diffusers) library. To experiment with Stable Diffusion models, Diffusers exposes the [`StableDiffusionPipeline`](https://huggingface.co/docs/diffusers/using-diffusers/conditional_image_generation) similar to the [other Diffusers pipelines](https://huggingface.co/docs/diffusers/api/pipelines/overview).  The code below demonstrates how to create `StableDiffusionPipeline` using `stable-diffusion-2-1`:

In [1]:
from torch._export import capture_pre_autograd_graph
from nncf.torch.dynamic_graph.patch_pytorch import disable_patching
import numpy as np

INFO:nncf:NNCF initialized successfully. Supported frameworks detected: torch, onnx, openvino


In [2]:
from diffusers import DiffusionPipeline
import torch
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)
generator = torch.Generator(device="cpu").manual_seed(42)
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
pipe.to("cpu")

# if using torch < 2.0
# pipe.enable_xformers_memory_efficient_attention()

prompt = "valley in the Alps at sunset, epic vista, beautiful landscape, 4k, 8k"
negative_prompt = "frames, borderline, text, charachter, duplicate, error, out of frame, watermark, low quality, ugly, deformed, blur"

images = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=25, generator=generator).images[0]
# images.show()
# images.save("/home/user/Downloads/stable-diffusion-xl-experiments/stable_diffusion_xl_all_pytorch_model.png")# del images

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/25 [00:00<?, ?it/s]

torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
torch.Size([2, 1280]) torch.Size([

In [3]:
images.show()

In [3]:
inp = torch.ones((1, 4, 128, 128))
pipe_vae_out = pipe.vae.decode(inp)

torch.Size([1, 4, 128, 128])


In [5]:
vae_out = vae_decoder(inp)

In [11]:
pipe_vae_out["sample"].shape

torch.Size([1, 3, 1024, 1024])

In [126]:
pipe.text_encoder = torch.compile(pipe.text_encoder.eval(), backend='openvino')
pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2.eval(), backend='openvino')
# pipe.vae.encoder = torch.compile(pipe.vae.encoder.eval(), backend='openvino')
pipe.vae.decoder = torch.compile(pipe.vae.decoder.eval(), backend='openvino')
pipe.unet = torch.compile(pipe.unet.eval(), backend='openvino')

In [5]:
torch.manual_seed(42)
with torch.no_grad():
    pipe(prompt ,num_inference_steps=1)
    image = pipe(prompt=prompt, negative_prompt=negative_prompt ,num_inference_steps=25, generator=generator).images[0]
image.show()
image.save("/home/user/Downloads/stable-diffusion-xl-experiments/stable_diffusion_xl_all_torch_compile_model.png")

  0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
Reached decoding
torch.Size([1, 4, 128, 128])


  0%|          | 0/25 [00:00<?, ?it/s]

torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
torch.Size([2, 1280]) torch.Size([

### Convert Models to Torch Fx Graph

In [6]:
text_encoder_input = torch.ones((1, 77), dtype=torch.long)
text_encoder_2_input = torch.ones((1, 77), dtype=torch.long)
vae_encoder_input = torch.ones((1, 3, 256, 256))
vae_decoder_input = torch.ones((1, 4, 128, 128))
vae_decoder_kwargs = {}
vae_decoder_kwargs["return_dict"] = False
latents_shape = (2, 4, 128, 128)
latents = torch.randn(latents_shape)
t = torch.from_numpy(np.array(1, dtype=np.float32))
added_cond_kwargs = {}
added_cond_kwargs["text_embeds"] = torch.ones((2, 1280))
added_cond_kwargs["time_ids"] = torch.ones((2,6))
unet_kwargs = {}
unet_kwargs["encoder_hidden_states"] = torch.ones((2, 77, 2048))
unet_kwargs["added_cond_kwargs"] = added_cond_kwargs
unet_kwargs["return_dict"] = False
unet_input = (latents, t)

text_encoder_kwargs = {}
text_encoder_kwargs['output_hidden_states'] = True

with torch.no_grad():
    with disable_patching():
        text_encoder = capture_pre_autograd_graph(pipe.text_encoder.eval(), args=(text_encoder_input,), kwargs=(text_encoder_kwargs))
        text_encoder_2 = capture_pre_autograd_graph(pipe.text_encoder_2.eval(), args=(text_encoder_2_input,), kwargs=(text_encoder_kwargs))
        # vae_encoder = capture_pre_autograd_graph(pipe.vae.encoder, args=(vae_encoder_input,))
        vae_decoder = capture_pre_autograd_graph(pipe.vae.decoder, args=(vae_decoder_input,))
        unet = capture_pre_autograd_graph(pipe.unet.eval(), args=(*unet_input,), kwargs=(unet_kwargs))
del added_cond_kwargs
del unet_kwargs
del unet_input
del latents
del t
del vae_encoder_input
del vae_decoder_input
del text_encoder_2_input
del text_encoder_input


### Weights Compression

In [7]:
class VAEDecoderWrapper(torch.nn.Module):
        def __init__(self, vae):
            super().__init__()
            self.vae = vae

        def forward(self, latents):
            return self.vae.decode(latents)
vae_decoder = VAEDecoderWrapper(pipe.vae)

In [129]:
import nncf
# compressed_text_encoder = nncf.compress_weights(text_encoder)
# compressed_text_encoder_2 = nncf.compress_weights(text_encoder_2)
# compressed_vae_encoder = nncf.compress_weights(vae_encoder)
# compressed_vae_decoder = nncf.compress_weights(vae_decoder)
compressed_unet = nncf.compress_weights(unet)
del text_encoder
del text_encoder_2
del vae_decoder
del unet

INFO:nncf:Statistics of the bitwidth distribution:
┍━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┑
│   Num bits (N) │ % all parameters (layers)   │ % ratio-defining parameters (layers)   │
┝━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┥
│              8 │ 100% (794 / 794)            │ 100% (794 / 794)                       │
┕━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┙


Output()

### Compile Models with OV Backend

In [130]:
compiled_compressed_text_encoder = torch.compile(compressed_text_encoder, backend='openvino')
compiled_compressed_text_encoder_2 = torch.compile(compressed_text_encoder_2, backend='openvino')
compiled_compressed_unet = torch.compile(compressed_unet, backend='openvino')
compiled_compressed_vae_encoder = torch.compile(compressed_vae_encoder, backend='openvino')
compiled_compressed_vae_decoder = torch.compile(compressed_vae_decoder, backend='openvino')

In [21]:
del compressed_text_encoder
del compressed_text_encoder_2
del compressed_vae_decoder
del compressed_unet

In [131]:
pipe.text_encoder = compiled_compressed_text_encoder
pipe.text_encoder_2 = compiled_compressed_text_encoder_2
pipe.vae.encoder = compiled_compressed_vae_encoder
pipe.vae.decoder = compiled_compressed_vae_decoder
pipe.unet = compiled_compressed_unet

In [None]:
pipe.text_encoder = compressed_text_encoder
pipe.text_encoder_2 = compressed_text_encoder_2
pipe.vae.decoder = compressed_vae_decoder
pipe.unet = compressed_unet

In [132]:
del compiled_compressed_text_encoder
del compiled_compressed_text_encoder_2
del compiled_compressed_vae_decoder
del compiled_compressed_vae_encoder
del compiled_compressed_unet

In [8]:
pipe.text_encoder = text_encoder
pipe.text_encoder_2 = text_encoder_2
pipe.vae.decoder = vae_decoder
pipe.unet = unet

In [None]:
for node in pipe.unet.graph.nodes:
    print(node)

### Inference for Compilation

In [135]:
#Warmup the model for initial compile
prompt = "valley in the Alps at sunset, epic vista, beautiful landscape, 4k, 8k"
negative_prompt = "frames, borderline, text, charachter, duplicate, error, out of frame, watermark, low quality, ugly, deformed, blur"
num_steps = 1
with torch.no_grad():
    image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_steps, generator=generator).images[0]

  0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None


: 

In [33]:
torch.seed()

13185493284991089338

In [None]:
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)
with torch.no_grad():
    image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=25, generator=generator).images[0]
image.show()
image.save("stable_diffusion_xl_unet_compressed_rest_torch_fx_compiled_model.png")

  0%|          | 0/25 [00:00<?, ?it/s]

torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
torch.Size([2, 1280]) torch.Size([

## Running Inference
Generating an image with the same parameters as the original OV Stable diffusion model for comparison

In [18]:
num_steps = 25

with torch.no_grad():
    image = pipe(prompt, num_inference_steps=num_steps).images[0]
image.show()


  0%|          | 0/25 [00:00<?, ?it/s]

torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
torch.Size([2, 1280]) torch.Size([