This repository has been archived by the owner on Jul 3, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 129
/
api.py
408 lines (325 loc) · 15.7 KB
/
api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
# Copyright 2019 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""User interface for the NAS Benchmark dataset.
Before using this API, download the data files from the links in the README.
Usage:
# Load the data from file (this will take some time)
nasbench = api.NASBench('/path/to/nasbench.tfrecord')
# Create an Inception-like module (5x5 convolution replaced with two 3x3
# convolutions).
model_spec = api.ModelSpec(
# Adjacency matrix of the module
matrix=[[0, 1, 1, 1, 0, 1, 0], # input layer
[0, 0, 0, 0, 0, 0, 1], # 1x1 conv
[0, 0, 0, 0, 0, 0, 1], # 3x3 conv
[0, 0, 0, 0, 1, 0, 0], # 5x5 conv (replaced by two 3x3's)
[0, 0, 0, 0, 0, 0, 1], # 5x5 conv (replaced by two 3x3's)
[0, 0, 0, 0, 0, 0, 1], # 3x3 max-pool
[0, 0, 0, 0, 0, 0, 0]], # output layer
# Operations at the vertices of the module, matches order of matrix
ops=[INPUT, CONV1X1, CONV3X3, CONV3X3, CONV3X3, MAXPOOL3X3, OUTPUT])
# Query this model from dataset
data = nasbench.query(model_spec)
Adjacency matrices are expected to be upper-triangular 0-1 matrices within the
defined search space (7 vertices, 9 edges, 3 allowed ops). The first and last
operations must be 'input' and 'output'. The other operations should be from
config['available_ops']. Currently, the available operations are:
CONV3X3 = "conv3x3-bn-relu"
CONV1X1 = "conv1x1-bn-relu"
MAXPOOL3X3 = "maxpool3x3"
When querying a spec, the spec will first be automatically pruned (removing
unused vertices and edges along with ops). If the pruned spec is still out of
the search space, an OutOfDomainError will be raised, otherwise the data is
returned.
The returned data object is a dictionary with the following keys:
- module_adjacency: numpy array for the adjacency matrix
- module_operations: list of operation labels
- trainable_parameters: number of trainable parameters in the model
- training_time: the total training time in seconds up to this point
- train_accuracy: training accuracy
- validation_accuracy: validation_accuracy
- test_accuracy: testing accuracy
Instead of querying the dataset for a single run of a model, it is also possible
to retrieve all metrics for a given spec, using:
fixed_stats, computed_stats = nasbench.get_metrics_from_spec(model_spec)
The fixed_stats is a dictionary with the keys:
- module_adjacency
- module_operations
- trainable_parameters
The computed_stats is a dictionary from epoch count to a list of metric
dicts. For example, computed_stats[108][0] contains the metrics for the first
repeat of the provided model trained to 108 epochs. The available keys are:
- halfway_training_time
- halfway_train_accuracy
- halfway_validation_accuracy
- halfway_test_accuracy
- final_training_time
- final_train_accuracy
- final_validation_accuracy
- final_test_accuracy
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import base64
import copy
import json
import os
import random
import time
from nasbench.lib import config
from nasbench.lib import evaluate
from nasbench.lib import model_metrics_pb2
from nasbench.lib import model_spec as _model_spec
import numpy as np
import tensorflow as tf
# Bring ModelSpec to top-level for convenience. See lib/model_spec.py.
ModelSpec = _model_spec.ModelSpec
class OutOfDomainError(Exception):
"""Indicates that the requested graph is outside of the search domain."""
class NASBench(object):
"""User-facing API for accessing the NASBench dataset."""
def __init__(self, dataset_file, seed=None):
"""Initialize dataset, this should only be done once per experiment.
Args:
dataset_file: path to .tfrecord file containing the dataset.
seed: random seed used for sampling queried models. Two NASBench objects
created with the same seed will return the same data points when queried
with the same models in the same order. By default, the seed is randomly
generated.
"""
self.config = config.build_config()
random.seed(seed)
print('Loading dataset from file... This may take a few minutes...')
start = time.time()
# Stores the fixed statistics that are independent of evaluation (i.e.,
# adjacency matrix, operations, and number of parameters).
# hash --> metric name --> scalar
self.fixed_statistics = {}
# Stores the statistics that are computed via training and evaluating the
# model on CIFAR-10. Statistics are computed for multiple repeats of each
# model at each max epoch length.
# hash --> epochs --> repeat index --> metric name --> scalar
self.computed_statistics = {}
# Valid queriable epoch lengths. {4, 12, 36, 108} for the full dataset or
# {108} for the smaller dataset with only the 108 epochs.
self.valid_epochs = set()
for serialized_row in tf.python_io.tf_record_iterator(dataset_file):
# Parse the data from the data file.
module_hash, epochs, raw_adjacency, raw_operations, raw_metrics = (
json.loads(serialized_row.decode('utf-8')))
dim = int(np.sqrt(len(raw_adjacency)))
adjacency = np.array([int(e) for e in list(raw_adjacency)], dtype=np.int8)
adjacency = np.reshape(adjacency, (dim, dim))
operations = raw_operations.split(',')
metrics = model_metrics_pb2.ModelMetrics.FromString(
base64.b64decode(raw_metrics))
if module_hash not in self.fixed_statistics:
# First time seeing this module, initialize fixed statistics.
new_entry = {}
new_entry['module_adjacency'] = adjacency
new_entry['module_operations'] = operations
new_entry['trainable_parameters'] = metrics.trainable_parameters
self.fixed_statistics[module_hash] = new_entry
self.computed_statistics[module_hash] = {}
self.valid_epochs.add(epochs)
if epochs not in self.computed_statistics[module_hash]:
self.computed_statistics[module_hash][epochs] = []
# Each data_point consists of the metrics recorded from a single
# train-and-evaluation of a model at a specific epoch length.
data_point = {}
# Note: metrics.evaluation_data[0] contains the computed metrics at the
# start of training (step 0) but this is unused by this API.
# Evaluation statistics at the half-way point of training
half_evaluation = metrics.evaluation_data[1]
data_point['halfway_training_time'] = half_evaluation.training_time
data_point['halfway_train_accuracy'] = half_evaluation.train_accuracy
data_point['halfway_validation_accuracy'] = (
half_evaluation.validation_accuracy)
data_point['halfway_test_accuracy'] = half_evaluation.test_accuracy
# Evaluation statistics at the end of training
final_evaluation = metrics.evaluation_data[2]
data_point['final_training_time'] = final_evaluation.training_time
data_point['final_train_accuracy'] = final_evaluation.train_accuracy
data_point['final_validation_accuracy'] = (
final_evaluation.validation_accuracy)
data_point['final_test_accuracy'] = final_evaluation.test_accuracy
self.computed_statistics[module_hash][epochs].append(data_point)
elapsed = time.time() - start
print('Loaded dataset in %d seconds' % elapsed)
self.history = {}
self.training_time_spent = 0.0
self.total_epochs_spent = 0
def query(self, model_spec, epochs=108, stop_halfway=False):
"""Fetch one of the evaluations for this model spec.
Each call will sample one of the config['num_repeats'] evaluations of the
model. This means that repeated queries of the same model (or isomorphic
models) may return identical metrics.
This function will increment the budget counters for benchmarking purposes.
See self.training_time_spent, and self.total_epochs_spent.
This function also allows querying the evaluation metrics at the halfway
point of training using stop_halfway. Using this option will increment the
budget counters only up to the halfway point.
Args:
model_spec: ModelSpec object.
epochs: number of epochs trained. Must be one of the evaluated number of
epochs, [4, 12, 36, 108] for the full dataset.
stop_halfway: if True, returned dict will only contain the training time
and accuracies at the halfway point of training (num_epochs/2).
Otherwise, returns the time and accuracies at the end of training
(num_epochs).
Returns:
dict containing the evaluated data for this object.
Raises:
OutOfDomainError: if model_spec or num_epochs is outside the search space.
"""
if epochs not in self.valid_epochs:
raise OutOfDomainError('invalid number of epochs, must be one of %s'
% self.valid_epochs)
fixed_stat, computed_stat = self.get_metrics_from_spec(model_spec)
sampled_index = random.randint(0, self.config['num_repeats'] - 1)
computed_stat = computed_stat[epochs][sampled_index]
data = {}
data['module_adjacency'] = fixed_stat['module_adjacency']
data['module_operations'] = fixed_stat['module_operations']
data['trainable_parameters'] = fixed_stat['trainable_parameters']
if stop_halfway:
data['training_time'] = computed_stat['halfway_training_time']
data['train_accuracy'] = computed_stat['halfway_train_accuracy']
data['validation_accuracy'] = computed_stat['halfway_validation_accuracy']
data['test_accuracy'] = computed_stat['halfway_test_accuracy']
else:
data['training_time'] = computed_stat['final_training_time']
data['train_accuracy'] = computed_stat['final_train_accuracy']
data['validation_accuracy'] = computed_stat['final_validation_accuracy']
data['test_accuracy'] = computed_stat['final_test_accuracy']
self.training_time_spent += data['training_time']
if stop_halfway:
self.total_epochs_spent += epochs // 2
else:
self.total_epochs_spent += epochs
return data
def is_valid(self, model_spec):
"""Checks the validity of the model_spec.
For the purposes of benchmarking, this does not increment the budget
counters.
Args:
model_spec: ModelSpec object.
Returns:
True if model is within space.
"""
try:
self._check_spec(model_spec)
except OutOfDomainError:
return False
return True
def get_budget_counters(self):
"""Returns the time and budget counters."""
return self.training_time_spent, self.total_epochs_spent
def reset_budget_counters(self):
"""Reset the time and epoch budget counters."""
self.training_time_spent = 0.0
self.total_epochs_spent = 0
def evaluate(self, model_spec, model_dir):
"""Trains and evaluates a model spec from scratch (does not query dataset).
This function runs the same procedure that was used to generate each
evaluation in the dataset. Because we are not querying the generated
dataset of trained models, there are no limitations on number of vertices,
edges, operations, or epochs. Note that the results will not exactly match
the dataset due to randomness. By default, this uses TPUs for evaluation but
CPU/GPU can be used by setting --use_tpu=false (GPU will require installing
tensorflow-gpu).
Args:
model_spec: ModelSpec object.
model_dir: directory to store the checkpoints, summaries, and logs.
Returns:
dict contained the evaluated data for this object, same structure as
returned by query().
"""
# Metadata contains additional metrics that aren't reported normally.
# However, these are stored in the JSON file at the model_dir.
metadata = evaluate.train_and_evaluate(model_spec, self.config, model_dir)
metadata_file = os.path.join(model_dir, 'metadata.json')
with tf.gfile.Open(metadata_file, 'w') as f:
json.dump(metadata, f, cls=_NumpyEncoder)
data_point = {}
data_point['module_adjacency'] = model_spec.matrix
data_point['module_operations'] = model_spec.ops
data_point['trainable_parameters'] = metadata['trainable_params']
final_evaluation = metadata['evaluation_results'][-1]
data_point['training_time'] = final_evaluation['training_time']
data_point['train_accuracy'] = final_evaluation['train_accuracy']
data_point['validation_accuracy'] = final_evaluation['validation_accuracy']
data_point['test_accuracy'] = final_evaluation['test_accuracy']
return data_point
def hash_iterator(self):
"""Returns iterator over all unique model hashes."""
return self.fixed_statistics.keys()
def get_metrics_from_hash(self, module_hash):
"""Returns the metrics for all epochs and all repeats of a hash.
This method is for dataset analysis and should not be used for benchmarking.
As such, it does not increment any of the budget counters.
Args:
module_hash: MD5 hash, i.e., the values yielded by hash_iterator().
Returns:
fixed stats and computed stats of the model spec provided.
"""
fixed_stat = copy.deepcopy(self.fixed_statistics[module_hash])
computed_stat = copy.deepcopy(self.computed_statistics[module_hash])
return fixed_stat, computed_stat
def get_metrics_from_spec(self, model_spec):
"""Returns the metrics for all epochs and all repeats of a model.
This method is for dataset analysis and should not be used for benchmarking.
As such, it does not increment any of the budget counters.
Args:
model_spec: ModelSpec object.
Returns:
fixed stats and computed stats of the model spec provided.
"""
self._check_spec(model_spec)
module_hash = self._hash_spec(model_spec)
return self.get_metrics_from_hash(module_hash)
def _check_spec(self, model_spec):
"""Checks that the model spec is within the dataset."""
if not model_spec.valid_spec:
raise OutOfDomainError('invalid spec, provided graph is disconnected.')
num_vertices = len(model_spec.ops)
num_edges = np.sum(model_spec.matrix)
if num_vertices > self.config['module_vertices']:
raise OutOfDomainError('too many vertices, got %d (max vertices = %d)'
% (num_vertices, config['module_vertices']))
if num_edges > self.config['max_edges']:
raise OutOfDomainError('too many edges, got %d (max edges = %d)'
% (num_edges, self.config['max_edges']))
if model_spec.ops[0] != 'input':
raise OutOfDomainError('first operation should be \'input\'')
if model_spec.ops[-1] != 'output':
raise OutOfDomainError('last operation should be \'output\'')
for op in model_spec.ops[1:-1]:
if op not in self.config['available_ops']:
raise OutOfDomainError('unsupported op %s (available ops = %s)'
% (op, self.config['available_ops']))
def _hash_spec(self, model_spec):
"""Returns the MD5 hash for a provided model_spec."""
return model_spec.hash_spec(self.config['available_ops'])
class _NumpyEncoder(json.JSONEncoder):
"""Converts numpy objects to JSON-serializable format."""
def default(self, obj):
if isinstance(obj, np.ndarray):
# Matrices converted to nested lists
return obj.tolist()
elif isinstance(obj, np.generic):
# Scalars converted to closest Python type
return np.asscalar(obj)
return json.JSONEncoder.default(self, obj)