Skip to content

Commit

Permalink
unify inference code, fix sea-of-nodes set
Browse files Browse the repository at this point in the history
  • Loading branch information
martty committed Apr 19, 2024
1 parent 886b2e7 commit 31ce739
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 32 deletions.
69 changes: 38 additions & 31 deletions include/vuk/Value.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,10 @@ namespace vuk {
if (def.node->kind == Node::ACQUIRE_NEXT_IMAGE) {
return;
}
def.node->construct.args[1] = get_render_graph()->make_extract(src.get_def(), 0);
def.node->construct.args[2] = get_render_graph()->make_extract(src.get_def(), 1);
def.node->construct.args[3] = get_render_graph()->make_extract(src.get_def(), 2);
node->deps.push_back(src.node);
replace_arg_with_extract_or_constant(def, src.get_def(), 0);
replace_arg_with_extract_or_constant(def, src.get_def(), 1);
replace_arg_with_extract_or_constant(def, src.get_def(), 2);
}

/// @brief Inference target has the same width & height as the source
Expand All @@ -131,8 +132,9 @@ namespace vuk {
if (def.node->kind == Node::ACQUIRE_NEXT_IMAGE) {
return;
}
def.node->construct.args[1] = get_render_graph()->make_extract(src.get_def(), 0);
def.node->construct.args[2] = get_render_graph()->make_extract(src.get_def(), 1);
node->deps.push_back(src.node);
replace_arg_with_extract_or_constant(def, src.get_def(), 0);
replace_arg_with_extract_or_constant(def, src.get_def(), 1);
}

/// @brief Inference target has the same format as the source
Expand All @@ -142,7 +144,8 @@ namespace vuk {
if (def.node->kind == Node::ACQUIRE_NEXT_IMAGE) {
return;
}
def.node->construct.args[4] = get_render_graph()->make_extract(src.get_def(), 3);
node->deps.push_back(src.node);
replace_arg_with_extract_or_constant(def, src.get_def(), 3);
}

/// @brief Inference target has the same shape(extent, layers, levels) as the source
Expand All @@ -155,7 +158,7 @@ namespace vuk {
same_extent_as(src);

for (auto i = 6; i < 10; i++) { /* 6 - 9 : layers, levels */
def.node->construct.args[i] = get_render_graph()->make_extract(src.get_def(), i - 1);
replace_arg_with_extract_or_constant(def, src.get_def(), i - 1);
}
}

Expand All @@ -168,36 +171,16 @@ namespace vuk {
}
same_shape_as(src);
same_format_as(src);
def.node->construct.args[5] = get_render_graph()->make_extract(src.get_def(), 4);
replace_arg_with_extract_or_constant(def, src.get_def(), 4);
}

// Buffer inferences

void same_size(const Value<Buffer>& src)
requires std::is_same_v<T, Buffer>
{
uint64_t index = 0;
Type* cty = get_render_graph()->u64();
auto constant_node = Node{ .kind = Node::CONSTANT, .type = std::span{ &cty, 1 } };
constant_node.constant.value = &index; // writing these out for clang workaround

auto composite = src.get_def();
Type* ty;
auto stripped = Type::stripped(composite.type());
if (stripped->kind == Type::ARRAY_TY) {
ty = stripped->array.T;
} else if (stripped->kind == Type::COMPOSITE_TY) {
ty = stripped->composite.types[index];
}
auto candidate_node = Node{ .kind = Node::EXTRACT, .type = std::span{ &ty, 1 } };
candidate_node.extract.composite = composite; // writing these out for clang workaround
candidate_node.extract.index = first(&constant_node);
try {
auto result = eval<uint64_t>(first(&candidate_node));
def.node->construct.args[1] = get_render_graph()->make_constant<uint64_t>(result);
} catch (...) {
def.node->construct.args[1] = get_render_graph()->make_extract(src.get_def(), 0);
}
node->deps.push_back(src.node);
replace_arg_with_extract_or_constant(def, src.get_def(), 0);
}

Value<uint64_t> get_size()
Expand All @@ -210,7 +193,7 @@ namespace vuk {
void set_size(Value<uint64_t> arg)
requires std::is_same_v<T, Buffer>
{
get_render_graph()->subgraphs.push_back(arg.get_render_graph());
node->deps.push_back(arg.node);
def.node->construct.args[1] = arg.get_head();
}

Expand Down Expand Up @@ -250,6 +233,30 @@ namespace vuk {
node->module->make_constant(1u));
return Value(ExtRef(std::make_shared<ExtNode>(get_render_graph(), item.node, node), item), item_def);
}

void replace_arg_with_extract_or_constant(Ref construct, Ref src_composite, uint64_t index) {
Type* cty = get_render_graph()->u64();
auto constant_node = Node{ .kind = Node::CONSTANT, .type = std::span{ &cty, 1 } };
constant_node.constant.value = &index; // writing these out for clang workaround

auto composite = src_composite;
Type* ty;
auto stripped = Type::stripped(composite.type());
if (stripped->kind == Type::ARRAY_TY) {
ty = stripped->array.T;
} else if (stripped->kind == Type::COMPOSITE_TY) {
ty = stripped->composite.types[index];
}
auto candidate_node = Node{ .kind = Node::EXTRACT, .type = std::span{ &ty, 1 } };
candidate_node.extract.composite = composite; // writing these out for clang workaround
candidate_node.extract.index = first(&constant_node);
try {
auto result = eval<uint64_t>(first(&candidate_node));
construct.node->construct.args[index + 1] = get_render_graph()->make_constant<uint64_t>(result);
} catch (...) {
construct.node->construct.args[index + 1] = get_render_graph()->make_extract(composite, index);
}
}
};

inline Value<uint64_t> operator*(Value<uint64_t> a, uint64_t b) {
Expand Down
2 changes: 1 addition & 1 deletion src/RenderGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,7 @@ namespace vuk {
// linked-sea-of-nodes to list of nodes
std::vector<IRModule*, short_alloc<IRModule*>> work_queue(*impl->arena_);
std::unordered_set<IRModule*> visited;
for (auto& ref : impl->refs) {
for (auto& ref : impl->depnodes) {
auto mod = ref->module.get();
if (!visited.count(mod)) {
work_queue.push_back(mod);
Expand Down

0 comments on commit 31ce739

Please sign in to comment.