Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Output normalization #1717

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions backend/src/nodes/properties/inputs/numpy_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np

from ...impl.image_utils import get_h_w_c, normalize
from ...impl.image_utils import get_h_w_c
from ...utils.format import format_image_with_channels
from .. import expression
from .base_input import BaseInput
Expand Down Expand Up @@ -47,10 +47,12 @@ def enforce(self, value):
f"The input {self.label} only supports {expected} but was given {actual}."
)

assert value.dtype == np.float32, "Expected the input image to be normalized."

if c == 1 and value.ndim == 3:
value = value[:, :, 0]

return normalize(value)
return value


class VideoInput(BaseInput):
Expand Down
9 changes: 5 additions & 4 deletions backend/src/nodes/properties/outputs/base_output.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Literal, Union
from typing import Literal

from base_types import OutputId

Expand All @@ -20,7 +20,7 @@ def __init__(
self.output_type: expression.ExpressionJson = output_type
self.label: str = label
self.id: OutputId = OutputId(-1)
self.never_reason: Union[str, None] = None
self.never_reason: str | None = None
self.kind: OutputKind = kind
self.has_handle: bool = has_handle

Expand All @@ -34,7 +34,7 @@ def toDict(self):
"hasHandle": self.has_handle,
}

def with_id(self, output_id: Union[OutputId, int]):
def with_id(self, output_id: OutputId | int):
self.id = OutputId(output_id)
return self

Expand All @@ -54,5 +54,6 @@ def get_broadcast_data(self, _value):
def get_broadcast_type(self, _value) -> expression.ExpressionJson | None:
return None

def validate(self, value) -> None:
def enforce(self, value: object) -> object:
assert value is not None
return value
6 changes: 4 additions & 2 deletions backend/src/nodes/properties/outputs/file_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ def __init__(self, file_type: expression.ExpressionJson, label: str):
def get_broadcast_data(self, value: str):
return value

def validate(self, value) -> None:
def enforce(self, value) -> str:
assert isinstance(value, str)
return value


class DirectoryOutput(BaseOutput):
Expand All @@ -32,5 +33,6 @@ def __init__(self, label: str = "Directory", of_input: int | None = None):
def get_broadcast_type(self, value: str):
return expression.named("Directory", {"path": expression.literal(value)})

def validate(self, value) -> None:
def enforce(self, value) -> str:
assert isinstance(value, str)
return value
9 changes: 6 additions & 3 deletions backend/src/nodes/properties/outputs/generic_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ def __init__(
def get_broadcast_type(self, value: int | float):
return expression.literal(value)

def validate(self, value) -> None:
def enforce(self, value) -> int | float:
assert isinstance(value, (int, float))
return value


class TextOutput(BaseOutput):
Expand All @@ -31,8 +32,9 @@ def __init__(
def get_broadcast_type(self, value: str):
return expression.literal(value)

def validate(self, value) -> None:
def enforce(self, value) -> str:
assert isinstance(value, str)
return value


def FileNameOutput(label: str = "Name", of_input: int | None = None):
Expand All @@ -49,5 +51,6 @@ class SeedOutput(BaseOutput):
def __init__(self, label: str = "Seed"):
super().__init__(output_type="Seed", label=label, kind="generic")

def validate(self, value) -> None:
def enforce(self, value) -> Seed:
assert isinstance(value, Seed)
return value
25 changes: 21 additions & 4 deletions backend/src/nodes/properties/outputs/numpy_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import cv2
import numpy as np

from ...impl.image_utils import to_uint8
from ...impl.image_utils import normalize, to_uint8
from ...impl.pil_utils import InterpolationMethod, resize
from ...utils.format import format_image_with_channels
from ...utils.utils import get_h_w_c
Expand All @@ -24,8 +24,9 @@ def __init__(
):
super().__init__(output_type, label, kind=kind, has_handle=has_handle)

def validate(self, value) -> None:
def enforce(self, value) -> np.ndarray:
assert isinstance(value, np.ndarray)
return value


def AudioOutput():
Expand All @@ -41,6 +42,7 @@ def __init__(
kind: OutputKind = "image",
has_handle: bool = True,
channels: Optional[int] = None,
assume_normalized: bool = False,
):
super().__init__(
expression.intersect(image_type, expression.Image(channels=channels)),
Expand All @@ -50,6 +52,7 @@ def __init__(
)

self.channels: Optional[int] = channels
self.assume_normalized: bool = assume_normalized

def get_broadcast_data(self, value: np.ndarray):
h, w, c = get_h_w_c(value)
Expand All @@ -63,9 +66,8 @@ def get_broadcast_type(self, value: np.ndarray):
h, w, c = get_h_w_c(value)
return expression.Image(width=w, height=h, channels=c)

def validate(self, value) -> None:
def enforce(self, value) -> np.ndarray:
assert isinstance(value, np.ndarray)
assert value.dtype == np.float32

_, _, c = get_h_w_c(value)

Expand All @@ -78,6 +80,21 @@ def validate(self, value) -> None:
f" Please report this bug."
)

# flatting 3D single-channel images to 2D
if c == 1 and value.ndim == 3:
value = value[:, :, 0]

if self.assume_normalized:
assert value.dtype == np.float32, (
f"The output {self.label} did not return a normalized image."
f" This is a bug in the implementation of the node."
f" Please report this bug."
f"\n\nTo the author of this node: Either use `normalize` or remove `assume_normalized=True` from this output."
)
return value

return normalize(value)


def preview_encode(
img: np.ndarray,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
from sanic.log import logger

from nodes.impl.image_utils import normalize, to_uint8
from nodes.impl.image_utils import to_uint8
from nodes.properties.inputs import (
BoolInput,
DirectoryInput,
Expand Down Expand Up @@ -73,7 +73,7 @@ class Writer:
def VideoFrameIteratorFrameLoaderNode(
img: np.ndarray, idx: int, video_dir: str, video_name: str
) -> Tuple[np.ndarray, int, str, str]:
return normalize(img), idx, video_dir, video_name
return img, idx, video_dir, video_name


@batch_processing_group.register(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
width="Input0",
height="Input1",
channels="1",
)
),
assume_normalized=True,
)
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@
width="Input0",
height="Input1",
channels="3",
)
),
assume_normalized=True,
)
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@
width="Input0",
height="Input1",
channels="4",
)
),
assume_normalized=True,
)
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,4 +200,4 @@ def create_noise_node(
kwargs["seed"] = (kwargs["seed"] + 1) % (2**32)
img /= total_brightness

return np.clip(img, 0, 1)
return img
3 changes: 0 additions & 3 deletions backend/src/packages/chaiNNer_standard/image/io/load_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from nodes.impl.dds.texconv import dds_to_png_texconv
from nodes.impl.image_formats import get_opencv_formats, get_pil_formats
from nodes.impl.image_utils import normalize
from nodes.properties.inputs import ImageFileInput
from nodes.properties.outputs import DirectoryOutput, FileNameOutput, LargeImageOutput
from nodes.utils.utils import get_h_w_c, split_file_path
Expand Down Expand Up @@ -146,6 +145,4 @@ def load_image_node(path: str) -> Tuple[np.ndarray, str, str]:
f'The image "{path}" you are trying to read cannot be read by chaiNNer.'
)

