diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp index c3b3a78abe7f7..8b6c62ca2e36d 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -624,7 +624,8 @@ LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands, const LLVMTypeConverter &converter) { TypeRange operandTypes(operands); if (llvm::any_of(operandTypes, llvm::IsaPred)) { - VectorType vectorType = cast(op->getResultTypes()[0]); + VectorType vectorType = + cast(converter.convertType(op->getResultTypes()[0])); rewriter.replaceOp(op, scalarizeVectorOpHelper(op, operands, vectorType, rewriter, converter)); return success(); diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir index 9448304f11dbd..313d7b086731e 100644 --- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir +++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir @@ -516,6 +516,20 @@ module { // ----- +module @test_module { + // CHECK: llvm.func @__ocml_sin_f16(f16) -> f16 + // CHECK-LABEL: func @math_sin_vector_0d + func.func @math_sin_vector_0d(%arg : vector) -> vector { + // CHECK: llvm.extractelement {{.*}} : vector<1xf16> + // CHECK: llvm.call @__ocml_sin_f16(%{{.*}}) : (f16) -> f16 + // CHECK: llvm.insertelement {{.*}} : vector<1xf16> + %result = math.sin %arg : vector + func.return %result : vector + } +} + +// ----- + module @test_module { // CHECK: llvm.func @__ocml_sin_f16(f16) -> f16 // CHECK-LABEL: func @math_sin_vector_1d