Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions examples/closure-bind.ilo
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
-- Closure-bind: HOFs (srt, map, flt, fld) accept an optional ctx arg
-- that is passed to every invocation of the fn.
--
-- Pattern shown: sort symbols by an external priority lookup map.
-- Without closure-bind, you'd have to zip the priority into each element
-- (records of {sym, pri}) before sorting, then unzip after — recurring
-- per-program tax that closure-bind removes.

-- Key fn: takes element + ctx (the lookup map). Missing keys sort last.
pri sym:t m:M t n>n;r=mget m sym;?r{n v:v;_:99999}

-- 3-arg srt: srt key-fn ctx xs
top pri-map:M t n syms:L t>L t;srt pri pri-map syms

-- Build a small priority map and sort by it.
main>L t;m=mset mmap "a" 2;m=mset m "b" 3;m=mset m "c" 1;top m ["a","b","c"]

-- map closure-bind: enrich each symbol with its price via lookup
look sym:t pm:M t n>n;r=mget pm sym;?r{n v:v;_:0}
prices pm:M t n syms:L t>L n;map look pm syms

prices-demo>L n;m=mset mmap "a" 10;m=mset m "b" 20;prices m ["a","b","a"]

-- engine-skip: vm
-- engine-skip: jit
-- engine-skip: cranelift
-- run: main
-- out: [c, a, b]
-- run: prices-demo
-- out: [10, 20, 10]
73 changes: 56 additions & 17 deletions src/interpreter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -819,7 +819,7 @@ fn call_function(env: &mut Env, name: &str, args: Vec<Value>) -> Result<Value> {
)),
};
}
if builtin == Some(Builtin::Srt) && args.len() == 2 {
if builtin == Some(Builtin::Srt) && (args.len() == 2 || args.len() == 3) {
let fn_name = resolve_fn_ref(&args[0]).ok_or_else(|| {
RuntimeError::new(
"ILO-R009",
Expand All @@ -829,20 +829,30 @@ fn call_function(env: &mut Env, name: &str, args: Vec<Value>) -> Result<Value> {
),
)
})?;
let items = match &args[1] {
// closure-bind: srt fn ctx xs
let (ctx, list_arg) = if args.len() == 3 {
(Some(args[1].clone()), &args[2])
} else {
(None, &args[1])
};
let items = match list_arg {
Value::List(l) => l.clone(),
other => {
return Err(RuntimeError::new(
"ILO-R009",
format!("srt: second arg must be a list, got {:?}", other),
format!("srt: list arg must be a list, got {:?}", other),
));
}
};
// Compute keys for each item, then sort by key
let mut keyed: Vec<(Value, Value)> = items
.into_iter()
.map(|item| {
let key = call_function(env, &fn_name, vec![item.clone()])?;
let call_args = match &ctx {
Some(c) => vec![item.clone(), c.clone()],
None => vec![item.clone()],
};
let key = call_function(env, &fn_name, call_args)?;
Ok((key, item))
})
.collect::<Result<_>>()?;
Expand Down Expand Up @@ -1390,7 +1400,7 @@ fn call_function(env: &mut Env, name: &str, args: Vec<Value>) -> Result<Value> {
_ => None,
}
}
if builtin == Some(Builtin::Map) && args.len() == 2 {
if builtin == Some(Builtin::Map) && (args.len() == 2 || args.len() == 3) {
let fn_name = resolve_fn_ref(&args[0]).ok_or_else(|| {
RuntimeError::new(
"ILO-R009",
Expand All @@ -1400,22 +1410,32 @@ fn call_function(env: &mut Env, name: &str, args: Vec<Value>) -> Result<Value> {
),
)
})?;
let items = match &args[1] {
// closure-bind: map fn ctx xs
let (ctx, list_arg) = if args.len() == 3 {
(Some(args[1].clone()), &args[2])
} else {
(None, &args[1])
};
let items = match list_arg {
Value::List(l) => l.clone(),
other => {
return Err(RuntimeError::new(
"ILO-R009",
format!("map: second arg must be a list, got {:?}", other),
format!("map: list arg must be a list, got {:?}", other),
));
}
};
let mut result = Vec::with_capacity(items.len());
for item in items {
result.push(call_function(env, &fn_name, vec![item])?);
let call_args = match &ctx {
Some(c) => vec![item, c.clone()],
None => vec![item],
};
result.push(call_function(env, &fn_name, call_args)?);
}
return Ok(Value::List(result));
}
if builtin == Some(Builtin::Flt) && args.len() == 2 {
if builtin == Some(Builtin::Flt) && (args.len() == 2 || args.len() == 3) {
let fn_name = resolve_fn_ref(&args[0]).ok_or_else(|| {
RuntimeError::new(
"ILO-R009",
Expand All @@ -1425,18 +1445,27 @@ fn call_function(env: &mut Env, name: &str, args: Vec<Value>) -> Result<Value> {
),
)
})?;
let items = match &args[1] {
let (ctx, list_arg) = if args.len() == 3 {
(Some(args[1].clone()), &args[2])
} else {
(None, &args[1])
};
let items = match list_arg {
Value::List(l) => l.clone(),
other => {
return Err(RuntimeError::new(
"ILO-R009",
format!("flt: second arg must be a list, got {:?}", other),
format!("flt: list arg must be a list, got {:?}", other),
));
}
};
let mut result = Vec::new();
for item in items {
match call_function(env, &fn_name, vec![item.clone()])? {
let call_args = match &ctx {
Some(c) => vec![item.clone(), c.clone()],
None => vec![item.clone()],
};
match call_function(env, &fn_name, call_args)? {
Value::Bool(true) => result.push(item),
Value::Bool(false) => {}
other => {
Expand All @@ -1449,7 +1478,7 @@ fn call_function(env: &mut Env, name: &str, args: Vec<Value>) -> Result<Value> {
}
return Ok(Value::List(result));
}
if builtin == Some(Builtin::Fld) && args.len() == 3 {
if builtin == Some(Builtin::Fld) && (args.len() == 3 || args.len() == 4) {
let fn_name = resolve_fn_ref(&args[0]).ok_or_else(|| {
RuntimeError::new(
"ILO-R009",
Expand All @@ -1459,18 +1488,28 @@ fn call_function(env: &mut Env, name: &str, args: Vec<Value>) -> Result<Value> {
),
)
})?;
let items = match &args[1] {
// closure-bind: fld fn ctx xs init
let (ctx, list_arg, init) = if args.len() == 4 {
(Some(args[1].clone()), &args[2], args[3].clone())
} else {
(None, &args[1], args[2].clone())
};
let items = match list_arg {
Value::List(l) => l.clone(),
other => {
return Err(RuntimeError::new(
"ILO-R009",
format!("fld: second arg must be a list, got {:?}", other),
format!("fld: list arg must be a list, got {:?}", other),
));
}
};
let mut acc = args[2].clone();
let mut acc = init;
for item in items {
acc = call_function(env, &fn_name, vec![acc, item])?;
let call_args = match &ctx {
Some(c) => vec![acc, item, c.clone()],
None => vec![acc, item],
};
acc = call_function(env, &fn_name, call_args)?;
}
return Ok(acc);
}
Expand Down
120 changes: 111 additions & 9 deletions src/verify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,42 @@ fn builtin_check_args(
(Ty::Text, errors)
}
"srt" => {
if arg_types.len() == 3 {
// srt key-fn ctx xs — closure-bind variant: fn takes (elem, ctx)
if let Some(fn_ty) = arg_types.first()
&& !matches!(fn_ty, Ty::Fn(_, _) | Ty::Unknown)
{
errors.push(VerifyError {
code: "ILO-T013",
function: func_ctx.to_string(),
message: format!("'srt' key arg must be a function (F ...), got {fn_ty}"),
hint: Some("pass a function name: srt key-fn ctx xs".to_string()),
span,
is_warning: false,
});
}
// fn must accept 2 args
if let Some(Ty::Fn(params, _)) = arg_types.first()
&& params.len() != 2
{
errors.push(VerifyError {
code: "ILO-T013",
function: func_ctx.to_string(),
message: format!(
"'srt' key fn must take 2 args (elem, ctx) for closure-bind variant, got {} args",
params.len()
),
hint: Some("for srt fn ctx xs, fn must be: F a c b".to_string()),
span,
is_warning: false,
});
}
let ret = match arg_types.get(2) {
Some(ty @ Ty::List(_)) => ty.clone(),
_ => Ty::Unknown,
};
return (ret, errors);
}
if arg_types.len() == 2 {
// srt key-fn xs — sort by key function
if let Some(fn_ty) = arg_types.first()
Expand Down Expand Up @@ -915,7 +951,7 @@ fn builtin_check_args(
}
"map" => {
// map fn:F a b xs:L a → L b
// First arg must be a function type; second must be a list.
// map fn:F a c b ctx:c xs:L a → L b (closure-bind variant)
if let Some(fn_ty) = arg_types.first()
&& !matches!(fn_ty, Ty::Fn(_, _) | Ty::Unknown)
{
Expand All @@ -928,6 +964,23 @@ fn builtin_check_args(
is_warning: false,
});
}
// closure-bind: fn must take 2 args (elem, ctx)
if arg_types.len() == 3
&& let Some(Ty::Fn(params, _)) = arg_types.first()
&& params.len() != 2
{
errors.push(VerifyError {
code: "ILO-T013",
function: func_ctx.to_string(),
message: format!(
"'map' fn must take 2 args (elem, ctx) for closure-bind variant, got {} args",
params.len()
),
hint: Some("for map fn ctx xs, fn must be: F a c b".to_string()),
span,
is_warning: false,
});
}
// Return type: L of the function's return type, or L Unknown
let ret_elem = match arg_types.first() {
Some(Ty::Fn(_, ret)) => *ret.clone(),
Expand All @@ -937,7 +990,7 @@ fn builtin_check_args(
}
"flt" => {
// flt fn:F a b xs:L a → L a
// First arg: function returning bool; second: list.
// flt fn:F a c b ctx:c xs:L a → L a (closure-bind variant)
if let Some(fn_ty) = arg_types.first()
&& !matches!(fn_ty, Ty::Fn(_, _) | Ty::Unknown)
{
Expand All @@ -950,16 +1003,33 @@ fn builtin_check_args(
is_warning: false,
});
}
// Return type: same list type as input
let ret = match arg_types.get(1) {
if arg_types.len() == 3
&& let Some(Ty::Fn(params, _)) = arg_types.first()
&& params.len() != 2
{
errors.push(VerifyError {
code: "ILO-T013",
function: func_ctx.to_string(),
message: format!(
"'flt' fn must take 2 args (elem, ctx) for closure-bind variant, got {} args",
params.len()
),
hint: Some("for flt fn ctx xs, fn must be: F a c b".to_string()),
span,
is_warning: false,
});
}
// Return type: same list type as input (last arg position)
let list_idx = if arg_types.len() == 3 { 2 } else { 1 };
let ret = match arg_types.get(list_idx) {
Some(ty @ Ty::List(_)) => ty.clone(),
_ => Ty::Unknown,
};
(ret, errors)
}
"fld" => {
// fld fn:F a b b xs:L a init:b → b
// First arg: function; second: list; third: initial accumulator.
// fld fn:F a c b b ctx:c xs:L a init:b → b (closure-bind variant)
if let Some(fn_ty) = arg_types.first()
&& !matches!(fn_ty, Ty::Fn(_, _) | Ty::Unknown)
{
Expand All @@ -972,8 +1042,25 @@ fn builtin_check_args(
is_warning: false,
});
}
// Return type: accumulator type (third arg) or function return type
let ret = match arg_types.get(2) {
if arg_types.len() == 4
&& let Some(Ty::Fn(params, _)) = arg_types.first()
&& params.len() != 3
{
errors.push(VerifyError {
code: "ILO-T013",
function: func_ctx.to_string(),
message: format!(
"'fld' fn must take 3 args (acc, elem, ctx) for closure-bind variant, got {} args",
params.len()
),
hint: Some("for fld fn ctx xs init, fn must be: F b a c b".to_string()),
span,
is_warning: false,
});
}
// Return type: accumulator type (last arg) or function return type
let init_idx = if arg_types.len() == 4 { 3 } else { 2 };
let ret = match arg_types.get(init_idx) {
Some(ty) if !matches!(ty, Ty::Unknown) => ty.clone(),
_ => match arg_types.first() {
Some(Ty::Fn(_, ret)) => *ret.clone(),
Expand Down Expand Up @@ -1814,7 +1901,16 @@ impl VerifyContext {
builtin_arity(callee).expect("is_builtin guarantees arity exists");
let arity_ok = if callee == "rnd" {
args.is_empty() || args.len() == 2
} else if callee == "srt" || callee == "rd" {
} else if callee == "srt" {
// srt xs / srt fn xs / srt fn ctx xs
args.len() == 1 || args.len() == 2 || args.len() == 3
} else if callee == "map" || callee == "flt" {
// map fn xs / map fn ctx xs (closure-bind variant)
args.len() == 2 || args.len() == 3
} else if callee == "fld" {
// fld fn xs init / fld fn ctx xs init (closure-bind variant)
args.len() == 3 || args.len() == 4
} else if callee == "rd" {
args.len() == 1 || args.len() == 2
} else if callee == "wr" {
args.len() == 2 || args.len() == 3
Expand All @@ -1830,7 +1926,13 @@ impl VerifyContext {
if !arity_ok {
let arity_desc = if callee == "rnd" {
"0 or 2".to_string()
} else if callee == "srt" || callee == "rd" || callee == "get" {
} else if callee == "srt" {
"1, 2, or 3".to_string()
} else if callee == "map" || callee == "flt" {
"2 or 3".to_string()
} else if callee == "fld" {
"3 or 4".to_string()
} else if callee == "rd" || callee == "get" {
"1 or 2".to_string()
} else if callee == "post" {
"2 or 3".to_string()
Expand Down
Loading
Loading