In [3]:
from dataclasses import dataclass
from typing import Sequence

@dataclass
class Term: pass

@dataclass
class Var(Term):
    name: str

Environment: TypeAlias = Sequence[tuple[Var,Term]]

@dataclass
class Fun(Term):
    x: Var
    t: Term

@dataclass
class App(Term):
    l: Term
    r: Term

@dataclass
class Value(Term): pass

@dataclass
class Num(Value):
    n: int

@dataclass
class Op(Term):
    op: str
    left: Term
    right: Term

@dataclass
class Ifz(Term):
    cond: Term
    thenTerm: Term
    elseTerm: Term

@dataclass
class Fix(Term):
    x: Var
    t: Term

@dataclass
class Let(Term):
    x: Var
    t: Term
    body: Term

@dataclass
class Closure(Value):
    x: Var
    t: Term
    e: Environment

@dataclass
class Thunk(Term):
    term: Fix
    e: Environment

def searchEnv(x: Var, e: Environment) -> Value:
    for (var, val) in e:
        if var == x:
            match val:
                case Thunk(term, e2):
                    return interp(term, e2)
                case _: return val
    raise Exception("Unbound variable: " + str(x))

def extendEnv(x: Var, t: Term, e: Environment) -> Environment:
    return [(x, t)] + e

def checkNumber(num: Value) -> int:
    match num:
        case Num(n): return n
    raise Exception("Not a number: " + str(num))

def interp(t: Term, e: Environment) -> Value:
    match t:
        case Var(x): return searchEnv(t, e)
        case App(t, u):
            w = interp(u, e)
            match interp(t, e):
                case Closure(x, t2, e2):
                    return interp(t2, extendEnv(x, w, e2))
                case _: raise Exception("Illegal function: " + str(v))
        case Fun(x, t2): return Closure(x, t2, e)
        case Num(n): return t
        case Op(op, l, r):
            lv = checkNumber(interp(l, e))
            rv = checkNumber(interp(r, e))
            match op:
                case '+': return Num(lv + rv)
                case '-': return Num(lv - rv)
                case '*': return Num(lv * rv)
                case '/': return Num(lv / rv)
                case '%': return Num(lv % rv)
                case '<': return Num(1 if lv < rv else 0)
                case '=': return Num(1 if lv == rv else 0)
                case _: raise Exception("Unknown op: " + str(op))
        case Ifz(cond, t, u):
            c = interp(cond, e)
            match c:
                case Num(0): return interp(t, e)
                case Num(_): return interp(u, e)
                case _: raise Exception("Condition not a number: " + str(c))
        case Fix(x, t2):
            return interp(t2, extendEnv(x, Thunk(t, e), e))
        case Let(x, t, u):
            w = interp(t, e)
            return interp(u, extendEnv(x, w, e))

print(interp(Op('+', Num(1), Num(2)), []))  # Numeric expression
print(interp(Var('x'), [(Var('x'), Num(1))]))  # Variable
print(interp(App(Fun(Var('x'), Op('*', Var('x'), Var('x'))), Num(3)), [])) # Function and application
print(interp(App(Fix(Var('f'), Fun(Var('x'), Ifz(Var('x'), Num(1), Op('*', Var('x'), App(Var('f'), Op('-', Var('x'), Num(1))))))), Num(10)), [])) # Recursive function call using Fix


Num(n=3)
Num(n=1)
Num(n=9)
Num(n=3628800)


In [7]:
@dataclass
class FixFun(Term):
    f: Var
    x: Var
    t: Term

@dataclass
class RecClosure(Value):
    f: Var
    x: Var
    t: Term
    e: Environment

def searchEnv(x: Var, e: Environment) -> Value:
    for (var, val) in e:
        if var == x:
            return val
    raise Exception("Unbound variable: " + str(x))

def interp(t: Term, e: Environment) -> Value:
    match t:
        case Var(x): return searchEnv(t, e)
        case App(t, u):
            w = interp(u, e)
            v = interp(t, e)
            match v:
                case RecClosure(f, x, t2, e2):
                    return interp(t2, extendEnv(x, w, extendEnv(f, v, e2)))
                case _: raise Exception("Illegal function: " + str(v))
        case Fun(x, t2): return RecClosure(Var(''), x, t2, e)
        case Num(n): return t
        case Op(op, l, r):
            lv = checkNumber(interp(l, e))
            rv = checkNumber(interp(r, e))
            match op:
                case '+': return Num(lv + rv)
                case '-': return Num(lv - rv)
                case '*': return Num(lv * rv)
                case '/': return Num(lv / rv)
                case '%': return Num(lv % rv)
                case '<': return Num(1 if lv < rv else 0)
                case '=': return Num(1 if lv == rv else 0)
                case _: raise Exception("Unknown op: " + str(op))
        case Ifz(cond, t, u):
            c = interp(cond, e)
            match c:
                case Num(0): return interp(t, e)
                case Num(_): return interp(u, e)
                case _: raise Exception("Condition not a number: " + str(c))
        case FixFun(f, x, t2):
            return RecClosure(f, x, t2, e)
        case Let(x, t, u):
            w = interp(t, e)
            return interp(u, extendEnv(x, w, e))

print(interp(App(FixFun(Var('f'), Var('x'), Ifz(Var('x'), Num(1), Op('*', Var('x'), App(Var('f'), Op('-', Var('x'), Num(1)))))), Num(10)), [])) # Recursive function call using FixFun


Num(n=3628800)


In [14]:
from dataclasses import dataclass
from typing import Sequence
from typing import Union

@dataclass
class Instruction: pass

@dataclass
class Ldi(Instruction): 
    n: int

@dataclass
class Push(Instruction): pass

@dataclass
class Extend(Instruction): pass

@dataclass
class Search(Instruction):
    n: int

@dataclass
class Pushenv(Instruction): pass

@dataclass
class Popenv(Instruction): pass

@dataclass
class Mkclos(Instruction):
    i: Sequence[code]

@dataclass
class Apply(Instruction): pass

@dataclass
class Test(Instruction):
    i: Sequence[code]
    j: Sequence[code]

@dataclass
class Add(Instruction): pass

@dataclass
class Sub(Instruction): pass

@dataclass
class Mult(Instruction): pass

@dataclass
class Div(Instruction): pass

@dataclass
class MachineClosure(Value):
    i: Sequence[Code]
    e: Sequence[value]

Stack: TypeAlias = Sequence[Value | Sequence[Value]]

def trimValue(v): # for debug
    match v:
        case Num(n): return v
        case [item, *items]: return "[{}, ...]".format(trimValue(item))
        case []: return "[]"
        case _: return type(v).__name__

# stackは教科書と逆に末尾に破壊的にpushし、末尾からpopする。
# envは教科書と逆に先頭にコピーして追加する。Searchの引数は先頭からの位置を表す。
# codeは先頭から破壊的にpopし、先頭にコピーして結合する。
def PCFmachine(acc: Value, stack: Stack, env: Sequence[Value], code: Sequence[Instruction]) -> Value:
    while code:
        # print("acc={}, stack={}, env={}, code={}".format(acc, trimValue(stack), trimValue(env), trimValue(code)))
        insn = code.pop(0)
        match insn:
            case Mkclos(i):
                acc = MachineClosure(i, env)
            case Push():
                stack.append(acc)
            case Extend():
                env = [acc] + env
            case Search(n):
                acc = env[n]
            case Pushenv():
                stack.append(env)
            case Popenv():
                env = stack.pop()
            case Apply():
                w = stack.pop()
                match acc:
                    case MachineClosure(i, env):
                        env = [w, acc] + env
                        code = i + code
                    case _:
                        raise Exception("Not a closure: " + str(acc))
            case Ldi(n):
                acc = Num(n)
            case Add():
                m = stack.pop(); acc = Num(acc.n + m.n)
            case Sub():
                m = stack.pop(); acc = Num(acc.n - m.n)
            case Mult():
                m = stack.pop(); acc = Num(acc.n * m.n)
            case Div():
                m = stack.pop(); acc = Num(acc.n / m.n)
            case Test(i, j):
                match acc:
                    case Num(0):
                        code = i + code
                    case Num(_):
                        code = j + code
                    case _:
                        raise Exception("Not a number: " + str(acc))
    return acc

def compilePCF(t: term, env: Sequence[Var]) -> Sequence[Code]:
    match t:
        case Var(name):
            for (index, v) in enumerate(env):
                if v.name == name:
                    return [Search(index)]
            raise Exception("Unbound variable: " + name)
        case App(t, u):
            return [Pushenv()] + compilePCF(u, env) + [Push()] + compilePCF(t, env) + [Apply(), Popenv()]
        case Fun(x, t):
            return [Mkclos(compilePCF(t, [x, Var('')] + env))]
        case FixFun(f, x, t):
            return [Mkclos(compilePCF(t, [x, f] + env))]
        case Num(n):
            return [Ldi(n)]
        case Op(op, t, u):
            dic = {'+': Add(), '-': Sub(), '*': Mult(), '/': Div()}
            return compilePCF(u, env) + [Push()] + compilePCF(t, env) + [dic[op]]
        case Ifz(t, u, v):
            return compilePCF(t, env) + [Test(compilePCF(u, env), compilePCF(v, env))]
        case Let(x, t, u):
            return [Pushenv()] + compilePCF(t, env) + [Extend()] + compilePCF(u, [x] + env) + [Popenv()]
        case _:
            raise Exception("Illegal term: " + str(t))

import pprint

def compileTest(t: Term) -> None:
    print("# Source term")
    pprint.pp(t)
    print("# Compiled code")
    code = compilePCF(t, [])
    pprint.pp(code)
    print("# Result")
    pprint.pp(PCFmachine(Num(0), [], [], code))
    print()


compileTest(App(Fun(Var('x'), Var('x')),Num(1))) # => 1
compileTest(App(Fun(Var('x'), Op('*', Var('x'), Var('x'))),Num(3))) # => 9
compileTest(App(App(FixFun(Var('f'),Var('x'), Fun(Var('y'),Op('*', Var('x'), Var('y')))),Num(2)),Num(3))) # => 6

factcall = Let(Var('f'), 
                FixFun(Var('g'), Var('x'), 
                       Ifz(Var('x'), Num(1), 
                            Op('*', Var('x'),
                                    App(Var('g'), Op('-', Var('x'), Num(1)))))),
                App(Var('f'), Num(6)))
compileTest(factcall)

# Source term
App(l=Fun(x=Var(name='x'), t=Var(name='x')), r=Num(n=1))
# Compiled code
[Pushenv(), Ldi(n=1), Push(), Mkclos(i=[Search(n=0)]), Apply(), Popenv()]
# Result
Num(n=1)

# Source term
App(l=Fun(x=Var(name='x'), t=Op(op='*', left=Var(name='x'), right=Var(name='x'))), r=Num(n=3))
# Compiled code
[Pushenv(),
 Ldi(n=3),
 Push(),
 Mkclos(i=[Search(n=0), Push(), Search(n=0), Mult()]),
 Apply(),
 Popenv()]
# Result
Num(n=9)

# Source term
App(l=App(l=FixFun(f=Var(name='f'), x=Var(name='x'), t=Fun(x=Var(name='y'), t=Op(op='*', left=Var(name='x'), right=Var(name='y')))), r=Num(n=2)), r=Num(n=3))
# Compiled code
[Pushenv(),
 Ldi(n=3),
 Push(),
 Pushenv(),
 Ldi(n=2),
 Push(),
 Mkclos(i=[Mkclos(i=[Search(n=0), Push(), Search(n=2), Mult()])]),
 Apply(),
 Popenv(),
 Apply(),
 Popenv()]
# Result
Num(n=6)

# Source term
Let(x=Var(name='f'), t=FixFun(f=Var(name='g'), x=Var(name='x'), t=Ifz(cond=Var(name='x'), thenTerm=Num(n=1), elseTerm=Op(op='*', left=Var(name='x'), right=App(l=Var(name='g')