-
-
Notifications
You must be signed in to change notification settings - Fork 268
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for AUTOMATIC1111's extras - image upscaling
- Loading branch information
Showing
2 changed files
with
192 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
189 changes: 189 additions & 0 deletions
189
backend/src/nodes/nodes/external_stable_diffusion/upscaling.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
from __future__ import annotations | ||
|
||
from enum import Enum | ||
from typing import Optional | ||
|
||
import numpy as np | ||
|
||
from ...group import group | ||
from ...groups import Condition, if_enum_group, if_group | ||
from ...impl.external_stable_diffusion import (STABLE_DIFFUSION_UPSCALE_PATH, | ||
decode_base64_image, | ||
encode_base64_image, post, | ||
verify_api_connection) | ||
from ...node_base import NodeBase | ||
from ...node_cache import cached | ||
from ...node_factory import NodeFactory | ||
from ...properties.inputs import (BoolInput, EnumInput, ImageInput, | ||
NumberInput, SliderInput) | ||
from ...properties.outputs import ImageOutput | ||
from ...utils.utils import get_h_w_c | ||
from . import category as ExternalStableDiffusionCategory | ||
|
||
verify_api_connection() | ||
|
||
class UpscalerMode(Enum): | ||
SCALE_BY = "ScaleBy" | ||
SCALE_TO = "ScaleTo" | ||
|
||
UPSCALER_MODE_LABELS = { | ||
UpscalerMode.SCALE_BY: "Scale by", | ||
UpscalerMode.SCALE_TO: "Scale to" | ||
} | ||
|
||
class UpscalerName(Enum): | ||
LANCZOS = "Lanczos" | ||
NEAREST = "Nearest" | ||
ESRGAN_4X = "ESRGAN_4x" | ||
LDSR = "LDSR" | ||
SCUNET = "ScuNET" | ||
SCUNET_PNSR = "ScuNET_PSNR" | ||
SWINIR_4x = "SwinIR_4x" | ||
|
||
UPSCALER_NAME_LABELS = { | ||
UpscalerName.LANCZOS: "Lanczos", | ||
UpscalerName.NEAREST: "Nearest", | ||
UpscalerName.ESRGAN_4X: "ESRGAN_4x", | ||
UpscalerName.LDSR: "LDSR", | ||
UpscalerName.SCUNET: "ScuNET", | ||
UpscalerName.SCUNET_PNSR: "ScuNET PNSR", | ||
UpscalerName.SWINIR_4x: "SwinIR_4x", | ||
} | ||
|
||
@NodeFactory.register("chainner:external_stable_diffusion:upscaling") | ||
class Extras(NodeBase): | ||
def __init__(self): | ||
super().__init__() | ||
self.description = "Upscale image using Automatic1111" | ||
self.inputs = [ | ||
ImageInput(channels=3), | ||
EnumInput(UpscalerMode, default_value=UpscalerMode.SCALE_BY, option_labels=UPSCALER_MODE_LABELS).with_id(1), | ||
if_enum_group(1, UpscalerMode.SCALE_BY)( | ||
SliderInput( | ||
"Resize multiplier", | ||
minimum=1.0, | ||
default=4.0, | ||
maximum=8.0, | ||
slider_step=0.1, | ||
controls_step=0.1, | ||
precision=1, | ||
).with_id(2), | ||
), | ||
if_enum_group(1, UpscalerMode.SCALE_TO)( | ||
NumberInput("Width", controls_step=1, default=512).with_id(3), | ||
NumberInput("Height", controls_step=1, default=512).with_id(4), | ||
BoolInput("Crop to fit", default=True).with_id(5), | ||
), | ||
EnumInput( | ||
UpscalerName, | ||
label="Upscaler 1", | ||
default_value=UpscalerName.LANCZOS, | ||
option_labels=UPSCALER_NAME_LABELS, | ||
), | ||
BoolInput("Use second upscaler", default=False).with_id(7), | ||
if_group(Condition.bool(7, True)) ( | ||
EnumInput( | ||
UpscalerName, | ||
label="Upscaler 2", | ||
default_value=UpscalerName.LANCZOS, | ||
option_labels=UPSCALER_NAME_LABELS, | ||
), | ||
SliderInput( | ||
"Upscaler 2 visibility", | ||
minimum=0.0, | ||
default=0.0, | ||
maximum=1.0, | ||
slider_step=0.001, | ||
controls_step=0.001, | ||
precision=3, | ||
), | ||
) | ||
] | ||
|
||
self.outputs = [ | ||
ImageOutput( | ||
image_type=""" | ||
def nearest_valid(n: number) = int & floor(n); | ||
let in_w = Input0.width; | ||
let in_h = Input0.height; | ||
let ratio_w = width/in_w; | ||
let ratio_h = height/in_h; | ||
let larger_ratio = if ratio_w>ratio_h { ratio_w } else { ratio_h }; | ||
let mode = Input1; | ||
let factor = Input2; | ||
let crop = Input5; | ||
let width = Input3; | ||
let height = Input4; | ||
match mode==UpscalerMode::ScaleTo { | ||
true => match crop==true { | ||
true => Image { | ||
width: width, | ||
height: height | ||
}, | ||
false => Image { | ||
width: nearest_valid(in_w*larger_ratio), | ||
height: nearest_valid(in_h*larger_ratio) | ||
} | ||
}, | ||
false => Image{ | ||
width: nearest_valid(in_w*factor), | ||
height: nearest_valid(in_h*factor) | ||
} | ||
} | ||
""", | ||
channels=3, | ||
) | ||
] | ||
self.category = ExternalStableDiffusionCategory | ||
self.name = "Upscale" | ||
self.icon = "MdChangeCircle" | ||
self.sub = "Automatic1111" | ||
|
||
@cached | ||
def run( | ||
self, | ||
image: np.ndarray, | ||
mode: UpscalerMode, | ||
upscaling_resize: float, | ||
width: int, | ||
height: int, | ||
crop: bool, | ||
upscaler_1: UpscalerName, | ||
use_second_upscaler: bool, | ||
upscaler_2: UpscalerName, | ||
upscaler_2_visibility: float, | ||
) -> np.ndarray: | ||
if mode==UpscalerMode.SCALE_BY: | ||
resize_mode = 0 | ||
else: | ||
resize_mode = 1 | ||
|
||
if use_second_upscaler: | ||
u2 = upscaler_2.value | ||
else: | ||
u2 = "None" | ||
|
||
request_data = { | ||
"resize_mode": resize_mode, | ||
"show_extras_results": False, | ||
"gfpgan_visibility": 0.0, | ||
"codeformer_visibility": 0.0, | ||
"codeformer_weight": 0.0, | ||
"upscaling_resize": upscaling_resize, | ||
"upscaling_resize_w": width, | ||
"upscaling_resize_h": height, | ||
"upscaling_crop": crop, | ||
"upscaler_1": upscaler_1.value, | ||
"upscaler_2": u2, | ||
"extras_upscaler_2_visibility": upscaler_2_visibility, | ||
"upscale_first": False, | ||
"image": encode_base64_image(image), | ||
|
||
} | ||
response = post(path=STABLE_DIFFUSION_UPSCALE_PATH, json_data=request_data) | ||
result = decode_base64_image(response["image"]) | ||
return result |