diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index 65bd714d788182..4d88aa8e99dadd 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -2601,6 +2601,10 @@ bool FPExtOp::areCastCompatible(Type a, Type b) { if (auto fa = a.dyn_cast()) if (auto fb = b.dyn_cast()) return fa.getWidth() < fb.getWidth(); + if (auto va = a.dyn_cast()) + if (auto vb = b.dyn_cast()) + return va.getShape().equals(vb.getShape()) && + areCastCompatible(va.getElementType(), vb.getElementType()); return false; } @@ -2612,6 +2616,10 @@ bool FPTruncOp::areCastCompatible(Type a, Type b) { if (auto fa = a.dyn_cast()) if (auto fb = b.dyn_cast()) return fa.getWidth() > fb.getWidth(); + if (auto va = a.dyn_cast()) + if (auto vb = b.dyn_cast()) + return va.getShape().equals(vb.getShape()) && + areCastCompatible(va.getElementType(), vb.getElementType()); return false; } diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir index 68aeef8a2e1fd6..660fc79fb3102b 100644 --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -485,6 +485,18 @@ func @fpext(%arg0 : f16, %arg1 : f32) { return } +// Checking conversion of integer types to floating point. +// CHECK-LABEL: @fpext +func @fpext_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>) { +// CHECK-NEXT: = llvm.fpext {{.*}} : !llvm<"<2 x half>"> to !llvm<"<2 x float>"> + %0 = fpext %arg0: vector<2xf16> to vector<2xf32> +// CHECK-NEXT: = llvm.fpext {{.*}} : !llvm<"<2 x half>"> to !llvm<"<2 x double>"> + %1 = fpext %arg0: vector<2xf16> to vector<2xf64> +// CHECK-NEXT: = llvm.fpext {{.*}} : !llvm<"<2 x float>"> to !llvm<"<2 x double>"> + %2 = fpext %arg1: vector<2xf32> to vector<2xf64> + return +} + // Checking conversion of integer types to floating point. // CHECK-LABEL: @fptrunc func @fptrunc(%arg0 : f32, %arg1 : f64) { @@ -497,6 +509,18 @@ func @fptrunc(%arg0 : f32, %arg1 : f64) { return } +// Checking conversion of integer types to floating point. +// CHECK-LABEL: @fptrunc +func @fptrunc_vector(%arg0 : vector<2xf32>, %arg1 : vector<2xf64>) { +// CHECK-NEXT: = llvm.fptrunc {{.*}} : !llvm<"<2 x float>"> to !llvm<"<2 x half>"> + %0 = fptrunc %arg0: vector<2xf32> to vector<2xf16> +// CHECK-NEXT: = llvm.fptrunc {{.*}} : !llvm<"<2 x double>"> to !llvm<"<2 x half>"> + %1 = fptrunc %arg1: vector<2xf64> to vector<2xf16> +// CHECK-NEXT: = llvm.fptrunc {{.*}} : !llvm<"<2 x double>"> to !llvm<"<2 x float>"> + %2 = fptrunc %arg1: vector<2xf64> to vector<2xf32> + return +} + // Check sign and zero extension and truncation of integers. // CHECK-LABEL: @integer_extension_and_truncation func @integer_extension_and_truncation() { diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir index c07931f01f8c25..28ef33edba06a3 100644 --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -506,6 +506,12 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index, i64, f16) { // CHECK: %{{[0-9]+}} = sqrt %arg0 : tensor<4x4x?xf32> %142 = sqrt %t : tensor<4x4x?xf32> + // CHECK: = fpext {{.*}} : vector<4xf32> to vector<4xf64> + %143 = fpext %vcf32 : vector<4xf32> to vector<4xf64> + + // CHECK: = fptrunc {{.*}} : vector<4xf32> to vector<4xf16> + %144 = fptrunc %vcf32 : vector<4xf32> to vector<4xf16> + return } diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir index 7cc0331bd48401..5b43103e9018d2 100644 --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -563,6 +563,46 @@ func @fpext_f32_to_i32(%arg0 : f32) { // ----- +func @fpext_vec(%arg0 : vector<2xf16>) { + // expected-error@+1 {{are cast incompatible}} + %0 = fpext %arg0 : vector<2xf16> to vector<3xf32> + return +} + +// ----- + +func @fpext_vec_f32_to_f16(%arg0 : vector<2xf32>) { + // expected-error@+1 {{are cast incompatible}} + %0 = fpext %arg0 : vector<2xf32> to vector<2xf16> + return +} + +// ----- + +func @fpext_vec_f16_to_f16(%arg0 : vector<2xf16>) { + // expected-error@+1 {{are cast incompatible}} + %0 = fpext %arg0 : vector<2xf16> to vector<2xf16> + return +} + +// ----- + +func @fpext_vec_i32_to_f32(%arg0 : vector<2xi32>) { + // expected-error@+1 {{are cast incompatible}} + %0 = fpext %arg0 : vector<2xi32> to vector<2xf32> + return +} + +// ----- + +func @fpext_vec_f32_to_i32(%arg0 : vector<2xf32>) { + // expected-error@+1 {{are cast incompatible}} + %0 = fpext %arg0 : vector<2xf32> to vector<2xi32> + return +} + +// ----- + func @fptrunc_f16_to_f32(%arg0 : f16) { // expected-error@+1 {{are cast incompatible}} %0 = fptrunc %arg0 : f16 to f32 @@ -595,6 +635,46 @@ func @fptrunc_f32_to_i32(%arg0 : f32) { // ----- +func @fptrunc_vec(%arg0 : vector<2xf16>) { + // expected-error@+1 {{are cast incompatible}} + %0 = fptrunc %arg0 : vector<2xf16> to vector<3xf32> + return +} + +// ----- + +func @fptrunc_vec_f16_to_f32(%arg0 : vector<2xf16>) { + // expected-error@+1 {{are cast incompatible}} + %0 = fptrunc %arg0 : vector<2xf16> to vector<2xf32> + return +} + +// ----- + +func @fptrunc_vec_f32_to_f32(%arg0 : vector<2xf32>) { + // expected-error@+1 {{are cast incompatible}} + %0 = fptrunc %arg0 : vector<2xf32> to vector<2xf32> + return +} + +// ----- + +func @fptrunc_vec_i32_to_f32(%arg0 : vector<2xi32>) { + // expected-error@+1 {{are cast incompatible}} + %0 = fptrunc %arg0 : vector<2xi32> to vector<2xf32> + return +} + +// ----- + +func @fptrunc_vec_f32_to_i32(%arg0 : vector<2xf32>) { + // expected-error@+1 {{are cast incompatible}} + %0 = fptrunc %arg0 : vector<2xf32> to vector<2xi32> + return +} + +// ----- + func @sexti_index_as_operand(%arg0 : index) { // expected-error@+1 {{'index' is not a valid operand type}} %0 = sexti %arg0 : index to i128