Skip to content

Commit

Permalink
DML EP add einsum MatMul NHCW ops (#13440)
Browse files Browse the repository at this point in the history
### Description
This adds the "NHCW" format support for einsum MatMul. The logic is
basically a merge of the existing Transpose and MatMul Einsum
implementations.



### Motivation and Context
Some transformer models that I'm tracking use Einsum quite often during
a single inference, and about half of those were "NHCW" MatMul Einsums.
Supporting them will reduce the number of copies to the CPU.
  • Loading branch information
PatriceVignola committed Oct 26, 2022
1 parent d5e8d59 commit ac48bde
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class DmlOperatorEinSum : public DmlOperator, public EinSumHelper
{
public:
DmlOperatorEinSum(const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t opsetVersion)
: DmlOperator(kernelCreationContext),
: DmlOperator(kernelCreationContext),
EinSumHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription(), opsetVersion)
{
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() + 1 == m_components.size(), "EinSum input tensor count is inconsistent with the equation component count.");
Expand All @@ -30,7 +30,7 @@ class DmlOperatorEinSum : public DmlOperator, public EinSumHelper
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();

static_assert(RecognizedOperatorType::Total == static_cast<RecognizedOperatorType>(8), "Update this switch.");
static_assert(RecognizedOperatorType::Total == static_cast<RecognizedOperatorType>(11), "Update this switch.");
switch (m_recognizedOperatorType)
{
case RecognizedOperatorType::Multiply:
Expand Down Expand Up @@ -62,6 +62,82 @@ class DmlOperatorEinSum : public DmlOperator, public EinSumHelper
SetDmlOperatorDesc({ DML_OPERATOR_GEMM, &operatorDesc }, kernelCreationContext);
}
break;
case RecognizedOperatorType::MatMulNhcw:
case RecognizedOperatorType::MatMulNhcwTransposeA:
case RecognizedOperatorType::MatMulNhcwTransposeB:
{
// Transpose via input strides. The output tensor is not strided. Support only 4D for now.
assert(m_components.size() == 3);
assert(m_components[0].GetDimensionCount() == m_components[2].GetDimensionCount());
assert(m_components[1].GetDimensionCount() == m_components[2].GetDimensionCount());
assert(m_components[2].GetDimensionCount() == 4);

// Remap transposed strides from NCHW to NHCW
constexpr std::array<uint32_t, 4> labelIndices = {0, 2, 1, 3};

assert(m_inputTensorDescs.size() >= 2);
for (uint32_t i = 0; i < 2; ++i)
{
TensorDesc& tensorDesc = m_inputTensorDescs[i];
auto originalStrides = tensorDesc.GetStrides();
std::vector<uint32_t> inputSizes = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(i);
std::vector<uint32_t> inputStrides(inputSizes.size());

// If there were no strides, compute them based in descending packed order
// based on the input sizes.
if (originalStrides.empty())
{
Dml::GetDescendingPackedStrides(inputSizes, /*out*/ inputStrides);
}
else // Copy the original strides.
{
assert(originalStrides.size() >= inputStrides.size());
size_t offset = originalStrides.size() - inputStrides.size();
inputStrides.assign(originalStrides.begin() + offset, originalStrides.end());
}

std::vector<uint32_t> newStrides(inputStrides.size());
std::vector<uint32_t> newSizes(inputStrides.size());
for (size_t i = 0, dimensionCount = inputStrides.size(); i < dimensionCount; ++i)
{
uint32_t labelIndex = labelIndices[i];
assert(labelIndex < inputStrides.size());
newSizes[i] = inputSizes[labelIndex];
newStrides[i] = inputStrides[labelIndex];
}

// Override the initial input tensor with the new strides.
tensorDesc = TensorDesc(tensorDesc.GetDmlDataType(), newSizes, newStrides, 0);
tensorDesc.GetDmlDesc(); // Discard value, but keep side effect of refreshing the DML view.
}

std::vector<uint32_t> outputSizes = kernelCreationContext.GetTensorShapeDescription().GetOutputTensorShape(0);
std::vector<uint32_t> newOutputSizes(outputSizes.size());
assert(outputSizes.size() == labelIndices.size());

for (size_t i = 0; i < outputSizes.size(); ++i)
{
uint32_t labelIndex = labelIndices[i];
newOutputSizes[i] = outputSizes[labelIndex];
}

m_outputTensorDescs.front() = TensorDesc(m_outputTensorDescs.front().GetDmlDataType(), newOutputSizes, std::nullopt, 0);
m_outputTensorDescs.front().GetDmlDesc(); // Discard value, but keep side effect of refreshing the DML view.

DML_GEMM_OPERATOR_DESC operatorDesc = {};
operatorDesc.ATensor = &inputDescs[0];
operatorDesc.BTensor = &inputDescs[1];
// No operatorDesc.CTensor
operatorDesc.OutputTensor = &outputDescs[0];
operatorDesc.TransA = (m_recognizedOperatorType == RecognizedOperatorType::MatMulNhcwTransposeA) ? DML_MATRIX_TRANSFORM_TRANSPOSE : DML_MATRIX_TRANSFORM_NONE;
operatorDesc.TransB = (m_recognizedOperatorType == RecognizedOperatorType::MatMulNhcwTransposeB) ? DML_MATRIX_TRANSFORM_TRANSPOSE : DML_MATRIX_TRANSFORM_NONE;
operatorDesc.Alpha = 1.0;
operatorDesc.Beta = 0.0;
operatorDesc.FusedActivation = nullptr;

SetDmlOperatorDesc({ DML_OPERATOR_GEMM, &operatorDesc }, kernelCreationContext);
}
break;

case RecognizedOperatorType::ReduceSum:
{
Expand Down Expand Up @@ -176,7 +252,7 @@ void CALLBACK QueryEinSum(IMLOperatorSupportQueryContextPrivate* context, bool*
EinSumHelper helper(attributes);
auto recognizedOperatorType = helper.GetRecognizedOperatorType();

static_assert(EinSumHelper::RecognizedOperatorType::Total == static_cast<EinSumHelper::RecognizedOperatorType>(8), "Verify this test still matches the switch above.");
static_assert(EinSumHelper::RecognizedOperatorType::Total == static_cast<EinSumHelper::RecognizedOperatorType>(11), "Update this function.");
*isSupported = (recognizedOperatorType != EinSumHelper::RecognizedOperatorType::None);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ namespace OperatorHelper
// `transBatch` needs to be applied first and then `transpose`.
if (transBatch)
{
ML_CHECK_VALID_ARGUMENT(dimensionCount > 2,
ML_CHECK_VALID_ARGUMENT(dimensionCount > 2,
"FusedMatMul operator: Tensor size should be more than 2, if attribute transBatch is true");

std::rotate(newSizes.begin(), newSizes.end() - 2, newSizes.end() - 1);
Expand Down Expand Up @@ -702,7 +702,7 @@ namespace OperatorHelper
if (inputShape0 != inputShape1)
{
ML_CHECK_VALID_ARGUMENT(
inputShape0.size() == inputShape1.size() &&
inputShape0.size() == inputShape1.size() &&
inputShape0.size() == inputStride0.size() &&
inputStride0.size() == inputStride1.size(),
"Size of inputShape0, inputStride0, inputShape1 and inputStride1 should be same while broadcasting");
Expand All @@ -715,7 +715,7 @@ namespace OperatorHelper

auto inStride0Iter = inputStride0.rbegin();
auto inStride1Iter = inputStride1.rbegin();

while (rank-- > 0)
{
DimensionType inDimension0 = *inDim0Iter;
Expand Down Expand Up @@ -1503,18 +1503,21 @@ namespace OperatorHelper
};

const RecognizedOperatorInfo recognizedOperators[] = {
{RecognizedOperatorType::MatMul, {2,2,2},{0,1, 1,2, 0,2}}, // ij,jk->ik
{RecognizedOperatorType::MatMul, {3,3,3},{0,1,2, 0,2,3, 0,1,3}}, // bij,bjk->bik
{RecognizedOperatorType::MatMul, {4,4,4},{0,1,2,3, 0,1,3,4, 0,1,2,4}}, // abij,abjk->abik
{RecognizedOperatorType::MatMulTransposeA, {2,2,2},{0,1, 0,2, 1,2}}, // ji,jk->ik
{RecognizedOperatorType::MatMulTransposeA, {3,3,3},{0,1,2, 0,1,3, 0,2,3}}, // bji,bjk->bik
{RecognizedOperatorType::MatMulTransposeA, {4,4,4},{0,1,2,3, 0,1,2,4, 0,1,3,4}}, // abji,abjk->abik
{RecognizedOperatorType::MatMulTransposeB, {2,2,2},{0,1, 2,1, 0,2}}, // ij,kj->ik
{RecognizedOperatorType::MatMulTransposeB, {3,3,3},{0,1,2, 0,3,2, 0,1,3}}, // bij,bkj->bik
{RecognizedOperatorType::MatMulTransposeB, {4,4,4},{0,1,2,3, 0,1,4,3, 0,1,2,4}}, // abij,abkj->abik
{RecognizedOperatorType::MatMulTransposeB, {1,1,0},{0,0,}}, // i,i-> (1D inner_prod)
{RecognizedOperatorType::ReduceSum, {2,1 },{0,1, 0}}, // ij->i
{RecognizedOperatorType::ReduceSum, {2,1 },{0,1, 1}}, // ij->j
{RecognizedOperatorType::MatMul, {2,2,2},{0,1, 1,2, 0,2}}, // ij,jk->ik
{RecognizedOperatorType::MatMul, {3,3,3},{0,1,2, 0,2,3, 0,1,3}}, // bij,bjk->bik
{RecognizedOperatorType::MatMul, {4,4,4},{0,1,2,3, 0,1,3,4, 0,1,2,4}}, // abij,abjk->abik
{RecognizedOperatorType::MatMulTransposeA, {2,2,2},{0,1, 0,2, 1,2}}, // ji,jk->ik
{RecognizedOperatorType::MatMulTransposeA, {3,3,3},{0,1,2, 0,1,3, 0,2,3}}, // bji,bjk->bik
{RecognizedOperatorType::MatMulTransposeA, {4,4,4},{0,1,2,3, 0,1,2,4, 0,1,3,4}}, // abji,abjk->abik
{RecognizedOperatorType::MatMulTransposeB, {2,2,2},{0,1, 2,1, 0,2}}, // ij,kj->ik
{RecognizedOperatorType::MatMulTransposeB, {3,3,3},{0,1,2, 0,3,2, 0,1,3}}, // bij,bkj->bik
{RecognizedOperatorType::MatMulTransposeB, {4,4,4},{0,1,2,3, 0,1,4,3, 0,1,2,4}}, // abij,abkj->abik
{RecognizedOperatorType::MatMulTransposeB, {1,1,0},{0,0,}}, // i,i-> (1D inner_prod)
{RecognizedOperatorType::MatMulNhcw, {4,4,4},{0,1,2,3, 0,3,2,4, 0,1,2,4}}, // aibj,ajbk->aibk
{RecognizedOperatorType::MatMulNhcwTransposeA, {4,4,4},{0,1,2,3, 0,1,2,4, 0,3,2,4}}, // ajbi,ajbk->aibk
{RecognizedOperatorType::MatMulNhcwTransposeB, {4,4,4},{0,1,2,3, 0,4,2,3, 0,1,2,4}}, // aibj,akbj->aibk
{RecognizedOperatorType::ReduceSum, {2,1 },{0,1, 0}}, // ij->i
{RecognizedOperatorType::ReduceSum, {2,1 },{0,1, 1}}, // ij->j
};

// For each recognized operator, compare the labels-per-component and label indices.
Expand Down Expand Up @@ -1595,7 +1598,10 @@ namespace OperatorHelper
{
return m_recognizedOperatorType == RecognizedOperatorType::MatMul ||
m_recognizedOperatorType == RecognizedOperatorType::MatMulTransposeA ||
m_recognizedOperatorType == RecognizedOperatorType::MatMulTransposeB;
m_recognizedOperatorType == RecognizedOperatorType::MatMulTransposeB ||
m_recognizedOperatorType == RecognizedOperatorType::MatMulNhcw ||
m_recognizedOperatorType == RecognizedOperatorType::MatMulNhcwTransposeA ||
m_recognizedOperatorType == RecognizedOperatorType::MatMulNhcwTransposeB;
}

std::vector<EdgeShapes> MatMulHelperBase::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ void FusedMatMulShapeMapping(
std::vector<DimensionType>& outputShape);

std::pair<std::vector<uint32_t>, std::vector<uint32_t>> GetFusedMatMulSizesAndStrides(
gsl::span<const uint32_t> sizes,
gsl::span<const uint32_t> sizes,
int32_t transBatch = 0,
int32_t transpose = 0);

Expand Down Expand Up @@ -437,15 +437,15 @@ class ConvolutionHelperBase
enum InputDims { N, C, H, W };

public:
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
template<typename Info_t, typename Shape_t>
ConvolutionHelperBase(const Info_t& info, const Shape_t& shape, bool transpose, bool hasDynamicPads, uint32_t inputTensorIndex, uint32_t filterTensorIndex) :
m_inputTensorIndex(inputTensorIndex),
m_filterTensorIndex(filterTensorIndex),
m_kernel(InitializeKernel(info, shape.GetInputTensorDimensionCount(inputTensorIndex), shape.GetInputTensorShape(filterTensorIndex)))
{
m_groupCount = info.template GetOptionalAttribute<uint32_t>(AttrName::Group, 1);

if (!transpose)
{
InitializeKernelAndShapes(ShapeInformationAdapter(shape));
Expand Down Expand Up @@ -507,8 +507,8 @@ class QLinearConvHelper : public ConvolutionHelperBase
class GemmHelper
{
public:
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
template<typename Info_t, typename Shape_t>
GemmHelper(const Info_t& info, const Shape_t& shape)
{
Expand Down Expand Up @@ -591,8 +591,8 @@ class SliceHelper
);

public:
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
template<typename Info_t, typename Shape_t>
SliceHelper(const Info_t& info, const Shape_t& shape, uint32_t opsetVersion)
{
Expand Down Expand Up @@ -722,6 +722,9 @@ class EinSumHelper
MatMul,
MatMulTransposeA,
MatMulTransposeB,
MatMulNhcw,
MatMulNhcwTransposeA,
MatMulNhcwTransposeB,
ReduceSum,
Transpose,
Total,
Expand All @@ -740,7 +743,7 @@ class EinSumHelper
{
uint32_t labelIndexBegin;
uint32_t labelIndexEnd;

uint32_t GetDimensionCount() const noexcept
{
return labelIndexEnd - labelIndexBegin;
Expand Down Expand Up @@ -1037,8 +1040,8 @@ class PoolingHelperBase
class UnpoolingHelper
{
public:
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
template<typename Info_t, typename Shape_t>
UnpoolingHelper(
const Info_t& info,
Expand Down
27 changes: 27 additions & 0 deletions onnxruntime/test/providers/cpu/math/einsum_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,33 @@ TEST(Einsum, ExplicitEinsumAsMatmul) {
test.Run();
}

TEST(Einsum, ExplicitEinsumAsMatmulNhcw) {
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
test.AddAttribute<std::string>("equation", "aibj,ajbk->aibk");
test.AddInput<float>("x", {1, 3, 1, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
test.AddInput<float>("y", {1, 2, 1, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
test.AddOutput<float>("o", {1, 3, 1, 3}, {9.f, 12.f, 15.f, 19.f, 26.f, 33.f, 29.f, 40.f, 51.f});
test.Run();
}

TEST(Einsum, ExplicitEinsumAsMatmulNhcwTransposeA) {
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
test.AddAttribute<std::string>("equation", "ajbi,ajbk->aibk");
test.AddInput<float>("x", {1, 2, 1, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
test.AddInput<float>("y", {1, 2, 1, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
test.AddOutput<float>("o", {1, 3, 1, 3}, {17.f, 22.f, 27.f, 22.f, 29.f, 36.f, 27.f, 36.f, 45.f});
test.Run();
}

TEST(Einsum, ExplicitEinsumAsMatmulNhcwTransposeB) {
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
test.AddAttribute<std::string>("equation", "aibj,akbj->aibk");
test.AddInput<float>("x", {1, 3, 1, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
test.AddInput<float>("y", {1, 3, 1, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
test.AddOutput<float>("o", {1, 3, 1, 3}, {5.f, 11.f, 17.f, 11.f, 25.f, 39.f, 17.f, 39.f, 61.f});
test.Run();
}

TEST(Einsum, ExplicitEinsumAsMatmulWithUpperCasedLabel) {
OpTester test("Einsum", 12, onnxruntime::kOnnxDomain);
// 'K' != 'k' (and dim values differ too) and Einsum should handle be able to handle that
Expand Down

0 comments on commit ac48bde

Please sign in to comment.