diff --git a/java/Cargo.toml b/java/Cargo.toml index 72c992e5..2c3320c9 100644 --- a/java/Cargo.toml +++ b/java/Cargo.toml @@ -9,7 +9,15 @@ doc = false [dependencies] jni = "0.21.1" +prost = "0.11.8" +prost-types = "0.11.8" +serde_json = "1" +indexmap = "2.2.5" anyhow = "1" +serde = { version = "1", features = ["derive"] } kcl-lang = {path = "../"} once_cell = "1.19.0" lazy_static = "1.4.0" +kclvm-parser = { git = "https://github.com/peefy/KCLVM", version = "0.8.1" } +kclvm-sema = { git = "https://github.com/peefy/KCLVM", version = "0.8.1" } +kclvm-api = { git = "https://github.com/peefy/KCLVM", version = "0.8.1" } diff --git a/java/src/lib.rs b/java/src/lib.rs index 28c58ddb..614c2f9f 100644 --- a/java/src/lib.rs +++ b/java/src/lib.rs @@ -1,19 +1,31 @@ extern crate anyhow; extern crate jni; extern crate kcl_lang; +extern crate kclvm_api; +extern crate kclvm_parser; +extern crate kclvm_sema; extern crate lazy_static; extern crate once_cell; +extern crate prost; use anyhow::Result; use jni::objects::{JByteArray, JClass, JObject}; use jni::sys::jbyteArray; use jni::JNIEnv; +use kcl_lang::API; +use kclvm_api::gpyrpc::LoadPackageArgs; +use kclvm_api::service::KclvmServiceImpl; +use kclvm_parser::KCLModuleCache; +use kclvm_sema::resolver::scope::KCLScopeCache; use lazy_static::lazy_static; use once_cell::sync::OnceCell; +use prost::Message; use std::sync::Mutex; lazy_static! { - static ref API_INSTANCE: Mutex> = Mutex::new(OnceCell::new()); + static ref API_INSTANCE: Mutex> = Mutex::new(OnceCell::new()); + static ref MODULE_CACHE: Mutex> = Mutex::new(OnceCell::new()); + static ref SCOPE_CACHE: Mutex> = Mutex::new(OnceCell::new()); } #[no_mangle] @@ -29,6 +41,18 @@ pub extern "system" fn Java_com_kcl_api_API_callNative( }) } +#[no_mangle] +pub extern "system" fn Java_com_kcl_api_API_loadPackageWithCache( + mut env: JNIEnv, + _: JClass, + args: JByteArray, +) -> jbyteArray { + intern_load_package_with_cache(&mut env, args).unwrap_or_else(|e| { + let _ = throw(&mut env, e); + JObject::default().into_raw() + }) +} + fn intern_call_native(env: &mut JNIEnv, name: JByteArray, args: JByteArray) -> Result { let binding = API_INSTANCE.lock().unwrap(); let api = binding.get_or_init(|| kcl_lang::API::new().expect("Failed to create API instance")); @@ -39,6 +63,25 @@ fn intern_call_native(env: &mut JNIEnv, name: JByteArray, args: JByteArray) -> R Ok(j_byte_array.into_raw()) } +/// This is a stateful API, so we avoid serialization overhead and directly use JVM +/// to instantiate global variables here. +fn intern_load_package_with_cache(env: &mut JNIEnv, args: JByteArray) -> Result { + // AST module cache + let binding = MODULE_CACHE.lock().unwrap(); + let module_cache = binding.get_or_init(|| KCLModuleCache::default()); + // Resolver scope cache + let binding = SCOPE_CACHE.lock().unwrap(); + let scope_cache = binding.get_or_init(|| KCLScopeCache::default()); + // Load package arguments from protobuf bytes. + let args = env.convert_byte_array(args)?; + let args: LoadPackageArgs = ::decode(args.as_ref())?; + let svc = KclvmServiceImpl::default(); + // Call load package API and decode the result to protobuf bytes. + let packages = svc.load_package_with_cache(&args, module_cache.clone(), scope_cache.clone())?; + let j_byte_array = env.byte_array_from_slice(&packages.encode_to_vec())?; + Ok(j_byte_array.into_raw()) +} + fn throw(env: &mut JNIEnv, error: anyhow::Error) -> jni::errors::Result<()> { env.throw(("java/lang/Exception", error.to_string())) } diff --git a/java/src/main/java/com/kcl/api/API.java b/java/src/main/java/com/kcl/api/API.java index 475f5802..4f6f1606 100644 --- a/java/src/main/java/com/kcl/api/API.java +++ b/java/src/main/java/com/kcl/api/API.java @@ -56,6 +56,7 @@ private static String bundledLibraryPath() { } private native byte[] callNative(byte[] call, byte[] args); + private native byte[] loadPackageWithCache(byte[] args); public API() { } @@ -132,6 +133,39 @@ public LoadPackage_Result loadPackage(LoadPackage_Args args) throws Exception { return LoadPackage_Result.parseFrom(call("KclvmService.LoadPackage", args.toByteArray())); } + /** + * Loads KCL package with the internal cache and returns the AST, symbol, type, definition information. * + * + *

+ * Example usage: + * + *

