Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
haifeng-jin committed Apr 2, 2019
1 parent bb07524 commit 5b31ca5
Show file tree
Hide file tree
Showing 8 changed files with 535 additions and 0 deletions.
124 changes: 124 additions & 0 deletions .gitignore
@@ -0,0 +1,124 @@
# vim swp files
*.swp
# caffe/pytorch model files
*.pth

# Mkdocs
/docs/
/mkdocs/docs/temp

.DS_Store
.idea
.pytest_cache
/experiments

# resource temp folder
tests/resources/temp/*
!tests/resources/temp/.gitkeep

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/

# Translations
*.mo
*.pot

# Django stuff:
*.log
.static_storage/
.media/
local_settings.py

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/

examples/text_cnn/glove_embedding/
1 change: 1 addition & 0 deletions README.md
@@ -1,2 +1,3 @@
# autokeras-algorithm
Some other AutoML algorithms as baselines.
Refer to: https://autokeras.com/temp/nas/
60 changes: 60 additions & 0 deletions examples/cifar10_tutorial.py
@@ -0,0 +1,60 @@
"""
Run NAS baseline methods
========================
We provide 4 NAS baseline methods now, the default one is bayesian optimization.
Here is a tutorial about running NAS baseline methods.
Generally, to run a non-default NAS methods, we will do the following steps in order:
1. Prepare the dataset in the form of torch.utils.data.DataLoader.
2. Initialize the CnnModule/MlpModule with the class name of the NAS Searcher.
3. Start search by running fit function.
Refer the cifar10 example below for more details.
"""
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torch.nn.functional import cross_entropy

from autokeras import CnnModule
from autokeras.nn.metric import Accuracy
from nas.greedy import GreedySearcher

if __name__ == '__main__':
print('==> Preparing data..')
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
(image, target) = trainset[0]
image = np.array(image).transpose((1, 2, 0))
# add dim for batch
input_shape = np.expand_dims(image, axis=0).shape
num_classes = 10

# take GreedySearcher as an example, you can implement your own searcher and
# pass the class name to the CnnModule by search_type=YOUR_SEARCHER.
cnnModule = CnnModule(loss=cross_entropy, metric=Accuracy,
searcher_args={}, verbose=True,
search_type=GreedySearcher)

cnnModule.fit(n_output_node=num_classes, input_shape=input_shape,
train_data=trainloader,
test_data=testloader)
Empty file added nas/__init__.py
Empty file.
97 changes: 97 additions & 0 deletions nas/greedy.py
@@ -0,0 +1,97 @@
import time
from copy import deepcopy

from autokeras.custom_queue import Queue
from autokeras.bayesian import contain, SearchTree
from autokeras.net_transformer import transform
from autokeras.search import Searcher

class GreedyOptimizer:

def __init__(self, searcher, metric):
self.searcher = searcher
self.metric = metric

def generate(self, descriptors, timeout, sync_message):
"""Generate new neighbor architectures from the best model.
Args:
descriptors: All the searched neural architectures.
timeout: An integer. The time limit in seconds.
sync_message: the Queue for multiprocessing return value.
Returns:
out: A list of 2-elements tuple. Each tuple contains
an instance of Graph, a morphed neural network with weights
and the father node id in the search tree.
"""
out = []
start_time = time.time()
descriptors = deepcopy(descriptors)

if isinstance(sync_message, Queue) and sync_message.qsize() != 0:
return out
model_id = self.searcher.get_neighbour_best_model_id()
graph = self.searcher.load_model_by_id(model_id)
father_id = model_id
for temp_graph in transform(graph):
if contain(descriptors, temp_graph.extract_descriptor()):
continue
out.append((deepcopy(temp_graph), father_id))
remaining_time = timeout - (time.time() - start_time)

if remaining_time < 0:
raise TimeoutError
return out


class GreedySearcher(Searcher):
""" Class to search for neural architectures using Greedy search strategy.
Attribute:
optimizer: An instance of BayesianOptimizer.
"""

def __init__(self, n_output_node, input_shape, path, metric, loss, generators, verbose,
trainer_args=None,
default_model_len=None,
default_model_width=None):
super(GreedySearcher, self).__init__(n_output_node, input_shape,
path, metric, loss, generators,
verbose, trainer_args, default_model_len,
default_model_width)
self.optimizer = GreedyOptimizer(self, metric)

def generate(self, multiprocessing_queue):
"""Generate the next neural architecture.
Args:
multiprocessing_queue: the Queue for multiprocessing return value.
pass into the search algorithm for synchronizing
Returns:
results: A list of 2-element tuples. Each tuple contains an instance of Graph,
and anything to be saved in the training queue together with the architecture
"""
remaining_time = self._timeout - time.time()
results = self.optimizer.generate(self.descriptors, remaining_time,
multiprocessing_queue)
if not results:
new_father_id = 0
generated_graph = self.generators[0](self.n_classes, self.input_shape). \
generate(self.default_model_len, self.default_model_width)
results.append((generated_graph, new_father_id))

return results

def update(self, other_info, model_id, graph, metric_value):
return

def load_neighbour_best_model(self):
return self.load_model_by_id(self.get_neighbour_best_model_id())

def get_neighbour_best_model_id(self):
if self.metric.higher_better():
return max(self.neighbour_history, key=lambda x: x['metric_value'])['model_id']
return min(self.neighbour_history, key=lambda x: x['metric_value'])['model_id']
97 changes: 97 additions & 0 deletions nas/grid.py
@@ -0,0 +1,97 @@
import itertools

from autokeras.constant import Constant
from autokeras.search import Searcher


def assert_search_space(search_space):
grid = search_space
value_list = []
if Constant.LENGTH_DIM not in list(grid.keys()):
print('No length dimension found in search Space. Using default values')
grid[Constant.LENGTH_DIM] = Constant.DEFAULT_LENGTH_SEARCH
elif not isinstance(grid[Constant.LENGTH_DIM][0], int):
print('Converting String to integers. Next time please make sure to enter integer values for Length Dimension')
grid[Constant.LENGTH_DIM] = list(map(int, grid[Constant.LENGTH_DIM]))

if Constant.WIDTH_DIM not in list(grid.keys()):
print('No width dimension found in search Space. Using default values')
grid[Constant.WIDTH_DIM] = Constant.DEFAULT_WIDTH_SEARCH
elif not isinstance(grid[Constant.WIDTH_DIM][0], int):
print('Converting String to integers. Next time please make sure to enter integer values for Width Dimension')
grid[Constant.WIDTH_DIM] = list(map(int, grid[Constant.WIDTH_DIM]))

grid_key_list = list(grid.keys())
grid_key_list.sort()
for key in grid_key_list:
value_list.append(grid[key])

dimension = list(itertools.product(*value_list))
# print(dimension)
return grid, dimension


class GridSearcher(Searcher):
""" Class to search for neural architectures using Greedy search strategy.
Attribute:
search_space: A dictionary. Specifies the search dimensions and their possible values
"""

def __init__(self, n_output_node, input_shape, path, metric, loss, generators, verbose, search_space={},
trainer_args=None, default_model_len=None, default_model_width=None):
super(GridSearcher, self).__init__(n_output_node, input_shape, path, metric, loss, generators, verbose,
trainer_args, default_model_len, default_model_width)
self.search_space, self.search_dimensions = assert_search_space(search_space)
self.search_space_counter = 0

def get_search_dimensions(self):
return self.search_dimensions

def search_space_exhausted(self):
""" Check if Grid search has exhausted the search space """
if self.search_space_counter == len(self.search_dimensions):
return True
return False

def search(self, train_data, test_data, timeout=60 * 60 * 24):
"""Run the search loop of training, generating and updating once.
Call the base class implementation for search with
Args:
train_data: An instance of DataLoader.
test_data: An instance of Dataloader.
timeout: An integer, time limit in seconds.
"""
if self.search_space_exhausted():
return
else:
super().search(train_data, test_data, timeout)

def update(self, other_info, model_id, graph, metric_value):
return

def generate(self, multiprocessing_queue):
"""Generate the next neural architecture.
Args:
multiprocessing_queue: the Queue for multiprocessing return value.
Returns:
list of 2-element tuples: generated_graph and other_info,
for grid searcher the length of list is 1.
generated_graph: An instance of Graph.
other_info: Always 0.
"""
grid = self.get_grid()
self.search_space_counter += 1
generated_graph = self.generators[0](self.n_classes, self.input_shape). \
generate(grid[Constant.LENGTH_DIM], grid[Constant.WIDTH_DIM])
return [(generated_graph, 0)]

def get_grid(self):
""" Return the next grid to be searched """
if self.search_space_counter < len(self.search_dimensions):
return self.search_dimensions[self.search_space_counter]
return None

0 comments on commit 5b31ca5

Please sign in to comment.