Skip to content
This repository has been archived by the owner on Nov 3, 2022. It is now read-only.

add keep_aspect_ratio flag to load_img #81

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
45 changes: 40 additions & 5 deletions keras_preprocessing/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ def save_img(path,


def load_img(path, grayscale=False, color_mode='rgb', target_size=None,
interpolation='nearest'):
interpolation='nearest', keep_aspect_ratio=False, cval=0):
"""Loads an image into PIL format.

# Arguments
Expand All @@ -490,6 +490,11 @@ def load_img(path, grayscale=False, color_mode='rgb', target_size=None,
If PIL version 1.1.3 or newer is installed, "lanczos" is also
supported. If PIL version 3.4.0 or newer is installed, "box" and
"hamming" are also supported. By default, "nearest" is used.
keep_aspect_ratio: if `True`, the resized image will have the
same aspect ratio as the original, centered and padded
with `cval` to respect `target_size`
cval: integer in [0, 255]. value to pad the output image with if
`keep_aspect_ratio` is `True`

# Returns
A PIL Image instance.
Expand Down Expand Up @@ -527,7 +532,21 @@ def load_img(path, grayscale=False, color_mode='rgb', target_size=None,
interpolation,
", ".join(_PIL_INTERPOLATION_METHODS.keys())))
resample = _PIL_INTERPOLATION_METHODS[interpolation]
img = img.resize(width_height_tuple, resample)
if not keep_aspect_ratio:
img = img.resize(width_height_tuple, resample)
else:
img.thumbnail(width_height_tuple, resample)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If size is smaller than width_height_tuple, then thumbnail will contain small image with edges on the right and at the bottom. I'm sure it is possible to resize an image to minimal height/width keeping aspect ratio.


final_img = pil_image.new(img.mode, width_height_tuple,
(cval if img.mode == 'L'
else (cval, cval, cval)))

final_img.paste(
img,
((width_height_tuple[0] - img.size[0]) // 2,
(width_height_tuple[1] - img.size[1]) // 2)
)
return final_img
return img


Expand Down Expand Up @@ -943,7 +962,10 @@ def flow_from_directory(self, directory,
save_format='png',
follow_links=False,
subset=None,
interpolation='nearest'):
interpolation='nearest',
keep_aspect_ratio=False,
cval=0,
):
"""Takes the path to a directory & generates batches of augmented data.

# Arguments
Expand Down Expand Up @@ -1029,7 +1051,10 @@ class subdirectories (default: False).
save_format=save_format,
follow_links=follow_links,
subset=subset,
interpolation=interpolation)
interpolation=interpolation,
keep_aspect_ratio=keep_aspect_ratio,
cval=cval,
)

def flow_from_dataframe(self, dataframe, directory,
x_col="filename", y_col="class", has_ext=True,
Expand Down Expand Up @@ -1857,6 +1882,8 @@ def __init__(self, directory, image_data_generator,
follow_links=False,
subset=None,
interpolation='nearest',
keep_aspect_ratio=False,
cval=0,
dtype='float32'):
super(DirectoryIterator, self).common_init(image_data_generator,
target_size,
Expand All @@ -1879,6 +1906,11 @@ def __init__(self, directory, image_data_generator,
self.dtype = dtype
white_list_formats = {'png', 'jpg', 'jpeg', 'bmp',
'ppm', 'tif', 'tiff'}
self.keep_aspect_ratio = keep_aspect_ratio
if not (0 <= cval < 255):
raise ValueError('cval {} not valid, must be in [0, 255]'
.format(cval))
self.cval = cval
# First, count the number of samples and classes.
self.samples = 0

Expand Down Expand Up @@ -1932,7 +1964,10 @@ def _get_batches_of_transformed_samples(self, index_array):
img = load_img(os.path.join(self.directory, fname),
color_mode=self.color_mode,
target_size=self.target_size,
interpolation=self.interpolation)
interpolation=self.interpolation,
keep_aspect_ratio=self.keep_aspect_ratio,
cval=self.cval,
)
x = img_to_array(img, data_format=self.data_format)
# Pillow images should be closed after `load_img`,
# but not PIL images.
Expand Down
22 changes: 22 additions & 0 deletions tests/image_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,5 +1036,27 @@ def test_load_img(self, tmpdir):
loaded_im = image.load_img(filename_rgb, target_size=(25, 25),
interpolation="unsupported")

# Test aspect ratio preservation
with tempfile.NamedTemporaryFile(suffix='.png') as f:
fn = f.name
white_square = Image.new(
mode='RGB', size=(10, 10), color=(255, 255, 255)
)
white_square.save(fn)
f.flush()

i1 = image.load_img(fn, target_size=(5, 7), keep_aspect_ratio=True)
i2 = image.load_img(
fn, target_size=(10, 5), keep_aspect_ratio=True,
cval=0, grayscale=True)

assert i1.size == (7, 5)
assert i2.size == (5, 10)
i2arr = np.array(i2)
# expect a (5, 5) white square over a (5, 10) black canvas
assert i2arr.shape == (10, 5)
assert 120 < np.mean(i2arr) < 140


if __name__ == '__main__':
pytest.main([__file__])