Skip to content

Commit

Permalink
Removing legacy stream binding recording. (#6146)
Browse files Browse the repository at this point in the history
  • Loading branch information
benvanik committed Jun 8, 2021
1 parent da141ef commit 5b958e7
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 106 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -771,84 +771,6 @@ static void recordInterfaceBindings(Value device, Value commandBuffer,
}
}

// DEPRECATED: this is going away with the legacy linalg-on-buffers path.
static LogicalResult recordLegacyBindings(
Value device, Value commandBuffer, IREE::Flow::DispatchOp &dispatchOp,
IREE::HAL::InterfaceOp &interfaceOp, Value executableLayout,
StreamSchedulingState &schedulingState,
ConversionPatternRewriter &rewriter) {
SmallVector<Value, 4> pushConstantValues;
for (auto inputValue : dispatchOp.operands()) {
if (inputValue.getType().isa<IndexType>() ||
inputValue.getType().isa<IntegerType>()) {
auto pushConstantValue = rewriter.getRemappedValue(inputValue);
// Need an explicit index cast to i32 since the
// CommandBufferPushConstantsOp is intrinsically i32 based.
if (inputValue.getType().isa<IndexType>()) {
pushConstantValue = rewriter.create<IndexCastOp>(
dispatchOp.getLoc(), rewriter.getIntegerType(32),
pushConstantValue);
}
pushConstantValues.push_back(pushConstantValue);
}
}
if (!pushConstantValues.empty()) {
uint64_t maxPushConstants =
interfaceOp.push_constants().hasValue()
? interfaceOp.push_constants().getValue().getZExtValue()
: 0;
(void)maxPushConstants;
assert(pushConstantValues.size() <= maxPushConstants &&
"uniform buffer spilling not yet implemented");

rewriter.create<IREE::HAL::CommandBufferPushConstantsOp>(
dispatchOp.getLoc(), commandBuffer, executableLayout,
rewriter.getIndexAttr(0), pushConstantValues);
}

LLVM_DEBUG(llvm::dbgs() << "HAL recordPushBindings: "
<< *dispatchOp.getOperation() << "\n");
uint32_t setOrdinal = 0;
uint32_t bindingOrdinal = 0;
SmallVector<IREE::HAL::DescriptorSetBindingValue, 4> bindings;
auto zeroOffset = schedulingState.lookupOrCreateIndex(0, rewriter);
auto pushBinding =
[&](Value tensorValue,
IREE::HAL::MemoryAccessBitfield accessType) -> LogicalResult {
auto bufferRange = schedulingState.lookupTensorBufferRange(tensorValue);
bindings.push_back(std::make_tuple(
schedulingState.lookupOrCreateIndex(bindingOrdinal++, rewriter),
bufferRange.buffer, zeroOffset, bufferRange.length));
return success();
};
for (auto it : llvm::enumerate(dispatchOp.operands())) {
LLVM_DEBUG(llvm::dbgs()
<< " + OPERAND(" << it.index() << "): " << it.value() << "\n");
if (it.value().getType().isa<TensorType>()) {
if (failed(
pushBinding(it.value(), IREE::HAL::MemoryAccessBitfield::Read))) {
return failure();
}
}
}
for (auto it : llvm::enumerate(dispatchOp.results())) {
LLVM_DEBUG(llvm::dbgs()
<< " + RESULT(" << it.index() << "): " << it.value() << "\n");
if (dispatchOp.getTiedResultOperandIndex(it.index())) {
LLVM_DEBUG(llvm::dbgs() << " TIED TO OPERAND; SKIP\n");
} else {
if (failed(pushBinding(it.value(),
IREE::HAL::MemoryAccessBitfield::DiscardWrite))) {
return failure();
}
}
}
rewriter.create<IREE::HAL::CommandBufferPushDescriptorSetOp>(
dispatchOp.getLoc(), commandBuffer, executableLayout, setOrdinal,
bindings);
return success();
}

