diff --git a/combinatorial_opt/linear_sum_assignment.hpp b/combinatorial_opt/linear_sum_assignment.hpp index 969f6841..e9b6d03a 100644 --- a/combinatorial_opt/linear_sum_assignment.hpp +++ b/combinatorial_opt/linear_sum_assignment.hpp @@ -1,13 +1,24 @@ #pragma once #include +#include #include #include namespace linear_sum_assignment { +template struct Result { + T opt; + std::vector mate; + std::vector f, g; // dual variables +}; + template -T augment(int nr, int nc, const std::vector> &C, std::vector &f, - std::vector &g, int s, std::vector &mate, std::vector &mate_inv) { +T augment(int nr, int nc, const std::vector> &C, std::vector &f, std::vector &g, + int s, // source row + std::vector &mate, + std::vector &mate_inv, // duplicates are allowed (used for k-best algorithms) + int fixed_rows = 0 // Ignore first rows and corresponding columns (used for k-best algorithms) +) { assert(0 <= s and s < nr); assert(mate.at(s) < 0); @@ -17,15 +28,29 @@ T augment(int nr, int nc, const std::vector> &C, std::vector & dist.resize(nc); prv.resize(nc); - f.at(s) = C.at(s).at(0) - g.at(0); - for (int j = 1; j < nc; ++j) f.at(s) = std::min(f.at(s), C.at(s).at(j) - g.at(j)); + std::vector done(nc); - for (int j = 0; j < nc; ++j) { - dist.at(j) = C.at(s).at(j) - f.at(s) - g.at(j); - prv.at(j) = s; + for (int i = 0; i < fixed_rows; ++i) { + if (int j = mate.at(i); j >= 0) done.at(j) = 1; } - std::vector done(nc); + { + int h = 0; + while (done.at(h)) ++h; + + f.at(s) = C.at(s).at(h) - g.at(h); + for (int j = h + 1; j < nc; ++j) { + if (done.at(j)) continue; + f.at(s) = std::min(f.at(s), C.at(s).at(j) - g.at(j)); + } + } + + for (int j = 0; j < nc; ++j) { + if (!done.at(j)) { + dist.at(j) = C.at(s).at(j) - f.at(s) - g.at(j); + prv.at(j) = -1; + } + } int t = -1; std::vector stk; @@ -40,8 +65,6 @@ T augment(int nr, int nc, const std::vector> &C, std::vector & } } - if (j1 == -1) return false; - if (mate_inv.at(j1) < 0) { t = j1; break; @@ -51,7 +74,8 @@ T augment(int nr, int nc, const std::vector> &C, std::vector & stk = {j1}; while (!stk.empty()) { - const int i = mate_inv.at(stk.back()); + const int j2 = stk.back(); + const int i = mate_inv.at(j2); if (i < 0) { t = stk.back(); break; @@ -65,7 +89,7 @@ T augment(int nr, int nc, const std::vector> &C, std::vector & if (dist.at(j) > dist.at(j1) + len) { dist.at(j) = dist.at(j1) + len; - prv.at(j) = i; + prv.at(j) = j2; } if (len == T()) { @@ -80,33 +104,46 @@ T augment(int nr, int nc, const std::vector> &C, std::vector & f.at(s) += len; - T ret = len; + for (int i = 0; i < fixed_rows; ++i) { + if (const int j = mate.at(i); j >= 0) done.at(j) = 0; + } for (int j = 0; j < nc; ++j) { if (!done.at(j)) continue; g.at(j) -= len - dist.at(j); - if (mate_inv.at(j) >= 0) { - f.at(mate_inv.at(j)) += len - dist.at(j); - } else { - ret -= len - dist.at(j); - } } + for (int i = fixed_rows; i < nr; ++i) { + const int j = mate.at(i); + if (j < 0 or !done.at(j) or j >= nc) continue; + f.at(i) += len - dist.at(j); + } + + T ret = T(); + for (int cur = t; cur >= 0;) { - const int i = prv.at(cur); + const int nxt = prv.at(cur); + if (nxt < 0) { + mate_inv.at(cur) = s; + mate.at(s) = cur; + ret += C.at(s).at(cur); + break; + } + const int i = mate_inv.at(nxt); + + ret += C.at(i).at(cur) - C.at(i).at(nxt); + mate_inv.at(cur) = i; - if (i == -1) break; - std::swap(cur, mate.at(i)); + mate.at(i) = cur; + cur = nxt; } return ret; } // Complexity: O(nr^2 nc) -template -std::tuple, std::vector, std::vector> -_solve(int nr, int nc, const std::vector> &C) { +template Result _solve(int nr, int nc, const std::vector> &C) { assert(nr <= nc); @@ -178,9 +215,7 @@ _solve(int nr, int nc, const std::vector> &C) { // Jonker–Volgenant algorithm: find minimum weight assignment // Dual problem (nr == nc): maximize sum(f) + sum(g) s.t. f_i + g_j <= C_ij // Complexity: O(nr nc min(nr, nc)) -template -std::tuple, std::vector, std::vector> -solve(int nr, int nc, const std::vector> &C) { +template Result solve(int nr, int nc, const std::vector> &C) { const bool transpose = (nr > nc); @@ -203,3 +238,118 @@ solve(int nr, int nc, const std::vector> &C) { } } // namespace linear_sum_assignment + +template struct best_assignments { + + struct Node { + T opt; + std::vector mate; + std::vector f, g; // dual variables + int fixed_rows; + std::vector banned_js; // C[fixed_rows][j] が inf となる j の集合 + + // for priority queue + // NOTE: reverse order + bool operator<(const Node &rhs) const { return opt > rhs.opt; } + + linear_sum_assignment::Result to_output(bool transpose) const { + if (transpose) { + std::vector mate2(g.size(), -1); + for (int i = 0; i < (int)mate.size(); ++i) mate2.at(mate.at(i)) = i; + return {opt, mate2, g, f}; + } else { + return {opt, mate, f, g}; + } + } + }; + + bool transpose; + int nr_, nc_; + T inf; + std::vector> C_, Ctmp_; + std::priority_queue pq; + + best_assignments(int nr, int nc, const std::vector> &C, T inf) + : transpose(nr > nc), inf(inf) { + + assert((int)C.size() == nr); + for (int i = 0; i < nr; ++i) assert((int)C.at(i).size() == nc); + + nr_ = transpose ? nc : nr; + nc_ = transpose ? nr : nc; + + C_.assign(nr_ + (nr_ != nc_), std::vector(nc_, T())); + for (int i = 0; i < nr; ++i) { + for (int j = 0; j < nc; ++j) { + C_.at(transpose ? j : i).at(transpose ? i : j) = C.at(i).at(j); + } + } + + Ctmp_ = C_; + + auto [opt, mate, f, g] = linear_sum_assignment::solve(C_.size(), nc, C_); + + pq.emplace(Node{opt, std::move(mate), std::move(f), std::move(g), 0, {}}); + } + + bool finished() const { return pq.empty(); } + + linear_sum_assignment::Result yield() { + assert(!pq.empty()); + + const Node ret = pq.top(); + pq.pop(); + + for (int fixed_rows = ret.fixed_rows; fixed_rows < nr_; ++fixed_rows) { + std::vector banned_js; + if (fixed_rows == ret.fixed_rows) banned_js = ret.banned_js; + + const int s = fixed_rows; + banned_js.push_back(ret.mate.at(s)); + + if ((int)banned_js.size() >= nc_) continue; + + auto f = ret.f; + auto g = ret.g; + auto mate = ret.mate; + + std::vector mate_inv(nc_, nr_); + for (int i = 0; i < nr_; ++i) mate_inv.at(mate.at(i)) = i; + + std::vector iscoldone(nc_); + for (int i = 0; i < fixed_rows; ++i) iscoldone.at(mate.at(i)) = 1; + + for (int j : banned_js) Ctmp_.at(s).at(j) = inf; + + mate_inv.at(mate.at(s)) = -1; + mate.at(s) = -1; + + auto aug = linear_sum_assignment::augment( + nr_, nc_, Ctmp_, f, g, s, mate, mate_inv, fixed_rows); + + for (int j = 0; j < nc_; ++j) { + if (mate_inv.at(j) < 0) { // nrows < ncols + g.at(j) = -f.back(); + for (int i = fixed_rows; i < nr_; ++i) { + g.at(j) = std::min(g.at(j), Ctmp_.at(i).at(j) - f.at(i)); + } + } + } + + if (Ctmp_.at(s).at(mate.at(s)) < inf) { + pq.emplace(Node{ + ret.opt + aug - C_.at(s).at(ret.mate.at(s)), + std::move(mate), + std::move(f), + std::move(g), + fixed_rows, + banned_js, + }); + } + + for (int j : banned_js) Ctmp_.at(s).at(j) = C_.at(s).at(j); + } + + return ret.to_output(transpose); + } +}; diff --git a/combinatorial_opt/linear_sum_assignment.md b/combinatorial_opt/linear_sum_assignment.md index 6c4e8703..19721510 100644 --- a/combinatorial_opt/linear_sum_assignment.md +++ b/combinatorial_opt/linear_sum_assignment.md @@ -7,6 +7,8 @@ $r$ 行 $c$ 列の行列を入力とした割当問題(二部グラフの最 オーソドックスな Hungarian algorithm の実装ではなく, Jonker–Volgenant algorithm の工夫を一部取り入れることで定数倍高速化を試みている. +また,割当問題の上位 $k$ 個の解を効率的に列挙するクラス `best_assignments` も提供している.このクラスのコンストラクタや `yield()` を呼び出し毎に $O(rc \min{r, c})$ の時間計算量が発生する. + ## 解いてくれる問題 主問題として,オーソドックスな線形重み割当問題を解く. @@ -22,6 +24,8 @@ $ ## 使用方法 +### 割当問題(最適解の計算) + ```cpp vector C(r, vector(c)); @@ -36,6 +40,23 @@ std::tie(min_weight, mate, f, g) = linear_sum_assignment::solve(r, c, C); また, `f[i]` および `g[j]` は最適解における双対変数の一例を示す.すなわち,任意の $i, j$ について $f\_i + g\_j \le C\_{ij}$ が成立し,特に第 $i$ 行と第 $j$ 列が対応する場合は等号が成立する.この双対変数は,行列の一部要素に更新を加えた場合の最適解の変化を効率的に追うために利用できる. +### 割当問題の $k$-best 解列挙 + +```cpp +int r, c; + +vector> C; +int inf; + +best_assignments gen(r, c, cost, inf); + +// 解の生成 +for (int t = 0; t < k; ++t) { + if (ba.finished()) break; + auto [opt, mate, f, g] = ba.yield(); +} +``` + ## 問題例 - [Library Checker: Assignment Problem](https://judge.yosupo.jp/problem/assignment) @@ -43,3 +64,5 @@ std::tie(min_weight, mate, f, g) = linear_sum_assignment::solve(r, c, C); ## 文献・リンク集 - [Lecture 8: Assignment Algorithms](https://cyberlab.engr.uconn.edu/wp-content/uploads/sites/2576/2018/09/Lecture_8.pdf) +- [1] K. G. Murty, "An algorithm for ranking all the assignments in order of increasing cost," Operations Research, 16(3), 682–687, 1968. +- [2] M.L. Miller, H.S. Stone, I.J. Cox, "Optimizing Murty's ranked assignment method," IEEE Transactions on Aerospace and Electronic Systems, 33(3), 851-862, 1997.