Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions examples/math.ilo
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
-- Transcendental builtins: pow, sqrt, log, exp, sin, cos (radians).

-- 2D Euclidean distance: sqrt(x^2 + y^2)
dist x:n y:n>n;a=*x x;b=*y y;sqrt +a b

-- Compound interest: principal * (1 + rate)^years
ci p:n r:n y:n>n;g=+1 r;f=pow g y;*p f

-- Log/exp round-trip
rt x:n>n;e=exp x;log e

-- run: dist 3 4
-- out: 5
-- run: ci 1000 0.05 10
-- out: 1628.894626777442
-- run: rt 5
-- out: 5
28 changes: 23 additions & 5 deletions src/builtins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ pub enum Builtin {
Min,
Max,
Mod,
Pow,
Sqrt,
Log,
Exp,
Sin,
Cos,
Sum,
Avg,

Expand Down Expand Up @@ -92,6 +98,12 @@ impl Builtin {
"min" => Some(Builtin::Min),
"max" => Some(Builtin::Max),
"mod" => Some(Builtin::Mod),
"pow" => Some(Builtin::Pow),
"sqrt" => Some(Builtin::Sqrt),
"log" => Some(Builtin::Log),
"exp" => Some(Builtin::Exp),
"sin" => Some(Builtin::Sin),
"cos" => Some(Builtin::Cos),
"sum" => Some(Builtin::Sum),
"avg" => Some(Builtin::Avg),
"len" => Some(Builtin::Len),
Expand Down Expand Up @@ -151,6 +163,12 @@ impl Builtin {
Builtin::Min => "min",
Builtin::Max => "max",
Builtin::Mod => "mod",
Builtin::Pow => "pow",
Builtin::Sqrt => "sqrt",
Builtin::Log => "log",
Builtin::Exp => "exp",
Builtin::Sin => "sin",
Builtin::Cos => "cos",
Builtin::Sum => "sum",
Builtin::Avg => "avg",
Builtin::Len => "len",
Expand Down Expand Up @@ -209,11 +227,11 @@ mod tests {
#[test]
fn round_trip_all_builtins() {
let all = [
"str", "num", "abs", "flr", "cel", "rou", "min", "max", "mod", "sum", "avg", "len",
"hd", "at", "tl", "rev", "srt", "slc", "unq", "flat", "has", "spl", "cat", "map",
"flt", "fld", "grp", "rnd", "now", "rd", "rdl", "rdb", "wr", "wrl", "prnt", "env",
"trm", "fmt", "rgx", "jpth", "jdmp", "jpar", "get", "post", "mmap", "mget", "mset",
"mhas", "mkeys", "mvals", "mdel",
"str", "num", "abs", "flr", "cel", "rou", "min", "max", "mod", "pow", "sqrt", "log",
"exp", "sin", "cos", "sum", "avg", "len", "hd", "at", "tl", "rev", "srt", "slc", "unq",
"flat", "has", "spl", "cat", "map", "flt", "fld", "grp", "rnd", "now", "rd", "rdl",
"rdb", "wr", "wrl", "prnt", "env", "trm", "fmt", "rgx", "jpth", "jdmp", "jpar", "get",
"post", "mmap", "mget", "mset", "mhas", "mkeys", "mvals", "mdel",
];
for name in &all {
let b = Builtin::from_name(name).unwrap_or_else(|| panic!("missing builtin: {name}"));
Expand Down
89 changes: 89 additions & 0 deletions src/interpreter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,37 @@ fn call_function(env: &mut Env, name: &str, args: Vec<Value>) -> Result<Value> {
)),
};
}
if matches!(
builtin,
Some(Builtin::Sqrt | Builtin::Log | Builtin::Exp | Builtin::Sin | Builtin::Cos)
) && args.len() == 1
{
return match &args[0] {
Value::Number(n) => {
let result = match builtin {
Some(Builtin::Sqrt) => n.sqrt(),
Some(Builtin::Log) => n.ln(),
Some(Builtin::Exp) => n.exp(),
Some(Builtin::Sin) => n.sin(),
_ => n.cos(),
};
Ok(Value::Number(result))
}
other => Err(RuntimeError::new(
"ILO-R009",
format!("{} requires a number, got {:?}", name, other),
)),
};
}
if builtin == Some(Builtin::Pow) && args.len() == 2 {
return match (&args[0], &args[1]) {
(Value::Number(a), Value::Number(b)) => Ok(Value::Number(a.powf(*b))),
_ => Err(RuntimeError::new(
"ILO-R009",
"pow requires two numbers".to_string(),
)),
};
}
if builtin == Some(Builtin::Now) && args.is_empty() {
let ts = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
Expand Down Expand Up @@ -2548,6 +2579,64 @@ mod tests {
assert!(result.is_err());
}

// ── Error paths for the new transcendental math builtins ─────────────
// The tree-walker accepts any Value at runtime; verify catches the type
// mismatch at compile time but does not run here. These tests cover the
// `other => Err(...)` arms in the Sqrt|Log|Exp|Sin|Cos and Pow handlers.
#[test]
fn interpret_sqrt_non_number_errors() {
let source = "f x:t>n;sqrt x";
let prog = parse_program(source);
let err = run(&prog, Some("f"), vec![Value::Text("nope".into())]).unwrap_err();
assert!(
err.to_string().contains("sqrt") && err.to_string().contains("requires a number"),
"unexpected error: {err}"
);
}

#[test]
fn interpret_log_non_number_errors() {
let prog = parse_program("f x:t>n;log x");
let err = run(&prog, Some("f"), vec![Value::Text("nope".into())]).unwrap_err();
assert!(err.to_string().contains("log"), "unexpected error: {err}");
}

#[test]
fn interpret_exp_non_number_errors() {
let prog = parse_program("f x:t>n;exp x");
let err = run(&prog, Some("f"), vec![Value::Text("nope".into())]).unwrap_err();
assert!(err.to_string().contains("exp"), "unexpected error: {err}");
}

#[test]
fn interpret_sin_non_number_errors() {
let prog = parse_program("f x:t>n;sin x");
let err = run(&prog, Some("f"), vec![Value::Text("nope".into())]).unwrap_err();
assert!(err.to_string().contains("sin"), "unexpected error: {err}");
}

#[test]
fn interpret_cos_non_number_errors() {
let prog = parse_program("f x:t>n;cos x");
let err = run(&prog, Some("f"), vec![Value::Text("nope".into())]).unwrap_err();
assert!(err.to_string().contains("cos"), "unexpected error: {err}");
}

#[test]
fn interpret_pow_non_number_errors() {
let prog = parse_program("f x:t y:t>n;pow x y");
let err = run(
&prog,
Some("f"),
vec![Value::Text("a".into()), Value::Text("b".into())],
)
.unwrap_err();
assert!(
err.to_string().contains("pow") && err.to_string().contains("two numbers"),
"unexpected error: {err}"
);
}

#[test]
fn interpret_logical_and() {
let source = "f a:b b:b>b;&a b";
Expand Down
10 changes: 8 additions & 2 deletions src/verify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,12 @@ const BUILTINS: &[(&str, &[&str], &str)] = &[
("min", &["n", "n"], "n"),
("max", &["n", "n"], "n"),
("mod", &["n", "n"], "n"),
("pow", &["n", "n"], "n"),
("sqrt", &["n"], "n"),
("log", &["n"], "n"),
("exp", &["n"], "n"),
("sin", &["n"], "n"),
("cos", &["n"], "n"),
("get", &["t"], "R t t"),
("get", &["t", "M t t"], "R t t"),
("post", &["t", "t"], "R t t"),
Expand Down Expand Up @@ -372,7 +378,7 @@ fn builtin_check_args(
}
(Ty::Result(Box::new(Ty::Number), Box::new(Ty::Text)), errors)
}
"abs" | "flr" | "cel" | "rou" => {
"abs" | "flr" | "cel" | "rou" | "sqrt" | "log" | "exp" | "sin" | "cos" => {
if let Some(arg) = arg_types.first()
&& !compatible(arg, &Ty::Number)
{
Expand All @@ -387,7 +393,7 @@ fn builtin_check_args(
}
(Ty::Number, errors)
}
"min" | "max" | "mod" => {
"min" | "max" | "mod" | "pow" => {
for (i, arg) in arg_types.iter().enumerate() {
if !compatible(arg, &Ty::Number) {
errors.push(VerifyError {
Expand Down
86 changes: 85 additions & 1 deletion src/vm/compile_cranelift.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ struct HelperFuncs {
flr: FuncId,
cel: FuncId,
rou: FuncId,
pow: FuncId,
sqrt: FuncId,
log: FuncId,
exp: FuncId,
sin: FuncId,
cos: FuncId,
rnd0: FuncId,
rnd2: FuncId,
now: FuncId,
Expand Down Expand Up @@ -165,6 +171,12 @@ fn declare_all_helpers(module: &mut ObjectModule) -> HelperFuncs {
flr: declare_helper(module, "jit_flr", 1, 1),
cel: declare_helper(module, "jit_cel", 1, 1),
rou: declare_helper(module, "jit_rou", 1, 1),
pow: declare_helper(module, "jit_pow", 2, 1),
sqrt: declare_helper(module, "jit_sqrt", 1, 1),
log: declare_helper(module, "jit_log", 1, 1),
exp: declare_helper(module, "jit_exp", 1, 1),
sin: declare_helper(module, "jit_sin", 1, 1),
cos: declare_helper(module, "jit_cos", 1, 1),
rnd0: declare_helper(module, "jit_rnd0", 0, 1),
rnd2: declare_helper(module, "jit_rnd2", 2, 1),
now: declare_helper(module, "jit_now", 0, 1),
Expand Down Expand Up @@ -870,7 +882,8 @@ fn compile_function_body(
// Guaranteed numeric outputs.
OP_ADD_NN | OP_SUB_NN | OP_MUL_NN | OP_DIV_NN | OP_ADDK_N | OP_SUBK_N
| OP_MULK_N | OP_DIVK_N | OP_LEN | OP_ABS | OP_MIN | OP_MAX | OP_FLR | OP_CEL
| OP_ROU | OP_RND0 | OP_RND2 | OP_NOW | OP_MOD => {
| OP_ROU | OP_RND0 | OP_RND2 | OP_NOW | OP_MOD | OP_POW | OP_SQRT | OP_LOG
| OP_EXP | OP_SIN | OP_COS => {
num_write[a] = true;
}
// LOADK: numeric only when the constant itself is a number.
Expand Down Expand Up @@ -1742,6 +1755,36 @@ fn compile_function_body(
builder.def_var(f64_vars[a_idx], rf);
}
}
OP_POW => {
let bv = builder.use_var(vars[b_idx]);
let cv = builder.use_var(vars[c_idx]);
let fref = get_func_ref(&mut builder, module, helpers.pow);
let call_inst = builder.ins().call(fref, &[bv, cv]);
let result = builder.inst_results(call_inst)[0];
builder.def_var(vars[a_idx], result);
if a_idx < reg_count && reg_always_num[a_idx] {
let rf = builder.ins().bitcast(F64, mf, result);
builder.def_var(f64_vars[a_idx], rf);
}
}
OP_SQRT | OP_LOG | OP_EXP | OP_SIN | OP_COS => {
let bv = builder.use_var(vars[b_idx]);
let fid = match op {
OP_SQRT => helpers.sqrt,
OP_LOG => helpers.log,
OP_EXP => helpers.exp,
OP_SIN => helpers.sin,
_ => helpers.cos,
};
let fref = get_func_ref(&mut builder, module, fid);
let call_inst = builder.ins().call(fref, &[bv]);
let result = builder.inst_results(call_inst)[0];
builder.def_var(vars[a_idx], result);
if a_idx < reg_count && reg_always_num[a_idx] {
let rf = builder.ins().bitcast(F64, mf, result);
builder.def_var(f64_vars[a_idx], rf);
}
}
OP_RND0 => {
let fref = get_func_ref(&mut builder, module, helpers.rnd0);
let call_inst = builder.ins().call(fref, &[]);
Expand Down Expand Up @@ -3651,6 +3694,47 @@ mod tests {
assert!(bytes4.is_ok());
}

// ── AOT translator coverage for the new transcendental math opcodes ───
// These exercise the OP_POW and OP_SQRT|OP_LOG|OP_EXP|OP_SIN|OP_COS arms
// in compile_function_body, which the JIT-based --run-cranelift tests do
// not reach (those go through jit_cranelift::compile_and_call, not the
// AOT translator).
#[test]
fn codegen_pow_emits_object() {
let bytes = compile_to_object_bytes("f>n;pow 2 10");
assert!(bytes.is_ok(), "pow AOT failed: {:?}", bytes.err());
}

#[test]
fn codegen_sqrt_emits_object() {
let bytes = compile_to_object_bytes("f>n;sqrt 4");
assert!(bytes.is_ok(), "sqrt AOT failed: {:?}", bytes.err());
}

#[test]
fn codegen_log_emits_object() {
let bytes = compile_to_object_bytes("f>n;log 2.5");
assert!(bytes.is_ok(), "log AOT failed: {:?}", bytes.err());
}

#[test]
fn codegen_exp_emits_object() {
let bytes = compile_to_object_bytes("f>n;exp 1");
assert!(bytes.is_ok(), "exp AOT failed: {:?}", bytes.err());
}

#[test]
fn codegen_sin_emits_object() {
let bytes = compile_to_object_bytes("f>n;sin 0");
assert!(bytes.is_ok(), "sin AOT failed: {:?}", bytes.err());
}

#[test]
fn codegen_cos_emits_object() {
let bytes = compile_to_object_bytes("f>n;cos 0");
assert!(bytes.is_ok(), "cos AOT failed: {:?}", bytes.err());
}

#[test]
fn codegen_min_max_emits_object() {
let bytes = compile_to_object_bytes("f a:n b:n>n;min a b");
Expand Down
Loading
Loading