Skip to content

Commit

Permalink
Status before second class session on first-class functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
parkerziegler committed Nov 4, 2022
1 parent 100dbcb commit 43bb105
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 43 deletions.
11 changes: 10 additions & 1 deletion firstclassfns.lisp
Expand Up @@ -6,4 +6,13 @@
(print (f (lambda (x) (+ x x))))

(define (f g) (g 2))
(let ((y 3)) (print (f (lambda (x) (+ x y)))))
(let ((y 3)) (print (f (lambda (x) (+ x y)))))

(define (range lo hi)
(if (< lo hi)
(pair lo (range (add1 lo) hi))
false))
(define (map f l)
(if (not l) l
(pair (f (left l)) (map f (right l)))))
(print (map (lambda (x) (+ x 1)) (range 0 2)))
90 changes: 75 additions & 15 deletions lib/ast.ml
@@ -1,4 +1,5 @@
open S_exp
open Util

type prim0 = ReadNum | Newline

Expand Down Expand Up @@ -57,10 +58,24 @@ type expr =
| Do of expr list
| Num of int
| Var of string
| Call of string * expr list
| Call of expr * expr list
| True
| False

type expr_lam =
| Prim0 of prim0
| Prim1 of prim1 * expr_lam
| Prim2 of prim2 * expr_lam * expr_lam
| Let of string * expr_lam * expr_lam
| If of expr_lam * expr_lam * expr_lam
| Do of expr_lam list
| Num of int
| Var of string
| Call of expr_lam * expr_lam list
| True
| False
| Lambda of string list * expr_lam

type defn = {name: string; args: string list; body: expr}

type program = {defns: defn list; body: expr}
Expand All @@ -69,7 +84,11 @@ let is_defn defns name = List.exists (fun d -> d.name = name) defns

let get_defn defns name = List.find (fun d -> d.name = name) defns

let rec expr_of_s_exp : s_exp -> expr = function
let is_sym e = match e with Sym _ -> true | _ -> false

let as_sym e = match e with Sym s -> s | _ -> raise Not_found

let rec expr_lam_of_s_exp : s_exp -> expr_lam = function
| Num x ->
Num x
| Sym "true" ->
Expand All @@ -79,26 +98,64 @@ let rec expr_of_s_exp : s_exp -> expr = function
| Sym var ->
Var var
| Lst [Sym "let"; Lst [Lst [Sym var; exp]]; body] ->
Let (var, expr_of_s_exp exp, expr_of_s_exp body)
Let (var, expr_lam_of_s_exp exp, expr_lam_of_s_exp body)
| Lst (Sym "do" :: exps) when List.length exps > 0 ->
Do (List.map expr_of_s_exp exps)
Do (List.map expr_lam_of_s_exp exps)
| Lst [Sym "if"; test_s; then_s; else_s] ->
If (expr_of_s_exp test_s, expr_of_s_exp then_s, expr_of_s_exp else_s)
If
( expr_lam_of_s_exp test_s
, expr_lam_of_s_exp then_s
, expr_lam_of_s_exp else_s )
| Lst [Sym "lambda"; Lst args; body] when List.for_all is_sym args ->
Lambda (List.map as_sym args, expr_lam_of_s_exp body)
| Lst [Sym prim] when Option.is_some (prim0_of_string prim) ->
Prim0 (Option.get (prim0_of_string prim))
| Lst [Sym prim; arg] when Option.is_some (prim1_of_string prim) ->
Prim1 (Option.get (prim1_of_string prim), expr_of_s_exp arg)
Prim1 (Option.get (prim1_of_string prim), expr_lam_of_s_exp arg)
| Lst [Sym prim; arg1; arg2] when Option.is_some (prim2_of_string prim) ->
Prim2
( Option.get (prim2_of_string prim)
, expr_of_s_exp arg1
, expr_of_s_exp arg2 )
| Lst (Sym f :: args) ->
Call (f, List.map expr_of_s_exp args)
, expr_lam_of_s_exp arg1
, expr_lam_of_s_exp arg2 )
| Lst (f :: args) ->
Call (expr_lam_of_s_exp f, List.map expr_lam_of_s_exp args)
| e ->
raise (BadSExpression e)

