Skip to content

Commit

Permalink
Merge b5f539b into a901c00
Browse files Browse the repository at this point in the history
  • Loading branch information
satyakesav committed Nov 13, 2018
2 parents a901c00 + b5f539b commit b0d7fa9
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 4 deletions.
4 changes: 4 additions & 0 deletions autokeras/constant.py
Expand Up @@ -52,6 +52,10 @@ class Constant:
PRE_TRAIN_FILE_LINK = "http://nlp.stanford.edu/data/glove.6B.zip"
PRE_TRAIN_FILE_NAME = "glove.6B.100d.txt"

# Image Resize

MAX_IMAGE_SIZE = 128 * 128

# SYS Constant

SYS_LINUX = 'linux'
Expand Down
29 changes: 26 additions & 3 deletions autokeras/image/image_supervised.py
Expand Up @@ -14,7 +14,7 @@
from autokeras.preprocessor import OneHotEncoder, ImageDataTransformer
from autokeras.supervised import Supervised, PortableClass
from autokeras.utils import has_file, pickle_from_file, pickle_to_file, temp_folder_generator, validate_xy, \
read_csv_file, read_image
read_csv_file, read_image, compute_image_resize_params, resize_image_data


def read_images(img_file_names, images_dir_path):
Expand Down Expand Up @@ -116,6 +116,9 @@ def __init__(self, verbose=False, path=None, resume=False, searcher_args=None, a
self.augment = augment
self.cnn = CnnModule(self.loss, self.metric, searcher_args, path, verbose)

self.resize_height = None
self.resize_width = None

@property
@abstractmethod
def metric(self):
Expand All @@ -128,6 +131,13 @@ def loss(self):

def fit(self, x, y, x_test=None, y_test=None, time_limit=None):
x = np.array(x)

if len(x.shape) != 0 and len(x[0].shape) == 3:
self.resize_height, self.resize_width = compute_image_resize_params(x)
x = resize_image_data(x, self.resize_height, self.resize_width)
if x_test is not None:
x_test = resize_image_data(x_test, self.resize_height, self.resize_width)

y = np.array(y).flatten()
validate_xy(x, y)
y = self.transform_y(y)
Expand Down Expand Up @@ -192,6 +202,8 @@ def inverse_transform_y(self, output):

def evaluate(self, x_test, y_test):
"""Return the accuracy score between predict value and `y_test`."""
if len(x_test.shape) != 0 and len(x_test[0].shape) == 3:
x_test = resize_image_data(x_test, self.resize_height, self.resize_width)
y_predict = self.predict(x_test)
return self.metric().evaluate(y_test, y_predict)

Expand All @@ -209,6 +221,11 @@ def final_fit(self, x_train, y_train, x_test, y_test, trainer_args=None, retrain
if trainer_args is None:
trainer_args = {'max_no_improvement_num': 30}

if len(x_train.shape) != 0 and len(x_train[0].shape) == 3:
x_train = resize_image_data(x_train, self.resize_height, self.resize_width)
if x_test is not None:
x_test = resize_image_data(x_test, self.resize_height, self.resize_width)

y_train = self.transform_y(y_train)
y_test = self.transform_y(y_test)

Expand All @@ -230,7 +247,8 @@ def export_autokeras_model(self, model_file_name):
y_encoder=self.y_encoder,
data_transformer=self.data_transformer,
metric=self.metric,
inverse_transform_y_method=self.inverse_transform_y)
inverse_transform_y_method=self.inverse_transform_y,
resize_params=(self.resize_height, self.resize_width))
pickle_to_file(portable_model, model_file_name)


Expand Down Expand Up @@ -302,7 +320,7 @@ def __init__(self, **kwargs):


class PortableImageSupervised(PortableClass):
def __init__(self, graph, data_transformer, y_encoder, metric, inverse_transform_y_method):
def __init__(self, graph, data_transformer, y_encoder, metric, inverse_transform_y_method, resize_params):
"""Initialize the instance.
Args:
graph: The graph form of the learned model
Expand All @@ -312,6 +330,8 @@ def __init__(self, graph, data_transformer, y_encoder, metric, inverse_transform
self.y_encoder = y_encoder
self.metric = metric
self.inverse_transform_y_method = inverse_transform_y_method
self.resize_height = resize_params[0]
self.resize_width = resize_params[1]

def predict(self, x_test):
"""Return predict results for the testing data.
Expand All @@ -324,6 +344,7 @@ def predict(self, x_test):
"""
if Constant.LIMIT_MEMORY:
pass

test_loader = self.data_transformer.transform_test(x_test)
model = self.graph.produce_model()
model.eval()
Expand All @@ -340,5 +361,7 @@ def inverse_transform_y(self, output):

def evaluate(self, x_test, y_test):
"""Return the accuracy score between predict value and `y_test`."""
if len(x_test.shape) != 0 and len(x_test.shape) == 3:
x_test = resize_image_data(x_test, self.resize_height, self.resize_width)
y_predict = self.predict(x_test)
return self.metric().evaluate(y_test, y_predict)
52 changes: 52 additions & 0 deletions autokeras/utils.py
Expand Up @@ -7,9 +7,12 @@

import warnings
import imageio
import numpy
import requests
from skimage.transform import resize
import torch
import subprocess

from autokeras.constant import Constant


Expand Down Expand Up @@ -171,6 +174,55 @@ def read_image(img_path):
return img


def compute_image_resize_params(data):
"""Compute median height and width of all images in data. These
values are used to resize the images at later point. Number of
channels do not change from the original images. Currently, only
2-D images are supported.
Args:
data: 2-D Image data with shape N x H x W x C.
Returns:
median height: Median height of all images in the data.
median width: Median width of all images in the data.
"""
median_height, median_width = numpy.median(numpy.array(list(map(lambda x: x.shape, data))), axis=0)[:2]

if median_height * median_width > Constant.MAX_IMAGE_SIZE:
reduction_factor = numpy.sqrt(median_height * median_width / Constant.MAX_IMAGE_SIZE)
median_height = median_height / reduction_factor
median_width = median_width / reduction_factor

return int(median_height), int(median_width)


def resize_image_data(data, h, w):
"""Resize all images in data to size h x w x c, where h is the height,
w is the width and c is the number of channels. The number of channels
c does not change from data. The function supports only 2-D image data.
Args:
data: 2-D Image data with shape N x H x W x C.
h: Image resize height.
w: Image resize width.
Returns:
data: Resize data.
"""

output_data = []
for im in data:
if len(im.shape) != 3:
return data
output_data.append(resize(image=im,
output_shape=(h, w, im.shape[-1]),
mode='edge',
preserve_range=True))

return numpy.array(output_data)


def get_system():
"""
Get the current system environment. If the current system is not supported,
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Expand Up @@ -3,6 +3,7 @@ torch==0.4.1
torchvision==0.2.1
numpy==1.14.5
scikit-learn==0.19.1
scikit-image==0.13.1
keras==2.2.2
tqdm==4.25.0
tensorflow==1.10.0
Expand Down
51 changes: 51 additions & 0 deletions tests/image/temp_test.py
@@ -0,0 +1,51 @@
from unittest.mock import patch

import pytest

from autokeras.image.image_supervised import *
from tests.common import clean_dir, MockProcess, simple_transform, mock_train, TEST_TEMP_DIR


@patch('torch.multiprocessing.get_context', side_effect=MockProcess)
@patch('autokeras.search.ModelTrainer.train_model', side_effect=mock_train)
def test_fit_predict(_, _1):
Constant.MAX_ITER_NUM = 1
Constant.MAX_MODEL_NUM = 4
Constant.SEARCH_MAX_ITER = 1
Constant.T_MIN = 0.8
Constant.DATA_AUGMENTATION = False

clf = ImageClassifier(path=TEST_TEMP_DIR, verbose=True)
train_x = np.random.rand(100, 25, 25, 1)
train_y = np.random.randint(0, 5, 100)
clf.fit(train_x, train_y)
results = clf.predict(train_x)
assert all(map(lambda result: result in train_y, results))

clf = ImageClassifier1D(path=TEST_TEMP_DIR, verbose=True)
train_x = np.random.rand(100, 25, 1)
train_y = np.random.randint(0, 5, 100)
clf.fit(train_x, train_y)
results = clf.predict(train_x)
assert all(map(lambda result: result in train_y, results))

clf = ImageClassifier3D(path=TEST_TEMP_DIR, verbose=True)
train_x = np.random.rand(100, 25, 25, 25, 1)
train_y = np.random.randint(0, 5, 100)
clf.fit(train_x, train_y)
results = clf.predict(train_x)
assert all(map(lambda result: result in train_y, results))

clf = ImageRegressor1D(path=TEST_TEMP_DIR, verbose=True)
train_x = np.random.rand(100, 25, 1)
train_y = np.random.randint(0, 5, 100)
clf.fit(train_x, train_y)
results = clf.predict(train_x)
assert len(results) == len(train_y)

clf = ImageRegressor3D(path=TEST_TEMP_DIR, verbose=True)
train_x = np.random.rand(100, 25, 25, 25, 1)
train_y = np.random.randint(0, 5, 100)
clf.fit(train_x, train_y)
results = clf.predict(train_x)
assert len(results) == len(train_y)
31 changes: 30 additions & 1 deletion tests/test_utils.py
@@ -1,8 +1,10 @@
import numpy
import subprocess
from unittest.mock import patch

from autokeras.constant import Constant
from autokeras.utils import temp_folder_generator, download_file, get_system, get_device
from autokeras.utils import temp_folder_generator, download_file, get_system, get_device, compute_image_resize_params, \
resize_image_data
from tests.common import clean_dir, TEST_TEMP_DIR, mock_nvidia_smi_output


Expand Down Expand Up @@ -47,6 +49,33 @@ def test_fetch(_):
clean_dir(TEST_TEMP_DIR)


def test_compute_image_resize_params():
# Case-1: Compute median height and width for smaller images.
data = numpy.array([numpy.random.randint(256, size=(10, 10, 3)),
numpy.random.randint(256, size=(20, 20, 3)),
numpy.random.randint(256, size=(30, 30, 3)),
numpy.random.randint(256, size=(40, 40, 3))])
resize_height, resize_width = compute_image_resize_params(data)
assert resize_height == 25
assert resize_width == 25

modified_data = resize_image_data(data, resize_height, resize_width)
for image in modified_data:
assert image.shape == (25, 25, 3)

# Case-2: Resize to max size for larger images.
data = numpy.array([numpy.random.randint(256, size=(int(numpy.sqrt(Constant.MAX_IMAGE_SIZE)+1),
int(numpy.sqrt(Constant.MAX_IMAGE_SIZE)+1),
3))])
resize_height, resize_width = compute_image_resize_params(data)
assert resize_height == int(numpy.sqrt(Constant.MAX_IMAGE_SIZE))
assert resize_width == int(numpy.sqrt(Constant.MAX_IMAGE_SIZE))

modified_data = resize_image_data(data, resize_height, resize_width)
for image in modified_data:
assert image.shape == (resize_height, resize_width, 3)


def test_get_system():
sys_name = get_system()
assert \
Expand Down

0 comments on commit b0d7fa9

Please sign in to comment.