From 3ebfc5c815008cec5c1d145538e565e77df294eb Mon Sep 17 00:00:00 2001 From: "gskesav@tamu.edu" Date: Sun, 25 Nov 2018 18:54:40 -0600 Subject: [PATCH 1/2] Resize 1-D, 2-D and 3-D images --- autokeras/image/image_supervised.py | 27 +++++------------ autokeras/utils.py | 41 +++++++++++++------------ setup.py | 1 - tests/test_utils.py | 46 ++++++++++++++--------------- 4 files changed, 51 insertions(+), 64 deletions(-) diff --git a/autokeras/image/image_supervised.py b/autokeras/image/image_supervised.py index 30f1960a1..7b17879f7 100644 --- a/autokeras/image/image_supervised.py +++ b/autokeras/image/image_supervised.py @@ -95,8 +95,7 @@ def __init__(self, augment=None, **kwargs): if augment is None: augment = Constant.DATA_AUGMENTATION self.augment = augment - self.resize_height = None - self.resize_width = None + self.resize_shape = [] super().__init__(**kwargs) @@ -107,15 +106,10 @@ def fit(self, x, y, x_test=None, y_test=None, time_limit=None): if self.verbose: print("Preprocessing the images.") - if x is not None and (len(x.shape) == 4 or len(x.shape) == 1 and len(x[0].shape) == 3): - self.resize_height, self.resize_width = compute_image_resize_params(x) + self.resize_shape = compute_image_resize_params(x) - if self.resize_height is not None: - x = resize_image_data(x, self.resize_height, self.resize_width) - print("x is ", x.shape) - - if self.resize_height is not None: - x_test = resize_image_data(x_test, self.resize_height, self.resize_width) + x = resize_image_data(x, self.resize_shape) + x_test = resize_image_data(x_test, self.resize_shape) if self.verbose: print("Preprocessing finished.") @@ -133,14 +127,11 @@ def export_autokeras_model(self, model_file_name): data_transformer=self.data_transformer, metric=self.metric, inverse_transform_y_method=self.inverse_transform_y, - resize_params=(self.resize_height, self.resize_width)) + resize_params=self.resize_shape) pickle_to_file(portable_model, model_file_name) def preprocess(self, x): - if len(x.shape) != 0 and len(x[0].shape) == 3: - if self.resize_height is not None: - return resize_image_data(x, self.resize_height, self.resize_width) - return x + return resize_image_data(x, self.resize_shape) class ImageClassifier(ImageSupervised): @@ -257,8 +248,7 @@ def __init__(self, graph, data_transformer, y_encoder, metric, inverse_transform self.y_encoder = y_encoder self.metric = metric self.inverse_transform_y_method = inverse_transform_y_method - self.resize_height = resize_params[0] - self.resize_width = resize_params[1] + self.resize_shape = resize_params def predict(self, x_test): """Return predict results for the testing data. @@ -288,7 +278,6 @@ def inverse_transform_y(self, output): def evaluate(self, x_test, y_test): """Return the accuracy score between predict value and `y_test`.""" - if self.resize_height is not None: - x_test = resize_image_data(x_test, self.resize_height, self.resize_width) + x_test = resize_image_data(x_test, self.resize_shape) y_predict = self.predict(x_test) return self.metric().evaluate(y_test, y_predict) diff --git a/autokeras/utils.py b/autokeras/utils.py index 7af557636..372cee5c1 100644 --- a/autokeras/utils.py +++ b/autokeras/utils.py @@ -7,14 +7,14 @@ import warnings import imageio -import numpy +import numpy as np import requests -from skimage.transform import resize import torch import subprocess import string import random from autokeras.constant import Constant +from scipy.ndimage import zoom class NoImprovementError(Exception): @@ -206,20 +206,24 @@ def compute_image_resize_params(data): median height: Median height of all images in the data. median width: Median width of all images in the data. """ - if len(data.shape) == 1 and len(data[0].shape) != 3: - return None, None + if data is None or len(data.shape) == 0: + return [] - median_height, median_width = numpy.median(numpy.array(list(map(lambda x: x.shape, data))), axis=0)[:2] + image_shapes = [] + for x in data: + image_shapes.append(x.shape) - if median_height * median_width > Constant.MAX_IMAGE_SIZE: - reduction_factor = numpy.sqrt(median_height * median_width / Constant.MAX_IMAGE_SIZE) - median_height = median_height / reduction_factor - median_width = median_width / reduction_factor + median_shape = np.median(np.array(image_shapes), axis=0) + median_size = np.prod(median_shape[:-1]) - return int(median_height), int(median_width) + if median_size > Constant.MAX_IMAGE_SIZE: + reduction_factor = np.power(Constant.MAX_IMAGE_SIZE / median_size, 1 / (len(median_shape) - 1)) + median_shape[:-1] = median_shape[:-1] * reduction_factor + return median_shape.astype(int) -def resize_image_data(data, height, width): + +def resize_image_data(data, resize_shape): """Resize images to provided height and width. Resize all images in data to size h x w x c, where h is the height, w is the width and c is the number of channels. @@ -233,22 +237,17 @@ def resize_image_data(data, height, width): Returns: data: Resize data. """ - if data is None: + if data is None or len(resize_shape) == 0: return data - if len(data.shape) == 4 and data[0].shape[0] == height and data[0].shape[1] == width: + if len(data.shape) > 1 and np.array_equal(data[0].shape, resize_shape): return data output_data = [] for im in data: - if len(im.shape) != 3: - return data - output_data.append(resize(image=im, - output_shape=(height, width, im.shape[-1]), - mode='edge', - preserve_range=True)) - - return numpy.array(output_data) + output_data.append(zoom(im, np.divide(resize_shape, im.shape))) + + return np.array(output_data) def get_system(): diff --git a/setup.py b/setup.py index a177766e0..7e835b9c0 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,6 @@ 'numpy==1.14.5', 'keras==2.2.2', 'scikit-learn==0.19.1', - 'scikit-image==0.13.1', 'tqdm==4.25.0', 'tensorflow==1.10.0', 'imageio==2.4.1', diff --git a/tests/test_utils.py b/tests/test_utils.py index 13058a62d..8c8deb2f5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,4 @@ -import numpy +import numpy as np import subprocess from unittest.mock import patch @@ -56,30 +56,30 @@ def test_fetch(_): def test_compute_image_resize_params(): - # Case-1: Compute median height and width for smaller images. - data = numpy.array([numpy.random.randint(256, size=(10, 10, 3)), - numpy.random.randint(256, size=(20, 20, 3)), - numpy.random.randint(256, size=(30, 30, 3)), - numpy.random.randint(256, size=(40, 40, 3))]) - resize_height, resize_width = compute_image_resize_params(data) - assert resize_height == 25 - assert resize_width == 25 - - modified_data = resize_image_data(data, resize_height, resize_width) + # Case-1: Compute median shape for smaller (3-D) images. + data = np.array([np.random.randint(256, size=(5, 5, 5, 3)), + np.random.randint(256, size=(6, 6, 6, 3)), + np.random.randint(256, size=(7, 7, 7, 3)), + np.random.randint(256, size=(8, 8, 8, 3))]) + resize_shape = compute_image_resize_params(data) + assert np.array_equal(resize_shape, [6, 6, 6, 3]) + + modified_data = resize_image_data(data, resize_shape) for image in modified_data: - assert image.shape == (25, 25, 3) - - # Case-2: Resize to max size for larger images. - data = numpy.array([numpy.random.randint(256, size=(int(numpy.sqrt(Constant.MAX_IMAGE_SIZE) + 1), - int(numpy.sqrt(Constant.MAX_IMAGE_SIZE) + 1), - 3))]) - resize_height, resize_width = compute_image_resize_params(data) - assert resize_height == int(numpy.sqrt(Constant.MAX_IMAGE_SIZE)) - assert resize_width == int(numpy.sqrt(Constant.MAX_IMAGE_SIZE)) - - modified_data = resize_image_data(data, resize_height, resize_width) + assert np.array_equal(image.shape, [6, 6, 6, 3]) + + # Case-2: Resize to max size for larger (2-D) images. + data = np.array([np.random.randint(256, size=(int(np.sqrt(Constant.MAX_IMAGE_SIZE) + 1), + int(np.sqrt(Constant.MAX_IMAGE_SIZE) + 1), + 3))]) + resize_shape = compute_image_resize_params(data) + assert np.array_equal(resize_shape, (int(np.sqrt(Constant.MAX_IMAGE_SIZE)), + int(np.sqrt(Constant.MAX_IMAGE_SIZE)), + 3)) + + modified_data = resize_image_data(data, resize_shape) for image in modified_data: - assert image.shape == (resize_height, resize_width, 3) + assert np.array_equal(image.shape, resize_shape) def test_get_system(): From f7cc3973bb83ef03af611c68f1b9de461d44f7e5 Mon Sep 17 00:00:00 2001 From: "gskesav@tamu.edu" Date: Mon, 26 Nov 2018 00:47:19 -0600 Subject: [PATCH 2/2] Update Docstrings --- autokeras/utils.py | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/autokeras/utils.py b/autokeras/utils.py index 372cee5c1..64024face 100644 --- a/autokeras/utils.py +++ b/autokeras/utils.py @@ -194,26 +194,24 @@ def read_image(img_path): def compute_image_resize_params(data): - """Compute median height and width of all images in data. + """Compute median dimension of all images in data. - These values are used to resize the images at later point. Number of channels do not change from the original - images. Currently, only 2-D images are supported. + It used to resize the images later. Number of channels do not change from the original data. Args: - data: 2-D Image data with shape N x H x W x C. + data: 1-D, 2-D or 3-D images. The Images are expected to have channel last configuration. Returns: - median height: Median height of all images in the data. - median width: Median width of all images in the data. + median shape. """ if data is None or len(data.shape) == 0: return [] - image_shapes = [] + data_shapes = [] for x in data: - image_shapes.append(x.shape) + data_shapes.append(x.shape) - median_shape = np.median(np.array(image_shapes), axis=0) + median_shape = np.median(np.array(data_shapes), axis=0) median_size = np.prod(median_shape[:-1]) if median_size > Constant.MAX_IMAGE_SIZE: @@ -224,18 +222,14 @@ def compute_image_resize_params(data): def resize_image_data(data, resize_shape): - """Resize images to provided height and width. - - Resize all images in data to size h x w x c, where h is the height, w is the width and c is the number of channels. - The number of channels c does not change from data. The function supports only 2-D image data. + """Resize images to given dimension. Args: - data: 2-D Image data with shape N x H x W x C. - height: Image resize height. - width: Image resize width. + data: 1-D, 2-D or 3-D images. The Images are expected to have channel last configuration. + resize_shape: Image resize dimension. Returns: - data: Resize data. + data: Reshaped data. """ if data is None or len(resize_shape) == 0: return data @@ -245,7 +239,7 @@ def resize_image_data(data, resize_shape): output_data = [] for im in data: - output_data.append(zoom(im, np.divide(resize_shape, im.shape))) + output_data.append(zoom(input=im, zoom=np.divide(resize_shape, im.shape))) return np.array(output_data)