# Extract LoRA from Fine-tune Stable Diffusion model and Deploy Model with Multiple Adapters

In this tutorial, we will convert a fine-tune stable diffusion model to ONNX, and extract the LoRA adapters from the model. 
The resulting model can be deployed with multiple adapters for different tasks.

## Prerequisites

Before running this tutorial, please ensure you already installed olive-ai. Please refer to the [installation guide](https://github.com/microsoft/Olive?tab=readme-ov-file#installation) for more information.

### Install Dependencies
We will optimize for `CUDAExecutionProvider` so `onnxruntime-gpu>=1.20` should also be installed allong with the other dependencies:

In [None]:
# install required packages
!pip install -r diffusers

# install onnxruntime-gpu >1.20, if not available install onnxruntime-gpu nightly
# ort-nightly had been renamed to onnxruntime: https://github.com/microsoft/onnxruntime/issues/22541
!pip uninstall -y onnxruntime onnxruntime-gpu ort-nightly ort-nightly-gpu
!pip install "onnxruntime-gpu>=1.20" || pip install --pre onnxruntime-gpu --extra-index-url=https://pkgs.dev.azure.com/aiinfra/PublicPackages/_packaging/ORT-Nightly/pypi/simple/

## Workflow

Let's try lovely [wolf plushie LoRA](https://huggingface.co/lora-library/B-LoRA-wolf_plushie) and [pen sketch LoRA](https://huggingface.co/lora-library/B-LoRA-pen_sketch) in this example.

Olive provides command line tools to run the export and extract adapters workflow. This workflow includes the following steps:
- `capture-onnx-graph`: Convert the fine-tuned model to ONNX
- `generate-adapter`: Extract the adapters from the ONNX model as model inputs.

In [None]:
# run this cell to see the available options to finetune, capture-onnx-graph and generate-adapter commands
!olive capture-onnx-graph --help
!olive generate-adapter --help

We convert vae encoder, vae decoder, text encoder, text encoder 2 and unet with LoRAs to ONNX model first by `capture-onnx-graph` command:

In [None]:
# convert vae encoder
!olive capture-onnx-graph -m stabilityai/stable-diffusion-xl-base-1.0 --model_script vae_encoder.py -o onnx_model/vae_encoder
# convert vae decoder
!olive capture-onnx-graph -m stabilityai/stable-diffusion-xl-base-1.0 --model_script vae_decoder.py -o onnx_model/vae_decoder
# convert text encoder
!olive capture-onnx-graph -m stabilityai/stable-diffusion-xl-base-1.0 --model_script text_encoder.py -o onnx_model/text_encoder
# convert text encoder 2
!olive capture-onnx-graph -m stabilityai/stable-diffusion-xl-base-1.0 --model_script text_encoder2.py -o onnx_model/text_encoder_2
# convert unet model with wolf plushie lora
!olive capture-onnx-graph -m stabilityai/stable-diffusion-xl-base-1.0 --model_script unet_wolf_plushie.py -o onnx_model/unet_wolf_plushie
# convert unet model with pen sketch lora
!olive capture-onnx-graph -m stabilityai/stable-diffusion-xl-base-1.0 --model_script unet_pen_sketch.py -o onnx_model/unet_pen_sketch

Let's try pen sketch LoRA first:

In [None]:
import onnxruntime as ort
import torch
from pathlib import Path
from optimum.onnxruntime import ORTStableDiffusionXLPipeline
from diffusers import DiffusionPipeline, OnnxRuntimeModel

ort.set_default_logger_severity(3)

sess_options = ort.SessionOptions()
sess_options.enable_mem_pattern = False

model_id = "stabilityai/stable-diffusion-xl-base-1.0"
onnx_model_path = Path("onnx_model")
provider = "CUDAExecutionProvider"
pipeline = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)

vae_encoder_session = OnnxRuntimeModel.load_model(onnx_model_path / "vae_encoder" / "model.onnx", provider=provider)
vae_decoder_session = OnnxRuntimeModel.load_model(onnx_model_path / "vae_decoder" / "model.onnx", provider=provider)
text_encoder_session = OnnxRuntimeModel.load_model(onnx_model_path / "text_encoder" / "model.onnx", provider=provider)
text_encoder_2_session = OnnxRuntimeModel.load_model(onnx_model_path / "text_encoder_2" / "model" / "model.onnx", provider=provider)

pen_sketch_unet_session = OnnxRuntimeModel.load_model(onnx_model_path / "unet_pen_sketch" / "model" / "model.onnx", provider=provider)

onnx_pipeline = ORTStableDiffusionXLPipeline(
    vae_encoder_session=vae_encoder_session,
    vae_decoder_session=vae_decoder_session,
    text_encoder_session=text_encoder_session,
    unet_session=pen_sketch_unet_session,
    text_encoder_2_session=text_encoder_2_session,
    tokenizer=pipeline.tokenizer,
    tokenizer_2=pipeline.tokenizer_2,
    scheduler=pipeline.scheduler,
    feature_extractor=pipeline.feature_extractor,
    config=dict(pipeline.config),
)

In [None]:
prompt = "A woman is dancing [v30]"
batch_size = 1
result = onnx_pipeline(
    [prompt] * batch_size,
    num_inference_steps=20,
    height=512,
    width=512,
)

for image_index in range(batch_size):
    output_path = f"result_pen_{image_index}.png"
    result.images[image_index].save(output_path)

In [None]:
del pen_sketch_unet_session
del onnx_pipeline

Then you can see the image, like this:   
![pen sketch](image/result_pen.png)

