Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PTQ][OV] BF16 support #2307

Merged
merged 38 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
26b6c73
Added BF16 & ov.Tensor support
KodiaqQ Dec 7, 2023
5f02d99
Add FQ params dtype conversion
KodiaqQ Dec 7, 2023
2bcfca9
Update tests for BF16
KodiaqQ Dec 7, 2023
65ce6dc
Fix tests
KodiaqQ Dec 7, 2023
4f91018
Fix bf16 tests
KodiaqQ Dec 7, 2023
69e2297
Added const with types
KodiaqQ Dec 8, 2023
de85c1d
Apply comment
KodiaqQ Dec 12, 2023
058a6e1
Disable tests
KodiaqQ Dec 12, 2023
12447d9
Added PrePostProcessor for FP32 outputs
KodiaqQ Dec 12, 2023
627ff67
Remove BF16 from testing
KodiaqQ Dec 13, 2023
6f011d9
Merge remote-tracking branch 'openvinotoolkit/develop' into nm/bf16_s…
KodiaqQ Jan 16, 2024
f3c8ed8
Adjust to develop
KodiaqQ Jan 16, 2024
c97c616
Adjust BF16 suport in tests
KodiaqQ Jan 16, 2024
ccd0b91
Added opset.constant with shared_memory option
KodiaqQ Jan 18, 2024
7f670a0
Merge remote-tracking branch 'openvinotoolkit/develop' into nm/bf16_s…
KodiaqQ Jan 22, 2024
91bb312
Added cast to fp32
KodiaqQ Jan 23, 2024
ff8f0ca
Merge openvinotoolkit/develop into nm/bf16_support
KodiaqQ Apr 17, 2024
ca6ff73
Removed PrePostProcessor usage
KodiaqQ Apr 17, 2024
11f4929
Adapt F-/BC algos to BF16
KodiaqQ Apr 18, 2024
c67ee84
Change get_const_value data output
KodiaqQ Apr 19, 2024
833f7c9
Merge remote-tracking branch 'openvinotoolkit/develop' into nm/bf16_s…
KodiaqQ May 27, 2024
097e938
Change get_const_value behavior
KodiaqQ May 28, 2024
106ff8a
Update implementation
KodiaqQ Jun 17, 2024
d2e5556
Merge remote-tracking branch 'openvinotoolkit/develop' into nm/bf16_s…
KodiaqQ Jun 17, 2024
5636d1a
Merge remote-tracking branch 'openvinotoolkit/develop' into nm/bf16_s…
KodiaqQ Jun 18, 2024
d2df92e
Fix pipeline tests
KodiaqQ Jun 18, 2024
4f9cd37
Tensor names set update
KodiaqQ Jun 18, 2024
87ea12e
Merge remote-tracking branch 'openvinotoolkit/develop' into nm/bf16_s…
KodiaqQ Jun 19, 2024
1cdc747
Extend OutputInsertionCommand
KodiaqQ Jun 19, 2024
7dfd1c1
Apply comments
KodiaqQ Jun 19, 2024
9471ac8
Merge remote-tracking branch 'openvinotoolkit/develop' into nm/bf16_s…
KodiaqQ Jun 20, 2024
8ffe8ae
Limit .get_data usage
KodiaqQ Jul 10, 2024
5405bc9
Merge remote-tracking branch 'openvinotoolkit/develop' into nm/bf16_s…
KodiaqQ Jul 10, 2024
5f4062b
Limit shared_memory usage
KodiaqQ Jul 10, 2024
cfa7ce9
Fix WC
KodiaqQ Jul 10, 2024
f2add1f
Fix test_get_const_value
KodiaqQ Jul 10, 2024
5725636
Apply comment
KodiaqQ Jul 11, 2024
3e531c4
Apply minor comments
KodiaqQ Jul 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions nncf/openvino/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def __init__(self, compiled_model: ov.CompiledModel, stateful: bool):
self.reset_state = stateful and hasattr(self.infer_request, "reset_state")

def infer(
self, input_data: Union[np.ndarray, List[np.ndarray], Tuple[np.ndarray], Dict[str, np.ndarray]]
) -> Dict[str, np.ndarray]:
self, input_data: Union[np.ndarray, List[np.ndarray], Tuple[np.ndarray], Dict[str, np.ndarray], ov.Tensor]
) -> Union[Dict[str, np.ndarray], ov.Tensor]:
andrey-churkin marked this conversation as resolved.
Show resolved Hide resolved
"""
Runs model on the provided input via OpenVINO Runtime.
Returns the dictionary of model outputs by node names.
Expand Down Expand Up @@ -73,8 +73,8 @@ def __init__(self, model: ov.Model, target_device: TargetDevice = TargetDevice.C
self.engine = OVCompiledModelEngine(compiled_model, stateful)

def infer(
self, input_data: Union[np.ndarray, List[np.ndarray], Tuple[np.ndarray], Dict[str, np.ndarray]]
) -> Dict[str, np.ndarray]:
self, input_data: Union[np.ndarray, List[np.ndarray], Tuple[np.ndarray], Dict[str, np.ndarray], ov.Tensor]
) -> Union[Dict[str, np.ndarray], ov.Tensor]:
"""
Runs model on the provided input via OpenVINO Runtime.
Returns the dictionary of model outputs by node names.
Expand Down
127 changes: 60 additions & 67 deletions nncf/openvino/graph/model_transformer.py

Large diffs are not rendered by default.

22 changes: 13 additions & 9 deletions nncf/openvino/graph/node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,48 +101,52 @@ def cnt_if_op(model: ov.Model, cnt: int) -> int:
return cnt_if_op(model, 0)


def get_const_value(const_node: ov.Node, dtype: Optional[np.dtype] = None) -> np.ndarray:
def get_const_value(const_node: ov.Node, dtype: ov.Type = ov.Type.f32) -> np.ndarray:
"""
Returns the constant tensor for the node.

:param const_node: OpenVINO node.
:param dtype: Destination type.
:param dtype: Value return type.
:return: The constant value.
"""
if dtype is None:
return const_node.data
return const_node.get_data(dtype=dtype)
return const_node.get_data(dtype=dtype.to_dtype())


def get_bias_value(node_with_bias: NNCFNode, nncf_graph: NNCFGraph, model: ov.Model) -> np.ndarray:
def get_bias_value(
node_with_bias: NNCFNode, nncf_graph: NNCFGraph, model: ov.Model, dtype: ov.Type = ov.Type.f32
) -> np.ndarray:
"""
Returns the bias tensor for the biased node.

:param node_with_bias: The node that corresponds to the operation with bias.
:param nncf_graph: NNCFGraph instance.
:param model: The model that contains this operation.
:param dtype: Value return type.
:return: The bias value that is applied to the output tensor of the node's operation.
"""
ops_dict = {op.get_friendly_name(): op for op in model.get_ops()}
bias_constant = get_node_with_bias_value(get_add_bias_node(node_with_bias, nncf_graph), nncf_graph)
ov_bias_constant = ops_dict[bias_constant.node_name]
return get_const_value(ov_bias_constant)
return get_const_value(ov_bias_constant, dtype)


def get_weight_value(node_with_weight: NNCFNode, model: ov.Model, port_id: int) -> np.ndarray:
def get_weight_value(
node_with_weight: NNCFNode, model: ov.Model, port_id: int, dtype: ov.Type = ov.Type.f32
) -> np.ndarray:
"""
Returns a weight value for the node with weight.

:param node_with_weight: Node with weight.
:param nncf_graph: NNCF graph.
:param model: The model that contains this operation.
:param port_id: The input port ID to get weight input.
:param dtype: Value return type.
:return: The weight value.
"""
const_op_friendly_name = node_with_weight.layer_attributes.constant_attributes[port_id]["name"]
friendly_name_to_op_map = {op.get_friendly_name(): op for op in model.get_ops()}
const_op = friendly_name_to_op_map[const_op_friendly_name]
weight_tensor = get_const_value(const_op)
weight_tensor = get_const_value(const_op, dtype)
return weight_tensor


Expand Down
6 changes: 6 additions & 0 deletions nncf/openvino/graph/transformations/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def __init__(self, target_point: OVTargetPoint):


class OVOutputInsertionCommand(OVInsertionCommand):
def __init__(self, target_point: OVTargetPoint, output_dtype: ov.Type = ov.Type.f32):
super().__init__(target_point)
self.output_dtype = output_dtype

