Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Rust implementation for image resizing #2387

Merged
merged 2 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
96 changes: 96 additions & 0 deletions backend/src/nodes/impl/resize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from __future__ import annotations

from enum import Enum

import numpy as np
from chainner_ext import ResizeFilter as NavtiveResizeFilter
from chainner_ext import resize as native_resize

from ..utils.utils import get_h_w_c


class ResizeFilter(Enum):
AUTO = -1
NEAREST = 0
BOX = 4
LINEAR = 2
CATROM = 3
LANCZOS = 1

# HERMITE = 5
# MITCHELL = 6
# BSPLINE = 7
# HAMMING = 8
# HANN = 9
# LAGRANGE = 10
# GAUSS = 11


_FILTER_MAP: dict[ResizeFilter, NavtiveResizeFilter] = {
ResizeFilter.NEAREST: NavtiveResizeFilter.Nearest,
ResizeFilter.BOX: NavtiveResizeFilter.Box,
ResizeFilter.LINEAR: NavtiveResizeFilter.Linear,
ResizeFilter.CATROM: NavtiveResizeFilter.CubicCatrom,
ResizeFilter.LANCZOS: NavtiveResizeFilter.Lanczos,
# ResizeFilter.HERMITE: NavtiveResizeFilter.Hermite,
# ResizeFilter.MITCHELL: NavtiveResizeFilter.CubicMitchell,
# ResizeFilter.BSPLINE: NavtiveResizeFilter.CubicBSpline,
# ResizeFilter.HAMMING: NavtiveResizeFilter.Hamming,
# ResizeFilter.HANN: NavtiveResizeFilter.Hann,
# ResizeFilter.LAGRANGE: NavtiveResizeFilter.Lagrange,
# ResizeFilter.GAUSS: NavtiveResizeFilter.Gauss,
}


def resize(
img: np.ndarray,
out_dims: tuple[int, int],
filter: ResizeFilter,
separate_alpha: bool = False,
gamma_correction: bool = False,
) -> np.ndarray:
h, w, c = get_h_w_c(img)
new_w, new_h = out_dims

# check memory
GB: int = 2**30 # noqa: N806
MAX_MEMORY = 16 * GB # noqa: N806
new_memory = new_w * new_h * c * 4
if new_memory > MAX_MEMORY:
raise RuntimeError(
f"Resize would require {round(new_memory / GB, 3)} GB of memory, but only {MAX_MEMORY//GB} GB are allowed."
)

if filter == ResizeFilter.AUTO:
# automatically chose a method that works
if new_w > w or new_h > h:
filter = ResizeFilter.LANCZOS
else:
filter = ResizeFilter.BOX

if (w, h) == out_dims and (filter in (ResizeFilter.NEAREST, ResizeFilter.BOX)):
# no resize needed
return img.copy()

native_filter = _FILTER_MAP[filter]

if not separate_alpha and c == 4:
# pre-multiply alpha
img = img.copy()
img[:, :, 0] *= img[..., 3]
img[:, :, 1] *= img[..., 3]
img[:, :, 2] *= img[..., 3]

img = native_resize(img, out_dims, native_filter, gamma_correction)
# native_resize guarantees that the output is float32 in the range [0, 1]
# so no need to normalize

if not separate_alpha and c == 4:
# undo pre-multiply alpha
alpha_r = 1 / np.maximum(img[..., 3], 0.0001)
img[:, :, 0] *= alpha_r
img[:, :, 1] *= alpha_r
img[:, :, 2] *= alpha_r
np.minimum(img, 1, out=img)

return img
14 changes: 8 additions & 6 deletions backend/src/nodes/properties/inputs/image_dropdown_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

# pylint: disable=relative-beyond-top-level
from ...impl.image_utils import BorderType
from ...impl.pil_utils import InterpolationMethod, RotationInterpolationMethod
from ...impl.pil_utils import RotationInterpolationMethod
from ...impl.resize import ResizeFilter
from .generic_inputs import DropDownInput, EnumInput


Expand Down Expand Up @@ -50,13 +51,14 @@ def ColorSpaceInput(label: str = "Color Space") -> DropDownInput:
)