// Records a dispatch operation.
static LogicalResult recordDispatch(Value device, Value commandBuffer,
IREE::Flow::DispatchOp &dispatchOp,
Expand Down Expand Up @@ -892,20 +814,11 @@ static LogicalResult recordDispatch(Value device, Value commandBuffer,
interfaceOp.push_constantsAttr(),
interfaceOp.getExecutableSetLayoutsAttr(), rewriter);

// HACK: switch off for new-style bindings population if the metadata is
// available.
if (auto bindingsAttr =
dispatchOp->getAttrOfType<ArrayAttr>("hal.bindings")) {
recordInterfaceBindings(device, commandBuffer, dispatchOp, interfaceOp,
executableLayout, bindingsAttr, schedulingState,
rewriter);
} else {
if (failed(recordLegacyBindings(device, commandBuffer, dispatchOp,
interfaceOp, executableLayout,
schedulingState, rewriter))) {
return failure();
}
}
auto bindingsAttr = dispatchOp->getAttrOfType<ArrayAttr>("hal.bindings");
assert(bindingsAttr);
recordInterfaceBindings(device, commandBuffer, dispatchOp, interfaceOp,
executableLayout, bindingsAttr, schedulingState,
rewriter);

dispatchState.entryPointOp = entryPointOp;
dispatchState.interfaceOp = interfaceOp;
Expand Down
80 changes: 66 additions & 14 deletions iree/compiler/Dialect/HAL/Conversion/FlowToHAL/test/stream_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,22 @@ func @multipleDispatches(%input: tensor<128xf32>) -> tensor<128xf32> {
// CHECK: hal.command_buffer.dispatch.symbol
// CHECK-SAME: target(@ex0::@vmvx::@entry0)
// CHECK: hal.command_buffer.execution_barrier
%1 = flow.dispatch @ex0::@entry0[%arg1](%arg2) : (tensor<128xf32>) -> tensor<128xf32>
%1 = flow.dispatch @ex0::@entry0[%arg1](%arg2) {
hal.bindings = [
#hal.ex.operand_buffer<"s0b0", 0 : index>,
#hal.ex.result_buffer<"s0b1", 0 : index>
]
} : (tensor<128xf32>) -> tensor<128xf32>
// CHECK: hal.command_buffer.push_descriptor_set
// CHECK: hal.command_buffer.dispatch.symbol
// CHECK-SAME: target(@ex0::@vmvx::@entry0)
// CHECK: hal.command_buffer.execution_barrier
%2 = flow.dispatch @ex0::@entry0[%arg1](%1) : (tensor<128xf32>) -> tensor<128xf32>
%2 = flow.dispatch @ex0::@entry0[%arg1](%1) {
hal.bindings = [
#hal.ex.operand_buffer<"s0b0", 0 : index>,
#hal.ex.result_buffer<"s0b1", 0 : index>
]
} : (tensor<128xf32>) -> tensor<128xf32>
flow.return %2 : tensor<128xf32>
}
// CHECK: hal.command_buffer.end<%[[CMD]]
Expand Down Expand Up @@ -237,7 +247,7 @@ func @tensorUpdate(%arg0 : tensor<1x1x10xf32>, %arg1 : tensor<5x1x10xf32>) -> te
// -----

hal.executable @ex0 {
hal.interface @interface attributes {push_constants = 2 : index} {
hal.interface @interface attributes {push_constants = 1 : index} {
hal.interface.binding @s0b0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @s0b1, set=0, binding=1, type="StorageBuffer", access="Read|Write"
}
Expand Down Expand Up @@ -267,9 +277,27 @@ func @dispatchWithShapeTies(%arg0: tensor<?x128xf32>, %bs : index) -> tensor<?x1
// allocation is covering all ops.
%0 = flow.ex.stream.fragment(%cst, %arg0, %bs) : (index, tensor<?x128xf32>{%cst}, index) -> tensor<?x128xf32>{%cst} =
(%arg1: index, %arg2: tensor<?x128xf32>, %arg3: index) -> tensor<?x128xf32> {
%3 = flow.dispatch @ex0::@entry0[%arg1](%arg2, %arg3) : (tensor<?x128xf32>{%arg3}, index) -> tensor<?x128xf32>{%arg3}
%5 = flow.dispatch @ex0::@entry0[%arg1](%3, %arg3) : (tensor<?x128xf32>{%arg3}, index) -> tensor<?x128xf32>{%arg3}
%7 = flow.dispatch @ex0::@entry0[%arg1](%5, %arg3) : (tensor<?x128xf32>{%arg3}, index) -> tensor<?x128xf32>{%arg3}
%3 = flow.dispatch @ex0::@entry0[%arg1](%arg2, %arg3) {
hal.bindings = [
#hal.ex.operand_buffer<"s0b0", 0 : index>,
#hal.ex.push_constant<0 : index, 1 : index>,
#hal.ex.result_buffer<"s0b1", 0 : index>
]
} : (tensor<?x128xf32>{%arg3}, index) -> tensor<?x128xf32>{%arg3}
%5 = flow.dispatch @ex0::@entry0[%arg1](%3, %arg3) {
hal.bindings = [
#hal.ex.operand_buffer<"s0b0", 0 : index>,
#hal.ex.push_constant<0 : index, 1 : index>,
#hal.ex.result_buffer<"s0b1", 0 : index>
]
} : (tensor<?x128xf32>{%arg3}, index) -> tensor<?x128xf32>{%arg3}
%7 = flow.dispatch @ex0::@entry0[%arg1](%5, %arg3) {
hal.bindings = [
#hal.ex.operand_buffer<"s0b0", 0 : index>,
#hal.ex.push_constant<0 : index, 1 : index>,
#hal.ex.result_buffer<"s0b1", 0 : index>
]
} : (tensor<?x128xf32>{%arg3}, index) -> tensor<?x128xf32>{%arg3}
flow.return %7 : tensor<?x128xf32>
}
return %0 : tensor<?x128xf32>
Expand Down Expand Up @@ -307,7 +335,12 @@ func @staticTiledDispatch(%input: tensor<7x4x24xf32>) -> tensor<4x7x1024xf32> {
// CHECK-NEXT: %c1 = (%{{.+}} : !hal.buffer)[%c0, %c114688]
// CHECK: hal.command_buffer.dispatch.symbol
// CHECK-SAME: target(@ex::@tgt::@entry)
%0 = flow.dispatch @ex::@entry[%arg6, %arg7, %arg7](%arg3) : (tensor<7x4x24xf32>) -> tensor<4x7x1024xf32>
%0 = flow.dispatch @ex::@entry[%arg6, %arg7, %arg7](%arg3) {
hal.bindings = [
#hal.ex.operand_buffer<"arg0", 0 : index>,
#hal.ex.result_buffer<"ret0", 0 : index>
]
} : (tensor<7x4x24xf32>) -> tensor<4x7x1024xf32>
flow.return %0 : tensor<4x7x1024xf32>
}
// CHECK: hal.command_buffer.end<%[[CMD]]
Expand Down Expand Up @@ -339,15 +372,15 @@ func @dynamicTiledDispatch(%arg0: tensor<7x?x24x?xf32>, %arg1: index, %arg2: ind
// CHECK-NEXT: hal.command_buffer.begin<%[[CMD]]
%2 = flow.ex.stream.fragment(%arg0, %arg1, %arg2, %c1024, %c512) : (tensor<7x?x24x?xf32>{%arg1, %arg2}, index, index, index, index) -> tensor<?x?x1024xf32>{%arg2, %arg1} =
(%arg3: tensor<7x?x24x?xf32>, %arg4: index, %arg5: index, %arg6: index, %arg7: index) -> tensor<?x?x1024xf32> {
// CHECK: hal.command_buffer.push_constants<%[[CMD]]
// CHECK-SAME: layout(%executable_layout
// CHECK-SAME: offset(0)
// CHECK-SAME: values([%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}]) : i32, i32, i32, i32
// CHECK: hal.command_buffer.push_descriptor_set<%[[CMD]]
// CHECK-SAME: layout(%executable_layout : !hal.executable_layout)[%c0]
// CHECK-SAME: bindings([
// CHECK-NEXT: %c0 = (%[[INPUT]] : !hal.buffer)[%c0, %2],
// CHECK-NEXT: %c1 = (%{{.+}} : !hal.buffer)[%c0, %5]
// CHECK: hal.command_buffer.push_constants<%[[CMD]]
// CHECK-SAME: layout(%executable_layout
// CHECK-SAME: offset(0)
// CHECK-SAME: values([%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}]) : i32, i32, i32, i32

// CHECK: #hal.device.match.id<"dylib*">(
// CHECK-SAME: %[[CMD_INNER:.+]] = %cmd : !hal.command_buffer,
Expand All @@ -358,7 +391,16 @@ func @dynamicTiledDispatch(%arg0: tensor<7x?x24x?xf32>, %arg1: index, %arg2: ind
// CHECK: hal.command_buffer.dispatch.symbol<%[[CMD_INNER]]
// CHECK-SAME: target(@ex::@tgt::@entry)
// CHECK-SAME: workgroups([%[[COUNT_X]], %[[COUNT_Y]], %[[COUNT_Z]]])
%6 = flow.dispatch @ex::@entry[%arg6, %arg7, %arg7](%arg3, %arg4, %arg5, %arg5, %arg4) : (tensor<7x?x24x?xf32>{%arg4, %arg5}, index, index, index, index) -> tensor<?x?x1024xf32>{%arg5, %arg4}
%6 = flow.dispatch @ex::@entry[%arg6, %arg7, %arg7](%arg3, %arg4, %arg5, %arg5, %arg4) {
hal.bindings = [
#hal.ex.operand_buffer<"arg0", 0 : index>,
#hal.ex.push_constant<0 : index, 1 : index>,
#hal.ex.push_constant<1 : index, 2 : index>,
#hal.ex.push_constant<2 : index, 3 : index>,
#hal.ex.push_constant<3 : index, 4 : index>,
#hal.ex.result_buffer<"ret0", 0 : index>
]
} : (tensor<7x?x24x?xf32>{%arg4, %arg5}, index, index, index, index) -> tensor<?x?x1024xf32>{%arg5, %arg4}
flow.return %6 : tensor<?x?x1024xf32>
}
// CHECK: hal.command_buffer.end<%[[CMD]]
Expand Down Expand Up @@ -415,7 +457,12 @@ func @dispatchTiedBuffer(%fill: tensor<i32>, %input: tensor<2x3xi32>) -> tensor<
// CHECK-SAME: bindings([
// CHECK-NEXT: %c0 = (%[[FILL]] : !hal.buffer)[%c0, %c4],
// CHECK-NEXT: %c1 = (%[[OUTPUT]] : !hal.buffer)[%c0, %c108]
%3 = flow.dispatch @pad_dispatch_0::@pad_dispatch_0[%c9, %c3, %c1](%arg0) : (tensor<i32>) -> tensor<3x9xi32>
%3 = flow.dispatch @pad_dispatch_0::@pad_dispatch_0[%c9, %c3, %c1](%arg0) {
hal.bindings = [
#hal.ex.operand_buffer<"ro0", 0 : index>,
#hal.ex.result_buffer<"wo1", 0 : index>
]
} : (tensor<i32>) -> tensor<3x9xi32>
// CHECK: %[[LAYOUT1:.+]] = hal.executable_layout.lookup
// CHECK-SAME: layouts([
// CHECK-SAME: #hal.descriptor_set_layout_binding<0, "StorageBuffer", R>,
Expand All @@ -425,7 +472,12 @@ func @dispatchTiedBuffer(%fill: tensor<i32>, %input: tensor<2x3xi32>) -> tensor<
// CHECK-SAME: bindings([
// CHECK-NEXT: %c0 = (%[[INPUT]] : !hal.buffer)[%c0, %c24],
// CHECK-NEXT: %c1 = (%[[OUTPUT]] : !hal.buffer)[%c0, %c108]
%4 = flow.dispatch @pad_dispatch_1::@pad_dispatch_1[%c9, %c3, %c1](%arg1, %3) : (tensor<2x3xi32>, tensor<3x9xi32>) -> %3
%4 = flow.dispatch @pad_dispatch_1::@pad_dispatch_1[%c9, %c3, %c1](%arg1, %3) {
hal.bindings = [
#hal.ex.operand_buffer<"ro0", 0 : index>,
#hal.ex.operand_buffer<"rw1", 1 : index>
]
} : (tensor<2x3xi32>, tensor<3x9xi32>) -> %3
flow.return %4 : tensor<3x9xi32>
}
return %0 : tensor<3x9xi32>
Expand Down

0 comments on commit 5b958e7

Please sign in to comment.