def union(self, other: "TransformationCommand") -> "TransformationCommand":
# Have a look at nncf/torch/graph/transformations/commands/PTInsertionCommand
raise NotImplementedError()
Expand All @@ -60,11 +64,13 @@ def __init__(
inplace_op_fn: InplaceInsertionFnType,
fn_output_port_id: int,
last_inplace_node_name: str,
output_dtype: ov.Type = ov.Type.f32,
):
super().__init__(target_point)
self.inplace_op_fn = inplace_op_fn
self.fn_output_port_id = fn_output_port_id
self.last_inplace_node_name = last_inplace_node_name
self.output_dtype = output_dtype

def union(self, other: "TransformationCommand") -> "TransformationCommand":
# Have a look at nncf/torch/graph/transformations/commands/PTInsertionCommand
Expand Down
8 changes: 4 additions & 4 deletions nncf/openvino/quantization/quantize_ifmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,8 @@ def create_output_insertion_commands_if_node(model: ov.Model, if_node: NNCFNode)
commands = []
name_to_node_mapping = {op.get_friendly_name(): op for op in model.get_ops()}
ov_node = name_to_node_mapping[if_node.node_name]
for port_id in range(len(ov_node.inputs())):
commands.append(
OVOutputInsertionCommand(OVTargetPoint(TargetType.PRE_LAYER_OPERATION, if_node.node_name, port_id))
)
for port_id, ov_input in enumerate(ov_node.inputs()):
target_point = OVTargetPoint(TargetType.PRE_LAYER_OPERATION, if_node.node_name, port_id)
ov_input_dtype = ov_input.get_element_type()
commands.append(OVOutputInsertionCommand(target_point, output_dtype=ov_input_dtype))
return commands
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
# limitations under the License.
from typing import Dict, Iterable, List, Optional, Tuple

import numpy as np
import openvino as ov
from openvino.runtime import opset13 as opset