def InterpolationInput() -> DropDownInput:
"""Resize interpolation dropdown"""
def ResizeFilterInput() -> DropDownInput:
return EnumInput(
InterpolationMethod,
ResizeFilter,
label="Interpolation Method",
option_labels={
InterpolationMethod.NEAREST: "Nearest Neighbor",
InterpolationMethod.BOX: "Area (Box)",
ResizeFilter.NEAREST: "Nearest Neighbor",
ResizeFilter.BOX: "Area (Box)",
ResizeFilter.CATROM: "Cubic",
},
)

Expand Down
8 changes: 2 additions & 6 deletions backend/src/nodes/properties/outputs/numpy_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from api import BaseOutput, OutputKind

from ...impl.image_utils import normalize, to_uint8
from ...impl.pil_utils import InterpolationMethod, resize
from ...impl.resize import ResizeFilter, resize
from ...utils.format import format_image_with_channels
from ...utils.utils import get_h_w_c, round_half_up

Expand Down Expand Up @@ -131,11 +131,7 @@ def preview_encode(
if w > max_size or h > max_size:
f = max(w / target_size, h / target_size)
t = (max(1, round_half_up(w / f)), max(1, round_half_up(h / f)))
if c == 4:
# https://github.com/chaiNNer-org/chaiNNer/issues/1321
img = resize(img, t, InterpolationMethod.BOX)
else:
img = cv2.resize(img, t, interpolation=cv2.INTER_AREA)
img = resize(img, t, ResizeFilter.BOX)

image_format = "png" if c > 3 or lossless else "jpg"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import numpy as np

from nodes.groups import if_enum_group
from nodes.impl.pil_utils import InterpolationMethod, resize
from nodes.impl.resize import ResizeFilter, resize
from nodes.properties.inputs import (
EnumInput,
ImageInput,
InterpolationInput,
NumberInput,
ResizeFilterInput,
)
from nodes.properties.outputs import ImageOutput
from nodes.utils.utils import get_h_w_c, round_half_up
Expand Down Expand Up @@ -50,7 +50,7 @@ class ImageResizeMode(Enum):
NumberInput("Width", minimum=1, default=1, unit="px").with_id(3),
NumberInput("Height", minimum=1, default=1, unit="px").with_id(4),
),
InterpolationInput().with_id(5),
ResizeFilterInput().with_id(5),
],
outputs=[
ImageOutput(
Expand Down Expand Up @@ -78,15 +78,14 @@ class ImageResizeMode(Enum):
assume_normalized=True,
)
],
limited_to_8bpc=True,
)
def resize_node(
img: np.ndarray,
mode: ImageResizeMode,
scale: float,
width: int,
height: int,
interpolation: InterpolationMethod,
filter: ResizeFilter,
) -> np.ndarray:
h, w, _ = get_h_w_c(img)

Expand All @@ -99,4 +98,10 @@ def resize_node(
else:
out_dims = (width, height)

return resize(img, out_dims, interpolation)
return resize(
img,
out_dims,
filter,
separate_alpha=False,
gamma_correction=False,
)
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
from enum import Enum

import numpy as np
from sanic.log import logger

from nodes.impl.pil_utils import InterpolationMethod, resize
from nodes.impl.resize import ResizeFilter, resize
from nodes.properties.inputs import (
EnumInput,
ImageInput,
InterpolationInput,
NumberInput,
ResizeFilterInput,
)
from nodes.properties.outputs import ImageOutput
from nodes.utils.utils import get_h_w_c, round_half_up
Expand Down Expand Up @@ -99,7 +98,7 @@ def compare_conditions(b: int) -> bool:
unit="px",
),
EnumInput(SideSelection, label="Resize To"),
InterpolationInput(),
ResizeFilterInput(),
EnumInput(
ResizeCondition,
option_labels={
Expand Down Expand Up @@ -166,20 +165,15 @@ def compareCondition(b: uint): bool {
assume_normalized=True,
)
],
limited_to_8bpc=True,
)
def resize_to_side_node(
img: np.ndarray,
target: int,
side: SideSelection,
interpolation: InterpolationMethod,
filter: ResizeFilter,
condition: ResizeCondition,
) -> np.ndarray:
"""Takes an image and resizes it"""

