Skip to content

Commit

Permalink
Merge 4946b9b into b3f47e4
Browse files Browse the repository at this point in the history
  • Loading branch information
haifeng-jin committed Aug 6, 2018
2 parents b3f47e4 + 4946b9b commit d0dd57c
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 4 deletions.
55 changes: 54 additions & 1 deletion autokeras/graph.py
Expand Up @@ -3,12 +3,15 @@
from queue import Queue
import numpy as np
import torch
import keras
from keras import layers

from autokeras.constant import Constant
from autokeras.layer_transformer import wider_bn, wider_next_conv, wider_next_dense, wider_pre_dense, wider_pre_conv, \
deeper_conv_block, dense_to_deeper_block, add_noise
from autokeras.layers import StubConcatenate, StubAdd, StubConv, is_layer, layer_width, to_real_layer, \
set_torch_weight_to_stub, set_stub_weight_to_torch, StubBatchNormalization, StubReLU, StubDropout
to_real_keras_layer, set_torch_weight_to_stub, set_stub_weight_to_torch, set_stub_weight_to_keras, \
set_keras_weight_to_stub, StubBatchNormalization, StubReLU, StubDropout


class NetworkDescriptor:
Expand Down Expand Up @@ -536,6 +539,10 @@ def produce_model(self):
"""Build a new model based on the current graph."""
return TorchModel(self)

def produce_keras_model(self):
"""Build a new keras model based on the current graph."""
return KerasModel(self).model

def _layer_ids_in_order(self, layer_ids):
node_id_to_order_index = {}
for index, node_id in enumerate(self.topological_order):
Expand Down Expand Up @@ -605,3 +612,49 @@ def set_weight_to_graph(self):
self.graph.weighted = True
for index, layer in enumerate(self.layers):
set_torch_weight_to_stub(layer, self.graph.layer_list[index])


class KerasModel():
def __init__(self, graph):
super(KerasModel, self).__init__()
self.graph = graph
self._layers = []
for layer in graph.layer_list:
self._layers.append(to_real_keras_layer(layer))

# Construct the keras graph.
# Input
topo_node_list = self.graph.topological_order
output_id = topo_node_list[-1]
input_id = topo_node_list[0]
input_tensor = keras.layers.Input(shape=graph.node_list[input_id].shape)

node_list = deepcopy(self.graph.node_list)
node_list[input_id] = input_tensor

# Output
for v in topo_node_list:
for u, layer_id in self.graph.reverse_adj_list[v]:
layer = self.graph.layer_list[layer_id]
keras_layer = self._layers[layer_id]

if isinstance(layer, (StubAdd, StubConcatenate)):
edge_input_tensor = list(map(lambda x: node_list[x],
self.graph.layer_id_to_input_node_ids[layer_id]))
else:
edge_input_tensor = node_list[u]

temp_tensor = keras_layer(edge_input_tensor)
node_list[v] = temp_tensor

output_tensor = node_list[output_id]
self.model = keras.models.Model(inputs=input_tensor, outputs=output_tensor)

if graph.weighted:
for index, layer in enumerate(self._layers):
set_stub_weight_to_keras(self.graph.layer_list[index], layer)

def set_weight_to_graph(self):
self.graph.weighted = True
for index, layer in enumerate(self._layers):
set_keras_weight_to_stub(layer, self.graph.layer_list[index])
66 changes: 65 additions & 1 deletion autokeras/layers.py
@@ -1,6 +1,6 @@
import torch
from torch import nn

from keras import layers

class StubLayer:
def __init__(self, input_node=None, output_node=None):
Expand All @@ -17,9 +17,15 @@ def set_weights(self, weights):
def import_weights(self, torch_layer):
pass

def import_weights_keras(self, keras_layer):
pass

def export_weights(self, torch_layer):
pass

def export_weights_keras(self, keras_layer):
pass

def get_weights(self):
return self.weights

Expand All @@ -32,10 +38,16 @@ class StubWeightBiasLayer(StubLayer):
def import_weights(self, torch_layer):
self.set_weights((torch_layer.weight.data.cpu().numpy(), torch_layer.bias.data.cpu().numpy()))

