Skip to content

Commit

Permalink
[CUDA] Remove assumptions that bindings are a dense set (#6044)
Browse files Browse the repository at this point in the history
The previous implementation relied on the bindings being a dense set
being in order.
The new logic implements a new convention that the binding are mapped
are ordered and compressed based on binding index and mapped to kernel
arguments.
  • Loading branch information
ThomasRaoux committed Jun 6, 2021
1 parent ea65a2e commit 0dbd1ac
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 24 deletions.
46 changes: 37 additions & 9 deletions iree/compiler/Conversion/LinalgToLLVMGPU/ConvertToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,34 @@ class ScalarizationTestPass
}
};

// Convention with the HAL side to pass kernel arguments.
// The bindings are ordered based on binding index then compressed and mapped to
// dense set of arguments.
// This function looks at the symbols and return the mapping between binding
// index and kernel argument index. For instance if the kernel has bindings 1,
// 5, 6 it will return the mapping [1, 0], [5, 1], [6, 2]
static llvm::SmallDenseMap<uint64_t, size_t> getKernelArgMapping(
Operation *func) {
llvm::SmallDenseMap<uint64_t, size_t> mapBindingArgIndex;
llvm::SmallVector<uint64_t> bindingUsed;
Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(func);
SymbolTable::walkSymbolTables(symbolTableOp, true, [&](Operation *op, bool) {
if (auto interface = dyn_cast<IREE::HAL::InterfaceOp>(op)) {
interface.walk([&](Operation *symbolOp) {
if (auto binding = dyn_cast<IREE::HAL::InterfaceBindingOp>(symbolOp)) {
uint64_t bindingIndex = binding.binding().getZExtValue();
bindingUsed.push_back(bindingIndex);
}
});
}
});
std::sort(bindingUsed.begin(), bindingUsed.end());
for (auto binding : llvm::enumerate(bindingUsed)) {
mapBindingArgIndex[binding.value()] = binding.index();
}
return mapBindingArgIndex;
}

