Skip to content

Commit

Permalink
Adding some CTS variants for indirect command buffers.
Browse files Browse the repository at this point in the history
This is a simple start that tests a few commands to test the
infrastructure for record/replay of indirect command buffers and the
validation mechanism. We don't have a good way in the CTS yet to
conditionally run tests based on whether validation is compiled into the
build but I'll be looking into that for testing failure cases in future
PRs (for now I've just tested manually to verify errors are propagated).

Fixed a few bugs found on the first runs.
  • Loading branch information
benvanik committed Jul 10, 2024
1 parent 6f25718 commit 7356efd
Show file tree
Hide file tree
Showing 11 changed files with 525 additions and 202 deletions.
5 changes: 5 additions & 0 deletions runtime/src/iree/hal/command_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,11 @@ iree_hal_buffer_binding_table_empty(void) {
return table;
}

static inline bool iree_hal_buffer_binding_table_is_empty(
iree_hal_buffer_binding_table_t binding_table) {
return binding_table.count == 0;
}

// Returns an unretained buffer specified in |buffer_ref| or from
// |binding_table| with the slot specified if indirect. If the caller needs to
// preserve the buffer for longer than the (known) lifetime of the binding table
Expand Down
52 changes: 31 additions & 21 deletions runtime/src/iree/hal/command_buffer_validation.c
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,16 @@ static iree_status_t iree_hal_command_buffer_validate_binding_requirements(
// it are in range.
if (requirements.max_byte_offset > 0) {
iree_device_size_t end = binding.offset + requirements.max_byte_offset;
if (IREE_UNLIKELY(end > binding.length)) {
return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
"at least one command attempted to access an "
"address outside of the valid bound buffer "
"range (length=%" PRIdsz ", end(inc)=%" PRIdsz
", binding offset=%" PRIdsz
", binding length=%" PRIdsz ")",
requirements.max_byte_offset, end - 1,
binding.offset, binding.length);
if (IREE_UNLIKELY(end > binding.offset + binding.length)) {
return iree_make_status(
IREE_STATUS_OUT_OF_RANGE,
"at least one command attempted to access an "
"address outside of the valid bound buffer "
"range (length=%" PRIdsz ", end(inc)=%" PRIdsz
", binding offset=%" PRIdsz ", binding length=%" PRIdsz
", binding end(inc)=%" PRIdsz ")",
requirements.max_byte_offset, end - 1, binding.offset, binding.length,
binding.offset + binding.length - 1);
}
}

Expand Down Expand Up @@ -188,8 +189,11 @@ static iree_status_t iree_hal_command_buffer_validate_buffer_requirements(
table_requirements->type |= requirements.type;
table_requirements->max_byte_offset = iree_max(
table_requirements->max_byte_offset, requirements.max_byte_offset);
table_requirements->min_byte_alignment = iree_device_size_lcm(
table_requirements->min_byte_alignment, requirements.min_byte_alignment);
if (requirements.min_byte_alignment) {
table_requirements->min_byte_alignment =
iree_device_size_lcm(table_requirements->min_byte_alignment,
requirements.min_byte_alignment);
}

return iree_ok_status();
}
Expand Down Expand Up @@ -430,14 +434,19 @@ iree_status_t iree_hal_command_buffer_copy_buffer_validation(
// Check for overlap - just like memcpy we don't handle that.
// Note that it's only undefined behavior if violated so we are ok if tricky
// situations (subspans of subspans of binding table subranges etc) make it
// through.
if (iree_hal_buffer_test_overlap(source_ref.buffer, source_ref.offset,
source_ref.length, target_ref.buffer,
target_ref.offset, target_ref.length) !=
IREE_HAL_BUFFER_OVERLAP_DISJOINT) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"source and target ranges overlap within the same buffer");
// through. This is only possible if both buffers are directly referenced -
// we _could_ try to catch this for indirect references by stashing the
// overlap check metadata for validation when the binding table is available
// but that's too costly to be worth it.
if (source_ref.buffer && target_ref.buffer) {
if (iree_hal_buffer_test_overlap(source_ref.buffer, source_ref.offset,
source_ref.length, target_ref.buffer,
target_ref.offset, target_ref.length) !=
IREE_HAL_BUFFER_OVERLAP_DISJOINT) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"source and target ranges overlap within the same buffer");
}
}

