Skip to content

Add QLinearConcat for DML EP#16971

Merged
zhangxiang1993 merged 6 commits into
DmlPrototypefrom
user/xianz/QLinearConcat
Aug 17, 2023
Merged

Add QLinearConcat for DML EP#16971
zhangxiang1993 merged 6 commits into
DmlPrototypefrom
user/xianz/QLinearConcat

Conversation

@zhangxiang1993
Copy link
Copy Markdown
Contributor

[ OK ] QLinearConcatS8.ExpectFail_WrongZeroPointType_1 (372 ms)
[ RUN ] QLinearConcatS8.InputOne_Dynamic
[ OK ] QLinearConcatS8.InputOne_Dynamic (255 ms)
[ RUN ] QLinearConcatS8.InputOne_Const
[ OK ] QLinearConcatS8.InputOne_Const (255 ms)
[----------] 11 tests from QLinearConcatS8 (3385 ms total)

[----------] Global test environment tear-down
[==========] 21 tests from 3 test suites ran. (9355 ms total)
[ PASSED ] 21 tests.

@zhangxiang1993 zhangxiang1993 requested review from AnaghaRaoAMD, PatriceVignola and fdwr and removed request for PatriceVignola August 2, 2023 15:58

// broadcast y_scale and y_zero_point to output shape
m_inputTensorDescs[OnnxInputIndex::yScale] = TensorDesc(
kernelCreationContext.GetInputEdgeDescription(OnnxInputIndex::yScale).tensorDataType,
Copy link
Copy Markdown
Contributor Author

@zhangxiang1993 zhangxiang1993 Aug 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kernelCreationContext.GetInputEdgeDescription(OnnxInputIndex::yScale).tensorDataType

yScaleDataType #Resolved

);

m_inputTensorDescs[OnnxInputIndex::yZeroPoint] = TensorDesc(
kernelCreationContext.GetInputEdgeDescription(OnnxInputIndex::yZeroPoint).tensorDataType,
Copy link
Copy Markdown
Contributor Author

@zhangxiang1993 zhangxiang1993 Aug 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kernelCreationContext.GetInputEdgeDescription(OnnxInputIndex::yZeroPoint).tensorDataType

yZeroPointDataType #Closed


// broadcast x_scale and x_zero_point to shape of corresponding x
m_inputTensorDescs[tuple_start + 1] = TensorDesc(
kernelCreationContext.GetInputEdgeDescription(tuple_start + 1).tensorDataType,
Copy link
Copy Markdown
Contributor Author

@zhangxiang1993 zhangxiang1993 Aug 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kernelCreationContext.GetInputEdgeDescription(tuple_start + 1).tensorDataType

xScaleDataType #Closed

);

m_inputTensorDescs[tuple_start + 2] = TensorDesc(
kernelCreationContext.GetInputEdgeDescription(tuple_start + 2).tensorDataType,
Copy link
Copy Markdown
Contributor Author

@zhangxiang1993 zhangxiang1993 Aug 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kernelCreationContext.GetInputEdgeDescription(tuple_start + 2).tensorDataType

xZeroPointDataType #Resolved

// Given an axis in ONNX axis numbering, return the axis adjusted for DML based on how the sizes have been coerced.
// Note this function presumes the axis attribute is relative to the first input tensor (which is always the case).
uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t dmlDimCount, uint32_t firstInputIndex);
uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t dmlDimCount);
Copy link
Copy Markdown
Contributor

@sumitsays sumitsays Aug 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we assign 0 as a default value to firstInputIndex parameter and remove the 2nd overloaded method? #Resolved

