Skip to content

Commit

Permalink
Change argument/type signature of Call::prefetch (Fixes #4211)
Browse files Browse the repository at this point in the history
As noted in the issue above, the return type of this Call was used weirdly and wrongly: it wasn't the return type of the intrinsic (which is always int32, and always ignored), but rather, the type of the elements to prefetch. This didn't seem to be causing any obvious errors, but it meant we needed some weird special cases in the code. We now insert another element in the args to carry the type needed (with a value of constant zero).
  • Loading branch information
steven-johnson committed Jun 8, 2023
1 parent 67eaff3 commit 5373f6a
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 75 deletions.
19 changes: 11 additions & 8 deletions src/CodeGen_C.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1825,20 +1825,23 @@ void CodeGen_C::visit(const Call *op) {
} else if (op->is_intrinsic(Call::undef)) {
user_error << "undef not eliminated before code generation. Please report this as a Halide bug.\n";
} else if (op->is_intrinsic(Call::prefetch)) {
user_assert((op->args.size() == 4) && is_const_one(op->args[2]))
internal_assert(op->type == Int(32));
user_assert((op->args.size() == 5) && is_const_one(op->args[3]))
<< "Only prefetch of 1 cache line is supported in C backend.\n";

const Expr &base_address = op->args[0];
const Expr &base_offset = op->args[1];
// const Expr &extent0 = op->args[2]; // unused
// const Expr &stride0 = op->args[3]; // unused
const Type prefetch_element_type = op->args[0].type();
const Expr &base_address = op->args[1];
const Expr &base_offset = op->args[2];
// const Expr &extent0 = op->args[3]; // unused
// const Expr &stride0 = op->args[4]; // unused

const Variable *base = base_address.as<Variable>();
internal_assert(base && base->type.is_handle());
// TODO: provide some way to customize the rw and locality?
rhs << "__builtin_prefetch("
<< "((" << print_type(op->type) << " *)" << print_name(base->name)
<< " + " << print_expr(base_offset) << "), /*rw*/0, /*locality*/0)";
// __builtin_prefetch() returns void, so use comma operator to satisfy assignment
rhs << "(__builtin_prefetch("
<< "((" << print_type(prefetch_element_type) << " *)" << print_name(base->name)
<< " + " << print_expr(base_offset) << "), /*rw*/0, /*locality*/0), 0)";
} else if (op->is_intrinsic(Call::size_of_halide_buffer_t)) {
rhs << "(sizeof(halide_buffer_t))";
} else if (op->is_intrinsic(Call::strict_float)) {
Expand Down
23 changes: 13 additions & 10 deletions src/CodeGen_Hexagon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1937,20 +1937,22 @@ void CodeGen_Hexagon::visit(const Call *op) {
}

if (op->is_intrinsic(Call::prefetch)) {
internal_assert((op->args.size() == 4) || (op->args.size() == 6))
internal_assert(op->type == Int(32));
internal_assert((op->args.size() == 5) || (op->args.size() == 7))
<< "Hexagon only supports 1D or 2D prefetch\n";

const int elem_size = op->type.bytes();
const Expr &base_address = op->args[0];
const Expr &base_offset = op->args[1];
const Expr &extent0 = op->args[2];
const Expr &stride0 = op->args[3];
const Type prefetch_element_type = op->args[0].type();
const Expr &base_address = op->args[1];
const Expr &base_offset = op->args[2];
const Expr &extent0 = op->args[3];
const Expr &stride0 = op->args[4];

const int elem_size = prefetch_element_type.bytes();
Expr width_bytes = extent0 * stride0 * elem_size;
Expr height, stride_bytes;
if (op->args.size() == 6) {
const Expr &extent1 = op->args[4];
const Expr &stride1 = op->args[5];
if (op->args.size() == 7) {
const Expr &extent1 = op->args[5];
const Expr &stride1 = op->args[6];
height = extent1;
stride_bytes = stride1 * elem_size;
} else {
Expand All @@ -1959,7 +1961,7 @@ void CodeGen_Hexagon::visit(const Call *op) {
}

vector<llvm::Value *> args;
args.push_back(codegen_buffer_pointer(codegen(base_address), op->type, base_offset));
args.push_back(codegen_buffer_pointer(codegen(base_address), prefetch_element_type, base_offset));
args.push_back(codegen(width_bytes));
args.push_back(codegen(height));
args.push_back(codegen(stride_bytes));
Expand All @@ -1974,6 +1976,7 @@ void CodeGen_Hexagon::visit(const Call *op) {
args[0] = builder->CreateBitCast(args[0], ptr_type);

value = builder->CreateCall(prefetch_fn, args);
internal_assert(value->getType() == i32_t);
return;
}

Expand Down
22 changes: 10 additions & 12 deletions src/CodeGen_LLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1334,12 +1334,7 @@ Value *CodeGen_LLVM::codegen(const Expr &e) {
// (eg OpenCL, HVX, WASM); for now we're just ignoring the assert, but
// in the long run we should improve the smarts. See https://github.com/halide/Halide/issues/4194.
const bool is_bool_vector = e.type().is_bool() && e.type().lanes() > 1;
// TODO: skip this correctness check for prefetch, because the return type
// of prefetch indicates the type being prefetched, which does not match the
// implementation of prefetch.
// See https://github.com/halide/Halide/issues/4211.
const bool is_prefetch = Call::as_intrinsic(e, {Call::prefetch});
bool types_match = is_bool_vector || is_prefetch ||
bool types_match = is_bool_vector ||
e.type().is_handle() ||
value->getType()->isVoidTy() ||
value->getType() == llvm_type_of(e.type());
Expand Down Expand Up @@ -3282,26 +3277,29 @@ void CodeGen_LLVM::visit(const Call *op) {
llvm::CallInst *call = builder->CreateCall(base_fn->getFunctionType(), phi, call_args);
value = call;
} else if (op->is_intrinsic(Call::prefetch)) {
user_assert((op->args.size() == 4) && is_const_one(op->args[2]))
internal_assert(op->type == Int(32));
user_assert((op->args.size() == 5) && is_const_one(op->args[3]))
<< "Only prefetch of 1 cache line is supported.\n";

const Expr &base_address = op->args[0];
const Expr &base_offset = op->args[1];
// const Expr &extent0 = op->args[2]; // unused
// const Expr &stride0 = op->args[3]; // unused
const Type prefetch_element_type = op->args[0].type();
const Expr &base_address = op->args[1];
const Expr &base_offset = op->args[2];
// const Expr &extent0 = op->args[3]; // unused
// const Expr &stride0 = op->args[4]; // unused

llvm::Function *prefetch_fn = module->getFunction("_halide_prefetch");
internal_assert(prefetch_fn);

vector<llvm::Value *> args;
args.push_back(codegen_buffer_pointer(codegen(base_address), op->type, base_offset));
args.push_back(codegen_buffer_pointer(codegen(base_address), prefetch_element_type, base_offset));
// The first argument is a pointer, which has type i8*. We
// need to cast the argument, which might be a pointer to a
// different type.
llvm::Type *ptr_type = prefetch_fn->getFunctionType()->params()[0];
args[0] = builder->CreateBitCast(args[0], ptr_type);

value = builder->CreateCall(prefetch_fn, args);
internal_assert(value->getType() == i32_t);
} else if (op->is_intrinsic(Call::signed_integer_overflow)) {
user_error << "Signed integer overflow occurred during constant-folding. Signed"
" integer overflow for int32 and int64 is undefined behavior in"
Expand Down
7 changes: 5 additions & 2 deletions src/IR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -703,8 +703,11 @@ Expr Call::make(Type type, const std::string &name, const std::vector<Expr> &arg
FunctionPtr func, int value_index,
Buffer<> image, Parameter param) {
if (name == intrinsic_op_names[Call::prefetch] && call_type == Call::Intrinsic) {
internal_assert(args.size() % 2 == 0)
<< "Number of args to a prefetch call should be even: {base, offset, extent0, stride0, extent1, stride1, ...}\n";
internal_assert(type == Int(32))
<< "The return type of a prefetch call must be Int(32)";
internal_assert(args.size() >= 5 && (args.size() % 2) == 1) // Prefetch: {prefetch_element_type(0), base, offset, extent0, stride0, ...}
<< "Number of args to a prefetch call should be even: {prefetch_element_type(0), base, offset, extent0, stride0, extent1, stride1, ...}\n";
internal_assert(is_const_zero(args[0])) << "The first arg to a prefetch call should be a constant zero of the type to prefetch";
}
for (size_t i = 0; i < args.size(); i++) {
internal_assert(args[i].defined()) << "Call of " << name << " with argument " << i << " undefined.\n";
Expand Down
30 changes: 17 additions & 13 deletions src/Prefetch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,10 +278,12 @@ class ReducePrefetchDimension : public IRMutator {
// the dimensions with larger strides and keep the smaller ones in
// the prefetch call.

const size_t max_arg_size = 2 + 2 * max_dim; // Prefetch: {base, offset, extent0, stride0, extent1, stride1, ...}
const size_t max_arg_size = 3 + 2 * max_dim; // Prefetch: {prefetch_element_type(0), base, offset, extent0, stride0, extent1, stride1, ...}
if (prefetch && (prefetch->args.size() > max_arg_size)) {
const Expr &base_address = prefetch->args[0];
const Expr &base_offset = prefetch->args[1];
internal_assert(prefetch->type == Int(32));
// const Type prefetch_element_type = prefetch->args[0].type(); unused
const Expr &base_address = prefetch->args[1];
const Expr &base_offset = prefetch->args[2];

const Variable *base = base_address.as<Variable>();
internal_assert(base && base->type.is_handle());
Expand All @@ -296,14 +298,14 @@ class ReducePrefetchDimension : public IRMutator {
new_offset += Variable::make(Int(32), index_name) * stride;
}

vector<Expr> args = {base, new_offset};
for (size_t i = 2; i < max_arg_size; ++i) {
vector<Expr> args = {prefetch->args[0], base, new_offset};
for (size_t i = 3; i < max_arg_size; ++i) {
args.push_back(prefetch->args[i]);
}

stmt = Evaluate::make(Call::make(prefetch->type, Call::prefetch, args, Call::Intrinsic));
stmt = Evaluate::make(Call::make(Int(32), Call::prefetch, args, Call::Intrinsic));
for (size_t i = 0; i < index_names.size(); ++i) {
stmt = For::make(index_names[i], 0, prefetch->args[(i + max_dim) * 2 + 2],
stmt = For::make(index_names[i], 0, prefetch->args[(i + max_dim) * 2 + 3],
ForType::Serial, DeviceAPI::None, stmt);
}
debug(5) << "\nReduce prefetch to " << max_dim << " dim:\n"
Expand Down Expand Up @@ -333,18 +335,20 @@ class SplitPrefetch : public IRMutator {
op = stmt.as<Evaluate>();
internal_assert(op);
if (const Call *prefetch = Call::as_intrinsic(op->value, {Call::prefetch})) {
const Expr &base_address = prefetch->args[0];
const Expr &base_offset = prefetch->args[1];
internal_assert(prefetch->type == Int(32));
const Type prefetch_element_type = prefetch->args[0].type();
const Expr &base_address = prefetch->args[1];
const Expr &base_offset = prefetch->args[2];

const Variable *base = base_address.as<Variable>();
internal_assert(base && base->type.is_handle());

int elem_size = prefetch->type.bytes();
const int elem_size = prefetch_element_type.bytes();

vector<string> index_names;
vector<Expr> extents;
Expr new_offset = base_offset;
for (size_t i = 2; i < prefetch->args.size(); i += 2) {
for (size_t i = 3; i < prefetch->args.size(); i += 2) {
Expr extent = prefetch->args[i];
Expr stride = prefetch->args[i + 1];
Expr stride_bytes = stride * elem_size;
Expand All @@ -371,8 +375,8 @@ class SplitPrefetch : public IRMutator {

Expr new_extent = 1;
Expr new_stride = simplify(max_byte_size / elem_size);
vector<Expr> args = {base, std::move(new_offset), std::move(new_extent), std::move(new_stride)};
stmt = Evaluate::make(Call::make(prefetch->type, Call::prefetch, args, Call::Intrinsic));
vector<Expr> args = {prefetch->args[0], base, std::move(new_offset), std::move(new_extent), std::move(new_stride)};
stmt = Evaluate::make(Call::make(Int(32), Call::prefetch, args, Call::Intrinsic));
for (size_t i = 0; i < index_names.size(); ++i) {
stmt = For::make(index_names[i], 0, extents[i],
ForType::Serial, DeviceAPI::None, stmt);
Expand Down
4 changes: 2 additions & 2 deletions src/Simplify_Call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,15 +422,15 @@ Expr Simplify::visit(const Call *op, ExprInfo *bounds) {
// Collapse the prefetched region into lower dimension whenever is possible.
// TODO(psuriana): Deal with negative strides and overlaps.

internal_assert(op->args.size() % 2 == 0); // Prefetch: {base, offset, extent0, stride0, ...}
internal_assert(op->args.size() >= 5 && (op->args.size() % 2) == 1); // Prefetch: {prefetch_element_type(0), base, offset, extent0, stride0, ...}

auto [args, changed] = mutate_with_changes(op->args, nullptr);

// The {extent, stride} args in the prefetch call are sorted
// based on the storage dimension in ascending order (i.e. innermost
// first and outermost last), so, it is enough to check for the upper
// triangular pairs to see if any contiguous addresses exist.
for (size_t i = 2; i < args.size(); i += 2) {
for (size_t i = 3; i < args.size(); i += 2) {
Expr extent_0 = args[i];
Expr stride_0 = args[i + 1];
for (size_t j = i + 2; j < args.size(); j += 2) {
Expand Down
4 changes: 2 additions & 2 deletions src/StorageFlattening.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ class FlattenDimensions : public IRMutator {

Expr base_offset = mutate(flatten_args(op->name, prefetch_min, Buffer<>(), op->prefetch.param));
Expr base_address = Variable::make(Handle(), op->name);
vector<Expr> args = {base_address, base_offset};
vector<Expr> args = {make_zero(op->types[0]), base_address, base_offset};

auto iter = env.find(op->name);
if (iter != env.end()) {
Expand Down Expand Up @@ -391,7 +391,7 @@ class FlattenDimensions : public IRMutator {
}

// TODO: Consider generating a prefetch call for each tuple element.
Stmt prefetch_call = Evaluate::make(Call::make(op->types[0], Call::prefetch, args, Call::Intrinsic));
Stmt prefetch_call = Evaluate::make(Call::make(Int(32), Call::prefetch, args, Call::Intrinsic));
if (!is_const_one(condition)) {
prefetch_call = IfThenElse::make(condition, prefetch_call);
}
Expand Down
Loading

0 comments on commit 5373f6a

Please sign in to comment.