Let's try wolf plushie LoRA:

In [3]:
# Update onnx pipeline with wolf plushie unet
wolf_plushie_unet_session = OnnxRuntimeModel.load_model(onnx_model_path / "unet_wolf_plushie" / "model" / "model.onnx", provider=provider)

onnx_pipeline = ORTStableDiffusionXLPipeline(
    vae_encoder_session=vae_encoder_session,
    vae_decoder_session=vae_decoder_session,
    text_encoder_session=text_encoder_session,
    unet_session=wolf_plushie_unet_session,
    text_encoder_2_session=text_encoder_2_session,
    tokenizer=pipeline.tokenizer,
    tokenizer_2=pipeline.tokenizer_2,
    scheduler=pipeline.scheduler,
    feature_extractor=pipeline.feature_extractor,
    config=dict(pipeline.config),
)

In [None]:
prompt = "A running wolf"
batch_size = 1
result = onnx_pipeline(
    [prompt] * batch_size,
    num_inference_steps=20,
    height=512,
    width=512,
)

for image_index in range(batch_size):
    output_path = f"result_wolf_{image_index}.png"
    result.images[image_index].save(output_path)

In [None]:
del wolf_plushie_unet_session
del onnx_pipeline

Then you can see the image, like this:   
![wolf plushie](image/result_wolf.png)

Finally, extract the adapters from the ONNX model.

In [None]:
# Extract adapters from pen sketch unet model
!olive generate-adapter -m onnx_model/unet_pen_sketch -o adapters/pen_sketch
# Extract adapters from wolf plushie unet model
!olive generate-adapter -m onnx_model/unet_wolf_plushie -o adapters/wolf_plushie

## Deploy Model with Multiple Adapters

We can now deploy the same model with multiple adapters for different tasks by loading the adapter weights independently of the model and providing the relevant weights as input at inference time.

In [37]:
import shutil

file_list = ["model.onnx", "model.onnx.data"]
old_model_path = Path("adapters/pen_sketch/model")
base_model_path = Path("model")

base_model_path.parent.mkdir(parents=True, exist_ok=True)

# copy base onnx model to new folder
for file in file_list:
    shutil.copy2(old_model_path / file, base_model_path / file)

base_model_path = base_model_path / "model.onnx"

adapters = {
    "pen_sketch":  "adapters/pen_sketch/model/adapter_weights.onnx_adapter",
    "wolf_plushie": "adapters/wolf_plushie/model/adapter_weights.onnx_adapter"
}

Here, we implement a `UNetSessionWrapper` to inject adapter weights as inputs to the UNet model.

In [40]:
import sys

utils_path = Path().resolve().parent.parent.parent / "olive" / "common"
sys.path.append(str(utils_path))

from utils import load_weights

class UNetSessionWrapper:
    def __init__(self, unet_session):
        self.unet_session = unet_session
        self.adapters = {}
        self.active_adapter = None

    def load_adapter(self, adapter_name, adapter_path):
        self.adapters[adapter_name] = load_weights(adapter_path)

    def set_adapter(self, adapter_name):
        assert adapter_name in self.adapters, f"Adapter {adapter_name} not found"
        self.active_adapter = adapter_name

    def unset_adapter(self):
        self.active_adapter = None

    def run(self, output_names, input_feed):
        # running when adapter is not set is equivalent to running the base model
        inputs = {**input_feed, **self.adapters.get(self.active_adapter, {})}
        return self.unet_session.run(output_names, inputs)

    def __getattr__(self, name):
        return getattr(self.unet_session, name)


# load unet session from base model
unet_session = OnnxRuntimeModel.load_model(base_model_path, provider=provider)
unet_session_wrapped = UNetSessionWrapper(unet_session)

unet_session_wrapped.load_adapter("pen_sketch", adapters["pen_sketch"])
unet_session_wrapped.load_adapter("wolf_plushie", adapters["wolf_plushie"])

onnx_pipeline = ORTStableDiffusionXLPipeline(
    vae_encoder_session=vae_encoder_session,
    vae_decoder_session=vae_decoder_session,
    text_encoder_session=text_encoder_session,
    unet_session=unet_session_wrapped,
    text_encoder_2_session=text_encoder_2_session,
    tokenizer=pipeline.tokenizer,
    tokenizer_2=pipeline.tokenizer_2,
    scheduler=pipeline.scheduler,
    feature_extractor=pipeline.feature_extractor,
    config=dict(pipeline.config),
)

#### Generate with Pen Sketch Adapters
Let's test pen sketch adapters first

In [None]:
unet_session_wrapped.set_adapter("pen_sketch")

prompt = "A dancing woman"
batch_size = 1
result = onnx_pipeline(
    [prompt] * batch_size,
    num_inference_steps=20,
    height=512,
    width=512,
)

for image_index in range(batch_size):
    output_path = f"result_pen_merge_{image_index}.png"
    result.images[image_index].save(output_path)

Then you can see the image from pen sketch adapter weight merged model, like this:   
![pen sketch](image/result_pen_merge.png)

#### Generate with Wolf Plushie Adapters

In [None]:
unet_session_wrapped.set_adapter("wolf_plushie")

prompt = "A running wolf"
batch_size = 1
result = onnx_pipeline(
    [prompt] * batch_size,
    num_inference_steps=20,
    height=512,
    width=512,
)

for image_index in range(batch_size):
    output_path = f"result_wolf_merge_{image_index}.png"
    result.images[image_index].save(output_path)

Then you can see the image from wolf plushie adapter weight merged model, like this:   
![wolf plushie](image/result_wolf_merge.png)