class ConvertFunc : public ConvertToLLVMPattern {
public:
explicit ConvertFunc(MLIRContext *context, LLVMTypeConverter &converter)
Expand All @@ -88,13 +116,16 @@ class ConvertFunc : public ConvertToLLVMPattern {
assert(fnType.getNumInputs() == 0 && fnType.getNumResults() == 0);

TypeConverter::SignatureConversion signatureConverter(/*numOrigInputs=*/0);
SmallVector<Type, 8> llvmInputTypes;
llvm::SmallDenseMap<uint64_t, size_t> argMapping =
getKernelArgMapping(funcOp);
SmallVector<Type, 8> llvmInputTypes(argMapping.size());
funcOp.walk([&](IREE::HAL::InterfaceBindingSubspanOp input) {
auto memrefType = input.getType().cast<MemRefType>();
Type elType = memrefType.getElementType();
auto llvmType =
LLVM::LLVMPointerType::get(elType, memrefType.getMemorySpaceAsInt());
llvmInputTypes.push_back(llvmType);
uint64_t binding = input.queryBindingOp().binding().getZExtValue();
llvmInputTypes[argMapping[binding]] = llvmType;
});
signatureConverter.addInputs(llvmInputTypes);

Expand Down Expand Up @@ -142,18 +173,15 @@ class ConvertIREEBindingOp : public ConvertToLLVMPattern {
if (!llvmFuncOp) return failure();
assert(llvmFuncOp.getNumArguments() > 0);

llvm::SmallDenseMap<uint64_t, size_t> argMapping =
getKernelArgMapping(llvmFuncOp);
Location loc = op->getLoc();
auto ireeBindingOp = cast<IREE::HAL::InterfaceBindingSubspanOp>(op);
IREE::HAL::InterfaceBindingSubspanOpAdaptor adaptor(operands);
MemRefType memrefType =
ireeBindingOp.getResult().getType().dyn_cast<MemRefType>();

// Fetch the interface binding op and extract the buffer index from void**.
auto symbol = SymbolTable::lookupNearestSymbolFrom(
op, op->getAttrOfType<SymbolRefAttr>("binding"));
auto interfaceBindingOp = cast<IREE::HAL::InterfaceBindingOp>(symbol);
Value llvmBufferBasePtr =
llvmFuncOp.getArgument(interfaceBindingOp.binding().getZExtValue());
uint64_t binding = ireeBindingOp.queryBindingOp().binding().getZExtValue();
Value llvmBufferBasePtr = llvmFuncOp.getArgument(argMapping[binding]);
if (memrefType.hasStaticShape()) {
auto desc = MemRefDescriptor::fromStaticShape(
rewriter, loc, *getTypeConverter(), memrefType, llvmBufferBasePtr);
Expand Down
19 changes: 10 additions & 9 deletions iree/compiler/Conversion/LinalgToLLVMGPU/test/convert_to_nvvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,27 @@
func @abs_ex_dispatch_0() {
%c0 = constant 0 : index
%0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<16xf32>
%1 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<16xf32>
%1 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<16xi32>
%2 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<16xf32>
%3 = "gpu.block_id"() {dimension = "x"} : () -> index
%4 = "gpu.block_dim"() {dimension = "x"} : () -> index
%5 = "gpu.thread_id"() {dimension = "x"} : () -> index
%6 = muli %3, %4 : index
%7 = addi %6, %5 : index
%9 = memref.load %1[%7] : memref<16xf32>
%10 = memref.load %2[%7] : memref<16xf32>
%11 = addf %9, %10 : f32
memref.store %11, %0[%7] : memref<16xf32>
%9 = memref.load %0[%7] : memref<16xf32>
%10 = memref.load %1[%7] : memref<16xi32>
%11 = sitofp %10 : i32 to f32
%12 = addf %9, %11 : f32
memref.store %12, %2[%7] : memref<16xf32>
return
}
hal.interface @io attributes {sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
hal.interface.binding @arg0, set=0, binding=4, type="StorageBuffer", access="Read"
hal.interface.binding @arg1, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @ret0, set=0, binding=7, type="StorageBuffer", access="Write|Discard"
}

// CHECK-LABEL: llvm.func @abs_ex_dispatch_0
// CHECK-SAME: (%{{.*}}: !llvm.ptr<f32>, %{{.*}}: !llvm.ptr<f32>, %{{.*}}: !llvm.ptr<f32>)
// CHECK-SAME: (%{{.*}}: !llvm.ptr<i32>, %{{.*}}: !llvm.ptr<f32>, %{{.*}}: !llvm.ptr<f32>)
// CHECK: nvvm.read.ptx.sreg.tid.x
// CHECK: llvm.fadd
39 changes: 33 additions & 6 deletions iree/hal/cuda/graph_command_buffer.c
Original file line number Diff line number Diff line change
Expand Up @@ -320,22 +320,49 @@ static iree_status_t iree_hal_cuda_graph_command_buffer_push_constants(
"need cuda implementation");
}

// Tie together the binding index and its index in |bindings| array.
typedef struct {
uint32_t index;
uint32_t binding;
} iree_hal_cuda_binding_mapping_t;

// Helper to sort the binding based on their binding index.
static int compare_binding_index(const void* a, const void* b) {
const iree_hal_cuda_binding_mapping_t buffer_a =
*(const iree_hal_cuda_binding_mapping_t*)a;
const iree_hal_cuda_binding_mapping_t buffer_b =
*(const iree_hal_cuda_binding_mapping_t*)b;
return buffer_a.binding < buffer_b.binding ? -1 : 1;
}

static iree_status_t iree_hal_cuda_graph_command_buffer_push_descriptor_set(
iree_hal_command_buffer_t* base_command_buffer,
iree_hal_executable_layout_t* executable_layout, uint32_t set,
iree_host_size_t binding_count,
const iree_hal_descriptor_set_binding_t* bindings) {
iree_hal_cuda_graph_command_buffer_t* command_buffer =
iree_hal_cuda_graph_command_buffer_cast(base_command_buffer);
// Convention with the compiler side. We map bindings to kernel argument.
// We compact the bindings to get a dense set of arguments and keep them order
// based on the binding index.
// Sort the binding based on the binding index and map the array index to the
// argument index.
iree_hal_cuda_binding_mapping_t binding_used[max_binding_count];
for (iree_host_size_t i = 0; i < binding_count; i++) {
iree_hal_cuda_binding_mapping_t buffer = {i, bindings[i].binding};
binding_used[i] = buffer;
}
qsort(binding_used, binding_count, sizeof(iree_hal_cuda_binding_mapping_t),
compare_binding_index);
assert(binding_count < max_binding_count &&
"binding count larger than the max expected.");
for (iree_host_size_t i = 0; i < binding_count; i++) {
uint32_t arg_index = bindings[i].binding;
assert(arg_index < max_binding_count &&
"binding index larger than the max expected.");
iree_hal_descriptor_set_binding_t binding = bindings[binding_used[i].index];
CUdeviceptr device_ptr =
iree_hal_cuda_buffer_device_pointer(
iree_hal_buffer_allocated_buffer(bindings[i].buffer)) +
iree_hal_buffer_byte_offset(bindings[i].buffer) + bindings[i].offset;
*((CUdeviceptr*)command_buffer->current_descriptor[arg_index]) = device_ptr;
iree_hal_buffer_allocated_buffer(binding.buffer)) +
iree_hal_buffer_byte_offset(binding.buffer) + binding.offset;
*((CUdeviceptr*)command_buffer->current_descriptor[i]) = device_ptr;
}
return iree_ok_status();
}
Expand Down

0 comments on commit 0dbd1ac

Please sign in to comment.