In [1]:
from typing import Optional, List
from typing import Any, Dict, Optional, Union

import numpy as np
import torch
from diffusers import (
    DDIMScheduler,
    LMSDiscreteScheduler,
    PNDMScheduler,
    StableDiffusionPipeline,
)

import json
from optimum.pipelines.diffusers.pipeline_stable_diffusion import (
    StableDiffusionPipelineMixin,
)
from optimum.pipelines.diffusers.pipeline_utils import rescale_noise_cfg
from typing import Callable, List, Optional, Union
import inspect
import numpy as np
import torch
from abc import abstractmethod
from tritonclient.grpc import InferenceServerClient, InferInput, InferRequestedOutput
from tritonclient.utils import np_to_triton_dtype
from typing import Optional, List
import numpy as np

from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput

from optimum.modeling_base import OptimizedModel

import importlib
import os
from pathlib import Path
from typing import Any, Dict, Optional, Union

from diffusers import (
    DDIMScheduler,
    LMSDiscreteScheduler,
    PNDMScheduler,
    StableDiffusionPipeline,
)
from transformers import CLIPFeatureExtractor, CLIPTokenizer

from optimum.utils import (
    DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER,
    DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER,
    DIFFUSION_MODEL_UNET_SUBFOLDER,
    DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER,
    DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER,
)
from optimum.pipelines.diffusers.pipeline_utils import VaeImageProcessor

from optimum.exporters import TasksManager

from pathlib import Path
from typing import Any, Dict, Optional, Union
from transformers import (
    AutoModel,
    GenerationMixin,
)

from transformers import AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


# Modeling Triton Model parts

In [2]:
class _RemoteTritonDiffusionModelPart:
    def __init__(
        self, channel: str, model_name: str, model_version: Optional[int] = None
    ):
        self._client: InferenceServerClient = InferenceServerClient(channel)
        self._model_name: str = model_name
        self._model_version: str = model_version if model_version is not None else ""

        if not self._client.is_server_live():
            raise ConnectionError("Triton server is not live")

        if not self._client.is_server_ready():
            raise ConnectionError("Triton server is not ready")

        if not self._client.is_model_ready(self._model_name, self._model_version):
            raise ConnectionError(f"Model {self._model_name} is not ready")

        self._metadata = self._client.get_model_metadata(
            self._model_name, self._model_version
        )

        self._input_dtypes = {
            input_.name: input_.datatype for input_ in self._metadata.inputs
        }

        self.output_names = [output_.name for output_ in self._metadata.outputs]

    def make_infer_input(self, input_name: str, input_data: np.ndarray) -> InferInput:
        if input_name not in self._input_dtypes.keys():
            raise ValueError(f"Input '{input_name}' is not found in the model")

        expected_data_type = self._input_dtypes[input_name]

        actual_data_type = np_to_triton_dtype(input_data.dtype)

        if actual_data_type != expected_data_type:
            raise ValueError(
                f"Input '{input_name}' has dtype '{actual_data_type}' but expected '{expected_data_type}'"
            )

        infer_input = InferInput(input_name, input_data.shape, actual_data_type)
        infer_input.set_data_from_numpy(input_data)

        return infer_input

    def make_infer_requested_outputs(
        self,
    ) -> List[InferRequestedOutput]:
        return [InferRequestedOutput(name) for name in self.output_names]

    @abstractmethod
    def forward(self, *args, **kwargs):
        pass

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)


class RemoteTritonTextEncoder(_RemoteTritonDiffusionModelPart):
    def forward(self, input_ids: np.ndarray):
        results = self._client.infer(
            model_name=self._model_name,
            inputs=[self.make_infer_input("input_ids", input_ids)],
            outputs=self.make_infer_requested_outputs(),
        )
        return [results.as_numpy("last_hidden_state")]


