Skip to content

Commit

Permalink
Merge eee7459 into bfee0c2
Browse files Browse the repository at this point in the history
  • Loading branch information
haifeng-jin committed Dec 15, 2019
2 parents bfee0c2 + eee7459 commit 8767f29
Show file tree
Hide file tree
Showing 19 changed files with 634 additions and 460 deletions.
7 changes: 6 additions & 1 deletion autokeras/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from autokeras.auto_model import AutoModel
from autokeras.auto_model import GraphAutoModel
from autokeras.const import Constant
from autokeras.hypermodel.base import Block
from autokeras.hypermodel.base import Head
from autokeras.hypermodel.base import HyperBlock
from autokeras.hypermodel.base import Node
from autokeras.hypermodel.base import Preprocessor
from autokeras.hypermodel.block import ConvBlock
from autokeras.hypermodel.block import DenseBlock
from autokeras.hypermodel.block import EmbeddingBlock
Expand All @@ -21,7 +26,7 @@
from autokeras.hypermodel.node import TextInput
from autokeras.hypermodel.preprocessor import FeatureEngineering
from autokeras.hypermodel.preprocessor import ImageAugmentation
from autokeras.hypermodel.preprocessor import LightGBMBlock
from autokeras.hypermodel.preprocessor import LightGBM
from autokeras.hypermodel.preprocessor import Normalization
from autokeras.hypermodel.preprocessor import TextToIntSequence
from autokeras.hypermodel.preprocessor import TextToNgramVector
Expand Down
12 changes: 10 additions & 2 deletions autokeras/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ class AutoModel(object):
or maximize, e.g. 'val_accuracy'. Defaults to 'val_loss'.
tuner: String. It should be one of 'greedy', 'bayesian', 'hyperband' or
'random'. Defaults to 'greedy'.
overwrite: Boolean. Defaults to `False`. If `False`, reloads an existing
project of the same name if one is found. Otherwise, overwrites the
project.
seed: Int. Random seed.
"""

Expand All @@ -44,6 +47,7 @@ def __init__(self,
directory=None,
objective='val_loss',
tuner='greedy',
overwrite=False,
seed=None):
self.inputs = nest.flatten(inputs)
self.outputs = nest.flatten(outputs)
Expand All @@ -55,6 +59,7 @@ def __init__(self,
self.objective = objective
# TODO: Support passing a tuner instance.
self.tuner = tuner_module.get_tuner_class(tuner)
self.overwrite = overwrite
self._split_dataset = False
if all([isinstance(output_node, base.Head)
for output_node in self.outputs]):
Expand Down Expand Up @@ -119,7 +124,8 @@ def fit(self,
validation_split=validation_split)

# Initialize the hyper_graph.
self._meta_build(dataset)
if not self.hyper_graph:
self._meta_build(dataset)

# Initialize the Tuner.
# The hypermodel needs input_shape, which can only be known after
Expand All @@ -136,6 +142,7 @@ def fit(self,
hyper_graph=self.hyper_graph,
hypermodel=keras_graph,
fit_on_val_data=self._split_dataset,
overwrite=self.overwrite,
objective=self.objective,
max_trials=self.max_trials,
directory=self.directory,
Expand Down Expand Up @@ -323,4 +330,5 @@ def __init__(self,
)

def _meta_build(self, dataset):
self.hyper_graph = graph.HyperGraph(self.inputs, self.outputs)
self.hyper_graph = graph.HyperGraph(inputs=self.inputs,
outputs=self.outputs)
25 changes: 20 additions & 5 deletions autokeras/encoder.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import kerastuner
import numpy as np
import tensorflow as tf

from autokeras.hypermodel import base

class Encoder(kerastuner.engine.stateful.Stateful):

class Encoder(base.Picklable):
"""Base class for encoders of the prediction targets.
# Arguments
Expand Down Expand Up @@ -44,13 +46,14 @@ def decode(self, data):
"""
raise NotImplementedError

def get_config(self):
return {'num_classes': self.num_classes}

def get_state(self):
return {'num_classes': self.num_classes,
'labels': self._labels,
return {'labels': self._labels,
'int_to_label': self._int_to_label}

def set_state(self, state):
self.num_classes = state['num_classes']
self._labels = state['labels']
self._int_to_label = state['int_to_label']

Expand Down Expand Up @@ -210,3 +213,15 @@ def decode(self, data):
"""
return np.array(list(map(lambda x: self._int_to_label[int(round(x[0]))],
np.array(data)))).reshape(-1, 1)


