diff --git a/build.sbt b/build.sbt index 3b2ee7f..b100bb5 100644 --- a/build.sbt +++ b/build.sbt @@ -16,5 +16,6 @@ lazy val root = project "-unchecked", "-Wunused:all" ), - libraryDependencies += "org.scalameta" %% "munit" % "0.7.29" % Test + libraryDependencies += "org.scalameta" %% "munit" % "0.7.29" % Test, + Test / parallelExecution := false ) diff --git a/src/main/scala/io/github/acl4s/FenwickTree.scala b/src/main/scala/io/github/acl4s/FenwickTree.scala new file mode 100644 index 0000000..b9171ef --- /dev/null +++ b/src/main/scala/io/github/acl4s/FenwickTree.scala @@ -0,0 +1,127 @@ +package io.github.acl4s + +import scala.reflect.ClassTag + +import io.github.acl4s.internal.rightOpenInterval + +/** + * Reference: https://en.wikipedia.org/wiki/Fenwick_tree + * + * @param n + * @param m + * @tparam T + */ +case class FenwickTree[T: ClassTag](n: Int)(using m: AddSub[T]) { + private val data: Array[T] = Array.fill(n)(m.e()) + + def add(index: Int, x: T): Unit = { + assert(0 <= index && index < n) + var p = index + 1 + while (p <= n) { + data(p - 1) = m.combine(data(p - 1), x) + p += p & -p + } + } + + private def sum(i: Int): T = { + var s = m.e() + var r = i + while (r > 0) { + s = m.combine(s, data(r - 1)) + r -= r & -r + } + s + } + + def sum(range: Range): T = { + val (l, r) = rightOpenInterval(range) + sum(l, r) + } + + def sum(l: Int, r: Int): T = { + assert(0 <= l && l <= r && r <= n) + m.subtract(sum(r), sum(l)) + } +} + +object FenwickTree { + def apply[T: AddSub: ClassTag](array: Array[T]): FenwickTree[T] = { + val ft = FenwickTree[T](array.length) + array.indices.foreach(i => { + ft.add(i, array(i)) + }) + ft + } +} + +trait AddSub[T] extends Add[T] { + def subtract(a: T, b: T): T +} + +object AddSub { + given (using m: Add[Char]): AddSub[Char] with { + override def e(): Char = m.e() + override def combine(a: Char, b: Char): Char = m.combine(a, b) + override def subtract(a: Char, b: Char): Char = (a - b).asInstanceOf[Char] + } + + given (using m: Add[Byte]): AddSub[Byte] with { + override def e(): Byte = m.e() + override def combine(a: Byte, b: Byte): Byte = m.combine(a, b) + override def subtract(a: Byte, b: Byte): Byte = (a - b).asInstanceOf[Byte] + } + + given (using m: Add[Short]): AddSub[Short] with { + override def e(): Short = m.e() + override def combine(a: Short, b: Short): Short = m.combine(a, b) + override def subtract(a: Short, b: Short): Short = (a - b).asInstanceOf[Short] + } + + given (using m: Add[Int]): AddSub[Int] with { + override def e(): Int = m.e() + override def combine(a: Int, b: Int): Int = m.combine(a, b) + override def subtract(a: Int, b: Int): Int = a - b + } + + given (using m: Add[Long]): AddSub[Long] with { + override def e(): Long = m.e() + override def combine(a: Long, b: Long): Long = m.combine(a, b) + override def subtract(a: Long, b: Long): Long = a - b + } + + given (using m: Add[Float]): AddSub[Float] with { + override def e(): Float = m.e() + override def combine(a: Float, b: Float): Float = m.combine(a, b) + override def subtract(a: Float, b: Float): Float = a - b + } + + given (using m: Add[Double]): AddSub[Double] with { + override def e(): Double = m.e() + override def combine(a: Double, b: Double): Double = m.combine(a, b) + override def subtract(a: Double, b: Double): Double = a - b + } + + given (using m: Add[DynamicModInt]): AddSub[DynamicModInt] with { + override def e(): DynamicModInt = m.e() + override def combine(a: DynamicModInt, b: DynamicModInt): DynamicModInt = m.combine(a, b) + override def subtract(a: DynamicModInt, b: DynamicModInt): DynamicModInt = a - b + } + + given (using m: Add[ModInt998244353]): AddSub[ModInt998244353] with { + override def e(): ModInt998244353 = m.e() + override def combine(a: ModInt998244353, b: ModInt998244353): ModInt998244353 = m.combine(a, b) + override def subtract(a: ModInt998244353, b: ModInt998244353): ModInt998244353 = a - b + } + + given (using m: Add[ModInt1000000007]): AddSub[ModInt1000000007] with { + override def e(): ModInt1000000007 = m.e() + override def combine(a: ModInt1000000007, b: ModInt1000000007): ModInt1000000007 = m.combine(a, b) + override def subtract(a: ModInt1000000007, b: ModInt1000000007): ModInt1000000007 = a - b + } + + given [T <: Int](using m: Add[StaticModInt[T]]): AddSub[StaticModInt[T]] with { + override def e(): StaticModInt[T] = m.e() + override def combine(a: StaticModInt[T], b: StaticModInt[T]): StaticModInt[T] = m.combine(a, b) + override def subtract(a: StaticModInt[T], b: StaticModInt[T]): StaticModInt[T] = a - b + } +} diff --git a/src/main/scala/io/github/acl4s/ModInt.scala b/src/main/scala/io/github/acl4s/ModInt.scala index 0b1f52c..6756af1 100644 --- a/src/main/scala/io/github/acl4s/ModInt.scala +++ b/src/main/scala/io/github/acl4s/ModInt.scala @@ -189,6 +189,7 @@ type ModInt998244353 = StaticModInt[Mod998244353.value.type] object ModInt1000000007 { given Modulus[1_000_000_007] = Mod1000000007 + def apply(): ModInt1000000007 = StaticModInt() def apply(value: Int): ModInt1000000007 = StaticModInt(value) def apply(value: Long): ModInt1000000007 = StaticModInt(value) } @@ -196,6 +197,7 @@ object ModInt1000000007 { object ModInt998244353 { given Modulus[998_244_353] = Mod998244353 + def apply(): ModInt998244353 = StaticModInt() def apply(value: Int): ModInt998244353 = StaticModInt(value) def apply(value: Long): ModInt998244353 = StaticModInt(value) } diff --git a/src/main/scala/io/github/acl4s/Monoid.scala b/src/main/scala/io/github/acl4s/Monoid.scala index 3db742e..91c5cf0 100644 --- a/src/main/scala/io/github/acl4s/Monoid.scala +++ b/src/main/scala/io/github/acl4s/Monoid.scala @@ -8,3 +8,67 @@ trait Monoid[T] { object Monoid { def apply[T](using m: Monoid[T]): Monoid[T] = m } + +trait Add[T] extends Monoid[T] +object Add { + given Add[Char] with { + override def e(): Char = 0 + override def combine(a: Char, b: Char): Char = (a + b).asInstanceOf[Char] + } + + given Add[Byte] with { + override def e(): Byte = 0 + override def combine(a: Byte, b: Byte): Byte = (a + b).asInstanceOf[Byte] + } + + given Add[Short] with { + override def e(): Short = 0 + override def combine(a: Short, b: Short): Short = (a + b).asInstanceOf[Short] + } + + given Add[Int] with { + override def e(): Int = 0 + override def combine(a: Int, b: Int): Int = a + b + } + + given Add[Long] with { + override def e(): Long = 0L + override def combine(a: Long, b: Long): Long = a + b + } + + given Add[Float] with { + override def e(): Float = 0f + override def combine(a: Float, b: Float): Float = a + b + } + + given Add[Double] with { + override def e(): Double = 0d + override def combine(a: Double, b: Double): Double = a + b + } + + given Add[DynamicModInt] with { + private val zero = DynamicModInt() + override def e(): DynamicModInt = zero + override def combine(a: DynamicModInt, b: DynamicModInt): DynamicModInt = a + b + } + + given Add[ModInt998244353] with { + private val zero = ModInt998244353() + override def e(): ModInt998244353 = zero + override def combine(a: ModInt998244353, b: ModInt998244353): ModInt998244353 = a + b + } + + given Add[ModInt1000000007] with { + private val zero = ModInt1000000007() + override def e(): ModInt1000000007 = zero + override def combine(a: ModInt1000000007, b: ModInt1000000007): ModInt1000000007 = a + b + } + + given [T <: Int](using Modulus[T]): Add[StaticModInt[T]] with { + private val zero = StaticModInt[T]() + override def e(): StaticModInt[T] = zero + override def combine(a: StaticModInt[T], b: StaticModInt[T]): StaticModInt[T] = a + b + } + + def apply[T](using m: Add[T]): Add[T] = m +} diff --git a/src/test/scala/io/github/acl4s/FenwickTreeSuite.scala b/src/test/scala/io/github/acl4s/FenwickTreeSuite.scala new file mode 100644 index 0000000..4a63b41 --- /dev/null +++ b/src/test/scala/io/github/acl4s/FenwickTreeSuite.scala @@ -0,0 +1,170 @@ +package io.github.acl4s + +class FenwickTreeSuite extends munit.FunSuite { + + /** + * @see https://atcoder.jp/contests/practice2/tasks/practice2_b + */ + test("AtCoder Library Practice Contest B - Fenwick Tree") { + val fw = FenwickTree(Array(1L, 2L, 3L, 4L, 5L)) + + assertEquals(fw.sum(0, 5), 15L) + assertEquals(fw.sum(2, 4), 7L) + + fw.add(3, 10) + + assertEquals(fw.sum(0, 5), 25L) + assertEquals(fw.sum(0, 3), 6L) + } + + test("zero") { + { + val fw = FenwickTree[Long](0) + assertEquals(fw.sum(0, 0), 0L) + } + + { + type ModInt = DynamicModInt + val ModInt = DynamicModInt + + val fw = FenwickTree[ModInt](0) + assertEquals(fw.sum(0, 0), ModInt(0)) + } + + { + type ModInt = ModInt998244353 + val ModInt = ModInt998244353 + + val fw = FenwickTree[ModInt](0) + assertEquals(fw.sum(0, 0), ModInt(0)) + } + + { + type ModInt = ModInt1000000007 + val ModInt = ModInt1000000007 + + val fw = FenwickTree[ModInt](0) + assertEquals(fw.sum(0, 0), ModInt(0)) + } + + { + given Modulus[1_000_000_009] = Modulus[1_000_000_009]() + type ModInt = StaticModInt[1_000_000_009] + val ModInt = StaticModInt + + val fw = FenwickTree[ModInt](0) + assertEquals(fw.sum(0, 0), ModInt(0)) + } + } + + test("naive") { + (0 to 50).foreach(n => { + val fw = FenwickTree[Long](n) + (0 until n).foreach(i => { + fw.add(i, i.toLong * i) + }) + + for { + l <- 0 to n + r <- l to n + } { + val sum = (l until r).map(i => i.toLong * i).sum + assertEquals(fw.sum(l, r), sum) + } + }) + } + + test("bound int") { + val fw = FenwickTree[Int](10) + + fw.add(3, Int.MaxValue) + fw.add(5, Int.MinValue) + + assertEquals(fw.sum(0, 10), -1) + assertEquals(fw.sum(3, 6), -1) + + assertEquals(fw.sum(3, 4), Int.MaxValue) + assertEquals(fw.sum(4, 10), Int.MinValue) + } + + test("bound long") { + val fw = FenwickTree[Long](10) + + fw.add(3, Long.MaxValue) + fw.add(5, Long.MinValue) + + assertEquals(fw.sum(0, 10), -1L) + assertEquals(fw.sum(3, 6), -1L) + + assertEquals(fw.sum(3, 4), Long.MaxValue) + assertEquals(fw.sum(4, 10), Long.MinValue) + } + + test("overflow") { + val fw = FenwickTree[Int](20) + val a = new Array[Long](20) + (0 until 10).foreach(i => { + fw.add(i, Int.MaxValue) + a(i) += Int.MaxValue + }) + (10 until 20).foreach(i => { + fw.add(i, Int.MinValue) + a(i) += Int.MinValue + }) + + fw.add(5, 11_111) + a(5) += 11_111 + + for { + l <- 0 to 20 + r <- l to 20 + } { + val sum = (l until r).map(i => a(i)).sum + val dif = sum - fw.sum(l, r) + assertEquals(dif % (1L << 32), 0L) + } + } + + test("StaticModInt") { + given Modulus[11] = Modulus[11]() + type ModInt = StaticModInt[11] + val ModInt = StaticModInt + + (0 to 50).foreach(n => { + val fw = FenwickTree[ModInt](n) + (0 until n).foreach(i => { + fw.add(i, ModInt(i.toLong * i)) + }) + + for { + l <- 0 to n + r <- l to n + } { + val sum = (l until r).map(i => i.toLong * i).sum + assertEquals(fw.sum(l, r), ModInt(sum)) + } + }) + } + + test("DynamicModInt") { + type ModInt = DynamicModInt + val ModInt = DynamicModInt + ModInt.setMod(11) + + (0 to 50).foreach(n => { + val fw = FenwickTree[ModInt](n) + (0 until n).foreach(i => { + fw.add(i, ModInt(i.toLong * i)) + }) + + for { + l <- 0 to n + r <- l to n + } { + val sum = (l until r).map(i => i.toLong * i).sum + assertEquals(fw.sum(l, r), ModInt(sum)) + } + }) + } + +}