-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[Model Compression] Add global sort for taylor pruner #3896
Conversation
@@ -33,11 +33,12 @@ class StructuredWeightMasker(WeightMasker): | |||
|
|||
""" | |||
|
|||
def __init__(self, model, pruner, preserve_round=1, dependency_aware=False): | |||
def __init__(self, model, pruner, preserve_round=1, dependency_aware=False, global_sort=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it may be better to implement it for Taylor alone without changing the underlying interface, because we don't have another masker use these interfaces, and slim
is natural using global sort.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But based our discussion before, we agree that we should modify the condition in
May be we can add assert
to tell user that we only support Taylor currently?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the current implementation is good for me, just a little opinion, don't need to modify. And we need update doc for Taylor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have updated doc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so you put global sort in StructuredWeightMasker
but not in TaylorFOWeightFilterPrunerMasker
. what is the reason for it? is it because this would make TaylorFOWeightFilterPrunerMasker
have different initial arguments with other structured weight masker?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because in here, we modify the conditional judgement statement which discriminate different situations including global-sort and dependency-aware. The reason why we do conditional judgement here is because we think global-sort should be the same level with dependency-aware. This conditional judgement is done in StructuredWeightMasker
so we put global_sort
here.
weight = wrapper.module.weight.data | ||
filters = weight.size(0) | ||
channel_contribution = self.get_channel_sum(wrapper, wrapper_idx) | ||
num_prune = channel_contribution[channel_contribution < self.global_threshold].size()[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a dumb question, do we want < or <= here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is a key point and I thought it before, we can choose using '<=' which may cause num_prune
larger than specific sparsity or <
which may cause num_prune
smaller than specific sparsity. And I finally choose to use <
since little sparsity is much more safe and iterative pruning process would help it do further pruning.
channel_contribution = self.get_channel_sum(wrapper, wrapper_idx) | ||
channel_contribution_list.append(channel_contribution) | ||
all_channel_contributions = torch.cat(channel_contribution_list) | ||
k = int(all_channel_contributions.shape[0] * self.pruner.config_list[0]['sparsity']) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if the filters' sizes are different?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's truly a key problem, I will fix it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems we need view(-1)
to the contribution, or if they can cat together with different size?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a problem. Have fixed.
@@ -27,7 +27,7 @@ class DependencyAwarePruner(Pruner): | |||
""" | |||
|
|||
def __init__(self, model, config_list, optimizer=None, pruning_algorithm='level', dependency_aware=False, | |||
dummy_input=None, **algo_kwargs): | |||
dummy_input=None, global_sort=False, **algo_kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do all the dependency pruners support global_sort
mode? If not, I'm a little concerned.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree, how many pruners support global_sort
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not all of them support global_sort
, have deleted parameters in them.
Please add some unit tests for the global mode. |
@@ -51,10 +51,13 @@ def __init__(self, model, config_list, optimizer=None, pruning_algorithm='slim', | |||
dummy_input: torch.Tensor | |||
The dummy input to analyze the topology constraints. Note that, | |||
the dummy_input should on the same device with the model. | |||
global_sort: bool | |||
If prune the model in a global-sort way. | |||
Only support TaylorFOWeightFilterPruner currently. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if only TaylorFOWeightFilterPruner
supports global_sort
, suggest to put this argument in TaylorFOWeightFilterPruner
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
k = int(all_channel_contributions.shape[0] * self.pruner.config_list[0]['sparsity']) | ||
self.global_threshold = torch.topk( | ||
all_channel_contributions.view(-1), k, largest=False)[0].max() | ||
print(f'set global threshold to {self.global_threshold}') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better to remove this print
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have removed.
Will add it later. |
This PR supports global sort for taylor pruner.
Constraint set to 1 to prevent all channels pruned so that at least one channel will be kept.