Skip to content

Commit

Permalink
bug_fix
Browse files Browse the repository at this point in the history
  • Loading branch information
haifeng-jin committed Nov 28, 2018
1 parent 925ec74 commit e2a8f0b
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 11 deletions.
2 changes: 1 addition & 1 deletion autokeras/constant.py
Expand Up @@ -37,7 +37,7 @@ class Constant:
MAX_NO_IMPROVEMENT_NUM = 5
MAX_BATCH_SIZE = 128
LIMIT_MEMORY = False
SEARCH_MAX_ITER = 200
SEARCH_MAX_ITER = 1

# text preprocessor

Expand Down
13 changes: 11 additions & 2 deletions autokeras/net_transformer.py
Expand Up @@ -45,9 +45,18 @@ def to_skip_connection_graph(graph):
return graph


def create_new_layer(input_shape, n_dim):
def create_new_layer(layer, n_dim):
input_shape = layer.output.shape
dense_deeper_classes = [StubDense, get_dropout_class(n_dim), StubReLU]
conv_deeper_classes = [get_conv_class(n_dim), get_batch_norm_class(n_dim), StubReLU]
if is_layer(layer, 'ReLU'):
conv_deeper_classes = [get_conv_class(n_dim), get_batch_norm_class(n_dim)]
dense_deeper_classes = [StubDense, get_dropout_class(n_dim)]
elif is_layer(layer, 'Dropout'):
dense_deeper_classes = [StubDense, StubReLU]
elif is_layer(layer, 'BatchNormalization'):
conv_deeper_classes = [get_conv_class(n_dim), StubReLU]

if len(input_shape) == 1:
# It is in the dense layer part.
layer_class = sample(dense_deeper_classes, 1)[0]
Expand Down Expand Up @@ -85,7 +94,7 @@ def to_deeper_graph(graph):

for layer_id in deeper_layer_ids:
layer = graph.layer_list[layer_id]
new_layer = create_new_layer(layer.output.shape, graph.n_dim)
new_layer = create_new_layer(layer, graph.n_dim)
graph.to_deeper_model(layer_id, new_layer)
return graph

Expand Down
4 changes: 3 additions & 1 deletion autokeras/nn/graph.py
Expand Up @@ -2,7 +2,6 @@
from copy import deepcopy, copy
from queue import Queue

import keras
import numpy as np
import torch

Expand Down Expand Up @@ -505,6 +504,8 @@ def _insert_pooling_layer_chain(self, start_node_id, end_node_id):
if is_layer(new_layer, 'Conv'):
filters = self.node_list[start_node_id].shape[-1]
new_layer = get_conv_class(self.n_dim)(filters, filters, 1, layer.stride)
if self.weighted:
init_conv_weight(new_layer)
else:
new_layer = deepcopy(layer)
skip_output_id = self.add_layer(new_layer, skip_output_id)
Expand Down Expand Up @@ -685,6 +686,7 @@ def set_weight_to_graph(self):

class KerasModel:
def __init__(self, graph):
import keras
self.graph = graph
self.layers = []
for layer in graph.layer_list:
Expand Down
3 changes: 2 additions & 1 deletion autokeras/nn/layers.py
Expand Up @@ -2,7 +2,6 @@

import torch
from torch import nn
from keras import layers
from torch.nn import functional

from autokeras.constant import Constant
Expand Down Expand Up @@ -417,6 +416,7 @@ def forward(self, input_tensor):


def keras_dropout(layer, rate):
from keras import layers
input_dim = len(layer.input.shape)
if input_dim == 2:
return layers.SpatialDropout1D(rate)
Expand All @@ -429,6 +429,7 @@ def keras_dropout(layer, rate):


def to_real_keras_layer(layer):
from keras import layers
if is_layer(layer, 'Dense'):
return layers.Dense(layer.units, input_shape=(layer.input_units,))
if is_layer(layer, 'Conv'):
Expand Down
1 change: 1 addition & 0 deletions autokeras/search.py
Expand Up @@ -113,6 +113,7 @@ def add_model(self, metric_value, loss, graph, model_id):
if self.verbose:
print('\nSaving model.')

graph.clear_operation_history()
pickle_to_file(graph, os.path.join(self.path, str(model_id) + '.graph'))

ret = {'model_id': model_id, 'loss': loss, 'metric_value': metric_value}
Expand Down
12 changes: 6 additions & 6 deletions autokeras/text/text_preprocessor.py
Expand Up @@ -3,12 +3,6 @@

import GPUtil
import numpy as np
import tensorflow as tf
from keras import Input, Model
from keras import backend
from keras.layers import Embedding
from keras_preprocessing.sequence import pad_sequences
from keras_preprocessing.text import Tokenizer

from autokeras.constant import Constant
from autokeras.utils import download_file_with_extract, temp_path_generator, ensure_dir
Expand Down Expand Up @@ -61,6 +55,8 @@ def tokenlize_text(max_num_words, max_seq_length, x_train):
x_train: Tokenlized input data.
word_index: Dictionary contains word with tokenlized index.
"""
from keras_preprocessing.sequence import pad_sequences
from keras_preprocessing.text import Tokenizer
print("tokenlizing texts...")
tokenizer = Tokenizer(num_words=max_num_words)
tokenizer.fit_on_texts(x_train)
Expand Down Expand Up @@ -138,6 +134,7 @@ def processing(path, word_index, input_length, x_train):
Returns:
x_train: Numpy array as processed x_train.
"""
import tensorflow as tf

embedding_matrix = load_pretrain(path=path, word_index=word_index)

Expand All @@ -149,6 +146,9 @@ def processing(path, word_index, input_length, x_train):
os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
device = '/gpu:0'
with tf.device(device):
from keras import Input, Model
from keras import backend
from keras.layers import Embedding
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
Expand Down
1 change: 1 addition & 0 deletions tests/nn/test_graph.py
Expand Up @@ -163,6 +163,7 @@ def test_node_consistency():


def test_produce_keras_model():
import keras
for graph in [get_conv_dense_model(),
get_add_skip_model(),
get_pooling_model(),
Expand Down

0 comments on commit e2a8f0b

Please sign in to comment.