From e2a8f0bf0dc0077ffc38a7d85cb72492573eba38 Mon Sep 17 00:00:00 2001 From: Haifeng Jin Date: Wed, 28 Nov 2018 17:27:53 -0600 Subject: [PATCH] bug_fix --- autokeras/constant.py | 2 +- autokeras/net_transformer.py | 13 +++++++++++-- autokeras/nn/graph.py | 4 +++- autokeras/nn/layers.py | 3 ++- autokeras/search.py | 1 + autokeras/text/text_preprocessor.py | 12 ++++++------ tests/nn/test_graph.py | 1 + 7 files changed, 25 insertions(+), 11 deletions(-) diff --git a/autokeras/constant.py b/autokeras/constant.py index 7a6c28e86..d36338e4d 100644 --- a/autokeras/constant.py +++ b/autokeras/constant.py @@ -37,7 +37,7 @@ class Constant: MAX_NO_IMPROVEMENT_NUM = 5 MAX_BATCH_SIZE = 128 LIMIT_MEMORY = False - SEARCH_MAX_ITER = 200 + SEARCH_MAX_ITER = 1 # text preprocessor diff --git a/autokeras/net_transformer.py b/autokeras/net_transformer.py index 2f78639fb..ffd917968 100644 --- a/autokeras/net_transformer.py +++ b/autokeras/net_transformer.py @@ -45,9 +45,18 @@ def to_skip_connection_graph(graph): return graph -def create_new_layer(input_shape, n_dim): +def create_new_layer(layer, n_dim): + input_shape = layer.output.shape dense_deeper_classes = [StubDense, get_dropout_class(n_dim), StubReLU] conv_deeper_classes = [get_conv_class(n_dim), get_batch_norm_class(n_dim), StubReLU] + if is_layer(layer, 'ReLU'): + conv_deeper_classes = [get_conv_class(n_dim), get_batch_norm_class(n_dim)] + dense_deeper_classes = [StubDense, get_dropout_class(n_dim)] + elif is_layer(layer, 'Dropout'): + dense_deeper_classes = [StubDense, StubReLU] + elif is_layer(layer, 'BatchNormalization'): + conv_deeper_classes = [get_conv_class(n_dim), StubReLU] + if len(input_shape) == 1: # It is in the dense layer part. layer_class = sample(dense_deeper_classes, 1)[0] @@ -85,7 +94,7 @@ def to_deeper_graph(graph): for layer_id in deeper_layer_ids: layer = graph.layer_list[layer_id] - new_layer = create_new_layer(layer.output.shape, graph.n_dim) + new_layer = create_new_layer(layer, graph.n_dim) graph.to_deeper_model(layer_id, new_layer) return graph diff --git a/autokeras/nn/graph.py b/autokeras/nn/graph.py index 306abc387..8c9f5c1b7 100644 --- a/autokeras/nn/graph.py +++ b/autokeras/nn/graph.py @@ -2,7 +2,6 @@ from copy import deepcopy, copy from queue import Queue -import keras import numpy as np import torch @@ -505,6 +504,8 @@ def _insert_pooling_layer_chain(self, start_node_id, end_node_id): if is_layer(new_layer, 'Conv'): filters = self.node_list[start_node_id].shape[-1] new_layer = get_conv_class(self.n_dim)(filters, filters, 1, layer.stride) + if self.weighted: + init_conv_weight(new_layer) else: new_layer = deepcopy(layer) skip_output_id = self.add_layer(new_layer, skip_output_id) @@ -685,6 +686,7 @@ def set_weight_to_graph(self): class KerasModel: def __init__(self, graph): + import keras self.graph = graph self.layers = [] for layer in graph.layer_list: diff --git a/autokeras/nn/layers.py b/autokeras/nn/layers.py index 50c3e68cd..b3b7da581 100644 --- a/autokeras/nn/layers.py +++ b/autokeras/nn/layers.py @@ -2,7 +2,6 @@ import torch from torch import nn -from keras import layers from torch.nn import functional from autokeras.constant import Constant @@ -417,6 +416,7 @@ def forward(self, input_tensor): def keras_dropout(layer, rate): + from keras import layers input_dim = len(layer.input.shape) if input_dim == 2: return layers.SpatialDropout1D(rate) @@ -429,6 +429,7 @@ def keras_dropout(layer, rate): def to_real_keras_layer(layer): + from keras import layers if is_layer(layer, 'Dense'): return layers.Dense(layer.units, input_shape=(layer.input_units,)) if is_layer(layer, 'Conv'): diff --git a/autokeras/search.py b/autokeras/search.py index e6bf36f85..38077571d 100644 --- a/autokeras/search.py +++ b/autokeras/search.py @@ -113,6 +113,7 @@ def add_model(self, metric_value, loss, graph, model_id): if self.verbose: print('\nSaving model.') + graph.clear_operation_history() pickle_to_file(graph, os.path.join(self.path, str(model_id) + '.graph')) ret = {'model_id': model_id, 'loss': loss, 'metric_value': metric_value} diff --git a/autokeras/text/text_preprocessor.py b/autokeras/text/text_preprocessor.py index 1d44f3e5a..3fdf1cf27 100644 --- a/autokeras/text/text_preprocessor.py +++ b/autokeras/text/text_preprocessor.py @@ -3,12 +3,6 @@ import GPUtil import numpy as np -import tensorflow as tf -from keras import Input, Model -from keras import backend -from keras.layers import Embedding -from keras_preprocessing.sequence import pad_sequences -from keras_preprocessing.text import Tokenizer from autokeras.constant import Constant from autokeras.utils import download_file_with_extract, temp_path_generator, ensure_dir @@ -61,6 +55,8 @@ def tokenlize_text(max_num_words, max_seq_length, x_train): x_train: Tokenlized input data. word_index: Dictionary contains word with tokenlized index. """ + from keras_preprocessing.sequence import pad_sequences + from keras_preprocessing.text import Tokenizer print("tokenlizing texts...") tokenizer = Tokenizer(num_words=max_num_words) tokenizer.fit_on_texts(x_train) @@ -138,6 +134,7 @@ def processing(path, word_index, input_length, x_train): Returns: x_train: Numpy array as processed x_train. """ + import tensorflow as tf embedding_matrix = load_pretrain(path=path, word_index=word_index) @@ -149,6 +146,9 @@ def processing(path, word_index, input_length, x_train): os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) device = '/gpu:0' with tf.device(device): + from keras import Input, Model + from keras import backend + from keras.layers import Embedding config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.allow_growth = True sess = tf.Session(config=config) diff --git a/tests/nn/test_graph.py b/tests/nn/test_graph.py index b6884c979..99726b818 100644 --- a/tests/nn/test_graph.py +++ b/tests/nn/test_graph.py @@ -163,6 +163,7 @@ def test_node_consistency(): def test_produce_keras_model(): + import keras for graph in [get_conv_dense_model(), get_add_skip_model(), get_pooling_model(),