Skip to content

Commit

Permalink
Support BMP in Save Image (#2077)
Browse files Browse the repository at this point in the history
  • Loading branch information
RunDevelopment committed Aug 12, 2023
1 parent 57138ca commit 281a662
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 68 deletions.
43 changes: 0 additions & 43 deletions backend/src/nodes/properties/inputs/file_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

# pylint: disable=relative-beyond-top-level
from ...impl.image_formats import get_available_image_formats
from .generic_inputs import DropDownInput

FileInputKind = Union[
Literal["bin"],
Expand Down Expand Up @@ -145,48 +144,6 @@ def enforce(self, value):
return value


def ImageExtensionDropdown() -> DropDownInput:
"""Input for selecting file type from dropdown"""
return DropDownInput(
input_type="ImageExtension",
label="Image Extension",
options=[
{
"option": "PNG",
"value": "png",
},
{
"option": "JPG",
"value": "jpg",
},
{
"option": "GIF",
"value": "gif",
},
{
"option": "TIFF",
"value": "tiff",
},
{
"option": "WEBP",
"value": "webp",
},
{
"option": "WEBP (Lossless)",
"value": "webp-lossless",
},
{
"option": "TGA",
"value": "tga",
},
{
"option": "DDS",
"value": "dds",
},
],
)


def BinFileInput(primary_input: bool = False) -> FileInput:
"""Input for submitting a local .bin file"""
return FileInput(
Expand Down
97 changes: 72 additions & 25 deletions backend/src/packages/chaiNNer_standard/image/io/save_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
from enum import Enum
from typing import Dict, List

import cv2
import numpy as np
Expand All @@ -26,7 +27,6 @@
DdsMipMapsDropdown,
DirectoryInput,
EnumInput,
ImageExtensionDropdown,
ImageInput,
SliderInput,
TextInput,
Expand All @@ -35,6 +35,38 @@

from .. import io_group


class ImageFormat(Enum):
PNG = "png"
JPG = "jpg"
GIF = "gif"
BMP = "bmp"
TIFF = "tiff"
WEBP = "webp"
WEBP_LOSSLESS = "webp-lossless"
TGA = "tga"
DDS = "dds"

@property
def extension(self) -> str:
if self == ImageFormat.WEBP_LOSSLESS:
return ImageFormat.WEBP.value
return self.value


IMAGE_FORMAT_LABELS: Dict[ImageFormat, str] = {
ImageFormat.PNG: "PNG",
ImageFormat.JPG: "JPG",
ImageFormat.GIF: "GIF",
ImageFormat.BMP: "BMP",
ImageFormat.TIFF: "TIFF",
ImageFormat.WEBP: "WEBP",
ImageFormat.WEBP_LOSSLESS: "WEBP (Lossless)",
ImageFormat.TGA: "TGA",
ImageFormat.DDS: "DDS",
}


SUPPORTED_FORMATS = {f for f, _ in SUPPORTED_DDS_FORMATS}
SUPPORTED_BC7_FORMATS = list(SUPPORTED_FORMATS.intersection(BC7_FORMATS))
SUPPORTED_BC123_FORMATS = list(SUPPORTED_FORMATS.intersection(BC123_FORMATS))
Expand Down Expand Up @@ -77,8 +109,13 @@ class JpegSubsampling(Enum):
"The name of the image file **without** the file extension. If the file already exists, it will be overwritten.",
"Example: `my-image`",
),
ImageExtensionDropdown().with_id(4),
if_enum_group(4, ["jpg", "webp"])(
EnumInput(
ImageFormat,
"Image Format",
default_value=ImageFormat.PNG,
option_labels=IMAGE_FORMAT_LABELS,
).with_id(4),
if_enum_group(4, [ImageFormat.JPG, ImageFormat.WEBP])(
SliderInput(
"Quality",
minimum=0,
Expand All @@ -87,7 +124,7 @@ class JpegSubsampling(Enum):
slider_step=1,
),
),
if_enum_group(4, "jpg")(
if_enum_group(4, ImageFormat.JPG)(
EnumInput(
JpegSubsampling,
label="Chroma Subsampling",
Expand All @@ -101,7 +138,7 @@ class JpegSubsampling(Enum):
).with_id(11),
BoolInput("Progressive", default=False).with_id(12),
),
if_enum_group(4, "dds")(
if_enum_group(4, ImageFormat.DDS)(
DdsFormatDropdown().with_id(6),
if_enum_group(6, SUPPORTED_BC7_FORMATS)(
EnumInput(
Expand Down Expand Up @@ -132,7 +169,7 @@ def save_image_node(
base_directory: str,
relative_path: str | None,
filename: str,
extension: str,
image_format: ImageFormat,
quality: int,
chroma_subsampling: JpegSubsampling,
progressive: bool,
Expand All @@ -145,25 +182,17 @@ def save_image_node(
) -> None:
"""Write an image to the specified path and return write status"""

lossless = False
if extension == "webp-lossless":
extension = "webp"
lossless = True

full_file = f"{filename}.{extension}"
if relative_path and relative_path != ".":
base_directory = os.path.join(base_directory, relative_path)
full_path = os.path.join(base_directory, full_file)

full_path = get_full_path(base_directory, relative_path, filename, image_format)
logger.debug(f"Writing image to path: {full_path}")

# Create directory if it doesn't exist
os.makedirs(base_directory, exist_ok=True)

# Put image back in int range
img = to_uint8(img, normalized=True)

os.makedirs(base_directory, exist_ok=True)

# DDS files are handled separately
if extension == "dds":
if image_format == ImageFormat.DDS:
# remap legacy DX9 formats
legacy_dds = dds_format in LEGACY_TO_DXGI

Expand All @@ -181,8 +210,8 @@ def save_image_node(
)
return

# Any image not supported by cv2, will be handled by pillow.
if extension not in ["png", "jpg", "tiff", "webp"]:
# Some formats are handled by PIL
if image_format == ImageFormat.GIF or image_format == ImageFormat.TGA:
channels = get_h_w_c(img)[2]
if channels == 1:
# PIL supports grayscale images just fine, so we don't need to do any conversion
Expand All @@ -193,13 +222,16 @@ def save_image_node(
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGBA)
else:
raise RuntimeError(
f"Unsupported number of channels. Saving .{extension} images is only supported for "
f"Unsupported number of channels. Saving .{image_format.extension} images is only supported for "
f"grayscale, RGB, and RGBA images."
)

with Image.fromarray(img) as image:
image.save(full_path)

else:
if extension == "jpg":
params: List[int]
if image_format == ImageFormat.JPG:
params = [
cv2.IMWRITE_JPEG_QUALITY,
quality,
Expand All @@ -208,9 +240,24 @@ def save_image_node(
cv2.IMWRITE_JPEG_PROGRESSIVE,
int(progressive),
]
elif extension == "webp":
params = [cv2.IMWRITE_WEBP_QUALITY, 101 if lossless else quality]
elif image_format == ImageFormat.WEBP:
params = [cv2.IMWRITE_WEBP_QUALITY, quality]
elif image_format == ImageFormat.WEBP_LOSSLESS:
params = [cv2.IMWRITE_WEBP_QUALITY, 101]
else:
params = []

cv_save_image(full_path, img, params)


def get_full_path(
base_directory: str,
relative_path: str | None,
filename: str,
image_format: ImageFormat,
) -> str:
file = f"{filename}.{image_format.extension}"
if relative_path and relative_path != ".":
base_directory = os.path.join(base_directory, relative_path)
full_path = os.path.join(base_directory, file)
return full_path

0 comments on commit 281a662

Please sign in to comment.