Skip to content

Optimize computation orders#13672

Merged
pengwa merged 21 commits into
mainfrom
pengwa/optimize_compute
Dec 22, 2022
Merged

Optimize computation orders#13672
pengwa merged 21 commits into
mainfrom
pengwa/optimize_compute

Conversation

@pengwa
Copy link
Copy Markdown
Contributor

@pengwa pengwa commented Nov 16, 2022

Optimize computation orders

In Roberta/Electra, when ClassificationHead is used, there is slicing operation on features on sequence_length dimensions, then loss calculations only depend on this sliced data. This is a slicing at axis 1. Before slicing the shape is [batch, sequence_length, hidden], after slicing, it becomes [batch , hidden_stage]

We had opportunities to bring this slicing earlier as much as possible, by passing through simple elementwise ops (like Add/Div), or Layernorm/Softmax(if their reduce axis is after the slicing axis), or even MatMul's the left operand (if only it did not affect the last dims).

For operators like Reshape/Transpose, it is special since they have either data specified (after slicing we need update), or they have perm specified, which requires the input rank remain unchanged. So for those kinds of operators, we can remain the original rank, but just leave the sliced dim to be 1, after the compute completed, we do a Squeeze.

class RobertaClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)
        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, features, **kwargs):
        x = features[:, 0, :]  # take <s> token (equiv. to [CLS])
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x

src\transformers\models\roberta\modeling_roberta.py
src\transformers\models\electra\modeling_electra.py

Benchmark

A simple benchmark shows Robeta training latency dropped from 208ms ~ 199ms. 4.5+% reduction.
More comprehensive tests are on the way.

Motivation and Context

@pengwa pengwa added the training issues related to ONNX Runtime training; typically submitted using template label Nov 16, 2022
Comment thread onnxruntime/core/optimizer/compute_optimizer/compute_optimizer.h Outdated
Comment thread onnxruntime/core/optimizer/compute_optimizer/compute_optimizer.h Outdated
Comment thread orttraining/orttraining/core/optimizer/graph_transformer_utils.cc
Comment thread orttraining/orttraining/core/optimizer/graph_transformer_utils.cc Outdated
Comment thread onnxruntime/core/optimizer/compute_optimizer/compute_optimizer.h Outdated
Comment thread onnxruntime/core/optimizer/compute_optimizer/compute_optimizer.h Outdated
Comment thread onnxruntime/core/optimizer/compute_optimizer/compute_optimizer.h Outdated
Comment thread onnxruntime/core/optimizer/compute_optimizer/compute_optimizer.cc Outdated
Comment thread onnxruntime/core/optimizer/compute_optimizer/compute_optimizer.cc
Comment thread onnxruntime/core/optimizer/compute_optimizer/compute_optimizer.cc Outdated
Comment thread onnxruntime/core/optimizer/compute_optimizer/compute_optimizer.cc Outdated
Comment thread onnxruntime/core/optimizer/compute_optimizer/compute_optimizer.cc Outdated
Comment thread onnxruntime/core/optimizer/compute_optimizer/passthrough_actors.h Outdated
Comment thread onnxruntime/core/optimizer/compute_optimizer/passthrough_actors.h Outdated
Comment thread onnxruntime/core/optimizer/compute_optimizer/passthrough_actors.h Outdated
Comment thread onnxruntime/core/optimizer/compute_optimizer/passthrough_actors.cc Outdated
Comment thread onnxruntime/core/optimizer/compute_optimizer/passthrough_actors.cc Outdated
Comment thread onnxruntime/core/optimizer/compute_optimizer/passthrough_actors.cc Outdated
baijumeswani
baijumeswani previously approved these changes Dec 12, 2022
Copy link
Copy Markdown
Contributor

@baijumeswani baijumeswani left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Sorry for the delay in the review. Please update the branch to address failing pipelines.

Comment thread docs/ORTModule_Training_Guidelines.md
transformers.emplace_back(std::make_unique<ReshapeFusion>(compatible_eps));
transformers.emplace_back(std::make_unique<ConcatSliceElimination>(compatible_eps));
#if defined(USE_CUDA) || defined(USE_ROCM)
transformers.emplace_back(std::make_unique<ComputationReductionTransformer>(compatible_eps));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we deleting this transformer?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh This is a renaming., now it is called ComputeOptimizer

"${ONNXRUNTIME_INCLUDE_DIR}/core/optimizer/*.h"
"${ONNXRUNTIME_ROOT}/core/optimizer/*.h"
"${ONNXRUNTIME_ROOT}/core/optimizer/*.cc"
"${ONNXRUNTIME_ROOT}/core/optimizer/compute_optimizer/*.h"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the code in these files and in the test files is wrapped in ENABLE_TRAINING so why add these files here? You can add them within the if (onnxruntime_ENABLE_TRAINING) condition

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the optimizer is applicable for inferencing. I intent to make it easier to enable it for inference later without changing file location and cmake macros.

Comment thread docs/ORTModule_Training_Guidelines.md
// 3. Should all inputs be allowed when track back further (bottom-up);
// if not, add the input index restriction as MatMul did.
{GetFullQualifiedOpName("Add", kOnnxDomain),
OpPassThroughConfig({}, std::make_shared<SimplePassThroughActor>(), opset_14_13_7_6_1)},
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the benefit of creating these const initializer list? Since any opportunity for reuse across the ops is purely coincidental.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Defining a temporary list of opsets, on windows, later when we read it it is found some values are always zero. So I have to make is a constant to make it run correctly.

Copy link
Copy Markdown
Contributor

@askhade askhade left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@pengwa pengwa merged commit 2f5bf75 into main Dec 22, 2022
@pengwa pengwa deleted the pengwa/optimize_compute branch December 22, 2022 07:12
simon-moo pushed a commit to simon-moo/onnxruntime that referenced this pull request Dec 26, 2022
### Optimize computation orders

In `Roberta/Electra`, when `ClassificationHead` is used, there is
slicing operation on features on sequence_length dimensions, then loss
calculations only depend on this sliced data. This is a slicing at axis
1. Before slicing the shape is [batch, sequence_length, hidden], after
slicing, it becomes [batch , hidden_stage]

We had opportunities to bring this slicing earlier as much as possible,
by passing through simple elementwise ops (like Add/Div), or
Layernorm/Softmax(if their reduce axis is after the slicing axis), or
even MatMul's the left operand (if only it did not affect the last
dims).

For operators like Reshape/Transpose, it is special since they have
either data specified (after slicing we need update), or they have perm
specified, which requires the input rank remain unchanged. So for those
kinds of operators, we can remain the original rank, but just leave the
sliced dim to be 1, after the compute completed, we do a Squeeze.

```
class RobertaClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)
        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, features, **kwargs):
        x = features[:, 0, :]  # take <s> token (equiv. to [CLS])
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x
```

src\transformers\models\roberta\modeling_roberta.py
src\transformers\models\electra\modeling_electra.py

#### Benchmark

A simple benchmark shows Robeta training latency dropped from 208ms ~
199ms. 4.5+% reduction.
More comprehensive tests are on the way.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
pengwa added a commit that referenced this pull request Mar 13, 2023
### Slice op upstream refactor

A refactor work for #13672.

### Motivation and Context

There is a similar optimization opportunity for other operator
upstreaming, to reduce compute flops. So refactor the existing code base
for making it easier to support other ops.

The changes in this PR are mainly about renaming and moving. 
- Move common logic (from compute_optimizer.h/cc) into
upstream_transformer_base.h/cc and shared_utils.h/cc.
- For upstream common logic, they are moved into
upstream_transformer_base.h/cc
   - For shared utilities, they are moved to shared_utils.h/cc.
- After the move, compute_optimizer.h/cc mainly for upstreaming gather
implementation (inheriting upstream_transformer_base.h/cc). Ideally it
should be renamed, but for easier review this time, I keep its name.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

training issues related to ONNX Runtime training; typically submitted using template

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants