diff --git a/keras/src/backend/mlx/image.py b/keras/src/backend/mlx/image.py index e8d38750c56..a3c08cea2a4 100644 --- a/keras/src/backend/mlx/image.py +++ b/keras/src/backend/mlx/image.py @@ -300,20 +300,28 @@ def resize( size, interpolation="bilinear", antialias=False, + crop_to_aspect_ratio=False, + pad_to_aspect_ratio=False, + fill_mode="constant", + fill_value=0.0, data_format="channels_last", ): if antialias: raise NotImplementedError( "Antialiasing not implemented for the MLX backend" ) - + if pad_to_aspect_ratio and crop_to_aspect_ratio: + raise ValueError( + "Only one of `pad_to_aspect_ratio` & `crop_to_aspect_ratio` " + "can be `True`." + ) if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS.keys(): raise ValueError( "Invalid value for argument `interpolation`. Expected of one " f"{set(AFFINE_TRANSFORM_INTERPOLATIONS.keys())}. Received: " f"interpolation={interpolation}" ) - + target_height, target_width = size size = tuple(size) image = convert_to_tensor(image) @@ -324,6 +332,127 @@ def resize( f"image.shape={image.shape}" ) + if crop_to_aspect_ratio: + shape = image.shape + if data_format == "channels_last": + height, width = shape[-3], shape[-2] + else: + height, width = shape[-2], shape[-1] + crop_height = int(float(width * target_height) / target_width) + crop_height = min(height, crop_height) + crop_width = int(float(height * target_width) / target_height) + crop_width = min(width, crop_width) + crop_box_hstart = int(float(height - crop_height) / 2) + crop_box_wstart = int(float(width - crop_width) / 2) + if data_format == "channels_last": + if len(image.shape) == 4: + image = image[ + :, + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + :, + ] + else: + image = image[ + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + :, + ] + else: + if len(image.shape) == 4: + image = image[ + :, + :, + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + ] + else: + image = image[ + :, + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + ] + elif pad_to_aspect_ratio: + shape = image.shape + batch_size = image.shape[0] + if data_format == "channels_last": + height, width, channels = shape[-3], shape[-2], shape[-1] + else: + channels, height, width = shape[-3], shape[-2], shape[-1] + pad_height = int(float(width * target_height) / target_width) + pad_height = max(height, pad_height) + pad_width = int(float(height * target_width) / target_height) + pad_width = max(width, pad_width) + img_box_hstart = int(float(pad_height - height) / 2) + img_box_wstart = int(float(pad_width - width) / 2) + if data_format == "channels_last": + if len(image.shape) == 4: + padded_img = ( + mx.ones( + ( + batch_size, + pad_height + height, + pad_width + width, + channels, + ), + dtype=image.dtype, + ) + * fill_value + ) + padded_img[ + :, + img_box_hstart : img_box_hstart + height, + img_box_wstart : img_box_wstart + width, + :, + ] = image + else: + padded_img = ( + mx.ones( + (pad_height + height, pad_width + width, channels), + dtype=image.dtype, + ) + * fill_value + ) + padded_img[ + img_box_hstart : img_box_hstart + height, + img_box_wstart : img_box_wstart + width, + :, + ] = image + else: + if len(image.shape) == 4: + padded_img = ( + mx.ones( + ( + batch_size, + channels, + pad_height + height, + pad_width + width, + ), + dtype=image.dtype, + ) + * fill_value + ) + padded_img[ + :, + :, + img_box_hstart : img_box_hstart + height, + img_box_wstart : img_box_wstart + width, + ] = image + else: + padded_img = ( + mx.ones( + (channels, pad_height + height, pad_width + width), + dtype=image.dtype, + ) + * fill_value + ) + padded_img[ + :, + img_box_hstart : img_box_hstart + height, + img_box_wstart : img_box_wstart + width, + ] = image + image = padded_img + # Change to channels_last if data_format == "channels_first": image = ( diff --git a/keras/src/ops/image_test.py b/keras/src/ops/image_test.py index e7a3fe7face..330949c7b0e 100644 --- a/keras/src/ops/image_test.py +++ b/keras/src/ops/image_test.py @@ -273,6 +273,14 @@ def test_resize(self, interpolation, antialias, data_format): f"Received: interpolation={interpolation}, " f"antialias={antialias}." ) + if backend.backend() == "mlx": + if interpolation in ["lanczos3", "lanczos5", "bicubic"]: + self.skipTest( + f"Resizing with interpolation={interpolation} is " + "not supported by the mlx backend. " + ) + elif antialias: + self.skipTest("antialias=True not supported by mlx backend.") # Unbatched case if data_format == "channels_first": x = np.random.random((3, 50, 50)) * 255