diff --git a/src/main/scala/io/github/acl4s/Dsu.scala b/src/main/scala/io/github/acl4s/Dsu.scala new file mode 100644 index 0000000..ef120bb --- /dev/null +++ b/src/main/scala/io/github/acl4s/Dsu.scala @@ -0,0 +1,80 @@ +package io.github.acl4s + +import scala.collection.mutable + +/** + * Implement (union by size) + (path compression) + * Reference: + * Zvi Galil and Giuseppe F. Italiano, + * Data structures and algorithms for disjoint set union problems + * + * @param n + */ +case class Dsu(n: Int) { + + /** + * root node: -1 * component size + * otherwise: parent + */ + private val parentOrSize: Array[Int] = Array.fill(n)(-1) + + def merge(a: Int, b: Int): Int = { + assert(0 <= a && a < n) + assert(0 <= b && b < n) + var x = leader(a) + var y = leader(b) + if (x == y) { return x } + if (-parentOrSize(x) < -parentOrSize(y)) { + // std::swap(x, y); + val z = x + x = y + y = z + } + parentOrSize(x) += parentOrSize(y) + parentOrSize(y) = x + x + } + + def same(a: Int, b: Int): Boolean = { + assert(0 <= a && a < n) + assert(0 <= b && b < n) + leader(a) == leader(b) + } + + def leader(a: Int): Int = { + assert(0 <= a && a < n) + if (parentOrSize(a) < 0) { + a + } else { + parentOrSize(a) = leader(parentOrSize(a)) + parentOrSize(a) + } + } + + def size(a: Int): Int = { + assert(0 <= a && a < n) + -parentOrSize(leader(a)) + } + + def groups(): Seq[Seq[Int]] = { + val leader_buf = new Array[Int](n) + val group_size = new Array[Int](n) + (0 until n).foreach(i => { + leader_buf(i) = leader(i) + group_size(leader_buf(i)) += 1 + }) + + val result = new Array[mutable.ArrayBuffer[Int]](n) + (0 until n).foreach(i => { + result(i) = new mutable.ArrayBuffer(group_size(i)) + }) + (0 until n).foreach(i => { + result(leader_buf(i)) += i + }) + + result.collect { + case buf if buf.nonEmpty => buf.toSeq + }.toSeq + } + +} diff --git a/src/test/scala/io/github/acl4s/DsuSuite.scala b/src/test/scala/io/github/acl4s/DsuSuite.scala new file mode 100644 index 0000000..14d234b --- /dev/null +++ b/src/test/scala/io/github/acl4s/DsuSuite.scala @@ -0,0 +1,68 @@ +package io.github.acl4s + +class DsuSuite extends munit.FunSuite { + + /** + * @see https://atcoder.jp/contests/practice2/tasks/practice2_a + */ + test("AtCoder Library Practice Contest A - Disjoint Set Union") { + val uf = Dsu(4) + + assertEquals(uf.same(0, 1), false) + + uf.merge(0, 1) + uf.merge(2, 3) + + assertEquals(uf.same(0, 1), true) + assertEquals(uf.same(1, 2), false) + assertEquals(uf.groups(), Seq(Seq(0, 1), Seq(2, 3))) + + uf.merge(0, 2) + + assertEquals(uf.same(1, 3), true) + assertEquals(uf.groups(), Seq(Seq(0, 1, 2, 3))) + } + + test("zero") { + val uf = Dsu(0) + + assertEquals(uf.groups(), Seq()) + } + + test("simple") { + val uf = Dsu(2) + + assertEquals(uf.same(0, 1), false) + + val x = uf.merge(0, 1) + assertEquals(uf.leader(0), x) + assertEquals(uf.leader(1), x) + assertEquals(uf.same(0, 1), true) + assertEquals(uf.size(0), 2) + } + + test("line") { + val n = 500_000 + val uf = Dsu(n) + + (0 until n - 1).foreach(i => { + uf.merge(i, i + 1) + }) + + assertEquals(uf.size(0), n) + assertEquals(uf.groups().size, 1) + } + + test("line reverse") { + val n = 500_000 + val uf = Dsu(n) + + (0 until n - 1).reverse.foreach(i => { + uf.merge(i, i + 1) + }) + + assertEquals(uf.size(0), n) + assertEquals(uf.groups().size, 1) + } + +}