Skip to content

Commit

Permalink
[Compression] huggingface transformers extension (#5137)
Browse files Browse the repository at this point in the history
  • Loading branch information
J-shang committed Mar 25, 2023
1 parent 6cc9ecc commit e101717
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 5 deletions.
18 changes: 15 additions & 3 deletions nni/compression/pytorch/utils/external/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
PreTrainedModel,
BartConfig,
BertConfig,
T5Config
T5Config,
ViTConfig
)
except ImportError:
TRANSFORMERS_INSTALLED = False
Expand Down Expand Up @@ -123,18 +124,29 @@ class HuggingfaceT5Parser(HuggingfaceModelParser):
ATTENTION = ('SelfAttention', 'EncDecAttention')


class HuggingfaceViTParser(HuggingfaceModelParser):
TRANSFORMER_PREFIX = r'vit\.encoder\.layer\.[0-9]+\.'
QKV = ('attention.attention.query', 'attention.attention.key', 'attention.attention.value')
QKVO = QKV + ('attention.output.dense',)
FFN1 = ('intermediate.dense',)
FFN2 = ('output.dense',)
ATTENTION = ('attention.attention',)


# huggingface transformers pretrained model parser supported: bart, bert, t5
def parser_factory(model: Module) -> HuggingfaceModelParser | None:
if TRANSFORMERS_INSTALLED and isinstance(model, PreTrainedModel):
cls2parser = {
BartConfig: HuggingfaceBartParser,
BertConfig: HuggingfaceBertParser,
T5Config: HuggingfaceT5Parser
T5Config: HuggingfaceT5Parser,
ViTConfig: HuggingfaceViTParser
}
type2parser = {
'bart': HuggingfaceBartParser,
'bert': HuggingfaceBertParser,
't5': HuggingfaceT5Parser
't5': HuggingfaceT5Parser,
'vit': HuggingfaceViTParser
}

if hasattr(model, 'config_class'):
Expand Down
2 changes: 1 addition & 1 deletion nni/compression/pytorch/utils/shape_dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def _conv_condition(self, node_group):
leaf_module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d))
group = leaf_module.groups
n_filter = leaf_module.out_channels
return n_filter == group
return n_filter == group and group != 1

def _group_norm_condition(self, node_group) -> int:
node_name = node_group.name
Expand Down
2 changes: 1 addition & 1 deletion nni/contrib/compression/base/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def _apply_quant_helper(self, target: Tensor, target_space: QuantizationTargetSp

def _distil_observe_helper(self, target: Tensor, target_space: DistillationTargetSpace) -> Tensor:
# NOTE: here will have a risk, we don't know if target will be inplace changed in the following.
target_space.hidden_state = target.clone().detach()
target_space.hidden_state = target
return target

def _track_info(self, target_name: str, target: Tensor):
Expand Down

0 comments on commit e101717

Please sign in to comment.