Skip to content

Commit

Permalink
[AutoTVM] Add index boundary check in ConfigSpace.get() (apache#7234)
Browse files Browse the repository at this point in the history
* [AutoTVM] Add index boundary check in ConfigSpace.get()

* Fix unit test

Co-authored-by: Yanming Wang <yanmwang@amazon.com>
  • Loading branch information
2 people authored and trevor-m committed Jan 21, 2021
1 parent c7f3aaf commit a78d88d
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
8 changes: 5 additions & 3 deletions python/tvm/autotvm/task/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"""
Template configuration space.
Each template function can be parametrized by a ConfigSpace.
Each template function can be parameterized by a ConfigSpace.
The space is declared when we invoke the template function with ConfigSpace.
During evaluation, we pass in a ConfigEntity, which contains a specific
entity in the space. This entity contains deterministic parameters.
Expand Down Expand Up @@ -63,7 +63,7 @@ class TransformSpace(object):
Each operator has some tunable parameters (e.g. the split factor).
Then the tuning process is just to find good parameters of these op.
So the all the combinations of the parameters of these op forms our search space.
So all the combinations of the parameters of these op form our search space.
Naming convention:
We call the set of all possible values as XXXSpace. (XXX can be Split, Reorder, Config ...)
Expand Down Expand Up @@ -797,7 +797,7 @@ def add_flop(self, flop):

def raise_error(self, msg):
"""register error in config
Using this to actively detect error when scheudling.
Using this to actively detect error when scheduling.
Otherwise these error will occur during runtime, which
will cost more time.
Expand Down Expand Up @@ -848,6 +848,8 @@ def get(self, index):
index: int
index in the space
"""
if index < 0 or index >= len(self):
raise IndexError("Index out of range: size {}, got index {}".format(len(self), index))
entities = OrderedDict()
t = index
for name, space in self.space_map.items():
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_autotvm_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,6 @@ def get_sample_records(n):

inps, ress = [], []
for i in range(n):
inps.append(MeasureInput(target, tsk, tsk.config_space.get(i)))
inps.append(MeasureInput(target, tsk, tsk.config_space.get(i % len(tsk.config_space))))
ress.append(MeasureResult((i + 1,), 0, i, time.time()))
return list(zip(inps, ress))

0 comments on commit a78d88d

Please sign in to comment.