diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp index 993de9e9f64ec..85ea9e156cb97 100644 --- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp @@ -148,7 +148,10 @@ static const std::map> SPIRV::Extension::Extension::SPV_KHR_float_controls2}, {"SPV_INTEL_tensor_float32_conversion", SPIRV::Extension::Extension::SPV_INTEL_tensor_float32_conversion}, - {"SPV_KHR_bfloat16", SPIRV::Extension::Extension::SPV_KHR_bfloat16}}; + {"SPV_KHR_bfloat16", SPIRV::Extension::Extension::SPV_KHR_bfloat16}, + {"SPV_EXT_relaxed_printf_string_address_space", + SPIRV::Extension::Extension:: + SPV_EXT_relaxed_printf_string_address_space}}; bool SPIRVExtensionsParser::parse(cl::Option &O, StringRef ArgName, StringRef ArgValue, diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp index a95f393b75605..0a122fc994810 100644 --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -1222,6 +1222,31 @@ static void AddDotProductRequirements(const MachineInstr &MI, } } +void addPrintfRequirements(const MachineInstr &MI, + SPIRV::RequirementHandler &Reqs, + const SPIRVSubtarget &ST) { + SPIRVGlobalRegistry *GR = ST.getSPIRVGlobalRegistry(); + const SPIRVType *PtrType = GR->getSPIRVTypeForVReg(MI.getOperand(4).getReg()); + if (PtrType) { + MachineOperand ASOp = PtrType->getOperand(1); + if (ASOp.isImm()) { + unsigned AddrSpace = ASOp.getImm(); + if (AddrSpace != SPIRV::StorageClass::UniformConstant) { + if (!ST.canUseExtension( + SPIRV::Extension:: + SPV_EXT_relaxed_printf_string_address_space)) { + report_fatal_error("SPV_EXT_relaxed_printf_string_address_space is " + "required because printf uses a format string not " + "in constant address space.", + false); + } + Reqs.addExtension( + SPIRV::Extension::SPV_EXT_relaxed_printf_string_address_space); + } + } + } +} + static bool isBFloat16Type(const SPIRVType *TypeDef) { return TypeDef && TypeDef->getNumOperands() == 3 && TypeDef->getOpcode() == SPIRV::OpTypeFloat && @@ -1321,6 +1346,12 @@ void addInstrRequirements(const MachineInstr &MI, static_cast( SPIRV::InstructionSet::NonSemantic_Shader_DebugInfo_100)) { Reqs.addExtension(SPIRV::Extension::SPV_KHR_non_semantic_info); + break; + } + if (MI.getOperand(3).getImm() == + static_cast(SPIRV::OpenCLExtInst::printf)) { + addPrintfRequirements(MI, Reqs, ST); + break; } break; } diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_EXT_relaxed_printf_string_address_space/builtin_printf.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_EXT_relaxed_printf_string_address_space/builtin_printf.ll new file mode 100644 index 0000000000000..093d172c5c1b1 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_EXT_relaxed_printf_string_address_space/builtin_printf.ll @@ -0,0 +1,24 @@ +; RUN: llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_EXT_relaxed_printf_string_address_space %s -o - | FileCheck %s +; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR + +; CHECK: OpExtension "SPV_EXT_relaxed_printf_string_address_space" +; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] printf + +; CHECK-ERROR: LLVM ERROR: SPV_EXT_relaxed_printf_string_address_space is required because printf uses a format string not in constant address space. + +@.str = private unnamed_addr addrspace(1) constant [4 x i8] c"%d\0A\00", align 1 + +declare spir_func i32 @printf(ptr addrspace(4), ...) + +define spir_kernel void @test_kernel() { +entry: + ; Format string in addrspace(1) → cast to addrspace(4) + %format = addrspacecast ptr addrspace(1) @.str to ptr addrspace(4) + %val = alloca i32, align 4 + store i32 123, ptr %val, align 4 + %loaded = load i32, ptr %val, align 4 + + ; Call printf with non-constant format string + %call = call spir_func i32 (ptr addrspace(4), ...) @printf(ptr addrspace(4) %format, i32 %loaded) + ret void +} diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_EXT_relaxed_printf_string_address_space/non-constant-printf.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_EXT_relaxed_printf_string_address_space/non-constant-printf.ll new file mode 100644 index 0000000000000..b54d59b30309f --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_EXT_relaxed_printf_string_address_space/non-constant-printf.ll @@ -0,0 +1,48 @@ +; RUN: llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_EXT_relaxed_printf_string_address_space %s -o - | FileCheck %s +; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR + +; CHECK: OpExtension "SPV_EXT_relaxed_printf_string_address_space" +; CHECK: %[[#ExtInstSetId:]] = OpExtInstImport "OpenCL.std" +; CHECK-DAG: %[[#TypeInt32Id:]] = OpTypeInt 32 0 +; CHECK-DAG: %[[#TypeInt8Id:]] = OpTypeInt 8 0 +; CHECK-DAG: %[[#TypeInt64Id:]] = OpTypeInt 64 0 +; CHECK-DAG: %[[#TypeArrayId:]] = OpTypeArray %[[#TypeInt8Id]] %[[#]] +; CHECK-DAG: %[[#ConstantStorClassGlobalPtrTy:]] = OpTypePointer UniformConstant %[[#TypeArrayId]] +; CHECK-DAG: %[[#WGStorClassGlobalPtrTy:]] = OpTypePointer Workgroup %[[#TypeArrayId]] +; CHECK-DAG: %[[#CrossWFStorClassGlobalPtrTy:]] = OpTypePointer CrossWorkgroup %[[#TypeArrayId]] +; CHECK-DAG: %[[#FunctionStorClassPtrTy:]] = OpTypePointer Function %[[#TypeInt8Id]] +; CHECK-DAG: %[[#WGStorClassPtrTy:]] = OpTypePointer Workgroup %[[#TypeInt8Id]] +; CHECK-DAG: %[[#CrossWFStorClassPtrTy:]] = OpTypePointer CrossWorkgroup %[[#TypeInt8Id]] +; CHECK: %[[#ConstantCompositeId:]] = OpConstantComposite %[[#TypeArrayId]] %[[#]] %[[#]] %[[#]] %[[#]] %[[#]] %[[#]] +; CHECK: %[[#]] = OpVariable %[[#ConstantStorClassGlobalPtrTy]] UniformConstant %[[#ConstantCompositeId]] +; CHECK: %[[#]] = OpVariable %[[#CrossWFStorClassGlobalPtrTy]] CrossWorkgroup %[[#ConstantCompositeId]] +; CHECK: %[[#]] = OpVariable %[[#WGStorClassGlobalPtrTy]] Workgroup %[[#ConstantCompositeId]] +; CHECK: %[[#GEP1:]] = OpInBoundsPtrAccessChain %[[#FunctionStorClassPtrTy]] %[[#]] %[[#]] %[[#]] +; CHECK: %[[#]] = OpExtInst %[[#TypeInt32Id]] %[[#ExtInstSetId:]] printf %[[#GEP1]] +; CHECK: %[[#GEP2:]] = OpInBoundsPtrAccessChain %[[#CrossWFStorClassPtrTy]] %[[#]] %[[#]] %[[#]] +; CHECK: %[[#]] = OpExtInst %[[#TypeInt32Id]] %[[#ExtInstSetId:]] printf %[[#GEP2]] +; CHECK: %[[#GEP3:]] = OpInBoundsPtrAccessChain %[[#WGStorClassPtrTy]] %[[#]] %[[#]] %[[#]] +; CHECK: %[[#]] = OpExtInst %[[#TypeInt32Id]] %[[#ExtInstSetId:]] printf %[[#GEP3]] + +; CHECK-ERROR: LLVM ERROR: SPV_EXT_relaxed_printf_string_address_space is required because printf uses a format string not in constant address space. + +@0 = internal unnamed_addr addrspace(2) constant [6 x i8] c"Test\0A\00", align 1 +@1 = internal unnamed_addr addrspace(1) constant [6 x i8] c"Test\0A\00", align 1 +@2 = internal unnamed_addr addrspace(3) constant [6 x i8] c"Test\0A\00", align 1 + +define spir_kernel void @test() { + %tmp1 = alloca [6 x i8], align 1 + call void @llvm.memcpy.p0.p2.i64(ptr align 1 %tmp1, ptr addrspace(2) align 1 @0, i64 6, i1 false) + %1 = getelementptr inbounds [6 x i8], ptr %tmp1, i32 0, i32 0 + %2 = call spir_func i32 @_Z18__spirv_ocl_printfPc(ptr %1) + %3 = getelementptr inbounds [6 x i8], ptr addrspace(1) @1, i32 0, i32 0 + %4 = call spir_func i32 @_Z18__spirv_ocl_printfPU3AS1c(ptr addrspace(1) %3) + %5 = getelementptr inbounds [6 x i8], ptr addrspace(3) @2, i32 0, i32 0 + %6 = call spir_func i32 @_Z18__spirv_ocl_printfPU3AS3c(ptr addrspace(3) %5) + ret void +} + +declare spir_func i32 @_Z18__spirv_ocl_printfPc(ptr) +declare spir_func i32 @_Z18__spirv_ocl_printfPU3AS1c(ptr addrspace(1)) +declare spir_func i32 @_Z18__spirv_ocl_printfPU3AS3c(ptr addrspace(3)) +declare void @llvm.memcpy.p0.p2.i64(ptr captures(none), ptr addrspace(2) captures(none) readonly, i64, i1)