Skip to content

Commit

Permalink
initial gpr
Browse files Browse the repository at this point in the history
  • Loading branch information
haifeng-jin committed Apr 24, 2018
1 parent ef16d7e commit aa00a12
Show file tree
Hide file tree
Showing 10 changed files with 101 additions and 80 deletions.
6 changes: 5 additions & 1 deletion autokeras/bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,12 @@ def incremental_fit(self, train_x, train_y):

return self

@property
def first_fitted(self):
return self._first_fitted

def first_fit(self, train_x, train_y):
train_x, train_y = np.array([train_x]), np.array([train_y])
train_x, train_y = np.array(train_x), np.array(train_y)

self._x = np.copy(train_x)
self._y = np.copy(train_y)
Expand Down
20 changes: 6 additions & 14 deletions autokeras/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,20 +53,12 @@ def read_csv_file(csv_file_path):
"""
img_file_names = []
img_labels = []
try:
with open(csv_file_path, 'r') as images_path:
path_list = csv.DictReader(images_path)
fieldnames = path_list.fieldnames
for path in path_list:
img_file_names.append(path[fieldnames[0]])
img_labels.append(path[fieldnames[1]])
except IOError as e:
if e.errno == errno.EACCES:
raise IOError('File not accessible')
elif e.errno == errno.ENOENT:
raise IOError('No such file or directory exist')
else:
raise ValueError("Illegal file type")
with open(csv_file_path, 'r') as images_path:
path_list = csv.DictReader(images_path)
fieldnames = path_list.fieldnames
for path in path_list:
img_file_names.append(path[fieldnames[0]])
img_labels.append(path[fieldnames[1]])
return img_file_names, img_labels


Expand Down
3 changes: 2 additions & 1 deletion autokeras/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# Searcher

MAX_MODEL_NUM = 1000
ACQ_EXPLOITATION_DEPTH = 4
ACQ_EXPLOITATION_DEPTH = 2

# Model Defaults

Expand All @@ -29,3 +29,4 @@
EPOCHS_EACH = 1
MAX_BATCH_SIZE = 32
LIMIT_MEMORY = False
SEARCH_MAX_ITER = 30
19 changes: 12 additions & 7 deletions autokeras/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,25 +555,30 @@ def produce_model(self):
output_id = self.node_to_id[self.output]

new_to_old_layer = {}
self.node_list[input_id] = input_tensor
self.node_to_id[input_tensor] = input_id

node_list = deepcopy(self.node_list)
node_list[input_id] = input_tensor

node_to_id = deepcopy(self.node_to_id)
node_to_id[input_tensor] = input_id

for v in self._topological_order():
for u, layer_id in self.reverse_adj_list[v]:
layer = self.layer_list[layer_id]

if isinstance(layer, (StubWeightedAdd, StubConcatenate)):
edge_input_tensor = list(map(lambda x: self.node_list[x],
edge_input_tensor = list(map(lambda x: node_list[x],
self.layer_id_to_input_node_ids[layer_id]))
else:
edge_input_tensor = self.node_list[u]
edge_input_tensor = node_list[u]

new_layer = to_real_layer(layer)
new_to_old_layer[new_layer] = layer

temp_tensor = new_layer(edge_input_tensor)
self.node_list[v] = temp_tensor
self.node_to_id[temp_tensor] = v
model = Model(input_tensor, self.node_list[output_id])
node_list[v] = temp_tensor
node_to_id[temp_tensor] = v
model = Model(input_tensor, node_list[output_id])
for layer in model.layers[1:]:
if not isinstance(layer, (Activation, Dropout, Concatenate)):
old_layer = new_to_old_layer[layer]
Expand Down
2 changes: 1 addition & 1 deletion autokeras/net_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def transform(graph):
graph: the model from which we get new model
Returns:
The new model
A list of graphs.
"""
graphs = []
for target_id in graph.wide_layer_ids():
Expand Down
51 changes: 37 additions & 14 deletions autokeras/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class Searcher:
path: place that store searcher
model_count: the id of model
"""

def __init__(self, n_classes, input_shape, path, verbose):
"""Init Searcher class with n_classes, input_shape, path, verbose
Expand Down Expand Up @@ -62,15 +63,20 @@ def get_best_model_id(self):
def replace_model(self, model, model_id):
model.save(os.path.join(self.path, str(model_id) + '.h5'))

