Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ lazy val root = project
"-unchecked",
"-Wunused:all"
),
libraryDependencies += "org.scalameta" %% "munit" % "0.7.29" % Test
libraryDependencies += "org.scalameta" %% "munit" % "0.7.29" % Test,
Test / parallelExecution := false
)
127 changes: 127 additions & 0 deletions src/main/scala/io/github/acl4s/FenwickTree.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package io.github.acl4s

import scala.reflect.ClassTag

import io.github.acl4s.internal.rightOpenInterval

/**
* Reference: https://en.wikipedia.org/wiki/Fenwick_tree
*
* @param n
* @param m
* @tparam T
*/
case class FenwickTree[T: ClassTag](n: Int)(using m: AddSub[T]) {
private val data: Array[T] = Array.fill(n)(m.e())

def add(index: Int, x: T): Unit = {
assert(0 <= index && index < n)
var p = index + 1
while (p <= n) {
data(p - 1) = m.combine(data(p - 1), x)
p += p & -p
}
}

private def sum(i: Int): T = {
var s = m.e()
var r = i
while (r > 0) {
s = m.combine(s, data(r - 1))
r -= r & -r
}
s
}

def sum(range: Range): T = {
val (l, r) = rightOpenInterval(range)
sum(l, r)
}

def sum(l: Int, r: Int): T = {
assert(0 <= l && l <= r && r <= n)
m.subtract(sum(r), sum(l))
}
}

object FenwickTree {
def apply[T: AddSub: ClassTag](array: Array[T]): FenwickTree[T] = {
val ft = FenwickTree[T](array.length)
array.indices.foreach(i => {
ft.add(i, array(i))
})
ft
}
}

trait AddSub[T] extends Add[T] {
def subtract(a: T, b: T): T
}

object AddSub {
given (using m: Add[Char]): AddSub[Char] with {
override def e(): Char = m.e()
override def combine(a: Char, b: Char): Char = m.combine(a, b)
override def subtract(a: Char, b: Char): Char = (a - b).asInstanceOf[Char]
}

given (using m: Add[Byte]): AddSub[Byte] with {
override def e(): Byte = m.e()
override def combine(a: Byte, b: Byte): Byte = m.combine(a, b)
override def subtract(a: Byte, b: Byte): Byte = (a - b).asInstanceOf[Byte]
}

given (using m: Add[Short]): AddSub[Short] with {
override def e(): Short = m.e()
override def combine(a: Short, b: Short): Short = m.combine(a, b)
override def subtract(a: Short, b: Short): Short = (a - b).asInstanceOf[Short]
}

given (using m: Add[Int]): AddSub[Int] with {
override def e(): Int = m.e()
override def combine(a: Int, b: Int): Int = m.combine(a, b)
override def subtract(a: Int, b: Int): Int = a - b
}

given (using m: Add[Long]): AddSub[Long] with {
override def e(): Long = m.e()
override def combine(a: Long, b: Long): Long = m.combine(a, b)
override def subtract(a: Long, b: Long): Long = a - b
}

given (using m: Add[Float]): AddSub[Float] with {
override def e(): Float = m.e()
override def combine(a: Float, b: Float): Float = m.combine(a, b)
override def subtract(a: Float, b: Float): Float = a - b
}

given (using m: Add[Double]): AddSub[Double] with {
override def e(): Double = m.e()
override def combine(a: Double, b: Double): Double = m.combine(a, b)
override def subtract(a: Double, b: Double): Double = a - b
}

given (using m: Add[DynamicModInt]): AddSub[DynamicModInt] with {
override def e(): DynamicModInt = m.e()
override def combine(a: DynamicModInt, b: DynamicModInt): DynamicModInt = m.combine(a, b)
override def subtract(a: DynamicModInt, b: DynamicModInt): DynamicModInt = a - b
}

given (using m: Add[ModInt998244353]): AddSub[ModInt998244353] with {
override def e(): ModInt998244353 = m.e()
override def combine(a: ModInt998244353, b: ModInt998244353): ModInt998244353 = m.combine(a, b)
override def subtract(a: ModInt998244353, b: ModInt998244353): ModInt998244353 = a - b
}

given (using m: Add[ModInt1000000007]): AddSub[ModInt1000000007] with {
override def e(): ModInt1000000007 = m.e()
override def combine(a: ModInt1000000007, b: ModInt1000000007): ModInt1000000007 = m.combine(a, b)
override def subtract(a: ModInt1000000007, b: ModInt1000000007): ModInt1000000007 = a - b
}

given [T <: Int](using m: Add[StaticModInt[T]]): AddSub[StaticModInt[T]] with {
override def e(): StaticModInt[T] = m.e()
override def combine(a: StaticModInt[T], b: StaticModInt[T]): StaticModInt[T] = m.combine(a, b)
override def subtract(a: StaticModInt[T], b: StaticModInt[T]): StaticModInt[T] = a - b
}
}
2 changes: 2 additions & 0 deletions src/main/scala/io/github/acl4s/ModInt.scala
Original file line number Diff line number Diff line change
Expand Up @@ -189,13 +189,15 @@ type ModInt998244353 = StaticModInt[Mod998244353.value.type]
object ModInt1000000007 {
given Modulus[1_000_000_007] = Mod1000000007

def apply(): ModInt1000000007 = StaticModInt()
def apply(value: Int): ModInt1000000007 = StaticModInt(value)
def apply(value: Long): ModInt1000000007 = StaticModInt(value)
}

