Permalink
Find file
Fetching contributors…
Cannot retrieve contributors at this time
424 lines (336 sloc) 11.2 KB
(*
* zooly.ml
*
* This file is part of Zooly.
*
* Copyright (c) 2012 Jesse Haber-Kucharsky
* For licence information, see the included file
* LICENSE.txt
*)
open Batteries_uni ;;
open Result ;;
open Parser ;;
(* The structure of the allowed expressions. *)
type operator =
| Add
| Subtract
| Multiply
| Divide
| Pow
;;
type expr =
| BinaryExpr of expr * operator * expr
| Variable of string
| FunctionInvocation of string * expr list
| Number of float
| Assignment of string * expr
;;
(* Specific Parsers *)
class variable_parser () =
object
inherit [expr] parseable
method parse pi =
let s_p = new string_parser () in
match (s_p#parse pi) with
| Ok(ParseResult(str, c)) -> Ok(ParseResult(Variable(str), c))
| Bad(_) -> Bad("Expected a variable")
end ;;
class additive_operator_parser () =
object
inherit [operator] parseable
method parse pi =
let err_msg = "Expected '+' or '-'" in
let op_p = new alternate_parser
(new string_const_parser "+")
(new string_const_parser "-") in
match (op_p#parse pi) with
| Ok(ParseResult(op, c)) ->
(match op with
| "+" -> Ok(ParseResult(Add, c))
| "-" -> Ok(ParseResult(Subtract, c))
(* This case will never happen. *)
| _ -> Bad(err_msg))
| _ -> Bad(err_msg)
end ;;
class multiplicative_operator_parser () =
object
inherit [operator] parseable
method parse pi =
let err_msg = "Expected '*' or '/'" in
let op_p = new alternate_parser
(new string_const_parser "*")
(new string_const_parser "/") in
match (op_p#parse pi) with
| Ok(ParseResult(op, c)) ->
(match op with
| "*" -> Ok(ParseResult(Multiply, c))
| "/" -> Ok(ParseResult(Divide, c))
(* This case will never happend. *)
| _ -> Bad(err_msg))
| _ -> Bad(err_msg)
end ;;
class power_operator_parser () =
object
inherit [operator] parseable
method parse pi =
let err_msg = "Expected '^'" in
let p = new string_const_parser "^" in
match (p#parse pi) with
| Ok(_) -> Ok(ParseResult(Pow, 1))
| Bad(_) -> Bad(err_msg)
end ;;
class number_parser () =
object
inherit [expr] parseable
method parse pi =
let count = ref 0 in
let mult = ref 1.0 in
(* A number is optionally prefixed with a '+' or '-' sign. *)
let sign_parser = new optional_parser (new alternate_parser
(new string_const_parser "+")
(new string_const_parser "-")) in
begin
match (sign_parser#parse pi) with
| Ok(ParseResult(Some(sign), c)) ->
count := !count + c;
if sign = "-" then mult := -1.0
| _ -> ()
end;
let try_parse tok =
try
count := !count + 1;
let conv = Float.of_string tok in
Some(Number(!mult *. conv))
with _ -> None in
let result = Option.bind try_parse (pi#shift ()) in
match result with
| Some(f) -> Ok(ParseResult(f, 1))
| None ->
pi#unshift !count;
Bad("Expected a number")
end ;;
let rec compose_arith lhs rhs_list =
match rhs_list with
| None -> lhs
| Some(rhs) ->
List.fold_left (fun lhs part -> BinaryExpr(lhs, (fst part), (snd part))) lhs rhs ;;
class arith_expr_parser op_p lower_p =
object
inherit [expr] parseable
val op_p = op_p
val lower_p = lower_p
method parse pi =
let p = new sequence_parser compose_arith
lower_p
(new zero_or_more_parser (new sequence_parser (fun x y -> (x, y))
op_p
lower_p)) in
p#parse pi
end ;;
class expr_parser () =
object (self)
inherit [expr] parseable
method parse pi =
let number = new number_parser () in
let variable = new variable_parser () in
let identifier = new string_parser () in
let peren_open = new string_const_parser "(" in
let peren_close = new string_const_parser ")" in
let function_invocation = new sequence_parser (fun name args -> FunctionInvocation(name, args))
identifier
(new sequence_parser (fun lp args -> args)
peren_open
(new sequence_parser (fun args rp -> args)
(new list_parser self (new string_const_parser ","))
peren_close)) in
let atom = new alternate_parser
number
(new alternate_parser
function_invocation
variable) in
(* Notice the 'self' which makes the whole thing recursive. *)
let unary_expr = new alternate_parser
(new sequence_parser (fun lp e -> e)
peren_open
(new sequence_parser (fun e rp -> e)
self
peren_close))
atom in
let power_expr = new arith_expr_parser
(new power_operator_parser ())
unary_expr in
let multiplicative_expr = new arith_expr_parser
(new multiplicative_operator_parser ())
power_expr in
let additive_expr = new arith_expr_parser
(new additive_operator_parser ())
multiplicative_expr in
let expr = additive_expr in
expr#parse pi
end ;;
class line_parser () =
object (self)
inherit [expr] parseable
method parse pi =
let expr = new expr_parser () in
let assignment_keyword = new string_const_parser "let" in
let identifier = new string_parser () in
let equals_sign = new string_const_parser "=" in
let assignment = new sequence_parser (fun l ass -> Assignment(fst ass, snd ass))
assignment_keyword
(new sequence_parser (fun ident expr -> (ident, expr))
identifier
(new sequence_parser (fun eq expr -> expr)
equals_sign
expr)) in
let statement = new alternate_parser
assignment
expr in
statement#parse pi
end ;;
let parse tokens =
let expr = new line_parser () in
let pi = new parse_info tokens in
match (expr#parse pi) with
| Ok(ParseResult(r, _)) -> Ok((r, pi#get_tokens ()))
| Bad(msg) -> Bad(msg)
(* The tokenizer which parses the input strings. *)
let sep_tokens = [","; "("; ")"; "+"; "-"; "*"; "/"; " "; "^"] ;;
let all_but_last l =
let rec loop accum rem =
match rem with
| [] -> []
| (_ :: []) -> List.rev accum
| (x :: xs) -> loop (x :: accum) xs in
loop [] l ;;
let intersperse l sep =
let rec loop l parts =
match l with
| (x :: xs) -> loop xs (sep :: x :: parts)
| [] -> List.rev parts in
all_but_last (loop l []) ;;
let tokenize s =
let rec loop seps tokens =
match seps with
| (x :: xs) ->
let new_tokens = List.flatten
(List.map (fun s -> intersperse (String.nsplit s x) x) tokens) in
loop xs new_tokens
| [] -> tokens in
let tokens_with_whitespace = loop sep_tokens [s] in
let tokens_with_empties = List.map String.trim tokens_with_whitespace in
List.filter (fun s -> not (String.is_empty s)) tokens_with_empties ;;
(* Functions to evaluate the AST. *)
type function_def = {
name: string;
arrity: int;
body: float list -> (float, string) Result.t
} ;;
type env = {
vars: (string, float) Hashtbl.t;
funcs: (string, function_def) Hashtbl.t
} ;;
let define_variable env name v = Hashtbl.add env.vars name v ;;
let define_function env name arrity body =
Hashtbl.add env.funcs name {name=name; arrity=arrity; body=body} ;;
let evaluate_function f args =
match f.arrity with
| n when List.length args != n -> Bad(Printf.sprintf "'%s' expects %d arguments" f.name f.arrity)
| _ -> f.body args ;;
let evaluate env expr =
let rec evaluate_rec expr =
let eval_result op x =
(evaluate_rec x) >>= (fun r -> Ok(op r)) in
let eval_bin_result op x y =
(evaluate_rec x) >>= (fun lhs -> (evaluate_rec y) >>= (fun rhs -> Ok(op lhs rhs))) in
match expr with
| BinaryExpr(x, op, y) ->
(match op with
| Add -> eval_bin_result Float.add x y
| Subtract -> eval_bin_result Float.sub x y
| Multiply -> eval_bin_result Float.mul x y
| Divide -> eval_bin_result Float.div x y
| Pow -> eval_bin_result Float.pow x y)
| Number(x) -> Ok(x)
| Variable(v) ->
(try
Ok(Hashtbl.find env.vars v)
with _ -> Bad(Printf.sprintf "No such variable '%s'" v))
| Assignment(n, v) -> eval_result (fun f -> Hashtbl.add env.vars n f; f) v
| FunctionInvocation(name, args) ->
(try
let f = Hashtbl.find env.funcs name in
let evaluated_args = List.map evaluate_rec args in
(* If any of the argument expressions evaluate to an error,
then stop and return the error. *)
let rec check_args collected remaining =
match remaining with
| [] -> Ok(List.rev collected)
| (Ok(x) :: xs) -> check_args (x :: collected) xs
| (Bad(msg) :: xs) -> Bad(msg) in
let args = check_args [] evaluated_args in
match args with
| Ok(args) -> evaluate_function f args
| Bad(msg) -> Bad(msg)
with _ -> Bad(Printf.sprintf "No such function '%s'" name)) in
evaluate_rec expr ;;
(* Finally, the main interface of the calculator and predefined
functions. *)
module FunctionLibrary =
struct
let pop_args args n =
let rec loop collected remaining count =
match remaining with
| _ when count >= n -> Ok(List.rev collected)
| [] -> Bad(Printf.sprintf "Expected %d arguments" n)
| (x :: xs) -> loop (x :: collected) xs (count + 1) in
loop [] args 0
let single_arg_func f =
fun args ->
match pop_args args 1 with
| Ok(x) -> Ok(f (List.hd x))
| Bad(msg) -> Bad(msg)
let sin = single_arg_func Float.sin
let asin = single_arg_func Float.asin
let cos = single_arg_func Float.cos
let acos = single_arg_func Float.acos
let tan = single_arg_func Float.tan
let atan = single_arg_func Float.atan
let sqrt = single_arg_func sqrt
end ;;
let main () =
let env = {
vars = Hashtbl.create 10;
funcs = Hashtbl.create 10
} in
define_variable env "pi" Float.pi;
define_variable env "e" (Float.exp 1.0);
define_function env "sin" 1 FunctionLibrary.sin;
define_function env "asin" 1 FunctionLibrary.asin;
define_function env "cos" 1 FunctionLibrary.cos;
define_function env "acos" 1 FunctionLibrary.acos;
define_function env "tan" 1 FunctionLibrary.tan;
define_function env "atan" 1 FunctionLibrary.atan;
define_function env "sqrt" 1 FunctionLibrary.sqrt;
while true do
output_string stdout ">>> ";
flush stdout;
try
let line = input_line stdin in
let tokens = tokenize line in
let r = parse tokens in
match r with
| Ok(expr, remaining) ->
(match remaining with
| (_ :: _) ->
Printf.printf "Unexpected remaining expression: '%s'\n" (String.concat " " remaining);
| [] ->
(match (evaluate env expr) with
| Ok(result) -> Printf.printf "%f\n" result;
| Bad(msg) -> Printf.printf "ERROR: %s\n" msg;))
| Bad(msg) -> Printf.printf "ERROR: %s\n" msg;
with _ ->
output_string stdout "\n";
exit 0;
done ;;
main () ;;