logger.debug(f"Resizing image to {side} via {interpolation}")

h, w, _ = get_h_w_c(img)
out_dims = resize_to_side_conditional(w, h, target, side, condition)

return resize(img, out_dims, interpolation)
return resize(img, out_dims, filter)
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import cv2
import numpy as np

from nodes.impl.pil_utils import InterpolationMethod, resize
from nodes.impl.resize import ResizeFilter, resize
from nodes.properties.inputs import ImageInput, NumberInput
from nodes.properties.outputs import ImageOutput
from nodes.utils.utils import get_h_w_c
Expand Down Expand Up @@ -34,7 +34,6 @@
),
],
outputs=[ImageOutput(image_type="Input0")],
limited_to_8bpc=True,
)
def average_color_fix_node(
input_img: np.ndarray, ref_img: np.ndarray, scale_factor: float
Expand All @@ -49,11 +48,7 @@ def average_color_fix_node(
max(ceil(h * (scale_factor / 100)), 1),
)

ref_img = resize(
ref_img,
out_dims,
interpolation=InterpolationMethod.BOX,
)
ref_img = resize(ref_img, out_dims, filter=ResizeFilter.BOX)

input_h, input_w, input_c = get_h_w_c(input_img)
ref_h, ref_w, ref_c = get_h_w_c(ref_img)
Expand All @@ -65,11 +60,7 @@ def average_color_fix_node(
# Find the diff of both images

# Downscale the input image
downscaled_input = resize(
input_img,
(ref_w, ref_h),
interpolation=InterpolationMethod.BOX,
)
downscaled_input = resize(input_img, (ref_w, ref_h), filter=ResizeFilter.BOX)

# adjust channels
alpha = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

import navi
from nodes.impl.pil_utils import InterpolationMethod, resize
from nodes.impl.resize import ResizeFilter, resize
from nodes.properties.inputs import ImageInput
from nodes.properties.outputs import ImageOutput
from nodes.utils.utils import get_h_w_c
Expand Down Expand Up @@ -34,13 +34,13 @@ def metal_to_spec(
if metal_size == albedo_size:
metal3_inv_scaled = metal3_inv
else:
metal3_inv_scaled = resize(metal3_inv, albedo_size, InterpolationMethod.LANCZOS)
metal3_inv_scaled = resize(metal3_inv, albedo_size, ResizeFilter.LANCZOS)
diff = albedo * metal3_inv_scaled

if metal_size == albedo_size:
scaled_albedo = albedo
else:
scaled_albedo = resize(albedo, metal_size, InterpolationMethod.LANCZOS)
scaled_albedo = resize(albedo, metal_size, ResizeFilter.LANCZOS)
spec = metal3 * scaled_albedo + metal3_inv * 0.22

if roughness is None:
Expand Down Expand Up @@ -79,7 +79,6 @@ def metal_to_spec(
channels=1,
),
],
limited_to_8bpc=True,
)
def metal_to_specular_node(
albedo: np.ndarray,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

import navi
from nodes.impl.pil_utils import InterpolationMethod, resize
from nodes.impl.resize import ResizeFilter, resize
from nodes.properties.inputs import ImageInput, SliderInput
from nodes.properties.outputs import ImageOutput
from nodes.utils.utils import get_h_w_c
Expand Down Expand Up @@ -47,9 +47,7 @@ def spec_to_metal(
else:
# to prevent color bleeding from non-metal parts of the specular map,
# we apply the metal map as alpha and resize before combining with diffuse
scaled = resize(
np.dstack((spec, metal)), diff_size, InterpolationMethod.LANCZOS
)
scaled = resize(np.dstack((spec, metal)), diff_size, ResizeFilter.LANCZOS)
sped_scaled: np.ndarray = scaled[:, :, 0:3]
metal_scaled: np.ndarray = scaled[:, :, 3]
metal3_scaled = np.dstack((metal_scaled,) * 3)
Expand Down Expand Up @@ -107,7 +105,6 @@ def spec_to_metal(
channels=1,
),
],
limited_to_8bpc=True,
)
def specular_to_metal_node(
diff: np.ndarray,
Expand Down