Skip to content

Commit

Permalink
Docs (#760)
Browse files Browse the repository at this point in the history
* temp commit

* tutorial

* add io

* temp commit

* temp commit

* bug fix

* tutorial
  • Loading branch information
haifeng-jin committed Aug 30, 2019
1 parent 8b1a2eb commit 13c2850
Show file tree
Hide file tree
Showing 12 changed files with 233 additions and 13 deletions.
3 changes: 0 additions & 3 deletions .travis.yml
Expand Up @@ -6,7 +6,6 @@ jobs:
- stage: test
script:
- pip install -e .[tests] --progress-bar off
- pip install git+git://github.com/keras-team/keras-tuner@master#egg=keras-tuner
- pip install codacy-coverage
- pytest tests --cov=autokeras --cov-report xml:coverage.xml
- if ! [[ -z $CODACY_PROJECT_TOKEN ]]; then
Expand All @@ -16,12 +15,10 @@ jobs:
-
script:
- pip install -e .[tests] --progress-bar off
- pip install git+git://github.com/keras-team/keras-tuner@master#egg=keras-tuner
- flake8
-
script:
- pip install -e .[tests] --progress-bar off
- pip install git+git://github.com/keras-team/keras-tuner@master#egg=keras-tuner
- pip install mkdocs
- pip install mkdocs-material
- sh shell/docs.sh
Expand Down
8 changes: 5 additions & 3 deletions README.md
Expand Up @@ -14,14 +14,14 @@ It is developed by <a href="http://faculty.cs.tamu.edu/xiahu/index.html" target=
The ultimate goal of AutoML is to provide easily accessible deep learning tools to domain experts with limited data science or machine learning background.
Auto-Keras provides functions to automatically search for architecture and hyperparameters of deep learning models.

**Now we are refactoring the code on `master` branch for the next release.
Please use the `legacy` branch if you want to checkout the 0.4 version.**
# AutoKeras 1.0 is coming soon!

## Installation

To install the package, please use the `pip` installation as follows:

pip3 install autokeras
pip3 install autokeras # for 0.4 version
pip3 install git+git://github.com/keras-team/autokeras@master#egg=autokeras # for 1.0 version

**Note:** currently, Auto-Keras is only compatible with: **Python 3.6**.

Expand All @@ -37,6 +37,8 @@ clf.fit(x_train, y_train)
results = clf.predict(x_test)
```

For detailed tutorial, please check [here](https://autokeras.com/tutorial/).

## Cite this work

Haifeng Jin, Qingquan Song, and Xia Hu. "Auto-keras: An efficient neural architecture search system." Proceedings of the 25th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. ACM, 2019. ([Download](https://www.kdd.org/kdd2019/accepted-papers/view/auto-keras-an-efficient-neural-architecture-search-system))
Expand Down
1 change: 1 addition & 0 deletions autokeras/__init__.py
Expand Up @@ -16,6 +16,7 @@
from autokeras.hypermodel.hyperblock import LightGBMRegressorBlock
from autokeras.hypermodel.node import ImageInput
from autokeras.hypermodel.node import Input
from autokeras.hypermodel.node import StructuredInput
from autokeras.hypermodel.node import TextInput
from autokeras.hypermodel.preprocessor import Normalization
from autokeras.hypermodel.preprocessor import TextToIntSequence
Expand Down
9 changes: 8 additions & 1 deletion autokeras/hypermodel/hyperblock.py
Expand Up @@ -185,7 +185,14 @@ def build(self, hp, inputs=None):


class GeneralBlock(HyperBlock):
"""A general neural network block when the input type is unknown. """
"""A general neural network block when the input type is unknown.
When the input type is unknown. The GeneralBlock would search in a large space
for a good model.
# Arguments
name: String.
"""

def build(self, hp, inputs=None):
raise NotImplementedError
31 changes: 29 additions & 2 deletions docs/autogen.py
Expand Up @@ -11,7 +11,10 @@
import six
import autokeras
from autokeras import auto_model
from autokeras import task
from autokeras.hypermodel import block
from autokeras.hypermodel import head
from autokeras.hypermodel import hyperblock

try:
import pathlib
Expand Down Expand Up @@ -114,15 +117,39 @@
auto_model.AutoModel.fit,
auto_model.AutoModel.predict,
]),
]
},
{
'page': 'graph_auto_model.md',
'classes': [
(auto_model.GraphAutoModel, [
auto_model.GraphAutoModel.fit,
auto_model.GraphAutoModel.predict,
]),
]
},
{
'page': 'hypermodel/block.md',
'all_module_classes': [block],
'page': 'block.md',
'classes': [hyperblock.ImageBlock,
hyperblock.TextBlock,
hyperblock.StructuredDataBlock,
block.ResNetBlock,
block.XceptionBlock,
block.ConvBlock,
block.RNNBlock,
block.Merge],
},
{
'page': 'task.md',
'classes': [task.ImageClassifier,
task.ImageRegressor,
task.TextClassifier,
task.TextRegressor],
},
{
'page': 'head.md',
'classes': [head.ClassificationHead,
head.RegressionHead],
},
# {
# 'page': 'hypermodel/block.md',
Expand Down
10 changes: 6 additions & 4 deletions docs/mkdocs.yml
Expand Up @@ -17,11 +17,13 @@ google_analytics: ['UA-44322747-3', 'autokeras.com']
nav:
- Home: index.md
- Getting Started: start.md
- Tutorial for 1.0: tutorial.md
- Docker: docker.md
- Contributing Guide: contributing.md
- Neural Architecture Search: nas.md
- Documentation:
- auto_model: auto_model.md
- hypermodel:
- block: hypermodel/block.md
- Task API: task.md
- AutoModel: auto_model.md
- GraphAutoModel: graph_auto_model.md
- Block: block.md
- Head: head.md
- About: about.md
13 changes: 13 additions & 0 deletions docs/templates/task.md
@@ -0,0 +1,13 @@
# Task API

AutoKeras support the following task APIs.

{{autogenerated}}

### Coming Soon:

StructuredDataClassifier

StructuredDataRegressor

TimeSeriesForecaster
102 changes: 102 additions & 0 deletions docs/templates/tutorial.md
@@ -0,0 +1,102 @@
# AutoKeras 1.0 Tutorial

In AutoKeras, there are 3 levels of APIs: task API, IO API, and functional API.

## Task API
We have designed an extremely simple interface for a series of tasks.
The following code example shows how to do image classification with the task API.

```python
import autokeras as ak
from keras.datasets import mnist

# Prepare the data.
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(x_train.shape + (1,))
x_test = x_test.reshape(x_test.shape + (1,))

# Search and train the classifier.
clf = ak.ImageClassifier(max_trials=100)
clf.fit(x_train, y_train)
y = clf.predict(x_test, y_test)
```

See the [documentation of Task API](/task) for more details.



## IO API

The following code example shows how to use IO API for multi-modal and multi-task scenarios using [AutoModel](/auto_model)

```python
import numpy as np
import autokeras as ak
from keras.datasets import mnist

# Prepare the data.
(x_train, y_classification), (x_test, y_test) = mnist.load_data()
x_image = x_train.reshape(x_train.shape + (1,))
x_test = x_test.reshape(x_test.shape + (1,))

x_structured = np.random.rand(x_train.shape[0], 100)
y_regression = np.random.rand(x_train.shape[0], 1)

# Build model and train.
automodel = ak.AutoModel(
inputs=[ak.ImageInput(),
ak.StructuredDataInput()],
outputs=[ak.RegressionHead(metrics=['mae']),
ak.ClassificationHead(loss='categorical_crossentropy',
metrics=['accuracy'])])
automodel.fit([x_image, x_structured],
[y_regression, y_classification])

```

Now we support `ImageInput`, `TextInput`, and `StructuredDataInput`.

## Functional API

You can also define your own neural architecture with the predefined blocks and [GraphAutoModel](/graph_auto_model).

```python
import autokeras as ak
import numpy as np
import tensorflow as tf
from keras.datasets import mnist

# Prepare the data.
(x_train, y_classification), (x_test, y_test) = mnist.load_data()
x_image = x_train.reshape(x_train.shape + (1,))
x_test = x_test.reshape(x_test.shape + (1,))

x_structured = np.random.rand(x_train.shape[0], 100)
y_regression = np.random.rand(x_train.shape[0], 1)

# Build model and train.
inputs = ak.ImageInput(shape=(28, 28, 1))
outputs1 = ak.ResNetBlock(version='next')(inputs)
outputs2 = ak.XceptionBlock()(inputs)
image_outputs = ak.Merge()((outputs1, outputs2))

structured_inputs = ak.StructuredInput()
structured_outputs = ak.DenseBlock()(structured_inputs)
merged_outputs = ak.Merge()((image_outputs, structured_outputs))

classification_outputs = ak.ClassificationHead()(merged_outputs)
regression_outputs = ak.RegressionHead()(merged_outputs)
automodel = ak.GraphAutoModel(inputs=inputs,
outputs=[regression_outputs,
classification_outputs])

automodel.fit((x_image, x_structured),
(y_regression, y_classification),
trials=100,
epochs=200,
callbacks=[tf.keras.callbacks.EarlyStopping(),
tf.keras.callbacks.LearningRateScheduler()])

```

For complete list of blocks, please checkout the documentation [here](/block).
35 changes: 35 additions & 0 deletions examples/functional_api.py
@@ -0,0 +1,35 @@
import autokeras as ak
import numpy as np
import tensorflow as tf
from keras.datasets import mnist

# Prepare the data.
(x_train, y_classification), (x_test, y_test) = mnist.load_data()
x_image = x_train.reshape(x_train.shape + (1,))
x_test = x_test.reshape(x_test.shape + (1,))

x_structured = np.random.rand(x_train.shape[0], 100)
y_regression = np.random.rand(x_train.shape[0], 1)

# Build model and train.
inputs = ak.ImageInput(shape=(28, 28, 1))
outputs1 = ak.ResNetBlock(version='next')(inputs)
outputs2 = ak.XceptionBlock()(inputs)
image_outputs = ak.Merge()((outputs1, outputs2))

structured_inputs = ak.StructuredInput()
structured_outputs = ak.DenseBlock()(structured_inputs)
merged_outputs = ak.Merge()((image_outputs, structured_outputs))

classification_outputs = ak.ClassificationHead()(merged_outputs)
regression_outputs = ak.RegressionHead()(merged_outputs)
automodel = ak.GraphAutoModel(inputs=inputs,
outputs=[regression_outputs,
classification_outputs])

automodel.fit((x_image, x_structured),
(y_regression, y_classification),
trials=100,
epochs=200,
callbacks=[tf.keras.callbacks.EarlyStopping(),
tf.keras.callbacks.LearningRateScheduler()])
21 changes: 21 additions & 0 deletions examples/io_api.py
@@ -0,0 +1,21 @@
import numpy as np
import autokeras as ak
from keras.datasets import mnist

# Prepare the data.
(x_train, y_classification), (x_test, y_test) = mnist.load_data()
x_image = x_train.reshape(x_train.shape + (1,))
x_test = x_test.reshape(x_test.shape + (1,))

x_structured = np.random.rand(x_train.shape[0], 100)
y_regression = np.random.rand(x_train.shape[0], 1)

# Build model and train.
automodel = ak.AutoModel(
inputs=[ak.ImageInput(),
ak.StructuredInput()],
outputs=[ak.RegressionHead(metrics=['mae']),
ak.ClassificationHead(loss='categorical_crossentropy',
metrics=['accuracy'])])
automodel.fit([x_image, x_structured],
[y_regression, y_classification])
12 changes: 12 additions & 0 deletions examples/task_api.py
@@ -0,0 +1,12 @@
import autokeras as ak
from keras.datasets import mnist

# Prepare the data.
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(x_train.shape + (1,))
x_test = x_test.reshape(x_test.shape + (1,))

# Search and train the classifier.
clf = ak.ImageClassifier(max_trials=100)
clf.fit(x_train, y_train)
y = clf.predict(x_test, y_test)
1 change: 1 addition & 0 deletions setup.py
Expand Up @@ -13,6 +13,7 @@
keywords=['AutoML', 'keras'],
install_requires=[
'tensorflow>=2.0.0b1',
'keras-tuner',
'scikit-learn',
'numpy',
'lightgbm',
Expand Down

0 comments on commit 13c2850

Please sign in to comment.