Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TCP] Enable type promotion and scalar support for elementwise ops #2224

Merged
merged 8 commits into from
Jun 20, 2023

Conversation

zezhang
Copy link
Collaborator

@zezhang zezhang commented Jun 10, 2023

  • Add scalar support for existing binary elementwise ops
  • Enable type promotion for existing binary elementwise ops (Directly cast input type to output type if necessary)
  • Drop alpha=1 constraint in add/sub ops

Testing:

./build/bin/llvm-lit externals/llvm-external-projects/torch-mlir-dialects/test/Conversion/TcpToLinalg/ -v
./build/bin/llvm-lit externals/llvm-external-projects/torch-mlir-dialects/test/Dialect/Tcp/ -v
./build/bin/llvm-lit test/Conversion/TorchToTcp/ -v
python -m e2e_testing.main --config=tcp -v -s

@zezhang zezhang changed the title [TCP] Enable Type promotion and scalar support for elementwise ops [TCP] Enable type promotion and scalar support for elementwise ops Jun 10, 2023
@zezhang zezhang requested a review from navahgar June 10, 2023 00:11
if (isa<AtenAddScalarOp>(op) || isa<AtenSubScalarOp>(op)) {
RankedTensorType tensorResultType =
RankedTensorType::get({}, adaptor.getOther().getType());
rhs = torch_to_tcp::scalarToTcpTensor(rewriter, op, tensorResultType,
Copy link
Member

@asaadaldien asaadaldien Jun 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should move scalarToTcpTensor to be its own OpConversionPattern that converts
torch.const.int -> tcp.const and just use adaptor.getOther

Copy link
Collaborator Author

@zezhang zezhang Jun 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, tcp.const only outputs RankedTensorType. For ops that still requires the scalar input, always converting torch.const-> tcp.const will cause problems in the default backendTypeConversion pass due to the type mismatch. For elementwise ops with the scalar input, we use getConstTensor utils which creates a tcp::ConstOp to get the const tensor first, then cast and broadcast its type and shape. For other places which still requires the torch.const input, it will be automatically converted to arith.const op. Please let me know if this makes any sense or I misunderstand something.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make sure I understand correctly: which cases we don't want to lower torch.const.int -> tcp.const ?

Copy link
Collaborator Author

@zezhang zezhang Jun 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess any op that still requires scalar (torch.const) as its input will have this problem. I tried to implement the torch.const-> tcp.const conversion pattern for both int and fp, but fails bunch of e2e tests. For example, if we convert torch.const.int -> tcp.const, aten.add.scalar (which requires a scalar as its second input) will produce following error:

BackendTypeConversion.cpp:118: auto setupTorchFloatToF64Conversion(mlir::ConversionTarget &, mlir::TypeConverter &)::(anonymous class)::operator()(mlir::OpBuilder &, mlir::Float64Type, mlir::ValueRange, mlir::Location) const: Assertion `inputs[0].getType().isaTorch::FloatType()' failed.

Copy link
Member

@asaadaldien asaadaldien Jun 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't we converting aten.add.scalar -> tcp.add and broadcasting the rhs scalar to same shape as the lhs ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @zezhang's point that converting all torch.const-> tcp.const would not work because tcp.const always results in a tensor.

This would be a problem for any op that takes in a scalar. For example, tcp.broadcast op takes in broadcast sizes as a Variadic list of scalars (https://github.com/llvm/torch-mlir/blob/mlir-tcp/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/Tcp/IR/TcpOps.td#L202). If they are constants, we can't be using tcp.const to return index types (at least with the current specification of tcp.const).

Are you suggesting we change tcp.const to allow returning scalars as well?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you suggesting we change tcp.const to allow returning scalars as well?

Yes, I can see use writing less conversion pattern code if we allow tcp.const to return scalar or tensor types e.g

%0 = tcp.constant 2 : f32
%1 = tensor.from_ements %0 : tensor<f32>

will add folding pattern to do:

%0 = tcp.constant 2 : tensor<f32>

and so on, otherwise every op we will need to have handle scalar operand case.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree that it would help reduce the conversion pattern code. But can we base design decisions on that factor?

So far, all TCP ops return only Tensors, none of them produce scalars. My worry is that it might be a bigger change than just updating tcp.const, which might have other design implications. We should probably carefully consider this and do it in a separate PR, if there is an agreement. Wdyt?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But can we base design decisions on that factor?

My concern is current design converting TorchOp -> TCPOp requires special handling of TCPOp based on if operands are scalar or tensors seems like a gap we have in the TCP dialect and how it models constants.

We should probably carefully consider this and do it in a separate PR, if there is an agreement. Wdyt?

Yes, my comment here is more around understanding current Torch -> TCP approach so I can provide useful feedback on this PR its nonblocking indeed we shouldn't address design changes here!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably discuss this further separately. I just want to add one more point here:

  • One of the design decisions was that scalars that are used to represent indices / dims must be "scalars", and all other scalars must be "rank-0 tensors".
    Given this design, we wouldn't know when processing a torch.const if it is going to represent an index or a value (maybe we can classify based on type - index type always represent indices and others are values, not sure if that is too fragile though). So, I'm not sure if it is possible to have a conversion pattern fortorch.const.

The alternative is to question that design decision itself and have a discussion around that.

lib/Conversion/TorchToTcp/Utils.cpp Outdated Show resolved Hide resolved
if (isa<AtenAddScalarOp>(op) || isa<AtenSubScalarOp>(op)) {
RankedTensorType tensorResultType =
RankedTensorType::get({}, adaptor.getOther().getType());
rhs = torch_to_tcp::scalarToTcpTensor(rewriter, op, tensorResultType,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @zezhang's point that converting all torch.const-> tcp.const would not work because tcp.const always results in a tensor.

This would be a problem for any op that takes in a scalar. For example, tcp.broadcast op takes in broadcast sizes as a Variadic list of scalars (https://github.com/llvm/torch-mlir/blob/mlir-tcp/externals/llvm-external-projects/torch-mlir-dialects/include/torch-mlir-dialects/Dialect/Tcp/IR/TcpOps.td#L202). If they are constants, we can't be using tcp.const to return index types (at least with the current specification of tcp.const).

Are you suggesting we change tcp.const to allow returning scalars as well?

Copy link
Collaborator

@navahgar navahgar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made one pass over it. PTAL.

lib/Conversion/TorchToTcp/Utils.cpp Outdated Show resolved Hide resolved
lib/Conversion/TorchToTcp/Utils.cpp Outdated Show resolved Hide resolved
lib/Conversion/TorchToTcp/Utils.cpp Outdated Show resolved Hide resolved
lib/Conversion/TorchToTcp/Utils.cpp Outdated Show resolved Hide resolved
lib/Conversion/TorchToTcp/Utils.cpp Show resolved Hide resolved
@zezhang zezhang requested a review from navahgar June 16, 2023 04:42
lib/Conversion/TorchToTcp/Utils.cpp Outdated Show resolved Hide resolved
lib/Conversion/TorchToTcp/Elementwise.cpp Outdated Show resolved Hide resolved
rhs = torch_to_tcp::scalarToTcpTensor(rewriter, op, tensorResultType,
op.getOther());
if (adaptor.getOther().getType().template isa<mlir::FloatType>())
// FP rhs is treated as fp64
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this the case? What about f32 additions, for example?

Copy link
Collaborator Author

@zezhang zezhang Jun 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like torch.constant.float will be treated as 64 bit floating point value. So I just use fp64 as its type to perform the cast if necessary.

outputType, rhs,
resultType.getElementType());
else if (adaptor.getOther().getType().template isa<mlir::IntegerType>())
// INT rhs is treated as si64
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this getting casted to si64? There are no signed integers in TCP, right?

Copy link
Collaborator Author

@zezhang zezhang Jun 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like torch.constant.int will be treated as 64-bit signed value. So I just use si64 as its type to perform the cast if necessary.

lib/Conversion/TorchToTcp/Elementwise.cpp Outdated Show resolved Hide resolved
@zezhang zezhang requested a review from navahgar June 17, 2023 03:17
Copy link
Collaborator

@navahgar navahgar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly LGTM. Just one final comment.

lib/Conversion/TorchToTcp/Elementwise.cpp Outdated Show resolved Hide resolved
lib/Conversion/TorchToTcp/Elementwise.cpp Outdated Show resolved Hide resolved
Copy link
Collaborator

@navahgar navahgar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for addressing all the comments.

Please take a look at the build failures.

@zezhang zezhang force-pushed the ze.zhang/scalar_binary_ops branch from df2e5a3 to 268f0e9 Compare June 20, 2023 22:00
@zezhang zezhang merged commit 19515a5 into llvm:mlir-tcp Jun 20, 2023
@zezhang zezhang deleted the ze.zhang/scalar_binary_ops branch June 20, 2023 22:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants