Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement "Add Image Diff" node #2385

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
102 changes: 102 additions & 0 deletions backend/src/nodes/impl/diff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import cv2
import numpy as np

from ..utils.utils import get_h_w_c


def diff_images(img1: np.ndarray, img2: np.ndarray) -> np.ndarray:
"""Calculates diff of input images"""

h1, w1, c1 = get_h_w_c(img1)
h2, w2, c2 = get_h_w_c(img2)

if h1 != h2 or w1 != w2:
raise ValueError("Diff inputs must have identical size")

# adjust channels
alpha1 = None
alpha2 = None
if c1 > 3:
alpha1 = img1[:, :, 3:4]
img1 = img1[:, :, :3]
if c2 > 3:
alpha2 = img2[:, :, 3:4]
img2 = img2[:, :, :3]

# Get difference between the images
diff = img1 - img2 # type: ignore

alpha_diff = None
if alpha1 is not None or alpha2 is not None:
# Don't alter RGB pixels if either input pixel is fully transparent,
# since RGB diff is indeterminate for those pixels.
if alpha1 is not None and alpha2 is not None:
invalid_alpha_mask = (alpha1 == 0) | (alpha2 == 0)
elif alpha1 is not None:
invalid_alpha_mask = alpha1 == 0
else:
invalid_alpha_mask = alpha2 == 0
invalid_alpha_indices = np.nonzero(invalid_alpha_mask)
diff[invalid_alpha_indices] = 0

if alpha1 is not None and alpha2 is not None:
alpha_diff = alpha1 - alpha2 # type: ignore

# add alpha back in
if alpha_diff is not None:
diff = np.concatenate([diff, alpha_diff], axis=2)

return diff


def sum_images(
input_img: np.ndarray,
diff: np.ndarray,
) -> np.ndarray:
"""Calculates sum of input images"""

input_h, input_w, input_c = get_h_w_c(input_img)
diff_h, diff_w, diff_c = get_h_w_c(diff)

# adjust channels
alpha = None
alpha_diff = None
if input_c > 3:
alpha = input_img[:, :, 3:4]
input_img = input_img[:, :, :3]
if diff_c > 3:
alpha_diff = diff[:, :, 3:4]
diff = diff[:, :, :3]

if input_h != diff_h or input_w != diff_w:
# Upsample the difference
diff = cv2.resize(
diff,
(input_w, input_h),
interpolation=cv2.INTER_CUBIC,
)

if alpha_diff is not None:
alpha_diff = cv2.resize(
alpha_diff,
(input_w, input_h),
interpolation=cv2.INTER_CUBIC,
)
alpha_diff = np.expand_dims(alpha_diff, 2)

if alpha_diff is not None:
# Don't alter alpha pixels if the input pixel is fully transparent, since
# doing so would expose indeterminate RGB data.
invalid_rgb_mask = alpha == 0
invalid_rgb_indices = np.nonzero(invalid_rgb_mask)
alpha_diff[invalid_rgb_indices] = 0

result = input_img + diff
if alpha_diff is not None:
alpha = alpha + alpha_diff # type: ignore

# add alpha back in
if alpha is not None:
result = np.concatenate([result, alpha], axis=2)

return result
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from math import ceil

import cv2
import numpy as np

from nodes.impl.diff import diff_images, sum_images
from nodes.impl.pil_utils import InterpolationMethod, resize
from nodes.properties.inputs import ImageInput, NumberInput
from nodes.properties.outputs import ImageOutput
Expand Down Expand Up @@ -71,67 +71,10 @@ def average_color_fix_node(
interpolation=InterpolationMethod.BOX,
)

# adjust channels
alpha = None
downscaled_alpha = None
ref_alpha = None
if input_c > 3:
alpha = input_img[:, :, 3:4]
input_img = input_img[:, :, :3]
downscaled_alpha = downscaled_input[:, :, 3:4]
downscaled_input = downscaled_input[:, :, :3]
if ref_c > 3:
ref_alpha = ref_img[:, :, 3:4]
ref_img = ref_img[:, :, :3]

# Get difference between the reference image and downscaled input
downscaled_diff = ref_img - downscaled_input # type: ignore

