-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add BinInt, implement divmod and just for fun, add Nat.scala (#1194)
* Add BinInt, implement divmod and just for fun, add Nat.scala * improve Nat.scala test coverage * improve cmp_BinInt
- Loading branch information
Showing
6 changed files
with
583 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,219 @@ | ||
package org.bykn.bosatsu | ||
import org.bykn.bosatsu.Nat.Shift | ||
import org.bykn.bosatsu.Nat.Small | ||
|
||
import Nat._ | ||
|
||
sealed abstract class Nat { lhs => | ||
def toBigInt: BigInt = | ||
lhs match { | ||
case Small(asInt) => | ||
BigInt(toLong(asInt)) | ||
case Shift(x, b) => | ||
// m(x + 1) + b | ||
two_32_BigInt * (x.inc.toBigInt) + toLong(b) | ||
} | ||
|
||
override def toString = toBigInt.toString | ||
|
||
def maybeLong: Option[Long] = | ||
lhs match { | ||
case Small(asInt) => Some(toLong(asInt)) | ||
case Shift(x, b) => | ||
// 2^32 * (x + 1) + b | ||
x.maybeLong match { | ||
case Some(v0) => | ||
if (v0 < Int.MaxValue) { | ||
val value = v0 + 1 | ||
Some((value << 32) + toLong(b)) | ||
} | ||
else None | ||
case None => None | ||
} | ||
} | ||
|
||
|
||
def inc: Nat = | ||
lhs match { | ||
case Small(asInt) => | ||
val next = asInt + 1 | ||
if (next != 0) wrapInt(next) | ||
else two_32 | ||
case Shift(x, b) => | ||
// m(x + 1) + b | ||
val b1 = b + 1 | ||
if (b1 != 0) Shift(x, b1) | ||
else { | ||
// m(x + 1) + m | ||
// m(x + 1 + 1) | ||
Shift(x.inc, 0) | ||
} | ||
} | ||
|
||
def dec: Nat = | ||
lhs match { | ||
case Small(asInt) => | ||
if (asInt != 0) wrapInt(asInt - 1) | ||
else zero | ||
case Shift(x, b) => | ||
if (b != 0) Shift(x, b - 1) | ||
else { | ||
if (x.isZero) wrapInt(-1) | ||
else Shift(x.dec, -1) | ||
} | ||
} | ||
|
||
def isZero: Boolean = | ||
lhs match { | ||
case Small(0) => true | ||
case _ => false | ||
} | ||
|
||
// this * 2^32 | ||
def shift_32: Nat = | ||
lhs match { | ||
case Small(asInt) => | ||
if (asInt == 0) zero | ||
else { | ||
// m*(b - 1 + 1) | ||
Shift(wrapInt(asInt - 1), 0) | ||
} | ||
case Shift(x, b) => | ||
// m*(m*(x + 1) + b) | ||
// m*((m(x + 1) + (b - 1)) + 1) + 0 | ||
if (b != 0) { | ||
val b1 = b - 1 | ||
Shift(Shift(x, b1), 0) | ||
} | ||
else if (x.isZero) { | ||
// m * (m(0 + 1) + 0) | ||
// m * ((m - 1) + 1) + 0 | ||
val inner = wrapLong(0xFFFFFFFFL) | ||
Shift(inner, 0) | ||
} | ||
else { | ||
// x > 0 | ||
//m(m(x + 1)) = | ||
//m(m(x + 1) - 1 + 1) + 0 | ||
//m((m(x + 1) - 1) + 1) + 0 | ||
val inner = x.inc.shift_32.dec | ||
Shift(inner, 0) | ||
} | ||
} | ||
|
||
def +(rhs: Nat): Nat = { | ||
lhs match { | ||
case Small(l) => | ||
rhs match { | ||
case Small(r) => | ||
wrapLong(toLong(l) + toLong(r)) | ||
case Shift(x, b) => | ||
val res = toLong(l) + toLong(b) | ||
val low = lowBits(res) | ||
val high = highBits(res) | ||
if (high == 0) Shift(x, low) | ||
else { | ||
Shift(x + wrapInt(high), low) | ||
} | ||
} | ||
|
||
case Shift(x, b) => | ||
rhs match { | ||
case Small(l) => | ||
val res = toLong(l) + toLong(b) | ||
val low = lowBits(res) | ||
val high = highBits(res) | ||
if (high == 0) Shift(x, low) | ||
else { | ||
Shift(x + wrapInt(high), low) | ||
} | ||
case Shift(x1, b1) => | ||
// (m(x + 1) + b) + (m (x1 + 1) + b1) = m(x + x1 + 1 + 1) + (b + b1) | ||
val res = toLong(b) + toLong(b1) | ||
val low = lowBits(res) | ||
val high = highBits(res) | ||
val xs = (x + x1).inc | ||
if (high == 0) Shift(xs, low) | ||
else { | ||
Shift(xs + wrapInt(high), low) | ||
} | ||
} | ||
} | ||
} | ||
def *(rhs: Nat): Nat = { | ||
lhs match { | ||
case Small(l) => | ||
if (l == 0) zero | ||
else (rhs match { | ||
case Small(r) => | ||
// can't overflow Long | ||
wrapLong(toLong(l) * toLong(r)) | ||
case Shift(x, b) => | ||
// (m(x + 1) + b) * l | ||
(x.inc * lhs).shift_32 + wrapInt(b) * lhs | ||
}) | ||
|
||
case Shift(x, b) => | ||
// (m(x + 1) + b) * rhs | ||
// (x + 1)*rhs*m + b * rhs | ||
(rhs * x.inc).shift_32 + rhs * wrapInt(b) | ||
} | ||
} | ||
} | ||
|
||
object Nat { | ||
private val two_32_BigInt = BigInt(1L << 32) | ||
|
||
private val intMaskLow: Long = 0xFFFFFFFFL | ||
def toLong(i: Int): Long = i.toLong & intMaskLow | ||
def lowBits(l: Long): Int = l.toInt | ||
private def highBits(l: Long): Int = lowBits(l >> 32) | ||
|
||
// all numbers from [0, 2^{32} - 1] | ||
private case class Small(asInt: Int) extends Nat | ||
// base * (x + 1) + b | ||
private case class Shift(x: Nat, b: Int) extends Nat | ||
|
||
private val cache: Array[Small] = | ||
Array.tabulate(1024)(Small(_)) | ||
|
||
val zero: Nat = cache(0) | ||
val one: Nat = cache(1) | ||
val two_32: Nat = Shift(zero, 0) | ||
|
||
// if the number is <= 0, return 0 | ||
def fromInt(i: Int): Nat = | ||
if (i < 0) zero | ||
else if (i < cache.length) cache(i) | ||
else Small(i) | ||
|
||
// if the number is <= 0, return 0 | ||
def fromLong(l: Long): Nat = | ||
if (l < 0) zero | ||
else wrapLong(l) | ||
|
||
def wrapInt(i: Int): Nat = | ||
if (0 <= i && i < cache.length) cache(i) | ||
else Small(i) | ||
|
||
def wrapLong(l: Long): Nat = { | ||
val low = lowBits(l) | ||
val high = highBits(l) | ||
if (high == 0) wrapInt(low) | ||
else { | ||
Shift(wrapInt(high - 1), low) | ||
} | ||
} | ||
|
||
// if b < 0 return 0 | ||
def fromBigInt(b: BigInt): Nat = | ||
if (b <= 0) zero | ||
else if (b < two_32_BigInt) { | ||
fromLong(b.toLong) | ||
} | ||
else { | ||
val low = b % two_32_BigInt | ||
val high = b >> 32 | ||
Shift(fromBigInt(high - 1), lowBits(low.toLong)) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
package org.bykn.bosatsu | ||
|
||
import org.scalacheck.{Gen, Prop} | ||
|
||
import Prop.forAll | ||
|
||
class NatTest extends munit.ScalaCheckSuite { | ||
|
||
override def scalaCheckTestParameters = | ||
super.scalaCheckTestParameters | ||
.withMinSuccessfulTests(if (Platform.isScalaJvm) 100000 else 100) | ||
.withMaxDiscardRatio(10) | ||
|
||
//override def scalaCheckInitialSeed = "YOFqcGzXOFtFVgRFxOmODi5100tVovDS3EPOv0Ihk4C=" | ||
|
||
lazy val genNat: Gen[Nat] = { | ||
val recur = Gen.lzy(genNat) | ||
Gen.frequency( | ||
// make sure to exercise the cached table | ||
(1, Gen.chooseNum(0, 1024, 1).map(Nat.fromInt(_))), | ||
(5, Gen.chooseNum(0, Long.MaxValue, Int.MaxValue.toLong, Int.MaxValue.toLong + 1).map(Nat.fromLong(_))), | ||
(1, Gen.zip(recur, recur).map { case (a, b) => a + b }), | ||
(1, Gen.zip(recur, recur).map { case (a, b) => a * b }) | ||
) | ||
} | ||
|
||
test("constants are right") { | ||
assertEquals(Nat.zero.maybeLong, Some(0L)) | ||
assertEquals(Nat.one.maybeLong, Some(1L)) | ||
assertEquals(Nat.two_32.maybeLong, Some(1L << 32)) | ||
} | ||
property("fromInt/maybeLong identity") { | ||
forAll { (i: Int) => | ||
val n = Nat.fromInt(i) | ||
n.maybeLong match { | ||
case None => fail(s"couldn't maybeLong $i") | ||
case Some(l) => assert(i < 0 || (l.toInt == i)) | ||
} | ||
} | ||
} | ||
property("fromLong/maybeLong identity") { | ||
forAll { (i: Long) => | ||
val n = Nat.fromLong(i) | ||
n.maybeLong match { | ||
case Some(l) => assert(i < 0 || (l == i)) | ||
case None => fail(s"couldn't maybeLong $i") | ||
} | ||
} | ||
} | ||
|
||
def trunc(i: Int): Long = if (i < 0) 0L else i.toLong | ||
|
||
property("toLong => lowBits is identity") { | ||
val big = 0x80000000L | ||
|
||
assertEquals(Nat.lowBits(big), Int.MinValue, s"${Nat.lowBits(big)}") | ||
forAll(Gen.chooseNum(Int.MinValue, Int.MaxValue)) { (i: Int) => | ||
val l = Nat.toLong(i) | ||
assertEquals(Nat.lowBits(Nat.toLong(i)), i, s"long = $l") | ||
} | ||
} | ||
|
||
property("x + y homomorphism") { | ||
forAll(genNat, genNat) { (ni, nj) => | ||
val nk = ni + nj | ||
assertEquals(nk.toBigInt, ni.toBigInt + nj.toBigInt) | ||
} | ||
} | ||
|
||
property("x * y = y * x") { | ||
forAll(genNat, genNat) { (n1, n2) => | ||
assertEquals(n1 * n2, n2 * n1) | ||
} | ||
} | ||
|
||
property("x + y = y + x") { | ||
forAll(genNat, genNat) { (n1, n2) => | ||
assertEquals(n1 + n2, n2 + n1) | ||
} | ||
} | ||
|
||
property("x * y homomorphism") { | ||
forAll(genNat, genNat) { (ni, nj) => | ||
val nk = ni * nj | ||
assertEquals(nk.toBigInt, ni.toBigInt * nj.toBigInt) | ||
} | ||
} | ||
|
||
property("x.inc == x + 1") { | ||
forAll(genNat) { n => | ||
val i = n.inc | ||
val a = n + Nat.one | ||
assertEquals(i.toBigInt, a.toBigInt) | ||
} | ||
} | ||
|
||
property("x.dec == x - 1 when x > 0") { | ||
forAll(genNat) { n => | ||
val i = n.dec.toBigInt | ||
if (n == Nat.zero) assertEquals(i, BigInt(0)) | ||
else assertEquals(i, n.toBigInt - 1) | ||
} | ||
} | ||
|
||
property("x.shift_32 == x.toBigInt * 2^32") { | ||
val n1 = BigInt("8588888260151524380556863712485265508") | ||
val shift = n1 << 32 | ||
assertEquals(Nat.fromBigInt(n1).shift_32.toBigInt, shift) | ||
|
||
forAll(genNat) { n => | ||
val s = n.shift_32 | ||
val viaBigInt = n.toBigInt << 32 | ||
val viaTimes = n * Nat.two_32 | ||
assertEquals(s.toBigInt, viaBigInt, s"viaTimes = $viaTimes") | ||
assertEquals(s, viaTimes) | ||
} | ||
} | ||
|
||
property("Nat.fromBigInt/toBigInt") { | ||
assertEquals(Nat.fromLong(Long.MaxValue).toBigInt, BigInt(Long.MaxValue)) | ||
|
||
forAll { (bi0: BigInt) => | ||
val bi = bi0.abs | ||
val n = Nat.fromBigInt(bi) | ||
val b2 = n.toBigInt | ||
assertEquals(b2, bi) | ||
} | ||
} | ||
|
||
property("x.inc.dec == x") { | ||
forAll(genNat) { n => | ||
assertEquals(n.inc.dec, n) | ||
} | ||
} | ||
|
||
property("x.dec.inc == x || x.isZero") { | ||
forAll(genNat) { n => | ||
assert((n.dec.inc == n) || n.isZero) | ||
} | ||
} | ||
|
||
property("if the value is > Long.MaxValue maybeLong = None") { | ||
forAll(genNat) { n => | ||
val bi = n.toBigInt | ||
val ml = n.maybeLong | ||
assertEquals(ml.isEmpty, bi > Long.MaxValue) | ||
} | ||
} | ||
property("the string repr matches toBigInt") { | ||
forAll(genNat) { n => | ||
assertEquals(n.toString, n.toBigInt.toString) | ||
} | ||
} | ||
} |
Oops, something went wrong.