Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PTQ][OV] BF16 support #2307

Merged
merged 38 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
26b6c73
Added BF16 & ov.Tensor support
KodiaqQ Dec 7, 2023
5f02d99
Add FQ params dtype conversion
KodiaqQ Dec 7, 2023
2bcfca9
Update tests for BF16
KodiaqQ Dec 7, 2023
65ce6dc
Fix tests
KodiaqQ Dec 7, 2023
4f91018
Fix bf16 tests
KodiaqQ Dec 7, 2023
69e2297
Added const with types
KodiaqQ Dec 8, 2023
de85c1d
Apply comment
KodiaqQ Dec 12, 2023
058a6e1
Disable tests
KodiaqQ Dec 12, 2023
12447d9
Added PrePostProcessor for FP32 outputs
KodiaqQ Dec 12, 2023
627ff67
Remove BF16 from testing
KodiaqQ Dec 13, 2023
6f011d9
Merge remote-tracking branch 'openvinotoolkit/develop' into nm/bf16_s…
KodiaqQ Jan 16, 2024
f3c8ed8
Adjust to develop
KodiaqQ Jan 16, 2024
c97c616
Adjust BF16 suport in tests
KodiaqQ Jan 16, 2024
ccd0b91
Added opset.constant with shared_memory option
KodiaqQ Jan 18, 2024
7f670a0
Merge remote-tracking branch 'openvinotoolkit/develop' into nm/bf16_s…
KodiaqQ Jan 22, 2024
91bb312
Added cast to fp32
KodiaqQ Jan 23, 2024
ff8f0ca
Merge openvinotoolkit/develop into nm/bf16_support
KodiaqQ Apr 17, 2024
ca6ff73
Removed PrePostProcessor usage
KodiaqQ Apr 17, 2024
11f4929
Adapt F-/BC algos to BF16
KodiaqQ Apr 18, 2024
c67ee84
Change get_const_value data output
KodiaqQ Apr 19, 2024
833f7c9
Merge remote-tracking branch 'openvinotoolkit/develop' into nm/bf16_s…
KodiaqQ May 27, 2024
097e938
Change get_const_value behavior
KodiaqQ May 28, 2024
106ff8a
Update implementation
KodiaqQ Jun 17, 2024
d2e5556
Merge remote-tracking branch 'openvinotoolkit/develop' into nm/bf16_s…
KodiaqQ Jun 17, 2024
5636d1a
Merge remote-tracking branch 'openvinotoolkit/develop' into nm/bf16_s…
KodiaqQ Jun 18, 2024
d2df92e
Fix pipeline tests
KodiaqQ Jun 18, 2024
4f9cd37
Tensor names set update
KodiaqQ Jun 18, 2024
87ea12e
Merge remote-tracking branch 'openvinotoolkit/develop' into nm/bf16_s…
KodiaqQ Jun 19, 2024
1cdc747
Extend OutputInsertionCommand
KodiaqQ Jun 19, 2024
7dfd1c1
Apply comments
KodiaqQ Jun 19, 2024
9471ac8
Merge remote-tracking branch 'openvinotoolkit/develop' into nm/bf16_s…
KodiaqQ Jun 20, 2024
8ffe8ae
Limit .get_data usage
KodiaqQ Jul 10, 2024
5405bc9
Merge remote-tracking branch 'openvinotoolkit/develop' into nm/bf16_s…
KodiaqQ Jul 10, 2024
5f4062b
Limit shared_memory usage
KodiaqQ Jul 10, 2024
cfa7ce9
Fix WC
KodiaqQ Jul 10, 2024
f2add1f
Fix test_get_const_value
KodiaqQ Jul 10, 2024
5725636
Apply comment
KodiaqQ Jul 11, 2024
3e531c4
Apply minor comments
KodiaqQ Jul 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions nncf/openvino/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def __init__(self, compiled_model: ov.CompiledModel, stateful: bool):
self.reset_state = stateful and hasattr(self.infer_request, "reset_state")

def infer(
self, input_data: Union[np.ndarray, List[np.ndarray], Tuple[np.ndarray], Dict[str, np.ndarray]]
) -> Dict[str, np.ndarray]:
self, input_data: Union[np.ndarray, List[np.ndarray], Tuple[np.ndarray], Dict[str, np.ndarray], ov.Tensor]
) -> Union[Dict[str, np.ndarray], ov.Tensor]:
andrey-churkin marked this conversation as resolved.
Show resolved Hide resolved
"""
Runs model on the provided input via OpenVINO Runtime.
Returns the dictionary of model outputs by node names.
Expand Down Expand Up @@ -73,8 +73,8 @@ def __init__(self, model: ov.Model, target_device: TargetDevice = TargetDevice.C
self.engine = OVCompiledModelEngine(compiled_model, stateful)

def infer(
self, input_data: Union[np.ndarray, List[np.ndarray], Tuple[np.ndarray], Dict[str, np.ndarray]]
) -> Dict[str, np.ndarray]:
self, input_data: Union[np.ndarray, List[np.ndarray], Tuple[np.ndarray], Dict[str, np.ndarray], ov.Tensor]
) -> Union[Dict[str, np.ndarray], ov.Tensor]:
"""
Runs model on the provided input via OpenVINO Runtime.
Returns the dictionary of model outputs by node names.
Expand Down
116 changes: 54 additions & 62 deletions nncf/openvino/graph/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,6 @@ def __init__(self, model: TModel, inplace: bool = False):
(OVExtractIfBodyCommand, self._apply_extract_if_body_transformation),
]

@staticmethod
def _convert_to_fp16(data):
clip_data = np.clip(data, np.finfo(np.float16).min, np.finfo(np.float16).max)
return clip_data.astype(np.float16)

@staticmethod
def _get_name_to_node_mapping(model: ov.Model) -> Dict[str, ov.Node]:
"""
Expand Down Expand Up @@ -102,16 +97,16 @@ def _get_activation_node_names(model: ov.Model) -> List[str]:
return list(activation_nodes)

@staticmethod
def _update_tensor_name(tensors: List[DescriptorTensor], name: str) -> None:
def _update_tensor_name(tensors: List[DescriptorTensor], names: List[str]) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _update_tensor_name(tensors: List[DescriptorTensor], names: List[str]) -> None:
def _update_tensor_names(tensors: List[DescriptorTensor], names: List[str]) -> None:

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

"""
Updates tensors names in-place.

:param model: List of the tensors.
:param name: New name for tensor.
:param names: List of the new names for tensors.
"""
for tensor in tensors:
current_names = tensor.get_names()
current_names.add(name)
current_names.update(names)
tensor.set_names(current_names)

def transform(self, transformation_layout: TransformationLayout) -> ov.Model:
Expand Down Expand Up @@ -195,16 +190,22 @@ def _insert_outputs(model: ov.Model, outputs: List[Tuple[ov.Output, int, Callabl
:param outputs: list of tuples with ov.Output & port_id.
:return: Model with new outputs.
"""
outputs_type = ov.Type.f32
results = model.get_results()
params = model.get_parameters()

extra_model_outputs = []
for output, port_id in outputs:
output_name = output.get_node().get_friendly_name()
# TODO: (KodiaqQ) check out the models with the Split
node_output = output
output_name = node_output.get_node().get_friendly_name()
result_name = get_result_node_name(output_name, port_id)
result = opset.result(output, name=result_name)
OVModelTransformer._update_tensor_name([result.get_output_tensor(0)], result_name)

if node_output.get_element_type() != outputs_type:
node_output = opset.convert(output, destination_type=outputs_type)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This affects outputs for the then-else bodies of the if model. The point is that we should collect all statistics using fp32 precision. But some outputs may have a boolean format, not a numeric.
What should we do in this case, @kshpv, @alexsu52?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest to add flexibility. I mean the insert output transformation command should has dtype parameter to specify return type. By default the return type is FP32.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


result = opset.result(node_output, name=result_name)
result_tensor_names = [result_name] + list(output.get_names())
OVModelTransformer._update_tensor_name([result.get_output_tensor(0)], result_tensor_names)
extra_model_outputs.append(result)