downscaled_alpha_diff = None
if ref_alpha is not None or downscaled_alpha is not None:
# Don't alter RGB pixels if either the input or reference pixel is
# fully transparent, since RGB diff is indeterminate for those pixels.
if ref_alpha is not None and downscaled_alpha is not None:
invalid_alpha_mask = (ref_alpha == 0) | (downscaled_alpha == 0)
elif ref_alpha is not None:
invalid_alpha_mask = ref_alpha == 0
else:
invalid_alpha_mask = downscaled_alpha == 0
invalid_alpha_indices = np.nonzero(invalid_alpha_mask)
downscaled_diff[invalid_alpha_indices] = 0

if ref_alpha is not None and downscaled_alpha is not None:
downscaled_alpha_diff = ref_alpha - downscaled_alpha # type: ignore

# Upsample the difference
diff = cv2.resize(
downscaled_diff,
(input_w, input_h),
interpolation=cv2.INTER_CUBIC,
)

alpha_diff = None
if downscaled_alpha_diff is not None:
alpha_diff = cv2.resize(
downscaled_alpha_diff,
(input_w, input_h),
interpolation=cv2.INTER_CUBIC,
)
alpha_diff = np.expand_dims(alpha_diff, 2)

if alpha_diff is not None:
# Don't alter alpha pixels if the input pixel is fully transparent, since
# doing so would expose indeterminate RGB data.
invalid_rgb_mask = alpha == 0
invalid_rgb_indices = np.nonzero(invalid_rgb_mask)
alpha_diff[invalid_rgb_indices] = 0

result = input_img + diff
if alpha_diff is not None:
alpha = alpha + alpha_diff # type: ignore
downscaled_diff = diff_images(ref_img, downscaled_input)

# add alpha back in
if alpha is not None:
result = np.concatenate([result, alpha], axis=2)
# Add the difference to the input image
result = sum_images(input_img, downscaled_diff)

return result
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from __future__ import annotations

import cv2
import numpy as np

from nodes.impl.diff import diff_images, sum_images
from nodes.impl.pil_utils import InterpolationMethod, resize
from nodes.properties.inputs import ImageInput
from nodes.properties.outputs import ImageOutput
from nodes.utils.utils import get_h_w_c

from .. import compositing_group


@compositing_group.register(
schema_id="chainner:image:diff",
name="Add Image Diff",
description="""Subtracts two reference images, and adds the diff to an input image.""",
icon="BsLayersHalf",
inputs=[
ImageInput("Image", channels=[3, 4]),
ImageInput("Reference Init", channels=[3, 4]),
ImageInput("Reference Goal", channels=[3, 4]),
],
outputs=[ImageOutput(image_type="Input0")],
limited_to_8bpc=True,
)
def add_image_diff_node(
input_img: np.ndarray,
ref_init_img: np.ndarray,
ref_goal_img: np.ndarray,
) -> np.ndarray:
"""Subtract two images, and add result to another image"""

diff = diff_images(ref_goal_img, ref_init_img)

result = sum_images(input_img, diff)

# Handle pixels that are fully transparent in input image but not goal image.
result_h, result_w, result_c = get_h_w_c(result)
ref_goal_h, ref_goal_w, ref_goal_c = get_h_w_c(ref_goal_img)
if result_c > 3 and ref_goal_c > 3:
if result_h != ref_goal_h or result_w != ref_goal_w:
# Scale the goal image to match input image.
ref_goal_img = resize(
ref_goal_img,
(result_w, result_h),
interpolation=InterpolationMethod.CUBIC,
)

# split channels
result_b, result_g, result_r, result_alpha = cv2.split(result)
ref_goal_b, ref_goal_g, ref_goal_r, ref_goal_alpha = cv2.split(ref_goal_img)

# For pixels that are fully transparent in input image, pass-through the goal pixel.
invalid_mask = result_alpha <= 0 # type: ignore
invalid_indices = np.nonzero(invalid_mask)
result_b[invalid_indices] = ref_goal_b[invalid_indices]
result_g[invalid_indices] = ref_goal_g[invalid_indices]
result_r[invalid_indices] = ref_goal_r[invalid_indices]
result_alpha[invalid_indices] = ref_goal_alpha[invalid_indices]

# merge channels
result = cv2.merge([result_b, result_g, result_r, result_alpha])

return result