img = normalize(img)

return img, dirname, basename
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,4 @@ def brightness_and_contrast_node(
axis=2,
)

return np.clip(img, 0, 1)
return img
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,4 @@ def color_levels_node(
out_white_all - out_black_all # type: ignore
) + out_black_all

return np.clip(img, 0, 1)
return img
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
description="Inverts all colors in an image.",
icon="MdInvertColors",
inputs=[ImageInput()],
outputs=[ImageOutput(image_type="Input0")],
outputs=[ImageOutput(image_type="Input0", assume_normalized=True)],
)
def invert_node(img: np.ndarray) -> np.ndarray:
c = get_h_w_c(img)[2]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@
unit="%",
),
],
outputs=[ImageOutput(image_type=expression.Image(size_as="Input0"), channels=4)],
outputs=[
ImageOutput(
image_type=expression.Image(size_as="Input0"),
channels=4,
assume_normalized=True,
)
],
)
def opacity_node(img: np.ndarray, opacity: float) -> np.ndarray:
"""Apply opacity adjustment to alpha channel"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import cv2
import numpy as np

from nodes.impl.image_utils import normalize, to_uint8
from nodes.impl.image_utils import to_uint8
from nodes.properties.inputs import (
AdaptiveMethodInput,
AdaptiveThresholdInput,
Expand Down Expand Up @@ -61,4 +61,4 @@ def adaptive_threshold_node(
c,
)

return normalize(result)
return result
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
height="Input0.height & Input1.height & Input2.height & match Input3 { Image as i => i.height, _ => any }",
),
channels=4,
assume_normalized=True,
).with_never_reason(
"The input channels have different sizes but must all be the same size."
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,28 @@
inputs=[ImageInput()],
outputs=[
ImageOutput(
"R Channel", image_type=expression.Image(size_as="Input0"), channels=1
"R Channel",
image_type=expression.Image(size_as="Input0"),
channels=1,
assume_normalized=True,
).with_id(2),
ImageOutput(
"G Channel", image_type=expression.Image(size_as="Input0"), channels=1
"G Channel",
image_type=expression.Image(size_as="Input0"),
channels=1,
assume_normalized=True,
).with_id(1),
ImageOutput(
"B Channel", image_type=expression.Image(size_as="Input0"), channels=1
"B Channel",
image_type=expression.Image(size_as="Input0"),
channels=1,
assume_normalized=True,
).with_id(0),
ImageOutput(
"A Channel", image_type=expression.Image(size_as="Input0"), channels=1
"A Channel",
image_type=expression.Image(size_as="Input0"),
channels=1,
assume_normalized=True,
),
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
height="Input0.height & Input1.height",
),
channels=4,
assume_normalized=True,
).with_never_reason(
"The RGB and alpha channels have different sizes but must have the same size."
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@
"RGB Channels",
image_type=expression.Image(size_as="Input0"),
channels=3,
assume_normalized=True,
),
ImageOutput(
"Alpha Channel",
image_type=expression.Image(size_as="Input0"),
channels=1,
assume_normalized=True,
),
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
width="Input0.width + Input2 * 2",
height="Input0.height + Input2 * 2",
channels="BorderType::getOutputChannels(Input1, Input0.channels)",
)
),
assume_normalized=True,
)
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
width="Input0.width + Input3 + Input4",
height="Input0.height + Input2 + Input5",
channels="BorderType::getOutputChannels(Input1, Input0.channels)",
)
),
assume_normalized=True,
)
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
width="(Input0.width - Input1 * 2) & int(1..)",
height="(Input0.height - Input1 * 2) & int(1..)",
channels_as="Input0",
)
),
assume_normalized=True,
).with_never_reason(
"The cropped area would result in an image with no width or no height."
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
channels: Input0.channels
}
}
"""
""",
assume_normalized=True,
)
],
)
Expand Down
Loading