class RemoteTritonModelUnet(_RemoteTritonDiffusionModelPart):
    def forward(
        self,
        sample: np.ndarray,
        timestep: np.ndarray,
        encoder_hidden_states: np.ndarray,
        text_embeds: Optional[np.ndarray] = None,
        time_ids: Optional[np.ndarray] = None,
        timestep_cond: Optional[np.ndarray] = None,
    ):
        inputs = [
            self.make_infer_input("sample", sample),
            self.make_infer_input("timestep", timestep),
            self.make_infer_input("encoder_hidden_states", encoder_hidden_states),
        ]

        if text_embeds is not None:
            inputs.append(self.make_infer_input("text_embeds", text_embeds))
        if time_ids is not None:
            inputs.append(self.make_infer_input("time_ids", time_ids))
        if timestep_cond is not None:
            inputs.append(self.make_infer_input("timestep_cond", timestep_cond))

        results = self._client.infer(
            model_name=self._model_name,
            inputs=inputs,
            outputs=self.make_infer_requested_outputs(),
        )
        return [results.as_numpy(name) for name in self.output_names]


class RemoteTritonModelVaeDecoder(_RemoteTritonDiffusionModelPart):
    def forward(self, latent_sample: np.ndarray):
        results = self._client.infer(
            model_name=self._model_name,
            inputs=[self.make_infer_input("latent_sample", latent_sample)],
            outputs=self.make_infer_requested_outputs(),
        )
        return [results.as_numpy(name) for name in self.output_names]


class RemoteTritonModelVaeEncoder(_RemoteTritonDiffusionModelPart):
    def forward(self, sample: np.ndarray):
        results = self._client.infer(
            model_name=self._model_name,
            inputs=[self.make_infer_input("sample", sample)],
            outputs=self.make_infer_requested_outputs(),
        )
        return [results.as_numpy(name) for name in self.output_names]

In [3]:
text_encoder = RemoteTritonTextEncoder("localhost:8001", "text_encoder")
with open("./exported-models/onnx/stable-diffusion-v1-5/text_encoder/config.json") as f:
    text_encoder.config = json.load(f)

unet = RemoteTritonModelUnet("localhost:8001", "unet")
with open("./exported-models/onnx/stable-diffusion-v1-5/unet/config.json") as f:
    unet.config = json.load(f)

vae_decoder = RemoteTritonModelVaeDecoder("localhost:8001", "vae_decoder")
with open("./exported-models/onnx/stable-diffusion-v1-5/vae_decoder/config.json") as f:
    vae_decoder.config = json.load(f)


tokenizer = AutoTokenizer.from_pretrained("./exported-models/onnx/stable-diffusion-v1-5/tokenizer")


In [13]:
class ModifiedStableDiffusionPipelineMixin(StableDiffusionPipelineMixin):
    def __call__(
        self,
        prompt: Optional[Union[str, List[str]]] = None,
        height: Optional[int] = None,
        width: Optional[int] = None,
        num_inference_steps: int = 50,
        guidance_scale: float = 7.5,
        negative_prompt: Optional[Union[str, List[str]]] = None,
        num_images_per_prompt: int = 1,
        eta: float = 0.0,
        generator: Optional[np.random.RandomState] = None,
        latents: Optional[np.ndarray] = None,
        prompt_embeds: Optional[np.ndarray] = None,
        negative_prompt_embeds: Optional[np.ndarray] = None,
        output_type: str = "pil",
        return_dict: bool = True,
        callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
        callback_steps: int = 1,
        guidance_rescale: float = 0.0,
    ):
        height = (
            height or self.unet.config.get("sample_size", 64) * self.vae_scale_factor
        )
        width = width or self.unet.config.get("sample_size", 64) * self.vae_scale_factor

        # check inputs. Raise error if not correct
        self.check_inputs(
            prompt,
            height,
            width,
            callback_steps,
            negative_prompt,
            prompt_embeds,
            negative_prompt_embeds,
        )

        # define call parameters
        if isinstance(prompt, str):
            batch_size = 1
        elif isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        if generator is None:
            generator = np.random
        do_classifier_free_guidance = guidance_scale > 1.0

        prompt_embeds = self._encode_prompt(
            prompt,
            num_images_per_prompt,
            do_classifier_free_guidance,
            negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
        )

        self.scheduler.set_timesteps(num_inference_steps)
        timesteps = self.scheduler.timesteps

        latents = self.prepare_latents(
            batch_size * num_images_per_prompt,
            self.unet.config.get("in_channels", 4),
            height,
            width,
            prompt_embeds.dtype,
            generator,
            latents,
        )
        accepts_eta = "eta" in set(
            inspect.signature(self.scheduler.step).parameters.keys()
        )
        extra_step_kwargs = {}
        if accepts_eta:
            extra_step_kwargs["eta"] = eta

        # Adapted from diffusers to extend it for other runtimes than ORT
        timestep_dtype = self.unet.input_dtype.get("timestep", np.float32)

        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
        for i, t in enumerate(self.progress_bar(timesteps)):
            # expand the latents if we are doing classifier free guidance
            latent_model_input = (
                np.concatenate([latents] * 2)
                if do_classifier_free_guidance
                else latents
            )
            latent_model_input = self.scheduler.scale_model_input(
                torch.from_numpy(latent_model_input), t
            )
            latent_model_input = latent_model_input.cpu().numpy()
            # predict the noise residual

            timestep = np.array([t], dtype=timestep_dtype).reshape(1, 1).repeat(
                batch_size * 2, 0
            )
            print(latent_model_input.shape, timestep.shape, prompt_embeds.shape)
            noise_pred = self.unet(
                sample=latent_model_input,
                timestep=timestep,
                encoder_hidden_states=prompt_embeds,
            )
            noise_pred = noise_pred[0]

            # perform guidance
            if do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
                noise_pred = noise_pred_uncond + guidance_scale * (
                    noise_pred_text - noise_pred_uncond
                )
                if guidance_rescale > 0.0:
                    # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
                    noise_pred = rescale_noise_cfg(
                        noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
                    )

            # compute the previous noisy sample x_t -> x_t-1
            scheduler_output = self.scheduler.step(
                torch.from_numpy(noise_pred),
                t,
                torch.from_numpy(latents),
                **extra_step_kwargs
            )
            latents = scheduler_output.prev_sample.numpy()

            # call the callback, if provided
            if i == len(timesteps) - 1 or (
                (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
            ):
                if callback is not None and i % callback_steps == 0:
                    callback(i, t, latents)

        if output_type == "latent":
            image = latents
            has_nsfw_concept = None
        else:
            latents /= self.vae_decoder.config.get("scaling_factor", 0.18215)

            image = np.concatenate(
                [
                    self.vae_decoder(latent_sample=latents[i : i + 1])[0]
                    for i in range(latents.shape[0])
                ]
            )
            image, has_nsfw_concept = self.run_safety_checker(image)

        if has_nsfw_concept is None:
            do_denormalize = [True] * image.shape[0]
        else:
            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]

        image = self.image_processor.postprocess(
            image, output_type=output_type, do_denormalize=do_denormalize
        )

        if not return_dict:
            return (image, has_nsfw_concept)

        return StableDiffusionPipelineOutput(
            images=image, nsfw_content_detected=has_nsfw_concept
        )



# Modeling Remote triton pipeline

In [14]:
class RemoteTritonModel(OptimizedModel):
    model_type = "remote_triton_model"
    auto_model_class = AutoModel


    @classmethod
    def _auto_model_to_task(cls, auto_model_class):
        """
        Get the task corresponding to a class (for example AutoModelForXXX in transformers).
        """
        return TasksManager.infer_task_from_model(auto_model_class)

    def __init__(
        self,
        model: Any,
        config: "PretrainedConfig",
        **kwargs,
    ):
        super().__init__(model, config)

    def forward(self, *args, **kwargs):
        raise NotImplementedError

    def can_generate(self) -> bool:
        """
        Returns whether this model can generate sequences with `.generate()`.
        """
        return isinstance(self, GenerationMixin)