def serialize(encoder):
return tf.keras.utils.serialize_keras_object(encoder)


def deserialize(config, custom_objects=None):
return tf.keras.utils.deserialize_keras_object(
config,
module_objects=globals(),
custom_objects=custom_objects,
printable_module_name='encoder')
128 changes: 67 additions & 61 deletions autokeras/hypermodel/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,62 @@
import pickle

import kerastuner
import numpy as np
import pandas as pd
import tensorflow as tf
from kerastuner.engine import stateful
from tensorflow.python.util import nest

from autokeras import utils


class Node(kerastuner.engine.stateful.Stateful):
class Picklable(stateful.Stateful):
"""The mixin for saving and loading config and weights for HyperModels.
We define weights for any hypermodel as something that can only be know after
seeing the data. The rest of the states are configs.
"""

def get_config(self):
"""Returns the current config of this object.
# Returns
Dictionary.
"""
raise NotImplementedError

@classmethod
def from_config(cls, config):
"""Build an instance from the config of this object.
# Arguments
config: Dict. The config of the object.
"""
return cls(**config)

def save(self, fname):
"""Save state to file.
# Arguments
fname: String. The path to a file to save the state.
"""
state = self.get_state()
with tf.io.gfile.GFile(fname, 'wb') as f:
pickle.dump(state, f)
return str(fname)

def reload(self, fname):
"""Load state from file.
# Arguments
fname: String. The path to a file to load the state.
"""
with tf.io.gfile.GFile(fname, 'rb') as f:
state = pickle.load(f)
self.set_state(state)


class Node(Picklable):
"""The nodes in a network connecting the blocks."""

def __init__(self, shape=None):
Expand All @@ -25,14 +74,17 @@ def add_out_block(self, hypermodel):
def build(self):
return tf.keras.Input(shape=self.shape)

def get_config(self):
return {}

def get_state(self):
return {'shape': self.shape}

def set_state(self, state):
self.shape = state['shape']


class Block(kerastuner.HyperModel, kerastuner.engine.stateful.Stateful):
class Block(kerastuner.HyperModel, Picklable):
"""The base class for different Block.
The Block can be connected together to build the search space
Expand Down Expand Up @@ -99,21 +151,19 @@ def build(self, hp, inputs=None):
"""
return super().build(hp)

def get_state(self):
def get_config(self):
"""Get the configuration of the preprocessor.
# Returns
A dictionary of configurations of the preprocessor.
"""
return {'name': self.name}

def set_state(self, state):
"""Set the configuration of the preprocessor.
def get_state(self):
return {}

# Arguments
state: A dictionary of the configurations of the preprocessor.
"""
self.name = state['name']
def set_state(self, state):
pass


class Head(Block):
Expand All @@ -136,21 +186,20 @@ def __init__(self, loss=None, metrics=None, output_shape=None, **kwargs):
# Mark if the head should directly output the input tensor.
self.identity = False

def get_state(self):
state = super().get_state()
state.update({
'output_shape': self.output_shape,
def get_config(self):
config = super().get_config()
config.update({
'loss': self.loss,
'metrics': self.metrics,
'identity': self.identity
})
return state
return config

def get_state(self):
return {'output_shape': self.output_shape,
'identity': self.identity}

def set_state(self, state):
super().set_state(state)
self.output_shape = state['output_shape']
self.loss = state['loss']
self.metrics = state['metrics']
self.identity = state['identity']

def build(self, hp, inputs=None):
Expand Down Expand Up @@ -301,46 +350,3 @@ def output_shape(self):
def finalize(self):
"""Training process of the preprocessor after update with all instances."""
pass

def get_config(self):
"""Get the configuration of the preprocessor.
# Returns
A dictionary of configurations of the preprocessor.
"""
return {}

def set_config(self, config):
"""Set the configuration of the preprocessor.
# Arguments
config: A dictionary of the configurations of the preprocessor.
"""
pass

def get_weights(self):
"""Get the trained weights of the preprocessor.
# Returns
A dictionary of trained weights of the preprocessor.
"""
return {}

def set_weights(self, weights):
"""Set the trained weights of the preprocessor.
# Arguments
weights: A dictionary of trained weights of the preprocessor.
"""
pass

def get_state(self):
state = super().get_state()
state.update(self.get_config())
return {'config': state,
'weights': self.get_weights()}

def set_state(self, state):
self.set_config(state['config'])
super().set_state(state['config'])
self.set_weights(state['weights'])

0 comments on commit 8767f29

Please sign in to comment.