Skip to content

Commit

Permalink
fastapi fix
Browse files Browse the repository at this point in the history
  • Loading branch information
lshqqytiger committed Feb 8, 2024
1 parent 4161f91 commit 9e8f9bc
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 4 deletions.
14 changes: 13 additions & 1 deletion modules/onnx_impl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import os
from typing import Any, Dict, Optional
from typing import Any, Dict, Callable, Optional
import numpy as np
import torch
import diffusers
import onnxruntime as ort
import optimum.onnxruntime

# fastapi fix
from fastapi import encoders as fastapi_encoders


initialized = False
run_olive_workflow = None
Expand Down Expand Up @@ -201,6 +204,13 @@ def ORTDiffusionModelPart_to(self, *args, **kwargs):
return self


fastapi_jsonable_encoder = fastapi_encoders.jsonable_encoder
def jsonable_encoder(obj: Any, *args, **kwargs):
if isinstance(obj, Callable):
return {}
return fastapi_jsonable_encoder(obj, *args, **kwargs)


def initialize():
global initialized # pylint: disable=global-statement

Expand Down Expand Up @@ -253,6 +263,8 @@ def initialize():

optimum.onnxruntime.modeling_diffusion._ORTDiffusionModelPart.to = ORTDiffusionModelPart_to # pylint: disable=protected-access

fastapi_encoders.jsonable_encoder = jsonable_encoder

print(f'ONNX: selected={opts.onnx_execution_provider}, available={available_execution_providers}')

initialized = True
Expand Down
3 changes: 1 addition & 2 deletions modules/onnx_impl/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import json
import shutil
import tempfile
from abc import ABCMeta
from typing import Type, Tuple, List, Any, Dict
from packaging import version
import onnx
Expand All @@ -26,7 +25,7 @@
SUBMODELS_SDXL_REFINER = ("text_encoder_2", "unet", "vae_encoder", "vae_decoder",)


class PipelineBase(TorchCompatibleModule, diffusers.DiffusionPipeline, metaclass=ABCMeta):
class PipelineBase(TorchCompatibleModule, diffusers.DiffusionPipeline):
model_type: str
sd_model_hash: str
sd_checkpoint_info: CheckpointInfo
Expand Down
1 change: 0 additions & 1 deletion modules/shared_cmd_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,3 @@

cmd_opts.webui_is_non_local = any([cmd_opts.share, cmd_opts.listen, cmd_opts.ngrok, cmd_opts.server_name])
cmd_opts.disable_extension_access = cmd_opts.webui_is_non_local and not cmd_opts.enable_insecure_extension_access
cmd_opts.no_download_sd_model = True

0 comments on commit 9e8f9bc

Please sign in to comment.