Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DML EP add einsum MatMul NHCW ops #13440

Merged
merged 3 commits into from
Oct 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class DmlOperatorEinSum : public DmlOperator, public EinSumHelper
{
public:
DmlOperatorEinSum(const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t opsetVersion)
: DmlOperator(kernelCreationContext),
: DmlOperator(kernelCreationContext),
EinSumHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription(), opsetVersion)
{
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() + 1 == m_components.size(), "EinSum input tensor count is inconsistent with the equation component count.");
Expand All @@ -30,7 +30,7 @@ class DmlOperatorEinSum : public DmlOperator, public EinSumHelper
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();

static_assert(RecognizedOperatorType::Total == static_cast<RecognizedOperatorType>(8), "Update this switch.");
static_assert(RecognizedOperatorType::Total == static_cast<RecognizedOperatorType>(11), "Update this switch.");
switch (m_recognizedOperatorType)
{
case RecognizedOperatorType::Multiply:
Expand Down Expand Up @@ -62,6 +62,82 @@ class DmlOperatorEinSum : public DmlOperator, public EinSumHelper
SetDmlOperatorDesc({ DML_OPERATOR_GEMM, &operatorDesc }, kernelCreationContext);
}
break;
case RecognizedOperatorType::MatMulNhcw:
case RecognizedOperatorType::MatMulNhcwTransposeA:
case RecognizedOperatorType::MatMulNhcwTransposeB:
{
// Transpose via input strides. The output tensor is not strided. Support only 4D for now.
assert(m_components.size() == 3);
assert(m_components[0].GetDimensionCount() == m_components[2].GetDimensionCount());
assert(m_components[1].GetDimensionCount() == m_components[2].GetDimensionCount());
assert(m_components[2].GetDimensionCount() == 4);

// Remap transposed strides from NCHW to NHCW
constexpr std::array<uint32_t, 4> labelIndices = {0, 2, 1, 3};

assert(m_inputTensorDescs.size() >= 2);
for (uint32_t i = 0; i < 2; ++i)
{
TensorDesc& tensorDesc = m_inputTensorDescs[i];
auto originalStrides = tensorDesc.GetStrides();
std::vector<uint32_t> inputSizes = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(i);
std::vector<uint32_t> inputStrides(inputSizes.size());

// If there were no strides, compute them based in descending packed order
// based on the input sizes.
if (originalStrides.empty())
{
Dml::GetDescendingPackedStrides(inputSizes, /*out*/ inputStrides);
}
else // Copy the original strides.
{
assert(originalStrides.size() >= inputStrides.size());
size_t offset = originalStrides.size() - inputStrides.size();
inputStrides.assign(originalStrides.begin() + offset, originalStrides.end());
}

std::vector<uint32_t> newStrides(inputStrides.size());
std::vector<uint32_t> newSizes(inputStrides.size());
for (size_t i = 0, dimensionCount = inputStrides.size(); i < dimensionCount; ++i)
{
uint32_t labelIndex = labelIndices[i];
assert(labelIndex < inputStrides.size());
newSizes[i] = inputSizes[labelIndex];
newStrides[i] = inputStrides[labelIndex];
}

// Override the initial input tensor with the new strides.
tensorDesc = TensorDesc(tensorDesc.GetDmlDataType(), newSizes, newStrides, 0);
tensorDesc.GetDmlDesc(); // Discard value, but keep side effect of refreshing the DML view.
}

std::vector<uint32_t> outputSizes = kernelCreationContext.GetTensorShapeDescription().GetOutputTensorShape(0);
std::vector<uint32_t> newOutputSizes(outputSizes.size());
assert(outputSizes.size() == labelIndices.size());

for (size_t i = 0; i < outputSizes.size(); ++i)
{
uint32_t labelIndex = labelIndices[i];
newOutputSizes[i] = outputSizes[labelIndex];
}

m_outputTensorDescs.front() = TensorDesc(m_outputTensorDescs.front().GetDmlDataType(), newOutputSizes, std::nullopt, 0);
m_outputTensorDescs.front().GetDmlDesc(); // Discard value, but keep side effect of refreshing the DML view.

DML_GEMM_OPERATOR_DESC operatorDesc = {};
operatorDesc.ATensor = &inputDescs[0];
operatorDesc.BTensor = &inputDescs[1];
// No operatorDesc.CTensor
operatorDesc.OutputTensor = &outputDescs[0];
operatorDesc.TransA = (m_recognizedOperatorType == RecognizedOperatorType::MatMulNhcwTransposeA) ? DML_MATRIX_TRANSFORM_TRANSPOSE : DML_MATRIX_TRANSFORM_NONE;
operatorDesc.TransB = (m_recognizedOperatorType == RecognizedOperatorType::MatMulNhcwTransposeB) ? DML_MATRIX_TRANSFORM_TRANSPOSE : DML_MATRIX_TRANSFORM_NONE;
operatorDesc.Alpha = 1.0;
operatorDesc.Beta = 0.0;
operatorDesc.FusedActivation = nullptr;

SetDmlOperatorDesc({ DML_OPERATOR_GEMM, &operatorDesc }, kernelCreationContext);
}
break;

