Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions java/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,15 @@ doc = false

[dependencies]
jni = "0.21.1"
prost = "0.11.8"
prost-types = "0.11.8"
serde_json = "1"
indexmap = "2.2.5"
anyhow = "1"
serde = { version = "1", features = ["derive"] }
kcl-lang = {path = "../"}
once_cell = "1.19.0"
lazy_static = "1.4.0"
kclvm-parser = { git = "https://github.com/peefy/KCLVM", version = "0.8.1" }
kclvm-sema = { git = "https://github.com/peefy/KCLVM", version = "0.8.1" }
kclvm-api = { git = "https://github.com/peefy/KCLVM", version = "0.8.1" }
45 changes: 44 additions & 1 deletion java/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,31 @@
extern crate anyhow;
extern crate jni;
extern crate kcl_lang;
extern crate kclvm_api;
extern crate kclvm_parser;
extern crate kclvm_sema;
extern crate lazy_static;
extern crate once_cell;
extern crate prost;

use anyhow::Result;
use jni::objects::{JByteArray, JClass, JObject};
use jni::sys::jbyteArray;
use jni::JNIEnv;
use kcl_lang::API;
use kclvm_api::gpyrpc::LoadPackageArgs;
use kclvm_api::service::KclvmServiceImpl;
use kclvm_parser::KCLModuleCache;
use kclvm_sema::resolver::scope::KCLScopeCache;
use lazy_static::lazy_static;
use once_cell::sync::OnceCell;
use prost::Message;
use std::sync::Mutex;

lazy_static! {
static ref API_INSTANCE: Mutex<OnceCell<kcl_lang::API>> = Mutex::new(OnceCell::new());
static ref API_INSTANCE: Mutex<OnceCell<API>> = Mutex::new(OnceCell::new());
static ref MODULE_CACHE: Mutex<OnceCell<KCLModuleCache>> = Mutex::new(OnceCell::new());
static ref SCOPE_CACHE: Mutex<OnceCell<KCLScopeCache>> = Mutex::new(OnceCell::new());
}

#[no_mangle]
Expand All @@ -29,6 +41,18 @@ pub extern "system" fn Java_com_kcl_api_API_callNative(
})
}

#[no_mangle]
pub extern "system" fn Java_com_kcl_api_API_loadPackageWithCache(
mut env: JNIEnv,
_: JClass,
args: JByteArray,
) -> jbyteArray {
intern_load_package_with_cache(&mut env, args).unwrap_or_else(|e| {
let _ = throw(&mut env, e);
JObject::default().into_raw()
})
}

fn intern_call_native(env: &mut JNIEnv, name: JByteArray, args: JByteArray) -> Result<jbyteArray> {
let binding = API_INSTANCE.lock().unwrap();
let api = binding.get_or_init(|| kcl_lang::API::new().expect("Failed to create API instance"));
Expand All @@ -39,6 +63,25 @@ fn intern_call_native(env: &mut JNIEnv, name: JByteArray, args: JByteArray) -> R
Ok(j_byte_array.into_raw())
}

/// This is a stateful API, so we avoid serialization overhead and directly use JVM
/// to instantiate global variables here.
fn intern_load_package_with_cache(env: &mut JNIEnv, args: JByteArray) -> Result<jbyteArray> {
// AST module cache
let binding = MODULE_CACHE.lock().unwrap();
let module_cache = binding.get_or_init(|| KCLModuleCache::default());
// Resolver scope cache
let binding = SCOPE_CACHE.lock().unwrap();
let scope_cache = binding.get_or_init(|| KCLScopeCache::default());
// Load package arguments from protobuf bytes.
let args = env.convert_byte_array(args)?;
let args: LoadPackageArgs = <LoadPackageArgs as Message>::decode(args.as_ref())?;
let svc = KclvmServiceImpl::default();
// Call load package API and decode the result to protobuf bytes.
let packages = svc.load_package_with_cache(&args, module_cache.clone(), scope_cache.clone())?;
let j_byte_array = env.byte_array_from_slice(&packages.encode_to_vec())?;
Ok(j_byte_array.into_raw())
}

fn throw(env: &mut JNIEnv, error: anyhow::Error) -> jni::errors::Result<()> {
env.throw(("java/lang/Exception", error.to_string()))
}
42 changes: 42 additions & 0 deletions java/src/main/java/com/kcl/api/API.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ private static String bundledLibraryPath() {
}

private native byte[] callNative(byte[] call, byte[] args);
private native byte[] loadPackageWithCache(byte[] args);

public API() {
}
Expand Down Expand Up @@ -132,6 +133,39 @@ public LoadPackage_Result loadPackage(LoadPackage_Args args) throws Exception {
return LoadPackage_Result.parseFrom(call("KclvmService.LoadPackage", args.toByteArray()));
}

/**
* Loads KCL package with the internal cache and returns the AST, symbol, type, definition information. *
*
* <p>
* Example usage:
*
* <pre>
* {@code
* import com.kcl.api.*;
*
* API api = new API();
* LoadPackage_Result result = api.loadPackageWithCache(
* LoadPackage_Args.newBuilder().setResolveAst(true).setParseArgs(
* ParseProgram_Args.newBuilder().addPaths("/path/to/kcl.k").build())
* .build());
* result.getSymbolsMap().values().forEach(s -> System.out.println(s));
* }
* </pre>
*
* @param args
* the arguments specifying the file paths to be parsed and resolved.
*
* @return the result of parsing the program and parse errors, type errors, including the AST in JSON format and
* symbol, type and definition information.
*
* @throws Exception
* if an error occurs during the remote procedure call.
*/
@Override
public LoadPackage_Result loadPackageWithCache(LoadPackage_Args args) throws Exception {
return LoadPackage_Result.parseFrom(callLoadPackageWithCache(args.toByteArray()));
}

public ExecProgram_Result execProgram(ExecProgram_Args args) throws Exception {
return ExecProgram_Result.parseFrom(call("KclvmService.ExecProgram", args.toByteArray()));
}
Expand Down Expand Up @@ -194,6 +228,14 @@ private byte[] call(String name, byte[] args) throws Exception {
return result;
}

private byte[] callLoadPackageWithCache(byte[] args) throws Exception {
byte[] result = loadPackageWithCache(args);
if (result != null && startsWith(result, "ERROR")) {
throw new java.lang.Error(result.toString());
}
return result;
}

static boolean startsWith(byte[] array, String prefix) {
byte[] prefixBytes = prefix.getBytes();
if (array.length < prefixBytes.length) {
Expand Down
3 changes: 3 additions & 0 deletions java/src/main/java/com/kcl/api/Service.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ public interface Service {
// Loads KCL package and returns the AST, symbol, type, definition information.
LoadPackage_Result loadPackage(LoadPackage_Args args) throws Exception;

// Loads KCL package and returns the AST, symbol, type, definition information.
LoadPackage_Result loadPackageWithCache(LoadPackage_Args args) throws Exception;

// Execute KCL file with args
ExecProgram_Result execProgram(ExecProgram_Args args) throws Exception;

Expand Down
37 changes: 37 additions & 0 deletions java/src/test/java/com/kcl/LoadPackageTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,41 @@ public void testProgramSymbols() throws Exception {
// Query Scope using the symbol
Assert.assertEquals(SematicUtil.findScopeBySymbol(result, appSymbolIndex).getDefsCount(), 2);
}

@Test
public void testProgramSymbolsWithCache() throws Exception {
// API instance
API api = new API();
// Note call `loadPackageWithCache` here.
LoadPackage_Result result = api.loadPackageWithCache(LoadPackage_Args.newBuilder().setResolveAst(true)
.setWithAstIndex(true)
.setParseArgs(ParseProgram_Args.newBuilder().addPaths("./src/test_data/schema.k").build()).build());
// Get parse errors
Assert.assertEquals(result.getParseErrorsList().size(), 0);
// Get Type errors
Assert.assertEquals(result.getTypeErrorsList().size(), 0);
// Get AST
Program program = JsonUtil.deserializeProgram(result.getProgram());
Assert.assertTrue(program.getRoot().contains("test_data"));
// Variable definitions in the main scope.
Scope mainScope = SematicUtil.findMainPackageScope(result);
// Child scopes of the main scope.
Assert.assertEquals(mainScope.getChildrenList().size(), 2);
// Mapping AST node to Symbol type
NodeRef<Stmt> stmt = program.getFirstModule().getBody().get(0);
Assert.assertTrue(SematicUtil.findSymbolByAstId(result, stmt.getId()).getName().contains("pkg"));
// Mapping symbol to AST node
SymbolIndex appSymbolIndex = mainScope.getDefs(1);
Symbol appSymbol = SematicUtil.findSymbol(result, appSymbolIndex);
Assert.assertEquals(appSymbol.getTy().getSchemaName(), "AppConfig");
// Query type symbol using variable type.
String schemaFullName = appSymbol.getTy().getPkgPath() + "." + appSymbol.getTy().getSchemaName();
Symbol appConfigSymbol = SematicUtil.findSymbol(result,
result.getFullyQualifiedNameMapOrDefault(schemaFullName, null));
Assert.assertEquals(appConfigSymbol.getTy().getSchemaName(), "AppConfig");
// Query AST node using the symbol
Assert.assertNotNull(SematicUtil.findNodeBySymbol(result, appSymbolIndex));
// Query Scope using the symbol
Assert.assertEquals(SematicUtil.findScopeBySymbol(result, appSymbolIndex).getDefsCount(), 2);
}
}