Expand Down Expand Up @@ -162,7 +161,7 @@ def transform_model(
should_add_convert_node = True
break

weight = Tensor(get_const_value(const_node, np.float32 if const_dtype == ov.Type.bf16 else None))
weight = Tensor(get_const_value(const_node, const_dtype))
alexsu52 marked this conversation as resolved.
Show resolved Hide resolved
original_shape = weight.shape
compressed_weight = compress_weight(
weight,
Expand Down
3 changes: 2 additions & 1 deletion tests/openvino/native/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def get_dataset_for_test(model):
input_data = {}
for param in model.get_parameters():
input_shape = param.partial_shape.get_max_shape()
input_data[param.get_output_tensor(0).get_any_name()] = rng.uniform(0, 1, input_shape)
tensor = param.get_output_tensor(0)
input_data[tensor.get_any_name()] = rng.uniform(0, 1, input_shape).astype(tensor.get_element_type().to_dtype())

dataset = Dataset([input_data])
return dataset
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
strict digraph {
"0 Parameter_MatMul.0" [id=0, type=Parameter];
"1 Convert_430" [id=1, type=Convert];
"2 MatMul" [id=2, type=MatMul];
"3 Convert_431" [id=3, type=Convert];
"4 Result_MatMul.0" [id=4, type=Result];
"5 MatMul_const" [id=5, type=Constant];
"0 Parameter_MatMul.0" -> "1 Convert_430" [label="[1, 3, 4, 2]", style=solid];
"1 Convert_430" -> "2 MatMul" [label="[1, 3, 4, 2]", style=solid];
"2 MatMul" -> "3 Convert_431" [label="[1, 3, 2, 5]", style=solid];
"3 Convert_431" -> "4 Result_MatMul.0" [label="[1, 3, 2, 5]", style=solid];
"5 MatMul_const" -> "2 MatMul" [label="[1, 3, 4, 5]", style=solid];
}
14 changes: 7 additions & 7 deletions tests/openvino/native/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,21 +282,21 @@ def _create_ov_model(self):


class FPModel(OVReferenceModel):
def __init__(self, const_dtype="FP32", input_dtype="FP32"):
self.const_dtype = np.float32 if const_dtype == "FP32" else np.float16
self.input_dtype = np.float32 if input_dtype == "FP32" else np.float16
def __init__(self, const_dtype: ov.Type = ov.Type.f32, input_dtype: ov.Type = ov.Type.f32):
self.const_dtype = const_dtype
self.input_dtype = input_dtype
super().__init__()

def _create_ov_model(self):
input_shape = [1, 3, 4, 2]
input_1 = opset.parameter(input_shape, name="Input", dtype=self.input_dtype)
data = self._rng.random((1, 3, 4, 5)).astype(self.const_dtype)
data = opset.constant(value=self._rng.random((1, 3, 4, 5)), dtype=self.const_dtype, name="MatMul_const")
if self.const_dtype != self.input_dtype:
data = opset.convert(data, self.input_dtype)
data = opset.convert(data, self.input_dtype.to_string())
matmul = opset.matmul(input_1, data, transpose_a=True, transpose_b=False, name="MatMul")
bias = self._rng.random((1, 3, 1, 1)).astype(self.const_dtype)
bias = opset.constant(value=self._rng.random((1, 3, 1, 1)), dtype=self.const_dtype, name="MatMul_bias")
if self.const_dtype != self.input_dtype:
bias = opset.convert(bias, self.input_dtype)
bias = opset.convert(bias, self.input_dtype.to_string())
add = opset.add(matmul, bias, name="Add")
result = opset.result(add, name="Result_Add")
result.get_output_tensor(0).set_names(set(["Result_Add"]))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from pathlib import Path

import numpy as np
import openvino as ov
import pytest
import torch
Expand Down Expand Up @@ -163,8 +162,8 @@ def test_synthetic_models_fq_shapes(model_creator_func, ref_shapes, inplace_stat
assert node["output_high"].shape == ref_shapes[node_name]


@pytest.mark.parametrize("const_dtype", ["FP16", "FP32"])
@pytest.mark.parametrize("input_dtype", ["FP16", "FP32"])
@pytest.mark.parametrize("const_dtype", [ov.Type.f16, ov.Type.f32, ov.Type.bf16])
@pytest.mark.parametrize("input_dtype", [ov.Type.f16, ov.Type.f32, ov.Type.bf16])
def test_fq_precision_orig_fp32model(const_dtype, input_dtype, inplace_statistics):
model = FPModel(const_dtype, input_dtype)
quantized_model = quantize_model(
Expand All @@ -174,10 +173,10 @@ def test_fq_precision_orig_fp32model(const_dtype, input_dtype, inplace_statistic
if op.get_type_name() == "FakeQuantize":
inp_node = op.input(0)
fq_input_node = inp_node.get_source_output().get_node()
if fq_input_node.get_element_type() == "Constant":
assert op.get_element_type() == ov.Type(np.float32 if input_dtype == "FP32" else np.float16)
if fq_input_node.get_type_name() == "Constant":
assert op.get_element_type() == const_dtype
elif op.get_type_name() == "Convert":
inp_node = op.input(0)
fq_input_node = inp_node.get_source_output().get_node()
if fq_input_node.get_element_type() == "Constant":
assert op.get_element_type() == ov.Type(np.float32 if const_dtype == "FP32" else np.float16)
if fq_input_node.get_type_name() == "Constant":
assert op.get_element_type() == input_dtype
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def test_compress_weights(model_creator_func, ref_nodes):

fq_nodes = get_nodes_by_type(quantized_model, type_name="FakeQuantize")
assert len(fq_nodes) == len(ref_fqs_names)
for fq_name in fq_nodes:
for fq_node in fq_nodes:
fq_name = fq_node.get_friendly_name()
assert fq_name in ref_fqs_names

for op in quantized_model.get_ops():
Expand Down Expand Up @@ -76,7 +77,8 @@ def test_overflow_fix_applied(model_creator_func, ref_nodes):

fq_nodes = get_nodes_by_type(quantized_model, type_name="FakeQuantize")
assert len(fq_nodes) == len(ref_fqs_names)
for fq_name in fq_nodes:
for fq_node in fq_nodes:
fq_name = fq_node.get_friendly_name()
assert fq_name in ref_fqs_names

for op in quantized_model.get_ops():
Expand Down
Loading
Loading