diff --git a/backend/src/nodes/impl/diff.py b/backend/src/nodes/impl/diff.py new file mode 100644 index 000000000..3cfeab176 --- /dev/null +++ b/backend/src/nodes/impl/diff.py @@ -0,0 +1,102 @@ +import cv2 +import numpy as np + +from ..utils.utils import get_h_w_c + + +def diff_images(img1: np.ndarray, img2: np.ndarray) -> np.ndarray: + """Calculates diff of input images""" + + h1, w1, c1 = get_h_w_c(img1) + h2, w2, c2 = get_h_w_c(img2) + + if h1 != h2 or w1 != w2: + raise ValueError("Diff inputs must have identical size") + + # adjust channels + alpha1 = None + alpha2 = None + if c1 > 3: + alpha1 = img1[:, :, 3:4] + img1 = img1[:, :, :3] + if c2 > 3: + alpha2 = img2[:, :, 3:4] + img2 = img2[:, :, :3] + + # Get difference between the images + diff = img1 - img2 # type: ignore + + alpha_diff = None + if alpha1 is not None or alpha2 is not None: + # Don't alter RGB pixels if either input pixel is fully transparent, + # since RGB diff is indeterminate for those pixels. + if alpha1 is not None and alpha2 is not None: + invalid_alpha_mask = (alpha1 == 0) | (alpha2 == 0) + elif alpha1 is not None: + invalid_alpha_mask = alpha1 == 0 + else: + invalid_alpha_mask = alpha2 == 0 + invalid_alpha_indices = np.nonzero(invalid_alpha_mask) + diff[invalid_alpha_indices] = 0 + + if alpha1 is not None and alpha2 is not None: + alpha_diff = alpha1 - alpha2 # type: ignore + + # add alpha back in + if alpha_diff is not None: + diff = np.concatenate([diff, alpha_diff], axis=2) + + return diff + + +def sum_images( + input_img: np.ndarray, + diff: np.ndarray, +) -> np.ndarray: + """Calculates sum of input images""" + + input_h, input_w, input_c = get_h_w_c(input_img) + diff_h, diff_w, diff_c = get_h_w_c(diff) + + # adjust channels + alpha = None + alpha_diff = None + if input_c > 3: + alpha = input_img[:, :, 3:4] + input_img = input_img[:, :, :3] + if diff_c > 3: + alpha_diff = diff[:, :, 3:4] + diff = diff[:, :, :3] + + if input_h != diff_h or input_w != diff_w: + # Upsample the difference + diff = cv2.resize( + diff, + (input_w, input_h), + interpolation=cv2.INTER_CUBIC, + ) + + if alpha_diff is not None: + alpha_diff = cv2.resize( + alpha_diff, + (input_w, input_h), + interpolation=cv2.INTER_CUBIC, + ) + alpha_diff = np.expand_dims(alpha_diff, 2) + + if alpha_diff is not None: + # Don't alter alpha pixels if the input pixel is fully transparent, since + # doing so would expose indeterminate RGB data. + invalid_rgb_mask = alpha == 0 + invalid_rgb_indices = np.nonzero(invalid_rgb_mask) + alpha_diff[invalid_rgb_indices] = 0 + + result = input_img + diff + if alpha_diff is not None: + alpha = alpha + alpha_diff # type: ignore + + # add alpha back in + if alpha is not None: + result = np.concatenate([result, alpha], axis=2) + + return result diff --git a/backend/src/packages/chaiNNer_standard/image_filter/correction/average_color_fix.py b/backend/src/packages/chaiNNer_standard/image_filter/correction/average_color_fix.py index 3c2308980..6b1f03dcf 100644 --- a/backend/src/packages/chaiNNer_standard/image_filter/correction/average_color_fix.py +++ b/backend/src/packages/chaiNNer_standard/image_filter/correction/average_color_fix.py @@ -2,9 +2,9 @@ from math import ceil -import cv2 import numpy as np +from nodes.impl.diff import diff_images, sum_images from nodes.impl.pil_utils import InterpolationMethod, resize from nodes.properties.inputs import ImageInput, NumberInput from nodes.properties.outputs import ImageOutput @@ -71,67 +71,10 @@ def average_color_fix_node( interpolation=InterpolationMethod.BOX, ) - # adjust channels - alpha = None - downscaled_alpha = None - ref_alpha = None - if input_c > 3: - alpha = input_img[:, :, 3:4] - input_img = input_img[:, :, :3] - downscaled_alpha = downscaled_input[:, :, 3:4] - downscaled_input = downscaled_input[:, :, :3] - if ref_c > 3: - ref_alpha = ref_img[:, :, 3:4] - ref_img = ref_img[:, :, :3] - # Get difference between the reference image and downscaled input - downscaled_diff = ref_img - downscaled_input # type: ignore - - downscaled_alpha_diff = None - if ref_alpha is not None or downscaled_alpha is not None: - # Don't alter RGB pixels if either the input or reference pixel is - # fully transparent, since RGB diff is indeterminate for those pixels. - if ref_alpha is not None and downscaled_alpha is not None: - invalid_alpha_mask = (ref_alpha == 0) | (downscaled_alpha == 0) - elif ref_alpha is not None: - invalid_alpha_mask = ref_alpha == 0 - else: - invalid_alpha_mask = downscaled_alpha == 0 - invalid_alpha_indices = np.nonzero(invalid_alpha_mask) - downscaled_diff[invalid_alpha_indices] = 0 - - if ref_alpha is not None and downscaled_alpha is not None: - downscaled_alpha_diff = ref_alpha - downscaled_alpha # type: ignore - - # Upsample the difference - diff = cv2.resize( - downscaled_diff, - (input_w, input_h), - interpolation=cv2.INTER_CUBIC, - ) - - alpha_diff = None - if downscaled_alpha_diff is not None: - alpha_diff = cv2.resize( - downscaled_alpha_diff, - (input_w, input_h), - interpolation=cv2.INTER_CUBIC, - ) - alpha_diff = np.expand_dims(alpha_diff, 2) - - if alpha_diff is not None: - # Don't alter alpha pixels if the input pixel is fully transparent, since - # doing so would expose indeterminate RGB data. - invalid_rgb_mask = alpha == 0 - invalid_rgb_indices = np.nonzero(invalid_rgb_mask) - alpha_diff[invalid_rgb_indices] = 0 - - result = input_img + diff - if alpha_diff is not None: - alpha = alpha + alpha_diff # type: ignore + downscaled_diff = diff_images(ref_img, downscaled_input) - # add alpha back in - if alpha is not None: - result = np.concatenate([result, alpha], axis=2) + # Add the difference to the input image + result = sum_images(input_img, downscaled_diff) return result diff --git a/backend/src/packages/chaiNNer_standard/image_utility/compositing/add_image_diff.py b/backend/src/packages/chaiNNer_standard/image_utility/compositing/add_image_diff.py new file mode 100644 index 000000000..6f8c67bdf --- /dev/null +++ b/backend/src/packages/chaiNNer_standard/image_utility/compositing/add_image_diff.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +import cv2 +import numpy as np + +from nodes.impl.diff import diff_images, sum_images +from nodes.impl.pil_utils import InterpolationMethod, resize +from nodes.properties.inputs import ImageInput +from nodes.properties.outputs import ImageOutput +from nodes.utils.utils import get_h_w_c + +from .. import compositing_group + + +@compositing_group.register( + schema_id="chainner:image:diff", + name="Add Image Diff", + description="""Subtracts two reference images, and adds the diff to an input image.""", + icon="BsLayersHalf", + inputs=[ + ImageInput("Image", channels=[3, 4]), + ImageInput("Reference Init", channels=[3, 4]), + ImageInput("Reference Goal", channels=[3, 4]), + ], + outputs=[ImageOutput(image_type="Input0")], + limited_to_8bpc=True, +) +def add_image_diff_node( + input_img: np.ndarray, + ref_init_img: np.ndarray, + ref_goal_img: np.ndarray, +) -> np.ndarray: + """Subtract two images, and add result to another image""" + + diff = diff_images(ref_goal_img, ref_init_img) + + result = sum_images(input_img, diff) + + # Handle pixels that are fully transparent in input image but not goal image. + result_h, result_w, result_c = get_h_w_c(result) + ref_goal_h, ref_goal_w, ref_goal_c = get_h_w_c(ref_goal_img) + if result_c > 3 and ref_goal_c > 3: + if result_h != ref_goal_h or result_w != ref_goal_w: + # Scale the goal image to match input image. + ref_goal_img = resize( + ref_goal_img, + (result_w, result_h), + interpolation=InterpolationMethod.CUBIC, + ) + + # split channels + result_b, result_g, result_r, result_alpha = cv2.split(result) + ref_goal_b, ref_goal_g, ref_goal_r, ref_goal_alpha = cv2.split(ref_goal_img) + + # For pixels that are fully transparent in input image, pass-through the goal pixel. + invalid_mask = result_alpha <= 0 # type: ignore + invalid_indices = np.nonzero(invalid_mask) + result_b[invalid_indices] = ref_goal_b[invalid_indices] + result_g[invalid_indices] = ref_goal_g[invalid_indices] + result_r[invalid_indices] = ref_goal_r[invalid_indices] + result_alpha[invalid_indices] = ref_goal_alpha[invalid_indices] + + # merge channels + result = cv2.merge([result_b, result_g, result_r, result_alpha]) + + return result