Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support custom scales for PyTorch upscale #2613

Merged
merged 1 commit into from
Feb 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
41 changes: 41 additions & 0 deletions backend/src/nodes/impl/upscale/custom_scale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import math

import numpy as np

from nodes.impl.image_op import ImageOp
from nodes.impl.resize import ResizeFilter, resize
from nodes.utils.utils import get_h_w_c


def custom_scale_upscale(
img: np.ndarray,
upscale: ImageOp,
natural_scale: int,
custom_scale: int,
separate_alpha: bool,
) -> np.ndarray:
if custom_scale == natural_scale:
return upscale(img)

# number of iterations we need to do to reach the desired scale
# e.g. if the model is 2x and the desired scale is 13x, we need to do 4 iterations
iterations = max(1, math.ceil(math.log(custom_scale, natural_scale)))
org_h, org_w, _ = get_h_w_c(img)
for _ in range(iterations):
img = upscale(img)

# resize, if necessary
target_size = (
org_w * custom_scale,
org_h * custom_scale,
)
h, w, _ = get_h_w_c(img)
if (w, h) != target_size:
img = resize(
img,
target_size,
ResizeFilter.BOX,
separate_alpha=separate_alpha,
)

return img
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
parse_tile_size_input,
)
from nodes.impl.upscale.convenient_upscale import convenient_upscale
from nodes.impl.upscale.custom_scale import custom_scale_upscale
from nodes.impl.upscale.tiler import MaxTileSize
from nodes.properties.inputs import (
BoolInput,
ImageInput,
NumberInput,
SrModelInput,
TileSizeDropdown,
)
Expand Down Expand Up @@ -108,6 +110,40 @@ def estimate():
inputs=[
ImageInput().with_id(1),
SrModelInput().with_id(0),
if_group(
Condition.type(0, "PyTorchModel { scale: int(2..) }", if_not_connected=True)
& (
Condition.type(
0,
"PyTorchModel { inputChannels: 1, outputChannels: 1 }",
if_not_connected=True,
)
| Condition.type(
0, "PyTorchModel { inputChannels: 3, outputChannels: 3 }"
)
| Condition.type(
0, "PyTorchModel { inputChannels: 4, outputChannels: 4 }"
)
)
)(
BoolInput("Custom Scale", default=False)
.with_id(4)
.with_docs(
"If enabled, the scale factor can be manually set. This makes it possible to e.g. upscale 4x with a 2x model.",
"Custom scales are **not** supported for 1x models and colorization models.",
"Under the hood, this will repeatedly apply the model to the image, effectively upscaling by the given factor."
" E.g. if the model is 2x and the desired scale is 4x, the model will be applied 2 times."
" If the desired scale cannot be reached exactly, the image will be downscaled to the desired scale after upscaling."
" E.g. if the model is 2x and the desired scale is 6x, the model will be applied 3 times (8x) and the image will be downscaled to 6x.",
"If the desired scale is less than the model's scale, the image will be downscaled to the desired scale after upscaling.",
hint=True,
),
if_group(Condition.bool(4, True))(
NumberInput(
"Scale", default=4, minimum=1, maximum=32, label_style="hidden"
).with_id(5),
),
),
if_group(
Condition.type(
0,
Expand Down Expand Up @@ -142,7 +178,9 @@ def estimate():
)
)
)(
BoolInput("Separate Alpha", default=False).with_docs(
BoolInput("Separate Alpha", default=False)
.with_id(3)
.with_docs(
"Upscale alpha separately from color. Enabling this option will cause the alpha of"
" the upscaled image to be less noisy and more accurate to the alpha of the original"
" image, but the image may suffer from dark borders near transparency edges"
Expand All @@ -156,7 +194,24 @@ def estimate():
outputs=[
ImageOutput(
"Image",
image_type="""convenientUpscale(Input0, Input1)""",
image_type="""
let img = Input1;
let model = Input0;
let useCustomScale = Input4;
let customScale = Input5;

let singleUpscale = convenientUpscale(model, img);

if bool::and(useCustomScale, model.scale >= 2, model.inputChannels == model.outputChannels) {
Image {
width: img.width * customScale,
height: img.height * customScale,
channels: singleUpscale.channels,
}
} else {
singleUpscale
}
""",
assume_normalized=True, # pytorch_auto_split already does clipping internally
)
],
Expand All @@ -166,27 +221,35 @@ def upscale_image_node(
context: NodeContext,
img: np.ndarray,
model: ImageModelDescriptor,
use_custom_scale: bool,
custom_scale: int,
tile_size: TileSize,
separate_alpha: bool,
) -> np.ndarray:
exec_options = get_settings(context)

logger.debug("Upscaling image...")

in_nc = model.input_channels
out_nc = model.output_channels
scale = model.scale
h, w, c = get_h_w_c(img)
logger.debug(
f"Upscaling a {h}x{w}x{c} image with a {scale}x model (in_nc: {in_nc}, out_nc:"
f" {out_nc})"
)

return convenient_upscale(
img,
in_nc,
out_nc,
lambda i: upscale(i, model, tile_size, exec_options),
separate_alpha,
clip=False, # pytorch_auto_split already does clipping internally
)

def inner_upscale(img: np.ndarray) -> np.ndarray:
h, w, c = get_h_w_c(img)
logger.debug(
f"Upscaling a {h}x{w}x{c} image with a {scale}x model (in_nc: {in_nc}, out_nc:"
f" {out_nc})"
)

return convenient_upscale(
img,
in_nc,
out_nc,
lambda i: upscale(i, model, tile_size, exec_options),
separate_alpha,
clip=False, # pytorch_auto_split already does clipping internally
)

if not use_custom_scale or scale == 1 or in_nc != out_nc:
# no custom scale
custom_scale = scale

return custom_scale_upscale(img, inner_upscale, scale, custom_scale, separate_alpha)