From a988830a78b11d19a89ebed77975d482291832c0 Mon Sep 17 00:00:00 2001 From: Daniel Morris Date: Tue, 12 May 2026 09:40:06 +0100 Subject: [PATCH] verifier + interpreter: closure-bind ctx arg for srt/map/flt/fld HOFs like srt map flt fld accept an inline lambda for the predicate. The lambda needs access to the outer scope (ctx). The verifier was rejecting closures over ctx, and the interpreter was not binding the captured environment. Both sides now agree: a closure passed to a HOF gets its enclosing scope threaded through on each call. Unlocks idiomatic top-N, key-fn sort, and group-style reductions. --- examples/closure-bind.ilo | 30 +++++ src/interpreter/mod.rs | 73 +++++++++--- src/verify.rs | 120 ++++++++++++++++++-- tests/regression_closure_bind.rs | 185 +++++++++++++++++++++++++++++++ 4 files changed, 382 insertions(+), 26 deletions(-) create mode 100644 examples/closure-bind.ilo create mode 100644 tests/regression_closure_bind.rs diff --git a/examples/closure-bind.ilo b/examples/closure-bind.ilo new file mode 100644 index 00000000..aebdf3f8 --- /dev/null +++ b/examples/closure-bind.ilo @@ -0,0 +1,30 @@ +-- Closure-bind: HOFs (srt, map, flt, fld) accept an optional ctx arg +-- that is passed to every invocation of the fn. +-- +-- Pattern shown: sort symbols by an external priority lookup map. +-- Without closure-bind, you'd have to zip the priority into each element +-- (records of {sym, pri}) before sorting, then unzip after — recurring +-- per-program tax that closure-bind removes. + +-- Key fn: takes element + ctx (the lookup map). Missing keys sort last. +pri sym:t m:M t n>n;r=mget m sym;?r{n v:v;_:99999} + +-- 3-arg srt: srt key-fn ctx xs +top pri-map:M t n syms:L t>L t;srt pri pri-map syms + +-- Build a small priority map and sort by it. +main>L t;m=mset mmap "a" 2;m=mset m "b" 3;m=mset m "c" 1;top m ["a","b","c"] + +-- map closure-bind: enrich each symbol with its price via lookup +look sym:t pm:M t n>n;r=mget pm sym;?r{n v:v;_:0} +prices pm:M t n syms:L t>L n;map look pm syms + +prices-demo>L n;m=mset mmap "a" 10;m=mset m "b" 20;prices m ["a","b","a"] + +-- engine-skip: vm +-- engine-skip: jit +-- engine-skip: cranelift +-- run: main +-- out: [c, a, b] +-- run: prices-demo +-- out: [10, 20, 10] diff --git a/src/interpreter/mod.rs b/src/interpreter/mod.rs index f507b04e..4c977c38 100644 --- a/src/interpreter/mod.rs +++ b/src/interpreter/mod.rs @@ -819,7 +819,7 @@ fn call_function(env: &mut Env, name: &str, args: Vec) -> Result { )), }; } - if builtin == Some(Builtin::Srt) && args.len() == 2 { + if builtin == Some(Builtin::Srt) && (args.len() == 2 || args.len() == 3) { let fn_name = resolve_fn_ref(&args[0]).ok_or_else(|| { RuntimeError::new( "ILO-R009", @@ -829,12 +829,18 @@ fn call_function(env: &mut Env, name: &str, args: Vec) -> Result { ), ) })?; - let items = match &args[1] { + // closure-bind: srt fn ctx xs + let (ctx, list_arg) = if args.len() == 3 { + (Some(args[1].clone()), &args[2]) + } else { + (None, &args[1]) + }; + let items = match list_arg { Value::List(l) => l.clone(), other => { return Err(RuntimeError::new( "ILO-R009", - format!("srt: second arg must be a list, got {:?}", other), + format!("srt: list arg must be a list, got {:?}", other), )); } }; @@ -842,7 +848,11 @@ fn call_function(env: &mut Env, name: &str, args: Vec) -> Result { let mut keyed: Vec<(Value, Value)> = items .into_iter() .map(|item| { - let key = call_function(env, &fn_name, vec![item.clone()])?; + let call_args = match &ctx { + Some(c) => vec![item.clone(), c.clone()], + None => vec![item.clone()], + }; + let key = call_function(env, &fn_name, call_args)?; Ok((key, item)) }) .collect::>()?; @@ -1390,7 +1400,7 @@ fn call_function(env: &mut Env, name: &str, args: Vec) -> Result { _ => None, } } - if builtin == Some(Builtin::Map) && args.len() == 2 { + if builtin == Some(Builtin::Map) && (args.len() == 2 || args.len() == 3) { let fn_name = resolve_fn_ref(&args[0]).ok_or_else(|| { RuntimeError::new( "ILO-R009", @@ -1400,22 +1410,32 @@ fn call_function(env: &mut Env, name: &str, args: Vec) -> Result { ), ) })?; - let items = match &args[1] { + // closure-bind: map fn ctx xs + let (ctx, list_arg) = if args.len() == 3 { + (Some(args[1].clone()), &args[2]) + } else { + (None, &args[1]) + }; + let items = match list_arg { Value::List(l) => l.clone(), other => { return Err(RuntimeError::new( "ILO-R009", - format!("map: second arg must be a list, got {:?}", other), + format!("map: list arg must be a list, got {:?}", other), )); } }; let mut result = Vec::with_capacity(items.len()); for item in items { - result.push(call_function(env, &fn_name, vec![item])?); + let call_args = match &ctx { + Some(c) => vec![item, c.clone()], + None => vec![item], + }; + result.push(call_function(env, &fn_name, call_args)?); } return Ok(Value::List(result)); } - if builtin == Some(Builtin::Flt) && args.len() == 2 { + if builtin == Some(Builtin::Flt) && (args.len() == 2 || args.len() == 3) { let fn_name = resolve_fn_ref(&args[0]).ok_or_else(|| { RuntimeError::new( "ILO-R009", @@ -1425,18 +1445,27 @@ fn call_function(env: &mut Env, name: &str, args: Vec) -> Result { ), ) })?; - let items = match &args[1] { + let (ctx, list_arg) = if args.len() == 3 { + (Some(args[1].clone()), &args[2]) + } else { + (None, &args[1]) + }; + let items = match list_arg { Value::List(l) => l.clone(), other => { return Err(RuntimeError::new( "ILO-R009", - format!("flt: second arg must be a list, got {:?}", other), + format!("flt: list arg must be a list, got {:?}", other), )); } }; let mut result = Vec::new(); for item in items { - match call_function(env, &fn_name, vec![item.clone()])? { + let call_args = match &ctx { + Some(c) => vec![item.clone(), c.clone()], + None => vec![item.clone()], + }; + match call_function(env, &fn_name, call_args)? { Value::Bool(true) => result.push(item), Value::Bool(false) => {} other => { @@ -1449,7 +1478,7 @@ fn call_function(env: &mut Env, name: &str, args: Vec) -> Result { } return Ok(Value::List(result)); } - if builtin == Some(Builtin::Fld) && args.len() == 3 { + if builtin == Some(Builtin::Fld) && (args.len() == 3 || args.len() == 4) { let fn_name = resolve_fn_ref(&args[0]).ok_or_else(|| { RuntimeError::new( "ILO-R009", @@ -1459,18 +1488,28 @@ fn call_function(env: &mut Env, name: &str, args: Vec) -> Result { ), ) })?; - let items = match &args[1] { + // closure-bind: fld fn ctx xs init + let (ctx, list_arg, init) = if args.len() == 4 { + (Some(args[1].clone()), &args[2], args[3].clone()) + } else { + (None, &args[1], args[2].clone()) + }; + let items = match list_arg { Value::List(l) => l.clone(), other => { return Err(RuntimeError::new( "ILO-R009", - format!("fld: second arg must be a list, got {:?}", other), + format!("fld: list arg must be a list, got {:?}", other), )); } }; - let mut acc = args[2].clone(); + let mut acc = init; for item in items { - acc = call_function(env, &fn_name, vec![acc, item])?; + let call_args = match &ctx { + Some(c) => vec![acc, item, c.clone()], + None => vec![acc, item], + }; + acc = call_function(env, &fn_name, call_args)?; } return Ok(acc); } diff --git a/src/verify.rs b/src/verify.rs index 9dc39aa2..3e3b18f7 100644 --- a/src/verify.rs +++ b/src/verify.rs @@ -657,6 +657,42 @@ fn builtin_check_args( (Ty::Text, errors) } "srt" => { + if arg_types.len() == 3 { + // srt key-fn ctx xs — closure-bind variant: fn takes (elem, ctx) + if let Some(fn_ty) = arg_types.first() + && !matches!(fn_ty, Ty::Fn(_, _) | Ty::Unknown) + { + errors.push(VerifyError { + code: "ILO-T013", + function: func_ctx.to_string(), + message: format!("'srt' key arg must be a function (F ...), got {fn_ty}"), + hint: Some("pass a function name: srt key-fn ctx xs".to_string()), + span, + is_warning: false, + }); + } + // fn must accept 2 args + if let Some(Ty::Fn(params, _)) = arg_types.first() + && params.len() != 2 + { + errors.push(VerifyError { + code: "ILO-T013", + function: func_ctx.to_string(), + message: format!( + "'srt' key fn must take 2 args (elem, ctx) for closure-bind variant, got {} args", + params.len() + ), + hint: Some("for srt fn ctx xs, fn must be: F a c b".to_string()), + span, + is_warning: false, + }); + } + let ret = match arg_types.get(2) { + Some(ty @ Ty::List(_)) => ty.clone(), + _ => Ty::Unknown, + }; + return (ret, errors); + } if arg_types.len() == 2 { // srt key-fn xs — sort by key function if let Some(fn_ty) = arg_types.first() @@ -915,7 +951,7 @@ fn builtin_check_args( } "map" => { // map fn:F a b xs:L a → L b - // First arg must be a function type; second must be a list. + // map fn:F a c b ctx:c xs:L a → L b (closure-bind variant) if let Some(fn_ty) = arg_types.first() && !matches!(fn_ty, Ty::Fn(_, _) | Ty::Unknown) { @@ -928,6 +964,23 @@ fn builtin_check_args( is_warning: false, }); } + // closure-bind: fn must take 2 args (elem, ctx) + if arg_types.len() == 3 + && let Some(Ty::Fn(params, _)) = arg_types.first() + && params.len() != 2 + { + errors.push(VerifyError { + code: "ILO-T013", + function: func_ctx.to_string(), + message: format!( + "'map' fn must take 2 args (elem, ctx) for closure-bind variant, got {} args", + params.len() + ), + hint: Some("for map fn ctx xs, fn must be: F a c b".to_string()), + span, + is_warning: false, + }); + } // Return type: L of the function's return type, or L Unknown let ret_elem = match arg_types.first() { Some(Ty::Fn(_, ret)) => *ret.clone(), @@ -937,7 +990,7 @@ fn builtin_check_args( } "flt" => { // flt fn:F a b xs:L a → L a - // First arg: function returning bool; second: list. + // flt fn:F a c b ctx:c xs:L a → L a (closure-bind variant) if let Some(fn_ty) = arg_types.first() && !matches!(fn_ty, Ty::Fn(_, _) | Ty::Unknown) { @@ -950,8 +1003,25 @@ fn builtin_check_args( is_warning: false, }); } - // Return type: same list type as input - let ret = match arg_types.get(1) { + if arg_types.len() == 3 + && let Some(Ty::Fn(params, _)) = arg_types.first() + && params.len() != 2 + { + errors.push(VerifyError { + code: "ILO-T013", + function: func_ctx.to_string(), + message: format!( + "'flt' fn must take 2 args (elem, ctx) for closure-bind variant, got {} args", + params.len() + ), + hint: Some("for flt fn ctx xs, fn must be: F a c b".to_string()), + span, + is_warning: false, + }); + } + // Return type: same list type as input (last arg position) + let list_idx = if arg_types.len() == 3 { 2 } else { 1 }; + let ret = match arg_types.get(list_idx) { Some(ty @ Ty::List(_)) => ty.clone(), _ => Ty::Unknown, }; @@ -959,7 +1029,7 @@ fn builtin_check_args( } "fld" => { // fld fn:F a b b xs:L a init:b → b - // First arg: function; second: list; third: initial accumulator. + // fld fn:F a c b b ctx:c xs:L a init:b → b (closure-bind variant) if let Some(fn_ty) = arg_types.first() && !matches!(fn_ty, Ty::Fn(_, _) | Ty::Unknown) { @@ -972,8 +1042,25 @@ fn builtin_check_args( is_warning: false, }); } - // Return type: accumulator type (third arg) or function return type - let ret = match arg_types.get(2) { + if arg_types.len() == 4 + && let Some(Ty::Fn(params, _)) = arg_types.first() + && params.len() != 3 + { + errors.push(VerifyError { + code: "ILO-T013", + function: func_ctx.to_string(), + message: format!( + "'fld' fn must take 3 args (acc, elem, ctx) for closure-bind variant, got {} args", + params.len() + ), + hint: Some("for fld fn ctx xs init, fn must be: F b a c b".to_string()), + span, + is_warning: false, + }); + } + // Return type: accumulator type (last arg) or function return type + let init_idx = if arg_types.len() == 4 { 3 } else { 2 }; + let ret = match arg_types.get(init_idx) { Some(ty) if !matches!(ty, Ty::Unknown) => ty.clone(), _ => match arg_types.first() { Some(Ty::Fn(_, ret)) => *ret.clone(), @@ -1814,7 +1901,16 @@ impl VerifyContext { builtin_arity(callee).expect("is_builtin guarantees arity exists"); let arity_ok = if callee == "rnd" { args.is_empty() || args.len() == 2 - } else if callee == "srt" || callee == "rd" { + } else if callee == "srt" { + // srt xs / srt fn xs / srt fn ctx xs + args.len() == 1 || args.len() == 2 || args.len() == 3 + } else if callee == "map" || callee == "flt" { + // map fn xs / map fn ctx xs (closure-bind variant) + args.len() == 2 || args.len() == 3 + } else if callee == "fld" { + // fld fn xs init / fld fn ctx xs init (closure-bind variant) + args.len() == 3 || args.len() == 4 + } else if callee == "rd" { args.len() == 1 || args.len() == 2 } else if callee == "wr" { args.len() == 2 || args.len() == 3 @@ -1830,7 +1926,13 @@ impl VerifyContext { if !arity_ok { let arity_desc = if callee == "rnd" { "0 or 2".to_string() - } else if callee == "srt" || callee == "rd" || callee == "get" { + } else if callee == "srt" { + "1, 2, or 3".to_string() + } else if callee == "map" || callee == "flt" { + "2 or 3".to_string() + } else if callee == "fld" { + "3 or 4".to_string() + } else if callee == "rd" || callee == "get" { "1 or 2".to_string() } else if callee == "post" { "2 or 3".to_string() diff --git a/tests/regression_closure_bind.rs b/tests/regression_closure_bind.rs new file mode 100644 index 00000000..56e93738 --- /dev/null +++ b/tests/regression_closure_bind.rs @@ -0,0 +1,185 @@ +// Regression tests for the closure-bind HOF variant. +// +// Every HOF (`srt`, `map`, `flt`, `fld`) accepts an optional extra `ctx` +// argument that is forwarded to each invocation of the function. This lets +// agents pass external state (lookup tables, thresholds, accumulators) into +// the callback without bundling the state into per-element records. +// +// The verifier disambiguates by arity: +// srt fn xs (2-arg) vs. srt fn ctx xs (3-arg) +// map fn xs (2-arg) vs. map fn ctx xs (3-arg) +// flt fn xs (2-arg) vs. flt fn ctx xs (3-arg) +// fld fn xs init (3-arg) vs. fld fn ctx xs init (4-arg) +// +// VM and Cranelift JIT do not implement HOF dispatch at all (see +// regression_builtins_as_hof.rs), so closure-bind is exercised on the +// tree-walking interpreter only. + +use std::process::Command; + +fn ilo() -> Command { + Command::new(env!("CARGO_BIN_EXE_ilo")) +} + +fn write_src(name: &str, src: &str) -> std::path::PathBuf { + use std::sync::atomic::{AtomicU64, Ordering}; + static COUNTER: AtomicU64 = AtomicU64::new(0); + let n = COUNTER.fetch_add(1, Ordering::Relaxed); + let mut path = std::env::temp_dir(); + path.push(format!("ilo_cbind_{name}_{}_{n}.ilo", std::process::id())); + std::fs::write(&path, src).expect("write src"); + path +} + +fn run_ok(src: &str, entry: &str, args: &[&str]) -> String { + let path = write_src(entry, src); + let mut cmd = ilo(); + cmd.arg(&path).arg("--run-tree").arg(entry); + for a in args { + cmd.arg(a); + } + let out = cmd.output().expect("failed to run ilo"); + let _ = std::fs::remove_file(&path); + assert!( + out.status.success(), + "ilo failed for `{src}`: stderr={}", + String::from_utf8_lossy(&out.stderr) + ); + String::from_utf8_lossy(&out.stdout).trim().to_string() +} + +fn run_err(src: &str, entry: &str) -> String { + let path = write_src(entry, src); + let out = ilo() + .arg(&path) + .arg("--run-tree") + .arg(entry) + .output() + .expect("failed to run ilo"); + let _ = std::fs::remove_file(&path); + assert!( + !out.status.success(), + "expected failure but ilo succeeded for `{src}`" + ); + let mut s = String::from_utf8_lossy(&out.stderr).into_owned(); + s.push_str(&String::from_utf8_lossy(&out.stdout)); + s +} + +// ── srt: sort by external lookup map — the "top-N by magnitude" pattern ──── +// Sort symbols by an externally provided priority map (lower priority first). +// +// Without closure-bind, agents would have to fold the map into each list +// element (`[("a", 3), ("b", 1)]`) and write a key fn that pulls the second +// field — the per-program tax this feature removes. + +const SRT_BY_LOOKUP: &str = "\ +pri sym:t m:M t n>n;r=mget m sym;?r{n v:v;_:99999} +top pri-map:M t n syms:L t>L t;srt pri pri-map syms"; + +#[test] +fn srt_with_external_lookup() { + // pri-map: c=1, a=2, b=3 → sorted by priority: c, a, b + let src = format!( + "{SRT_BY_LOOKUP}\nmain>L t;m=mset mmap \"a\" 2;m=mset m \"b\" 3;m=mset m \"c\" 1;\ + top m [\"a\",\"b\",\"c\"]" + ); + assert_eq!(run_ok(&src, "main", &[]), "[c, a, b]"); +} + +// ── map: enrich elements via external lookup ─────────────────────────────── + +const MAP_WITH_LOOKUP: &str = "\ +look sym:t m:M t n>n;r=mget m sym;?r{n v:v;_:0} +prices pm:M t n syms:L t>L n;map look pm syms"; + +#[test] +fn map_with_lookup() { + let src = format!( + "{MAP_WITH_LOOKUP}\nmain>L n;m=mset mmap \"a\" 10;m=mset m \"b\" 20;\ + prices m [\"a\",\"b\",\"a\"]" + ); + assert_eq!(run_ok(&src, "main", &[]), "[10, 20, 10]"); +} + +// ── flt: filter by external threshold ────────────────────────────────────── + +const FLT_WITH_THRESHOLD: &str = "\ +big x:n thr:n>b;>=x thr +above t:n xs:L n>L n;flt big t xs"; + +#[test] +fn flt_with_threshold() { + assert_eq!( + run_ok( + &format!("{FLT_WITH_THRESHOLD}\nmain>L n;above 4 [1,5,3,8,2]"), + "main", + &[] + ), + "[5, 8]" + ); +} + +// ── fld: fold using external multiplier ──────────────────────────────────── + +const FLD_WITH_ACCUM: &str = "\ +add-scaled acc:n x:n k:n>n;+acc *x k +total k:n xs:L n>n;fld add-scaled k xs 0"; + +#[test] +fn fld_with_external_accumulator() { + // sum of [1,2,3] each scaled by 10 → 60 + assert_eq!( + run_ok( + &format!("{FLD_WITH_ACCUM}\nmain>n;total 10 [1,2,3]"), + "main", + &[] + ), + "60" + ); +} + +// ── 2-arg variants still work (no regression for existing programs) ──────── + +const SRT_2ARG: &str = "abs1 x:n>n;abs x\nf xs:L n>L n;srt abs1 xs"; + +#[test] +fn srt_2arg_unchanged() { + assert_eq!(run_ok(SRT_2ARG, "f", &["[-3,1,-5,2]"]), "[1, 2, -3, -5]"); +} + +// ── verifier: 3-arg srt with 1-arg fn is rejected ────────────────────────── + +#[test] +fn srt_3arg_rejects_1arg_fn() { + let src = "abs1 x:n>n;abs x\nf c:n xs:L n>L n;srt abs1 c xs"; + let err = run_err(src, "f"); + assert!( + err.contains("srt") && (err.contains("2 args") || err.contains("closure-bind")), + "expected verifier error about srt fn arity, got: {err}" + ); +} + +// ── verifier: 3-arg map with 1-arg fn is rejected ────────────────────────── + +#[test] +fn map_3arg_rejects_1arg_fn() { + let src = "abs1 x:n>n;abs x\nf c:n xs:L n>L n;map abs1 c xs"; + let err = run_err(src, "f"); + assert!( + err.contains("map") && (err.contains("2 args") || err.contains("closure-bind")), + "expected verifier error about map fn arity, got: {err}" + ); +} + +// ── verifier: 4-arg fld with 2-arg fn is rejected ────────────────────────── + +#[test] +fn fld_4arg_rejects_2arg_fn() { + let src = "add a:n b:n>n;+a b\nf c:n xs:L n>n;fld add c xs 0"; + let err = run_err(src, "f"); + assert!( + err.contains("fld") && (err.contains("3 args") || err.contains("closure-bind")), + "expected verifier error about fld fn arity, got: {err}" + ); +}