Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto-detect SD webui endpoint #1588

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 48 additions & 21 deletions backend/src/nodes/impl/external_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,43 +4,42 @@
import io
import os
from enum import Enum
from typing import Dict, Union
from typing import Dict, Optional, Union

import cv2
import numpy as np
import requests
from PIL import Image
from sanic.log import logger

from ..utils.utils import get_h_w_c
from .image_utils import normalize

STABLE_DIFFUSION_PROTOCOL = os.environ.get("STABLE_DIFFUSION_PROTOCOL", None)
STABLE_DIFFUSION_HOST = os.environ.get("STABLE_DIFFUSION_HOST", "127.0.0.1")
STABLE_DIFFUSION_PORT = os.environ.get("STABLE_DIFFUSION_PORT", "7860")
STABLE_DIFFUSION_PORT = os.environ.get("STABLE_DIFFUSION_PORT", None)

STABLE_DIFFUSION_REQUEST_TIMEOUT = float(
os.environ.get("STABLE_DIFFUSION_REQUEST_TIMEOUT", "600")
) # 10 minutes

STABLE_DIFFUSION_TEXT2IMG_URL = (
f"http://{STABLE_DIFFUSION_HOST}:{STABLE_DIFFUSION_PORT}/sdapi/v1/txt2img"
)
STABLE_DIFFUSION_IMG2IMG_URL = (
f"http://{STABLE_DIFFUSION_HOST}:{STABLE_DIFFUSION_PORT}/sdapi/v1/img2img"
)
STABLE_DIFFUSION_INTERROGATE_URL = (
f"http://{STABLE_DIFFUSION_HOST}:{STABLE_DIFFUSION_PORT}/sdapi/v1/interrogate"
)
STABLE_DIFFUSION_OPTIONS_URL = (
f"http://{STABLE_DIFFUSION_HOST}:{STABLE_DIFFUSION_PORT}/sdapi/v1/options"
)
STABLE_DIFFUSION_TEXT2IMG_PATH = f"/sdapi/v1/txt2img"
STABLE_DIFFUSION_IMG2IMG_PATH = f"/sdapi/v1/img2img"
STABLE_DIFFUSION_INTERROGATE_PATH = f"/sdapi/v1/interrogate"
STABLE_DIFFUSION_OPTIONS_PATH = f"/sdapi/v1/options"


def _stable_diffusion_url(path):
return f"{STABLE_DIFFUSION_PROTOCOL}://{STABLE_DIFFUSION_HOST}:{STABLE_DIFFUSION_PORT}{path}"


ERROR_MSG = f"""
If you want to use external stable diffusion nodes, run the Automatic1111 web ui with the --api flag, like so:

./webui.sh --api

ChaiNNer is currently configured to look for the API at http://{STABLE_DIFFUSION_HOST}:{STABLE_DIFFUSION_PORT}. If you
have it running somewhere else, you can change this using the STABLE_DIFFUSION_HOST and STABLE_DIFFUSION_PORT
To manually set where ChaiNNer looks for the API, use the
STABLE_DIFFUSION_PROTOCOL, STABLE_DIFFUSION_HOST, and STABLE_DIFFUSION_PORT
environment variables.
"""

Expand All @@ -58,20 +57,48 @@ class ExternalServiceTimeout(Exception):
pass


def get(url, timeout: float = STABLE_DIFFUSION_REQUEST_TIMEOUT) -> Dict:
def _auto_detect_endpoint(timeout=0.5):
global STABLE_DIFFUSION_PROTOCOL, STABLE_DIFFUSION_PORT # pylint: disable=global-statement

protocols = (
[STABLE_DIFFUSION_PROTOCOL] if STABLE_DIFFUSION_PROTOCOL else ["http", "https"]
)
ports = [STABLE_DIFFUSION_PORT] if STABLE_DIFFUSION_PORT else ["7860", "7861"]

last_error: Optional[Exception] = None
for STABLE_DIFFUSION_PROTOCOL in protocols:
for STABLE_DIFFUSION_PORT in ports:
try:
get(STABLE_DIFFUSION_OPTIONS_PATH, timeout=timeout)
logger.info(
f"Found stable diffusion API at {STABLE_DIFFUSION_PROTOCOL}://{STABLE_DIFFUSION_HOST}:{STABLE_DIFFUSION_PORT}"
)
return
except Exception as error:
last_error = error

if last_error:
raise last_error
else:
raise RuntimeError