def add_model(self, model, x_train, y_train, x_test, y_test):
def add_model(self, model, x_train, y_train, x_test, y_test, max_iter=constant.MAX_ITER_NUM):
"""add one model while will be trained to history list
Returns:
History object.
"""
if self.verbose:
model.summary()
ModelTrainer(model, x_train, y_train, x_test, y_test, self.verbose).train_model()
ModelTrainer(model,
x_train,
y_train,
x_test,
y_test,
self.verbose).train_model(max_iter_num=max_iter)
loss, accuracy = model.evaluate(x_test, y_test, verbose=self.verbose)
model.save(os.path.join(self.path, str(self.model_count) + '.h5'))
plot_model(model, to_file=os.path.join(self.path, str(self.model_count) + '.png'), show_shapes=True)
Expand All @@ -88,6 +94,7 @@ class RandomSearcher(Searcher):
RandomSearcher implements its search function with random strategy
"""

def __init__(self, n_classes, input_shape, path, verbose):
"""Init RandomSearcher with n_classes, input_shape, path, verbose"""
super().__init__(n_classes, input_shape, path, verbose)
Expand All @@ -108,6 +115,7 @@ class HillClimbingSearcher(Searcher):
HillClimbing Searcher implements its search function with hill climbing strategy
"""

def __init__(self, n_classes, input_shape, path, verbose):
"""Init HillClimbing Searcher with n_classes, input_shape, path, verbose"""
super().__init__(n_classes, input_shape, path, verbose)
Expand Down Expand Up @@ -155,27 +163,42 @@ def __init__(self, n_classes, input_shape, path, verbose):
super().__init__(n_classes, input_shape, path, verbose)
self.gpr = IncrementalGaussianProcess()
self.search_tree = SearchTree()
self.init_search_queue = None
self.init_gpr_x = []
self.init_gpr_y = []

def search(self, x_train, y_train, x_test, y_test):
if not self.history:
model = DefaultClassifierGenerator(self.n_classes, self.input_shape).generate()
history_item = self.add_model(model, x_train, y_train, x_test, y_test)
self.search_tree.add_child(-1, history_item['model_id'])
self.gpr.first_fit(Graph(model).extract_descriptor(), history_item['accuracy'])
pickle_to_file(self, os.path.join(self.path, 'searcher'))
del model
backend.clear_session()

else:
model_ids = self.search_tree.get_leaves()
new_model, father_id = self.maximize_acq(model_ids)
graph = Graph(model)
self.init_search_queue = transform(graph)
self.init_gpr_x.append(graph.extract_descriptor())
self.init_gpr_y.append(history_item['accuracy'])
pickle_to_file(self, os.path.join(self.path, 'searcher'))
return

history_item = self.add_model(new_model, x_train, y_train, x_test, y_test)
self.search_tree.add_child(father_id, history_item['model_id'])
self.gpr.incremental_fit(Graph(new_model).extract_descriptor(), history_item['accuracy'])
if self.init_search_queue:
graph = self.init_search_queue.pop()
model = graph.produce_model()
history_item = self.add_model(model, x_train, y_train, x_test, y_test, constant.SEARCH_MAX_ITER)
self.init_gpr_x.append(graph.extract_descriptor())
self.init_gpr_y.append(history_item['accuracy'])
pickle_to_file(self, os.path.join(self.path, 'searcher'))
del new_model
backend.clear_session()
return

if not self.init_search_queue and not self.gpr.first_fitted:
self.gpr.first_fit(self.init_gpr_x, self.init_gpr_y)

model_ids = self.search_tree.get_leaves()
new_model, father_id = self.maximize_acq(model_ids)

history_item = self.add_model(new_model, x_train, y_train, x_test, y_test, constant.SEARCH_MAX_ITER)
self.search_tree.add_child(father_id, history_item['model_id'])
self.gpr.incremental_fit(Graph(new_model).extract_descriptor(), history_item['accuracy'])
pickle_to_file(self, os.path.join(self.path, 'searcher'))

def maximize_acq(self, model_ids):
overall_max_acq_value = -1
Expand Down
11 changes: 4 additions & 7 deletions autokeras/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,10 @@ def __init__(self, model, x_train, y_train, x_test, y_test, verbose):
else:
self.datagen = None

def _converged(self, loss):
"""Return whether the training is converged"""

def train_model(self):
def train_model(self, max_iter_num=constant.MAX_ITER_NUM, max_no_improvement_num=constant.MAX_NO_IMPROVEMENT_NUM):
"""Train the model with dataset and return the minimum_loss"""
batch_size = min(self.x_train.shape[0], constant.MAX_BATCH_SIZE)
terminator = EarlyStop()
terminator = EarlyStop(max_no_improvement_num=max_no_improvement_num)
lr_scheduler = LearningRateScheduler(lr_schedule)

lr_reducer = ReduceLROnPlateau(factor=np.sqrt(0.1),
Expand All @@ -167,14 +164,14 @@ def train_model(self):
if constant.DATA_AUGMENTATION:
flow = self.datagen.flow(self.x_train, self.y_train, batch_size)
self.model.fit_generator(flow,
epochs=constant.MAX_ITER_NUM,
epochs=max_iter_num,
validation_data=(self.x_test, self.y_test),
callbacks=callbacks,
verbose=self.verbose)
else:
self.model.fit(self.x_train, self.y_train,
batch_size=batch_size,
epochs=constant.MAX_ITER_NUM,
epochs=max_iter_num,
validation_data=(self.x_test, self.y_test),
callbacks=callbacks,
verbose=self.verbose)
Expand Down
4 changes: 3 additions & 1 deletion tests/test_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ def test_edit_distance():

def test_gpr():
gpr = IncrementalGaussianProcess()
gpr.first_fit(Graph(get_add_skip_model()).extract_descriptor(), 0.5)
gpr.first_fit([Graph(get_add_skip_model()).extract_descriptor()], [0.5])
assert gpr.first_fitted

gpr.incremental_fit(Graph(get_concat_skip_model()).extract_descriptor(), 0.6)
assert abs(gpr.predict(np.array([Graph(get_concat_skip_model()).extract_descriptor()]))[0] - 0.6) < 1e-4
59 changes: 28 additions & 31 deletions tests/test_classifier.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from time import sleep
from copy import deepcopy
from unittest.mock import patch

import pytest
Expand Down Expand Up @@ -41,13 +41,17 @@ def start(self):
self.target(*self.args)


def simple_transform(graph):
return [deepcopy(graph), deepcopy(graph)]


@patch('multiprocessing.Process', new=MockProcess)
@patch('autokeras.search.ModelTrainer.train_model', side_effect=lambda: None)
def test_fit_predict(_):
constant.MAX_ITER_NUM = 2
constant.MAX_MODEL_NUM = 2
constant.EPOCHS_EACH = 1
constant.N_NEIGHBORS = 1
@patch('autokeras.search.transform', side_effect=simple_transform)
@patch('autokeras.search.ModelTrainer.train_model')
def test_fit_predict(_, _1):
constant.MAX_ITER_NUM = 1
constant.MAX_MODEL_NUM = 4
constant.SEARCH_MAX_ITER = 1
constant.DATA_AUGMENTATION = False
path = 'tests/resources/temp'
clean_dir(path)
Expand All @@ -60,23 +64,20 @@ def test_fit_predict(_):
clean_dir(path)


@patch('autokeras.search.ModelTrainer.train_model', side_effect=lambda: sleep(6))
def test_timout(_):
def test_timout():
path = 'tests/resources/temp'
clean_dir(path)
clf = ImageClassifier(path=path, verbose=False)
constant.MAX_ITER_NUM = 1
constant.MAX_MODEL_NUM = 1
constant.EPOCHS_EACH = 1
constant.N_NEIGHBORS = 1
train_x = np.random.rand(100, 25, 25, 1)
train_y = np.random.randint(0, 5, 100)
clf.fit(train_x, train_y, time_limit=5)
clf.fit(train_x, train_y, time_limit=1)
clean_dir(path)


@patch('multiprocessing.Process', new=MockProcess)
def test_final_fit():
@patch('autokeras.search.transform', side_effect=simple_transform)
@patch('autokeras.search.ModelTrainer.train_model')
def test_final_fit(_, _1):
constant.LIMIT_MEMORY = True
path = 'tests/resources/temp'
clean_dir(path)
Expand All @@ -85,6 +86,7 @@ def test_final_fit():
constant.MAX_MODEL_NUM = 1
constant.EPOCHS_EACH = 1
constant.N_NEIGHBORS = 1
constant.SEARCH_MAX_ITER = 1
train_x = np.random.rand(100, 25, 25, 1)
train_y = np.random.randint(0, 5, 100)
test_x = np.random.rand(100, 25, 25, 1)
Expand All @@ -96,12 +98,14 @@ def test_final_fit():


@patch('multiprocessing.Process', new=MockProcess)
@patch('autokeras.search.ModelTrainer.train_model', side_effect=lambda: None)
def test_save_continue(_):
@patch('autokeras.search.transform', side_effect=simple_transform)
@patch('autokeras.search.ModelTrainer.train_model')
def test_save_continue(_, _1):
constant.MAX_ITER_NUM = 1
constant.MAX_MODEL_NUM = 1
constant.EPOCHS_EACH = 1
constant.N_NEIGHBORS = 1
constant.SEARCH_MAX_ITER = 1
train_x = np.random.rand(100, 25, 25, 1)
train_y = np.random.randint(0, 5, 100)
test_x = np.random.rand(100, 25, 25, 1)
Expand Down Expand Up @@ -136,12 +140,14 @@ def test_save_continue(_):


@patch('multiprocessing.Process', new=MockProcess)
@patch('autokeras.search.ModelTrainer.train_model', side_effect=lambda: None)
def test_fit_csv_file_1(_):
@patch('autokeras.search.transform', side_effect=simple_transform)
@patch('autokeras.search.ModelTrainer.train_model')
def test_fit_csv_file(_, _1):
constant.MAX_ITER_NUM = 1
constant.MAX_MODEL_NUM = 1
constant.EPOCHS_EACH = 1
constant.N_NEIGHBORS = 1
constant.SEARCH_MAX_ITER = 1
path = 'tests/resources'
clf = ImageClassifier(verbose=False, path=os.path.join(path, "temp"), resume=False)
clf.fit(csv_file_path=os.path.join(path, "images_test/images_name.csv"),
Expand All @@ -153,20 +159,11 @@ def test_fit_csv_file_1(_):
assert len(results) == 5
clean_dir(os.path.join(path, "temp"))

clf = ImageClassifier(verbose=False, path=os.path.join(path, "temp"), resume=True)
clf.fit(csv_file_path=os.path.join(path, "images_test/images_name.csv"),
images_path=os.path.join(path, "images_test/Black_white_images"))
img_file_name, y_train = read_csv_file(csv_file_path=os.path.join(path, "images_test/images_name.csv"))
x_test = read_images(img_file_name, images_dir_path=os.path.join(path, "images_test/Black_white_images"))
results = clf.predict(x_test)
assert len(clf.load_searcher().history) == 1
assert len(results) == 5
clean_dir(os.path.join(path, "temp"))


@patch('multiprocessing.Process', new=MockProcess)
@patch('autokeras.search.ModelTrainer.train_model', side_effect=lambda: None)
def test_cross_validate(_):
@patch('autokeras.search.transform', side_effect=simple_transform)
@patch('autokeras.search.ModelTrainer.train_model')
def test_cross_validate(_, _1):
constant.MAX_ITER_NUM = 2
constant.MAX_MODEL_NUM = 2
constant.EPOCHS_EACH = 1
Expand Down
6 changes: 3 additions & 3 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def simple_transform(graph):


@patch('autokeras.search.transform', side_effect=simple_transform)
@patch('autokeras.search.ModelTrainer.train_model', side_effect=lambda: None)
@patch('autokeras.search.ModelTrainer.train_model')
def test_hill_climbing_searcher(_, _1):
# def test_hill_climbing_searcher(_):
x_train = np.random.rand(2, 28, 28, 1)
Expand All @@ -34,7 +34,7 @@ def test_hill_climbing_searcher(_, _1):
assert len(generator.history) == len(generator.history_configs)


@patch('autokeras.search.ModelTrainer.train_model', side_effect=lambda: None)
@patch('autokeras.search.ModelTrainer.train_model')
def test_random_searcher(_):
x_train = np.random.rand(2, 28, 28, 1)
y_train = np.random.rand(2, 3)
Expand All @@ -50,7 +50,7 @@ def test_random_searcher(_):


@patch('autokeras.search.transform', side_effect=simple_transform)
@patch('autokeras.search.ModelTrainer.train_model', side_effect=lambda: None)
@patch('autokeras.search.ModelTrainer.train_model')
def test_bayesian_searcher(_, _1):
x_train = np.random.rand(2, 28, 28, 1)
y_train = np.random.rand(2, 3)
Expand Down

0 comments on commit aa00a12

Please sign in to comment.