+     * {@code
+     * import com.kcl.api.*;
+     *
+     * API api = new API();
+     * LoadPackage_Result result = api.loadPackageWithCache(
+     *     LoadPackage_Args.newBuilder().setResolveAst(true).setParseArgs(
+     *     ParseProgram_Args.newBuilder().addPaths("/path/to/kcl.k").build())
+     *     .build());
+     * result.getSymbolsMap().values().forEach(s -> System.out.println(s));
+     * }
+     * 
+ * + * @param args + * the arguments specifying the file paths to be parsed and resolved. + * + * @return the result of parsing the program and parse errors, type errors, including the AST in JSON format and + * symbol, type and definition information. + * + * @throws Exception + * if an error occurs during the remote procedure call. + */ + @Override + public LoadPackage_Result loadPackageWithCache(LoadPackage_Args args) throws Exception { + return LoadPackage_Result.parseFrom(callLoadPackageWithCache(args.toByteArray())); + } + public ExecProgram_Result execProgram(ExecProgram_Args args) throws Exception { return ExecProgram_Result.parseFrom(call("KclvmService.ExecProgram", args.toByteArray())); } @@ -194,6 +228,14 @@ private byte[] call(String name, byte[] args) throws Exception { return result; } + private byte[] callLoadPackageWithCache(byte[] args) throws Exception { + byte[] result = loadPackageWithCache(args); + if (result != null && startsWith(result, "ERROR")) { + throw new java.lang.Error(result.toString()); + } + return result; + } + static boolean startsWith(byte[] array, String prefix) { byte[] prefixBytes = prefix.getBytes(); if (array.length < prefixBytes.length) { diff --git a/java/src/main/java/com/kcl/api/Service.java b/java/src/main/java/com/kcl/api/Service.java index 2f6ac9c3..a6b9807a 100644 --- a/java/src/main/java/com/kcl/api/Service.java +++ b/java/src/main/java/com/kcl/api/Service.java @@ -12,6 +12,9 @@ public interface Service { // Loads KCL package and returns the AST, symbol, type, definition information. LoadPackage_Result loadPackage(LoadPackage_Args args) throws Exception; + // Loads KCL package and returns the AST, symbol, type, definition information. + LoadPackage_Result loadPackageWithCache(LoadPackage_Args args) throws Exception; + // Execute KCL file with args ExecProgram_Result execProgram(ExecProgram_Args args) throws Exception; diff --git a/java/src/test/java/com/kcl/LoadPackageTest.java b/java/src/test/java/com/kcl/LoadPackageTest.java index 2932b69c..ef60dd8e 100644 --- a/java/src/test/java/com/kcl/LoadPackageTest.java +++ b/java/src/test/java/com/kcl/LoadPackageTest.java @@ -51,4 +51,41 @@ public void testProgramSymbols() throws Exception { // Query Scope using the symbol Assert.assertEquals(SematicUtil.findScopeBySymbol(result, appSymbolIndex).getDefsCount(), 2); } + + @Test + public void testProgramSymbolsWithCache() throws Exception { + // API instance + API api = new API(); + // Note call `loadPackageWithCache` here. + LoadPackage_Result result = api.loadPackageWithCache(LoadPackage_Args.newBuilder().setResolveAst(true) + .setWithAstIndex(true) + .setParseArgs(ParseProgram_Args.newBuilder().addPaths("./src/test_data/schema.k").build()).build()); + // Get parse errors + Assert.assertEquals(result.getParseErrorsList().size(), 0); + // Get Type errors + Assert.assertEquals(result.getTypeErrorsList().size(), 0); + // Get AST + Program program = JsonUtil.deserializeProgram(result.getProgram()); + Assert.assertTrue(program.getRoot().contains("test_data")); + // Variable definitions in the main scope. + Scope mainScope = SematicUtil.findMainPackageScope(result); + // Child scopes of the main scope. + Assert.assertEquals(mainScope.getChildrenList().size(), 2); + // Mapping AST node to Symbol type + NodeRef stmt = program.getFirstModule().getBody().get(0); + Assert.assertTrue(SematicUtil.findSymbolByAstId(result, stmt.getId()).getName().contains("pkg")); + // Mapping symbol to AST node + SymbolIndex appSymbolIndex = mainScope.getDefs(1); + Symbol appSymbol = SematicUtil.findSymbol(result, appSymbolIndex); + Assert.assertEquals(appSymbol.getTy().getSchemaName(), "AppConfig"); + // Query type symbol using variable type. + String schemaFullName = appSymbol.getTy().getPkgPath() + "." + appSymbol.getTy().getSchemaName(); + Symbol appConfigSymbol = SematicUtil.findSymbol(result, + result.getFullyQualifiedNameMapOrDefault(schemaFullName, null)); + Assert.assertEquals(appConfigSymbol.getTy().getSchemaName(), "AppConfig"); + // Query AST node using the symbol + Assert.assertNotNull(SematicUtil.findNodeBySymbol(result, appSymbolIndex)); + // Query Scope using the symbol + Assert.assertEquals(SematicUtil.findScopeBySymbol(result, appSymbolIndex).getDefsCount(), 2); + } }