Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVLegalizePointerCast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,81 @@ class SPIRVLegalizePointerCast : public FunctionPass {
return LI;
}

// Loads elements from an array and constructs a vector.
Value *loadVectorFromArray(IRBuilder<> &B, FixedVectorType *TargetType,
Value *Source) {
// Load each element of the array.
SmallVector<Value *, 4> LoadedElements;
for (unsigned i = 0; i < TargetType->getNumElements(); ++i) {
// Create a GEP to access the i-th element of the array.
SmallVector<Type *, 2> Types = {Source->getType(), Source->getType()};
SmallVector<Value *, 4> Args;
Args.push_back(B.getInt1(false));
Args.push_back(Source);
Args.push_back(B.getInt32(0));
Args.push_back(ConstantInt::get(B.getInt32Ty(), i));
auto *ElementPtr = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
GR->buildAssignPtr(B, TargetType->getElementType(), ElementPtr);

// Load the value from the element pointer.
Value *Load = B.CreateLoad(TargetType->getElementType(), ElementPtr);
buildAssignType(B, TargetType->getElementType(), Load);
LoadedElements.push_back(Load);
}

// Build the vector from the loaded elements.
Value *NewVector = PoisonValue::get(TargetType);
buildAssignType(B, TargetType, NewVector);

for (unsigned i = 0; i < TargetType->getNumElements(); ++i) {
Value *Index = B.getInt32(i);
SmallVector<Type *, 4> Types = {TargetType, TargetType,
TargetType->getElementType(),
Index->getType()};
SmallVector<Value *> Args = {NewVector, LoadedElements[i], Index};
NewVector = B.CreateIntrinsic(Intrinsic::spv_insertelt, {Types}, {Args});
buildAssignType(B, TargetType, NewVector);
}
return NewVector;
}

// Stores elements from a vector into an array.
void storeArrayFromVector(IRBuilder<> &B, Value *SrcVector,
Value *DstArrayPtr, ArrayType *ArrTy,
Align Alignment) {
auto *VecTy = cast<FixedVectorType>(SrcVector->getType());

// Ensure the element types of the array and vector are the same.
assert(VecTy->getElementType() == ArrTy->getElementType() &&
"Element types of array and vector must be the same.");

for (unsigned i = 0; i < VecTy->getNumElements(); ++i) {
// Create a GEP to access the i-th element of the array.
SmallVector<Type *, 2> Types = {DstArrayPtr->getType(),
DstArrayPtr->getType()};
SmallVector<Value *, 4> Args;
Args.push_back(B.getInt1(false));
Args.push_back(DstArrayPtr);
Args.push_back(B.getInt32(0));
Args.push_back(ConstantInt::get(B.getInt32Ty(), i));
auto *ElementPtr = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
GR->buildAssignPtr(B, ArrTy->getElementType(), ElementPtr);

// Extract the element from the vector and store it.
Value *Index = B.getInt32(i);
SmallVector<Type *, 3> EltTypes = {VecTy->getElementType(), VecTy,
Index->getType()};
SmallVector<Value *, 2> EltArgs = {SrcVector, Index};
Value *Element =
B.CreateIntrinsic(Intrinsic::spv_extractelt, {EltTypes}, {EltArgs});
buildAssignType(B, VecTy->getElementType(), Element);

Types = {Element->getType(), ElementPtr->getType()};
Args = {Element, ElementPtr, B.getInt16(2), B.getInt8(Alignment.value())};
B.CreateIntrinsic(Intrinsic::spv_store, {Types}, {Args});
}
}

