Skip to content

Commit

Permalink
feat: high order lambda closure captures (#1116)
Browse files Browse the repository at this point in the history
Signed-off-by: peefy <xpf6677@163.com>
  • Loading branch information
Peefy authored Mar 8, 2024
1 parent f11ddd4 commit 03005f3
Show file tree
Hide file tree
Showing 12 changed files with 207 additions and 94 deletions.
191 changes: 118 additions & 73 deletions kclvm/compiler/src/codegen/llvm/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use kclvm_sema::pkgpath_without_prefix;
use kclvm_sema::plugin;

use crate::codegen::abi::Align;
use crate::codegen::{error as kcl_error, EmitOptions, INNER_LEVEL};
use crate::codegen::{error as kcl_error, EmitOptions};
use crate::codegen::{
traits::*, ENTRY_NAME, GLOBAL_VAL_ALIGNMENT, MODULE_NAME, PKG_INIT_FUNCTION_SUFFIX,
};
Expand Down Expand Up @@ -85,7 +85,7 @@ pub struct LLVMCodeGenContext<'ctx> {
pub imported: RefCell<HashSet<String>>,
pub local_vars: RefCell<HashSet<String>>,
pub schema_stack: RefCell<Vec<value::SchemaType>>,
pub lambda_stack: RefCell<Vec<bool>>,
pub lambda_stack: RefCell<Vec<usize>>,
pub schema_expr_stack: RefCell<Vec<()>>,
pub pkgpath_stack: RefCell<Vec<String>>,
pub filename_stack: RefCell<Vec<String>>,
Expand Down Expand Up @@ -1234,7 +1234,8 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
imported: RefCell::new(HashSet::new()),
local_vars: RefCell::new(HashSet::new()),
schema_stack: RefCell::new(vec![]),
lambda_stack: RefCell::new(vec![false]),
// 0 denotes the top global main function lambda.
lambda_stack: RefCell::new(vec![0]),
schema_expr_stack: RefCell::new(vec![]),
pkgpath_stack: RefCell::new(vec![String::from(MAIN_PKG_PATH)]),
filename_stack: RefCell::new(vec![String::from("")]),
Expand Down Expand Up @@ -1634,12 +1635,12 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
pub fn store_variable_in_current_scope(&self, name: &str, value: BasicValueEnum<'ctx>) -> bool {
// Find argument name in the scope
let current_pkgpath = self.current_pkgpath();
let mut pkg_scopes = self.pkg_scopes.borrow_mut();
let pkg_scopes = self.pkg_scopes.borrow();
let msg = format!("pkgpath {} is not found", current_pkgpath);
let scopes = pkg_scopes.get_mut(&current_pkgpath).expect(&msg);
let scopes = pkg_scopes.get(&current_pkgpath).expect(&msg);
let index = scopes.len() - 1;
let variables_mut = scopes[index].variables.borrow_mut();
if let Some(var) = variables_mut.get(&name.to_string()) {
let variables = scopes[index].variables.borrow();
if let Some(var) = variables.get(&name.to_string()) {
self.builder.build_store(*var, value);
return true;
}
Expand All @@ -1650,13 +1651,13 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
pub fn store_variable(&self, name: &str, value: BasicValueEnum<'ctx>) -> bool {
// Find argument name in the scope
let current_pkgpath = self.current_pkgpath();
let mut pkg_scopes = self.pkg_scopes.borrow_mut();
let pkg_scopes = self.pkg_scopes.borrow();
let msg = format!("pkgpath {} is not found", current_pkgpath);
let scopes = pkg_scopes.get_mut(&current_pkgpath).expect(&msg);
let scopes = pkg_scopes.get(&current_pkgpath).expect(&msg);
for i in 0..scopes.len() {
let index = scopes.len() - i - 1;
let variables_mut = scopes[index].variables.borrow_mut();
if let Some(var) = variables_mut.get(&name.to_string()) {
let variables = scopes[index].variables.borrow();
if let Some(var) = variables.get(&name.to_string()) {
self.builder.build_store(*var, value);
return true;
}
Expand All @@ -1668,14 +1669,14 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
pub fn resolve_variable(&self, name: &str) -> bool {
// Find argument name in the scope
let current_pkgpath = self.current_pkgpath();
let mut pkg_scopes = self.pkg_scopes.borrow_mut();
let pkg_scopes = self.pkg_scopes.borrow();
let msg = format!("pkgpath {} is not found", current_pkgpath);
let scopes = pkg_scopes.get_mut(&current_pkgpath).expect(&msg);
let scopes = pkg_scopes.get(&current_pkgpath).expect(&msg);
let mut existed = false;
for i in 0..scopes.len() {
let index = scopes.len() - i - 1;
let variables_mut = scopes[index].variables.borrow_mut();
if variables_mut.get(&name.to_string()).is_some() {
let variables = scopes[index].variables.borrow();
if variables.get(&name.to_string()).is_some() {
existed = true;
break;
}
Expand Down Expand Up @@ -1729,8 +1730,8 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
let scopes = pkg_scopes.get_mut(&current_pkgpath).expect(&msg);
let mut existed = false;
if let Some(last) = scopes.last_mut() {
let variables_mut = last.variables.borrow_mut();
if let Some(var) = variables_mut.get(&name.to_string()) {
let variables = last.variables.borrow();
if let Some(var) = variables.get(&name.to_string()) {
self.builder.build_store(*var, value);
existed = true;
}
Expand Down Expand Up @@ -1791,7 +1792,7 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
self.builder.position_at_end(then_block);
let target_attr = self
.target_vars
.borrow_mut()
.borrow()
.last()
.expect(kcl_error::INTERNAL_ERROR_MSG)
.clone();
Expand Down Expand Up @@ -1851,7 +1852,7 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {

/// Get the variable value named `name` from the scope named `pkgpath`, return Err when not found
pub fn get_variable_in_pkgpath(&self, name: &str, pkgpath: &str) -> CompileResult<'ctx> {
let pkg_scopes = self.pkg_scopes.borrow_mut();
let pkg_scopes = self.pkg_scopes.borrow();
let pkgpath =
if !pkgpath.starts_with(kclvm_runtime::PKG_PATH_PREFIX) && pkgpath != MAIN_PKG_PATH {
format!("{}{}", kclvm_runtime::PKG_PATH_PREFIX, pkgpath)
Expand Down Expand Up @@ -1935,32 +1936,38 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
.unwrap_or_else(|| panic!("package {} is not found", pkgpath));
// Scopes 0 is builtin scope, Scopes 1 is the global scope, Scopes 2~ are the local scopes
let scopes_len = scopes.len();
let last_scopes = scopes.last().expect(kcl_error::INTERNAL_ERROR_MSG);
let mut closures_mut = last_scopes.closures.borrow_mut();
for i in 0..scopes_len {
let index = scopes_len - i - 1;
let variables_mut = scopes[index].variables.borrow_mut();
if let Some(var) = variables_mut.get(&name.to_string()) {
// Closure vars, 2 denotes the builtin scope and the global scope
let variables = scopes[index].variables.borrow();
if let Some(var) = variables.get(&name.to_string()) {
// Closure vars, 2 denotes the builtin scope and the global scope, here is a closure scope.
let value = if i >= 1 && i < scopes_len - 2 {
closures_mut.insert(name.to_string(), *var);
let variables = last_scopes.variables.borrow();
let ptr = variables.get(value::LAMBDA_CLOSURE);
// Lambda closure
match ptr {
Some(ptr) => {
let closure_map = self.builder.build_load(*ptr, "");
let string_ptr_value = self.native_global_string(name, "").into();
self.build_call(
&ApiFunc::kclvm_dict_get_value.name(),
&[
self.current_runtime_ctx_ptr(),
closure_map,
string_ptr_value,
],
)
let last_lambda_scope = self.last_lambda_scope();
// Local scope variable
if index >= last_lambda_scope {
self.builder.build_load(*var, name)
} else {
// Outer lamba closure
let variables = scopes[last_lambda_scope].variables.borrow();
let ptr = variables.get(value::LAMBDA_CLOSURE);
// Lambda closure
match ptr {
Some(ptr) => {
let closure_map = self.builder.build_load(*ptr, "");
let string_ptr_value =
self.native_global_string(name, "").into();
// Not a closure, mapbe a local variale
self.build_call(
&ApiFunc::kclvm_dict_get_value.name(),
&[
self.current_runtime_ctx_ptr(),
closure_map,
string_ptr_value,
],
)
}
None => self.builder.build_load(*var, name),
}
None => self.builder.build_load(*var, name),
}
} else {
self.builder.build_load(*var, name)
Expand Down Expand Up @@ -2010,9 +2017,9 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
// User module external variable
let external_var_name = format!("${}.${}", pkgpath_without_prefix!(ext_pkgpath), name);
let current_pkgpath = self.current_pkgpath();
let modules = self.modules.borrow_mut();
let modules = self.modules.borrow();
let msg = format!("pkgpath {} is not found", current_pkgpath);
let module = modules.get(&current_pkgpath).expect(&msg).borrow_mut();
let module = modules.get(&current_pkgpath).expect(&msg).borrow();
let tpe = self.value_ptr_type();
let mut global_var_maps = self.global_vars.borrow_mut();
let pkgpath = self.current_pkgpath();
Expand All @@ -2036,11 +2043,11 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
Ok(value)
}

/// Get closure map in the current scope.
pub(crate) fn get_closure_map(&self) -> BasicValueEnum<'ctx> {
// Get closures in the current scope.
let closure_map = self.dict_value();
{
/// Get closure map in the current inner scope.
pub(crate) fn get_current_inner_scope_variable_map(&self) -> BasicValueEnum<'ctx> {
let var_map = {
let last_lambda_scope = self.last_lambda_scope();
// Get variable map in the current scope.
let pkgpath = self.current_pkgpath();
let pkgpath = if !pkgpath.starts_with(PKG_PATH_PREFIX) && pkgpath != MAIN_PKG_PATH {
format!("{}{}", PKG_PATH_PREFIX, pkgpath)
Expand All @@ -2051,39 +2058,77 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
let scopes = pkg_scopes
.get(&pkgpath)
.unwrap_or_else(|| panic!("package {} is not found", pkgpath));
// Clouure variable must be inner of the global scope.
if scopes.len() > INNER_LEVEL {
let closures = scopes
.last()
.expect(kcl_error::INTERNAL_ERROR_MSG)
.closures
.borrow();
// Curret scope vaiable.
let variables = scopes
.get(scopes.len() - INNER_LEVEL)
.expect(kcl_error::INTERNAL_ERROR_MSG)
.variables
.borrow();
// Transverse all scope and capture closures except the builtin amd global scope.
for (key, ptr) in &*closures {
if variables.contains_key(key) {
let value = self.builder.build_load(*ptr, "");
self.dict_insert_override_item(closure_map, key.as_str(), value);
let current_scope = scopes.len() - 1;
// Get last closure map.
let var_map = if current_scope >= last_lambda_scope && last_lambda_scope > 0 {
let variables = scopes[last_lambda_scope].variables.borrow();
let ptr = variables.get(value::LAMBDA_CLOSURE);
let var_map = match ptr {
Some(ptr) => self.builder.build_load(*ptr, ""),
None => self.dict_value(),
};
// Get variable map including schema in the current scope.
for i in last_lambda_scope..current_scope + 1 {
let variables = scopes
.get(i)
.expect(kcl_error::INTERNAL_ERROR_MSG)
.variables
.borrow();
for (key, ptr) in &*variables {
if key != value::LAMBDA_CLOSURE {
let value = self.builder.build_load(*ptr, "");
self.dict_insert_override_item(var_map, key.as_str(), value);
}
}
} // Curret scope vaiable.
}
}
}
var_map
} else {
self.dict_value()
};
var_map
};
// Capture schema `self` closure.
let is_in_schema = self.schema_stack.borrow().len() > 0;
if is_in_schema {
for shcmea_closure_name in value::SCHEMA_VARIABLE_LIST {
let value = self
.get_variable(shcmea_closure_name)
.expect(kcl_error::INTERNAL_ERROR_MSG);
self.dict_insert_override_item(closure_map, shcmea_closure_name, value);
self.dict_insert_override_item(var_map, shcmea_closure_name, value);
}
}
closure_map
var_map
}

/// Push a lambda definition scope into the lambda stack
#[inline]
pub fn push_lambda(&self, scope: usize) {
self.lambda_stack.borrow_mut().push(scope);
}

/// Pop a lambda definition scope.
#[inline]
pub fn pop_lambda(&self) {
self.lambda_stack.borrow_mut().pop();
}

#[inline]
pub fn is_in_lambda(&self) -> bool {
*self
.lambda_stack
.borrow()
.last()
.expect(kcl_error::INTERNAL_ERROR_MSG)
> 0
}

#[inline]
pub fn last_lambda_scope(&self) -> usize {
*self
.lambda_stack
.borrow()
.last()
.expect(kcl_error::INTERNAL_ERROR_MSG)
}

/// Push a function call frame into the function stack
Expand Down Expand Up @@ -2111,9 +2156,9 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
/// Plan globals to a json string
pub fn globals_to_json_str(&self) -> BasicValueEnum<'ctx> {
let current_pkgpath = self.current_pkgpath();
let mut pkg_scopes = self.pkg_scopes.borrow_mut();
let pkg_scopes = self.pkg_scopes.borrow();
let scopes = pkg_scopes
.get_mut(&current_pkgpath)
.get(&current_pkgpath)
.unwrap_or_else(|| panic!("pkgpath {} is not found", current_pkgpath));
// The global scope.
let scope = scopes.last().expect(kcl_error::INTERNAL_ERROR_MSG);
Expand Down
31 changes: 10 additions & 21 deletions kclvm/compiler/src/codegen/llvm/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,7 @@ impl<'ctx> TypedResultWalker<'ctx> for LLVMCodeGenContext<'ctx> {
let value = self
.walk_schema_expr(&unification_stmt.value.node)
.expect(kcl_error::COMPILE_ERROR_MSG);
if self.scope_level() == GLOBAL_LEVEL
|| *self
.lambda_stack
.borrow_mut()
.last()
.expect(kcl_error::INTERNAL_ERROR_MSG)
{
if self.scope_level() == GLOBAL_LEVEL || self.is_in_lambda() {
if self.resolve_variable(name) {
let org_value = self
.walk_identifier_with_ctx(
Expand Down Expand Up @@ -2136,6 +2130,9 @@ impl<'ctx> TypedResultWalker<'ctx> for LLVMCodeGenContext<'ctx> {
check_backtrack_stop!(self);
let pkgpath = &self.current_pkgpath();
let is_in_schema = self.schema_stack.borrow().len() > 0;
// Higher-order lambda requires capturing the current lambda closure variable
// as well as the closure of a more external scope.
let last_closure_map = self.get_current_inner_scope_variable_map();
let func_before_block = self.append_block("");
self.br(func_before_block);
// Use "pkgpath"+"kclvm_lambda" to name 'function' to prevent conflicts between lambdas with the same name in different packages
Expand All @@ -2146,7 +2143,8 @@ impl<'ctx> TypedResultWalker<'ctx> for LLVMCodeGenContext<'ctx> {
));
// Enter the function
self.push_function(function);
self.lambda_stack.borrow_mut().push(true);
// Push the current lambda scope level in the lambda stack.
self.push_lambda(self.scope_level() + 1);
// Lambda function body
let block = self.context.append_basic_block(function, ENTRY_NAME);
self.builder.position_at_end(block);
Expand Down Expand Up @@ -2191,11 +2189,11 @@ impl<'ctx> TypedResultWalker<'ctx> for LLVMCodeGenContext<'ctx> {
let closure = self.list_value();
// Use closure map in the laste scope to construct curret closure map.
// The default value of the closure map is `{}`.
self.list_append(closure, self.get_closure_map());
self.list_append(closure, last_closure_map);
let function = self.closure_value(function, closure);
self.leave_scope();
self.pop_function();
self.lambda_stack.borrow_mut().pop();
self.pop_lambda();
Ok(function)
}

Expand Down Expand Up @@ -2441,12 +2439,7 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
right_value.expect(kcl_error::INTERNAL_ERROR_MSG),
);
// Local variables including schema/rule/lambda
} else if *self
.lambda_stack
.borrow_mut()
.last()
.expect(kcl_error::INTERNAL_ERROR_MSG)
{
} else if self.is_in_lambda() {
let value = right_value.expect(kcl_error::INTERNAL_ERROR_MSG);
// If variable exists in the scope and update it, if not, add it to the scope.
if !self.store_variable_in_current_scope(name, value) {
Expand Down Expand Up @@ -2610,11 +2603,7 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
let local_vars = self.local_vars.borrow_mut();
local_vars.contains(name)
};
let is_in_lambda = *self
.lambda_stack
.borrow()
.last()
.expect(kcl_error::INTERNAL_ERROR_MSG);
let is_in_lambda = self.is_in_lambda();
// Set config value for the schema attribute if the attribute is in the schema and
// it is not a local variable in the lambda function.
if self.scope_level() >= INNER_LEVEL
Expand Down
Loading

0 comments on commit 03005f3

Please sign in to comment.