Skip to content

Commit

Permalink
[compression] fix mask conflict v2 (#5592)
Browse files Browse the repository at this point in the history
  • Loading branch information
super-dainiu committed Jul 3, 2023
1 parent b9d9492 commit 8dc1a83
Show file tree
Hide file tree
Showing 7 changed files with 530 additions and 101 deletions.
3 changes: 2 additions & 1 deletion nni/compression/pytorch/speedup/v2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from .model_speedup import ModelSpeedup
from .dependency import auto_set_denpendency_group_ids
from .model_speedup import ModelSpeedup
97 changes: 7 additions & 90 deletions nni/compression/pytorch/speedup/v2/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING

import torch
from torch.fx.node import Node
Expand All @@ -22,92 +22,9 @@ def __init__(self, node: Node):
self.mask_updater: 'MaskUpdater' = None
self.replaced = False

self._output_origin = None
self._output_inplace = None
self._output_randomize = None
self._output_grad = None
self._output_masks = None
self._param_masks = None
self.assignment_status = {
'output_origin': 0,
'output_inplace': 0,
'output_randomize': 0,
'output_grad': 0,
'output_masks': 0,
'param_masks': 0,
}

@property
def output_origin(self):
"""
The original output of a node.
"""
# assert self.assignment_status['output_origin'] == 1, \
# f"NodeInfo error: bad output_origin({self.assignment_status['output_origin']})"
return self._output_origin

@property
def output_inplace(self):
"""
A clone of the original output, used as the input of successor node to get the orginal output of successor node.
"""
# assert self.assignment_status['output_inplace'] == 1, \
# f"NodeInfo error: bad output_inplace({self.assignment_status['output_inplace']})"
return self._output_inplace

@property
def output_randomize(self):
"""
A randomize output of the original output, used to direct propagate masks.
"""
# assert self.assignment_status['output_randomize'] == 1, \
# f"NodeInfo error: bad output_randomize({self.assignment_status['output_randomize']})"
return self._output_randomize

@property
def output_grad(self):
"""
The sum of the gradient given by successor during indirect propagation.
"""
# assert self.assignment_status['output_grad'] == 1, f"NodeInfo error: bad output_grad({self.assignment_status['output_grad']})"
return self._output_grad

@property
def output_masks(self):
# assert self.assignment_status['output_masks'] <= 3, f"NodeInfo error: bad output_masks({self.assignment_status['output_masks']})"
return self._output_masks

@property
def param_masks(self):
# assert self.assignment_status['param_masks'] <= 2, f"NodeInfo error: bad param_masks({self.assignment_status['param_masks']})"
return self._param_masks

@output_origin.setter
def output_origin(self, val: Any):
self._output_origin = val
self.assignment_status['output_origin'] += 1

@output_inplace.setter
def output_inplace(self, val: Any):
self._output_inplace = val
self.assignment_status['output_inplace'] += 1

@output_randomize.setter
def output_randomize(self, val: Any):
self._output_randomize = val
self.assignment_status['output_randomize'] += 1

@output_grad.setter
def output_grad(self, val: Any):
self._output_grad = val
self.assignment_status['output_grad'] += 1

@output_masks.setter
def output_masks(self, val: Any):
self._output_masks = val
self.assignment_status['output_masks'] += 1

@param_masks.setter
def param_masks(self, val: Any):
self._param_masks = val
self.assignment_status['param_masks'] += 1
self.output_origin = None
self.output_inplace = None
self.output_randomize = None
self.output_grad = None
self.output_masks = None
self.param_masks = None

0 comments on commit 8dc1a83

Please sign in to comment.