{
// QLinearConcat = Dequantize + Join + Quantize
// This kernel is the first usage of graph based implementation
class DmlOperatorQLinearConcat : public DmlOperator, public QLinearConcatHelper
Copy link
Copy Markdown
Contributor

@sumitsays sumitsays Aug 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Should we remove this comment? #Resolved

std::vector<DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_DESC> dequantizeOperatorDescs(input_count);
std::vector<DML_OPERATOR_DESC> dmlOpDesc(input_count);
std::vector<const DML_OPERATOR_DESC*> opDescs = {};
for (uint32_t input_index = 0; input_index < input_count; ++input_index)
Copy link
Copy Markdown
Contributor

@sumitsays sumitsays Aug 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Is there any specific reason we have used = {} to initialize this particular std::vector? #Closed

static const int sc_sinceVer_QuickGelu = 1;
static const int sc_sinceVer_GroupNorm = 1;
static const int sc_sinceVer_DynamicQuantizeMatMul = 1;
static const int sc_sinceVer_QLinearConcat = 1;

Check warning

Code scanning / PREfast

The const variable 'OperatorHelper::MsftOperatorSet1::sc_sinceVer_QLinearConcat' can be computed at compile-time. Consider using constexpr (con.5).

The const variable 'OperatorHelper::MsftOperatorSet1::sc_sinceVer_QLinearConcat' can be computed at compile-time. Consider using constexpr (con.5).
constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListQLinearConcat= {
SupportedTensorDataTypes::Float32,
SupportedTensorDataTypes::Ints8Bit,
SupportedTensorDataTypes::Ints8Bit|SupportedTensorDataTypes::Float32,
Copy link
Copy Markdown
Contributor

@AnaghaRaoAMD AnaghaRaoAMD Aug 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

contribop mentions TF supports any float tensor type, should we also consider supporting fp16?

Copy link
Copy Markdown
Contributor Author

@zhangxiang1993 zhangxiang1993 Aug 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think only tensor(float) is specified, tensor(float16) is for fp16
Type Constraints
T8 : tensor(uint8), tensor(int8)
Constrain input and output types to 8 bit signed and unsigned tensors.
TF : tensor(float)
Constrain scale types to any float tensor type.
TV : tensor(uint8), tensor(int8), tensor(float)
Sequence of (Tensor, Scale, ZeroPoint) tuples. The type is sequence of (T8, TF, T8).

// This order matches the ONNX schema.
enum OnnxInputIndex
{
yScale,
Copy link
Copy Markdown
Contributor

@fdwr fdwr Aug 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yScale

Suggested change
yScale,
YScale

Consistent casing with Count. #Closed

QLinearConcatHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription())
{

DmlOperator::Initialize(kernelCreationContext);
Copy link
Copy Markdown
Contributor

@fdwr fdwr Aug 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] extra blank line #Closed

auto outputShape = kernelCreationContext.GetTensorShapeDescription().GetOutputTensorShape(0);

// inputs: {y_scale, y_zero_point, tuple(x_tensor, x_scale, x_zero_point)}
uint32_t input_def_count = kernelCreationContext.GetInputCount();
Copy link
Copy Markdown
Contributor

@fdwr fdwr Aug 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_def_count

inputDefinitionCount

Naming intrafile consistency. (also, DML and DML EP always use 🐪, not 🐍) #Closed

std::vector<const DML_OPERATOR_DESC*> opDescs;
for (uint32_t input_index = 0; input_index < input_count; ++input_index)
{
auto tuple_start = 2 + input_index * 3;
Copy link
Copy Markdown
Contributor

@fdwr fdwr Aug 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tuple_start

tupleStartIndex #Resolved

std::vector<DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_DESC> dequantizeOperatorDescs(input_count);
std::vector<DML_OPERATOR_DESC> dmlOpDesc(input_count);
std::vector<const DML_OPERATOR_DESC*> opDescs;
for (uint32_t input_index = 0; input_index < input_count; ++input_index)
Copy link
Copy Markdown
Contributor

@fdwr fdwr Aug 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_count

inputCount

etcetera... #Closed

// inputs: {y_scale, y_zero_point, tuple(x_tensor, x_scale, x_zero_point)}
uint32_t input_def_count = kernelCreationContext.GetInputCount();
ML_CHECK_VALID_ARGUMENT(input_def_count >= 5 && (input_def_count - 2) % 3 == 0,
"Each input must be (tensor, scale, zero_point) tuple!");
Copy link
Copy Markdown
Contributor

@fdwr fdwr Aug 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  [](http://example.com/codeflow?start=8&length=6)

[nit] 6->4 space indent.

Alternately, consider splitting the long line more readably.

        ML_CHECK_VALID_ARGUMENT(
            input_def_count >= 5 && (input_def_count - 2) % 3 == 0,
            "Each input must be (tensor, scale, zero_point) tuple!"
        );

Or better yet, splitting it into two separate conditions, which would make it much clearer the specific error to the user. Also, then the lines are not so long and don't need to wrap:

        ML_CHECK_VALID_ARGUMENT(inputDefinitionCount >= 5, "Require at least 5 inputs.");
        ML_CHECK_VALID_ARGUMENT((inputDefinitionCount - 2) % 3 == 0, "Each input must be (tensor, scale, zero_point) tuple!");

Which is like what you do below, two separate ones rather than &&ing them:

            ML_CHECK_VALID_ARGUMENT(xScaleDataType == yScaleDataType, "Wrong input type encountered for scale");
            ML_CHECK_VALID_ARGUMENT(xZeroPointDataType == yZeroPointDataType, "Wrong input type encountered for zero point");
``` #Resolved

TensorAxis::W,
TensorAxis::RightAligned,
NchwDimensionCount, // minDimensionCount
0 // guaranteedBaseOffsetAlignment)
Copy link
Copy Markdown
Contributor

@fdwr fdwr Aug 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

)

0 // guaranteedBaseOffsetAlignment) ->
0 // guaranteedBaseOffsetAlignment #Resolved

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Indent also inconsistent from all the other TensorDesc calls (8 instead of 4)

joinDesc.OutputTensor = &namedJoinOutputTensorDesc;
joinDesc.Axis = dmlAxis;

const DML_OPERATOR_DESC opJoinDesc{DML_OPERATOR_JOIN, &joinDesc};
Copy link
Copy Markdown
Contributor

@fdwr fdwr Aug 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
const DML_OPERATOR_DESC opJoinDesc{DML_OPERATOR_JOIN, &joinDesc};
const DML_OPERATOR_DESC opJoinDesc = {DML_OPERATOR_JOIN, &joinDesc};

[nit] Consistent assignment style with nearby MLOperatorGraphDesc operatorGraphDesc = {};. #Closed

Copy link
Copy Markdown
Contributor

@fdwr fdwr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nits, but otherwise looks good Xiang.

ML_CHECK_VALID_ARGUMENT(xZeroPointDataType == yZeroPointDataType, "Wrong input type encountered for zero point");

// broadcast x_scale and x_zero_point to shape of corresponding x
m_inputTensorDescs[tupleStartIndex + 1] = TensorDesc(

Check warning

Code scanning / PREfast

Arithmetic overflow: Using operator '+' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '+' to avoid overflow (io.2).

Arithmetic overflow: Using operator '+' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '+' to avoid overflow (io.2).
0 // guaranteedBaseOffsetAlignment
);

m_inputTensorDescs[tupleStartIndex + 2] = TensorDesc(

Check warning

Code scanning / PREfast

Arithmetic overflow: Using operator '+' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '+' to avoid overflow (io.2).

Arithmetic overflow: Using operator '+' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '+' to avoid overflow (io.2).
namedDequantizeOperatorDescs[inputIndex] = intermediateOutputTensorDescs[inputIndex].GetDmlDesc();

dequantizeOperatorDescs[inputIndex].InputTensor = &inputDescs[tupleStartIndex];
dequantizeOperatorDescs[inputIndex].ScaleTensor = &inputDescs[tupleStartIndex + 1];

Check warning

Code scanning / PREfast

Arithmetic overflow: Using operator '+' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '+' to avoid overflow (io.2).

Arithmetic overflow: Using operator '+' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '+' to avoid overflow (io.2).

dequantizeOperatorDescs[inputIndex].InputTensor = &inputDescs[tupleStartIndex];
dequantizeOperatorDescs[inputIndex].ScaleTensor = &inputDescs[tupleStartIndex + 1];
dequantizeOperatorDescs[inputIndex].ZeroPointTensor = &inputDescs[tupleStartIndex + 2];

Check warning

Code scanning / PREfast

Arithmetic overflow: Using operator '+' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '+' to avoid overflow (io.2).

Arithmetic overflow: Using operator '+' on a 4 byte value and then casting the result to a 8 byte value. Cast the value to the wider type before calling operator '+' to avoid overflow (io.2).
@@ -65,7 +65,7 @@ namespace Dml

// Given an axis in ONNX axis numbering, return the axis adjusted for DML based on how the sizes have been coerced.
// Note this function presumes the axis attribute is relative to the first input tensor (which is always the case).
Copy link
Copy Markdown
Contributor

@fdwr fdwr Aug 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note this function presumes the axis attribute is relative to the first input tensor (which is always the case).

Stale comment. It's not always the case anymore. #Resolved

Copy link
Copy Markdown
Contributor

@fdwr fdwr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:shipit:

Copy link
Copy Markdown
Contributor

@fdwr fdwr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TY, XZ.

@zhangxiang1993 zhangxiang1993 merged commit d3345f3 into DmlPrototype Aug 17, 2023
@zhangxiang1993 zhangxiang1993 deleted the user/xianz/QLinearConcat branch August 17, 2023 22:15
AnaghaRaoAMD pushed a commit that referenced this pull request Nov 3, 2023
AnaghaRaoAMD added a commit that referenced this pull request Nov 3, 2023
### Description
[Cherry Pick Reviewed]
```
[ OK ] QLinearConcatS8.ExpectFail_WrongZeroPointType_1 (372 ms)
[ RUN ] QLinearConcatS8.InputOne_Dynamic
[ OK ] QLinearConcatS8.InputOne_Dynamic (255 ms)
[ RUN ] QLinearConcatS8.InputOne_Const
[ OK ] QLinearConcatS8.InputOne_Const (255 ms)
[----------] 11 tests from QLinearConcatS8 (3385 ms total)

[----------] Global test environment tear-down
[==========] 21 tests from 3 test suites ran. (9355 ms total)
[ PASSED ] 21 tests.
```
[#16971](#16971)


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Co-authored-by: Xiang Zhang <xianz@microsoft.com>
jeffbloo pushed a commit that referenced this pull request Jan 4, 2024
### Description
[Cherry Pick Reviewed]
```
[ OK ] QLinearConcatS8.ExpectFail_WrongZeroPointType_1 (372 ms)
[ RUN ] QLinearConcatS8.InputOne_Dynamic
[ OK ] QLinearConcatS8.InputOne_Dynamic (255 ms)
[ RUN ] QLinearConcatS8.InputOne_Const
[ OK ] QLinearConcatS8.InputOne_Const (255 ms)
[----------] 11 tests from QLinearConcatS8 (3385 ms total)

[----------] Global test environment tear-down
[==========] 21 tests from 3 test suites ran. (9355 ms total)
[ PASSED ] 21 tests.
```
[#16971](#16971)

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Co-authored-by: Xiang Zhang <xianz@microsoft.com>
jslap-ubi pushed a commit to cgaudreau-ubisoft/onnxruntime that referenced this pull request Apr 5, 2024
### Description
[Cherry Pick Reviewed]
```
[ OK ] QLinearConcatS8.ExpectFail_WrongZeroPointType_1 (372 ms)
[ RUN ] QLinearConcatS8.InputOne_Dynamic
[ OK ] QLinearConcatS8.InputOne_Dynamic (255 ms)
[ RUN ] QLinearConcatS8.InputOne_Const
[ OK ] QLinearConcatS8.InputOne_Const (255 ms)
[----------] 11 tests from QLinearConcatS8 (3385 ms total)

[----------] Global test environment tear-down
[==========] 21 tests from 3 test suites ran. (9355 ms total)
[ PASSED ] 21 tests.
```
[microsoft#16971](microsoft#16971)

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Co-authored-by: Xiang Zhang <xianz@microsoft.com>
rohan11235813 pushed a commit to quadric-io/onnxruntime that referenced this pull request Aug 19, 2025
### Description
[Cherry Pick Reviewed]
```
[ OK ] QLinearConcatS8.ExpectFail_WrongZeroPointType_1 (372 ms)
[ RUN ] QLinearConcatS8.InputOne_Dynamic
[ OK ] QLinearConcatS8.InputOne_Dynamic (255 ms)
[ RUN ] QLinearConcatS8.InputOne_Const
[ OK ] QLinearConcatS8.InputOne_Const (255 ms)
[----------] 11 tests from QLinearConcatS8 (3385 ms total)

[----------] Global test environment tear-down
[==========] 21 tests from 3 test suites ran. (9355 ms total)
[ PASSED ] 21 tests.
```
[#16971](microsoft/onnxruntime#16971)

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Co-authored-by: Xiang Zhang <xianz@microsoft.com>
rohan11235813 pushed a commit to quadric-io/onnxruntime that referenced this pull request Sep 15, 2025
### Description
[Cherry Pick Reviewed]
```
[ OK ] QLinearConcatS8.ExpectFail_WrongZeroPointType_1 (372 ms)
[ RUN ] QLinearConcatS8.InputOne_Dynamic
[ OK ] QLinearConcatS8.InputOne_Dynamic (255 ms)
[ RUN ] QLinearConcatS8.InputOne_Const
[ OK ] QLinearConcatS8.InputOne_Const (255 ms)
[----------] 11 tests from QLinearConcatS8 (3385 ms total)

[----------] Global test environment tear-down
[==========] 21 tests from 3 test suites ran. (9355 ms total)
[ PASSED ] 21 tests.
```
[#16971](microsoft/onnxruntime#16971)

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Co-authored-by: Xiang Zhang <xianz@microsoft.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants