Skip to content

Commit

Permalink
[DML EP] Add EmbedLayerNorm (microsoft#13868)
Browse files Browse the repository at this point in the history
### Description
Add EmbedLayerNorm to the DML EP
  • Loading branch information
PatriceVignola authored and fuhengwu2021 committed Dec 26, 2022
1 parent d636c0b commit 56690f8
Show file tree
Hide file tree
Showing 8 changed files with 551 additions and 4 deletions.
1 change: 1 addition & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -1124,6 +1124,7 @@ Do not modify directly.*
|BiasGelu|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(float), tensor(float16)|
|ConvTransposeWithDynamicPads|*in* X:**T**<br> *in* W:**T**<br> *in* Pads:**tensor(int64)**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|DequantizeLinear|*in* x:**T1**<br> *in* x_scale:**T2**<br> *in* x_zero_point:**T1**<br> *out* y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
|EmbedLayerNormalization|*in* input_ids:**T1**<br> *in* segment_ids:**T1**<br> *in* word_embedding:**T**<br> *in* position_embedding:**T**<br> *in* segment_embedding:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* mask:**T1**<br> *in* position_ids:**T1**<br> *out* output:**T**<br> *out* mask_index:**T1**<br> *out* embedding_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|FusedMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|QLinearAdd|*in* A:**T**<br> *in* A_scale:**tensor(float)**<br> *in* A_zero_point:**T**<br> *in* B:**T**<br> *in* B_scale:**tensor(float)**<br> *in* B_zero_point:**T**<br> *in* C_scale:**tensor(float)**<br> *in* C_zero_point:**T**<br> *out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)|
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
transformers.emplace_back(std::make_unique<LayerNormFusion>(cpu_cuda_dml_rocm_eps));
transformers.emplace_back(std::make_unique<SimplifiedLayerNormFusion>(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique<AttentionFusion>(cpu_cuda_dml_rocm_eps));
transformers.emplace_back(std::make_unique<EmbedLayerNormFusion>(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique<EmbedLayerNormFusion>(cpu_cuda_dml_rocm_eps));
transformers.emplace_back(std::make_unique<GatherToSplitFusion>(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique<GatherToSliceFusion>(cpu_cuda_rocm_eps));

Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(BatchNormalization15);
DML_OP_EXTERN_CREATION_FUNCTION(LayerNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(LayerNormalization17);
DML_OP_EXTERN_CREATION_FUNCTION(SkipLayerNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(EmbedLayerNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(LRN);
DML_OP_EXTERN_CREATION_FUNCTION(MeanVarianceNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(LpNormalization);
Expand Down Expand Up @@ -750,6 +751,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 11, DynamicQuantizeLinear, typeNameListTwo, supportedTypeListDynamicQuantizeLinear, DmlGraphSupport::Supported)},
{REG_INFO( 7, LayerNormalization, typeNameListLayerNormContrib, supportedTypeListLayerNormalizationContrib, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryLayerNormalization)},
{REG_INFO_MS( 1, SkipLayerNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, EmbedLayerNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
};

template<typename T>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2518,14 +2518,43 @@ namespace OperatorHelper
m_sliceEnd = std::max<uint32_t>(static_cast<uint32_t>(trueEnd), m_sliceStart);
}

std::vector<EdgeShapes> ShapeHelper::GetOutputShapes(const MLShapeInferenceContext & shapeInfo) const
std::vector<EdgeShapes> ShapeHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
return { EdgeShapes({m_sliceEnd - m_sliceStart}) };
}

std::vector<EdgeShapes> SizeHelper::GetOutputShapes(const MLShapeInferenceContext & shapeInfo) const
std::vector<EdgeShapes> SizeHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
return { EdgeShapes({}) };
}

std::vector<EdgeShapes> EmbedLayerNormalizationHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputCount() >= 3);

auto inputIdsShape = shapeInfo.GetInputTensorShape(0);
auto wordEmbeddingShape = shapeInfo.GetInputTensorShape(2);

// input_ids and word_embedding are 2D tensors
ML_CHECK_VALID_ARGUMENT(inputIdsShape.size() == 2);
ML_CHECK_VALID_ARGUMENT(wordEmbeddingShape.size() == 2);

uint32_t batchSize = inputIdsShape[0];
uint32_t sequenceLength = inputIdsShape[1];
uint32_t hiddenSize = wordEmbeddingShape[1];

std::vector<EdgeShapes> outputShapes;
outputShapes.reserve(3);

outputShapes.push_back(EdgeShapes({batchSize, sequenceLength, hiddenSize}));
outputShapes.push_back(EdgeShapes({batchSize}));

if (shapeInfo.GetOutputCount() == 3)
{
outputShapes.push_back(EdgeShapes({batchSize, sequenceLength, hiddenSize}));
}

return outputShapes;
}

} // namespace OperatorHelper
Original file line number Diff line number Diff line change
Expand Up @@ -1383,6 +1383,19 @@ class SizeHelper {
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
};

class EmbedLayerNormalizationHelper
{
void Initialize(
const IKernelInformationAdapter& kernelInformation,
const IShapeInformationAdapter& shapeInformation
);

public:
template <typename Info_t, typename Shape_t>
EmbedLayerNormalizationHelper(const Info_t& info, const Shape_t& shapeInfo) { }
std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
};

using ShapeInferenceHelper_Conv = ConvHelper;
using ShapeInferenceHelper_ConvTranspose = ConvTransposeHelper;
using ShapeInferenceHelper_ConvTransposeWithDynamicPads = ConvTransposeWithDynamicPadsHelper;
Expand All @@ -1406,6 +1419,7 @@ using ShapeInferenceHelper_MeanVarianceNormalization = GetOutputShapeAsInputShap
using ShapeInferenceHelper_LayerNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_LayerNormalization17 = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_SkipLayerNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_EmbedLayerNormalization = EmbedLayerNormalizationHelper;
using ShapeInferenceHelper_LpNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_RNN = RecurrentHelper;
using ShapeInferenceHelper_GRU = RecurrentHelper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ namespace OperatorHelper
static const int sc_sinceVer_QLinearSigmoid = 1;
static const int sc_sinceVer_Attention = 1;
static const int sc_sinceVer_SkipLayerNormalization = 1;
static const int sc_sinceVer_EmbedLayerNormalization = 1;
} // namespace MsftOperatorSet1

} // namespace OperatorHelper
7 changes: 6 additions & 1 deletion onnxruntime/test/contrib_ops/embed_layer_norm_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ static void RunTest(const embedlayernorm::OpData& data,
int min_cuda_architecture = use_float16 ? 530 : 0;

bool enable_cuda = HasCudaEnvironment(min_cuda_architecture);
bool enable_dml = DefaultDmlExecutionProvider().get() != nullptr;
bool enable_cpu = !use_float16;

if (enable_cpu || enable_cuda) {
if (enable_cpu || enable_cuda || enable_dml) {
// Input and output shapes
// Input 0 - input_ids : (batch_size, sequence_size)
// Input 1 - segment_ids : (batch_size, sequence_size)
Expand Down Expand Up @@ -142,6 +143,10 @@ static void RunTest(const embedlayernorm::OpData& data,
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
} else if (enable_dml) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultDmlExecutionProvider());
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
} else {
tester.Run();
}
Expand Down

0 comments on commit 56690f8

Please sign in to comment.