Skip to content

Commit

Permalink
Merge f7cc397 into 847c194
Browse files Browse the repository at this point in the history
  • Loading branch information
satyakesav committed Nov 26, 2018
2 parents 847c194 + f7cc397 commit de67ca1
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 77 deletions.
27 changes: 8 additions & 19 deletions autokeras/image/image_supervised.py
Expand Up @@ -95,8 +95,7 @@ def __init__(self, augment=None, **kwargs):
if augment is None:
augment = Constant.DATA_AUGMENTATION
self.augment = augment
self.resize_height = None
self.resize_width = None
self.resize_shape = []

super().__init__(**kwargs)

Expand All @@ -107,15 +106,10 @@ def fit(self, x, y, x_test=None, y_test=None, time_limit=None):
if self.verbose:
print("Preprocessing the images.")

if x is not None and (len(x.shape) == 4 or len(x.shape) == 1 and len(x[0].shape) == 3):
self.resize_height, self.resize_width = compute_image_resize_params(x)
self.resize_shape = compute_image_resize_params(x)

if self.resize_height is not None:
x = resize_image_data(x, self.resize_height, self.resize_width)
print("x is ", x.shape)

if self.resize_height is not None:
x_test = resize_image_data(x_test, self.resize_height, self.resize_width)
x = resize_image_data(x, self.resize_shape)
x_test = resize_image_data(x_test, self.resize_shape)

if self.verbose:
print("Preprocessing finished.")
Expand All @@ -133,14 +127,11 @@ def export_autokeras_model(self, model_file_name):
data_transformer=self.data_transformer,
metric=self.metric,
inverse_transform_y_method=self.inverse_transform_y,
resize_params=(self.resize_height, self.resize_width))
resize_params=self.resize_shape)
pickle_to_file(portable_model, model_file_name)

def preprocess(self, x):
if len(x.shape) != 0 and len(x[0].shape) == 3:
if self.resize_height is not None:
return resize_image_data(x, self.resize_height, self.resize_width)
return x
return resize_image_data(x, self.resize_shape)


class ImageClassifier(ImageSupervised):
Expand Down Expand Up @@ -257,8 +248,7 @@ def __init__(self, graph, data_transformer, y_encoder, metric, inverse_transform
self.y_encoder = y_encoder
self.metric = metric
self.inverse_transform_y_method = inverse_transform_y_method
self.resize_height = resize_params[0]
self.resize_width = resize_params[1]
self.resize_shape = resize_params

def predict(self, x_test):
"""Return predict results for the testing data.
Expand Down Expand Up @@ -288,7 +278,6 @@ def inverse_transform_y(self, output):

def evaluate(self, x_test, y_test):
"""Return the accuracy score between predict value and `y_test`."""
if self.resize_height is not None:
x_test = resize_image_data(x_test, self.resize_height, self.resize_width)
x_test = resize_image_data(x_test, self.resize_shape)
y_predict = self.predict(x_test)
return self.metric().evaluate(y_test, y_predict)
61 changes: 27 additions & 34 deletions autokeras/utils.py
Expand Up @@ -7,14 +7,14 @@

import warnings
import imageio
import numpy
import numpy as np
import requests
from skimage.transform import resize
import torch
import subprocess
import string
import random
from autokeras.constant import Constant
from scipy.ndimage import zoom


class NoImprovementError(Exception):
Expand Down Expand Up @@ -194,61 +194,54 @@ def read_image(img_path):


def compute_image_resize_params(data):
"""Compute median height and width of all images in data.
"""Compute median dimension of all images in data.
These values are used to resize the images at later point. Number of channels do not change from the original
images. Currently, only 2-D images are supported.
It used to resize the images later. Number of channels do not change from the original data.
Args:
data: 2-D Image data with shape N x H x W x C.
data: 1-D, 2-D or 3-D images. The Images are expected to have channel last configuration.
Returns:
median height: Median height of all images in the data.
median width: Median width of all images in the data.
median shape.
"""
if len(data.shape) == 1 and len(data[0].shape) != 3:
return None, None
if data is None or len(data.shape) == 0:
return []

median_height, median_width = numpy.median(numpy.array(list(map(lambda x: x.shape, data))), axis=0)[:2]
data_shapes = []
for x in data:
data_shapes.append(x.shape)

if median_height * median_width > Constant.MAX_IMAGE_SIZE:
reduction_factor = numpy.sqrt(median_height * median_width / Constant.MAX_IMAGE_SIZE)
median_height = median_height / reduction_factor
median_width = median_width / reduction_factor
median_shape = np.median(np.array(data_shapes), axis=0)
median_size = np.prod(median_shape[:-1])

return int(median_height), int(median_width)
if median_size > Constant.MAX_IMAGE_SIZE:
reduction_factor = np.power(Constant.MAX_IMAGE_SIZE / median_size, 1 / (len(median_shape) - 1))
median_shape[:-1] = median_shape[:-1] * reduction_factor

return median_shape.astype(int)

def resize_image_data(data, height, width):
"""Resize images to provided height and width.

Resize all images in data to size h x w x c, where h is the height, w is the width and c is the number of channels.
The number of channels c does not change from data. The function supports only 2-D image data.
def resize_image_data(data, resize_shape):
"""Resize images to given dimension.
Args:
data: 2-D Image data with shape N x H x W x C.
height: Image resize height.
width: Image resize width.
data: 1-D, 2-D or 3-D images. The Images are expected to have channel last configuration.
resize_shape: Image resize dimension.
Returns:
data: Resize data.
data: Reshaped data.
"""
if data is None:
if data is None or len(resize_shape) == 0:
return data

if len(data.shape) == 4 and data[0].shape[0] == height and data[0].shape[1] == width:
if len(data.shape) > 1 and np.array_equal(data[0].shape, resize_shape):
return data

output_data = []
for im in data:
if len(im.shape) != 3:
return data
output_data.append(resize(image=im,
output_shape=(height, width, im.shape[-1]),
mode='edge',
preserve_range=True))

return numpy.array(output_data)
output_data.append(zoom(input=im, zoom=np.divide(resize_shape, im.shape)))

return np.array(output_data)


def get_system():
Expand Down
1 change: 0 additions & 1 deletion setup.py
Expand Up @@ -10,7 +10,6 @@
'numpy==1.14.5',
'keras==2.2.2',
'scikit-learn==0.19.1',
'scikit-image==0.13.1',
'tqdm==4.25.0',
'tensorflow==1.10.0',
'imageio==2.4.1',
Expand Down
46 changes: 23 additions & 23 deletions tests/test_utils.py
@@ -1,4 +1,4 @@
import numpy
import numpy as np
import subprocess
from unittest.mock import patch

Expand Down Expand Up @@ -56,30 +56,30 @@ def test_fetch(_):


def test_compute_image_resize_params():
# Case-1: Compute median height and width for smaller images.
data = numpy.array([numpy.random.randint(256, size=(10, 10, 3)),
numpy.random.randint(256, size=(20, 20, 3)),
numpy.random.randint(256, size=(30, 30, 3)),
numpy.random.randint(256, size=(40, 40, 3))])
resize_height, resize_width = compute_image_resize_params(data)
assert resize_height == 25
assert resize_width == 25

modified_data = resize_image_data(data, resize_height, resize_width)
# Case-1: Compute median shape for smaller (3-D) images.
data = np.array([np.random.randint(256, size=(5, 5, 5, 3)),
np.random.randint(256, size=(6, 6, 6, 3)),
np.random.randint(256, size=(7, 7, 7, 3)),
np.random.randint(256, size=(8, 8, 8, 3))])
resize_shape = compute_image_resize_params(data)
assert np.array_equal(resize_shape, [6, 6, 6, 3])

modified_data = resize_image_data(data, resize_shape)
for image in modified_data:
assert image.shape == (25, 25, 3)

# Case-2: Resize to max size for larger images.
data = numpy.array([numpy.random.randint(256, size=(int(numpy.sqrt(Constant.MAX_IMAGE_SIZE) + 1),
int(numpy.sqrt(Constant.MAX_IMAGE_SIZE) + 1),
3))])
resize_height, resize_width = compute_image_resize_params(data)
assert resize_height == int(numpy.sqrt(Constant.MAX_IMAGE_SIZE))
assert resize_width == int(numpy.sqrt(Constant.MAX_IMAGE_SIZE))

modified_data = resize_image_data(data, resize_height, resize_width)
assert np.array_equal(image.shape, [6, 6, 6, 3])

# Case-2: Resize to max size for larger (2-D) images.
data = np.array([np.random.randint(256, size=(int(np.sqrt(Constant.MAX_IMAGE_SIZE) + 1),
int(np.sqrt(Constant.MAX_IMAGE_SIZE) + 1),
3))])
resize_shape = compute_image_resize_params(data)
assert np.array_equal(resize_shape, (int(np.sqrt(Constant.MAX_IMAGE_SIZE)),
int(np.sqrt(Constant.MAX_IMAGE_SIZE)),
3))

modified_data = resize_image_data(data, resize_shape)
for image in modified_data:
assert image.shape == (resize_height, resize_width, 3)
assert np.array_equal(image.shape, resize_shape)


def test_get_system():
Expand Down

0 comments on commit de67ca1

Please sign in to comment.