case RecognizedOperatorType::ReduceSum:
{
Expand Down Expand Up @@ -176,7 +252,7 @@ void CALLBACK QueryEinSum(IMLOperatorSupportQueryContextPrivate* context, bool*
EinSumHelper helper(attributes);
auto recognizedOperatorType = helper.GetRecognizedOperatorType();
Copy link
Contributor

@fdwr fdwr Oct 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forget if we have shape information at this time, but it would be useful to check if available. There's a comment above saying we don't support 4D in the DML EP, which means hypothetically an existing 5D EinSum that used to run (albeit falling back to CPU) will now fail to run. #Resolved

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, IIRC, the shape of the tensor must be compatible in the expression, and so we do already have enough info to proceed safely, and the existing code should work as-is.


static_assert(EinSumHelper::RecognizedOperatorType::Total == static_cast<EinSumHelper::RecognizedOperatorType>(8), "Verify this test still matches the switch above.");
static_assert(EinSumHelper::RecognizedOperatorType::Total == static_cast<EinSumHelper::RecognizedOperatorType>(11), "Update this function.");
*isSupported = (recognizedOperatorType != EinSumHelper::RecognizedOperatorType::None);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ namespace OperatorHelper
// `transBatch` needs to be applied first and then `transpose`.
if (transBatch)
{
ML_CHECK_VALID_ARGUMENT(dimensionCount > 2,
ML_CHECK_VALID_ARGUMENT(dimensionCount > 2,
"FusedMatMul operator: Tensor size should be more than 2, if attribute transBatch is true");

std::rotate(newSizes.begin(), newSizes.end() - 2, newSizes.end() - 1);
Expand Down Expand Up @@ -702,7 +702,7 @@ namespace OperatorHelper
if (inputShape0 != inputShape1)
{
ML_CHECK_VALID_ARGUMENT(
inputShape0.size() == inputShape1.size() &&
inputShape0.size() == inputShape1.size() &&
inputShape0.size() == inputStride0.size() &&
inputStride0.size() == inputStride1.size(),
"Size of inputShape0, inputStride0, inputShape1 and inputStride1 should be same while broadcasting");
Expand All @@ -715,7 +715,7 @@ namespace OperatorHelper

auto inStride0Iter = inputStride0.rbegin();
auto inStride1Iter = inputStride1.rbegin();

while (rank-- > 0)
{
DimensionType inDimension0 = *inDim0Iter;
Expand Down Expand Up @@ -1503,18 +1503,21 @@ namespace OperatorHelper
};

const RecognizedOperatorInfo recognizedOperators[] = {
{RecognizedOperatorType::MatMul, {2,2,2},{0,1, 1,2, 0,2}}, // ij,jk->ik
{RecognizedOperatorType::MatMul, {3,3,3},{0,1,2, 0,2,3, 0,1,3}}, // bij,bjk->bik
{RecognizedOperatorType::MatMul, {4,4,4},{0,1,2,3, 0,1,3,4, 0,1,2,4}}, // abij,abjk->abik
{RecognizedOperatorType::MatMulTransposeA, {2,2,2},{0,1, 0,2, 1,2}}, // ji,jk->ik
{RecognizedOperatorType::MatMulTransposeA, {3,3,3},{0,1,2, 0,1,3, 0,2,3}}, // bji,bjk->bik
{RecognizedOperatorType::MatMulTransposeA, {4,4,4},{0,1,2,3, 0,1,2,4, 0,1,3,4}}, // abji,abjk->abik
{RecognizedOperatorType::MatMulTransposeB, {2,2,2},{0,1, 2,1, 0,2}}, // ij,kj->ik
{RecognizedOperatorType::MatMulTransposeB, {3,3,3},{0,1,2, 0,3,2, 0,1,3}}, // bij,bkj->bik
{RecognizedOperatorType::MatMulTransposeB, {4,4,4},{0,1,2,3, 0,1,4,3, 0,1,2,4}}, // abij,abkj->abik
{RecognizedOperatorType::MatMulTransposeB, {1,1,0},{0,0,}}, // i,i-> (1D inner_prod)
{RecognizedOperatorType::ReduceSum, {2,1 },{0,1, 0}}, // ij->i
{RecognizedOperatorType::ReduceSum, {2,1 },{0,1, 1}}, // ij->j
{RecognizedOperatorType::MatMul, {2,2,2},{0,1, 1,2, 0,2}}, // ij,jk->ik
{RecognizedOperatorType::MatMul, {3,3,3},{0,1,2, 0,2,3, 0,1,3}}, // bij,bjk->bik
{RecognizedOperatorType::MatMul, {4,4,4},{0,1,2,3, 0,1,3,4, 0,1,2,4}}, // abij,abjk->abik
{RecognizedOperatorType::MatMulTransposeA, {2,2,2},{0,1, 0,2, 1,2}}, // ji,jk->ik
{RecognizedOperatorType::MatMulTransposeA, {3,3,3},{0,1,2, 0,1,3, 0,2,3}}, // bji,bjk->bik
{RecognizedOperatorType::MatMulTransposeA, {4,4,4},{0,1,2,3, 0,1,2,4, 0,1,3,4}}, // abji,abjk->abik
{RecognizedOperatorType::MatMulTransposeB, {2,2,2},{0,1, 2,1, 0,2}}, // ij,kj->ik
{RecognizedOperatorType::MatMulTransposeB, {3,3,3},{0,1,2, 0,3,2, 0,1,3}}, // bij,bkj->bik
{RecognizedOperatorType::MatMulTransposeB, {4,4,4},{0,1,2,3, 0,1,4,3, 0,1,2,4}}, // abij,abkj->abik
{RecognizedOperatorType::MatMulTransposeB, {1,1,0},{0,0,}}, // i,i-> (1D inner_prod)
{RecognizedOperatorType::MatMulNhcw, {4,4,4},{0,1,2,3, 0,3,2,4, 0,1,2,4}}, // aibj,ajbk->aibk
{RecognizedOperatorType::MatMulNhcwTransposeA, {4,4,4},{0,1,2,3, 0,1,2,4, 0,3,2,4}}, // ajbi,ajbk->aibk
{RecognizedOperatorType::MatMulNhcwTransposeB, {4,4,4},{0,1,2,3, 0,4,2,3, 0,1,2,4}}, // aibj,akbj->aibk
{RecognizedOperatorType::ReduceSum, {2,1 },{0,1, 0}}, // ij->i
{RecognizedOperatorType::ReduceSum, {2,1 },{0,1, 1}}, // ij->j
};

// For each recognized operator, compare the labels-per-component and label indices.
Expand Down Expand Up @@ -1595,7 +1598,10 @@ namespace OperatorHelper
{
return m_recognizedOperatorType == RecognizedOperatorType::MatMul ||
m_recognizedOperatorType == RecognizedOperatorType::MatMulTransposeA ||
m_recognizedOperatorType == RecognizedOperatorType::MatMulTransposeB;
m_recognizedOperatorType == RecognizedOperatorType::MatMulTransposeB ||
m_recognizedOperatorType == RecognizedOperatorType::MatMulNhcw ||
m_recognizedOperatorType == RecognizedOperatorType::MatMulNhcwTransposeA ||
m_recognizedOperatorType == RecognizedOperatorType::MatMulNhcwTransposeB;
}

std::vector<EdgeShapes> MatMulHelperBase::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ void FusedMatMulShapeMapping(
std::vector<DimensionType>& outputShape);

std::pair<std::vector<uint32_t>, std::vector<uint32_t>> GetFusedMatMulSizesAndStrides(
gsl::span<const uint32_t> sizes,
gsl::span<const uint32_t> sizes,
int32_t transBatch = 0,
int32_t transpose = 0);

Expand Down Expand Up @@ -437,15 +437,15 @@ class ConvolutionHelperBase
enum InputDims { N, C, H, W };

public:
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
template<typename Info_t, typename Shape_t>
ConvolutionHelperBase(const Info_t& info, const Shape_t& shape, bool transpose, bool hasDynamicPads, uint32_t inputTensorIndex, uint32_t filterTensorIndex) :
m_inputTensorIndex(inputTensorIndex),
m_filterTensorIndex(filterTensorIndex),
m_kernel(InitializeKernel(info, shape.GetInputTensorDimensionCount(inputTensorIndex), shape.GetInputTensorShape(filterTensorIndex)))
{
m_groupCount = info.template GetOptionalAttribute<uint32_t>(AttrName::Group, 1);

if (!transpose)
{
InitializeKernelAndShapes(ShapeInformationAdapter(shape));
Expand Down Expand Up @@ -507,8 +507,8 @@ class QLinearConvHelper : public ConvolutionHelperBase
class GemmHelper
{
public:
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
template<typename Info_t, typename Shape_t>
GemmHelper(const Info_t& info, const Shape_t& shape)
{
Expand Down Expand Up @@ -591,8 +591,8 @@ class SliceHelper
);

public:
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
template<typename Info_t, typename Shape_t>
SliceHelper(const Info_t& info, const Shape_t& shape, uint32_t opsetVersion)
{
Expand Down Expand Up @@ -722,6 +722,9 @@ class EinSumHelper
MatMul,
MatMulTransposeA,
MatMulTransposeB,
MatMulNhcw,
MatMulNhcwTransposeA,
MatMulNhcwTransposeB,
ReduceSum,
Transpose,
Total,
Expand All @@ -740,7 +743,7 @@ class EinSumHelper
{
uint32_t labelIndexBegin;
uint32_t labelIndexEnd;

uint32_t GetDimensionCount() const noexcept
{
return labelIndexEnd - labelIndexBegin;
Expand Down Expand Up @@ -1037,8 +1040,8 @@ class PoolingHelperBase
class UnpoolingHelper
{
public:
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
template<typename Info_t, typename Shape_t>
UnpoolingHelper(
const Info_t& info,
Expand Down
27 changes: 27 additions & 0 deletions onnxruntime/test/providers/cpu/math/einsum_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,33 @@ TEST(Einsum, ExplicitEinsumAsMatmul) {
test.Run();
}

TEST(Einsum, ExplicitEinsumAsMatmulNhcw) {
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
test.AddAttribute<std::string>("equation", "aibj,ajbk->aibk");
test.AddInput<float>("x", {1, 3, 1, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
test.AddInput<float>("y", {1, 2, 1, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
test.AddOutput<float>("o", {1, 3, 1, 3}, {9.f, 12.f, 15.f, 19.f, 26.f, 33.f, 29.f, 40.f, 51.f});
test.Run();
}

TEST(Einsum, ExplicitEinsumAsMatmulNhcwTransposeA) {
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
test.AddAttribute<std::string>("equation", "ajbi,ajbk->aibk");
test.AddInput<float>("x", {1, 2, 1, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
test.AddInput<float>("y", {1, 2, 1, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
test.AddOutput<float>("o", {1, 3, 1, 3}, {17.f, 22.f, 27.f, 22.f, 29.f, 36.f, 27.f, 36.f, 45.f});
test.Run();
}

TEST(Einsum, ExplicitEinsumAsMatmulNhcwTransposeB) {
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
test.AddAttribute<std::string>("equation", "aibj,akbj->aibk");
test.AddInput<float>("x", {1, 3, 1, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
test.AddInput<float>("y", {1, 3, 1, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
test.AddOutput<float>("o", {1, 3, 1, 3}, {5.f, 11.f, 17.f, 11.f, 25.f, 39.f, 17.f, 39.f, 61.f});
test.Run();
}

TEST(Einsum, ExplicitEinsumAsMatmulWithUpperCasedLabel) {
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
// 'K' != 'k' (and dim values differ too) and Einsum should handle be able to handle that
Expand Down