Skip to content

Commit

Permalink
[MRG] Created MlpModule (#301)
Browse files Browse the repository at this point in the history
* Created MlpModule

* Changed generators to a list

* Changed name of subclass while calling super

* Fix for failed test cases
  • Loading branch information
droidadroit authored and haifeng-jin committed Nov 15, 2018
1 parent d30f350 commit 6f02446
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 21 deletions.
2 changes: 1 addition & 1 deletion autokeras/image/image_supervised.py
Expand Up @@ -7,7 +7,7 @@
import torch
from sklearn.model_selection import train_test_split

from autokeras.cnn_module import CnnModule
from autokeras.net_module import CnnModule
from autokeras.constant import Constant
from autokeras.nn.loss_function import classification_loss, regression_loss
from autokeras.nn.metric import Accuracy, MSE
Expand Down
17 changes: 16 additions & 1 deletion autokeras/cnn_module.py → autokeras/net_module.py
Expand Up @@ -4,16 +4,18 @@
from autokeras.constant import Constant
from autokeras.search import Searcher, train
from autokeras.utils import pickle_to_file
from autokeras.nn.generator import CnnGenerator, MlpGenerator


class CnnModule(object):
class NetworkModule:
def __init__(self, loss, metric, searcher_args, path, verbose=False):
self.searcher_args = searcher_args
self.searcher = None
self.path = path
self.verbose = verbose
self.loss = loss
self.metric = metric
self.generators = []

def fit(self, n_output_node, input_shape, train_data, test_data, time_limit=24 * 60 * 60):
""" Search the best CnnModule.
Expand All @@ -34,6 +36,7 @@ def fit(self, n_output_node, input_shape, train_data, test_data, time_limit=24 *
self.searcher_args['path'] = self.path
self.searcher_args['metric'] = self.metric
self.searcher_args['loss'] = self.loss
self.searcher_args['generators'] = self.generators
self.searcher_args['verbose'] = self.verbose
self.searcher = Searcher(**self.searcher_args)
pickle_to_file(self, os.path.join(self.path, 'module'))
Expand Down Expand Up @@ -85,3 +88,15 @@ def final_fit(self, train_data, test_data, trainer_args=None, retrain=False):
@property
def best_model(self):
return self.searcher.load_best_model()


class CnnModule(NetworkModule):
def __init__(self, loss, metric, searcher_args, path, verbose=False):
super(CnnModule, self).__init__(loss, metric, searcher_args, path, verbose)
self.generators.append(CnnGenerator)


class MlpModule(NetworkModule):
def __init__(self, loss, metric, searcher_args, path, verbose=False):
super(MlpModule, self).__init__(loss, metric, searcher_args, path, verbose)
self.generators.append(MlpGenerator)
29 changes: 16 additions & 13 deletions autokeras/search.py
Expand Up @@ -7,8 +7,8 @@

from autokeras.bayesian import edit_distance, BayesianOptimizer
from autokeras.constant import Constant
from autokeras.nn.generator import CnnGenerator, MlpGenerator
from autokeras.net_transformer import default_transform
from autokeras.nn.generator import CnnGenerator
from autokeras.nn.model_trainer import ModelTrainer
from autokeras.utils import pickle_to_file, pickle_from_file, verbose_print, get_system

Expand Down Expand Up @@ -42,7 +42,7 @@ class Searcher:
t_min: A float. The minimum temperature during simulated annealing.
"""

def __init__(self, n_output_node, input_shape, path, metric, loss, verbose,
def __init__(self, n_output_node, input_shape, path, metric, loss, generators, verbose,
trainer_args=None,
default_model_len=Constant.MODEL_LEN,
default_model_width=Constant.MODEL_WIDTH,
Expand Down Expand Up @@ -72,6 +72,7 @@ def __init__(self, n_output_node, input_shape, path, metric, loss, verbose,
self.metric = metric
self.loss = loss
self.path = path
self.generators = generators
self.model_count = 0
self.descriptors = []
self.trainer_args = trainer_args
Expand Down Expand Up @@ -145,18 +146,20 @@ def add_model(self, metric_value, loss, graph, model_id):
def init_search(self):
if self.verbose:
print('\nInitializing search.')
graph = CnnGenerator(self.n_classes,
self.input_shape).generate(self.default_model_len,
self.default_model_width)
model_id = self.model_count
self.model_count += 1
self.training_queue.append((graph, -1, model_id))
self.descriptors.append(graph.extract_descriptor())
for child_graph in default_transform(graph):
child_id = self.model_count
graph, model_id = None, None
for generator in self.generators:
graph = generator(self.n_classes, self.input_shape).\
generate(self.default_model_len, self.default_model_width)
model_id = self.model_count
self.model_count += 1
self.training_queue.append((child_graph, model_id, child_id))
self.descriptors.append(child_graph.extract_descriptor())
self.training_queue.append((graph, -1, model_id))
self.descriptors.append(graph.extract_descriptor())
if graph is not None and model_id is not None:
for child_graph in default_transform(graph):
child_id = self.model_count
self.model_count += 1
self.training_queue.append((child_graph, model_id, child_id))
self.descriptors.append(child_graph.extract_descriptor())
if self.verbose:
print('Initialization finished.')

Expand Down
2 changes: 1 addition & 1 deletion autokeras/text/text_supervised.py
Expand Up @@ -6,7 +6,7 @@
import torch
from sklearn.model_selection import train_test_split

from autokeras.cnn_module import CnnModule
from autokeras.net_module import CnnModule
from autokeras.constant import Constant
from autokeras.nn.loss_function import classification_loss, regression_loss
from autokeras.nn.metric import Accuracy, MSE
Expand Down
2 changes: 1 addition & 1 deletion tests/test_preprocessor.py
@@ -1,6 +1,6 @@
from unittest.mock import patch

from autokeras.cnn_module import CnnModule
from autokeras.net_module import CnnModule
from autokeras.nn.loss_function import classification_loss
from autokeras.nn.metric import Accuracy
from autokeras.preprocessor import *
Expand Down
9 changes: 5 additions & 4 deletions tests/test_search.py
Expand Up @@ -3,6 +3,7 @@
from autokeras.nn.loss_function import classification_loss
from autokeras.nn.metric import Accuracy
from autokeras.search import *
from autokeras.nn.generator import CnnGenerator

from tests.common import clean_dir, MockProcess, get_classification_data_loaders, get_add_skip_model, \
get_concat_skip_model, simple_transform, MockMemoryOutProcess, TEST_TEMP_DIR
Expand All @@ -19,7 +20,7 @@ def test_bayesian_searcher(_, _1, _2):
train_data, test_data = get_classification_data_loaders()
clean_dir(TEST_TEMP_DIR)
generator = Searcher(3, (28, 28, 3), verbose=False, path=TEST_TEMP_DIR, metric=Accuracy,
loss=classification_loss)
loss=classification_loss, generators=[CnnGenerator])
Constant.N_NEIGHBOURS = 1
Constant.T_MIN = 0.8
for _ in range(2):
Expand All @@ -44,7 +45,7 @@ def test_export_json(_, _1, _2):

clean_dir(TEST_TEMP_DIR)
generator = Searcher(3, (28, 28, 3), verbose=False, path=TEST_TEMP_DIR, metric=Accuracy,
loss=classification_loss)
loss=classification_loss, generators=[CnnGenerator])
Constant.N_NEIGHBOURS = 1
Constant.T_MIN = 0.8
for _ in range(3):
Expand Down Expand Up @@ -75,7 +76,7 @@ def test_max_acq(_, _1, _2):
Constant.T_MIN = 0.8
Constant.BETA = 1
generator = Searcher(3, (28, 28, 3), verbose=False, path=TEST_TEMP_DIR, metric=Accuracy,
loss=classification_loss)
loss=classification_loss, generators=[CnnGenerator])
for _ in range(3):
generator.search(train_data, test_data)
for index1, descriptor1 in enumerate(generator.descriptors):
Expand All @@ -92,7 +93,7 @@ def test_out_of_memory(_, _1, _2):
train_data, test_data = get_classification_data_loaders()
clean_dir(TEST_TEMP_DIR)
searcher = Searcher(3, (28, 28, 3), verbose=False, path=TEST_TEMP_DIR, metric=Accuracy,
loss=classification_loss)
loss=classification_loss, generators=[CnnGenerator])
Constant.N_NEIGHBOURS = 1
Constant.T_MIN = 0.8
for _ in range(4):
Expand Down

0 comments on commit 6f02446

Please sign in to comment.