In [15]:
class RemoteTritonStableDiffusionPipelineBase(RemoteTritonModel):
    auto_model_class = StableDiffusionPipeline
    main_input_name = "input_ids"
    base_model_prefix = "onnx_model"
    config_name = "model_index.json"
    sub_component_config_name = "config.json"

    def __init__(
        self,
        vae_decoder: RemoteTritonModelVaeDecoder,
        text_encoder: RemoteTritonTextEncoder,
        unet: RemoteTritonModelUnet,
        config: Dict[str, Any],
        tokenizer: CLIPTokenizer,
        scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
        feature_extractor: Optional[CLIPFeatureExtractor] = None,
        vae_encoder: Optional[RemoteTritonModelVaeEncoder] = None,
        text_encoder_2: Optional[RemoteTritonTextEncoder] = None,
        tokenizer_2: Optional[CLIPTokenizer] = None,
    ):
        self._internal_dict = config
        self.vae_decoder = vae_decoder
        self.unet = unet

        self.text_encoder = text_encoder
        self.vae_encoder = vae_encoder
        self.text_encoder_2 = text_encoder_2

        self.tokenizer = tokenizer
        self.tokenizer_2 = tokenizer_2
        self.scheduler = scheduler
        self.feature_extractor = feature_extractor
        self.safety_checker = None

        sub_models = {
            DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER: self.text_encoder,
            DIFFUSION_MODEL_UNET_SUBFOLDER: self.unet,
            DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER: self.vae_decoder,
            DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER: self.vae_encoder,
            DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER: self.text_encoder_2,
        }

        for name in sub_models.keys():
            self._internal_dict[name] = (
                ("diffusers", "OnnxRuntimeModel")
                if sub_models[name] is not None
                else (None, None)
            )
        self._internal_dict.pop("vae", None)

        if "block_out_channels" in self.vae_decoder.config:
            self.vae_scale_factor = 2 ** (
                len(self.vae_decoder.config["block_out_channels"]) - 1
            )
        else:
            self.vae_scale_factor = 8

        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)

    @staticmethod
    def load_model():
        text_encoder = RemoteTritonTextEncoder("localhost:8001", "text_encoder")
        with open(
            "./exported-models/onnx/stable-diffusion-v1-5/text_encoder/config.json"
        ) as f:
            text_encoder.config = json.load(f)

        unet = RemoteTritonModelUnet("localhost:8001", "unet")
        with open("./exported-models/onnx/stable-diffusion-v1-5/unet/config.json") as f:
            unet.config = json.load(f)
        unet.input_dtype = {"timestep": np.dtype(np.int32)}


        vae_decoder = RemoteTritonModelVaeDecoder("localhost:8001", "vae_decoder")
        with open(
            "./exported-models/onnx/stable-diffusion-v1-5/vae_decoder/config.json"
        ) as f:
            vae_decoder.config = json.load(f)

        return vae_decoder, text_encoder, unet, None, None

    def _save_pretrained(self, save_directory: Union[str, Path]):
        # save_directory = Path(save_directory)
        # src_to_dst_path = {
        #     self.vae_decoder_model_path: save_directory / DIFFUSION_MODEL_VAE_DECODER_SUBFOLDER / ONNX_WEIGHTS_NAME,
        #     self.text_encoder_model_path: save_directory / DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER / ONNX_WEIGHTS_NAME,
        #     self.unet_model_path: save_directory / DIFFUSION_MODEL_UNET_SUBFOLDER / ONNX_WEIGHTS_NAME,
        # }

        # sub_models_to_save = {
        #     self.vae_encoder_model_path: DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER,
        #     self.text_encoder_2_model_path: DIFFUSION_MODEL_TEXT_ENCODER_2_SUBFOLDER,
        # }
        # for path, subfolder in sub_models_to_save.items():
        #     if path is not None:
        #         src_to_dst_path[path] = save_directory / subfolder / ONNX_WEIGHTS_NAME

        # # TODO: Modify _get_external_data_paths to give dictionnary
        # src_paths = list(src_to_dst_path.keys())
        # dst_paths = list(src_to_dst_path.values())
        # # Add external data paths in case of large models
        # src_paths, dst_paths = _get_external_data_paths(src_paths, dst_paths)

        # for src_path, dst_path in zip(src_paths, dst_paths):
        #     dst_path.parent.mkdir(parents=True, exist_ok=True)
        #     shutil.copyfile(src_path, dst_path)
        #     config_path = src_path.parent / self.sub_component_config_name
        #     if config_path.is_file():
        #         shutil.copyfile(config_path, dst_path.parent / self.sub_component_config_name)

        # self.scheduler.save_pretrained(save_directory / "scheduler")

        # if self.feature_extractor is not None:
        #     self.feature_extractor.save_pretrained(save_directory / "feature_extractor")
        # if self.tokenizer is not None:
        #     self.tokenizer.save_pretrained(save_directory / "tokenizer")
        # if self.tokenizer_2 is not None:
        #     self.tokenizer_2.save_pretrained(save_directory / "tokenizer_2")

        pass

    @classmethod
    def _from_pretrained(
        cls,
        model_id: Union[str, Path],
        config: Dict[str, Any],
        **kwargs,
    ):        
        model_id = str(model_id)
        patterns = set(config.keys())
        sub_models_to_load = patterns.intersection({"feature_extractor", "tokenizer", "tokenizer_2", "scheduler"})

        if not os.path.isdir(model_id):
            raise ValueError(f"Model {model_id} is not a directory")
        new_model_save_dir = Path(model_id)

        sub_models = {}
        for name in sub_models_to_load:
            library_name, library_classes = config[name]
            if library_classes is not None:
                library = importlib.import_module(library_name)
                class_obj = getattr(library, library_classes)

                load_method = getattr(class_obj, "from_pretrained")

                if (new_model_save_dir / name).is_dir():
                    sub_models[name] = load_method(new_model_save_dir / name)
                else:
                    sub_models[name] = load_method(new_model_save_dir)

        vae_decoder, text_encoder, unet, vae_encoder, text_encoder_2 = cls.load_model()

        return cls(
            vae_decoder=vae_decoder,
            text_encoder=text_encoder,
            unet=unet,
            config=config,
            tokenizer=sub_models.get("tokenizer", None),
            scheduler=sub_models.get("scheduler"),
            feature_extractor=sub_models.get("feature_extractor", None),
            tokenizer_2=sub_models.get("tokenizer_2", None),
            vae_encoder=vae_encoder,
            text_encoder_2=text_encoder_2,
        )

    @classmethod
    def _load_config(cls, config_name_or_path: Union[str, os.PathLike], **kwargs):
        return cls.load_config(config_name_or_path, **kwargs)

    def _save_config(self, save_directory):
        self.save_config(save_directory)


