Skip to content

SoftmaxCrossEntropyLossInternalGrad and Sum Fusion#12746

Merged
Lafi7e merged 7 commits intomainfrom
weicwang/sce_grad_add
Sep 14, 2022
Merged

SoftmaxCrossEntropyLossInternalGrad and Sum Fusion#12746
Lafi7e merged 7 commits intomainfrom
weicwang/sce_grad_add

Conversation

@Lafi7e
Copy link
Contributor

@Lafi7e Lafi7e commented Aug 26, 2022

We observed the SoftmaxCrossEntropyLossInternalGrad+Sum pattern in more than one customer models, both of the nodes need ~10ms to compute when the input tensor shape is relatively big, especially with big vocab size. The PR is to fuse this two Ops to a single one. In CUDA/ROCm EP, only one fused Ops is launched, so that the total execution is reduced by half. We can also observe >2% perf gain for the whole model from the throughput.

@Lafi7e Lafi7e added the training issues related to ONNX Runtime training; typically submitted using template label Aug 26, 2022
@Lafi7e Lafi7e requested review from askhade, pengwa and zhijxu-MS August 26, 2022 07:57
continue;
}

bool has_same_shape = true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recalled Shape can be compared with ==.

>   if (input->Shape() != skip->Shape()) {
>     return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
>                            "skip is expected to have same shape as input");
>   }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are different "shape". ORT's tensor shape is onnxruntime::TensorShape, the "shape" in ONNX graph (in transformer code) is onnx::TensorShapeProto.


namespace onnxruntime {

Status SceLossGradBiasFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the fusion graph is looks like this?

SoftmaxCrossEntropyLossInternalGrad --> optional Reshape --> Add|Sum ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right

d_logit->Reshape(new_shape);
}

// Bias.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit surprised we did not use parallel for here. maybe for the newly added elementwise add, it's simple to a parallel loop.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may means no body use CPU for training, so I didn't take effort to refactor the code, and just to make it work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's critical for on-device training, we can optimize this in a new PR.

};

std::unique_ptr<GraphTransformer> transformer = std::make_unique<SceLossGradBiasFusion>();
TestGraphTransformer(build_test_case, 14, logger, std::move(transformer), TransformerLevel::Level2, 1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we also run opset 12, 13, we had models onboarded running on 12, they probably will get refreshed to re-train with new ORT + new data, benefiting from this fuse.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the test. But the old ORT will not have this fusion. If user wants to re-train with new ORT with this fusion, the default OpSet version for ORTModule is now 15, unless user uses env variables to set the OpSet version to an old one, which is not recommended.

@pengwa pengwa requested a review from baijumeswani August 31, 2022 07:49
@pengwa
Copy link
Contributor

pengwa commented Aug 31, 2022

FYI @baijumeswani . let's check whether on-device training models (GPU/CPU) have this pattern or not.

@Lafi7e Lafi7e merged commit da07c83 into main Sep 14, 2022
@Lafi7e Lafi7e deleted the weicwang/sce_grad_add branch September 14, 2022 06:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

training issues related to ONNX Runtime training; typically submitted using template

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants