-
Notifications
You must be signed in to change notification settings - Fork 506
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TCP] Enable type promotion and scalar support for elementwise ops #2224
Conversation
if (isa<AtenAddScalarOp>(op) || isa<AtenSubScalarOp>(op)) { | ||
RankedTensorType tensorResultType = | ||
RankedTensorType::get({}, adaptor.getOther().getType()); | ||
rhs = torch_to_tcp::scalarToTcpTensor(rewriter, op, tensorResultType, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should move scalarToTcpTensor
to be its own OpConversionPattern
that converts
torch.const.int -> tcp.const
and just use adaptor.getOther
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, tcp.const
only outputs RankedTensorType
. For ops that still requires the scalar input, always converting torch.const-> tcp.const
will cause problems in the default backendTypeConversion
pass due to the type mismatch. For elementwise ops with the scalar input, we use getConstTensor
utils which creates a tcp::ConstOp
to get the const tensor first, then cast and broadcast its type and shape. For other places which still requires the torch.const
input, it will be automatically converted to arith.const
op. Please let me know if this makes any sense or I misunderstand something.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to make sure I understand correctly: which cases we don't want to lower torch.const.int -> tcp.const
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess any op that still requires scalar (torch.const) as its input will have this problem. I tried to implement the torch.const-> tcp.const
conversion pattern for both int and fp, but fails bunch of e2e tests. For example, if we convert torch.const.int -> tcp.const
, aten.add.scalar
(which requires a scalar as its second input) will produce following error:
BackendTypeConversion.cpp:118: auto setupTorchFloatToF64Conversion(mlir::ConversionTarget &, mlir::TypeConverter &)::(anonymous class)::operator()(mlir::OpBuilder &, mlir::Float64Type, mlir::ValueRange, mlir::Location) const: Assertion `inputs[0].getType().isaTorch::FloatType()' failed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aren't we converting aten.add.scalar -> tcp.add
and broadcasting the rhs scalar to same shape as the lhs ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with @zezhang's point that converting all torch.const-> tcp.const
would not work because tcp.const
always results in a tensor.
This would be a problem for any op that takes in a scalar. For example, tcp.broadcast
op takes in broadcast sizes as a Variadic list of scalars (https://github.com/llvm/torch-mlir/blob/mlir-tcp/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/Tcp/IR/TcpOps.td#L202). If they are constants, we can't be using tcp.const
to return index
types (at least with the current specification of tcp.const
).
Are you suggesting we change tcp.const
to allow returning scalars as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you suggesting we change tcp.const to allow returning scalars as well?
Yes, I can see use writing less conversion pattern code if we allow tcp.const
to return scalar or tensor types e.g
%0 = tcp.constant 2 : f32
%1 = tensor.from_ements %0 : tensor<f32>
will add folding pattern to do:
%0 = tcp.constant 2 : tensor<f32>
and so on, otherwise every op we will need to have handle scalar operand case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree that it would help reduce the conversion pattern code. But can we base design decisions on that factor?
So far, all TCP ops return only Tensors, none of them produce scalars. My worry is that it might be a bigger change than just updating tcp.const
, which might have other design implications. We should probably carefully consider this and do it in a separate PR, if there is an agreement. Wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But can we base design decisions on that factor?
My concern is current design converting TorchOp -> TCPOp
requires special handling of TCPOp
based on if operands are scalar or tensors seems like a gap we have in the TCP
dialect and how it models constants.
We should probably carefully consider this and do it in a separate PR, if there is an agreement. Wdyt?
Yes, my comment here is more around understanding current Torch -> TCP approach so I can provide useful feedback on this PR its nonblocking indeed we shouldn't address design changes here!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably discuss this further separately. I just want to add one more point here:
- One of the design decisions was that scalars that are used to represent indices / dims must be "scalars", and all other scalars must be "rank-0 tensors".
Given this design, we wouldn't know when processing atorch.const
if it is going to represent an index or a value (maybe we can classify based on type -index
type always represent indices and others are values, not sure if that is too fragile though). So, I'm not sure if it is possible to have a conversion pattern fortorch.const
.
The alternative is to question that design decision itself and have a discussion around that.
if (isa<AtenAddScalarOp>(op) || isa<AtenSubScalarOp>(op)) { | ||
RankedTensorType tensorResultType = | ||
RankedTensorType::get({}, adaptor.getOther().getType()); | ||
rhs = torch_to_tcp::scalarToTcpTensor(rewriter, op, tensorResultType, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with @zezhang's point that converting all torch.const-> tcp.const
would not work because tcp.const
always results in a tensor.
This would be a problem for any op that takes in a scalar. For example, tcp.broadcast
op takes in broadcast sizes as a Variadic list of scalars (https://github.com/llvm/torch-mlir/blob/mlir-tcp/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/Tcp/IR/TcpOps.td#L202). If they are constants, we can't be using tcp.const
to return index
types (at least with the current specification of tcp.const
).
Are you suggesting we change tcp.const
to allow returning scalars as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made one pass over it. PTAL.
rhs = torch_to_tcp::scalarToTcpTensor(rewriter, op, tensorResultType, | ||
op.getOther()); | ||
if (adaptor.getOther().getType().template isa<mlir::FloatType>()) | ||
// FP rhs is treated as fp64 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this the case? What about f32 additions, for example?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like torch.constant.float
will be treated as 64 bit floating point value. So I just use fp64 as its type to perform the cast if necessary.
outputType, rhs, | ||
resultType.getElementType()); | ||
else if (adaptor.getOther().getType().template isa<mlir::IntegerType>()) | ||
// INT rhs is treated as si64 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this getting casted to si64
? There are no signed integers in TCP, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like torch.constant.int
will be treated as 64-bit signed value. So I just use si64 as its type to perform the cast if necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly LGTM. Just one final comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for addressing all the comments.
Please take a look at the build failures.
df2e5a3
to
268f0e9
Compare
Testing: