Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: openapi stuff #6454

Merged
merged 9 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ help:
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
@echo "installer-zip Build the installer .zip file for the current version"
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
@echo "openapi Generate the OpenAPI schema for the app, outputting to stdout"

# Runs ruff, fixing any safely-fixable errors and formatting
ruff:
Expand Down Expand Up @@ -70,3 +71,6 @@ installer-zip:
tag-release:
cd installer && ./tag_release.sh

# Generate the OpenAPI Schema for the app
openapi:
python scripts/generate_openapi_schema.py
92 changes: 2 additions & 90 deletions invokeai/app/api_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,26 @@
import mimetypes
import socket
from contextlib import asynccontextmanager
from inspect import signature
from pathlib import Path
from typing import Any

import torch
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from fastapi.openapi.utils import get_openapi
from fastapi.responses import HTMLResponse
from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
from pydantic.json_schema import models_json_schema
from torch.backends.mps import is_available as is_mps_available

# for PyCharm:
# noinspection PyUnresolvedReferences
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
import invokeai.frontend.web as web_dir
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.util.custom_openapi import get_openapi_func
from invokeai.backend.util.devices import TorchDevice

from ..backend.util.logging import InvokeAILogger
Expand All @@ -45,11 +39,6 @@
workflows,
)
from .api.sockets import SocketIO
from .invocations.baseinvocation import (
BaseInvocation,
UIConfigBase,
)
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra

app_config = get_config()

Expand Down Expand Up @@ -119,84 +108,7 @@ async def lifespan(app: FastAPI):
app.include_router(session_queue.session_queue_router, prefix="/api")
app.include_router(workflows.workflows_router, prefix="/api")


# Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow?
def custom_openapi() -> dict[str, Any]:
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title=app.title,
description="An API for invoking AI image operations",
version="1.0.0",
routes=app.routes,
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
)

# Add all outputs
all_invocations = BaseInvocation.get_invocations()
output_types = set()
output_type_titles = {}
for invoker in all_invocations:
output_type = signature(invoker.invoke).return_annotation
output_types.add(output_type)

output_schemas = models_json_schema(
models=[(o, "serialization") for o in output_types], ref_template="#/components/schemas/{model}"
)
for schema_key, output_schema in output_schemas[1]["$defs"].items():
# TODO: note that we assume the schema_key here is the TYPE.__name__
# This could break in some cases, figure out a better way to do it
output_type_titles[schema_key] = output_schema["title"]
openapi_schema["components"]["schemas"][schema_key] = output_schema
openapi_schema["components"]["schemas"][schema_key]["class"] = "output"

# Some models don't end up in the schemas as standalone definitions
additional_schemas = models_json_schema(
[
(UIConfigBase, "serialization"),
(InputFieldJSONSchemaExtra, "serialization"),
(OutputFieldJSONSchemaExtra, "serialization"),
(ModelIdentifierField, "serialization"),
(ProgressImage, "serialization"),
],
ref_template="#/components/schemas/{model}",
)
for schema_key, schema_json in additional_schemas[1]["$defs"].items():
openapi_schema["components"]["schemas"][schema_key] = schema_json

openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
"type": "object",
"properties": {},
"required": [],
}

# Add a reference to the output type to additionalProperties of the invoker schema
for invoker in all_invocations:
invoker_name = invoker.__name__ # type: ignore [attr-defined] # this is a valid attribute
output_type = signature(obj=invoker.invoke).return_annotation
output_type_title = output_type_titles[output_type.__name__]
invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"]
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
invoker_schema["output"] = outputs_ref
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["properties"][invoker.get_type()] = outputs_ref
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["required"].append(invoker.get_type())
invoker_schema["class"] = "invocation"

# Add all event schemas
for event in sorted(EventBase.get_events(), key=lambda e: e.__name__):
json_schema = event.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
if "$defs" in json_schema:
for schema_key, schema in json_schema["$defs"].items():
openapi_schema["components"]["schemas"][schema_key] = schema
del json_schema["$defs"]
openapi_schema["components"]["schemas"][event.__name__] = json_schema

app.openapi_schema = openapi_schema
return app.openapi_schema


app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid assignment
app.openapi = get_openapi_func(app)


@app.get("/docs", include_in_schema=False)
Expand Down
30 changes: 19 additions & 11 deletions invokeai/app/invocations/baseinvocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,13 @@ class BaseInvocationOutput(BaseModel):

_output_classes: ClassVar[set[BaseInvocationOutput]] = set()
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
_typeadapter_needs_update: ClassVar[bool] = False
psychedelicious marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def register_output(cls, output: BaseInvocationOutput) -> None:
"""Registers an invocation output."""
cls._output_classes.add(output)
cls._typeadapter_needs_update = True

@classmethod
def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
Expand All @@ -112,11 +114,12 @@ def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
@classmethod
def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantc TypeAdapter for the union of all invocation output types."""
if not cls._typeadapter:
InvocationOutputsUnion = TypeAliasType(
"InvocationOutputsUnion", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
if not cls._typeadapter or cls._typeadapter_needs_update:
AnyInvocationOutput = TypeAliasType(
"AnyInvocationOutput", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
)
cls._typeadapter = TypeAdapter(InvocationOutputsUnion)
cls._typeadapter = TypeAdapter(AnyInvocationOutput)
cls._typeadapter_needs_update = False
return cls._typeadapter

@classmethod
Expand All @@ -125,12 +128,13 @@ def get_output_types(cls) -> Iterable[str]:
return (i.get_type() for i in BaseInvocationOutput.get_outputs())

@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None:
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
# Because we use a pydantic Literal field with default value for the invocation type,
# it will be typed as optional in the OpenAPI schema. Make it required manually.
if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = []
schema["class"] = "output"
schema["required"].extend(["type"])

@classmethod
Expand Down Expand Up @@ -167,6 +171,7 @@ class BaseInvocation(ABC, BaseModel):

_invocation_classes: ClassVar[set[BaseInvocation]] = set()
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
_typeadapter_needs_update: ClassVar[bool] = False

@classmethod
def get_type(cls) -> str:
Expand All @@ -177,15 +182,17 @@ def get_type(cls) -> str:
def register_invocation(cls, invocation: BaseInvocation) -> None:
"""Registers an invocation."""
cls._invocation_classes.add(invocation)
cls._typeadapter_needs_update = True

@classmethod
def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
if not cls._typeadapter:
InvocationsUnion = TypeAliasType(
"InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
if not cls._typeadapter or cls._typeadapter_needs_update:
AnyInvocation = TypeAliasType(
"AnyInvocation", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
)
cls._typeadapter = TypeAdapter(InvocationsUnion)
cls._typeadapter = TypeAdapter(AnyInvocation)
cls._typeadapter_needs_update = False
return cls._typeadapter

@classmethod
Expand Down Expand Up @@ -221,7 +228,7 @@ def get_output_annotation(cls) -> BaseInvocationOutput:
return signature(cls.invoke).return_annotation

@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None:
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None))
if uiconfig is not None:
Expand All @@ -237,6 +244,7 @@ def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *arg
schema["version"] = uiconfig.version
if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = []
schema["class"] = "invocation"
schema["required"].extend(["type", "id"])

@abstractmethod
Expand Down Expand Up @@ -310,7 +318,7 @@ def invoke_internal(self, context: InvocationContext, services: "InvocationServi
protected_namespaces=(),
validate_assignment=True,
json_schema_extra=json_schema_extra,
json_schema_serialization_defaults_required=True,
json_schema_serialization_defaults_required=False,
coerce_numbers_to_str=True,
)

Expand Down
32 changes: 8 additions & 24 deletions invokeai/app/services/events/events_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@

from fastapi_events.handlers.local import local_handler
from fastapi_events.registry.payload_schema import registry as payload_schema
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, field_validator
from pydantic import BaseModel, ConfigDict, Field

from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
Expand All @@ -14,6 +13,7 @@
SessionQueueItem,
SessionQueueStatus,
)
from invokeai.app.services.shared.graph import AnyInvocation, AnyInvocationOutput
from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
Expand Down Expand Up @@ -98,17 +98,9 @@ class InvocationEventBase(QueueItemEventBase):
item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch")
session_id: str = Field(description="The ID of the session (aka graph execution state)")
invocation: SerializeAsAny[BaseInvocation] = Field(description="The ID of the invocation")
invocation: AnyInvocation = Field(description="The ID of the invocation")
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")

@field_validator("invocation", mode="plain")
@classmethod
def validate_invocation(cls, v: Any):
"""Validates the invocation using the dynamic type adapter."""

invocation = BaseInvocation.get_typeadapter().validate_python(v)
return invocation


@payload_schema.register
class InvocationStartedEvent(InvocationEventBase):
Expand All @@ -117,7 +109,7 @@ class InvocationStartedEvent(InvocationEventBase):
__event_name__ = "invocation_started"

@classmethod
def build(cls, queue_item: SessionQueueItem, invocation: BaseInvocation) -> "InvocationStartedEvent":
def build(cls, queue_item: SessionQueueItem, invocation: AnyInvocation) -> "InvocationStartedEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
Expand All @@ -144,7 +136,7 @@ class InvocationDenoiseProgressEvent(InvocationEventBase):
def build(
cls,
queue_item: SessionQueueItem,
invocation: BaseInvocation,
invocation: AnyInvocation,
intermediate_state: PipelineIntermediateState,
progress_image: ProgressImage,
) -> "InvocationDenoiseProgressEvent":
Expand Down Expand Up @@ -182,19 +174,11 @@ class InvocationCompleteEvent(InvocationEventBase):

__event_name__ = "invocation_complete"

result: SerializeAsAny[BaseInvocationOutput] = Field(description="The result of the invocation")

@field_validator("result", mode="plain")
@classmethod
def validate_results(cls, v: Any):
"""Validates the invocation result using the dynamic type adapter."""

result = BaseInvocationOutput.get_typeadapter().validate_python(v)
return result
result: AnyInvocationOutput = Field(description="The result of the invocation")

@classmethod
def build(
cls, queue_item: SessionQueueItem, invocation: BaseInvocation, result: BaseInvocationOutput
cls, queue_item: SessionQueueItem, invocation: AnyInvocation, result: AnyInvocationOutput
) -> "InvocationCompleteEvent":
return cls(
queue_id=queue_item.queue_id,
Expand Down Expand Up @@ -223,7 +207,7 @@ class InvocationErrorEvent(InvocationEventBase):
def build(
cls,
queue_item: SessionQueueItem,
invocation: BaseInvocation,
invocation: AnyInvocation,
error_type: str,
error_message: str,
error_traceback: str,
Expand Down
Loading
Loading