diff --git a/backend/src/packages/chaiNNer_standard/image/io/load_image.py b/backend/src/packages/chaiNNer_standard/image/io/load_image.py index f40add486..9e8cc270f 100644 --- a/backend/src/packages/chaiNNer_standard/image/io/load_image.py +++ b/backend/src/packages/chaiNNer_standard/image/io/load_image.py @@ -35,6 +35,25 @@ def get_ext(path: str) -> str: return split_file_path(path)[2].lower() +def remove_unnecessary_alpha(img: np.ndarray) -> np.ndarray: + """ + Removes the alpha channel from an image if it is not used. + """ + if get_h_w_c(img)[2] != 4: + return img + + unnecessary = ( + (img.dtype == np.uint8 and np.all(img[:, :, 3] == 255)) + or (img.dtype == np.uint16 and np.all(img[:, :, 3] == 65536)) + or (img.dtype == np.float32 and np.all(img[:, :, 3] == 1.0)) + or (img.dtype == np.float64 and np.all(img[:, :, 3] == 1.0)) + ) + + if unnecessary: + return img[:, :, :3] + return img + + def _read_cv(path: str) -> np.ndarray | None: if get_ext(path) not in get_opencv_formats(): # not supported @@ -88,7 +107,10 @@ def _read_dds(path: str) -> np.ndarray | None: png = dds_to_png_texconv(path) try: - return _read_cv(png) + img = _read_cv(png) + if img is not None: + img = remove_unnecessary_alpha(img) + return img finally: os.remove(png)