def get(path, timeout: float = STABLE_DIFFUSION_REQUEST_TIMEOUT) -> Dict:
try:
response = requests.get(url, timeout=timeout)
response = requests.get(_stable_diffusion_url(path), timeout=timeout)
except requests.ConnectionError as exc:
raise ExternalServiceConnectionError(ERROR_MSG) from exc
except requests.exceptions.ReadTimeout as exc:
raise ExternalServiceTimeout(TIMEOUT_MSG) from exc
return response.json()


def post(url, json_data: Dict) -> Dict:
def post(path, json_data: Dict) -> Dict:
try:
response = requests.post(
url, json=json_data, timeout=STABLE_DIFFUSION_REQUEST_TIMEOUT
_stable_diffusion_url(path),
json=json_data,
timeout=STABLE_DIFFUSION_REQUEST_TIMEOUT,
)
except requests.ConnectionError as exc:
raise ExternalServiceConnectionError(ERROR_MSG) from exc
Expand All @@ -96,7 +123,7 @@ def verify_api_connection():
global has_api_connection # pylint: disable=global-statement
if has_api_connection is None:
has_api_connection = False
get(STABLE_DIFFUSION_OPTIONS_URL, timeout=0.5)
_auto_detect_endpoint()
has_api_connection = True

if not has_api_connection:
Expand Down
4 changes: 2 additions & 2 deletions backend/src/nodes/nodes/external_stable_diffusion/img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ...impl.external_stable_diffusion import (
RESIZE_MODE_LABELS,
SAMPLER_NAME_LABELS,
STABLE_DIFFUSION_IMG2IMG_URL,
STABLE_DIFFUSION_IMG2IMG_PATH,
ResizeMode,
SamplerName,
decode_base64_image,
Expand Down Expand Up @@ -141,7 +141,7 @@ def run(
"resize_mode": resize_mode.value,
"tiling": tiling,
}
response = post(url=STABLE_DIFFUSION_IMG2IMG_URL, json_data=request_data)
response = post(path=STABLE_DIFFUSION_IMG2IMG_PATH, json_data=request_data)
result = decode_base64_image(response["images"][0])
h, w, _ = get_h_w_c(result)
assert (w, h) == (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

from ...impl.external_stable_diffusion import (
STABLE_DIFFUSION_INTERROGATE_URL,
STABLE_DIFFUSION_INTERROGATE_PATH,
encode_base64_image,
post,
verify_api_connection,
Expand Down Expand Up @@ -40,5 +40,5 @@ def run(self, image: np.ndarray) -> str:
request_data = {
"image": encode_base64_image(image),
}
response = post(url=STABLE_DIFFUSION_INTERROGATE_URL, json_data=request_data)
response = post(path=STABLE_DIFFUSION_INTERROGATE_PATH, json_data=request_data)
return response["caption"]
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ...impl.external_stable_diffusion import (
RESIZE_MODE_LABELS,
SAMPLER_NAME_LABELS,
STABLE_DIFFUSION_IMG2IMG_URL,
STABLE_DIFFUSION_IMG2IMG_PATH,
InpaintingFill,
ResizeMode,
SamplerName,
Expand Down Expand Up @@ -261,7 +261,7 @@ def run(
}
)

response = post(url=STABLE_DIFFUSION_IMG2IMG_URL, json_data=request_data)
response = post(path=STABLE_DIFFUSION_IMG2IMG_PATH, json_data=request_data)
result = decode_base64_image(response["images"][0])
h, w, _ = get_h_w_c(result)
assert (w, h) == (
Expand Down
4 changes: 2 additions & 2 deletions backend/src/nodes/nodes/external_stable_diffusion/txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ...impl.external_stable_diffusion import (
SAMPLER_NAME_LABELS,
STABLE_DIFFUSION_TEXT2IMG_URL,
STABLE_DIFFUSION_TEXT2IMG_PATH,
SamplerName,
decode_base64_image,
nearest_valid_size,
Expand Down Expand Up @@ -116,7 +116,7 @@ def run(
"height": height,
"tiling": tiling,
}
response = post(url=STABLE_DIFFUSION_TEXT2IMG_URL, json_data=request_data)
response = post(path=STABLE_DIFFUSION_TEXT2IMG_PATH, json_data=request_data)
result = decode_base64_image(response["images"][0])
h, w, _ = get_h_w_c(result)
assert (w, h) == (
Expand Down