Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 177 additions & 27 deletions combinatorial_opt/linear_sum_assignment.hpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
#pragma once
#include <cassert>
#include <queue>
#include <tuple>
#include <vector>

namespace linear_sum_assignment {

template <class T> struct Result {
T opt;
std::vector<int> mate;
std::vector<T> f, g; // dual variables
};

template <class T>
T augment(int nr, int nc, const std::vector<std::vector<T>> &C, std::vector<T> &f,
std::vector<T> &g, int s, std::vector<int> &mate, std::vector<int> &mate_inv) {
T augment(int nr, int nc, const std::vector<std::vector<T>> &C, std::vector<T> &f, std::vector<T> &g,
int s, // source row
std::vector<int> &mate,
std::vector<int> &mate_inv, // duplicates are allowed (used for k-best algorithms)
int fixed_rows = 0 // Ignore first rows and corresponding columns (used for k-best algorithms)
) {

assert(0 <= s and s < nr);
assert(mate.at(s) < 0);
Expand All @@ -17,15 +28,29 @@ T augment(int nr, int nc, const std::vector<std::vector<T>> &C, std::vector<T> &
dist.resize(nc);
prv.resize(nc);

f.at(s) = C.at(s).at(0) - g.at(0);
for (int j = 1; j < nc; ++j) f.at(s) = std::min(f.at(s), C.at(s).at(j) - g.at(j));
std::vector<bool> done(nc);

for (int j = 0; j < nc; ++j) {
dist.at(j) = C.at(s).at(j) - f.at(s) - g.at(j);
prv.at(j) = s;
for (int i = 0; i < fixed_rows; ++i) {
if (int j = mate.at(i); j >= 0) done.at(j) = 1;
}

std::vector<bool> done(nc);
{
int h = 0;
while (done.at(h)) ++h;

f.at(s) = C.at(s).at(h) - g.at(h);
for (int j = h + 1; j < nc; ++j) {
if (done.at(j)) continue;
f.at(s) = std::min(f.at(s), C.at(s).at(j) - g.at(j));
}
}

for (int j = 0; j < nc; ++j) {
if (!done.at(j)) {
dist.at(j) = C.at(s).at(j) - f.at(s) - g.at(j);
prv.at(j) = -1;
}
}

int t = -1;
std::vector<int> stk;
Expand All @@ -40,8 +65,6 @@ T augment(int nr, int nc, const std::vector<std::vector<T>> &C, std::vector<T> &
}
}

if (j1 == -1) return false;

if (mate_inv.at(j1) < 0) {
t = j1;
break;
Expand All @@ -51,7 +74,8 @@ T augment(int nr, int nc, const std::vector<std::vector<T>> &C, std::vector<T> &
stk = {j1};

while (!stk.empty()) {
const int i = mate_inv.at(stk.back());
const int j2 = stk.back();
const int i = mate_inv.at(j2);
if (i < 0) {
t = stk.back();
break;
Expand All @@ -65,7 +89,7 @@ T augment(int nr, int nc, const std::vector<std::vector<T>> &C, std::vector<T> &

if (dist.at(j) > dist.at(j1) + len) {
dist.at(j) = dist.at(j1) + len;
prv.at(j) = i;
prv.at(j) = j2;
}

if (len == T()) {
Expand All @@ -80,33 +104,46 @@ T augment(int nr, int nc, const std::vector<std::vector<T>> &C, std::vector<T> &

f.at(s) += len;

T ret = len;
for (int i = 0; i < fixed_rows; ++i) {
if (const int j = mate.at(i); j >= 0) done.at(j) = 0;
}

for (int j = 0; j < nc; ++j) {
if (!done.at(j)) continue;

g.at(j) -= len - dist.at(j);
if (mate_inv.at(j) >= 0) {
f.at(mate_inv.at(j)) += len - dist.at(j);
} else {
ret -= len - dist.at(j);
}
}

for (int i = fixed_rows; i < nr; ++i) {
const int j = mate.at(i);
if (j < 0 or !done.at(j) or j >= nc) continue;
f.at(i) += len - dist.at(j);
}

T ret = T();

for (int cur = t; cur >= 0;) {
const int i = prv.at(cur);
const int nxt = prv.at(cur);
if (nxt < 0) {
mate_inv.at(cur) = s;
mate.at(s) = cur;
ret += C.at(s).at(cur);
break;
}
const int i = mate_inv.at(nxt);

ret += C.at(i).at(cur) - C.at(i).at(nxt);

mate_inv.at(cur) = i;
if (i == -1) break;
std::swap(cur, mate.at(i));
mate.at(i) = cur;
cur = nxt;
}

return ret;
}

// Complexity: O(nr^2 nc)
template <class T>
std::tuple<T, std::vector<int>, std::vector<T>, std::vector<T>>
_solve(int nr, int nc, const std::vector<std::vector<T>> &C) {
template <class T> Result<T> _solve(int nr, int nc, const std::vector<std::vector<T>> &C) {

assert(nr <= nc);

Expand Down Expand Up @@ -178,9 +215,7 @@ _solve(int nr, int nc, const std::vector<std::vector<T>> &C) {
// Jonker–Volgenant algorithm: find minimum weight assignment
// Dual problem (nr == nc): maximize sum(f) + sum(g) s.t. f_i + g_j <= C_ij
// Complexity: O(nr nc min(nr, nc))
template <class T>
std::tuple<T, std::vector<int>, std::vector<T>, std::vector<T>>
solve(int nr, int nc, const std::vector<std::vector<T>> &C) {
template <class T> Result<T> solve(int nr, int nc, const std::vector<std::vector<T>> &C) {

const bool transpose = (nr > nc);

Expand All @@ -203,3 +238,118 @@ solve(int nr, int nc, const std::vector<std::vector<T>> &C) {
}

} // namespace linear_sum_assignment

template <class T> struct best_assignments {

struct Node {
T opt;
std::vector<int> mate;
std::vector<T> f, g; // dual variables
int fixed_rows;
std::vector<int> banned_js; // C[fixed_rows][j] が inf となる j の集合

// for priority queue
// NOTE: reverse order
bool operator<(const Node &rhs) const { return opt > rhs.opt; }

linear_sum_assignment::Result<T> to_output(bool transpose) const {
if (transpose) {
std::vector<int> mate2(g.size(), -1);
for (int i = 0; i < (int)mate.size(); ++i) mate2.at(mate.at(i)) = i;
return {opt, mate2, g, f};
} else {
return {opt, mate, f, g};
}
}
};

bool transpose;
int nr_, nc_;
T inf;
std::vector<std::vector<T>> C_, Ctmp_;
std::priority_queue<Node> pq;

best_assignments(int nr, int nc, const std::vector<std::vector<T>> &C, T inf)
: transpose(nr > nc), inf(inf) {

assert((int)C.size() == nr);
for (int i = 0; i < nr; ++i) assert((int)C.at(i).size() == nc);

nr_ = transpose ? nc : nr;
nc_ = transpose ? nr : nc;

C_.assign(nr_ + (nr_ != nc_), std::vector<T>(nc_, T()));
for (int i = 0; i < nr; ++i) {
for (int j = 0; j < nc; ++j) {
C_.at(transpose ? j : i).at(transpose ? i : j) = C.at(i).at(j);
}
}

Ctmp_ = C_;

auto [opt, mate, f, g] = linear_sum_assignment::solve(C_.size(), nc, C_);

pq.emplace(Node{opt, std::move(mate), std::move(f), std::move(g), 0, {}});
}

bool finished() const { return pq.empty(); }

linear_sum_assignment::Result<T> yield() {
assert(!pq.empty());

const Node ret = pq.top();
pq.pop();

for (int fixed_rows = ret.fixed_rows; fixed_rows < nr_; ++fixed_rows) {
std::vector<int> banned_js;
if (fixed_rows == ret.fixed_rows) banned_js = ret.banned_js;

const int s = fixed_rows;
banned_js.push_back(ret.mate.at(s));

if ((int)banned_js.size() >= nc_) continue;

auto f = ret.f;
auto g = ret.g;
auto mate = ret.mate;

std::vector<int> mate_inv(nc_, nr_);
for (int i = 0; i < nr_; ++i) mate_inv.at(mate.at(i)) = i;

std::vector<int> iscoldone(nc_);
for (int i = 0; i < fixed_rows; ++i) iscoldone.at(mate.at(i)) = 1;

for (int j : banned_js) Ctmp_.at(s).at(j) = inf;

mate_inv.at(mate.at(s)) = -1;
mate.at(s) = -1;

auto aug = linear_sum_assignment::augment(
nr_, nc_, Ctmp_, f, g, s, mate, mate_inv, fixed_rows);

for (int j = 0; j < nc_; ++j) {
if (mate_inv.at(j) < 0) { // nrows < ncols
g.at(j) = -f.back();
for (int i = fixed_rows; i < nr_; ++i) {
g.at(j) = std::min(g.at(j), Ctmp_.at(i).at(j) - f.at(i));
}
}
}

if (Ctmp_.at(s).at(mate.at(s)) < inf) {
pq.emplace(Node{
ret.opt + aug - C_.at(s).at(ret.mate.at(s)),
std::move(mate),
std::move(f),
std::move(g),
fixed_rows,
banned_js,
});
}

for (int j : banned_js) Ctmp_.at(s).at(j) = C_.at(s).at(j);
}

return ret.to_output(transpose);
}
};
23 changes: 23 additions & 0 deletions combinatorial_opt/linear_sum_assignment.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ $r$ 行 $c$ 列の行列を入力とした割当問題(二部グラフの最

オーソドックスな Hungarian algorithm の実装ではなく, Jonker–Volgenant algorithm の工夫を一部取り入れることで定数倍高速化を試みている.

また,割当問題の上位 $k$ 個の解を効率的に列挙するクラス `best_assignments` も提供している.このクラスのコンストラクタや `yield()` を呼び出し毎に $O(rc \min{r, c})$ の時間計算量が発生する.

## 解いてくれる問題

主問題として,オーソドックスな線形重み割当問題を解く.
Expand All @@ -22,6 +24,8 @@ $

## 使用方法

### 割当問題(最適解の計算)

```cpp
vector C(r, vector<long long>(c));

Expand All @@ -36,10 +40,29 @@ std::tie(min_weight, mate, f, g) = linear_sum_assignment::solve(r, c, C);

また, `f[i]` および `g[j]` は最適解における双対変数の一例を示す.すなわち,任意の $i, j$ について $f\_i + g\_j \le C\_{ij}$ が成立し,特に第 $i$ 行と第 $j$ 列が対応する場合は等号が成立する.この双対変数は,行列の一部要素に更新を加えた場合の最適解の変化を効率的に追うために利用できる.

### 割当問題の $k$-best 解列挙

```cpp
int r, c;

vector<vector<int>> C;
int inf;

best_assignments<int> gen(r, c, cost, inf);

// 解の生成
for (int t = 0; t < k; ++t) {
if (ba.finished()) break;
auto [opt, mate, f, g] = ba.yield();
}
```

## 問題例

- [Library Checker: Assignment Problem](https://judge.yosupo.jp/problem/assignment)

## 文献・リンク集

- [Lecture 8: Assignment Algorithms](https://cyberlab.engr.uconn.edu/wp-content/uploads/sites/2576/2018/09/Lecture_8.pdf)
- [1] K. G. Murty, "An algorithm for ranking all the assignments in order of increasing cost," Operations Research, 16(3), 682–687, 1968.
- [2] M.L. Miller, H.S. Stone, I.J. Cox, "Optimizing Murty's ranked assignment method," IEEE Transactions on Aerospace and Electronic Systems, 33(3), 851-862, 1997.
Loading