Skip to content

Commit

Permalink
Add BinInt, implement divmod and just for fun, add Nat.scala (#1194)
Browse files Browse the repository at this point in the history
* Add BinInt, implement divmod and just for fun, add Nat.scala

* improve Nat.scala test coverage

* improve cmp_BinInt
  • Loading branch information
johnynek committed Apr 10, 2024
1 parent a71c059 commit 56fc9c1
Show file tree
Hide file tree
Showing 6 changed files with 583 additions and 3 deletions.
219 changes: 219 additions & 0 deletions core/src/main/scala/org/bykn/bosatsu/Nat.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
package org.bykn.bosatsu
import org.bykn.bosatsu.Nat.Shift
import org.bykn.bosatsu.Nat.Small

import Nat._

sealed abstract class Nat { lhs =>
def toBigInt: BigInt =
lhs match {
case Small(asInt) =>
BigInt(toLong(asInt))
case Shift(x, b) =>
// m(x + 1) + b
two_32_BigInt * (x.inc.toBigInt) + toLong(b)
}

override def toString = toBigInt.toString

def maybeLong: Option[Long] =
lhs match {
case Small(asInt) => Some(toLong(asInt))
case Shift(x, b) =>
// 2^32 * (x + 1) + b
x.maybeLong match {
case Some(v0) =>
if (v0 < Int.MaxValue) {
val value = v0 + 1
Some((value << 32) + toLong(b))
}
else None
case None => None
}
}


def inc: Nat =
lhs match {
case Small(asInt) =>
val next = asInt + 1
if (next != 0) wrapInt(next)
else two_32
case Shift(x, b) =>
// m(x + 1) + b
val b1 = b + 1
if (b1 != 0) Shift(x, b1)
else {
// m(x + 1) + m
// m(x + 1 + 1)
Shift(x.inc, 0)
}
}

def dec: Nat =
lhs match {
case Small(asInt) =>
if (asInt != 0) wrapInt(asInt - 1)
else zero
case Shift(x, b) =>
if (b != 0) Shift(x, b - 1)
else {
if (x.isZero) wrapInt(-1)
else Shift(x.dec, -1)
}
}

def isZero: Boolean =
lhs match {
case Small(0) => true
case _ => false
}

// this * 2^32
def shift_32: Nat =
lhs match {
case Small(asInt) =>
if (asInt == 0) zero
else {
// m*(b - 1 + 1)
Shift(wrapInt(asInt - 1), 0)
}
case Shift(x, b) =>
// m*(m*(x + 1) + b)
// m*((m(x + 1) + (b - 1)) + 1) + 0
if (b != 0) {
val b1 = b - 1
Shift(Shift(x, b1), 0)
}
else if (x.isZero) {
// m * (m(0 + 1) + 0)
// m * ((m - 1) + 1) + 0
val inner = wrapLong(0xFFFFFFFFL)
Shift(inner, 0)
}
else {
// x > 0
//m(m(x + 1)) =
//m(m(x + 1) - 1 + 1) + 0
//m((m(x + 1) - 1) + 1) + 0
val inner = x.inc.shift_32.dec
Shift(inner, 0)
}
}

def +(rhs: Nat): Nat = {
lhs match {
case Small(l) =>
rhs match {
case Small(r) =>
wrapLong(toLong(l) + toLong(r))
case Shift(x, b) =>
val res = toLong(l) + toLong(b)
val low = lowBits(res)
val high = highBits(res)
if (high == 0) Shift(x, low)
else {
Shift(x + wrapInt(high), low)
}
}

case Shift(x, b) =>
rhs match {
case Small(l) =>
val res = toLong(l) + toLong(b)
val low = lowBits(res)
val high = highBits(res)
if (high == 0) Shift(x, low)
else {
Shift(x + wrapInt(high), low)
}
case Shift(x1, b1) =>
// (m(x + 1) + b) + (m (x1 + 1) + b1) = m(x + x1 + 1 + 1) + (b + b1)
val res = toLong(b) + toLong(b1)
val low = lowBits(res)
val high = highBits(res)
val xs = (x + x1).inc
if (high == 0) Shift(xs, low)
else {
Shift(xs + wrapInt(high), low)
}
}
}
}
def *(rhs: Nat): Nat = {
lhs match {
case Small(l) =>
if (l == 0) zero
else (rhs match {
case Small(r) =>
// can't overflow Long
wrapLong(toLong(l) * toLong(r))
case Shift(x, b) =>
// (m(x + 1) + b) * l
(x.inc * lhs).shift_32 + wrapInt(b) * lhs
})

case Shift(x, b) =>
// (m(x + 1) + b) * rhs
// (x + 1)*rhs*m + b * rhs
(rhs * x.inc).shift_32 + rhs * wrapInt(b)
}
}
}

object Nat {
private val two_32_BigInt = BigInt(1L << 32)

private val intMaskLow: Long = 0xFFFFFFFFL
def toLong(i: Int): Long = i.toLong & intMaskLow
def lowBits(l: Long): Int = l.toInt
private def highBits(l: Long): Int = lowBits(l >> 32)

// all numbers from [0, 2^{32} - 1]
private case class Small(asInt: Int) extends Nat
// base * (x + 1) + b
private case class Shift(x: Nat, b: Int) extends Nat

private val cache: Array[Small] =
Array.tabulate(1024)(Small(_))

val zero: Nat = cache(0)
val one: Nat = cache(1)
val two_32: Nat = Shift(zero, 0)

// if the number is <= 0, return 0
def fromInt(i: Int): Nat =
if (i < 0) zero
else if (i < cache.length) cache(i)
else Small(i)

// if the number is <= 0, return 0
def fromLong(l: Long): Nat =
if (l < 0) zero
else wrapLong(l)

def wrapInt(i: Int): Nat =
if (0 <= i && i < cache.length) cache(i)
else Small(i)

def wrapLong(l: Long): Nat = {
val low = lowBits(l)
val high = highBits(l)
if (high == 0) wrapInt(low)
else {
Shift(wrapInt(high - 1), low)
}
}

// if b < 0 return 0
def fromBigInt(b: BigInt): Nat =
if (b <= 0) zero
else if (b < two_32_BigInt) {
fromLong(b.toLong)
}
else {
val low = b % two_32_BigInt
val high = b >> 32
Shift(fromBigInt(high - 1), lowBits(low.toLong))
}
}
154 changes: 154 additions & 0 deletions core/src/test/scala/org/bykn/bosatsu/NatTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
package org.bykn.bosatsu

import org.scalacheck.{Gen, Prop}

import Prop.forAll

class NatTest extends munit.ScalaCheckSuite {

override def scalaCheckTestParameters =
super.scalaCheckTestParameters
.withMinSuccessfulTests(if (Platform.isScalaJvm) 100000 else 100)
.withMaxDiscardRatio(10)

//override def scalaCheckInitialSeed = "YOFqcGzXOFtFVgRFxOmODi5100tVovDS3EPOv0Ihk4C="

lazy val genNat: Gen[Nat] = {
val recur = Gen.lzy(genNat)
Gen.frequency(
// make sure to exercise the cached table
(1, Gen.chooseNum(0, 1024, 1).map(Nat.fromInt(_))),
(5, Gen.chooseNum(0, Long.MaxValue, Int.MaxValue.toLong, Int.MaxValue.toLong + 1).map(Nat.fromLong(_))),
(1, Gen.zip(recur, recur).map { case (a, b) => a + b }),
(1, Gen.zip(recur, recur).map { case (a, b) => a * b })
)
}

test("constants are right") {
assertEquals(Nat.zero.maybeLong, Some(0L))
assertEquals(Nat.one.maybeLong, Some(1L))
assertEquals(Nat.two_32.maybeLong, Some(1L << 32))
}
property("fromInt/maybeLong identity") {
forAll { (i: Int) =>
val n = Nat.fromInt(i)
n.maybeLong match {
case None => fail(s"couldn't maybeLong $i")
case Some(l) => assert(i < 0 || (l.toInt == i))
}
}
}
property("fromLong/maybeLong identity") {
forAll { (i: Long) =>
val n = Nat.fromLong(i)
n.maybeLong match {
case Some(l) => assert(i < 0 || (l == i))
case None => fail(s"couldn't maybeLong $i")
}
}
}

def trunc(i: Int): Long = if (i < 0) 0L else i.toLong

property("toLong => lowBits is identity") {
val big = 0x80000000L

assertEquals(Nat.lowBits(big), Int.MinValue, s"${Nat.lowBits(big)}")
forAll(Gen.chooseNum(Int.MinValue, Int.MaxValue)) { (i: Int) =>
val l = Nat.toLong(i)
assertEquals(Nat.lowBits(Nat.toLong(i)), i, s"long = $l")
}
}

property("x + y homomorphism") {
forAll(genNat, genNat) { (ni, nj) =>
val nk = ni + nj
assertEquals(nk.toBigInt, ni.toBigInt + nj.toBigInt)
}
}

property("x * y = y * x") {
forAll(genNat, genNat) { (n1, n2) =>
assertEquals(n1 * n2, n2 * n1)
}
}

property("x + y = y + x") {
forAll(genNat, genNat) { (n1, n2) =>
assertEquals(n1 + n2, n2 + n1)
}
}

property("x * y homomorphism") {
forAll(genNat, genNat) { (ni, nj) =>
val nk = ni * nj
assertEquals(nk.toBigInt, ni.toBigInt * nj.toBigInt)
}
}

property("x.inc == x + 1") {
forAll(genNat) { n =>
val i = n.inc
val a = n + Nat.one
assertEquals(i.toBigInt, a.toBigInt)
}
}

property("x.dec == x - 1 when x > 0") {
forAll(genNat) { n =>
val i = n.dec.toBigInt
if (n == Nat.zero) assertEquals(i, BigInt(0))
else assertEquals(i, n.toBigInt - 1)
}
}

property("x.shift_32 == x.toBigInt * 2^32") {
val n1 = BigInt("8588888260151524380556863712485265508")
val shift = n1 << 32
assertEquals(Nat.fromBigInt(n1).shift_32.toBigInt, shift)

forAll(genNat) { n =>
val s = n.shift_32
val viaBigInt = n.toBigInt << 32
val viaTimes = n * Nat.two_32
assertEquals(s.toBigInt, viaBigInt, s"viaTimes = $viaTimes")
assertEquals(s, viaTimes)
}
}

property("Nat.fromBigInt/toBigInt") {
assertEquals(Nat.fromLong(Long.MaxValue).toBigInt, BigInt(Long.MaxValue))

forAll { (bi0: BigInt) =>
val bi = bi0.abs
val n = Nat.fromBigInt(bi)
val b2 = n.toBigInt
assertEquals(b2, bi)
}
}

property("x.inc.dec == x") {
forAll(genNat) { n =>
assertEquals(n.inc.dec, n)
}
}

property("x.dec.inc == x || x.isZero") {
forAll(genNat) { n =>
assert((n.dec.inc == n) || n.isZero)
}
}

property("if the value is > Long.MaxValue maybeLong = None") {
forAll(genNat) { n =>
val bi = n.toBigInt
val ml = n.maybeLong
assertEquals(ml.isEmpty, bi > Long.MaxValue)
}
}
property("the string repr matches toBigInt") {
forAll(genNat) { n =>
assertEquals(n.toString, n.toBigInt.toString)
}
}
}
Loading

0 comments on commit 56fc9c1

Please sign in to comment.