-
Notifications
You must be signed in to change notification settings - Fork 34
Static basic indexing propagator #501
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: constraint-programming
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,44 @@ | ||
| // Copyright 2026 D-Wave Systems Inc. | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
| #pragma once | ||
|
|
||
| #include "dwave-optimization/cp/core/cpvar.hpp" | ||
| #include "dwave-optimization/cp/core/propagator.hpp" | ||
| #include "dwave-optimization/nodes/indexing.hpp" | ||
|
|
||
| namespace dwave::optimization::cp { | ||
|
|
||
| struct BasicIndexingForwardTransform : IndexTransform { | ||
| BasicIndexingForwardTransform(const ArrayNode* array_ptr, const BasicIndexingNode* bi_ptr); | ||
|
|
||
| void affected(ssize_t i, std::vector<ssize_t>& out) override; | ||
|
|
||
| const ArrayNode* array_ptr_; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here too, members are public so should we remove |
||
| const BasicIndexingNode* bi_ptr_; | ||
| std::vector<BasicIndexingNode::slice_or_int> slices; | ||
| }; | ||
|
|
||
| class BasicIndexingPropagator : public Propagator { | ||
| public: | ||
| BasicIndexingPropagator(ssize_t index, CPVar* array, CPVar* basic_indexing); | ||
|
|
||
| void initialize_state(CPState& state) const override; | ||
| CPStatus propagate(CPPropagatorsState& p_state, CPVarsState& v_state) const override; | ||
|
|
||
| private: | ||
| CPVar* array_; | ||
| CPVar* basic_indexing_; | ||
| }; | ||
|
|
||
| } // namespace dwave::optimization::cp | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,153 @@ | ||
| // Copyright 2026 D-Wave Systems Inc. | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
|
|
||
| #include "dwave-optimization/cp/propagators/indexing_propagators.hpp" | ||
|
|
||
| #include "dwave-optimization/nodes/indexing.hpp" | ||
|
|
||
| namespace dwave::optimization::cp { | ||
|
|
||
| BasicIndexingForwardTransform::BasicIndexingForwardTransform(const ArrayNode* array_ptr, | ||
| const BasicIndexingNode* bi_ptr) | ||
| : array_ptr_(array_ptr), bi_ptr_(bi_ptr) { | ||
| slices = bi_ptr->infer_indices(); | ||
| for (ssize_t axis = 0; axis < array_ptr->ndim(); ++axis) { | ||
| if (std::holds_alternative<Slice>(slices[axis])) { | ||
| if (std::get<Slice>(slices[axis]).step != 1) { | ||
| throw std::invalid_argument("step != 1 not supported"); | ||
| } | ||
|
|
||
| slices[axis] = std::get<Slice>(slices[axis]).fit(array_ptr->shape()[axis]); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| void BasicIndexingForwardTransform::affected(ssize_t i, std::vector<ssize_t>& out) { | ||
| std::vector<ssize_t> in_multi_index = unravel_index(i, array_ptr_->shape()); | ||
| std::vector<ssize_t> out_multi_index; | ||
| bool belongs = true; | ||
| // Iterate through the axes to see if any index is outside the slice | ||
| for (ssize_t axis = 0; axis < array_ptr_->ndim(); ++axis) { | ||
| if (std::holds_alternative<ssize_t>(slices[axis])) { | ||
| if (in_multi_index[axis] == std::get<ssize_t>(slices[axis])) continue; | ||
| } else { | ||
| const auto& slice = std::get<Slice>(slices[axis]); | ||
| if (in_multi_index[axis] >= slice.start and in_multi_index[axis] < slice.stop) { | ||
| out_multi_index.push_back(in_multi_index[axis] - slice.start); | ||
| continue; | ||
| } | ||
| } | ||
|
|
||
| belongs = false; | ||
| break; | ||
| } | ||
|
|
||
| if (belongs) { | ||
| out.push_back(ravel_multi_index(out_multi_index, bi_ptr_->shape())); | ||
| } | ||
| } | ||
|
|
||
| BasicIndexingPropagator::BasicIndexingPropagator(ssize_t index, CPVar* array, CPVar* basic_indexing) | ||
| : Propagator(index) { | ||
| // TODO: not supporting dynamic variables for now | ||
| if (array->min_size() != array->max_size()) { | ||
| throw std::invalid_argument("dynamic arrays not supported"); | ||
| } | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should put the creation of the |
||
| array_ = array; | ||
| basic_indexing_ = basic_indexing; | ||
| } | ||
|
|
||
| void BasicIndexingPropagator::initialize_state(CPState& state) const { | ||
| CPPropagatorsState& p_state = state.get_propagators_state(); | ||
| assert(propagator_index_ >= 0); | ||
| assert(propagator_index_ < static_cast<ssize_t>(p_state.size())); | ||
| p_state[propagator_index_] = std::make_unique<PropagatorData>(state.get_state_manager(), | ||
| basic_indexing_->max_size()); | ||
| } | ||
|
|
||
| CPStatus BasicIndexingPropagator::propagate(CPPropagatorsState& p_state, | ||
| CPVarsState& v_state) const { | ||
| auto data = data_ptr<PropagatorData>(p_state); | ||
|
|
||
| const BasicIndexingNode* bi = dynamic_cast<const BasicIndexingNode*>(basic_indexing_->node_); | ||
| assert(bi); | ||
|
|
||
| // Not caching this for now as we may need to fit these at propagate time for | ||
| // dynamic arrays | ||
| std::vector<BasicIndexingNode::slice_or_int> slices = bi->infer_indices(); | ||
| for (ssize_t axis = 0; axis < array_->node_->ndim(); ++axis) { | ||
| if (std::holds_alternative<Slice>(slices[axis])) { | ||
| assert(std::get<Slice>(slices[axis]).step == 1); | ||
| slices[axis] = std::get<Slice>(slices[axis]).fit(array_->node_->shape()[axis]); | ||
| } | ||
| } | ||
|
|
||
| std::deque<ssize_t>& indices_to_process = data->indices_to_process(); | ||
|
|
||
| assert(indices_to_process.size() > 0); | ||
| while (indices_to_process.size() > 0) { | ||
| ssize_t bi_index = indices_to_process.front(); | ||
| indices_to_process.pop_front(); | ||
|
|
||
| // Derive the original array index based on the index of the basic indexing variable. | ||
| // We unravel the basic indexing variable index, transform the multi-index into | ||
| // one on the original array, and then ravel it to get the final linear index on | ||
| // the array. | ||
| std::vector<ssize_t> bi_multi_index = | ||
| unravel_index(bi_index, basic_indexing_->node_->shape()); | ||
| std::vector<ssize_t> arr_multi_index; | ||
| ssize_t bi_axis = 0; | ||
| for (ssize_t axis = 0; axis < array_->node_->ndim(); ++axis) { | ||
| if (std::holds_alternative<ssize_t>(slices[axis])) { | ||
| arr_multi_index.push_back(std::get<ssize_t>(slices[axis])); | ||
| continue; | ||
| } | ||
| assert(std::holds_alternative<Slice>(slices[axis])); | ||
| const auto& slice = std::get<Slice>(slices[axis]); | ||
| assert(slice.step == 1); | ||
| arr_multi_index.push_back(bi_multi_index[bi_axis] + slice.start); | ||
| bi_axis++; | ||
| } | ||
| ssize_t array_index = ravel_multi_index(arr_multi_index, array_->node_->shape()); | ||
|
|
||
| // Now we make the bounds of the array element and the basic indexing element equal | ||
|
|
||
| // Make the upper bounds consistent | ||
| if (CPStatus status = basic_indexing_->remove_above( | ||
| v_state, array_->max(v_state, array_index), bi_index); | ||
| not status) | ||
| return status; | ||
| if (CPStatus status = array_->remove_above(v_state, basic_indexing_->max(v_state, bi_index), | ||
| array_index); | ||
| not status) | ||
| return status; | ||
|
|
||
| // Make the lower bounds consistent | ||
| if (CPStatus status = basic_indexing_->remove_below( | ||
| v_state, array_->min(v_state, array_index), bi_index); | ||
| not status) | ||
| return status; | ||
| if (CPStatus status = array_->remove_below(v_state, basic_indexing_->min(v_state, bi_index), | ||
| array_index); | ||
| not status) | ||
| return status; | ||
|
|
||
| data->set_scheduled(false, bi_index); | ||
| } | ||
|
|
||
| return CPStatus::OK; | ||
| } | ||
|
|
||
| } // namespace dwave::optimization::cp | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -23,6 +23,7 @@ | |||||
| #include "dwave-optimization/cp/core/index_transform.hpp" | ||||||
| #include "dwave-optimization/cp/core/interval_array.hpp" | ||||||
| #include "dwave-optimization/cp/propagators/identity_propagator.hpp" | ||||||
| #include "dwave-optimization/cp/propagators/indexing_propagators.hpp" | ||||||
| #include "dwave-optimization/cp/state/copier.hpp" | ||||||
| #include "dwave-optimization/nodes.hpp" | ||||||
| #include "dwave-optimization/state.hpp" | ||||||
|
|
@@ -109,4 +110,159 @@ TEST_CASE("ElementWiseIdentityPropagator") { | |||||
| } | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| TEST_CASE("BasicIndexingPropagator") { | ||||||
| using namespace dwave::optimization; | ||||||
|
|
||||||
| GIVEN("A dwopt graph with basic indexing") { | ||||||
| Graph graph; | ||||||
| auto i = graph.emplace_node<IntegerNode>(5, -3, 4); | ||||||
| auto b = graph.emplace_node<BasicIndexingNode>(i, Slice(1, 4)); | ||||||
|
|
||||||
| // Lock the graph | ||||||
| graph.topological_sort(); | ||||||
|
|
||||||
| // Construct the CP corresponding model | ||||||
| AND_GIVEN("The CP Model") { | ||||||
| CPModel model; | ||||||
|
|
||||||
| // Add the variabbles to the model | ||||||
| CPVar* i_var = model.emplace_variable<CPVar>(model, i, i->topological_index()); | ||||||
| CPVar* b_var = model.emplace_variable<CPVar>(model, b, b->topological_index()); | ||||||
|
|
||||||
| Propagator* p = model.emplace_propagator<BasicIndexingPropagator>( | ||||||
| model.num_propagators(), i_var, b_var); | ||||||
|
|
||||||
| // build the advisor for the propagator p aimed to the variable for i | ||||||
| Advisor advisor_i(p, 0, std::make_unique<BasicIndexingForwardTransform>(i, b)); | ||||||
| i_var->propagate_on_domain_change(std::move(advisor_i)); | ||||||
|
|
||||||
| // build the advisor for the propagator p aimed to the variable for b | ||||||
| Advisor advisor_b(p, 1, std::make_unique<ElementWiseTransform>()); | ||||||
| b_var->propagate_on_domain_change(std::move(advisor_b)); | ||||||
|
|
||||||
| REQUIRE(i_var->on_domain.size() == 1); | ||||||
| REQUIRE(b_var->on_domain.size() == 1); | ||||||
|
|
||||||
| WHEN("We initialize a state") { | ||||||
| CPState state = model.initialize_state<Copier>(); | ||||||
| CPVarsState& s_state = state.get_variables_state(); | ||||||
| CPPropagatorsState& p_state = state.get_propagators_state(); | ||||||
|
|
||||||
| REQUIRE(s_state.size() == 2); | ||||||
| REQUIRE(p_state.size() == 1); | ||||||
|
|
||||||
| i_var->initialize_state(state); | ||||||
| b_var->initialize_state(state); | ||||||
| p->initialize_state(state); | ||||||
|
|
||||||
| AND_WHEN("We alter the domain of the integer variable inside the slice") { | ||||||
| CPStatus status = i_var->assign(s_state, -2, 0); | ||||||
| REQUIRE(status == CPStatus::OK); | ||||||
| THEN("We see that the propagator is not triggered") { | ||||||
| CHECK(not p_state[0]->scheduled()); | ||||||
| CHECK(p_state[0]->indices_to_process().size() == 0); | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| AND_WHEN("We alter the domain of the integer variable inside the slice") { | ||||||
| CPStatus status = i_var->assign(s_state, -2, 3); | ||||||
| REQUIRE(status == CPStatus::OK); | ||||||
| THEN("We see that the propagator is triggered to run on the same index") { | ||||||
| REQUIRE(p_state[0]->scheduled()); | ||||||
| REQUIRE(p_state[0]->indices_to_process().size() == 1); | ||||||
|
|
||||||
| CHECK(p_state[0]->scheduled(2)); | ||||||
| } | ||||||
|
|
||||||
| AND_WHEN("We call the fix point engine") { | ||||||
| CPEngine engine; | ||||||
| engine.fix_point(state); | ||||||
|
|
||||||
| THEN("The sum output variable 2 is correctly fixed") { | ||||||
| CHECK(b_var->min(s_state, 2) == -2); | ||||||
| CHECK(b_var->max(s_state, 2) == -2); | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| GIVEN("A dwopt graph with basic indexing on a 2d array") { | ||||||
| Graph graph; | ||||||
| auto i = graph.emplace_node<IntegerNode>(std::initializer_list<ssize_t>{4, 7}, -3, 4); | ||||||
| auto b = graph.emplace_node<BasicIndexingNode>(i, 2, Slice(1, 4)); | ||||||
|
|
||||||
| // Lock the graph | ||||||
| graph.topological_sort(); | ||||||
|
|
||||||
| // Construct the CP corresponding model | ||||||
| AND_GIVEN("The CP Model") { | ||||||
| CPModel model; | ||||||
|
|
||||||
| // Add the variabbles to the model | ||||||
| CPVar* i_var = model.emplace_variable<CPVar>(model, i, i->topological_index()); | ||||||
| CPVar* b_var = model.emplace_variable<CPVar>(model, b, b->topological_index()); | ||||||
|
|
||||||
| Propagator* p = model.emplace_propagator<BasicIndexingPropagator>( | ||||||
| model.num_propagators(), i_var, b_var); | ||||||
|
|
||||||
| // build the advisor for the propagator p aimed to the variable for i | ||||||
| Advisor advisor_i(p, 0, std::make_unique<BasicIndexingForwardTransform>(i, b)); | ||||||
| i_var->propagate_on_domain_change(std::move(advisor_i)); | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here and in the following you are pushing on the list of the advisors for domain consistent propagators. In your propagator implementation however, you are imposing the weaker bound consistency, so here we should have
Suggested change
|
||||||
|
|
||||||
| // build the advisor for the propagator p aimed to the variable for b | ||||||
| Advisor advisor_b(p, 1, std::make_unique<ElementWiseTransform>()); | ||||||
| b_var->propagate_on_domain_change(std::move(advisor_b)); | ||||||
|
|
||||||
| REQUIRE(i_var->on_domain.size() == 1); | ||||||
| REQUIRE(b_var->on_domain.size() == 1); | ||||||
|
|
||||||
| WHEN("We initialize a state") { | ||||||
| CPState state = model.initialize_state<Copier>(); | ||||||
| CPVarsState& s_state = state.get_variables_state(); | ||||||
| CPPropagatorsState& p_state = state.get_propagators_state(); | ||||||
|
|
||||||
| REQUIRE(s_state.size() == 2); | ||||||
| REQUIRE(p_state.size() == 1); | ||||||
|
|
||||||
| i_var->initialize_state(state); | ||||||
| b_var->initialize_state(state); | ||||||
| p->initialize_state(state); | ||||||
|
|
||||||
| AND_WHEN("We alter the domain of the integer variable inside the slice") { | ||||||
| CPStatus status = i_var->assign(s_state, -2, 0); | ||||||
| REQUIRE(status == CPStatus::OK); | ||||||
| THEN("We see that the propagator is not triggered") { | ||||||
| CHECK(not p_state[0]->scheduled()); | ||||||
| CHECK(p_state[0]->indices_to_process().size() == 0); | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| AND_WHEN("We alter the domain of the integer variable inside the slice") { | ||||||
| CPStatus status = i_var->assign(s_state, -2, 16); | ||||||
| REQUIRE(status == CPStatus::OK); | ||||||
| THEN("We see that the propagator is triggered to run on the same index") { | ||||||
| REQUIRE(p_state[0]->scheduled()); | ||||||
| REQUIRE(p_state[0]->indices_to_process().size() == 1); | ||||||
|
|
||||||
| CHECK(p_state[0]->scheduled(1)); | ||||||
| } | ||||||
|
|
||||||
| AND_WHEN("We call the fix point engine") { | ||||||
| CPEngine engine; | ||||||
| engine.fix_point(state); | ||||||
|
|
||||||
| THEN("The sum output variable 2 is correctly fixed") { | ||||||
| CHECK(b_var->min(s_state, 1) == -2); | ||||||
| CHECK(b_var->max(s_state, 1) == -2); | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| } // namespace dwave::optimization::cp | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove trailing
_as now it's a public member