return iree_ok_status();
Expand Down Expand Up @@ -571,10 +580,11 @@ iree_status_t iree_hal_command_buffer_push_descriptor_set_validation(
// TODO(benvanik): validate set index.

// TODO(benvanik): use pipeline layout to derive usage and access bits.
// For now we conservatively say _any_ access may be performed (read/write).
iree_hal_buffer_binding_requirements_t requirements = {
.required_compatibility = IREE_HAL_BUFFER_COMPATIBILITY_QUEUE_DISPATCH,
// .usage = IREE_HAL_BUFFER_USAGE_DISPATCH_...,
// .access = IREE_HAL_MEMORY_ACCESS_...,
.usage = IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE,
.access = IREE_HAL_MEMORY_ACCESS_ANY,
.type = IREE_HAL_MEMORY_TYPE_DEVICE_VISIBLE,
};
for (iree_host_size_t i = 0; i < binding_count; ++i) {
Expand Down
126 changes: 76 additions & 50 deletions runtime/src/iree/hal/cts/command_buffer_dispatch_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "iree/base/api.h"
#include "iree/base/string_view.h"
#include "iree/hal/api.h"
#include "iree/hal/buffer_view_util.h"
#include "iree/hal/cts/cts_test_base.h"
#include "iree/testing/gtest.h"
#include "iree/testing/status_matchers.h"
Expand All @@ -18,7 +19,8 @@ namespace iree {
namespace hal {
namespace cts {

class command_buffer_dispatch_test : public CTSTestBase<> {
class CommandBufferDispatchTest
: public CTSTestBase<::testing::TestWithParam<RecordingType>> {
protected:
void PrepareAbsExecutable() {
IREE_ASSERT_OK(iree_hal_executable_cache_create(
Expand Down Expand Up @@ -76,58 +78,77 @@ class command_buffer_dispatch_test : public CTSTestBase<> {
iree_hal_executable_t* executable_ = NULL;
};

TEST_F(command_buffer_dispatch_test, DispatchAbs) {
// Dispatches absf(x) on a subrange (elements 1-2) of a 4 element input buffer.
// input_buffer = [-2.5 -2.5 -2.5 -2.5]
// output_buffer = [-9.0 2.5 2.5 -9.0]
TEST_P(CommandBufferDispatchTest, DispatchAbs) {
PrepareAbsExecutable();

iree_hal_command_buffer_t* command_buffer = NULL;
IREE_ASSERT_OK(iree_hal_command_buffer_create(
device_,
IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT |
IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION,
IREE_HAL_COMMAND_CATEGORY_DISPATCH, IREE_HAL_QUEUE_AFFINITY_ANY,
/*binding_capacity=*/0, &command_buffer));

IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer));
// Create input buffer.
iree_hal_buffer_t* input_buffer = NULL;
CreateFilledDeviceBuffer<float>(4 * sizeof(float), -2.5f, &input_buffer);

// Create input and output buffers.
iree_hal_buffer_params_t input_params = {0};
input_params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL;
input_params.usage =
IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE | IREE_HAL_BUFFER_USAGE_TRANSFER;
iree_hal_buffer_view_t* input_buffer_view = NULL;
float input_data[1] = {-2.5f};
IREE_ASSERT_OK(iree_hal_buffer_view_allocate_buffer_copy(
device_, device_allocator_,
/*shape_rank=*/0, /*shape=*/NULL, IREE_HAL_ELEMENT_TYPE_FLOAT_32,
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR, input_params,
iree_make_const_byte_span((void*)input_data, sizeof(input_data)),
&input_buffer_view));
iree_hal_buffer_params_t output_params = {0};
output_params.type =
IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL | IREE_HAL_MEMORY_TYPE_HOST_VISIBLE;
output_params.usage = IREE_HAL_BUFFER_USAGE_DISPATCH_STORAGE |
IREE_HAL_BUFFER_USAGE_TRANSFER |
IREE_HAL_BUFFER_USAGE_MAPPING;
// Create output buffer.
iree_hal_buffer_t* output_buffer = NULL;
IREE_ASSERT_OK(iree_hal_allocator_allocate_buffer(
device_allocator_, output_params, sizeof(float), &output_buffer));

iree_hal_buffer_ref_t descriptor_set_bindings[] = {
{
CreateFilledDeviceBuffer<float>(4 * sizeof(float), -9.0f, &output_buffer);

iree_hal_buffer_ref_t descriptor_set_bindings[2];
iree_hal_buffer_binding_t bindings[2];
iree_hal_buffer_binding_table_t binding_table =
iree_hal_buffer_binding_table_empty();
switch (GetParam()) {
case RecordingType::kDirect:
descriptor_set_bindings[0] = {
/*binding=*/0,
/*buffer_slot=*/0,
iree_hal_buffer_view_buffer(input_buffer_view),
/*offset=*/0,
iree_hal_buffer_view_byte_length(input_buffer_view),
},
{
/*buffer=*/input_buffer,
/*offset=*/1 * sizeof(float),
/*length=*/2 * sizeof(float),
};
descriptor_set_bindings[1] = {
/*binding=*/1,
/*buffer_slot=*/0,
output_buffer,
iree_hal_buffer_byte_offset(output_buffer),
iree_hal_buffer_byte_length(output_buffer),
},
};
/*buffer=*/output_buffer,
/*offset=*/1 * sizeof(float),
/*length=*/2 * sizeof(float),
};
break;
case RecordingType::kIndirect:
binding_table.count = IREE_ARRAYSIZE(descriptor_set_bindings);
binding_table.bindings = bindings;
bindings[0] = {
/*buffer=*/input_buffer,
/*offset=*/1 * sizeof(float),
/*length=*/2 * sizeof(float),
};
descriptor_set_bindings[0] = {
/*binding=*/0,
/*buffer_slot=*/0,
/*buffer=*/NULL,
/*offset=*/0,
/*length=*/2 * sizeof(float),
};
bindings[1] = {
/*buffer=*/output_buffer,
/*offset=*/1 * sizeof(float),
/*length=*/2 * sizeof(float),
};
descriptor_set_bindings[1] = {
/*binding=*/1,
/*buffer_slot=*/1,
/*buffer=*/NULL,
/*offset=*/0,
/*length=*/2 * sizeof(float),
};
break;
}

iree_hal_command_buffer_t* command_buffer = NULL;
IREE_ASSERT_OK(iree_hal_command_buffer_create(
device_, IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT,
IREE_HAL_COMMAND_CATEGORY_DISPATCH, IREE_HAL_QUEUE_AFFINITY_ANY,
binding_table.count, &command_buffer));
IREE_ASSERT_OK(iree_hal_command_buffer_begin(command_buffer));

IREE_ASSERT_OK(iree_hal_command_buffer_push_descriptor_set(
command_buffer, pipeline_layout_, /*set=*/0,
Expand All @@ -149,21 +170,26 @@ TEST_F(command_buffer_dispatch_test, DispatchAbs) {

IREE_ASSERT_OK(iree_hal_command_buffer_end(command_buffer));

IREE_ASSERT_OK(SubmitCommandBufferAndWait(command_buffer));
IREE_ASSERT_OK(SubmitCommandBufferAndWait(command_buffer, binding_table));

float output_value = 0.0f;
float output_values[4] = {0.0f};
IREE_ASSERT_OK(iree_hal_device_transfer_d2h(
device_, output_buffer,
/*source_offset=*/0, &output_value, sizeof(output_value),
/*source_offset=*/0, output_values, sizeof(output_values),
IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, iree_infinite_timeout()));
EXPECT_EQ(2.5f, output_value);
EXPECT_THAT(output_values, ::testing::ElementsAre(-9.0f, 2.5f, 2.5f, -9.0f));

iree_hal_command_buffer_release(command_buffer);
iree_hal_buffer_release(output_buffer);
iree_hal_buffer_view_release(input_buffer_view);
iree_hal_buffer_release(input_buffer);
CleanupExecutable();
}

INSTANTIATE_TEST_SUITE_P(CommandBufferTest, CommandBufferDispatchTest,
::testing::Values(RecordingType::kDirect,
RecordingType::kIndirect),
GenerateTestName());

} // namespace cts
} // namespace hal
} // namespace iree
Expand Down
Loading

0 comments on commit 7356efd

Please sign in to comment.