Add QLinearConcat for DML EP#16971
Conversation
|
|
||
| // broadcast y_scale and y_zero_point to output shape | ||
| m_inputTensorDescs[OnnxInputIndex::yScale] = TensorDesc( | ||
| kernelCreationContext.GetInputEdgeDescription(OnnxInputIndex::yScale).tensorDataType, |
| ); | ||
|
|
||
| m_inputTensorDescs[OnnxInputIndex::yZeroPoint] = TensorDesc( | ||
| kernelCreationContext.GetInputEdgeDescription(OnnxInputIndex::yZeroPoint).tensorDataType, |
|
|
||
| // broadcast x_scale and x_zero_point to shape of corresponding x | ||
| m_inputTensorDescs[tuple_start + 1] = TensorDesc( | ||
| kernelCreationContext.GetInputEdgeDescription(tuple_start + 1).tensorDataType, |
| ); | ||
|
|
||
| m_inputTensorDescs[tuple_start + 2] = TensorDesc( | ||
| kernelCreationContext.GetInputEdgeDescription(tuple_start + 2).tensorDataType, |
| // Given an axis in ONNX axis numbering, return the axis adjusted for DML based on how the sizes have been coerced. | ||
| // Note this function presumes the axis attribute is relative to the first input tensor (which is always the case). | ||
| uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t dmlDimCount, uint32_t firstInputIndex); | ||
| uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t dmlDimCount); |
There was a problem hiding this comment.
Should we assign 0 as a default value to firstInputIndex parameter and remove the 2nd overloaded method? #Resolved
| { | ||
| // QLinearConcat = Dequantize + Join + Quantize | ||
| // This kernel is the first usage of graph based implementation | ||
| class DmlOperatorQLinearConcat : public DmlOperator, public QLinearConcatHelper |
There was a problem hiding this comment.
[nit] Should we remove this comment? #Resolved
| std::vector<DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_DESC> dequantizeOperatorDescs(input_count); | ||
| std::vector<DML_OPERATOR_DESC> dmlOpDesc(input_count); | ||
| std::vector<const DML_OPERATOR_DESC*> opDescs = {}; | ||
| for (uint32_t input_index = 0; input_index < input_count; ++input_index) |
There was a problem hiding this comment.
[nit] Is there any specific reason we have used = {} to initialize this particular std::vector? #Closed
| static const int sc_sinceVer_QuickGelu = 1; | ||
| static const int sc_sinceVer_GroupNorm = 1; | ||
| static const int sc_sinceVer_DynamicQuantizeMatMul = 1; | ||
| static const int sc_sinceVer_QLinearConcat = 1; |
Check warning
Code scanning / PREfast
The const variable 'OperatorHelper::MsftOperatorSet1::sc_sinceVer_QLinearConcat' can be computed at compile-time. Consider using constexpr (con.5).
| constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListQLinearConcat= { | ||
| SupportedTensorDataTypes::Float32, | ||
| SupportedTensorDataTypes::Ints8Bit, | ||
| SupportedTensorDataTypes::Ints8Bit|SupportedTensorDataTypes::Float32, |
There was a problem hiding this comment.
contribop mentions TF supports any float tensor type, should we also consider supporting fp16?
There was a problem hiding this comment.
I think only tensor(float) is specified, tensor(float16) is for fp16
Type Constraints
T8 : tensor(uint8), tensor(int8)
Constrain input and output types to 8 bit signed and unsigned tensors.
TF : tensor(float)
Constrain scale types to any float tensor type.
TV : tensor(uint8), tensor(int8), tensor(float)
Sequence of (Tensor, Scale, ZeroPoint) tuples. The type is sequence of (T8, TF, T8).
| // This order matches the ONNX schema. | ||
| enum OnnxInputIndex | ||
| { | ||
| yScale, |
| QLinearConcatHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription()) | ||
| { | ||
|
|
||
| DmlOperator::Initialize(kernelCreationContext); |
There was a problem hiding this comment.
[nit] extra blank line #Closed
| auto outputShape = kernelCreationContext.GetTensorShapeDescription().GetOutputTensorShape(0); | ||
|
|
||
| // inputs: {y_scale, y_zero_point, tuple(x_tensor, x_scale, x_zero_point)} | ||
| uint32_t input_def_count = kernelCreationContext.GetInputCount(); |
| std::vector<const DML_OPERATOR_DESC*> opDescs; | ||
| for (uint32_t input_index = 0; input_index < input_count; ++input_index) | ||
| { | ||
| auto tuple_start = 2 + input_index * 3; |
| std::vector<DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_DESC> dequantizeOperatorDescs(input_count); | ||
| std::vector<DML_OPERATOR_DESC> dmlOpDesc(input_count); | ||
| std::vector<const DML_OPERATOR_DESC*> opDescs; | ||
| for (uint32_t input_index = 0; input_index < input_count; ++input_index) |
| // inputs: {y_scale, y_zero_point, tuple(x_tensor, x_scale, x_zero_point)} | ||
| uint32_t input_def_count = kernelCreationContext.GetInputCount(); | ||
| ML_CHECK_VALID_ARGUMENT(input_def_count >= 5 && (input_def_count - 2) % 3 == 0, | ||
| "Each input must be (tensor, scale, zero_point) tuple!"); |
There was a problem hiding this comment.
[](http://example.com/codeflow?start=8&length=6)
[nit] 6->4 space indent.
Alternately, consider splitting the long line more readably.
ML_CHECK_VALID_ARGUMENT(
input_def_count >= 5 && (input_def_count - 2) % 3 == 0,
"Each input must be (tensor, scale, zero_point) tuple!"
);
Or better yet, splitting it into two separate conditions, which would make it much clearer the specific error to the user. Also, then the lines are not so long and don't need to wrap:
ML_CHECK_VALID_ARGUMENT(inputDefinitionCount >= 5, "Require at least 5 inputs.");
ML_CHECK_VALID_ARGUMENT((inputDefinitionCount - 2) % 3 == 0, "Each input must be (tensor, scale, zero_point) tuple!");
Which is like what you do below, two separate ones rather than &&ing them:
ML_CHECK_VALID_ARGUMENT(xScaleDataType == yScaleDataType, "Wrong input type encountered for scale");
ML_CHECK_VALID_ARGUMENT(xZeroPointDataType == yZeroPointDataType, "Wrong input type encountered for zero point");
``` #Resolved
| TensorAxis::W, | ||
| TensorAxis::RightAligned, | ||
| NchwDimensionCount, // minDimensionCount | ||
| 0 // guaranteedBaseOffsetAlignment) |
There was a problem hiding this comment.
[nit] Indent also inconsistent from all the other TensorDesc calls (8 instead of 4)
| joinDesc.OutputTensor = &namedJoinOutputTensorDesc; | ||
| joinDesc.Axis = dmlAxis; | ||
|
|
||
| const DML_OPERATOR_DESC opJoinDesc{DML_OPERATOR_JOIN, &joinDesc}; |
There was a problem hiding this comment.
| const DML_OPERATOR_DESC opJoinDesc{DML_OPERATOR_JOIN, &joinDesc}; | |
| const DML_OPERATOR_DESC opJoinDesc = {DML_OPERATOR_JOIN, &joinDesc}; |
[nit] Consistent assignment style with nearby MLOperatorGraphDesc operatorGraphDesc = {};. #Closed
fdwr
left a comment
There was a problem hiding this comment.
Nits, but otherwise looks good Xiang.
| ML_CHECK_VALID_ARGUMENT(xZeroPointDataType == yZeroPointDataType, "Wrong input type encountered for zero point"); | ||
|
|
||
| // broadcast x_scale and x_zero_point to shape of corresponding x | ||
| m_inputTensorDescs[tupleStartIndex + 1] = TensorDesc( |
Check warning
Code scanning / PREfast
Arithmetic overflow: Using operator '+' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '+' to avoid overflow (io.2).
| 0 // guaranteedBaseOffsetAlignment | ||
| ); | ||
|
|
||
| m_inputTensorDescs[tupleStartIndex + 2] = TensorDesc( |
Check warning
Code scanning / PREfast
Arithmetic overflow: Using operator '+' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '+' to avoid overflow (io.2).
| namedDequantizeOperatorDescs[inputIndex] = intermediateOutputTensorDescs[inputIndex].GetDmlDesc(); | ||
|
|
||
| dequantizeOperatorDescs[inputIndex].InputTensor = &inputDescs[tupleStartIndex]; | ||
| dequantizeOperatorDescs[inputIndex].ScaleTensor = &inputDescs[tupleStartIndex + 1]; |
Check warning
Code scanning / PREfast
Arithmetic overflow: Using operator '+' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '+' to avoid overflow (io.2).
|
|
||
| dequantizeOperatorDescs[inputIndex].InputTensor = &inputDescs[tupleStartIndex]; | ||
| dequantizeOperatorDescs[inputIndex].ScaleTensor = &inputDescs[tupleStartIndex + 1]; | ||
| dequantizeOperatorDescs[inputIndex].ZeroPointTensor = &inputDescs[tupleStartIndex + 2]; |
Check warning
Code scanning / PREfast
Arithmetic overflow: Using operator '+' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '+' to avoid overflow (io.2).
| @@ -65,7 +65,7 @@ namespace Dml | |||
|
|
|||
| // Given an axis in ONNX axis numbering, return the axis adjusted for DML based on how the sizes have been coerced. | |||
| // Note this function presumes the axis attribute is relative to the first input tensor (which is always the case). | |||
### Description [Cherry Pick Reviewed] ``` [ OK ] QLinearConcatS8.ExpectFail_WrongZeroPointType_1 (372 ms) [ RUN ] QLinearConcatS8.InputOne_Dynamic [ OK ] QLinearConcatS8.InputOne_Dynamic (255 ms) [ RUN ] QLinearConcatS8.InputOne_Const [ OK ] QLinearConcatS8.InputOne_Const (255 ms) [----------] 11 tests from QLinearConcatS8 (3385 ms total) [----------] Global test environment tear-down [==========] 21 tests from 3 test suites ran. (9355 ms total) [ PASSED ] 21 tests. ``` [#16971](#16971) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Co-authored-by: Xiang Zhang <xianz@microsoft.com>
### Description [Cherry Pick Reviewed] ``` [ OK ] QLinearConcatS8.ExpectFail_WrongZeroPointType_1 (372 ms) [ RUN ] QLinearConcatS8.InputOne_Dynamic [ OK ] QLinearConcatS8.InputOne_Dynamic (255 ms) [ RUN ] QLinearConcatS8.InputOne_Const [ OK ] QLinearConcatS8.InputOne_Const (255 ms) [----------] 11 tests from QLinearConcatS8 (3385 ms total) [----------] Global test environment tear-down [==========] 21 tests from 3 test suites ran. (9355 ms total) [ PASSED ] 21 tests. ``` [#16971](#16971) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Co-authored-by: Xiang Zhang <xianz@microsoft.com>
### Description [Cherry Pick Reviewed] ``` [ OK ] QLinearConcatS8.ExpectFail_WrongZeroPointType_1 (372 ms) [ RUN ] QLinearConcatS8.InputOne_Dynamic [ OK ] QLinearConcatS8.InputOne_Dynamic (255 ms) [ RUN ] QLinearConcatS8.InputOne_Const [ OK ] QLinearConcatS8.InputOne_Const (255 ms) [----------] 11 tests from QLinearConcatS8 (3385 ms total) [----------] Global test environment tear-down [==========] 21 tests from 3 test suites ran. (9355 ms total) [ PASSED ] 21 tests. ``` [microsoft#16971](microsoft#16971) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Co-authored-by: Xiang Zhang <xianz@microsoft.com>
### Description [Cherry Pick Reviewed] ``` [ OK ] QLinearConcatS8.ExpectFail_WrongZeroPointType_1 (372 ms) [ RUN ] QLinearConcatS8.InputOne_Dynamic [ OK ] QLinearConcatS8.InputOne_Dynamic (255 ms) [ RUN ] QLinearConcatS8.InputOne_Const [ OK ] QLinearConcatS8.InputOne_Const (255 ms) [----------] 11 tests from QLinearConcatS8 (3385 ms total) [----------] Global test environment tear-down [==========] 21 tests from 3 test suites ran. (9355 ms total) [ PASSED ] 21 tests. ``` [#16971](microsoft/onnxruntime#16971) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Co-authored-by: Xiang Zhang <xianz@microsoft.com>
### Description [Cherry Pick Reviewed] ``` [ OK ] QLinearConcatS8.ExpectFail_WrongZeroPointType_1 (372 ms) [ RUN ] QLinearConcatS8.InputOne_Dynamic [ OK ] QLinearConcatS8.InputOne_Dynamic (255 ms) [ RUN ] QLinearConcatS8.InputOne_Const [ OK ] QLinearConcatS8.InputOne_Const (255 ms) [----------] 11 tests from QLinearConcatS8 (3385 ms total) [----------] Global test environment tear-down [==========] 21 tests from 3 test suites ran. (9355 ms total) [ PASSED ] 21 tests. ``` [#16971](microsoft/onnxruntime#16971) ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Co-authored-by: Xiang Zhang <xianz@microsoft.com>
[ OK ] QLinearConcatS8.ExpectFail_WrongZeroPointType_1 (372 ms)
[ RUN ] QLinearConcatS8.InputOne_Dynamic
[ OK ] QLinearConcatS8.InputOne_Dynamic (255 ms)
[ RUN ] QLinearConcatS8.InputOne_Const
[ OK ] QLinearConcatS8.InputOne_Const (255 ms)
[----------] 11 tests from QLinearConcatS8 (3385 ms total)
[----------] Global test environment tear-down
[==========] 21 tests from 3 test suites ran. (9355 ms total)
[ PASSED ] 21 tests.