Skip to content

Commit

Permalink
Merge pull request #167 from google:jbaileyhandle_integration_functio…
Browse files Browse the repository at this point in the history
…n_class

PiperOrigin-RevId: 339957938
  • Loading branch information
Copybara-Service committed Oct 30, 2020
2 parents d8ad02d + 9328f1a commit 4c08f75
Show file tree
Hide file tree
Showing 4 changed files with 498 additions and 0 deletions.
1 change: 1 addition & 0 deletions xls/contrib/integrator/BUILD
Expand Up @@ -35,6 +35,7 @@ cc_test(
":ir_integrator",
"//xls/common/status:matchers",
"//xls/ir",
"//xls/ir:ir_matcher",
"//xls/ir:ir_parser",
"//xls/ir:ir_test_base",
"@com_google_googletest//:gtest_main",
Expand Down
101 changes: 101 additions & 0 deletions xls/contrib/integrator/ir_integrator.cc
Expand Up @@ -18,6 +18,107 @@

namespace xls {

absl::StatusOr<std::unique_ptr<IntegrationFunction>>
IntegrationFunction::MakeIntegrationFunctionWithParamTuples(
Package* package, absl::Span<const Function* const> source_functions,
std::string function_name) {
// Create integration function object.
auto integration_function = absl::make_unique<IntegrationFunction>();
integration_function->package_ = package;

// Create ir function.
integration_function->function_ =
absl::make_unique<Function>(function_name, package);

// Package source function parameters as tuple parameters to integration
// function.
for (const auto* source_func : source_functions) {
// Add tuple paramter for source function.
std::vector<Type*> arg_types;
for (const Node* param : source_func->params()) {
arg_types.push_back(param->GetType());
}
Type* args_tuple_type = package->GetTupleType(arg_types);
std::string tuple_name = source_func->name() + std::string("ParamTuple");
XLS_ASSIGN_OR_RETURN(
Node * args_tuple,
integration_function->function_->MakeNodeWithName<Param>(
/*loc=*/std::nullopt, tuple_name, args_tuple_type));

// Add TupleIndex nodes inside function to unpack tuple parameter.
int64 parameter_index = 0;
for (const Node* param : source_func->params()) {
XLS_ASSIGN_OR_RETURN(
Node * tuple_index,
integration_function->function_->MakeNode<TupleIndex>(
/*loc=*/std::nullopt, args_tuple, parameter_index));
XLS_RETURN_IF_ERROR(
integration_function->SetNodeMapping(param, tuple_index));
parameter_index++;
}
}

return std::move(integration_function);
}

absl::Status IntegrationFunction::SetNodeMapping(const Node* source,
Node* map_target) {
// Validate map pairing.
XLS_RET_CHECK_NE(source, map_target);
XLS_RET_CHECK(IntegrationFunctionOwnsNode(map_target));
XLS_RET_CHECK(
!(IntegrationFunctionOwnsNode(source) && !IsMappingTarget(source)));

// 'original' is itself a member of the integrated function.
if (IntegrationFunctionOwnsNode(source)) {
absl::flat_hash_set<const Node*>& nodes_that_map_to_source =
integrated_node_to_original_nodes_map_.at(source);

// Nodes that previously mapped to original now map to map_target.
for (const Node* original_node : nodes_that_map_to_source) {
integrated_node_to_original_nodes_map_[map_target].insert(original_node);
XLS_RET_CHECK(HasMapping(original_node));
original_node_to_integrated_node_map_[original_node] = map_target;
}

// No nodes map to source anymore.
integrated_node_to_original_nodes_map_.erase(source);

// 'source' is an external node.
} else {
original_node_to_integrated_node_map_[source] = map_target;
integrated_node_to_original_nodes_map_[map_target].insert(source);
}

return absl::OkStatus();
}

absl::StatusOr<Node*> IntegrationFunction::GetNodeMapping(
const Node* original) const {
XLS_RET_CHECK(!IntegrationFunctionOwnsNode(original));
if (!HasMapping(original)) {
return absl::InternalError("No mapping found for original node");
}
return original_node_to_integrated_node_map_.at(original);
}

absl::StatusOr<const absl::flat_hash_set<const Node*>*>
IntegrationFunction::GetNodesMappedToNode(const Node* map_target) const {
XLS_RET_CHECK(IntegrationFunctionOwnsNode(map_target));
if (!IsMappingTarget(map_target)) {
return absl::InternalError("No mappings found for map target node");
}
return &integrated_node_to_original_nodes_map_.at(map_target);
}

bool IntegrationFunction::HasMapping(const Node* node) const {
return original_node_to_integrated_node_map_.contains(node);
}

bool IntegrationFunction::IsMappingTarget(const Node* node) const {
return integrated_node_to_original_nodes_map_.contains(node);
}

absl::StatusOr<Function*> IntegrationBuilder::CloneFunctionRecursive(
const Function* function,
absl::flat_hash_map<const Function*, Function*>* call_remapping) {
Expand Down
57 changes: 57 additions & 0 deletions xls/contrib/integrator/ir_integrator.h
Expand Up @@ -21,6 +21,63 @@

namespace xls {

// Class that represents an integration function i.e. a function combining the
// IR of other functions. This class tracks which original function nodes are
// mapped to which integration function nodes. It also provides some utilities
// that are useful for constructing the integrated function.
class IntegrationFunction {
public:
IntegrationFunction() {}

IntegrationFunction(const IntegrationFunction& other) = delete;
void operator=(const IntegrationFunction& other) = delete;

// Create an IntegrationFunction object that is empty expect for
// parameters. Each initial parameter of the function is a tuple
// which holds inputs corresponding to the paramters of one
// of the source_functions.
static absl::StatusOr<std::unique_ptr<IntegrationFunction>>
MakeIntegrationFunctionWithParamTuples(
Package* package, absl::Span<const Function* const> source_functions,
std::string function_name = "IntegrationFunction");

Function* function() const { return function_.get(); }

// Declares that node 'source' from a source function maps
// to node 'map_target' in the integrated_function.
absl::Status SetNodeMapping(const Node* source, Node* map_target);

// Returns the integrated node that 'original' maps to, if it
// exists. Otherwise, return an error status.
absl::StatusOr<Node*> GetNodeMapping(const Node* original) const;

// Returns the original nodes that map to 'map_target' in the integrated
// function.
absl::StatusOr<const absl::flat_hash_set<const Node*>*> GetNodesMappedToNode(
const Node* map_target) const;

// Returns true if 'node' is mapped to a node in the integrated function.
bool HasMapping(const Node* node) const;

// Returns true if other nodes map to 'node'
bool IsMappingTarget(const Node* node) const;

// Returns true if 'node' is in the integrated function.
bool IntegrationFunctionOwnsNode(const Node* node) const {
return function_.get() == node->function_base();
}

private:
// Track mapping of original function nodes to integrated function nodes.
absl::flat_hash_map<const Node*, Node*> original_node_to_integrated_node_map_;
absl::flat_hash_map<const Node*, absl::flat_hash_set<const Node*>>
integrated_node_to_original_nodes_map_;

// Integrated function.
std::unique_ptr<Function> function_;
Package* package_;
};

// Class used to integrate separate functions into a combined, reprogrammable
// circuit that can be configured to have the same functionality as the
// input functions. The builder will attempt to construct the integrated
Expand Down

0 comments on commit 4c08f75

Please sign in to comment.