-
Notifications
You must be signed in to change notification settings - Fork 11
/
Nat.bosatsu
164 lines (139 loc) · 4.4 KB
/
Nat.bosatsu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
package Bosatsu/Nat
from Bosatsu/Predef import add as operator +, times as operator *
export Nat(), times2, add, sub_Nat, mult, exp, to_Int, to_Nat, is_even, div2, cmp_Nat
# This is the traditional encoding of natural numbers
# it is useful when you are iterating on all values
enum Nat: Zero, Succ(prev: Nat)
def cmp_Nat(a: Nat, b: Nat) -> Comparison:
recur a:
case Zero:
match b:
case Zero: EQ
case _: LT
case Succ(n):
match b:
case Zero: GT
case Succ(m): cmp_Nat(n, m)
# This is an O(n) operation
def times2(n: Nat) -> Nat:
def loop(n: Nat, acc: Nat):
recur n:
Zero: acc
Succ(prev):
# 2*(n + 1) = 2*n + 1 + 1
loop(prev, Succ(Succ(acc)))
loop(n, Zero)
def add(n1: Nat, n2: Nat) -> Nat:
def loop(n1: Nat, n2: Nat):
recur n1:
Zero: n2
Succ(prev_n1): loop(prev_n1, Succ(n2))
if n2 matches Zero: n1
else: loop(n1, n2)
def sub_Nat(n1: Nat, n2: Nat) -> Nat:
recur n2:
Zero: n1
Succ(prev_n2):
# (n1 + 1) - (n2 + 1) == (n1 - n2)
match n1:
case Zero: Zero
case Succ(prev_n1):
sub_Nat(prev_n1, prev_n2)
def mult(n1: Nat, n2: Nat) -> Nat:
# return n1 * n2 + c
# note: (n1 + 1) * (n2 + 1) + c ==
# n1 * n2 + (n1 + n2 + 1 + c)
def loop(n1: Nat, n2: Nat, c: Nat):
recur n1:
Zero: c
Succ(n1):
match n2:
Zero: c
Succ(n2):
c1 = Succ(add(add(n1, n2), c))
loop(n1, n2, c1)
# we repeatedly do add(n, m)
# where we keep stepping down by one on both sides.
# add is more efficient if the lhs is smaller
# so we check that first
if cmp_Nat(n1, n2) matches GT: loop(n2, n1, Zero)
else: loop(n1, n2, Zero)
def is_even(n: Nat) -> Bool:
def loop(n: Nat, res: Bool) -> Bool:
recur n:
case Zero: res
case Succ(n): loop(n, False if res else True)
loop(n, True)
def div2(n: Nat) -> Nat:
def loop(n: Nat, acc: Nat, is_even):
recur n:
case Zero: acc
case Succ(n):
# (n + 1) / 2 = n/2 if n is even, else n/2 + 1
if is_even: loop(n, Succ(acc), False)
else: loop(n, acc, True)
loop(n, Zero, is_even(n))
one = Succ(Zero)
def exp(base: Nat, power: Nat) -> Nat:
match base:
case Zero: one if power matches Zero else Zero
case Succ(Zero): one
case two_or_more:
def loop(power, acc):
recur power:
case Zero: acc
case Succ(prev):
# b^(n + 1) = (b^n) * b
loop(prev, acc.mult(two_or_more))
loop(power, one)
def to_Int(n: Nat) -> Int:
def loop(acc: Int, n: Nat):
recur n:
Zero: acc
Succ(n): loop(acc + 1, n)
loop(0, n)
def to_Nat(i: Int) -> Nat:
int_loop(i, Zero, \i, nat -> (i.sub(1), Succ(nat)))
################
# Test code below
################
n1 = Succ(Zero)
n2 = Succ(n1)
n3 = Succ(n2)
n4 = Succ(n3)
n5 = Succ(n4)
def operator ==(i0: Int, i1: Int):
cmp_Int(i0, i1) matches EQ
def addLaw(n1: Nat, n2: Nat, label: String) -> Test:
Assertion(add(n1, n2).to_Int() == (n1.to_Int() + n2.to_Int()), label)
def multLaw(n1: Nat, n2: Nat, label: String) -> Test:
Assertion(mult(n1, n2).to_Int() == (n1.to_Int() * n2.to_Int()), label)
def from_to_law(i: Int, message: String) -> Test:
Assertion(i.to_Nat().to_Int() == i, message)
from_to_suite = TestSuite("to_Nat/to_Int tests", [
Assertion(-1.to_Nat().to_Int() == 0, "-1 -> 0"),
Assertion(-42.to_Nat().to_Int() == 0, "-42 -> 0"),
from_to_law(0, "0"),
from_to_law(1, "1"),
from_to_law(10, "10"),
from_to_law(42, "42"),
])
tests = TestSuite("Nat tests",
[
addLaw(Zero, Zero, "0 + 0"),
addLaw(Zero, n1, "0 + 1"),
addLaw(n1, Zero, "1 + 0"),
addLaw(n1, n2, "1 + 2"),
addLaw(n2, n1, "2 + 1"),
multLaw(Zero, Zero, "0 * 0"),
multLaw(Zero, n1, "0 * 1"),
multLaw(n1, Zero, "1 * 0"),
multLaw(n1, n2, "1 * 2"),
multLaw(n2, n1, "2 * 1"),
from_to_suite,
Assertion(exp(n2, n5).to_Int() matches 32, "exp(2, 5) == 32"),
Assertion(n1.div2().to_Int() matches 0, "1 div2 == 0"),
Assertion(n2.div2().to_Int() matches 1, "2 div2 == 1"),
Assertion(n3.div2().to_Int() matches 1, "3 div2 == 1"),
Assertion(n4.div2().to_Int() matches 2, "4 div2 == 2"),
])