Skip to content

Commit

Permalink
update autoslim_greedy_search
Browse files Browse the repository at this point in the history
  • Loading branch information
gaoyang07 committed Jan 13, 2023
1 parent 02de63d commit 7866460
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions mmrazor/engine/runner/autoslim_greedy_search_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,16 @@ def __init__(self,
else:
self.model = runner.model

assert hasattr(self.model, 'mutator')
search_groups = self.model.mutator.search_groups
assert hasattr(self.model, 'search_space')
search_space = self.model.search_space

self.candidate_choices = {}
for group_id, modules in search_groups.items():
self.candidate_choices[group_id] = modules[0].candidate_choices
for name, mutables in search_space.items():
self.candidate_choices[name] = mutables[0].candidate_choices

self.max_subnet = {}
for group_id, candidate_choices in self.candidate_choices.items():
self.max_subnet[group_id] = len(candidate_choices)
for name, candidate_choices in self.candidate_choices.items():
self.max_subnet[name] = len(candidate_choices)
self.current_subnet = self.max_subnet

current_subnet_choices = self._channel_bins2choices(
Expand Down Expand Up @@ -106,7 +107,8 @@ def run(self) -> None:
continue

while self.current_flops > target:
best_score, best_subnet = 0., None
best_score, best_subnet = None, None

for unit_name in sorted(self.current_subnet.keys()):
if self.current_subnet[unit_name] == 1:
# The number of channel_bin has reached the minimum
Expand All @@ -123,7 +125,7 @@ def run(self) -> None:
self.runner.logger.info(
f'Slimming unit {unit_name}, {self.score_key}: {score}'
)
if score >= best_score:
if best_score is None or score > best_score:
best_score = score
best_subnet = pruned_subnet

Expand Down

0 comments on commit 7866460

Please sign in to comment.