-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir] Support emit fp16 and bf16 type to cpp #105803
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-emitc Author: Jianjian Guan (jacquesguan) ChangesFull diff: https://github.com/llvm/llvm-project/pull/105803.diff 5 Files Affected:
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index e6f1618cc26116..8555e82002d56b 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -116,6 +116,7 @@ bool mlir::emitc::isIntegerIndexOrOpaqueType(Type type) {
bool mlir::emitc::isSupportedFloatType(Type type) {
if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
switch (floatType.getWidth()) {
+ case 16:
case 32:
case 64:
return true;
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index c043582b7be9c6..aa45e7c9d7f757 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -1640,6 +1640,11 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) {
}
if (auto fType = dyn_cast<FloatType>(type)) {
switch (fType.getWidth()) {
+ case 16:
+ if (llvm::isa<Float16Type>(type))
+ return (os << "_Float16"), success();
+ else
+ return (os << "__bf16"), success();
case 32:
return (os << "float"), success();
case 64:
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
index ef0e71ee8673b7..b3eebaf8a1ef1e 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
@@ -16,39 +16,6 @@ func.func @arith_cast_vector(%arg0: vector<5xf32>) -> vector<5xi32> {
// -----
-func.func @arith_cast_bf16(%arg0: bf16) -> i32 {
- // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
- %t = arith.fptosi %arg0 : bf16 to i32
- return %t: i32
-}
-
-// -----
-
-func.func @arith_cast_f16(%arg0: f16) -> i32 {
- // expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
- %t = arith.fptosi %arg0 : f16 to i32
- return %t: i32
-}
-
-
-// -----
-
-func.func @arith_cast_to_bf16(%arg0: i32) -> bf16 {
- // expected-error @+1 {{failed to legalize operation 'arith.sitofp'}}
- %t = arith.sitofp %arg0 : i32 to bf16
- return %t: bf16
-}
-
-// -----
-
-func.func @arith_cast_to_f16(%arg0: i32) -> f16 {
- // expected-error @+1 {{failed to legalize operation 'arith.sitofp'}}
- %t = arith.sitofp %arg0 : i32 to f16
- return %t: f16
-}
-
-// -----
-
func.func @arith_cast_fptosi_i1(%arg0: f32) -> i1 {
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
%t = arith.fptosi %arg0 : f32 to i1
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
index 836d8aedefc1f0..14977bfb3e2fd9 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
@@ -46,14 +46,6 @@ memref.global "nested" constant @nested_global : memref<3x7xf32>
// -----
-func.func @unsupported_type_f16() {
- // expected-error@+1 {{failed to legalize operation 'memref.alloca'}}
- %0 = memref.alloca() : memref<4xf16>
- return
-}
-
-// -----
-
func.func @unsupported_type_i4() {
// expected-error@+1 {{failed to legalize operation 'memref.alloca'}}
%0 = memref.alloca() : memref<4xi4>
diff --git a/mlir/test/Target/Cpp/types.mlir b/mlir/test/Target/Cpp/types.mlir
index deda383b3b0a72..e7f935c7374382 100644
--- a/mlir/test/Target/Cpp/types.mlir
+++ b/mlir/test/Target/Cpp/types.mlir
@@ -22,6 +22,10 @@ func.func @ptr_types() {
emitc.call_opaque "f"() {template_args = [!emitc.ptr<i32>]} : () -> ()
// CHECK-NEXT: f<int64_t*>();
emitc.call_opaque "f"() {template_args = [!emitc.ptr<i64>]} : () -> ()
+ // CHECK-NEXT: f<_Float16*>();
+ emitc.call_opaque "f"() {template_args = [!emitc.ptr<f16>]} : () -> ()
+ // CHECK-NEXT: f<__bf16*>();
+ emitc.call_opaque "f"() {template_args = [!emitc.ptr<bf16>]} : () -> ()
// CHECK-NEXT: f<float*>();
emitc.call_opaque "f"() {template_args = [!emitc.ptr<f32>]} : () -> ()
// CHECK-NEXT: f<double*>();
|
case 16: | ||
if (llvm::isa<Float16Type>(type)) | ||
return (os << "_Float16"), success(); | ||
else |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please explicitly check for bfloat16; at the pace of current floating point development, I wouldn't be surprised if we get another 16 bit float type in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Outdated
bool mlir::emitc::isSupportedFloatType(Type type) { | ||
if (auto floatType = llvm::dyn_cast<FloatType>(type)) { | ||
switch (floatType.getWidth()) { | ||
case 16: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please explicitly check for Float16Type and BFloat16Type here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Thanks for your PR! Could you please update |
Can you please update the attribute emission to print the correct literal suffix: godbolt, clang documentation |
|
||
// ----- | ||
|
||
func.func @arith_cast_bf16(%arg0: bf16) -> i32 { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you keep the tests and replace the types with unsupported floating point types, like f128
and f80
(and tf32
is a thing as well, I think)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
19f7d82
to
488bd12
Compare
Added fp16 and bf16 to the literal emit logic. |
Add to the doc. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One small nit, LGTM
|
||
// ----- | ||
|
||
func.func @unsupported_type_f16() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, keep the tests and replace the types with unsupported floating point types, like f128
. But it's not that important. It's up to you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Just two minor suggestions.
mlir/docs/Dialects/emitc.md
Outdated
* If `__bf16` is used, the code requires complier that supports it, such as: | ||
GCC and Clang. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would probably rephrase to
* If `__bf16` is used, the code requires complier that supports it, such as: | |
GCC and Clang. | |
* If `__bf16` is used, the code requires a complier that supports it, such as | |
GCC or Clang. |
488bd12
to
d0adb5c
Compare
Thanks for comment, all addressed. |
LGTM! |
mlir/docs/Dialects/emitc.md
Outdated
or any of the C++ headers in which the type is defined. | ||
* If `_Float16` is used, the code requires the support of C additional | ||
floating types. | ||
* If `__bf16` is used, the code requires a complier that supports it, such as |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* If `__bf16` is used, the code requires a complier that supports it, such as | |
* If `__bf16` is used, the code requires a compiler that supports it, such as |
typo slipped through
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, this look good to me. Please fix the typo as pointed to by @cferry-AMD. Beside this the PR should be ready to land. If you need assistance and someone who hits the Squash and merge
button for you, let me know.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small comment to the doc only, looks good to me otherwise!
* If `_Float16` is used, the code requires the support of C additional | ||
floating types. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think usually this documentation specifies the type in the mlir world and then relates it to something in C/C++. Maybe
If
f16
types are used, the code requires a compiler that supports_Float16
.
Same for bf16
below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed
d0adb5c
to
5cd6285
Compare
No description provided.