Skip to content

Commit

Permalink
1d3d image classifier and regressor added and tested
Browse files Browse the repository at this point in the history
  • Loading branch information
haifeng-jin committed Nov 5, 2018
1 parent 280b5ac commit b35895f
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 14 deletions.
11 changes: 4 additions & 7 deletions autokeras/cnn_module.py
Expand Up @@ -7,11 +7,6 @@
from autokeras.utils import pickle_from_file


def _run_searcher_once(train_data, test_data, path, timeout):
if Constant.LIMIT_MEMORY:
pass
searcher = pickle_from_file(os.path.join(path, 'searcher'))
searcher.search(train_data, test_data, timeout)


class CnnModule(object):
Expand All @@ -28,7 +23,8 @@ def fit(self, n_output_node, input_shape, train_data, test_data, time_limit=24 *
Args:
n_output_node: A integer value represent the number of output node in the final layer.
input_shape: A tuple to express the shape of every train entry. For example,MNIST dataset would be (28,28,1)
input_shape: A tuple to express the shape of every train entry. For example,
MNIST dataset would be (28,28,1)
train_data: A PyTorch DataLoader instance represents the training data
test_data: A PyTorch DataLoader instance represents the testing data
time_limit: A integer value represents the time limit on searching for models.
Expand All @@ -50,7 +46,8 @@ def fit(self, n_output_node, input_shape, train_data, test_data, time_limit=24 *
time_remain = time_limit
try:
while time_remain > 0:
_run_searcher_once(train_data, test_data, self.path, int(time_remain))
searcher = pickle_from_file(os.path.join(self.path, 'searcher'))
searcher.search(train_data, test_data, int(time_remain))
if len(self._load_searcher().history) >= Constant.MAX_MODEL_NUM:
break
time_elapsed = time.time() - start_time
Expand Down
26 changes: 24 additions & 2 deletions autokeras/image/image_supervised.py
Expand Up @@ -74,8 +74,6 @@ class ImageSupervised(Supervised):
path: A path to the directory to save the classifier.
y_encoder: An instance of OneHotEncoder for `y_train` (array of categorical labels).
verbose: A boolean value indicating the verbosity mode.
searcher: An instance of BayesianSearcher. It searches different
neural architecture to find the best model.
searcher_args: A dictionary containing the parameters for the searcher's __init__ function.
augment: A boolean value indicating whether the data needs augmentation. If not define, then it
will use the value of Constant.DATA_AUGMENTATION which is True by default.
Expand Down Expand Up @@ -262,6 +260,18 @@ def metric(self):
return Accuracy


class ImageClassifier1D(ImageClassifier):
def __init__(self, **kwargs):
kwargs['augment'] = False
super().__init__(**kwargs)


class ImageClassifier3D(ImageClassifier):
def __init__(self, **kwargs):
kwargs['augment'] = False
super().__init__(**kwargs)


class ImageRegressor(ImageSupervised):
@property
def loss(self):
Expand All @@ -281,6 +291,18 @@ def inverse_transform_y(self, output):
return output.flatten()


class ImageRegressor1D(ImageRegressor):
def __init__(self, **kwargs):
kwargs['augment'] = False
super().__init__(**kwargs)


class ImageRegressor3D(ImageRegressor):
def __init__(self, **kwargs):
kwargs['augment'] = False
super().__init__(**kwargs)


class PortableImageSupervised(PortableClass):
def __init__(self, graph, data_transformer, y_encoder, metric, inverse_transform_y_method):
"""Initialize the instance.
Expand Down
4 changes: 2 additions & 2 deletions autokeras/nn/generator.py
Expand Up @@ -20,9 +20,9 @@ class CnnGenerator(NetworkGenerator):
def __init__(self, n_output_node, input_shape):
super(CnnGenerator, self).__init__(n_output_node, input_shape)
self.n_dim = len(self.input_shape) - 1
if len(self.input_shape) > 3:
if len(self.input_shape) > 4:
raise ValueError('The input dimension is too high.')
if len(self.input_shape) < 1:
if len(self.input_shape) < 2:
raise ValueError('The input dimension is too low.')
self.conv = get_conv_class(self.n_dim)
self.dropout = get_dropout_class(self.n_dim)
Expand Down
3 changes: 2 additions & 1 deletion autokeras/nn/graph.py
Expand Up @@ -483,7 +483,8 @@ def to_concat_skip_model(self, start_id, end_id):
weights = np.zeros((filters_end, filters_end) + filter_shape)
for i in range(filters_end):
filter_weight = np.zeros((filters_end,) + filter_shape)
filter_weight[(i, 0, 0)] = 1
center_index = (i,) + (0,) * self.n_dim
filter_weight[center_index] = 1
weights[i, ...] = filter_weight
weights = np.concatenate((weights,
np.zeros((filters_end, filters_start) + filter_shape)), axis=1)
Expand Down
2 changes: 1 addition & 1 deletion autokeras/nn/layers.py
Expand Up @@ -290,7 +290,7 @@ def __init__(self, input_node=None, output_node=None):

@property
def output_shape(self):
return self.input.shape[2:]
return self.input.shape[-1],

@abstractmethod
def to_real_layer(self):
Expand Down
8 changes: 7 additions & 1 deletion autokeras/preprocessor.py
Expand Up @@ -148,6 +148,9 @@ def transform_train(self, data, targets=None, batch_size=None):
else:
compose_list = common_list

if len(data.shape) != 4:
compose_list = []

dataset = self._transform(compose_list, data, targets)

if batch_size is None:
Expand All @@ -159,6 +162,8 @@ def transform_train(self, data, targets=None, batch_size=None):
def transform_test(self, data, targets=None, batch_size=None):
common_list = [Normalize(torch.Tensor(self.mean), torch.Tensor(self.std))]
compose_list = common_list
if len(data.shape) != 4:
compose_list = []

dataset = self._transform(compose_list, data, targets)

Expand All @@ -170,7 +175,8 @@ def transform_test(self, data, targets=None, batch_size=None):

def _transform(self, compose_list, data, targets):
data = data / self.max_val
data = torch.Tensor(data.transpose(0, 3, 1, 2))
args = [0, len(data.shape) - 1] + list(range(1, len(data.shape) - 1))
data = torch.Tensor(data.transpose(*args))
data_transforms = Compose(compose_list)
return MultiTransformDataset(data, targets, data_transforms)

Expand Down
30 changes: 30 additions & 0 deletions tests/image/test_image_supervised.py
Expand Up @@ -36,12 +36,42 @@ def test_fit_predict(_):
Constant.T_MIN = 0.8
Constant.DATA_AUGMENTATION = False
clean_dir(TEST_TEMP_DIR)

clf = ImageClassifier(path=TEST_TEMP_DIR, verbose=True)
train_x = np.random.rand(100, 25, 25, 1)
train_y = np.random.randint(0, 5, 100)
clf.fit(train_x, train_y)
results = clf.predict(train_x)
assert all(map(lambda result: result in train_y, results))

clf = ImageClassifier1D(path=TEST_TEMP_DIR, verbose=True)
train_x = np.random.rand(100, 25, 1)
train_y = np.random.randint(0, 5, 100)
clf.fit(train_x, train_y)
results = clf.predict(train_x)
assert all(map(lambda result: result in train_y, results))

clf = ImageClassifier3D(path=TEST_TEMP_DIR, verbose=True)
train_x = np.random.rand(100, 25, 25, 25, 1)
train_y = np.random.randint(0, 5, 100)
clf.fit(train_x, train_y)
results = clf.predict(train_x)
assert all(map(lambda result: result in train_y, results))

clf = ImageRegressor1D(path=TEST_TEMP_DIR, verbose=True)
train_x = np.random.rand(100, 25, 1)
train_y = np.random.randint(0, 5, 100)
clf.fit(train_x, train_y)
results = clf.predict(train_x)
assert len(results) == len(train_y)

clf = ImageRegressor3D(path=TEST_TEMP_DIR, verbose=True)
train_x = np.random.rand(100, 25, 25, 25, 1)
train_y = np.random.randint(0, 5, 100)
clf.fit(train_x, train_y)
results = clf.predict(train_x)
assert len(results) == len(train_y)

clean_dir(TEST_TEMP_DIR)


Expand Down

0 comments on commit b35895f

Please sign in to comment.