Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BinInt, implement divmod and just for fun, add Nat.scala #1194

Merged
merged 3 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 219 additions & 0 deletions core/src/main/scala/org/bykn/bosatsu/Nat.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
package org.bykn.bosatsu
import org.bykn.bosatsu.Nat.Shift
import org.bykn.bosatsu.Nat.Small

import Nat._

sealed abstract class Nat { lhs =>
def toBigInt: BigInt =
lhs match {
case Small(asInt) =>
BigInt(toLong(asInt))
case Shift(x, b) =>
// m(x + 1) + b
two_32_BigInt * (x.inc.toBigInt) + toLong(b)
}

override def toString = toBigInt.toString

def maybeLong: Option[Long] =
lhs match {
case Small(asInt) => Some(toLong(asInt))
case Shift(x, b) =>
// 2^32 * (x + 1) + b
x.maybeLong match {
case Some(v0) =>
if (v0 < Int.MaxValue) {
val value = v0 + 1
Some((value << 32) + toLong(b))
}
else None
case None => None
}
}


def inc: Nat =
lhs match {
case Small(asInt) =>
val next = asInt + 1
if (next != 0) wrapInt(next)
else two_32
case Shift(x, b) =>
// m(x + 1) + b
val b1 = b + 1
if (b1 != 0) Shift(x, b1)
else {
// m(x + 1) + m
// m(x + 1 + 1)
Shift(x.inc, 0)
}
}

def dec: Nat =
lhs match {
case Small(asInt) =>
if (asInt != 0) wrapInt(asInt - 1)
else zero
case Shift(x, b) =>
if (b != 0) Shift(x, b - 1)
else {
if (x.isZero) wrapInt(-1)
else Shift(x.dec, -1)
}
}

def isZero: Boolean =
lhs match {
case Small(0) => true
case _ => false
}

// this * 2^32
def shift_32: Nat =
lhs match {
case Small(asInt) =>
if (asInt == 0) zero
else {
// m*(b - 1 + 1)
Shift(wrapInt(asInt - 1), 0)
}
case Shift(x, b) =>
// m*(m*(x + 1) + b)
// m*((m(x + 1) + (b - 1)) + 1) + 0
if (b != 0) {
val b1 = b - 1
Shift(Shift(x, b1), 0)
}
else if (x.isZero) {
// m * (m(0 + 1) + 0)
// m * ((m - 1) + 1) + 0
val inner = wrapLong(0xFFFFFFFFL)
Shift(inner, 0)
}
else {
// x > 0
//m(m(x + 1)) =
//m(m(x + 1) - 1 + 1) + 0
//m((m(x + 1) - 1) + 1) + 0
val inner = x.inc.shift_32.dec
Shift(inner, 0)
}
}

def +(rhs: Nat): Nat = {
lhs match {
case Small(l) =>
rhs match {
case Small(r) =>
wrapLong(toLong(l) + toLong(r))
case Shift(x, b) =>
val res = toLong(l) + toLong(b)
val low = lowBits(res)
val high = highBits(res)
if (high == 0) Shift(x, low)
else {
Shift(x + wrapInt(high), low)
}
}

case Shift(x, b) =>
rhs match {
case Small(l) =>
val res = toLong(l) + toLong(b)
val low = lowBits(res)
val high = highBits(res)
if (high == 0) Shift(x, low)
else {
Shift(x + wrapInt(high), low)
}
case Shift(x1, b1) =>
// (m(x + 1) + b) + (m (x1 + 1) + b1) = m(x + x1 + 1 + 1) + (b + b1)
val res = toLong(b) + toLong(b1)
val low = lowBits(res)
val high = highBits(res)
val xs = (x + x1).inc
if (high == 0) Shift(xs, low)
else {
Shift(xs + wrapInt(high), low)
}
}
}
}
def *(rhs: Nat): Nat = {
lhs match {
case Small(l) =>
if (l == 0) zero
else (rhs match {
case Small(r) =>
// can't overflow Long
wrapLong(toLong(l) * toLong(r))
case Shift(x, b) =>
// (m(x + 1) + b) * l
(x.inc * lhs).shift_32 + wrapInt(b) * lhs
})

case Shift(x, b) =>
// (m(x + 1) + b) * rhs
// (x + 1)*rhs*m + b * rhs
(rhs * x.inc).shift_32 + rhs * wrapInt(b)
}
}
}

object Nat {
private val two_32_BigInt = BigInt(1L << 32)

private val intMaskLow: Long = 0xFFFFFFFFL
def toLong(i: Int): Long = i.toLong & intMaskLow
def lowBits(l: Long): Int = l.toInt
private def highBits(l: Long): Int = lowBits(l >> 32)

// all numbers from [0, 2^{32} - 1]
private case class Small(asInt: Int) extends Nat
// base * (x + 1) + b
private case class Shift(x: Nat, b: Int) extends Nat

private val cache: Array[Small] =
Array.tabulate(1024)(Small(_))

val zero: Nat = cache(0)
val one: Nat = cache(1)
val two_32: Nat = Shift(zero, 0)

// if the number is <= 0, return 0
def fromInt(i: Int): Nat =
if (i < 0) zero
else if (i < cache.length) cache(i)
else Small(i)

// if the number is <= 0, return 0
def fromLong(l: Long): Nat =
if (l < 0) zero
else wrapLong(l)

def wrapInt(i: Int): Nat =
if (0 <= i && i < cache.length) cache(i)
else Small(i)

def wrapLong(l: Long): Nat = {
val low = lowBits(l)
val high = highBits(l)
if (high == 0) wrapInt(low)
else {
Shift(wrapInt(high - 1), low)
}
}

// if b < 0 return 0
def fromBigInt(b: BigInt): Nat =
if (b <= 0) zero
else if (b < two_32_BigInt) {
fromLong(b.toLong)
}
else {
val low = b % two_32_BigInt
val high = b >> 32
Shift(fromBigInt(high - 1), lowBits(low.toLong))
}
}
154 changes: 154 additions & 0 deletions core/src/test/scala/org/bykn/bosatsu/NatTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
package org.bykn.bosatsu

