diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp index 65dffc7908b78..4ce871b6f5e5d 100644 --- a/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp @@ -116,6 +116,81 @@ class SPIRVLegalizePointerCast : public FunctionPass { return LI; } + // Loads elements from an array and constructs a vector. + Value *loadVectorFromArray(IRBuilder<> &B, FixedVectorType *TargetType, + Value *Source) { + // Load each element of the array. + SmallVector LoadedElements; + for (unsigned i = 0; i < TargetType->getNumElements(); ++i) { + // Create a GEP to access the i-th element of the array. + SmallVector Types = {Source->getType(), Source->getType()}; + SmallVector Args; + Args.push_back(B.getInt1(false)); + Args.push_back(Source); + Args.push_back(B.getInt32(0)); + Args.push_back(ConstantInt::get(B.getInt32Ty(), i)); + auto *ElementPtr = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args}); + GR->buildAssignPtr(B, TargetType->getElementType(), ElementPtr); + + // Load the value from the element pointer. + Value *Load = B.CreateLoad(TargetType->getElementType(), ElementPtr); + buildAssignType(B, TargetType->getElementType(), Load); + LoadedElements.push_back(Load); + } + + // Build the vector from the loaded elements. + Value *NewVector = PoisonValue::get(TargetType); + buildAssignType(B, TargetType, NewVector); + + for (unsigned i = 0; i < TargetType->getNumElements(); ++i) { + Value *Index = B.getInt32(i); + SmallVector Types = {TargetType, TargetType, + TargetType->getElementType(), + Index->getType()}; + SmallVector Args = {NewVector, LoadedElements[i], Index}; + NewVector = B.CreateIntrinsic(Intrinsic::spv_insertelt, {Types}, {Args}); + buildAssignType(B, TargetType, NewVector); + } + return NewVector; + } + + // Stores elements from a vector into an array. + void storeArrayFromVector(IRBuilder<> &B, Value *SrcVector, + Value *DstArrayPtr, ArrayType *ArrTy, + Align Alignment) { + auto *VecTy = cast(SrcVector->getType()); + + // Ensure the element types of the array and vector are the same. + assert(VecTy->getElementType() == ArrTy->getElementType() && + "Element types of array and vector must be the same."); + + for (unsigned i = 0; i < VecTy->getNumElements(); ++i) { + // Create a GEP to access the i-th element of the array. + SmallVector Types = {DstArrayPtr->getType(), + DstArrayPtr->getType()}; + SmallVector Args; + Args.push_back(B.getInt1(false)); + Args.push_back(DstArrayPtr); + Args.push_back(B.getInt32(0)); + Args.push_back(ConstantInt::get(B.getInt32Ty(), i)); + auto *ElementPtr = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args}); + GR->buildAssignPtr(B, ArrTy->getElementType(), ElementPtr); + + // Extract the element from the vector and store it. + Value *Index = B.getInt32(i); + SmallVector EltTypes = {VecTy->getElementType(), VecTy, + Index->getType()}; + SmallVector EltArgs = {SrcVector, Index}; + Value *Element = + B.CreateIntrinsic(Intrinsic::spv_extractelt, {EltTypes}, {EltArgs}); + buildAssignType(B, VecTy->getElementType(), Element); + + Types = {Element->getType(), ElementPtr->getType()}; + Args = {Element, ElementPtr, B.getInt16(2), B.getInt8(Alignment.value())}; + B.CreateIntrinsic(Intrinsic::spv_store, {Types}, {Args}); + } + } + // Replaces the load instruction to get rid of the ptrcast used as source // operand. void transformLoad(IRBuilder<> &B, LoadInst *LI, Value *CastedOperand, @@ -154,6 +229,8 @@ class SPIRVLegalizePointerCast : public FunctionPass { // - float v = s.m; else if (SST && SST->getTypeAtIndex(0u) == ToTy) Output = loadFirstValueFromAggregate(B, ToTy, OriginalOperand, LI); + else if (SAT && DVT && SAT->getElementType() == DVT->getElementType()) + Output = loadVectorFromArray(B, DVT, OriginalOperand); else llvm_unreachable("Unimplemented implicit down-cast from load."); @@ -288,6 +365,7 @@ class SPIRVLegalizePointerCast : public FunctionPass { auto *S_VT = dyn_cast(FromTy); auto *D_ST = dyn_cast(ToTy); auto *D_VT = dyn_cast(ToTy); + auto *D_AT = dyn_cast(ToTy); B.SetInsertPoint(BadStore); if (D_ST && isTypeFirstElementAggregate(FromTy, D_ST)) @@ -296,6 +374,8 @@ class SPIRVLegalizePointerCast : public FunctionPass { storeVectorFromVector(B, Src, Dst, Alignment); else if (D_VT && !S_VT && FromTy == D_VT->getElementType()) storeToFirstValueAggregate(B, Src, Dst, D_VT, Alignment); + else if (D_AT && S_VT && S_VT->getElementType() == D_AT->getElementType()) + storeArrayFromVector(B, Src, Dst, D_AT, Alignment); else llvm_unreachable("Unsupported ptrcast use in store. Please fix."); diff --git a/llvm/test/CodeGen/SPIRV/pointers/load-store-vec-from-array.ll b/llvm/test/CodeGen/SPIRV/pointers/load-store-vec-from-array.ll new file mode 100644 index 0000000000000..917bb27afad00 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/pointers/load-store-vec-from-array.ll @@ -0,0 +1,54 @@ +; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv-unknown-vulkan %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -verify-machineinstrs -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %} + +; CHECK-DAG: [[FLOAT:%[0-9]+]] = OpTypeFloat 32 +; CHECK-DAG: [[VEC4FLOAT:%[0-9]+]] = OpTypeVector [[FLOAT]] 4 +; CHECK-DAG: [[UINT_TYPE:%[0-9]+]] = OpTypeInt 32 0 +; CHECK-DAG: [[UINT4:%[0-9]+]] = OpConstant [[UINT_TYPE]] 4 +; CHECK-DAG: [[ARRAY4FLOAT:%[0-9]+]] = OpTypeArray [[FLOAT]] [[UINT4]] +; CHECK-DAG: [[PTR_ARRAY4FLOAT:%[0-9]+]] = OpTypePointer Private [[ARRAY4FLOAT]] +; CHECK-DAG: [[G_IN:%[0-9]+]] = OpVariable [[PTR_ARRAY4FLOAT]] Private +; CHECK-DAG: [[G_OUT:%[0-9]+]] = OpVariable [[PTR_ARRAY4FLOAT]] Private +; CHECK-DAG: [[UINT0:%[0-9]+]] = OpConstant [[UINT_TYPE]] 0 +; CHECK-DAG: [[UINT1:%[0-9]+]] = OpConstant [[UINT_TYPE]] 1 +; CHECK-DAG: [[UINT2:%[0-9]+]] = OpConstant [[UINT_TYPE]] 2 +; CHECK-DAG: [[UINT3:%[0-9]+]] = OpConstant [[UINT_TYPE]] 3 +; CHECK-DAG: [[PTR_FLOAT:%[0-9]+]] = OpTypePointer Private [[FLOAT]] +; CHECK-DAG: [[UNDEF_VEC:%[0-9]+]] = OpUndef [[VEC4FLOAT]] + +@G_in = internal addrspace(10) global [4 x float] zeroinitializer +@G_out = internal addrspace(10) global [4 x float] zeroinitializer + +define spir_func void @main() { +entry: +; CHECK: [[GEP0:%[0-9]+]] = OpAccessChain [[PTR_FLOAT]] [[G_IN]] [[UINT0]] +; CHECK-NEXT: [[LOAD0:%[0-9]+]] = OpLoad [[FLOAT]] [[GEP0]] +; CHECK-NEXT: [[GEP1:%[0-9]+]] = OpAccessChain [[PTR_FLOAT]] [[G_IN]] [[UINT1]] +; CHECK-NEXT: [[LOAD1:%[0-9]+]] = OpLoad [[FLOAT]] [[GEP1]] +; CHECK-NEXT: [[GEP2:%[0-9]+]] = OpAccessChain [[PTR_FLOAT]] [[G_IN]] [[UINT2]] +; CHECK-NEXT: [[LOAD2:%[0-9]+]] = OpLoad [[FLOAT]] [[GEP2]] +; CHECK-NEXT: [[GEP3:%[0-9]+]] = OpAccessChain [[PTR_FLOAT]] [[G_IN]] [[UINT3]] +; CHECK-NEXT: [[LOAD3:%[0-9]+]] = OpLoad [[FLOAT]] [[GEP3]] +; CHECK-NEXT: [[VEC_INSERT0:%[0-9]+]] = OpCompositeInsert [[VEC4FLOAT]] [[LOAD0]] [[UNDEF_VEC]] 0 +; CHECK-NEXT: [[VEC_INSERT1:%[0-9]+]] = OpCompositeInsert [[VEC4FLOAT]] [[LOAD1]] [[VEC_INSERT0]] 1 +; CHECK-NEXT: [[VEC_INSERT2:%[0-9]+]] = OpCompositeInsert [[VEC4FLOAT]] [[LOAD2]] [[VEC_INSERT1]] 2 +; CHECK-NEXT: [[VEC:%[0-9]+]] = OpCompositeInsert [[VEC4FLOAT]] [[LOAD3]] [[VEC_INSERT2]] 3 + %0 = load <4 x float>, ptr addrspace(10) @G_in, align 64 + +; CHECK-NEXT: [[GEP_OUT0:%[0-9]+]] = OpAccessChain [[PTR_FLOAT]] [[G_OUT]] [[UINT0]] +; CHECK-NEXT: [[VEC_EXTRACT0:%[0-9]+]] = OpCompositeExtract [[FLOAT]] [[VEC]] 0 +; CHECK-NEXT: OpStore [[GEP_OUT0]] [[VEC_EXTRACT0]] +; CHECK-NEXT: [[GEP_OUT1:%[0-9]+]] = OpAccessChain [[PTR_FLOAT]] [[G_OUT]] [[UINT1]] +; CHECK-NEXT: [[VEC_EXTRACT1:%[0-9]+]] = OpCompositeExtract [[FLOAT]] [[VEC]] 1 +; CHECK-NEXT: OpStore [[GEP_OUT1]] [[VEC_EXTRACT1]] +; CHECK-NEXT: [[GEP_OUT2:%[0-9]+]] = OpAccessChain [[PTR_FLOAT]] [[G_OUT]] [[UINT2]] +; CHECK-NEXT: [[VEC_EXTRACT2:%[0-9]+]] = OpCompositeExtract [[FLOAT]] [[VEC]] 2 +; CHECK-NEXT: OpStore [[GEP_OUT2]] [[VEC_EXTRACT2]] +; CHECK-NEXT: [[GEP_OUT3:%[0-9]+]] = OpAccessChain [[PTR_FLOAT]] [[G_OUT]] [[UINT3]] +; CHECK-NEXT: [[VEC_EXTRACT3:%[0-9]+]] = OpCompositeExtract [[FLOAT]] [[VEC]] 3 +; CHECK-NEXT: OpStore [[GEP_OUT3]] [[VEC_EXTRACT3]] + store <4 x float> %0, ptr addrspace(10) @G_out, align 64 + +; CHECK-NEXT: OpReturn + ret void +}