# Stable Diffusion v2 Demo with Torch Compile

## Stable Diffusion v2 for Text-to-Image Generation

To start, let's look on Text-to-Image process for Stable Diffusion v2. We will use [Stable Diffusion v2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1) model for these purposes. The main difference from Stable Diffusion v2 and Stable Diffusion v2.1 is usage of more data, more training, and less restrictive filtering of the dataset, that gives promising results for selecting wide range of input text prompts. More details about model can be found in [Stability AI blog post](https://stability.ai/blog/stablediffusion2-1-release7-dec-2022) and original model [repository](https://github.com/Stability-AI/stablediffusion).

### Stable Diffusion in Diffusers library
To work with Stable Diffusion v2, we will use Hugging Face [Diffusers](https://github.com/huggingface/diffusers) library. To experiment with Stable Diffusion models, Diffusers exposes the [`StableDiffusionPipeline`](https://huggingface.co/docs/diffusers/using-diffusers/conditional_image_generation) similar to the [other Diffusers pipelines](https://huggingface.co/docs/diffusers/api/pipelines/overview).  The code below demonstrates how to create `StableDiffusionPipeline` using `stable-diffusion-2-1`:

In [1]:
from torch._export import capture_pre_autograd_graph
from nncf.torch.dynamic_graph.patch_pytorch import disable_patching
import numpy as np

INFO:nncf:NNCF initialized successfully. Supported frameworks detected: torch, onnx, openvino


In [2]:
from diffusers import DiffusionPipeline
import torch
import random
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)
generator = torch.Generator(device="cpu").manual_seed(42)
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
pipe.to("cpu")

prompt = "valley in the Alps at sunset, epic vista, beautiful landscape, 4k, 8k"
negative_prompt = "frames, borderline, text, charachter, duplicate, error, out of frame, watermark, low quality, ugly, deformed, blur"

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

### Convert Models to Torch Fx Graph

In [3]:
text_encoder_input = torch.ones((1, 77), dtype=torch.long)
text_encoder_2_input = torch.ones((1, 77), dtype=torch.long)
vae_encoder_input = torch.ones((1, 3, 256, 256))
vae_decoder_input = torch.ones((1, 4, 128, 128))
vae_decoder_kwargs = {}
vae_decoder_kwargs["return_dict"] = False
latents_shape = (2, 4, 128, 128)
latents = torch.randn(latents_shape)
t = torch.from_numpy(np.array(1, dtype=np.float32))
added_cond_kwargs = {}
added_cond_kwargs["text_embeds"] = torch.ones((2, 1280))
added_cond_kwargs["time_ids"] = torch.ones((2,6))
unet_kwargs = {}
unet_kwargs["encoder_hidden_states"] = torch.ones((2, 77, 2048))
unet_kwargs["added_cond_kwargs"] = added_cond_kwargs
unet_kwargs["return_dict"] = False
unet_input = (latents, t)

text_encoder_kwargs = {}
text_encoder_kwargs['output_hidden_states'] = True

with torch.no_grad():
    with disable_patching():
        text_encoder = capture_pre_autograd_graph(pipe.text_encoder.eval(), args=(text_encoder_input,), kwargs=(text_encoder_kwargs))
        text_encoder_2 = capture_pre_autograd_graph(pipe.text_encoder_2.eval(), args=(text_encoder_2_input,), kwargs=(text_encoder_kwargs))
        vae_encoder = capture_pre_autograd_graph(pipe.vae.encoder, args=(vae_encoder_input,))
        vae_decoder = capture_pre_autograd_graph(pipe.vae.decoder.eval(), args=(vae_decoder_input,))
        unet = capture_pre_autograd_graph(pipe.unet.eval(), args=(*unet_input,), kwargs=(unet_kwargs))
del added_cond_kwargs
del unet_kwargs
del unet_input
del latents
del t
del vae_encoder_input
del vae_decoder_input
del text_encoder_2_input
del text_encoder_input
del text_encoder_kwargs
del vae_decoder_kwargs

In [4]:
pipe.text_encoder = torch.compile(text_encoder, backend='openvino')
pipe.text_encoder_2 = torch.compile(text_encoder_2, backend='openvino')
pipe.vae.decoder = torch.compile(vae_decoder, backend='openvino')
pipe.vae.encoder = torch.compile(vae_encoder, backend='openvino')
pipe.unet = unet

### Weights Compression

In [5]:
import pickle
with open("test", "rb") as fp:
    unet_calibration_data = pickle.load(fp)

def collect_ops_with_weights(graph_module):
    ops_with_weights = []
    for node in graph_module.graph.nodes:
        if "linear" in node.name:
            ops_with_weights.append(node.name)
    return ops_with_weights

unet_ignored_scope = collect_ops_with_weights(pipe.unet)

In [6]:
import nncf
from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters
from nncf.quantization.range_estimator import RangeEstimatorParametersSet

def disable_progress_bar(pipeline, disable=True):
    if not hasattr(pipeline, "_progress_bar_config"):
        pipeline._progress_bar_config = {'disable': disable}
    else:
        pipeline._progress_bar_config['disable'] = disable

