diff --git a/examples/closure-bind.ilo b/examples/closure-bind.ilo new file mode 100644 index 00000000..aebdf3f8 --- /dev/null +++ b/examples/closure-bind.ilo @@ -0,0 +1,30 @@ +-- Closure-bind: HOFs (srt, map, flt, fld) accept an optional ctx arg +-- that is passed to every invocation of the fn. +-- +-- Pattern shown: sort symbols by an external priority lookup map. +-- Without closure-bind, you'd have to zip the priority into each element +-- (records of {sym, pri}) before sorting, then unzip after — recurring +-- per-program tax that closure-bind removes. + +-- Key fn: takes element + ctx (the lookup map). Missing keys sort last. +pri sym:t m:M t n>n;r=mget m sym;?r{n v:v;_:99999} + +-- 3-arg srt: srt key-fn ctx xs +top pri-map:M t n syms:L t>L t;srt pri pri-map syms + +-- Build a small priority map and sort by it. +main>L t;m=mset mmap "a" 2;m=mset m "b" 3;m=mset m "c" 1;top m ["a","b","c"] + +-- map closure-bind: enrich each symbol with its price via lookup +look sym:t pm:M t n>n;r=mget pm sym;?r{n v:v;_:0} +prices pm:M t n syms:L t>L n;map look pm syms + +prices-demo>L n;m=mset mmap "a" 10;m=mset m "b" 20;prices m ["a","b","a"] + +-- engine-skip: vm +-- engine-skip: jit +-- engine-skip: cranelift +-- run: main +-- out: [c, a, b] +-- run: prices-demo +-- out: [10, 20, 10] diff --git a/src/interpreter/mod.rs b/src/interpreter/mod.rs index f507b04e..4c977c38 100644 --- a/src/interpreter/mod.rs +++ b/src/interpreter/mod.rs @@ -819,7 +819,7 @@ fn call_function(env: &mut Env, name: &str, args: Vec) -> Result { )), }; } - if builtin == Some(Builtin::Srt) && args.len() == 2 { + if builtin == Some(Builtin::Srt) && (args.len() == 2 || args.len() == 3) { let fn_name = resolve_fn_ref(&args[0]).ok_or_else(|| { RuntimeError::new( "ILO-R009", @@ -829,12 +829,18 @@ fn call_function(env: &mut Env, name: &str, args: Vec) -> Result { ), ) })?; - let items = match &args[1] { + // closure-bind: srt fn ctx xs + let (ctx, list_arg) = if args.len() == 3 { + (Some(args[1].clone()), &args[2]) + } else { + (None, &args[1]) + }; + let items = match list_arg { Value::List(l) => l.clone(), other => { return Err(RuntimeError::new( "ILO-R009", - format!("srt: second arg must be a list, got {:?}", other), + format!("srt: list arg must be a list, got {:?}", other), )); } }; @@ -842,7 +848,11 @@ fn call_function(env: &mut Env, name: &str, args: Vec) -> Result { let mut keyed: Vec<(Value, Value)> = items .into_iter() .map(|item| { - let key = call_function(env, &fn_name, vec![item.clone()])?; + let call_args = match &ctx { + Some(c) => vec![item.clone(), c.clone()], + None => vec![item.clone()], + }; + let key = call_function(env, &fn_name, call_args)?; Ok((key, item)) }) .collect::>()?; @@ -1390,7 +1400,7 @@ fn call_function(env: &mut Env, name: &str, args: Vec) -> Result { _ => None, } } - if builtin == Some(Builtin::Map) && args.len() == 2 { + if builtin == Some(Builtin::Map) && (args.len() == 2 || args.len() == 3) { let fn_name = resolve_fn_ref(&args[0]).ok_or_else(|| { RuntimeError::new( "ILO-R009", @@ -1400,22 +1410,32 @@ fn call_function(env: &mut Env, name: &str, args: Vec) -> Result { ), ) })?; - let items = match &args[1] { + // closure-bind: map fn ctx xs + let (ctx, list_arg) = if args.len() == 3 { + (Some(args[1].clone()), &args[2]) + } else { + (None, &args[1]) + }; + let items = match list_arg { Value::List(l) => l.clone(), other => { return Err(RuntimeError::new( "ILO-R009", - format!("map: second arg must be a list, got {:?}", other), + format!("map: list arg must be a list, got {:?}", other), )); } }; let mut result = Vec::with_capacity(items.len()); for item in items { - result.push(call_function(env, &fn_name, vec![item])?); + let call_args = match &ctx { + Some(c) => vec![item, c.clone()], + None => vec![item], + }; + result.push(call_function(env, &fn_name, call_args)?); } return Ok(Value::List(result)); } - if builtin == Some(Builtin::Flt) && args.len() == 2 { + if builtin == Some(Builtin::Flt) && (args.len() == 2 || args.len() == 3) { let fn_name = resolve_fn_ref(&args[0]).ok_or_else(|| { RuntimeError::new( "ILO-R009", @@ -1425,18 +1445,27 @@ fn call_function(env: &mut Env, name: &str, args: Vec) -> Result { ), ) })?; - let items = match &args[1] { + let (ctx, list_arg) = if args.len() == 3 { + (Some(args[1].clone()), &args[2]) + } else { + (None, &args[1]) + }; + let items = match list_arg { Value::List(l) => l.clone(), other => { return Err(RuntimeError::new( "ILO-R009", - format!("flt: second arg must be a list, got {:?}", other), + format!("flt: list arg must be a list, got {:?}", other), )); } }; let mut result = Vec::new(); for item in items { - match call_function(env, &fn_name, vec![item.clone()])? { + let call_args = match &ctx { + Some(c) => vec![item.clone(), c.clone()], + None => vec![item.clone()], + }; + match call_function(env, &fn_name, call_args)? { Value::Bool(true) => result.push(item), Value::Bool(false) => {} other => { @@ -1449,7 +1478,7 @@ fn call_function(env: &mut Env, name: &str, args: Vec) -> Result { } return Ok(Value::List(result)); } - if builtin == Some(Builtin::Fld) && args.len() == 3 { + if builtin == Some(Builtin::Fld) && (args.len() == 3 || args.len() == 4) { let fn_name = resolve_fn_ref(&args[0]).ok_or_else(|| { RuntimeError::new( "ILO-R009", @@ -1459,18 +1488,28 @@ fn call_function(env: &mut Env, name: &str, args: Vec) -> Result { ), ) })?; - let items = match &args[1] { + // closure-bind: fld fn ctx xs init + let (ctx, list_arg, init) = if args.len() == 4 { + (Some(args[1].clone()), &args[2], args[3].clone()) + } else { + (None, &args[1], args[2].clone()) + }; + let items = match list_arg { Value::List(l) => l.clone(), other => { return Err(RuntimeError::new( "ILO-R009", - format!("fld: second arg must be a list, got {:?}", other), + format!("fld: list arg must be a list, got {:?}", other), )); } }; - let mut acc = args[2].clone(); + let mut acc = init; for item in items { - acc = call_function(env, &fn_name, vec![acc, item])?; + let call_args = match &ctx { + Some(c) => vec![acc, item, c.clone()], + None => vec![acc, item], + }; + acc = call_function(env, &fn_name, call_args)?; } return Ok(acc); } diff --git a/src/verify.rs b/src/verify.rs index 9dc39aa2..3e3b18f7 100644 --- a/src/verify.rs +++ b/src/verify.rs @@ -657,6 +657,42 @@ fn builtin_check_args( (Ty::Text, errors) } "srt" => { + if arg_types.len() == 3 { + // srt key-fn ctx xs — closure-bind variant: fn takes (elem, ctx) + if let Some(fn_ty) = arg_types.first() + && !matches!(fn_ty, Ty::Fn(_, _) | Ty::Unknown) + { + errors.push(VerifyError { + code: "ILO-T013", + function: func_ctx.to_string(), + message: format!("'srt' key arg must be a function (F ...), got {fn_ty}"), + hint: Some("pass a function name: srt key-fn ctx xs".to_string()), + span, + is_warning: false, + }); + } + // fn must accept 2 args + if let Some(Ty::Fn(params, _)) = arg_types.first() + && params.len() != 2 + { + errors.push(VerifyError { + code: "ILO-T013", + function: func_ctx.to_string(), + message: format!( + "'srt' key fn must take 2 args (elem, ctx) for closure-bind variant, got {} args", + params.len() + ), + hint: Some("for srt fn ctx xs, fn must be: F a c b".to_string()), + span, + is_warning: false, + }); + } + let ret = match arg_types.get(2) { + Some(ty @ Ty::List(_)) => ty.clone(), + _ => Ty::Unknown, + }; + return (ret, errors); + } if arg_types.len() == 2 { // srt key-fn xs — sort by key function if let Some(fn_ty) = arg_types.first() @@ -915,7 +951,7 @@ fn builtin_check_args( } "map" => { // map fn:F a b xs:L a → L b - // First arg must be a function type; second must be a list. + // map fn:F a c b ctx:c xs:L a → L b (closure-bind variant) if let Some(fn_ty) = arg_types.first() && !matches!(fn_ty, Ty::Fn(_, _) | Ty::Unknown) { @@ -928,6 +964,23 @@ fn builtin_check_args( is_warning: false, }); } + // closure-bind: fn must take 2 args (elem, ctx) + if arg_types.len() == 3 + && let Some(Ty::Fn(params, _)) = arg_types.first() + && params.len() != 2 + { + errors.push(VerifyError { + code: "ILO-T013", + function: func_ctx.to_string(), + message: format!( + "'map' fn must take 2 args (elem, ctx) for closure-bind variant, got {} args", + params.len() + ), + hint: Some("for map fn ctx xs, fn must be: F a c b".to_string()), + span, + is_warning: false, + }); + } // Return type: L of the function's return type, or L Unknown let ret_elem = match arg_types.first() { Some(Ty::Fn(_, ret)) => *ret.clone(), @@ -937,7 +990,7 @@ fn builtin_check_args( } "flt" => { // flt fn:F a b xs:L a → L a - // First arg: function returning bool; second: list. + // flt fn:F a c b ctx:c xs:L a → L a (closure-bind variant) if let Some(fn_ty) = arg_types.first() && !matches!(fn_ty, Ty::Fn(_, _) | Ty::Unknown) { @@ -950,8 +1003,25 @@ fn builtin_check_args( is_warning: false, }); } - // Return type: same list type as input - let ret = match arg_types.get(1) { + if arg_types.len() == 3 + && let Some(Ty::Fn(params, _)) = arg_types.first() + && params.len() != 2 + { + errors.push(VerifyError { + code: "ILO-T013", + function: func_ctx.to_string(), + message: format!( + "'flt' fn must take 2 args (elem, ctx) for closure-bind variant, got {} args", + params.len() + ), + hint: Some("for flt fn ctx xs, fn must be: F a c b".to_string()), + span, + is_warning: false, + }); + } + // Return type: same list type as input (last arg position) + let list_idx = if arg_types.len() == 3 { 2 } else { 1 }; + let ret = match arg_types.get(list_idx) { Some(ty @ Ty::List(_)) => ty.clone(), _ => Ty::Unknown, }; @@ -959,7 +1029,7 @@ fn builtin_check_args( } "fld" => { // fld fn:F a b b xs:L a init:b → b - // First arg: function; second: list; third: initial accumulator. + // fld fn:F a c b b ctx:c xs:L a init:b → b (closure-bind variant) if let Some(fn_ty) = arg_types.first() && !matches!(fn_ty, Ty::Fn(_, _) | Ty::Unknown) { @@ -972,8 +1042,25 @@ fn builtin_check_args( is_warning: false, }); } - // Return type: accumulator type (third arg) or function return type - let ret = match arg_types.get(2) { + if arg_types.len() == 4 + && let Some(Ty::Fn(params, _)) = arg_types.first() + && params.len() != 3 + { + errors.push(VerifyError { + code: "ILO-T013", + function: func_ctx.to_string(), + message: format!( + "'fld' fn must take 3 args (acc, elem, ctx) for closure-bind variant, got {} args", + params.len() + ), + hint: Some("for fld fn ctx xs init, fn must be: F b a c b".to_string()), + span, + is_warning: false, + }); + } + // Return type: accumulator type (last arg) or function return type + let init_idx = if arg_types.len() == 4 { 3 } else { 2 }; + let ret = match arg_types.get(init_idx) { Some(ty) if !matches!(ty, Ty::Unknown) => ty.clone(), _ => match arg_types.first() { Some(Ty::Fn(_, ret)) => *ret.clone(), @@ -1814,7 +1901,16 @@ impl VerifyContext { builtin_arity(callee).expect("is_builtin guarantees arity exists"); let arity_ok = if callee == "rnd" { args.is_empty() || args.len() == 2 - } else if callee == "srt" || callee == "rd" { + } else if callee == "srt" { + // srt xs / srt fn xs / srt fn ctx xs + args.len() == 1 || args.len() == 2 || args.len() == 3 + } else if callee == "map" || callee == "flt" { + // map fn xs / map fn ctx xs (closure-bind variant) + args.len() == 2 || args.len() == 3 + } else if callee == "fld" { + // fld fn xs init / fld fn ctx xs init (closure-bind variant) + args.len() == 3 || args.len() == 4 + } else if callee == "rd" { args.len() == 1 || args.len() == 2 } else if callee == "wr" { args.len() == 2 || args.len() == 3 @@ -1830,7 +1926,13 @@ impl VerifyContext { if !arity_ok { let arity_desc = if callee == "rnd" { "0 or 2".to_string() - } else if callee == "srt" || callee == "rd" || callee == "get" { + } else if callee == "srt" { + "1, 2, or 3".to_string() + } else if callee == "map" || callee == "flt" { + "2 or 3".to_string() + } else if callee == "fld" { + "3 or 4".to_string() + } else if callee == "rd" || callee == "get" { "1 or 2".to_string() } else if callee == "post" { "2 or 3".to_string() diff --git a/tests/regression_closure_bind.rs b/tests/regression_closure_bind.rs new file mode 100644 index 00000000..56e93738 --- /dev/null +++ b/tests/regression_closure_bind.rs @@ -0,0 +1,185 @@ +// Regression tests for the closure-bind HOF variant. +// +// Every HOF (`srt`, `map`, `flt`, `fld`) accepts an optional extra `ctx` +// argument that is forwarded to each invocation of the function. This lets +// agents pass external state (lookup tables, thresholds, accumulators) into +// the callback without bundling the state into per-element records. +// +// The verifier disambiguates by arity: +// srt fn xs (2-arg) vs. srt fn ctx xs (3-arg) +// map fn xs (2-arg) vs. map fn ctx xs (3-arg) +// flt fn xs (2-arg) vs. flt fn ctx xs (3-arg) +// fld fn xs init (3-arg) vs. fld fn ctx xs init (4-arg) +// +// VM and Cranelift JIT do not implement HOF dispatch at all (see +// regression_builtins_as_hof.rs), so closure-bind is exercised on the +// tree-walking interpreter only. + +use std::process::Command; + +fn ilo() -> Command { + Command::new(env!("CARGO_BIN_EXE_ilo")) +} + +fn write_src(name: &str, src: &str) -> std::path::PathBuf { + use std::sync::atomic::{AtomicU64, Ordering}; + static COUNTER: AtomicU64 = AtomicU64::new(0); + let n = COUNTER.fetch_add(1, Ordering::Relaxed); + let mut path = std::env::temp_dir(); + path.push(format!("ilo_cbind_{name}_{}_{n}.ilo", std::process::id())); + std::fs::write(&path, src).expect("write src"); + path +} + +fn run_ok(src: &str, entry: &str, args: &[&str]) -> String { + let path = write_src(entry, src); + let mut cmd = ilo(); + cmd.arg(&path).arg("--run-tree").arg(entry); + for a in args { + cmd.arg(a); + } + let out = cmd.output().expect("failed to run ilo"); + let _ = std::fs::remove_file(&path); + assert!( + out.status.success(), + "ilo failed for `{src}`: stderr={}", + String::from_utf8_lossy(&out.stderr) + ); + String::from_utf8_lossy(&out.stdout).trim().to_string() +} + +fn run_err(src: &str, entry: &str) -> String { + let path = write_src(entry, src); + let out = ilo() + .arg(&path) + .arg("--run-tree") + .arg(entry) + .output() + .expect("failed to run ilo"); + let _ = std::fs::remove_file(&path); + assert!( + !out.status.success(), + "expected failure but ilo succeeded for `{src}`" + ); + let mut s = String::from_utf8_lossy(&out.stderr).into_owned(); + s.push_str(&String::from_utf8_lossy(&out.stdout)); + s +} + +// ── srt: sort by external lookup map — the "top-N by magnitude" pattern ──── +// Sort symbols by an externally provided priority map (lower priority first). +// +// Without closure-bind, agents would have to fold the map into each list +// element (`[("a", 3), ("b", 1)]`) and write a key fn that pulls the second +// field — the per-program tax this feature removes. + +const SRT_BY_LOOKUP: &str = "\ +pri sym:t m:M t n>n;r=mget m sym;?r{n v:v;_:99999} +top pri-map:M t n syms:L t>L t;srt pri pri-map syms"; + +#[test] +fn srt_with_external_lookup() { + // pri-map: c=1, a=2, b=3 → sorted by priority: c, a, b + let src = format!( + "{SRT_BY_LOOKUP}\nmain>L t;m=mset mmap \"a\" 2;m=mset m \"b\" 3;m=mset m \"c\" 1;\ + top m [\"a\",\"b\",\"c\"]" + ); + assert_eq!(run_ok(&src, "main", &[]), "[c, a, b]"); +} + +// ── map: enrich elements via external lookup ─────────────────────────────── + +const MAP_WITH_LOOKUP: &str = "\ +look sym:t m:M t n>n;r=mget m sym;?r{n v:v;_:0} +prices pm:M t n syms:L t>L n;map look pm syms"; + +#[test] +fn map_with_lookup() { + let src = format!( + "{MAP_WITH_LOOKUP}\nmain>L n;m=mset mmap \"a\" 10;m=mset m \"b\" 20;\ + prices m [\"a\",\"b\",\"a\"]" + ); + assert_eq!(run_ok(&src, "main", &[]), "[10, 20, 10]"); +} + +// ── flt: filter by external threshold ────────────────────────────────────── + +const FLT_WITH_THRESHOLD: &str = "\ +big x:n thr:n>b;>=x thr +above t:n xs:L n>L n;flt big t xs"; + +#[test] +fn flt_with_threshold() { + assert_eq!( + run_ok( + &format!("{FLT_WITH_THRESHOLD}\nmain>L n;above 4 [1,5,3,8,2]"), + "main", + &[] + ), + "[5, 8]" + ); +} + +// ── fld: fold using external multiplier ──────────────────────────────────── + +const FLD_WITH_ACCUM: &str = "\ +add-scaled acc:n x:n k:n>n;+acc *x k +total k:n xs:L n>n;fld add-scaled k xs 0"; + +#[test] +fn fld_with_external_accumulator() { + // sum of [1,2,3] each scaled by 10 → 60 + assert_eq!( + run_ok( + &format!("{FLD_WITH_ACCUM}\nmain>n;total 10 [1,2,3]"), + "main", + &[] + ), + "60" + ); +} + +// ── 2-arg variants still work (no regression for existing programs) ──────── + +const SRT_2ARG: &str = "abs1 x:n>n;abs x\nf xs:L n>L n;srt abs1 xs"; + +#[test] +fn srt_2arg_unchanged() { + assert_eq!(run_ok(SRT_2ARG, "f", &["[-3,1,-5,2]"]), "[1, 2, -3, -5]"); +} + +// ── verifier: 3-arg srt with 1-arg fn is rejected ────────────────────────── + +#[test] +fn srt_3arg_rejects_1arg_fn() { + let src = "abs1 x:n>n;abs x\nf c:n xs:L n>L n;srt abs1 c xs"; + let err = run_err(src, "f"); + assert!( + err.contains("srt") && (err.contains("2 args") || err.contains("closure-bind")), + "expected verifier error about srt fn arity, got: {err}" + ); +} + +// ── verifier: 3-arg map with 1-arg fn is rejected ────────────────────────── + +#[test] +fn map_3arg_rejects_1arg_fn() { + let src = "abs1 x:n>n;abs x\nf c:n xs:L n>L n;map abs1 c xs"; + let err = run_err(src, "f"); + assert!( + err.contains("map") && (err.contains("2 args") || err.contains("closure-bind")), + "expected verifier error about map fn arity, got: {err}" + ); +} + +// ── verifier: 4-arg fld with 2-arg fn is rejected ────────────────────────── + +#[test] +fn fld_4arg_rejects_2arg_fn() { + let src = "add a:n b:n>n;+a b\nf c:n xs:L n>n;fld add c xs 0"; + let err = run_err(src, "f"); + assert!( + err.contains("fld") && (err.contains("3 args") || err.contains("closure-bind")), + "expected verifier error about fld fn arity, got: {err}" + ); +}