// Replaces the load instruction to get rid of the ptrcast used as source
// operand.
void transformLoad(IRBuilder<> &B, LoadInst *LI, Value *CastedOperand,
Expand Down Expand Up @@ -154,6 +229,8 @@ class SPIRVLegalizePointerCast : public FunctionPass {
// - float v = s.m;
else if (SST && SST->getTypeAtIndex(0u) == ToTy)
Output = loadFirstValueFromAggregate(B, ToTy, OriginalOperand, LI);
else if (SAT && DVT && SAT->getElementType() == DVT->getElementType())
Output = loadVectorFromArray(B, DVT, OriginalOperand);
else
llvm_unreachable("Unimplemented implicit down-cast from load.");

Expand Down Expand Up @@ -288,6 +365,7 @@ class SPIRVLegalizePointerCast : public FunctionPass {
auto *S_VT = dyn_cast<FixedVectorType>(FromTy);
auto *D_ST = dyn_cast<StructType>(ToTy);
auto *D_VT = dyn_cast<FixedVectorType>(ToTy);
auto *D_AT = dyn_cast<ArrayType>(ToTy);

B.SetInsertPoint(BadStore);
if (D_ST && isTypeFirstElementAggregate(FromTy, D_ST))
Expand All @@ -296,6 +374,8 @@ class SPIRVLegalizePointerCast : public FunctionPass {
storeVectorFromVector(B, Src, Dst, Alignment);
else if (D_VT && !S_VT && FromTy == D_VT->getElementType())
storeToFirstValueAggregate(B, Src, Dst, D_VT, Alignment);
else if (D_AT && S_VT && S_VT->getElementType() == D_AT->getElementType())
storeArrayFromVector(B, Src, Dst, D_AT, Alignment);
else
llvm_unreachable("Unsupported ptrcast use in store. Please fix.");

Expand Down
54 changes: 54 additions & 0 deletions llvm/test/CodeGen/SPIRV/pointers/load-store-vec-from-array.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv-unknown-vulkan %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -verify-machineinstrs -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val %}

; CHECK-DAG: [[FLOAT:%[0-9]+]] = OpTypeFloat 32
; CHECK-DAG: [[VEC4FLOAT:%[0-9]+]] = OpTypeVector [[FLOAT]] 4
; CHECK-DAG: [[UINT_TYPE:%[0-9]+]] = OpTypeInt 32 0
; CHECK-DAG: [[UINT4:%[0-9]+]] = OpConstant [[UINT_TYPE]] 4
; CHECK-DAG: [[ARRAY4FLOAT:%[0-9]+]] = OpTypeArray [[FLOAT]] [[UINT4]]
; CHECK-DAG: [[PTR_ARRAY4FLOAT:%[0-9]+]] = OpTypePointer Private [[ARRAY4FLOAT]]
; CHECK-DAG: [[G_IN:%[0-9]+]] = OpVariable [[PTR_ARRAY4FLOAT]] Private
; CHECK-DAG: [[G_OUT:%[0-9]+]] = OpVariable [[PTR_ARRAY4FLOAT]] Private
; CHECK-DAG: [[UINT0:%[0-9]+]] = OpConstant [[UINT_TYPE]] 0
; CHECK-DAG: [[UINT1:%[0-9]+]] = OpConstant [[UINT_TYPE]] 1
; CHECK-DAG: [[UINT2:%[0-9]+]] = OpConstant [[UINT_TYPE]] 2
; CHECK-DAG: [[UINT3:%[0-9]+]] = OpConstant [[UINT_TYPE]] 3
; CHECK-DAG: [[PTR_FLOAT:%[0-9]+]] = OpTypePointer Private [[FLOAT]]
; CHECK-DAG: [[UNDEF_VEC:%[0-9]+]] = OpUndef [[VEC4FLOAT]]

@G_in = internal addrspace(10) global [4 x float] zeroinitializer
@G_out = internal addrspace(10) global [4 x float] zeroinitializer

define spir_func void @main() {
entry:
; CHECK: [[GEP0:%[0-9]+]] = OpAccessChain [[PTR_FLOAT]] [[G_IN]] [[UINT0]]
; CHECK-NEXT: [[LOAD0:%[0-9]+]] = OpLoad [[FLOAT]] [[GEP0]]
; CHECK-NEXT: [[GEP1:%[0-9]+]] = OpAccessChain [[PTR_FLOAT]] [[G_IN]] [[UINT1]]
; CHECK-NEXT: [[LOAD1:%[0-9]+]] = OpLoad [[FLOAT]] [[GEP1]]
; CHECK-NEXT: [[GEP2:%[0-9]+]] = OpAccessChain [[PTR_FLOAT]] [[G_IN]] [[UINT2]]
; CHECK-NEXT: [[LOAD2:%[0-9]+]] = OpLoad [[FLOAT]] [[GEP2]]
; CHECK-NEXT: [[GEP3:%[0-9]+]] = OpAccessChain [[PTR_FLOAT]] [[G_IN]] [[UINT3]]
; CHECK-NEXT: [[LOAD3:%[0-9]+]] = OpLoad [[FLOAT]] [[GEP3]]
; CHECK-NEXT: [[VEC_INSERT0:%[0-9]+]] = OpCompositeInsert [[VEC4FLOAT]] [[LOAD0]] [[UNDEF_VEC]] 0
; CHECK-NEXT: [[VEC_INSERT1:%[0-9]+]] = OpCompositeInsert [[VEC4FLOAT]] [[LOAD1]] [[VEC_INSERT0]] 1
; CHECK-NEXT: [[VEC_INSERT2:%[0-9]+]] = OpCompositeInsert [[VEC4FLOAT]] [[LOAD2]] [[VEC_INSERT1]] 2
; CHECK-NEXT: [[VEC:%[0-9]+]] = OpCompositeInsert [[VEC4FLOAT]] [[LOAD3]] [[VEC_INSERT2]] 3
%0 = load <4 x float>, ptr addrspace(10) @G_in, align 64

; CHECK-NEXT: [[GEP_OUT0:%[0-9]+]] = OpAccessChain [[PTR_FLOAT]] [[G_OUT]] [[UINT0]]
; CHECK-NEXT: [[VEC_EXTRACT0:%[0-9]+]] = OpCompositeExtract [[FLOAT]] [[VEC]] 0
; CHECK-NEXT: OpStore [[GEP_OUT0]] [[VEC_EXTRACT0]]
; CHECK-NEXT: [[GEP_OUT1:%[0-9]+]] = OpAccessChain [[PTR_FLOAT]] [[G_OUT]] [[UINT1]]
; CHECK-NEXT: [[VEC_EXTRACT1:%[0-9]+]] = OpCompositeExtract [[FLOAT]] [[VEC]] 1
; CHECK-NEXT: OpStore [[GEP_OUT1]] [[VEC_EXTRACT1]]
; CHECK-NEXT: [[GEP_OUT2:%[0-9]+]] = OpAccessChain [[PTR_FLOAT]] [[G_OUT]] [[UINT2]]
; CHECK-NEXT: [[VEC_EXTRACT2:%[0-9]+]] = OpCompositeExtract [[FLOAT]] [[VEC]] 2
; CHECK-NEXT: OpStore [[GEP_OUT2]] [[VEC_EXTRACT2]]
; CHECK-NEXT: [[GEP_OUT3:%[0-9]+]] = OpAccessChain [[PTR_FLOAT]] [[G_OUT]] [[UINT3]]
; CHECK-NEXT: [[VEC_EXTRACT3:%[0-9]+]] = OpCompositeExtract [[FLOAT]] [[VEC]] 3
; CHECK-NEXT: OpStore [[GEP_OUT3]] [[VEC_EXTRACT3]]
store <4 x float> %0, ptr addrspace(10) @G_out, align 64

; CHECK-NEXT: OpReturn
ret void
}