diff --git a/backend/src/nodes/properties/inputs/file_inputs.py b/backend/src/nodes/properties/inputs/file_inputs.py index 16499bf6e..7c0532d2d 100644 --- a/backend/src/nodes/properties/inputs/file_inputs.py +++ b/backend/src/nodes/properties/inputs/file_inputs.py @@ -7,7 +7,6 @@ # pylint: disable=relative-beyond-top-level from ...impl.image_formats import get_available_image_formats -from .generic_inputs import DropDownInput FileInputKind = Union[ Literal["bin"], @@ -145,48 +144,6 @@ def enforce(self, value): return value -def ImageExtensionDropdown() -> DropDownInput: - """Input for selecting file type from dropdown""" - return DropDownInput( - input_type="ImageExtension", - label="Image Extension", - options=[ - { - "option": "PNG", - "value": "png", - }, - { - "option": "JPG", - "value": "jpg", - }, - { - "option": "GIF", - "value": "gif", - }, - { - "option": "TIFF", - "value": "tiff", - }, - { - "option": "WEBP", - "value": "webp", - }, - { - "option": "WEBP (Lossless)", - "value": "webp-lossless", - }, - { - "option": "TGA", - "value": "tga", - }, - { - "option": "DDS", - "value": "dds", - }, - ], - ) - - def BinFileInput(primary_input: bool = False) -> FileInput: """Input for submitting a local .bin file""" return FileInput( diff --git a/backend/src/packages/chaiNNer_standard/image/io/save_image.py b/backend/src/packages/chaiNNer_standard/image/io/save_image.py index 241ebd01a..223a5d59a 100644 --- a/backend/src/packages/chaiNNer_standard/image/io/save_image.py +++ b/backend/src/packages/chaiNNer_standard/image/io/save_image.py @@ -2,6 +2,7 @@ import os from enum import Enum +from typing import Dict, List import cv2 import numpy as np @@ -26,7 +27,6 @@ DdsMipMapsDropdown, DirectoryInput, EnumInput, - ImageExtensionDropdown, ImageInput, SliderInput, TextInput, @@ -35,6 +35,38 @@ from .. import io_group + +class ImageFormat(Enum): + PNG = "png" + JPG = "jpg" + GIF = "gif" + BMP = "bmp" + TIFF = "tiff" + WEBP = "webp" + WEBP_LOSSLESS = "webp-lossless" + TGA = "tga" + DDS = "dds" + + @property + def extension(self) -> str: + if self == ImageFormat.WEBP_LOSSLESS: + return ImageFormat.WEBP.value + return self.value + + +IMAGE_FORMAT_LABELS: Dict[ImageFormat, str] = { + ImageFormat.PNG: "PNG", + ImageFormat.JPG: "JPG", + ImageFormat.GIF: "GIF", + ImageFormat.BMP: "BMP", + ImageFormat.TIFF: "TIFF", + ImageFormat.WEBP: "WEBP", + ImageFormat.WEBP_LOSSLESS: "WEBP (Lossless)", + ImageFormat.TGA: "TGA", + ImageFormat.DDS: "DDS", +} + + SUPPORTED_FORMATS = {f for f, _ in SUPPORTED_DDS_FORMATS} SUPPORTED_BC7_FORMATS = list(SUPPORTED_FORMATS.intersection(BC7_FORMATS)) SUPPORTED_BC123_FORMATS = list(SUPPORTED_FORMATS.intersection(BC123_FORMATS)) @@ -77,8 +109,13 @@ class JpegSubsampling(Enum): "The name of the image file **without** the file extension. If the file already exists, it will be overwritten.", "Example: `my-image`", ), - ImageExtensionDropdown().with_id(4), - if_enum_group(4, ["jpg", "webp"])( + EnumInput( + ImageFormat, + "Image Format", + default_value=ImageFormat.PNG, + option_labels=IMAGE_FORMAT_LABELS, + ).with_id(4), + if_enum_group(4, [ImageFormat.JPG, ImageFormat.WEBP])( SliderInput( "Quality", minimum=0, @@ -87,7 +124,7 @@ class JpegSubsampling(Enum): slider_step=1, ), ), - if_enum_group(4, "jpg")( + if_enum_group(4, ImageFormat.JPG)( EnumInput( JpegSubsampling, label="Chroma Subsampling", @@ -101,7 +138,7 @@ class JpegSubsampling(Enum): ).with_id(11), BoolInput("Progressive", default=False).with_id(12), ), - if_enum_group(4, "dds")( + if_enum_group(4, ImageFormat.DDS)( DdsFormatDropdown().with_id(6), if_enum_group(6, SUPPORTED_BC7_FORMATS)( EnumInput( @@ -132,7 +169,7 @@ def save_image_node( base_directory: str, relative_path: str | None, filename: str, - extension: str, + image_format: ImageFormat, quality: int, chroma_subsampling: JpegSubsampling, progressive: bool, @@ -145,25 +182,17 @@ def save_image_node( ) -> None: """Write an image to the specified path and return write status""" - lossless = False - if extension == "webp-lossless": - extension = "webp" - lossless = True - - full_file = f"{filename}.{extension}" - if relative_path and relative_path != ".": - base_directory = os.path.join(base_directory, relative_path) - full_path = os.path.join(base_directory, full_file) - + full_path = get_full_path(base_directory, relative_path, filename, image_format) logger.debug(f"Writing image to path: {full_path}") + # Create directory if it doesn't exist + os.makedirs(base_directory, exist_ok=True) + # Put image back in int range img = to_uint8(img, normalized=True) - os.makedirs(base_directory, exist_ok=True) - # DDS files are handled separately - if extension == "dds": + if image_format == ImageFormat.DDS: # remap legacy DX9 formats legacy_dds = dds_format in LEGACY_TO_DXGI @@ -181,8 +210,8 @@ def save_image_node( ) return - # Any image not supported by cv2, will be handled by pillow. - if extension not in ["png", "jpg", "tiff", "webp"]: + # Some formats are handled by PIL + if image_format == ImageFormat.GIF or image_format == ImageFormat.TGA: channels = get_h_w_c(img)[2] if channels == 1: # PIL supports grayscale images just fine, so we don't need to do any conversion @@ -193,13 +222,16 @@ def save_image_node( img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGBA) else: raise RuntimeError( - f"Unsupported number of channels. Saving .{extension} images is only supported for " + f"Unsupported number of channels. Saving .{image_format.extension} images is only supported for " f"grayscale, RGB, and RGBA images." ) + with Image.fromarray(img) as image: image.save(full_path) + else: - if extension == "jpg": + params: List[int] + if image_format == ImageFormat.JPG: params = [ cv2.IMWRITE_JPEG_QUALITY, quality, @@ -208,9 +240,24 @@ def save_image_node( cv2.IMWRITE_JPEG_PROGRESSIVE, int(progressive), ] - elif extension == "webp": - params = [cv2.IMWRITE_WEBP_QUALITY, 101 if lossless else quality] + elif image_format == ImageFormat.WEBP: + params = [cv2.IMWRITE_WEBP_QUALITY, quality] + elif image_format == ImageFormat.WEBP_LOSSLESS: + params = [cv2.IMWRITE_WEBP_QUALITY, 101] else: params = [] cv_save_image(full_path, img, params) + + +def get_full_path( + base_directory: str, + relative_path: str | None, + filename: str, + image_format: ImageFormat, +) -> str: + file = f"{filename}.{image_format.extension}" + if relative_path and relative_path != ".": + base_directory = os.path.join(base_directory, relative_path) + full_path = os.path.join(base_directory, file) + return full_path