with disable_patching():
    with torch.no_grad():
        compressed_unet = nncf.compress_weights(unet, ignored_scope=nncf.IgnoredScope(types=['conv2d']))
        # quantized_unet = nncf.quantize( #1
        #     model=pipe.unet,
        #     calibration_dataset=nncf.Dataset(unet_calibration_data),
        #     subset_size=len(unet_calibration_data),
        #     model_type=nncf.ModelType.TRANSFORMER,
        #     ignored_scope=nncf.IgnoredScope(names=unet_ignored_scope),
        #     advanced_parameters=nncf.AdvancedQuantizationParameters(smooth_quant_alpha=-1, disable_bias_correction=True)
        # )
        quantized_unet = nncf.quantize( #2
            model=compressed_unet,
            calibration_dataset=nncf.Dataset(unet_calibration_data),
            subset_size=len(unet_calibration_data),
            model_type=nncf.ModelType.TRANSFORMER,
            fast_bias_correction=False,
            ignored_scope=nncf.IgnoredScope(names=unet_ignored_scope),
            advanced_parameters=nncf.AdvancedQuantizationParameters(smooth_quant_alpha=-1, disable_bias_correction=True, weights_range_estimator_params=RangeEstimatorParametersSet.MINMAX, activations_range_estimator_params=RangeEstimatorParametersSet.MINMAX)
        )
        # quantized_unet = nncf.quantize( #3
        #     model=pipe.unet,
        #     calibration_dataset=nncf.Dataset(unet_calibration_data),
        #     subset_size=len(unet_calibration_data),
        #     model_type=nncf.ModelType.TRANSFORMER,
        #     fast_bias_correction=False,
        #     ignored_scope=nncf.IgnoredScope(names=unet_ignored_scope),
        #     advanced_parameters=nncf.AdvancedQuantizationParameters(smooth_quant_alpha=-1)
        # )
        # quantized_unet = nncf.quantize( #4.1
        #     model=pipe.unet,
        #     calibration_dataset=nncf.Dataset(unet_calibration_data),
        #     subset_size=len(unet_calibration_data),
        #     model_type=nncf.ModelType.TRANSFORMER,
        #     fast_bias_correction=False,
        #     ignored_scope=nncf.IgnoredScope(names=unet_ignored_scope),
        #     advanced_parameters=nncf.AdvancedQuantizationParameters(smooth_quant_alphas=AdvancedSmoothQuantParameters(convolution=0.95, matmul=-1))
        # )
        # quantized_unet = nncf.quantize( #4.2
        #     model=pipe.unet,
        #     calibration_dataset=nncf.Dataset(unet_calibration_data),
        #     subset_size=len(unet_calibration_data),
        #     model_type=nncf.ModelType.TRANSFORMER,
        #     fast_bias_correction=False,
        #     ignored_scope=nncf.IgnoredScope(names=unet_ignored_scope),
        #     advanced_parameters=nncf.AdvancedQuantizationParameters(smooth_quant_alphas=AdvancedSmoothQuantParameters(convolution=0.95))
        # )
# del text_encoder
# del text_encoder_2
# del vae_decoder
del unet

INFO:nncf:51 ignored nodes were found by types in the NNCFGraph
INFO:nncf:Statistics of the bitwidth distribution:
┍━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┯━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┑
│   Num bits (N) │ % all parameters (layers)   │ % ratio-defining parameters (layers)   │
┝━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┿━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┥
│              8 │ 100% (743 / 743)            │ 100% (743 / 743)                       │
┕━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┙


Output()





INFO:nncf:743 ignored nodes were found by names in the NNCFGraph
INFO:nncf:Not adding activation input quantizer for operation: 1051 linear_109
INFO:nncf:Not adding activation input quantizer for operation: 1054 linear_110
INFO:nncf:Not adding activation input quantizer for operation: 1131 linear_119
INFO:nncf:Not adding activation input quantizer for operation: 1134 linear_120
INFO:nncf:Not adding activation input quantizer for operation: 1211 linear_129
INFO:nncf:Not adding activation input quantizer for operation: 1214 linear_130
INFO:nncf:Not adding activation input quantizer for operation: 1291 linear_139
INFO:nncf:Not adding activation input quantizer for operation: 1294 linear_140
INFO:nncf:Not adding activation input quantizer for operation: 1371 linear_149
INFO:nncf:Not adding activation input quantizer for operation: 1374 linear_150
INFO:nncf:Not adding activation input quantizer for operation: 1500 linear_162
INFO:nncf:Not adding activation input quantizer for operation: 150

Output()

### Compile Models with OV Backend

In [8]:
pipe.unet = torch.compile(quantized_unet, backend='openvino')
del quantized_unet

In [11]:
def get_model_size(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    model_size_mb = (param_size + buffer_size) / 1024**2

    return model_size_mb

get_model_size(pipe.unet)

3410.9036712646484

### Inference for Compilation

In [9]:
#Warmup the model for initial compile
prompt = "cute cat 4k, high-res, masterpiece, best quality, soft lighting, dynamic angle"
# negative_prompt = "frames, borderline, text, charachter, duplicate, error, out of frame, watermark, low quality, ugly, deformed, blur"
num_steps = 1
with torch.no_grad():
    image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_steps, generator=generator).images[0]

  0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
Reached
torch.Size([1, 4, 128, 128])


In [12]:
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)
with torch.no_grad():
    image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=25).images[0]
image.show()
image.save("experiment_2_compressed.png")

  0%|          | 0/25 [00:00<?, ?it/s]

torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
torch.Size([2, 1280]) torch.Size([2, 6]) torch.Size([2, 4, 128, 128]) torch.Size([]) encoder_hidden_states= torch.Size([2, 77, 2048]) cross_attention_kwarg None
torch.Size([2, 1280]) torch.Size([