diff --git a/paddleslim/nas/one_shot/super_mnasnet.py b/paddleslim/nas/one_shot/super_mnasnet.py index 852b40383af52..169d1050ba180 100644 --- a/paddleslim/nas/one_shot/super_mnasnet.py +++ b/paddleslim/nas/one_shot/super_mnasnet.py @@ -2,7 +2,7 @@ from paddle import fluid from paddle.fluid.layer_helper import LayerHelper import numpy as np -from one_shot_nas import OneShotSuperNet +from .one_shot_nas import OneShotSuperNet __all__ = ['SuperMnasnet'] @@ -209,14 +209,14 @@ def __init__(self, def init_tokens(self): return [ - 3, 3, 6, 6, 6, 6, 3, 3, 3, 6, 6, 6, 3, 3, 3, 3, 6, 6, 3, 3, 3, 6, - 6, 6, 3, 3, 3, 6, 6, 6, 3, 6, 6, 6, 6, 6 + 3, 3, 6, 6, 6, 6, 3, 3, 3, 6, 6, 6, 3, 3, 3, 3, 6, 6, 3, 3, 3, 6, 6, + 6, 3, 3, 3, 6, 6, 6, 3, 6, 6, 6, 6, 6 ] def range_table(self): max_v = [ - 6, 6, 10, 10, 10, 10, 6, 6, 6, 10, 10, 10, 6, 6, 6, 6, 10, 10, 6, - 6, 6, 10, 10, 10, 6, 6, 6, 10, 10, 10, 6, 10, 10, 10, 10, 10 + 6, 6, 10, 10, 10, 10, 6, 6, 6, 10, 10, 10, 6, 6, 6, 6, 10, 10, 6, 6, + 6, 10, 10, 10, 6, 6, 6, 10, 10, 10, 6, 10, 10, 10, 10, 10 ] return (len(max_v) * [0], max_v)