-
Notifications
You must be signed in to change notification settings - Fork 0
/
search_strategy.h
264 lines (239 loc) · 7.82 KB
/
search_strategy.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
/**
* The search strategy defines which node to explore next. Two common
* strategies are BFS and DFS. DFS will quickly go down (following the best
* child at every node) to obtain a feasible solution. BFS will be quicker
* in increasing the lower bound but will take a long time to find a feasible
* solution. It is also common to switch between DFS and BFS, i.e., going
* deep first until a feasible solution is found, then select the cheapest
* node (usually BFS) to repeat.
*/
#ifndef CETSP_SEARCH_STRATEGY_H
#define CETSP_SEARCH_STRATEGY_H
#include "branching_strategy.h"
#include "cetsp/node.h"
#include <algorithm>
namespace cetsp {
class SearchStrategy {
public:
virtual void init(std::shared_ptr<Node> &root) = 0;
/**
* Gets called after a node has brachned.
* @param node The node that just branched.
*/
virtual void notify_of_branch(Node &node) = 0;
/**
* Returns the next node to be explored. This node has to be explored and
* should not be returned again.
* @return
*/
virtual std::shared_ptr<Node> next() = 0;
/**
* Checks if there is a node left to explore.
* @return True if there is a node that requires exploration. False if the
* tree has been fully explored.
*/
virtual bool has_next() = 0;
/**
* Called when the (last) node has been feasible. This allows you for
* example to switch the strategy every time a new solution has been found.
* @param node The last node.
*/
virtual void notify_of_feasible(Node &node){};
/**
* Called when the (last) node gets pruned.
* @param node The last node.
*/
virtual void notify_of_prune(Node &node){};
virtual ~SearchStrategy() = default;
};
class DfsBfs : public SearchStrategy {
public:
void init(std::shared_ptr<Node> &root) override {
std::cout << "Using DfsBfs search" << std::endl;
queue.emplace_back(root, root->obj(),
root->get_relaxed_solution().obj());
}
void notify_of_branch(Node &node) override {
auto children = node.get_children();
std::sort(children.begin(), children.end(),
[](std::shared_ptr<Node> &a, std::shared_ptr<Node> &b) {
const auto lb_a = a->obj();
const auto lb_b = b->obj();
if (std::abs(lb_a - lb_b) < 0.001) { // approx equal
return a->get_relaxed_solution().obj() >
b->get_relaxed_solution().obj();
}
return a->obj() > b->obj();
});
for (auto &child : children) {
queue.emplace_back(child, child->obj(),
child->get_relaxed_solution().obj());
}
}
void notify_of_feasible(Node &node) override {
sort_to_priotize_lowest_value();
}
void notify_of_prune(Node &node) override { sort_to_priotize_lowest_value(); }
std::shared_ptr<Node> next() override {
if (!has_next()) {
return nullptr;
}
auto n = queue.back();
queue.pop_back();
return std::get<0>(n);
}
bool has_next() override {
// remove all pruned entries from the back
while (!queue.empty() && std::get<0>(queue.back())->is_pruned()) {
queue.pop_back();
}
return !queue.empty();
}
private:
void sort_to_priotize_lowest_value() {
std::sort(queue.begin(), queue.end(), [](auto &a, auto &b) {
const auto lb_a = std::get<1>(a);
const auto lb_b = std::get<1>(b);
if (std::abs(lb_a - lb_b) < 0.001) { // approx equal
return std::get<2>(a) > std::get<2>(b);
}
return lb_a > lb_b;
});
}
std::vector<std::tuple<std::shared_ptr<Node>, double, double>> queue;
};
class CheapestChildDepthFirst : public SearchStrategy {
public:
void init(std::shared_ptr<Node> &root) override { queue.push_back(root); }
void notify_of_branch(Node &node) override {
auto children = node.get_children();
std::sort(children.begin(), children.end(),
[](std::shared_ptr<Node> &a, std::shared_ptr<Node> &b) {
const auto lb_a = a->get_lower_bound();
const auto lb_b = b->get_lower_bound();
if (std::abs(lb_a - lb_b) < 0.001) { // approx equal
return a->get_relaxed_solution().obj() >
b->get_relaxed_solution().obj();
}
return lb_a > lb_b;
});
for (auto &child : children) {
queue.push_back(child);
}
}
std::shared_ptr<Node> next() override {
if (!has_next()) {
return nullptr;
}
auto n = queue.back();
queue.pop_back();
return n;
}
bool has_next() override {
// remove all pruned entries from the back
while (!queue.empty() && queue.back()->is_pruned()) {
queue.pop_back();
}
return !queue.empty();
}
private:
std::vector<std::shared_ptr<Node>> queue;
};
class CheapestBreadthFirst : public SearchStrategy {
public:
void init(std::shared_ptr<Node> &root) override { queue.push_back(root); }
void notify_of_branch(Node &node) override {
for (auto &child : node.get_children()) {
queue.push_back(child);
}
std::sort(queue.begin(), queue.end(),
[](std::shared_ptr<Node> &a, std::shared_ptr<Node> &b) {
const auto lb_a = a->get_lower_bound();
const auto lb_b = b->get_lower_bound();
if (std::abs(lb_a - lb_b) < 0.001) { // approx equal
return a->get_relaxed_solution().obj() >
b->get_relaxed_solution().obj();
}
return lb_a > lb_b;
});
}
std::shared_ptr<Node> next() override {
if (!has_next()) {
return nullptr;
}
auto n = queue.back();
queue.pop_back();
return n;
}
bool has_next() override {
// remove all pruned entries from the back
while (!queue.empty() && queue.back()->is_pruned()) {
queue.pop_back();
}
return !queue.empty();
}
private:
std::vector<std::shared_ptr<Node>> queue;
};
class RandomNextNode : public SearchStrategy {
/**
* Just returning a random node. Probably not the best idea but a useful
* baseline.
*/
public:
void init(std::shared_ptr<Node> &root) override { queue.push_back(root); }
void notify_of_branch(Node &node) override {
for (auto &child : node.get_children()) {
queue.push_back(child);
}
// Shuffle queue.
// TODO: It would be more efficient to just take a random element
// from the queue.
std::shuffle(queue.begin(), queue.end(), std::default_random_engine());
}
std::shared_ptr<Node> next() override {
if (!has_next()) {
return nullptr;
}
auto n = queue.back();
queue.pop_back();
return n;
}
bool has_next() override {
// remove all pruned entries from the back
while (!queue.empty() && queue.back()->is_pruned()) {
queue.pop_back();
}
return !queue.empty();
}
private:
std::vector<std::shared_ptr<Node>> queue;
};
TEST_CASE("Search Strategy") {
// The strategy should choose the triangle and implicitly cover the
// second circle.
Instance instance({{{0, 0}, 1}, {{3, 0}, 1}, {{6, 0}, 1}, {{3, 6}, 1}});
FarthestCircle bs;
auto root = std::make_shared<Node>(std::vector<int>{0, 1, 2, 3}, &instance);
bs.setup(&instance, root, nullptr);
CheapestChildDepthFirst ss;
ss.init(root);
auto node = ss.next();
CHECK(node != nullptr);
CHECK(bs.branch(*node) == false);
ss.notify_of_branch(*node);
CHECK(ss.next() == nullptr);
std::vector<Circle> seq = {{{0, 0}, 1}, {{3, 0}, 1}, {{6, 0}, 1}};
auto root2 = std::make_shared<Node>(std::vector<int>{0, 1, 2}, &instance);
CheapestChildDepthFirst ss2;
ss2.init(root2);
node = ss2.next();
CHECK(bs.branch(*node) == true);
ss2.notify_of_branch(*node);
CHECK(ss2.next() != nullptr);
CHECK(ss2.next() != nullptr);
CHECK(ss2.next() != nullptr);
CHECK(ss2.next() == nullptr);
}
} // namespace cetsp
#endif // CETSP_SEARCH_STRATEGY_H