object ModInt998244353 {
given Modulus[998_244_353] = Mod998244353

def apply(): ModInt998244353 = StaticModInt()
def apply(value: Int): ModInt998244353 = StaticModInt(value)
def apply(value: Long): ModInt998244353 = StaticModInt(value)
}
Expand Down
64 changes: 64 additions & 0 deletions src/main/scala/io/github/acl4s/Monoid.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,67 @@ trait Monoid[T] {
object Monoid {
def apply[T](using m: Monoid[T]): Monoid[T] = m
}

trait Add[T] extends Monoid[T]
object Add {
given Add[Char] with {
override def e(): Char = 0
override def combine(a: Char, b: Char): Char = (a + b).asInstanceOf[Char]
}

given Add[Byte] with {
override def e(): Byte = 0
override def combine(a: Byte, b: Byte): Byte = (a + b).asInstanceOf[Byte]
}

given Add[Short] with {
override def e(): Short = 0
override def combine(a: Short, b: Short): Short = (a + b).asInstanceOf[Short]
}

given Add[Int] with {
override def e(): Int = 0
override def combine(a: Int, b: Int): Int = a + b
}

given Add[Long] with {
override def e(): Long = 0L
override def combine(a: Long, b: Long): Long = a + b
}

given Add[Float] with {
override def e(): Float = 0f
override def combine(a: Float, b: Float): Float = a + b
}

given Add[Double] with {
override def e(): Double = 0d
override def combine(a: Double, b: Double): Double = a + b
}

given Add[DynamicModInt] with {
private val zero = DynamicModInt()
override def e(): DynamicModInt = zero
override def combine(a: DynamicModInt, b: DynamicModInt): DynamicModInt = a + b
}

given Add[ModInt998244353] with {
private val zero = ModInt998244353()
override def e(): ModInt998244353 = zero
override def combine(a: ModInt998244353, b: ModInt998244353): ModInt998244353 = a + b
}

given Add[ModInt1000000007] with {
private val zero = ModInt1000000007()
override def e(): ModInt1000000007 = zero
override def combine(a: ModInt1000000007, b: ModInt1000000007): ModInt1000000007 = a + b
}

given [T <: Int](using Modulus[T]): Add[StaticModInt[T]] with {
private val zero = StaticModInt[T]()
override def e(): StaticModInt[T] = zero
override def combine(a: StaticModInt[T], b: StaticModInt[T]): StaticModInt[T] = a + b
}

def apply[T](using m: Add[T]): Add[T] = m
}
170 changes: 170 additions & 0 deletions src/test/scala/io/github/acl4s/FenwickTreeSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
package io.github.acl4s

class FenwickTreeSuite extends munit.FunSuite {

/**
* @see https://atcoder.jp/contests/practice2/tasks/practice2_b
*/
test("AtCoder Library Practice Contest B - Fenwick Tree") {
val fw = FenwickTree(Array(1L, 2L, 3L, 4L, 5L))

assertEquals(fw.sum(0, 5), 15L)
assertEquals(fw.sum(2, 4), 7L)

fw.add(3, 10)

assertEquals(fw.sum(0, 5), 25L)
assertEquals(fw.sum(0, 3), 6L)
}

test("zero") {
{
val fw = FenwickTree[Long](0)
assertEquals(fw.sum(0, 0), 0L)
}

{
type ModInt = DynamicModInt
val ModInt = DynamicModInt

val fw = FenwickTree[ModInt](0)
assertEquals(fw.sum(0, 0), ModInt(0))
}

{
type ModInt = ModInt998244353
val ModInt = ModInt998244353

val fw = FenwickTree[ModInt](0)
assertEquals(fw.sum(0, 0), ModInt(0))
}

{
type ModInt = ModInt1000000007
val ModInt = ModInt1000000007

val fw = FenwickTree[ModInt](0)
assertEquals(fw.sum(0, 0), ModInt(0))
}

{
given Modulus[1_000_000_009] = Modulus[1_000_000_009]()
type ModInt = StaticModInt[1_000_000_009]
val ModInt = StaticModInt

val fw = FenwickTree[ModInt](0)
assertEquals(fw.sum(0, 0), ModInt(0))
}
}

test("naive") {
(0 to 50).foreach(n => {
val fw = FenwickTree[Long](n)
(0 until n).foreach(i => {
fw.add(i, i.toLong * i)
})

for {
l <- 0 to n
r <- l to n
} {
val sum = (l until r).map(i => i.toLong * i).sum
assertEquals(fw.sum(l, r), sum)
}
})
}

test("bound int") {
val fw = FenwickTree[Int](10)

fw.add(3, Int.MaxValue)
fw.add(5, Int.MinValue)

assertEquals(fw.sum(0, 10), -1)
assertEquals(fw.sum(3, 6), -1)

assertEquals(fw.sum(3, 4), Int.MaxValue)
assertEquals(fw.sum(4, 10), Int.MinValue)
}

test("bound long") {
val fw = FenwickTree[Long](10)

fw.add(3, Long.MaxValue)
fw.add(5, Long.MinValue)

assertEquals(fw.sum(0, 10), -1L)
assertEquals(fw.sum(3, 6), -1L)

assertEquals(fw.sum(3, 4), Long.MaxValue)
assertEquals(fw.sum(4, 10), Long.MinValue)
}

test("overflow") {
val fw = FenwickTree[Int](20)
val a = new Array[Long](20)
(0 until 10).foreach(i => {
fw.add(i, Int.MaxValue)
a(i) += Int.MaxValue
})
(10 until 20).foreach(i => {
fw.add(i, Int.MinValue)
a(i) += Int.MinValue
})

fw.add(5, 11_111)
a(5) += 11_111

for {
l <- 0 to 20
r <- l to 20
} {
val sum = (l until r).map(i => a(i)).sum
val dif = sum - fw.sum(l, r)
assertEquals(dif % (1L << 32), 0L)
}
}

test("StaticModInt") {
given Modulus[11] = Modulus[11]()
type ModInt = StaticModInt[11]
val ModInt = StaticModInt

(0 to 50).foreach(n => {
val fw = FenwickTree[ModInt](n)
(0 until n).foreach(i => {
fw.add(i, ModInt(i.toLong * i))
})

for {
l <- 0 to n
r <- l to n
} {
val sum = (l until r).map(i => i.toLong * i).sum
assertEquals(fw.sum(l, r), ModInt(sum))
}
})
}

test("DynamicModInt") {
type ModInt = DynamicModInt
val ModInt = DynamicModInt
ModInt.setMod(11)

(0 to 50).foreach(n => {
val fw = FenwickTree[ModInt](n)
(0 until n).foreach(i => {
fw.add(i, ModInt(i.toLong * i))
})

for {
l <- 0 to n
r <- l to n
} {
val sum = (l until r).map(i => i.toLong * i).sum
assertEquals(fw.sum(l, r), ModInt(sum))
}
})
}

}