Skip to content

Commit

Permalink
feat: high order lambda closure captures
Browse files Browse the repository at this point in the history
Signed-off-by: peefy <xpf6677@163.com>
  • Loading branch information
Peefy committed Mar 7, 2024
1 parent 8cb11e2 commit f4b230f
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 35 deletions.
81 changes: 49 additions & 32 deletions kclvm/compiler/src/codegen/llvm/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use kclvm_sema::pkgpath_without_prefix;
use kclvm_sema::plugin;

use crate::codegen::abi::Align;
use crate::codegen::{error as kcl_error, EmitOptions, INNER_LEVEL};
use crate::codegen::{error as kcl_error, EmitOptions};
use crate::codegen::{
traits::*, ENTRY_NAME, GLOBAL_VAL_ALIGNMENT, MODULE_NAME, PKG_INIT_FUNCTION_SUFFIX,
};
Expand Down Expand Up @@ -85,7 +85,7 @@ pub struct LLVMCodeGenContext<'ctx> {
pub imported: RefCell<HashSet<String>>,
pub local_vars: RefCell<HashSet<String>>,
pub schema_stack: RefCell<Vec<value::SchemaType>>,
pub lambda_stack: RefCell<Vec<bool>>,
pub lambda_stack: RefCell<Vec<usize>>,
pub schema_expr_stack: RefCell<Vec<()>>,
pub pkgpath_stack: RefCell<Vec<String>>,
pub filename_stack: RefCell<Vec<String>>,
Expand Down Expand Up @@ -1234,7 +1234,8 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
imported: RefCell::new(HashSet::new()),
local_vars: RefCell::new(HashSet::new()),
schema_stack: RefCell::new(vec![]),
lambda_stack: RefCell::new(vec![false]),
// 0 denotes the top global main function lambda.
lambda_stack: RefCell::new(vec![0]),
schema_expr_stack: RefCell::new(vec![]),
pkgpath_stack: RefCell::new(vec![String::from(MAIN_PKG_PATH)]),
filename_stack: RefCell::new(vec![String::from("")]),
Expand Down Expand Up @@ -1944,7 +1945,12 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
// Closure vars, 2 denotes the builtin scope and the global scope
let value = if i >= 1 && i < scopes_len - 2 {
closures_mut.insert(name.to_string(), *var);
let variables = last_scopes.variables.borrow();
let last_lambda_scope = *self
.lambda_stack
.borrow()
.last()
.expect(kcl_error::INTERNAL_ERROR_MSG);
let variables = scopes[last_lambda_scope].variables.borrow();
let ptr = variables.get(value::LAMBDA_CLOSURE);
// Lambda closure
match ptr {
Expand Down Expand Up @@ -2036,11 +2042,15 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
Ok(value)
}

/// Get closure map in the current scope.
pub(crate) fn get_closure_map(&self) -> BasicValueEnum<'ctx> {
// Get closures in the current scope.
let closure_map = self.dict_value();
{
/// Get closure map in the current inner scope.
pub(crate) fn get_current_inner_scope_variable_map(&self) -> BasicValueEnum<'ctx> {
let var_map = {
let last_lambda_scope = *self
.lambda_stack
.borrow()
.last()
.expect(kcl_error::INTERNAL_ERROR_MSG);
// Get variable map in the current scope.
let pkgpath = self.current_pkgpath();
let pkgpath = if !pkgpath.starts_with(PKG_PATH_PREFIX) && pkgpath != MAIN_PKG_PATH {
format!("{}{}", PKG_PATH_PREFIX, pkgpath)
Expand All @@ -2051,39 +2061,46 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
let scopes = pkg_scopes
.get(&pkgpath)
.unwrap_or_else(|| panic!("package {} is not found", pkgpath));
// Clouure variable must be inner of the global scope.
if scopes.len() > INNER_LEVEL {
let closures = scopes
.last()
.expect(kcl_error::INTERNAL_ERROR_MSG)
.closures
.borrow();
// Curret scope vaiable.
let variables = scopes
.get(scopes.len() - INNER_LEVEL)
.expect(kcl_error::INTERNAL_ERROR_MSG)
.variables
.borrow();
// Transverse all scope and capture closures except the builtin amd global scope.
for (key, ptr) in &*closures {
if variables.contains_key(key) {
let value = self.builder.build_load(*ptr, "");
self.dict_insert_override_item(closure_map, key.as_str(), value);
let current_scope = scopes.len() - 1;
// Get last closure map.
let var_map = if current_scope >= last_lambda_scope && last_lambda_scope > 0 {
let variables = scopes[last_lambda_scope].variables.borrow();
let ptr = variables.get(value::LAMBDA_CLOSURE);
let var_map = match ptr {
Some(ptr) => self.builder.build_load(*ptr, ""),
None => self.dict_value(),
};
// Get variable map including schema in the current scope.
for i in last_lambda_scope..current_scope + 1 {
let variables = scopes
.get(i)
.expect(kcl_error::INTERNAL_ERROR_MSG)
.variables
.borrow();
for (key, ptr) in &*variables {
if key != value::LAMBDA_CLOSURE {
let value = self.builder.build_load(*ptr, "");
self.dict_insert_override_item(var_map, key.as_str(), value);
}
}
} // Curret scope vaiable.
}
}
}
var_map
} else {
self.dict_value()
};
var_map
};
// Capture schema `self` closure.
let is_in_schema = self.schema_stack.borrow().len() > 0;
if is_in_schema {
for shcmea_closure_name in value::SCHEMA_VARIABLE_LIST {
let value = self
.get_variable(shcmea_closure_name)
.expect(kcl_error::INTERNAL_ERROR_MSG);
self.dict_insert_override_item(closure_map, shcmea_closure_name, value);
self.dict_insert_override_item(var_map, shcmea_closure_name, value);
}
}
closure_map
var_map
}

/// Push a function call frame into the function stack
Expand Down
13 changes: 10 additions & 3 deletions kclvm/compiler/src/codegen/llvm/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ impl<'ctx> TypedResultWalker<'ctx> for LLVMCodeGenContext<'ctx> {
.borrow_mut()
.last()
.expect(kcl_error::INTERNAL_ERROR_MSG)
> 0
{
if self.resolve_variable(name) {
let org_value = self
Expand Down Expand Up @@ -2136,6 +2137,9 @@ impl<'ctx> TypedResultWalker<'ctx> for LLVMCodeGenContext<'ctx> {
check_backtrack_stop!(self);
let pkgpath = &self.current_pkgpath();
let is_in_schema = self.schema_stack.borrow().len() > 0;
// Higher-order lambda requires capturing the current lambda closure variable
// as well as the closure of a more external scope.
let last_closure_map = self.get_current_inner_scope_variable_map();
let func_before_block = self.append_block("");
self.br(func_before_block);
// Use "pkgpath"+"kclvm_lambda" to name 'function' to prevent conflicts between lambdas with the same name in different packages
Expand All @@ -2146,7 +2150,8 @@ impl<'ctx> TypedResultWalker<'ctx> for LLVMCodeGenContext<'ctx> {
));
// Enter the function
self.push_function(function);
self.lambda_stack.borrow_mut().push(true);
// Push the current lambda scope level in the lambda stack.
self.lambda_stack.borrow_mut().push(self.scope_level() + 1);
// Lambda function body
let block = self.context.append_basic_block(function, ENTRY_NAME);
self.builder.position_at_end(block);
Expand Down Expand Up @@ -2191,7 +2196,7 @@ impl<'ctx> TypedResultWalker<'ctx> for LLVMCodeGenContext<'ctx> {
let closure = self.list_value();
// Use closure map in the laste scope to construct curret closure map.
// The default value of the closure map is `{}`.
self.list_append(closure, self.get_closure_map());
self.list_append(closure, last_closure_map);
let function = self.closure_value(function, closure);
self.leave_scope();
self.pop_function();
Expand Down Expand Up @@ -2446,6 +2451,7 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
.borrow_mut()
.last()
.expect(kcl_error::INTERNAL_ERROR_MSG)
> 0
{
let value = right_value.expect(kcl_error::INTERNAL_ERROR_MSG);
// If variable exists in the scope and update it, if not, add it to the scope.
Expand Down Expand Up @@ -2614,7 +2620,8 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
.lambda_stack
.borrow()
.last()
.expect(kcl_error::INTERNAL_ERROR_MSG);
.expect(kcl_error::INTERNAL_ERROR_MSG)
> 0;
// Set config value for the schema attribute if the attribute is in the schema and
// it is not a local variable in the lambda function.
if self.scope_level() >= INNER_LEVEL
Expand Down
14 changes: 14 additions & 0 deletions test/grammar/lambda/closure_6/main.k
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
func = lambda config: {str:} {
x = 1
lambda {
y = 1
lambda {
z = 1
lambda {
x + y + z + config.key
}()
}()
}()
}

x = func({key = 1})
1 change: 1 addition & 0 deletions test/grammar/lambda/closure_6/stdout.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
x: 4
14 changes: 14 additions & 0 deletions test/grammar/lambda/closure_7/main.k
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
func = lambda config: {str:} {
x = 1
lambda {
y = 1
lambda {
z = 1
lambda {
{value = x + y + z + config.key}
}()
}()
}()
}

x = func({key = 1})
2 changes: 2 additions & 0 deletions test/grammar/lambda/closure_7/stdout.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
x:
value: 4
8 changes: 8 additions & 0 deletions test/grammar/lambda/closure_8/main.k
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
items = [1, 2, 3]
func = lambda config: {str:} {
[lambda {
config.key + i
}() for i in items]
}

x = func({key = 1})
8 changes: 8 additions & 0 deletions test/grammar/lambda/closure_8/stdout.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
items:
- 1
- 2
- 3
x:
- 2
- 3
- 4

0 comments on commit f4b230f

Please sign in to comment.