-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add tanh, add, and broadcast ops + Lowering to linalg #1595
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! First round of comments. Mostly small things. I dont fully follow the logic of broadcast lowering. Some comments there would be nice.
} // namespace mlir | ||
|
||
|
||
#endif // TORCH_MLIR_DIALECTS_CONVERSION_PASSES_H |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Add line at EOF here and every where else.
...-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/Tcp/IR/TcpBase.td
Show resolved
Hide resolved
AllElementTypesMatch<["in", "out"]>, | ||
PredOpTrait< | ||
"attribute `axes` must be in increasing order", | ||
CPred<"::llvm::is_sorted(getAxes(), " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd move this into utility functions and just call those from here instead of this in tablegen.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Upsignalling that this is still here
|
||
// Add indexing maps for all the tensor operands and for the result. | ||
for (size_t i = 0; i < tensorOperands.size(); ++i) { | ||
indexingMaps.push_back(b.getMultiDimIdentityMap(resultRank)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All AffineMaps
here are uniqued. You could just do
SmallVector<AffineMap> indexingMaps(tensorOperands.size() + 1, b.getMultiDimIdentityMap(resultRank));
} | ||
|
||
template<typename TcpOpT> | ||
Value createLinalgPayloadForElementwiseOp( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not entirely sure why this is templated. It would probably be the same if you use Operation *op
instead of TcpOpT op
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe also make this return FailureOr<Value>
. Not a fan of returning nullptr
s around.
ConversionTarget &target) { | ||
MLIRContext *context = patterns.getContext(); | ||
|
||
target.addIllegalOp<TanhOp>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldnt it be better to just add TCP dialect as illegal.
|
||
namespace { | ||
|
||
void getValuesFromIndexArrayAttribute(ArrayAttr attr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can just create a SmallVector<int64_t>
within this method and return it instead of return by reference? I think most compilers optimize the return as a move.
SmallVector<Value> resultDimSizes; | ||
SmallVector<AffineExpr> exprs; | ||
int64_t pos = 0; | ||
for (int64_t i = 0; i < resultRank; ++i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comments here would be nice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't add explicit comments, but refactored the code to be more understandable. LMK if this still needs comments.
#include "mlir/Transforms/DialectConversion.h" | ||
|
||
namespace mlir { | ||
namespace tcp_to_linalg { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you plan to move this into MLIR core maybe this should follow LLVM coding conventions, i.e. camel case instead of snake case.
@MaheshRavishankar Addressed comments, PTAL. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Next round of comments, but mostly looks fine now.
Please make sure all files have a new line at end of file.
AllElementTypesMatch<["in", "out"]>, | ||
PredOpTrait< | ||
"attribute `axes` must be in increasing order", | ||
CPred<"::llvm::is_sorted(getAxes(), " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Upsignalling that this is still here
// the tensorOperands, since all the operands are expected to have the same | ||
// shape. | ||
auto tensorOperand = tensorOperands[0]; | ||
for (int64_t i = 0; i < resultRank; ++i) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add {
}
around bodies that span multiple lines (even if it is a single statement).
Also you can just use createDimValues
method (here : https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h#L31 ). That returns a SmallVector<OpFoldResult>
that you can just pass to build
method of tensor.empty()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Changed to use createDimValues
.
pos < static_cast<int64_t>(axes.size()) && axes[pos] == i; | ||
if (isOutputDimBroadcasted) { | ||
resultDimSizes.push_back(op->getOperands()[pos + 1]); | ||
inputIndexingMapExprs.push_back(b.getAffineConstantExpr(0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the design for broadcast chosen here, but just pointing out that having 0
in indexing maps of linalg
ops is kind of bad form.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given the current definition of tcp.broadcast
, is there a better way to lower it to linalg?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, there isnt. Just pointing out (minor) downside of the current design
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. Dont know if anyone else needs to sign off as well.
TorchMLIRTcpDialect | ||
) | ||
|
||
torch_mlir_dialects_target_includes(TorchMLIRTcpToLinalg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: new line at EOF
pos < static_cast<int64_t>(axes.size()) && axes[pos] == i; | ||
if (isOutputDimBroadcasted) { | ||
resultDimSizes.push_back(op->getOperands()[pos + 1]); | ||
inputIndexingMapExprs.push_back(b.getAffineConstantExpr(0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, there isnt. Just pointing out (minor) downside of the current design
Thanks @MaheshRavishankar for your time and support here! |
* [Tcp] Add tanh, add, and broadcast ops * [Tcp] Add tcp-to-linalg pass Co-authored-by: Raghavan Raman <raghavan.raman@getcruise.com> Co-authored-by: Sanjoy Das <sanjoy.das@getcruise.com>
* [Tcp] Add tanh, add, and broadcast ops * [Tcp] Add tcp-to-linalg pass Co-authored-by: Raghavan Raman <raghavan.raman@getcruise.com> Co-authored-by: Sanjoy Das <sanjoy.das@getcruise.com>
* [Tcp] Add tanh, add, and broadcast ops * [Tcp] Add tcp-to-linalg pass Co-authored-by: Raghavan Raman <raghavan.raman@getcruise.com> Co-authored-by: Sanjoy Das <sanjoy.das@getcruise.com>
* [Tcp] Add tanh, add, and broadcast ops * [Tcp] Add tcp-to-linalg pass Co-authored-by: Raghavan Raman <raghavan.raman@getcruise.com> Co-authored-by: Sanjoy Das <sanjoy.das@getcruise.com>
* [Tcp] Add tanh, add, and broadcast ops * [Tcp] Add tcp-to-linalg pass Co-authored-by: Raghavan Raman <raghavan.raman@getcruise.com> Co-authored-by: Sanjoy Das <sanjoy.das@getcruise.com>
This PR adds tanh, add and broadcast ops to TCP and implements a lowering to linalg.
This PR was carved out of a larger set of commits that includes a full Torch->TCP->Linalg e2e flow for these ops. This larger flow is here for reference.
(@navahgar is the original author of this code, I'm sending it out on his behalf.)