Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tanh, add, and broadcast ops + Lowering to linalg #1595

Merged
merged 8 commits into from
Nov 22, 2022

Conversation

sanjoy
Copy link

@sanjoy sanjoy commented Nov 16, 2022

This PR adds tanh, add and broadcast ops to TCP and implements a lowering to linalg.

This PR was carved out of a larger set of commits that includes a full Torch->TCP->Linalg e2e flow for these ops. This larger flow is here for reference.

(@navahgar is the original author of this code, I'm sending it out on his behalf.)

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! First round of comments. Mostly small things. I dont fully follow the logic of broadcast lowering. Some comments there would be nice.

} // namespace mlir


#endif // TORCH_MLIR_DIALECTS_CONVERSION_PASSES_H
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Add line at EOF here and every where else.

AllElementTypesMatch<["in", "out"]>,
PredOpTrait<
"attribute `axes` must be in increasing order",
CPred<"::llvm::is_sorted(getAxes(), "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd move this into utility functions and just call those from here instead of this in tablegen.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Upsignalling that this is still here


// Add indexing maps for all the tensor operands and for the result.
for (size_t i = 0; i < tensorOperands.size(); ++i) {
indexingMaps.push_back(b.getMultiDimIdentityMap(resultRank));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All AffineMaps here are uniqued. You could just do

SmallVector<AffineMap> indexingMaps(tensorOperands.size() + 1, b.getMultiDimIdentityMap(resultRank));

}

template<typename TcpOpT>
Value createLinalgPayloadForElementwiseOp(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not entirely sure why this is templated. It would probably be the same if you use Operation *op instead of TcpOpT op?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe also make this return FailureOr<Value>. Not a fan of returning nullptrs around.

ConversionTarget &target) {
MLIRContext *context = patterns.getContext();

target.addIllegalOp<TanhOp>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldnt it be better to just add TCP dialect as illegal.


namespace {

void getValuesFromIndexArrayAttribute(ArrayAttr attr,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just create a SmallVector<int64_t> within this method and return it instead of return by reference? I think most compilers optimize the return as a move.

SmallVector<Value> resultDimSizes;
SmallVector<AffineExpr> exprs;
int64_t pos = 0;
for (int64_t i = 0; i < resultRank; ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments here would be nice.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't add explicit comments, but refactored the code to be more understandable. LMK if this still needs comments.

#include "mlir/Transforms/DialectConversion.h"

namespace mlir {
namespace tcp_to_linalg {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you plan to move this into MLIR core maybe this should follow LLVM coding conventions, i.e. camel case instead of snake case.

@sanjoy
Copy link
Author

sanjoy commented Nov 19, 2022

@MaheshRavishankar Addressed comments, PTAL.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Next round of comments, but mostly looks fine now.
Please make sure all files have a new line at end of file.

AllElementTypesMatch<["in", "out"]>,
PredOpTrait<
"attribute `axes` must be in increasing order",
CPred<"::llvm::is_sorted(getAxes(), "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Upsignalling that this is still here

// the tensorOperands, since all the operands are expected to have the same
// shape.
auto tensorOperand = tensorOperands[0];
for (int64_t i = 0; i < resultRank; ++i)
Copy link
Contributor

@MaheshRavishankar MaheshRavishankar Nov 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add { } around bodies that span multiple lines (even if it is a single statement).

Also you can just use createDimValues method (here : https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h#L31 ). That returns a SmallVector<OpFoldResult> that you can just pass to build method of tensor.empty().

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Changed to use createDimValues.

pos < static_cast<int64_t>(axes.size()) && axes[pos] == i;
if (isOutputDimBroadcasted) {
resultDimSizes.push_back(op->getOperands()[pos + 1]);
inputIndexingMapExprs.push_back(b.getAffineConstantExpr(0));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the design for broadcast chosen here, but just pointing out that having 0 in indexing maps of linalg ops is kind of bad form.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the current definition of tcp.broadcast, is there a better way to lower it to linalg?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, there isnt. Just pointing out (minor) downside of the current design

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Dont know if anyone else needs to sign off as well.

TorchMLIRTcpDialect
)

torch_mlir_dialects_target_includes(TorchMLIRTcpToLinalg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: new line at EOF

pos < static_cast<int64_t>(axes.size()) && axes[pos] == i;
if (isOutputDimBroadcasted) {
resultDimSizes.push_back(op->getOperands()[pos + 1]);
inputIndexingMapExprs.push_back(b.getAffineConstantExpr(0));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, there isnt. Just pointing out (minor) downside of the current design

@sjarus
Copy link
Collaborator

sjarus commented Nov 22, 2022

Thanks @MaheshRavishankar for your time and support here!

@sanjoy sanjoy merged commit f5844cb into llvm:mlir-tcp Nov 22, 2022
navahgar added a commit that referenced this pull request Nov 30, 2022
* [Tcp] Add tanh, add, and broadcast ops
* [Tcp] Add tcp-to-linalg pass

Co-authored-by: Raghavan Raman <raghavan.raman@getcruise.com>
Co-authored-by: Sanjoy Das <sanjoy.das@getcruise.com>
navahgar added a commit that referenced this pull request Nov 30, 2022
* [Tcp] Add tanh, add, and broadcast ops
* [Tcp] Add tcp-to-linalg pass

Co-authored-by: Raghavan Raman <raghavan.raman@getcruise.com>
Co-authored-by: Sanjoy Das <sanjoy.das@getcruise.com>
navahgar added a commit that referenced this pull request Dec 8, 2022
* [Tcp] Add tanh, add, and broadcast ops
* [Tcp] Add tcp-to-linalg pass

Co-authored-by: Raghavan Raman <raghavan.raman@getcruise.com>
Co-authored-by: Sanjoy Das <sanjoy.das@getcruise.com>
navahgar added a commit that referenced this pull request Dec 12, 2022
* [Tcp] Add tanh, add, and broadcast ops
* [Tcp] Add tcp-to-linalg pass

Co-authored-by: Raghavan Raman <raghavan.raman@getcruise.com>
Co-authored-by: Sanjoy Das <sanjoy.das@getcruise.com>
navahgar added a commit that referenced this pull request Jan 5, 2023
* [Tcp] Add tanh, add, and broadcast ops
* [Tcp] Add tcp-to-linalg pass

Co-authored-by: Raghavan Raman <raghavan.raman@getcruise.com>
Co-authored-by: Sanjoy Das <sanjoy.das@getcruise.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants