Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOSA] Add torch.prim.NumToTensor.Scalar float support #1802

Merged
merged 1 commit into from
Apr 18, 2023

Conversation

AmosLewis
Copy link
Collaborator

@AmosLewis AmosLewis commented Jan 16, 2023

Find this in OPT model : nod-ai/SHARK-Studio#865

Also find this f64 support issue in GPT2 and distilGPT2 nod-ai/SHARK-Studio#494 with transformer 4.25.1

This patch needs the f64 support of tosa https://reviews.llvm.org/D142599

For f64 support test file and debug output: https://gist.github.com/AmosLewis/fca6b0d16ee325fcf7ee400459f4fd40

Related issues: #1615

@AmosLewis
Copy link
Collaborator Author

AmosLewis commented Jan 16, 2023

@eric-k256 @ramiro050 @silvasean is there any possibility to add f64 in tosa. I got it from GPT2/distilGPT2 model torchscript with transformers package 4.25.1. I have no idea how to walk around this unless I downgrade transformer version to 4.21.2

Copy link
Collaborator

@ramiro050 ramiro050 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would've expected the e2e test for NumToTensor to pass on TOSA now

@eric-k256
Copy link
Collaborator

For TOSA, we're looking into adding something like a 'compiler profile' that would allow a wider set of data types, with a corresponding validation pass that would fail networks that are using f64/i64 with devices that don't expect it. If we added this to TOSA, and expanded the types in the dialect, would this be enough for you? Also, are your passes that consume TOSA capable of handling f64? I'd still recommend against f64 unless it is requested by the network developer, as it's likely to be slower on almost every device.

@AmosLewis AmosLewis marked this pull request as ready for review January 18, 2023 03:27
@AmosLewis
Copy link
Collaborator Author

AmosLewis commented Jan 18, 2023

For TOSA, we're looking into adding something like a 'compiler profile' that would allow a wider set of data types, with a corresponding validation pass that would fail networks that are using f64/i64 with devices that don't expect it. If we added this to TOSA, and expanded the types in the dialect, would this be enough for you? Also, are your passes that consume TOSA capable of handling f64? I'd still recommend against f64 unless it is requested by the network developer, as it's likely to be slower on almost every device.

What I need is at least if we use cast, it can walk around this issue. For i64 I used to use tosa::cast i64 to i32, then at the end of the rewrite pattern, I use tosa::cast i32 back to i64. It works. But for f64, the cast failed. As you can see in the following:
With tosa::cast f32 to f64 code(has been deleted in this patch):

  auto outElemTy = resultType.getElementType();
  if (outElemTy.isF64()) {
    auto resultF32 = tosa::getConstTensor<float>(rewriter, op, floatValue, {}).value();
    rewriter.replaceOpWithNewOp<tosa::CastOp>(op, resultType, resultF32);
  }

Bug Output:

// *** IR Dump After Pattern Application ***
mlir-asm-printer: Verifying operation: func.func
ImplicitTypeIDRegistry::lookupOrInsert(mlir::InferShapedTypeOpInterface::Trait<Empty>)
'tosa.cast' op result #0 must be tensor of number values, but got 'tensor<f64>'
mlir-asm-printer: 'func.func' failed to verify and will be printed in generic form
"func.func"() ({
  %0 = "torch.constant.float"() {value = 8.000000e+00 : f64} : () -> !torch.float
  %1 = "builtin.unrealized_conversion_cast"(%0) : (!torch.float) -> f64
  %2 = "tosa.const"() {value = dense<8.000000e+00> : tensor<f32>} : () -> tensor<f32>
  %3 = "tosa.cast"(%2) : (tensor<f32>) -> tensor<f64>
  %4 = "torch.prim.NumToTensor.Scalar"(%0) : (!torch.float) -> !torch.vtensor<[],f64>
  "func.return"(%4) : (!torch.vtensor<[],f64>) -> ()
}) {function_type = () -> !torch.vtensor<[],f64>, sym_name = "torch.prim.NumToTensor.Scalar"} : () -> ()


} -> SUCCESS
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'func.return'(0x55a53fab0c30) {
  "func.return"(%4) : (!torch.vtensor<[],f64>) -> ()

  * Fold {
  } -> FAILURE : unable to fold
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//
** Insert  : 'torch_c.from_builtin_tensor'(0x55a53fb17c60)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::PreservedAnalyses::AllAnalysesType)
/tmp/NumToTensor.mlir:3:8: error:   'tosa.cast' op result #0 must be tensor of number values, but got 'tensor<f64>'
  %1 = "torch.prim.NumToTensor.Scalar"(%float8.000000e00) : (!torch.float) -> !torch.vtensor<[],f64>
       ^
/tmp/NumToTensor.mlir:3:8: note: see current operation: %2 = "tosa.cast"(%1) : (tensor<f32>) -> tensor<f64>

@AmosLewis
Copy link
Collaborator Author

I would've expected the e2e test for NumToTensor to pass on TOSA now

I guess they pass is because by default it generated f32. So no f64 issue.

@AmosLewis
Copy link
Collaborator Author

AmosLewis commented Jan 18, 2023

Let's at least enable the f32 for this patch.

@AmosLewis AmosLewis changed the title [TOSA] Add torch.prim.NumToTensor.Scalar float support [TOSA] Add torch.prim.NumToTensor.Scalar float32 support Jan 18, 2023
@AmosLewis AmosLewis marked this pull request as draft January 18, 2023 20:32
@AmosLewis
Copy link
Collaborator Author

AmosLewis commented Jan 24, 2023

For TOSA, we're looking into adding something like a 'compiler profile' that would allow a wider set of data types, with a corresponding validation pass that would fail networks that are using f64/i64 with devices that don't expect it. If we added this to TOSA, and expanded the types in the dialect, would this be enough for you? Also, are your passes that consume TOSA capable of handling f64? I'd still recommend against f64 unless it is requested by the network developer, as it's likely to be slower on almost every device.

@eric-k256 Is there any possibility to add at least f64 in tosa.cast? Something like:

%2 = "tosa.cast"(%1) : (tensor<f32>) -> tensor<f64>
%3 = "tosa.cast"(%2) : (tensor<f64>) -> tensor<f32>

For the new hugging face model facebook/opt-125m I am working on. I definitely have no way to walk around the f64 support.
Here is the opt_torchbackend.mlir link: https://gist.github.com/AmosLewis/3faccdf32c91d30f21daed12bf7197b4#file-opt_torchbackend-mlir-L233:~:text=%25210%20%3D%20torch.prim.NumToTensor.Scalar%20%25float%2D3.402820e38%20%3A%20!torch.float%20%2D%3E%20!torch.vtensor%3C%5B%5D%2Cf64%3E%20loc(%23loc31)

I find in TOSA dialect definition, there is I64. Is there any reason that we cannot add F64 just at:
https://github.com/llvm/llvm-project/blob/9dea83d4af0b532373f8a0384ce7a873ebf18e41/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td#L88

def Tosa_Float : AnyTypeOf<[
                            F64,
                            F32,
			    F16,
			    BF16]>;

@AmosLewis AmosLewis changed the title [TOSA] Add torch.prim.NumToTensor.Scalar float32 support [TOSA] Add torch.prim.NumToTensor.Scalar float support Jan 24, 2023
@eric-k256
Copy link
Collaborator

I would prefer to start with the smallest possible scope for f64. If we only want to support it for cast, we could create a new type that is used for the inputs and outputs of tosa.cast instead of Tosa_Tensor, it could be something like Tosa_CastableTensor, which included all types from Tosa_Tensor and f64. That would allow us to manage the use of f64, avoiding it for most cases.

@AmosLewis
Copy link
Collaborator Author

the f64 support of tosa https://reviews.llvm.org/D142599

@AmosLewis
Copy link
Collaborator Author

I would prefer to start with the smallest possible scope for f64. If we only want to support it for cast, we could create a new type that is used for the inputs and outputs of tosa.cast instead of Tosa_Tensor, it could be something like Tosa_CastableTensor, which included all types from Tosa_Tensor and f64. That would allow us to manage the use of f64, avoiding it for most cases.

Done. Please review in llvm

@AmosLewis
Copy link
Collaborator Author

AmosLewis commented Feb 2, 2023

module attributes {torch.debug_module_name = "NumToTensorFloatModule"} {
  func.func @forward() -> !torch.vtensor<[],f64> {
    %float1.000000e00 = torch.constant.float 1.000000e+00 loc(#loc1)
    %0 = torch.prim.NumToTensor.Scalar %float1.000000e00 : !torch.float -> !torch.vtensor<[],f64> loc(#loc2)
    return %0 : !torch.vtensor<[],f64> loc(#loc)
  } loc(#loc)
} loc(#loc)

torch-mlir-opt -convert-torch-to-tosa /tmp/NumToTensorFloatModule.mlir -mlir-print-ir-after-all -mlir-disable-threading -debug

module attributes {torch.debug_module_name = "NumToTensorFloatModule"} {
  func.func @forward() -> !torch.vtensor<[],f64> {
    %float1.000000e00 = torch.constant.float 1.000000e+00
    %0 = "tosa.const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
    %1 = "tosa.cast"(%0) : (tensor<f32>) -> tensor<f64>
    %2 = torch_c.from_builtin_tensor %1 : tensor<f64> -> !torch.vtensor<[],f64>
    return %2 : !torch.vtensor<[],f64>
  }
}

@AmosLewis AmosLewis force-pushed the numtotensor branch 2 times, most recently from 1b425f4 to 07c391c Compare February 14, 2023 17:53
@AmosLewis AmosLewis force-pushed the numtotensor branch 2 times, most recently from 1245a18 to b386270 Compare March 3, 2023 20:08
@AmosLewis AmosLewis marked this pull request as ready for review March 3, 2023 20:33
@AmosLewis
Copy link
Collaborator Author

AmosLewis commented Mar 3, 2023

Since this https://reviews.llvm.org/rGa2dcd994a7f8cc33640f58105276b78acf3483e5 has been merged and update in torch-mlir. This patch should be ok to merged now. @eric-k256 @ramiro050 Please review and approve.

@sjw36
Copy link

sjw36 commented Mar 3, 2023

I am trying to solve a similar problem with f64 vtensor.literals in the dynamo flow. The rewrite already has to convert signed integers to unsigned, pretty simple to also convert f64 to f32.
And in general if TOSA is not going to support F64, would it be valid to make the typeConverter convert to f32?

@AmosLewis
Copy link
Collaborator Author

AmosLewis commented Mar 4, 2023

I am trying to solve a similar problem with f64 vtensor.literals in the dynamo flow. The rewrite already has to convert signed integers to unsigned, pretty simple to also convert f64 to f32. And in general if TOSA is not going to support F64, would it be valid to make the typeConverter convert to f32?

Just one line change the tosa::constOp definition to support Tosa_Tensor_Cast f64 should be enough.
def Tosa_ConstOp : Tosa_Op<"const", ...
let results = (outs
Tosa_Tensor_Cast:$output
);
https://reviews.llvm.org/D145336

@AmosLewis
Copy link
Collaborator Author

@eric-k256 Please review and pass this https://reviews.llvm.org/D145336

@AmosLewis AmosLewis force-pushed the numtotensor branch 2 times, most recently from 1840b15 to 594808e Compare April 18, 2023 02:08
@AmosLewis
Copy link
Collaborator Author

https://reviews.llvm.org/D145336 this has been merged and update in torch-mlir. Please review @ramiro050

lib/Conversion/TorchToTosa/TorchToTosa.cpp Outdated Show resolved Hide resolved
lib/Conversion/TorchToTosa/TorchToTosa.cpp Outdated Show resolved Hide resolved
lib/Conversion/TorchToTosa/TorchToTosa.cpp Outdated Show resolved Hide resolved
@AmosLewis AmosLewis merged commit 8d25dd4 into llvm:main Apr 18, 2023
mgehre-amd pushed a commit to Xilinx/torch-mlir that referenced this pull request May 11, 2023
gpetters94 pushed a commit to gpetters94/mlir-npcomp that referenced this pull request Jul 7, 2023
gpetters94 pushed a commit to gpetters94/mlir-npcomp that referenced this pull request Jul 7, 2023
@AmosLewis AmosLewis deleted the numtotensor branch January 19, 2024 19:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants