Skip to content

Commit

Permalink
Merge 3e269f2 into 1567122
Browse files Browse the repository at this point in the history
  • Loading branch information
haifeng-jin committed Nov 27, 2018
2 parents 1567122 + 3e269f2 commit 1a0e3b4
Show file tree
Hide file tree
Showing 21 changed files with 465 additions and 489 deletions.
43 changes: 36 additions & 7 deletions autokeras/bayesian.py
Expand Up @@ -13,11 +13,34 @@

from autokeras.constant import Constant
from autokeras.net_transformer import transform
from autokeras.nn.layers import is_layer


def layer_distance(a, b):
"""The distance between two layers."""
return abs(a - b) * 1.0 / max(a, b)
if type(a) != type(b):
return 1.0
if is_layer(a, 'Conv'):
att_diff = [(a.filters, b.filters),
(a.kernel_size, b.kernel_size),
(a.stride, b.stride)]
return attribute_difference(att_diff)
if is_layer(a, 'Pooling'):
att_diff = [(a.padding, b.padding),
(a.kernel_size, b.kernel_size),
(a.stride, b.stride)]
return attribute_difference(att_diff)
return 0.0


def attribute_difference(att_diff):
ret = 0
for a_value, b_value in att_diff:
if max(a_value, b_value) == 0:
ret += 0
else:
ret += abs(a_value - b_value) * 1.0 / max(a_value, b_value)
return ret * 1.0 / len(att_diff)


def layers_distance(list_a, list_b):
Expand Down Expand Up @@ -64,9 +87,7 @@ def edit_distance(x, y):
The edit-distance between x and y.
"""

ret = 0
ret += layers_distance(x.conv_widths, y.conv_widths)
ret += layers_distance(x.dense_widths, y.dense_widths)
ret = layers_distance(x.layers, y.layers)
ret += Constant.KERNEL_LAMBDA * skip_connections_distance(x.skip_connections, y.skip_connections)
return ret

Expand All @@ -77,6 +98,7 @@ class IncrementalGaussianProcess:
Attributes:
alpha: A hyperparameter.
"""

def __init__(self):
self.alpha = 1e-10
self._distance_matrix = None
Expand Down Expand Up @@ -266,6 +288,7 @@ class BayesianOptimizer:
beta: The beta in acquisition function. (refer to our paper)
search_tree: The network morphism search tree.
"""

def __init__(self, searcher, t_min, metric, beta):
self.searcher = searcher
self.t_min = t_min
Expand All @@ -284,12 +307,13 @@ def fit(self, x_queue, y_queue):
"""
self.gpr.fit(x_queue, y_queue)

def generate(self, descriptors, timeout):
def generate(self, descriptors, timeout, multiprocessing_queue):
"""Generate new architecture.
Args:
descriptors: All the searched neural architectures.
timeout: An integer. The time limit in seconds.
multiprocessing_queue: the Queue for multiprocessing return value.
Returns:
graph: An instance of Graph. A morphed neural network with weights.
Expand Down Expand Up @@ -318,11 +342,13 @@ def generate(self, descriptors, timeout):
pq.put(elem_class(metric_value, model_id, graph))

t = 1.0
t_min = self.t_min
# t_min = self.t_min
alpha = 0.9
opt_acq = self._get_init_opt_acq_value()
remaining_time = timeout
while not pq.empty() and t > t_min and remaining_time > 0:
while not pq.empty() and remaining_time > 0:
if multiprocessing_queue.qsize() != 0:
break
elem = pq.get()
if self.metric.higher_better():
temp_exp = min((elem.metric_value - opt_acq) / t, 1.0)
Expand Down Expand Up @@ -379,6 +405,7 @@ def add_child(self, father_id, model_id):
@total_ordering
class Elem:
"""Elements to be sorted according to metric value."""

def __init__(self, metric_value, father_id, graph):
self.father_id = father_id
self.graph = graph
Expand All @@ -393,6 +420,7 @@ def __lt__(self, other):

class ReverseElem(Elem):
"""Elements to be reversely sorted according to metric value."""

def __lt__(self, other):
return self.metric_value > other.metric_value

Expand All @@ -407,6 +435,7 @@ def contain(descriptors, target_descriptor):

class SearchTree:
"""The network morphism search tree."""

def __init__(self):
self.root = None
self.adj_list = {}
Expand Down
2 changes: 1 addition & 1 deletion autokeras/constant.py
Expand Up @@ -14,7 +14,7 @@ class Constant:
N_NEIGHBOURS = 8
MAX_MODEL_SIZE = (1 << 25)
MAX_LAYER_WIDTH = 4096
MAX_LAYERS = 100
MAX_LAYERS = 500

# Model Defaults

Expand Down
63 changes: 37 additions & 26 deletions autokeras/net_transformer.py
Expand Up @@ -5,7 +5,9 @@
from autokeras.nn.graph import NetworkDescriptor

from autokeras.constant import Constant
from autokeras.nn.layers import is_layer
from autokeras.nn.layer_transformer import init_dense_weight, init_conv_weight, init_bn_weight
from autokeras.nn.layers import is_layer, StubDense, get_dropout_class, StubReLU, get_conv_class, \
get_batch_norm_class, get_pooling_class


def to_wider_graph(graph):
Expand Down Expand Up @@ -53,6 +55,37 @@ def to_skip_connection_graph(graph):
return graph


def create_new_layer(input_shape, n_dim):
dense_deeper_classes = [StubDense, get_dropout_class(n_dim), StubReLU]
conv_deeper_classes = [get_conv_class(n_dim), get_batch_norm_class(n_dim), StubReLU]
if len(input_shape) == 1:
# It is in the dense layer part.
layer_class = sample(dense_deeper_classes, 1)[0]
else:
# It is in the conv layer part.
layer_class = sample(conv_deeper_classes, 1)[0]

if layer_class == StubDense:
new_layer = StubDense(input_shape[0], input_shape[0])

elif layer_class == get_dropout_class(n_dim):
new_layer = layer_class(Constant.DENSE_DROPOUT_RATE)

elif layer_class == get_conv_class(n_dim):
new_layer = layer_class(input_shape[-1], input_shape[-1], sample((1, 3, 5), 1)[0], stride=1)

elif layer_class == get_batch_norm_class(n_dim):
new_layer = layer_class(input_shape[-1])

elif layer_class == get_pooling_class(n_dim):
new_layer = layer_class(sample((1, 3, 5), 1)[0])

else:
new_layer = layer_class()

return new_layer


def to_deeper_graph(graph):
weighted_layer_ids = graph.deep_layer_ids()
if len(weighted_layer_ids) >= Constant.MAX_LAYERS:
Expand All @@ -62,21 +95,11 @@ def to_deeper_graph(graph):

for layer_id in deeper_layer_ids:
layer = graph.layer_list[layer_id]
if is_layer(layer, 'Conv'):
graph.to_conv_deeper_model(layer_id, 3)
else:
graph.to_dense_deeper_model(layer_id)
new_layer = create_new_layer(layer.output.shape, graph.n_dim)
graph.to_deeper_model(layer_id, new_layer)
return graph


def legal_graph(graph):
descriptor = graph.extract_descriptor()
skips = descriptor.skip_connections
if len(skips) != len(set(skips)):
return False
return True


def transform(graph):
graphs = []
for i in range(Constant.N_NEIGHBOURS * 2):
Expand All @@ -95,16 +118,4 @@ def transform(graph):
if len(graphs) >= Constant.N_NEIGHBOURS:
break

return list(filter(lambda x: legal_graph(x), graphs))


def default_transform(graph):
graph = deepcopy(graph)
graph.to_conv_deeper_model(1, 3)
graph.to_conv_deeper_model(1, 3)
graph.to_conv_deeper_model(5, 3)
graph.to_conv_deeper_model(9, 3)
graph.to_add_skip_model(1, 18)
graph.to_add_skip_model(18, 24)
graph.to_add_skip_model(24, 27)
return [graph]
return graphs
74 changes: 42 additions & 32 deletions autokeras/nn/generator.py
Expand Up @@ -15,6 +15,7 @@ class NetworkGenerator:
n_output_node: Number of output nodes in the network.
input_shape: A tuple to represent the input shape.
"""

def __init__(self, n_output_node, input_shape):
"""Initialize the instance.
Expand Down Expand Up @@ -77,10 +78,16 @@ def generate(self, model_len=Constant.MODEL_LEN, model_width=Constant.MODEL_WIDT
graph = Graph(self.input_shape, False)
temp_input_channel = self.input_shape[-1]
output_node_id = 0
stride = 1
for i in range(model_len):
output_node_id = graph.add_layer(StubReLU(), output_node_id)
output_node_id = graph.add_layer(self.conv(temp_input_channel, model_width, kernel_size=3), output_node_id)
output_node_id = graph.add_layer(self.batch_norm(model_width), output_node_id)
output_node_id = graph.add_layer(self.batch_norm(graph.node_list[output_node_id].shape[-1]), output_node_id)
output_node_id = graph.add_layer(self.conv(temp_input_channel,
model_width,
kernel_size=3,
stride=stride), output_node_id)
# if stride == 1:
# stride = 2
temp_input_channel = model_width
if pooling_len == 0 or ((i + 1) % pooling_len == 0 and i != model_len - 1):
output_node_id = graph.add_layer(self.pooling(), output_node_id)
Expand Down Expand Up @@ -143,60 +150,63 @@ def generate(self, model_len=Constant.MLP_MODEL_LEN, model_width=Constant.MLP_MO
class ResNetGenerator(NetworkGenerator):
def __init__(self, n_output_node, input_shape):
super(ResNetGenerator, self).__init__(n_output_node, input_shape)
self.layers = [3, 4, 6, 3]
# self.layers = [2, 2, 2, 2]
self.in_planes = 64
self.block_expansion = 1
self.n_dim = len(self.input_shape) - 1
if len(self.input_shape) > 4:
raise ValueError('The input dimension is too high.')
elif len(self.input_shape) < 2:
raise ValueError('The input dimension is too low.')
self.inplanes = 64
self.conv = get_conv_class(self.n_dim)
self.dropout = get_dropout_class(self.n_dim)
self.global_avg_pooling = get_global_avg_pooling_class(self.n_dim)
self.adaptive_avg_pooling = get_global_avg_pooling_class(self.n_dim)
self.pooling = get_pooling_class(self.n_dim)
self.batch_norm = get_batch_norm_class(self.n_dim)

def generate(self, model_len, model_width):
def generate(self, model_len=Constant.MODEL_LEN, model_width=Constant.MODEL_WIDTH):
graph = Graph(self.input_shape, False)
temp_input_channel = self.input_shape[-1]
output_node_id = 0
output_node_id = graph.add_layer(StubReLU(), output_node_id)
output_node_id = graph.add_layer(self.conv(temp_input_channel, model_width, kernel_size=7), output_node_id)
# output_node_id = graph.add_layer(StubReLU(), output_node_id)
output_node_id = graph.add_layer(self.conv(temp_input_channel, model_width, kernel_size=3), output_node_id)
output_node_id = graph.add_layer(self.batch_norm(model_width), output_node_id)
output_node_id = graph.add_layer(self.pooling(kernel_size=3, stride=2, padding=1), output_node_id)
for layer in self.layers:
output_node_id = self._make_layer(graph, model_width, layer, output_node_id)
model_width *= 2
# output_node_id = graph.add_layer(self.pooling(kernel_size=3, stride=2, padding=1), output_node_id)

output_node_id = self._make_layer(graph, model_width, 2, output_node_id, 1)
model_width *= 2
output_node_id = self._make_layer(graph, model_width, 2, output_node_id, 2)
model_width *= 2
output_node_id = self._make_layer(graph, model_width, 2, output_node_id, 2)
model_width *= 2
output_node_id = self._make_layer(graph, model_width, 2, output_node_id, 2)

output_node_id = graph.add_layer(self.global_avg_pooling(), output_node_id)
graph.add_layer(StubDense(int(model_width / 2) * self.block_expansion, self.n_output_node), output_node_id)
graph.add_layer(StubDense(model_width * self.block_expansion, self.n_output_node), output_node_id)
return graph

def _make_layer(self, graph, planes, blocks, node_id):
downsample = None
if self.inplanes != planes * self.block_expansion:
downsample = [
self.conv(self.inplanes, planes * self.block_expansion, kernel_size=1),
self.batch_norm(planes * self.block_expansion),
]
out = self._make_block(graph, self.inplanes, planes, node_id, downsample)
self.inplanes = planes * self.block_expansion
for _ in range(1, blocks):
out = self._make_block(graph, self.inplanes, planes, out)
def _make_layer(self, graph, planes, blocks, node_id, stride):
strides = [stride] + [1] * (blocks - 1)
out = node_id
for current_stride in strides:
out = self._make_block(graph, self.in_planes, planes, out, current_stride)
self.in_planes = planes * self.block_expansion
return out

def _make_block(self, graph, inplanes, planes, node_id, downsample=None):
residual_node_id = node_id
out = graph.add_layer(StubReLU(), node_id)
out = graph.add_layer(self.conv(inplanes, planes, kernel_size=1), out)
def _make_block(self, graph, in_planes, planes, node_id, stride=1):
out = graph.add_layer(self.batch_norm(in_planes), node_id)
out = graph.add_layer(StubReLU(), out)
residual_node_id = out
out = graph.add_layer(self.conv(in_planes, planes, kernel_size=3, stride=stride), out)
out = graph.add_layer(self.batch_norm(planes), out)
out = graph.add_layer(StubReLU(), out)
out = graph.add_layer(self.conv(planes, planes, kernel_size=3), out)
out = graph.add_layer(self.batch_norm(planes), out)
if downsample is not None:
downsample_out = graph.add_layer(StubReLU(), node_id)
downsample_out = graph.add_layer(downsample[0], downsample_out)
residual_node_id = graph.add_layer(downsample[1], downsample_out)

residual_node_id = graph.add_layer(StubReLU(), residual_node_id)
residual_node_id = graph.add_layer(self.conv(in_planes,
planes * self.block_expansion,
kernel_size=1,
stride=stride), residual_node_id)
out = graph.add_layer(StubAdd(), (out, residual_node_id))
return out

0 comments on commit 1a0e3b4

Please sign in to comment.