import org.scalacheck.{Gen, Prop}

import Prop.forAll

class NatTest extends munit.ScalaCheckSuite {

override def scalaCheckTestParameters =
super.scalaCheckTestParameters
.withMinSuccessfulTests(if (Platform.isScalaJvm) 100000 else 100)
.withMaxDiscardRatio(10)

//override def scalaCheckInitialSeed = "YOFqcGzXOFtFVgRFxOmODi5100tVovDS3EPOv0Ihk4C="

lazy val genNat: Gen[Nat] = {
val recur = Gen.lzy(genNat)
Gen.frequency(
// make sure to exercise the cached table
(1, Gen.chooseNum(0, 1024, 1).map(Nat.fromInt(_))),
(5, Gen.chooseNum(0, Long.MaxValue, Int.MaxValue.toLong, Int.MaxValue.toLong + 1).map(Nat.fromLong(_))),
(1, Gen.zip(recur, recur).map { case (a, b) => a + b }),
(1, Gen.zip(recur, recur).map { case (a, b) => a * b })
)
}

test("constants are right") {
assertEquals(Nat.zero.maybeLong, Some(0L))
assertEquals(Nat.one.maybeLong, Some(1L))
assertEquals(Nat.two_32.maybeLong, Some(1L << 32))
}
property("fromInt/maybeLong identity") {
forAll { (i: Int) =>
val n = Nat.fromInt(i)
n.maybeLong match {
case None => fail(s"couldn't maybeLong $i")
case Some(l) => assert(i < 0 || (l.toInt == i))
}
}
}
property("fromLong/maybeLong identity") {
forAll { (i: Long) =>
val n = Nat.fromLong(i)
n.maybeLong match {
case Some(l) => assert(i < 0 || (l == i))
case None => fail(s"couldn't maybeLong $i")
}
}
}

def trunc(i: Int): Long = if (i < 0) 0L else i.toLong

property("toLong => lowBits is identity") {
val big = 0x80000000L

assertEquals(Nat.lowBits(big), Int.MinValue, s"${Nat.lowBits(big)}")
forAll(Gen.chooseNum(Int.MinValue, Int.MaxValue)) { (i: Int) =>
val l = Nat.toLong(i)
assertEquals(Nat.lowBits(Nat.toLong(i)), i, s"long = $l")
}
}

property("x + y homomorphism") {
forAll(genNat, genNat) { (ni, nj) =>
val nk = ni + nj
assertEquals(nk.toBigInt, ni.toBigInt + nj.toBigInt)
}
}

property("x * y = y * x") {
forAll(genNat, genNat) { (n1, n2) =>
assertEquals(n1 * n2, n2 * n1)
}
}

property("x + y = y + x") {
forAll(genNat, genNat) { (n1, n2) =>
assertEquals(n1 + n2, n2 + n1)
}
}

property("x * y homomorphism") {
forAll(genNat, genNat) { (ni, nj) =>
val nk = ni * nj
assertEquals(nk.toBigInt, ni.toBigInt * nj.toBigInt)
}
}

property("x.inc == x + 1") {
forAll(genNat) { n =>
val i = n.inc
val a = n + Nat.one
assertEquals(i.toBigInt, a.toBigInt)
}
}

property("x.dec == x - 1 when x > 0") {
forAll(genNat) { n =>
val i = n.dec.toBigInt
if (n == Nat.zero) assertEquals(i, BigInt(0))
else assertEquals(i, n.toBigInt - 1)
}
}

property("x.shift_32 == x.toBigInt * 2^32") {
val n1 = BigInt("8588888260151524380556863712485265508")
val shift = n1 << 32
assertEquals(Nat.fromBigInt(n1).shift_32.toBigInt, shift)

forAll(genNat) { n =>
val s = n.shift_32
val viaBigInt = n.toBigInt << 32
val viaTimes = n * Nat.two_32
assertEquals(s.toBigInt, viaBigInt, s"viaTimes = $viaTimes")
assertEquals(s, viaTimes)
}
}

property("Nat.fromBigInt/toBigInt") {
assertEquals(Nat.fromLong(Long.MaxValue).toBigInt, BigInt(Long.MaxValue))

forAll { (bi0: BigInt) =>
val bi = bi0.abs
val n = Nat.fromBigInt(bi)
val b2 = n.toBigInt
assertEquals(b2, bi)
}
}

property("x.inc.dec == x") {
forAll(genNat) { n =>
assertEquals(n.inc.dec, n)
}
}

property("x.dec.inc == x || x.isZero") {
forAll(genNat) { n =>
assert((n.dec.inc == n) || n.isZero)
}
}

property("if the value is > Long.MaxValue maybeLong = None") {
forAll(genNat) { n =>
val bi = n.toBigInt
val ml = n.maybeLong
assertEquals(ml.isEmpty, bi > Long.MaxValue)
}
}
property("the string repr matches toBigInt") {
forAll(genNat) { n =>
assertEquals(n.toString, n.toBigInt.toString)
}
}
}
Loading
Loading