let rec expr_of_expr_lam (defns : defn list ref) : expr_lam -> expr = function
| Num x ->
Num x
| Var s ->
Var s
| True ->
True
| False ->
False
| If (test_exp, then_exp, else_exp) ->
If
( expr_of_expr_lam defns test_exp
, expr_of_expr_lam defns then_exp
, expr_of_expr_lam defns else_exp )
| Let (var, exp, body) ->
Let (var, expr_of_expr_lam defns exp, expr_of_expr_lam defns body)
| Prim0 p ->
Prim0 p
| Prim1 (p, e) ->
Prim1 (p, expr_of_expr_lam defns e)
| Prim2 (p, e1, e2) ->
Prim2 (p, expr_of_expr_lam defns e1, expr_of_expr_lam defns e2)
| Do exps ->
Do (List.map (expr_of_expr_lam defns) exps)
| Call (exp, args) ->
Call (expr_of_expr_lam defns exp, List.map (expr_of_expr_lam defns) args)
| Lambda (args, body) ->
let name = gensym "_lambda" in
let body = expr_of_expr_lam defns body in
defns := {name; args; body} :: !defns ;
Var name

let program_of_s_exps (exps : s_exp list) : program =
let defns = ref [] in
let rec get_args args =
match args with
| Sym v :: args ->
Expand All @@ -111,19 +168,22 @@ let program_of_s_exps (exps : s_exp list) : program =
let get_defn = function
| Lst [Sym "define"; Lst (Sym name :: args); body] ->
let args = get_args args in
{name; args; body= expr_of_s_exp body}
{name; args; body= body |> expr_lam_of_s_exp |> expr_of_expr_lam defns}
| e ->
raise (BadSExpression e)
in
let rec go exps defns =
let rec go exps =
match exps with
| [e] ->
{defns= List.rev defns; body= expr_of_s_exp e}
let body = e |> expr_lam_of_s_exp |> expr_of_expr_lam defns in
{defns= List.rev !defns; body}
| d :: exps ->
go exps (get_defn d :: defns)
let defn = get_defn d in
defns := defn :: !defns ;
go exps
| _ ->
raise (BadSExpression (Sym "empty"))
in
go exps []
go exps

exception BadExpression of expr
50 changes: 33 additions & 17 deletions lib/compile.ml
Expand Up @@ -14,6 +14,8 @@ let bool_tag = 0b0011111
let heap_mask = 0b111
let pair_tag = 0b010

let fn_tag = 0b110

let operand_of_bool (b:bool) : operand =
Imm (((if b then 1 else 0) lsl bool_shift) lor bool_tag)
let operand_of_num (x:int) : operand =
Expand Down Expand Up @@ -50,35 +52,41 @@ let ensure_pair (op: operand) : directive list =
Jnz "error"
]

let ensure_fn (op: operand) : directive list =
[
Mov (Reg R8, op);
And (Reg R8, Imm heap_mask);
Cmp (Reg R8, Imm fn_tag);
Jnz "error"
]

let stack_address (stack_index : int) = MemOffset (Reg Rsp, Imm stack_index)

let align_stack_index (stack_index : int) : int =
if stack_index mod 16 = -8 then stack_index else stack_index - 8

let rec compile_exp (defns : defn list) tab (stack_index : int) (program:expr) (is_tail : bool) : directive list =
match program with
| Call (f, args) when is_defn defns f && not is_tail ->
let defn = get_defn defns f in
if List.length args = List.length defn.args then
| Call (f, args) when not is_tail ->
let stack_base = align_stack_index (stack_index + 8) in
let compiled_args =
args
|> List.mapi (fun i arg ->
compile_exp defns tab (stack_base - ((i+2) * 8)) arg false
@ [Mov (stack_address (stack_base - ((i+2) * 8)), Reg Rax)]
compile_exp defns tab (stack_base - ((i + 2) * 8)) arg false
@ [Mov (stack_address (stack_base - ((i + 2) * 8)), Reg Rax)]
)
|> List.concat in
compiled_args
|> List.concat
in
compiled_args
@ compile_exp defns tab (stack_base - ((List.length args + 2) * 8)) f false
@ ensure_fn (Reg Rax)
@ [Sub (Reg Rax, Imm fn_tag)]
@ [
Add (Reg Rsp, Imm stack_base);
Call (defn_label f);
ComputedCall (Reg Rax);
Sub (Reg Rsp, Imm stack_base);
]
else
raise (BadExpression program)
| Call (f, args) when is_defn defns f && is_tail ->
let defn = get_defn defns f in
if List.length args = List.length defn.args then
| Call (f, args) when is_tail ->
let compiled_args =
args
|> List.mapi (fun i arg ->
Expand All @@ -94,10 +102,11 @@ let rec compile_exp (defns : defn list) tab (stack_index : int) (program:expr) (
)
|> List.concat in
compiled_args
@ compile_exp defns tab (stack_index - (8 * List.length args)) f false
@ ensure_fn (Reg Rax)
@ [Sub (Reg Rax, Imm fn_tag)]
@ moved_args
@ [Jmp (defn_label f) ]
else
raise (BadExpression program)
@ [ComputedJmp (Reg Rax) ]
| Call _ ->
raise (BadExpression program)
| Num n ->
Expand Down Expand Up @@ -163,6 +172,9 @@ let rec compile_exp (defns : defn list) tab (stack_index : int) (program:expr) (
@ compile_exp defns (Symtab.add var stack_index tab) (stack_index - 8) body is_tail
| Var var when Symtab.mem var tab ->
[Mov (Reg Rax, stack_address (Symtab.find var tab))]
| Var var when is_defn defns var ->
[ LeaLabel (Reg Rax, defn_label var)
; Or (Reg Rax, Imm fn_tag)]
| Var _ ->
raise (BadExpression program)
| Prim1 (Not, arg) ->
Expand Down Expand Up @@ -245,7 +257,7 @@ let compile_defn defns defn =
|> List.mapi (fun i arg -> (arg, -8 * (i + 1)))
|> Symtab.of_list
in
[Label (defn_label defn.name)]
[Align 8; Label (defn_label defn.name)]
@ compile_exp defns ftab (-8 * (List.length defn.args + 1)) defn.body true
@ [Ret]

Expand All @@ -271,6 +283,8 @@ let compile_to_file (program: string): unit =
compile_to_file program;
let format = (if Asm.macos then "macho64" else "elf64") in
ignore (Unix.system ("nasm program.s -f " ^ format ^ " -o program.o"));
(* Add -fno-pie -no-pie flags on Linux to make first-class functions work.
See: https://stackoverflow.com/questions/43367427/32-bit-absolute-addresses-no-longer-allowed-in-x86-64-linux *)
ignore (Unix.system "gcc program.o runtime.c -o program");
let inp = Unix.open_process_in "./program" in
let r = input_line inp in
Expand All @@ -279,6 +293,8 @@ let compile_to_file (program: string): unit =
let compile_and_run_io (program : string) (input : string) : string =
compile_to_file program ;
ignore (Unix.system "nasm program.s -f macho64 -o program.o") ;
(* Add -fno-pie -no-pie flags on Linux to make first-class functions work.
See: https://stackoverflow.com/questions/43367427/32-bit-absolute-addresses-no-longer-allowed-in-x86-64-linux *)
ignore (Unix.system "gcc program.o runtime.c -o program") ;
let inp, outp = Unix.open_process "./program" in
output_string outp input ;
Expand Down
30 changes: 20 additions & 10 deletions lib/interp.ml
Expand Up @@ -2,29 +2,37 @@ open S_exp
open Ast
open Util

type value = Number of int | Boolean of bool | Pair of (value * value)
type value =
| Number of int
| Boolean of bool
| Pair of (value * value)
| Function of string

let rec string_of_value (v: value) : string =
match v with
| Number n -> string_of_int n
| Boolean b -> if b then "true" else "false"
| Pair (v1, v2) -> Printf.sprintf "(pair %s %s)" (string_of_value v1) (string_of_value v2)
| Function _ -> "<function>"

let input_channel = ref stdin
let output_channel = ref stdout

let rec interp_exp (defns : defn list) (env: value symtab) (exp:expr): value =
match exp with
| Call (f, args) when is_defn defns f ->
let defn = get_defn defns f in
| Call (f, args) -> (
let vals = List.map (interp_exp defns env) args in
let fenv = (List.combine defn.args vals) |> Symtab.of_list in
if List.length args = List.length defn.args then
interp_exp defns fenv defn.body
else
raise (BadExpression exp)
| Call _ ->
raise (BadExpression exp)
let fv = interp_exp defns env f in
match fv with
| Function name ->
let defn = get_defn defns name in
let fenv = (List.combine defn.args vals) |> Symtab.of_list in
if List.length args = List.length defn.args then
interp_exp defns fenv defn.body
else
raise (BadExpression exp)
| _ -> raise (BadExpression exp)
)
| Num n -> Number n
| True -> Boolean true
| False -> Boolean false
Expand Down Expand Up @@ -54,6 +62,8 @@ let rec interp_exp (defns : defn list) (env: value symtab) (exp:expr): value =
)
| Var var when Symtab.mem var env ->
Symtab.find var env
| Var var when is_defn defns var ->
Function var
| Var _ ->
raise (BadExpression exp)
| Let (var, e, body) ->
Expand Down

0 comments on commit 43bb105

Please sign in to comment.