Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions src/main/scala/io/github/acl4s/Dsu.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,25 +56,23 @@ case class Dsu(n: Int) {
-parentOrSize(leader(a))
}

def groups(): Seq[Seq[Int]] = {
def groups(): collection.Seq[collection.Seq[Int]] = {
val leader_buf = new Array[Int](n)
val group_size = new Array[Int](n)
(0 until n).foreach(i => {
leader_buf(i) = leader(i)
group_size(leader_buf(i)) += 1
})

val result = new Array[mutable.ArrayBuffer[Int]](n)
val result = new mutable.ArrayBuffer[mutable.Buffer[Int]](n)
(0 until n).foreach(i => {
result(i) = new mutable.ArrayBuffer(group_size(i))
result += new mutable.ArrayBuffer(group_size(i))
})
(0 until n).foreach(i => {
result(leader_buf(i)) += i
})

result.collect {
case buf if buf.nonEmpty => buf.toSeq
}.toSeq
result.filter(_.nonEmpty)
}

}
22 changes: 22 additions & 0 deletions src/main/scala/io/github/acl4s/SccGraph.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package io.github.acl4s

case class SccGraph(private val internal: io.github.acl4s.internal.SccGraph) {

def addEdge(from: Int, to: Int): Unit = {
val n = internal.numVertices
assert(0 <= from && from < n)
assert(0 <= to && to < n)
internal.addEdge(from, to)
}

def scc(): collection.Seq[collection.Seq[Int]] = {
internal.scc()
}

}

object SccGraph {
def apply(n: Int): SccGraph = {
new SccGraph(internal.SccGraph(n))
}
}
105 changes: 105 additions & 0 deletions src/main/scala/io/github/acl4s/internal/SccGraph.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package io.github.acl4s.internal

import scala.collection.mutable
import scala.reflect.ClassTag

private[internal] case class Edge(to: Int)

private[internal] case class Csr[E] private (start: Array[Int], eList: Array[E])
private[internal] object Csr {
def apply[E: ClassTag](n: Int, edges: collection.Seq[(Int, E)]): Csr[E] = {
val csr = Csr(new Array[Int](n + 1), new Array[E](edges.size))
for ((from, _) <- edges) {
csr.start(from + 1) += 1
}
(1 to n).foreach(i => {
csr.start(i) += csr.start(i - 1)
})
val counter = csr.start.clone()
for ((from, edge) <- edges) {
csr.eList(counter(from)) = edge
counter(from) += 1
}
csr
}
}

/**
* Reference:
* R. Tarjan,
* Depth-First Search and Linear Graph Algorithms
*/
case class SccGraph(private val n: Int) {
private val edges: mutable.Buffer[(Int, Edge)] = mutable.ListBuffer.empty

def numVertices: Int = n

def addEdge(from: Int, to: Int): Unit = {
edges += ((from, Edge(to)))
}

/**
* @return pair of (# of scc, scc id)
*/
def sccIds(): (Int, Array[Int]) = {
val g = Csr(n, edges)
var now_ord = 0
var group_num = 0
val visited = new mutable.Stack[Int](n)
val ord = Array.fill(n)(-1)
val low = new Array[Int](n)
val ids = new Array[Int](n)

def dfs(v: Int): Unit = {
low(v) = now_ord
ord(v) = now_ord
now_ord += 1
visited.push(v)
(g.start(v) until g.start(v + 1)).foreach(i => {
val to = g.eList(i).to
if (ord(to) == -1) {
dfs(to)
low(v) = low(v).min(low(to))
} else {
low(v) = low(v).min(ord(to))
}
})
if (low(v) == ord(v)) {
while {
// do
val u = visited.pop()
ord(u) = n
ids(u) = group_num

// while
u != v
} do {}
group_num += 1
}
}

(0 until n).foreach(i => {
if (ord(i) == -1) { dfs(i) }
})
(0 until n).foreach(i => {
ids(i) = group_num - 1 - ids(i)
})

(group_num, ids)
}

def scc(): collection.Seq[collection.Seq[Int]] = {
val (group_nums, ids) = sccIds()
val counts = new Array[Int](group_nums)
ids.foreach(x => { counts(x) += 1 })
val groups = new mutable.ArrayBuffer[mutable.Buffer[Int]](n)
(0 until group_nums).foreach(i => {
groups += new mutable.ArrayBuffer[Int](counts(i))
})
(0 until n).foreach(i => {
groups(ids(i)) += i
})

groups
}
}
45 changes: 45 additions & 0 deletions src/test/scala/io/github/acl4s/SccGraphSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package io.github.acl4s

class SccGraphSuite extends munit.FunSuite {

/**
* @see https://atcoder.jp/contests/practice2/tasks/practice2_g
*/
test("AtCoder Library Practice Contest G - SCC") {
val graph = SccGraph(6)
graph.addEdge(1, 4)
graph.addEdge(5, 2)
graph.addEdge(3, 0)
graph.addEdge(5, 5)
graph.addEdge(4, 1)
graph.addEdge(0, 3)
graph.addEdge(4, 2)

val scc = graph.scc()
assertEquals(scc.size, 4)
assertEquals(scc.map(_.toSet).toSet, Set(Set(5), Set(1, 4), Set(2), Set(0, 3)))
}

test("empty") {
assertEquals(SccGraph(0).scc(), Seq())
}

test("simple") {
val graph = SccGraph(2)
graph.addEdge(0, 1)
graph.addEdge(1, 0)

val scc = graph.scc()
assertEquals(scc.size, 1)
}

test("self loop") {
val graph = SccGraph(2)
graph.addEdge(0, 0)
graph.addEdge(0, 0)
graph.addEdge(1, 1)

val scc = graph.scc()
assertEquals(scc.size, 2)
}
}