Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DML EP] Add EmbedLayerNorm #13868

Merged
merged 5 commits into from
Dec 13, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -1123,6 +1123,7 @@ Do not modify directly.*
|BiasGelu|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(float), tensor(float16)|
|ConvTransposeWithDynamicPads|*in* X:**T**<br> *in* W:**T**<br> *in* Pads:**tensor(int64)**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|DequantizeLinear|*in* x:**T1**<br> *in* x_scale:**T2**<br> *in* x_zero_point:**T1**<br> *out* y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
|EmbedLayerNormalization|*in* input_ids:**T1**<br> *in* segment_ids:**T1**<br> *in* word_embedding:**T**<br> *in* position_embedding:**T**<br> *in* segment_embedding:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* mask:**T1**<br> *in* position_ids:**T1**<br> *out* output:**T**<br> *out* mask_index:**T1**<br> *out* embedding_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|FusedMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|QLinearAdd|*in* A:**T**<br> *in* A_scale:**tensor(float)**<br> *in* A_zero_point:**T**<br> *in* B:**T**<br> *in* B_scale:**tensor(float)**<br> *in* B_zero_point:**T**<br> *in* C_scale:**tensor(float)**<br> *in* C_zero_point:**T**<br> *out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)|
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
transformers.emplace_back(std::make_unique<LayerNormFusion>(cpu_cuda_dml_rocm_eps));
transformers.emplace_back(std::make_unique<SimplifiedLayerNormFusion>(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique<AttentionFusion>(cpu_cuda_dml_rocm_eps));
transformers.emplace_back(std::make_unique<EmbedLayerNormFusion>(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique<EmbedLayerNormFusion>(cpu_cuda_dml_rocm_eps));
transformers.emplace_back(std::make_unique<GatherToSplitFusion>(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique<GatherToSliceFusion>(cpu_cuda_rocm_eps));

Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(BatchNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(BatchNormalization15);
DML_OP_EXTERN_CREATION_FUNCTION(LayerNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(LayerNormalization17);
DML_OP_EXTERN_CREATION_FUNCTION(EmbedLayerNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(LRN);
DML_OP_EXTERN_CREATION_FUNCTION(MeanVarianceNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(LpNormalization);
Expand Down Expand Up @@ -748,6 +749,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)},
{REG_INFO( 11, DynamicQuantizeLinear, typeNameListTwo, supportedTypeListDynamicQuantizeLinear, DmlGraphSupport::Supported)},
{REG_INFO( 7, LayerNormalization, typeNameListLayerNormContrib, supportedTypeListLayerNormalizationContrib, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryLayerNormalization)},
{REG_INFO_MS( 1, EmbedLayerNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
};

template<typename T>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2518,14 +2518,43 @@ namespace OperatorHelper
m_sliceEnd = std::max<uint32_t>(static_cast<uint32_t>(trueEnd), m_sliceStart);
}

std::vector<EdgeShapes> ShapeHelper::GetOutputShapes(const MLShapeInferenceContext & shapeInfo) const
std::vector<EdgeShapes> ShapeHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
return { EdgeShapes({m_sliceEnd - m_sliceStart}) };
}

std::vector<EdgeShapes> SizeHelper::GetOutputShapes(const MLShapeInferenceContext & shapeInfo) const
std::vector<EdgeShapes> SizeHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
return { EdgeShapes({}) };
}

std::vector<EdgeShapes> EmbedLayerNormalizationHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputCount() >= 3);

auto inputIdsShape = shapeInfo.GetInputTensorShape(0);
auto wordEmbeddingShape = shapeInfo.GetInputTensorShape(2);

// input_ids and word_embedding are 2D tensors
ML_CHECK_VALID_ARGUMENT(inputIdsShape.size() == 2);
ML_CHECK_VALID_ARGUMENT(wordEmbeddingShape.size() == 2);

uint32_t batchSize = inputIdsShape[0];
uint32_t sequenceLength = inputIdsShape[1];
uint32_t hiddenSize = wordEmbeddingShape[1];

std::vector<EdgeShapes> outputShapes;
outputShapes.reserve(3);

outputShapes.push_back(EdgeShapes({batchSize, sequenceLength, hiddenSize}));
outputShapes.push_back(EdgeShapes({batchSize}));

if (shapeInfo.GetOutputCount() == 3)
{
outputShapes.push_back(EdgeShapes({batchSize, sequenceLength, hiddenSize}));
}

return outputShapes;
}

} // namespace OperatorHelper
Original file line number Diff line number Diff line change
Expand Up @@ -1383,6 +1383,19 @@ class SizeHelper {
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
};

class EmbedLayerNormalizationHelper
{
void Initialize(
const IKernelInformationAdapter& kernelInformation,
const IShapeInformationAdapter& shapeInformation
);

public:
template <typename Info_t, typename Shape_t>
EmbedLayerNormalizationHelper(const Info_t& info, const Shape_t& shapeInfo) { }
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
};

using ShapeInferenceHelper_Conv = ConvHelper;
using ShapeInferenceHelper_ConvTranspose = ConvTransposeHelper;
using ShapeInferenceHelper_ConvTransposeWithDynamicPads = ConvTransposeWithDynamicPadsHelper;
Expand All @@ -1405,6 +1418,7 @@ using ShapeInferenceHelper_LRN = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_MeanVarianceNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_LayerNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_LayerNormalization17 = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_EmbedLayerNormalization = EmbedLayerNormalizationHelper;
using ShapeInferenceHelper_LpNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_RNN = RecurrentHelper;
using ShapeInferenceHelper_GRU = RecurrentHelper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ namespace OperatorHelper
static const int sc_sinceVer_FusedMatMul = 1;
static const int sc_sinceVer_QLinearSigmoid = 1;
static const int sc_sinceVer_Attention = 1;
static const int sc_sinceVer_EmbedLayerNormalization = 1;
} // namespace MsftOperatorSet1

} // namespace OperatorHelper
7 changes: 6 additions & 1 deletion onnxruntime/test/contrib_ops/embed_layer_norm_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ static void RunTest(const embedlayernorm::OpData& data,
int min_cuda_architecture = use_float16 ? 530 : 0;

bool enable_cuda = HasCudaEnvironment(min_cuda_architecture);
bool enable_dml = DefaultDmlExecutionProvider().get() != nullptr;
bool enable_cpu = !use_float16;

if (enable_cpu || enable_cuda) {
if (enable_cpu || enable_cuda || enable_dml) {
// Input and output shapes
// Input 0 - input_ids : (batch_size, sequence_size)
// Input 1 - segment_ids : (batch_size, sequence_size)
Expand Down Expand Up @@ -142,6 +143,10 @@ static void RunTest(const embedlayernorm::OpData& data,
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
} else if (enable_dml) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultDmlExecutionProvider());
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
} else {
tester.Run();
}
Expand Down