Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[Model Compression] Add global sort for taylor pruner #3896

Merged
merged 11 commits into from
Jul 20, 2021

Conversation

linbinskn
Copy link
Contributor

@linbinskn linbinskn commented Jul 2, 2021

This PR supports global sort for taylor pruner.
Constraint set to 1 to prevent all channels pruned so that at least one channel will be kept.

@@ -33,11 +33,12 @@ class StructuredWeightMasker(WeightMasker):

"""

def __init__(self, model, pruner, preserve_round=1, dependency_aware=False):
def __init__(self, model, pruner, preserve_round=1, dependency_aware=False, global_sort=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it may be better to implement it for Taylor alone without changing the underlying interface, because we don't have another masker use these interfaces, and slim is natural using global sort.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But based our discussion before, we agree that we should modify the condition in

https://github.com/microsoft/nni/blob/master/nni/algorithms/compression/pytorch/pruning/structured_pruning_masker.py#L63

May be we can add assert to tell user that we only support Taylor currently?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the current implementation is good for me, just a little opinion, don't need to modify. And we need update doc for Taylor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have updated doc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so you put global sort in StructuredWeightMasker but not in TaylorFOWeightFilterPrunerMasker. what is the reason for it? is it because this would make TaylorFOWeightFilterPrunerMasker have different initial arguments with other structured weight masker?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because in here, we modify the conditional judgement statement which discriminate different situations including global-sort and dependency-aware. The reason why we do conditional judgement here is because we think global-sort should be the same level with dependency-aware. This conditional judgement is done in StructuredWeightMasker so we put global_sort here.

weight = wrapper.module.weight.data
filters = weight.size(0)
channel_contribution = self.get_channel_sum(wrapper, wrapper_idx)
num_prune = channel_contribution[channel_contribution < self.global_threshold].size()[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a dumb question, do we want < or <= here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a key point and I thought it before, we can choose using '<=' which may cause num_prune larger than specific sparsity or < which may cause num_prune smaller than specific sparsity. And I finally choose to use < since little sparsity is much more safe and iterative pruning process would help it do further pruning.

channel_contribution = self.get_channel_sum(wrapper, wrapper_idx)
channel_contribution_list.append(channel_contribution)
all_channel_contributions = torch.cat(channel_contribution_list)
k = int(all_channel_contributions.shape[0] * self.pruner.config_list[0]['sparsity'])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if the filters' sizes are different?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's truly a key problem, I will fix it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have fixed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems we need view(-1) to the contribution, or if they can cat together with different size?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a problem. Have fixed.

@linbinskn linbinskn requested a review from J-shang July 9, 2021 08:28
@@ -27,7 +27,7 @@ class DependencyAwarePruner(Pruner):
"""

def __init__(self, model, config_list, optimizer=None, pruning_algorithm='level', dependency_aware=False,
dummy_input=None, **algo_kwargs):
dummy_input=None, global_sort=False, **algo_kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do all the dependency pruners support global_sort mode? If not, I'm a little concerned.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree, how many pruners support global_sort?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not all of them support global_sort, have deleted parameters in them.

@zheng-ningxin
Copy link
Contributor

Please add some unit tests for the global mode.

@@ -51,10 +51,13 @@ def __init__(self, model, config_list, optimizer=None, pruning_algorithm='slim',
dummy_input: torch.Tensor
The dummy input to analyze the topology constraints. Note that,
the dummy_input should on the same device with the model.
global_sort: bool
If prune the model in a global-sort way.
Only support TaylorFOWeightFilterPruner currently.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if only TaylorFOWeightFilterPruner supports global_sort, suggest to put this argument in TaylorFOWeightFilterPruner.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

k = int(all_channel_contributions.shape[0] * self.pruner.config_list[0]['sparsity'])
self.global_threshold = torch.topk(
all_channel_contributions.view(-1), k, largest=False)[0].max()
print(f'set global threshold to {self.global_threshold}')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to remove this print

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have removed.

@linbinskn
Copy link
Contributor Author

Please add some unit tests for the global mode.

Will add it later.

@QuanluZhang QuanluZhang merged commit d8127e0 into microsoft:master Jul 20, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants