Skip to content

Commit

Permalink
Add the linear squeeze activation pattern (#1665)
Browse files Browse the repository at this point in the history
### Changes

The linear squeeze activation pattern was added

### Reason for changes

Align with the OpenVINO runtime

### Related tickets

#1631

### Tests

N/A
  • Loading branch information
alexsu52 committed Mar 28, 2023
1 parent ff2533c commit 6df0b67
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 0 deletions.
1 change: 1 addition & 0 deletions nncf/common/graph/patterns/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ class PatternNames(Enum):
LINEAR_BATCH_NORM_SCALE_SHIFT_ACTIVATIONS = PatternDesc('linear_batch_norm_scale_shift_activations')
LINEAR_SCALE_SHIFT_ACTIVATIONS = PatternDesc('linear_scale_shift_activations')
LINEAR_CONST_MULTIPLY = PatternDesc('linear_const_multiply')
LINEAR_SQUEEZE_ACTIVATIONS = PatternDesc('linear_squeeze_activations')
SCALE_SHIFT_ACTIVATIONS = PatternDesc('scale_shift_activations')
MVN_SCALE_SHIFT_ACTIVATIONS = PatternDesc('mvn_scale_shift_activations')

Expand Down
18 changes: 18 additions & 0 deletions nncf/experimental/openvino_native/hardware/fused_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,17 @@ def create_linear_arithmetic_activations():
return linear


@OPENVINO_HW_FUSED_PATTERNS.register(PatternNames.LINEAR_SQUEEZE_ACTIVATIONS)
def create_linear_squeeze_activation():
linear = linear_operations()
squeeze = squeeze_operation()
activations = atomic_activations_operations()

linear.join_patterns(squeeze)
linear.join_patterns(activations)
return linear


@OPENVINO_HW_FUSED_PATTERNS.register(PatternNames.MVN_SCALE_SHIFT_ACTIVATIONS)
def create_mvn_scale_shift_activations():
pattern = GraphPattern()
Expand Down Expand Up @@ -804,6 +815,13 @@ def arithmetic_operations():
return pattern


def squeeze_operation():
pattern = GraphPattern()
pattern.add_node(**{GraphPattern.LABEL_ATTR: 'SQUEEZE',
GraphPattern.METATYPE_ATTR: om.OVSqueezeMetatype})
return pattern


def create_input_convert_transpose():
pattern = GraphPattern()
model_input = pattern.add_node(**{GraphPattern.LABEL_ATTR: 'MODEL_INPUT',
Expand Down
18 changes: 18 additions & 0 deletions nncf/onnx/hardware/fused_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,17 @@ def create_linear_bn_scale_shift_activation() -> GraphPattern:
return linear_batch_norm


@ONNX_HW_FUSED_PATTERNS.register(PatternNames.LINEAR_SQUEEZE_ACTIVATIONS)
def create_linear_squeeze_activation():
linear = linear_operations()
squeeze = squeeze_operation()
activations = atomic_activations_operations()

linear.join_patterns(squeeze)
linear.join_patterns(activations)
return linear


@ONNX_HW_FUSED_PATTERNS.register(PatternNames.BATCH_NORM_SCALE_SHIFT_ACTIVATIONS)
def create_bn_scale_shift_activation() -> GraphPattern:
batch_norm = batch_normalization_operations()
Expand Down Expand Up @@ -357,3 +368,10 @@ def arithmetic_operations():
pattern = GraphPattern()
pattern.add_node(**ARITHMETIC_OPERATIONS)
return pattern


def squeeze_operation():
pattern = GraphPattern()
pattern.add_node(**{GraphPattern.LABEL_ATTR: 'SQUEEZE',
GraphPattern.METATYPE_ATTR: om.ONNXSqueezeMetatype})
return pattern
1 change: 1 addition & 0 deletions tests/torch/test_pattern_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
PatternNames.LINEAR_ACTIVATION_ELEMENTWISE: 'Not relevant for Torch.',
PatternNames.LINEAR_BIASED_ACTIVATION_ELEMENTWISE: 'Not relevant for Torch.',
PatternNames.MVN_SCALE_SHIFT_ACTIVATIONS: 'Not relevant for Torch.',
PatternNames.LINEAR_SQUEEZE_ACTIVATIONS: 'Not relevant for Torch.'
}


Expand Down

0 comments on commit 6df0b67

Please sign in to comment.