def import_weights_keras(self, keras_layer):
self.set_weights(keras_layer.get_weights())

def export_weights(self, torch_layer):
torch_layer.weight.data = torch.Tensor(self.weights[0])
torch_layer.bias.data = torch.Tensor(self.weights[1])

def export_weights_keras(self, keras_layer):
keras_layer.set_weights(self.weights)


class StubBatchNormalization(StubWeightBiasLayer):
def __init__(self, num_features, input_node=None, output_node=None):
Expand Down Expand Up @@ -66,6 +78,12 @@ def __init__(self, input_units, units, input_node=None, output_node=None):
def output_shape(self):
return self.units,

def import_weights_keras(self, keras_layer):
self.set_weights((keras_layer.weights[0].T, keras_layer.weights[1]))

def export_weights_keras(self, keras_layer):
keras_layer.set_weights((self.weights[0].T, self.weights[1]))


class StubConv(StubWeightBiasLayer):
def __init__(self, input_channel, filters, kernel_size, input_node=None, output_node=None):
Expand Down Expand Up @@ -202,6 +220,18 @@ def forward(self, input_tensor):
return input_tensor.view(input_tensor.size(0), -1)


def KerasDropout(layer, rate):
input_dim = len(layer.input.shape)
if input_dim == 2:
return layers.SpatialDropout1D(rate)
elif input_dim == 3:
return layers.SpatialDropout2D(rate)
elif input_dim == 4:
return layers.SpatialDropout3D(rate)
else:
return layers.Dropout(rate)


def to_real_layer(layer):
if is_layer(layer, 'Dense'):
return torch.nn.Linear(layer.input_units, layer.units)
Expand All @@ -228,9 +258,43 @@ def to_real_layer(layer):
return TorchFlatten()


def to_real_keras_layer(layer):
if is_layer(layer, 'Dense'):
return layers.Dense(layer.units, input_shape=(layer.input_units, ))
if is_layer(layer, 'Conv'):
return layers.Conv2D(layer.filters,
layer.kernel_size,
input_shape=layer.input.shape,
padding='same') # padding
if is_layer(layer, 'Pooling'):
return layers.MaxPool2D(2)
if is_layer(layer, 'BatchNormalization'):
return layers.BatchNormalization(input_shape=layer.input.shape)
if is_layer(layer, 'Concatenate'):
return layers.Concatenate()
if is_layer(layer, 'Add'):
return layers.Add()
if is_layer(layer, 'Dropout'):
return KerasDropout(layer, layer.rate)
if is_layer(layer, 'ReLU'):
return layers.Activation('relu')
if is_layer(layer, 'Softmax'):
return layers.Activation('softmax')
if is_layer(layer, 'Flatten'):
return layers.Flatten()


def set_torch_weight_to_stub(torch_layer, stub_layer):
stub_layer.import_weights(torch_layer)


def set_keras_weight_to_stub(keras_layer, stub_layer):
stub_layer.import_weights_keras(keras_layer)


def set_stub_weight_to_torch(stub_layer, torch_layer):
stub_layer.export_weights(torch_layer)


def set_stub_weight_to_keras(stub_layer, keras_layer):
stub_layer.export_weights_keras(keras_layer)
2 changes: 1 addition & 1 deletion requirements.txt
Expand Up @@ -3,7 +3,7 @@ torch==0.4.0
torchvision==0.2.1
numpy==1.14.5
scikit-learn==0.19.1
keras
keras==2.2.2
tensorflow
pytest
pytest-cov
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -3,7 +3,7 @@
setup(
name='autokeras',
packages=['autokeras'], # this must be the same as the name above
install_requires=['torch==0.4.0', 'torchvision==0.2.1', 'numpy==1.14.5', 'keras', 'scikit-learn==0.19.1', 'tensorflow'],
install_requires=['torch==0.4.0', 'torchvision==0.2.1', 'numpy==1.14.5', 'keras==2.2.2', 'scikit-learn==0.19.1', 'tensorflow'],
version='0.2.0',
description='Automated Machine Learning with Keras',
author='Haifeng Jin',
Expand Down

0 comments on commit d0dd57c

Please sign in to comment.