model_with_outputs = ov.Model(
Expand Down Expand Up @@ -284,15 +285,15 @@ def _create_fake_quantize(
op_output: ov.Output,
fake_quantize_params: FakeQuantizeParameters,
fake_quantize_name: str,
convert_to_fp16: bool,
data_type: ov.Type,
) -> ov.Node:
"""
Creates FakeQuantize node.

:param op_output: Output of the previous node.
:param fake_quantize_params: FakeQuantizeParameters instance.
:param fake_quantize_name: New layer name.
:param convert_to_fp16: Whether convert parameters to FP16 or not.
:param data_type: ov.Type instance for data.
:return: ov.Node instance.
"""

Expand All @@ -301,24 +302,18 @@ def _create_fake_quantize(
output_low = fake_quantize_params.output_low.data
output_high = fake_quantize_params.output_high.data
levels = fake_quantize_params.levels
dtype = ov.Type.f32

if convert_to_fp16:
input_low = OVModelTransformer._convert_to_fp16(input_low)
input_high = OVModelTransformer._convert_to_fp16(input_high)
output_low = OVModelTransformer._convert_to_fp16(output_low)
output_high = OVModelTransformer._convert_to_fp16(output_high)
dtype = ov.Type.f16

input_low = OVModelTransformer._create_constant(input_low, dtype=dtype, name=f"{fake_quantize_name}/input_low")
input_low = OVModelTransformer._create_constant(
input_low, dtype=data_type, name=f"{fake_quantize_name}/input_low"
)
input_high = OVModelTransformer._create_constant(
input_high, dtype=dtype, name=f"{fake_quantize_name}/input_high"
input_high, dtype=data_type, name=f"{fake_quantize_name}/input_high"
)
output_low = OVModelTransformer._create_constant(
output_low, dtype=dtype, name=f"{fake_quantize_name}/output_low"
output_low, dtype=data_type, name=f"{fake_quantize_name}/output_low"
)
output_high = OVModelTransformer._create_constant(
output_high, dtype=dtype, name=f"{fake_quantize_name}/output_high"
output_high, dtype=data_type, name=f"{fake_quantize_name}/output_high"
)

return opset.fake_quantize(
Expand All @@ -330,30 +325,24 @@ def _create_fake_convert(
op_output: ov.Output,
fake_convert_params: FakeConvertParameters,
fake_convert_name: str,
convert_to_fp16: bool,
data_type: ov.Type,
) -> ov.Node:
"""
Creates FakeConvert node.

:param op_output: Output of the previous node.
:param fake_convert_params: FakeConvertParameters instance.
:param fake_convert_name: New layer name.
:param convert_to_fp16: Whether convert parameters to FP16 or not.
:param data_type: ov.Type instance for data.
:return: ov.Node instance.
"""

scale = fake_convert_params.scale.data
shift = fake_convert_params.shift.data
dtype = ov.Type.f32

if convert_to_fp16:
scale = OVModelTransformer._convert_to_fp16(scale)
shift = OVModelTransformer._convert_to_fp16(shift)
dtype = ov.Type.f16

destination_type = fake_convert_params.destination_type.value
scale = OVModelTransformer._create_constant(scale, dtype=dtype, name=f"{fake_convert_name}/scale")
shift = OVModelTransformer._create_constant(shift, dtype=dtype, name=f"{fake_convert_name}/shift")
scale = OVModelTransformer._create_constant(scale, dtype=data_type, name=f"{fake_convert_name}/scale")
shift = OVModelTransformer._create_constant(shift, dtype=data_type, name=f"{fake_convert_name}/shift")

return opset.fake_convert(
data=op_output,
Expand Down Expand Up @@ -383,7 +372,6 @@ def _insert_fake_quantize_op(
inp_node = target_node.input(port_id)
input_node_output = inp_node.get_source_output()
data_type = inp_node.get_element_type()
convert_to_fp16 = data_type == ov.Type(np.float16)
name = "fq_weights" if transform_type == TargetType.OPERATION_WITH_WEIGHTS else "fq_input"
fq_name = f"{node_name}/{name}_{port_id}"

Expand All @@ -398,20 +386,19 @@ def _insert_fake_quantize_op(
op_output=input_node_output,
fake_quantize_params=fq_params,
fake_quantize_name=fq_name,
convert_to_fp16=convert_to_fp16,
data_type=data_type,
)
inp_node.replace_source_output(fq.output(0))
elif transform_type == TargetType.POST_LAYER_OPERATION:
output = target_node.output(port_id)
data_type = output.get_element_type()
convert_to_fp16 = data_type == ov.Type(np.float16)
target_inputs = output.get_target_inputs()
fq_name = f"{node_name}/fq_output_{port_id}"
fq = OVModelTransformer._create_fake_quantize(
op_output=output,
fake_quantize_params=fq_params,
fake_quantize_name=fq_name,
convert_to_fp16=convert_to_fp16,
data_type=data_type,
)
for inp_node in target_inputs:
inp_node.replace_source_output(fq.output(0))
Expand Down Expand Up @@ -447,25 +434,25 @@ def _insert_fake_convert_op(
if out.get_node().get_type_name() == "FakeConvert":
fc = out.get_node()
if fc is None:
convert_to_fp16 = inp_node.get_element_type() == ov.Type(np.float16)
data_type = inp_node.get_element_type()
fc_name = f"{node_name}/fc_{name}_{port_id}"
fc = OVModelTransformer._create_fake_convert(
op_output=input_node_output,
fake_convert_params=fc_params,
fake_convert_name=fc_name,
convert_to_fp16=convert_to_fp16,
data_type=data_type,
)
inp_node.replace_source_output(fc.output(0))
elif transform_type == TargetType.POST_LAYER_OPERATION:
output = target_node.output(port_id)
convert_to_fp16 = output.get_element_type() == ov.Type(np.float16)
data_type = output.get_element_type()
target_inputs = output.get_target_inputs()
fc_name = f"{node_name}/fc_output_{port_id}"
fc = OVModelTransformer._create_fake_convert(
op_output=output,
fake_convert_params=fc_params,
fake_convert_name=fc_name,
convert_to_fp16=convert_to_fp16,
data_type=data_type,
)
for inp_node in target_inputs:
inp_node.replace_source_output(fc.output(0))
Expand Down Expand Up @@ -517,13 +504,11 @@ def _set_const_value(node_with_const: ov.Node, const_port_id: int, const_value:
if const_node is None:
raise nncf.InternalError("Constant node was expected but could not find it.")

const_shape = const_node.data.shape
const_dtype = const_node.data.dtype
const_value = np.reshape(const_value, const_shape).astype(const_dtype)
const_value = np.reshape(const_value, const_node.data.shape)

# TODO(andrey-churkin): Replace on opset13.constant() in 2023.3 release
new_const_node = ov.op.Constant(const_value, shared_memory=True)
new_const_node.set_friendly_name(const_node.get_friendly_name())
new_const_node = opset.constant(
const_value, dtype=const_port.get_element_type(), name=const_node.get_friendly_name(), shared_memory=True
)
const_port.replace_source_output(new_const_node.output(0))

@staticmethod
Expand Down Expand Up @@ -553,6 +538,7 @@ def _apply_model_extraction_transformation(
:param transformation: Model extraction transformation.
:return: Extracted sub-model.
"""
outputs_type = ov.Type.f32
transformation = transformations[-1]
name_to_node_mapping = OVModelTransformer._get_name_to_node_mapping(model)

Expand All @@ -564,25 +550,34 @@ def _apply_model_extraction_transformation(
continue

input_port = input_node.input(input_port_id)
input_type = input_port.get_element_type()
input_node_output = input_port.get_source_output()
parameter_name = get_parameter_node_name(input_name, input_port_id)

new_param = opset.parameter(
shape=input_node_output.partial_shape,
dtype=input_node_output.get_element_type(),
dtype=outputs_type,
name=parameter_name,
)
input_port.replace_source_output(new_param.output(0))
new_input = new_param.output(0)

if input_type != outputs_type:
new_input = opset.convert(new_param, destination_type=input_type).output(0)

input_port.replace_source_output(new_input)
new_param_tensors = [o.get_tensor() for o in new_param.outputs()]
OVModelTransformer._update_tensor_name(new_param_tensors, parameter_name)
OVModelTransformer._update_tensor_name(new_param_tensors, [parameter_name])
params.append(new_param)

for output_name, output_port_id in transformation.output_ids:
output_node = name_to_node_mapping[output_name]

output_port = output_node.output(output_port_id)
result_name = get_result_node_name(output_name, output_port_id)
new_result = opset.result(output_port, name=result_name)
OVModelTransformer._update_tensor_name([new_result.get_output_tensor(0)], result_name)
if output_node.get_element_type() != outputs_type:
output_node = opset.convert(output_node, destination_type=outputs_type)
new_result = opset.result(output_node, name=result_name)
result_tensor_names = [result_name] + list(output_node.output(0).get_names())
OVModelTransformer._update_tensor_name([new_result.get_output_tensor(0)], result_tensor_names)
results.append(new_result)

if not results:
Expand Down Expand Up @@ -624,7 +619,7 @@ def _apply_stateless_model_extraction_transformation(
for input_port in output_port.get_target_inputs():
input_port.replace_source_output(new_param.output(0))
new_param_tensors = [o.get_tensor() for o in new_param.outputs()]
OVModelTransformer._update_tensor_name(new_param_tensors, parameter_name)
OVModelTransformer._update_tensor_name(new_param_tensors, [parameter_name])
params.append(new_param)

for output_name, output_port_id in transformation.output_ids:
Expand All @@ -633,7 +628,7 @@ def _apply_stateless_model_extraction_transformation(
output_port = output_node.output(output_port_id)
result_name = get_result_node_name(output_name, output_port_id)
new_result = opset.result(output_port, name=result_name)
OVModelTransformer._update_tensor_name([new_result.get_output_tensor(0)], result_name)
OVModelTransformer._update_tensor_name([new_result.get_output_tensor(0)], [result_name])
results.append(new_result)

if not results:
Expand Down Expand Up @@ -759,10 +754,7 @@ def _apply_multiply_insertion_transformations(
if target_node.get_friendly_name() in transformation.destination_node_names:
destination_ports.append(target_input_port)

scale_dtype = ov.Type(np.float32)
fp16_dtype = ov.Type(np.float16)
if all(p.get_element_type() == fp16_dtype for p in destination_ports):
scale_dtype = fp16_dtype
scale_dtype = node_output_port.get_element_type()

scale_constant = OVModelTransformer._create_constant(
transformation.scale_value, dtype=scale_dtype, name=f"{transformation.multiply_node_name}/scale"
Expand Down
22 changes: 13 additions & 9 deletions nncf/openvino/graph/node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,48 +101,52 @@ def cnt_if_op(model: ov.Model, cnt: int) -> int:
return cnt_if_op(model, 0)


def get_const_value(const_node: ov.Node, dtype: Optional[np.dtype] = None) -> np.ndarray:
def get_const_value(const_node: ov.Node, dtype: ov.Type = ov.Type.f32) -> np.ndarray:
"""
Returns the constant tensor for the node.

:param const_node: OpenVINO node.
:param dtype: Destination type.
:param dtype: Value return type.
:return: The constant value.
"""
if dtype is None:
return const_node.data
return const_node.get_data(dtype=dtype)
return const_node.get_data(dtype=dtype.to_dtype())


def get_bias_value(node_with_bias: NNCFNode, nncf_graph: NNCFGraph, model: ov.Model) -> np.ndarray:
def get_bias_value(
node_with_bias: NNCFNode, nncf_graph: NNCFGraph, model: ov.Model, dtype: ov.Type = ov.Type.f32
) -> np.ndarray:
"""
Returns the bias tensor for the biased node.

:param node_with_bias: The node that corresponds to the operation with bias.
:param nncf_graph: NNCFGraph instance.
:param model: The model that contains this operation.
:param dtype: Value return type.
:return: The bias value that is applied to the output tensor of the node's operation.
"""
ops_dict = {op.get_friendly_name(): op for op in model.get_ops()}
bias_constant = get_node_with_bias_value(get_add_bias_node(node_with_bias, nncf_graph), nncf_graph)
ov_bias_constant = ops_dict[bias_constant.node_name]
return get_const_value(ov_bias_constant)
return get_const_value(ov_bias_constant, dtype)


def get_weight_value(node_with_weight: NNCFNode, model: ov.Model, port_id: int) -> np.ndarray:
def get_weight_value(
node_with_weight: NNCFNode, model: ov.Model, port_id: int, dtype: ov.Type = ov.Type.f32
) -> np.ndarray:
"""
Returns a weight value for the node with weight.

:param node_with_weight: Node with weight.
:param nncf_graph: NNCF graph.
:param model: The model that contains this operation.
:param port_id: The input port ID to get weight input.
:param dtype: Value return type.
:return: The weight value.
"""
const_op_friendly_name = node_with_weight.layer_attributes.constant_attributes[port_id]["name"]
friendly_name_to_op_map = {op.get_friendly_name(): op for op in model.get_ops()}
const_op = friendly_name_to_op_map[const_op_friendly_name]
weight_tensor = get_const_value(const_op)
weight_tensor = get_const_value(const_op, dtype)
return weight_tensor


Expand Down
Loading
Loading