class RemoteTritonStableDiffusionPipeline(
    RemoteTritonStableDiffusionPipelineBase, ModifiedStableDiffusionPipelineMixin
):
    __call__ = ModifiedStableDiffusionPipelineMixin.__call__


pipeline = RemoteTritonStableDiffusionPipeline.from_pretrained(
    "./exported-models/onnx/stable-diffusion-v1-5"
)

In [17]:
prompt = "sailing ship in storm by Leonardo da Vinci"
image = pipeline(prompt, num_inference_steps=2, num_images_per_prompt=2, guidance_scale=0).images[0]

  0%|          | 0/3 [00:00<?, ?it/s]

 33%|███▎      | 1/3 [00:00<00:00,  6.39it/s]

(2, 4, 64, 64) (2, 1) (2, 77, 768)
(2, 4, 64, 64) (2, 1) (2, 77, 768)


100%|██████████| 3/3 [00:00<00:00,  6.05it/s]


(2, 4, 64, 64) (2, 1) (2, 77, 768)


InferenceServerException: [StatusCode.INTERNAL] request specifies invalid shape for input 'latent_sample' for vae_decoder_0. Error details: model expected the shape of dimension 0 to be between 2 and 2 but received 1

In [None]:
np.array([1], dtype=np.int32).reshape(1, 1).repeat(2, 0).shape

