From 4a739318db93546898f9296f5b553f75f6e5d1f6 Mon Sep 17 00:00:00 2001 From: Lu Qiu Date: Fri, 10 May 2024 21:02:13 -0700 Subject: [PATCH 1/9] Refactor the lance error handling --- java/core/lance-jni/src/blocking_dataset.rs | 335 +++++++++++++++--- java/core/lance-jni/src/blocking_scanner.rs | 174 ++++----- java/core/lance-jni/src/error.rs | 262 +++++++------- java/core/lance-jni/src/ffi.rs | 133 ++++--- java/core/lance-jni/src/fragment.rs | 160 +++++---- java/core/lance-jni/src/lib.rs | 207 +---------- java/core/lance-jni/src/traits.rs | 58 +-- java/core/lance-jni/src/utils.rs | 35 +- .../main/java/com/lancedb/lance/Dataset.java | 16 +- .../java/com/lancedb/lance/TestUtils.java | 2 +- 10 files changed, 708 insertions(+), 674 deletions(-) diff --git a/java/core/lance-jni/src/blocking_dataset.rs b/java/core/lance-jni/src/blocking_dataset.rs index ec239b4dc6..4aaff5d423 100644 --- a/java/core/lance-jni/src/blocking_dataset.rs +++ b/java/core/lance-jni/src/blocking_dataset.rs @@ -12,15 +12,27 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::utils::import_ffi_schema; -use crate::{traits::IntoJava, Error, Result, RT}; +use std::iter::empty; +use crate::ffi::JNIEnvExt; +use crate::error::{JavaResult, JavaErrorExt}; +use crate::{traits::IntoJava, RT}; +use crate::utils::extract_write_params; use arrow::array::RecordBatchReader; use jni::sys::jlong; use jni::{objects::JObject, JNIEnv}; -use lance::dataset::fragment::FileFragment; use lance::dataset::transaction::Operation; use lance::dataset::{Dataset, WriteParams}; -use snafu::{location, Location}; +use lance::table::format::Fragment; +use jni::sys::jint; +use jni::objects::JString; +use arrow::ffi::FFI_ArrowSchema; +use arrow::datatypes::Schema; +use arrow::record_batch::RecordBatchIterator; +use std::sync::Arc; +use arrow::ffi_stream::FFI_ArrowArrayStream; +use arrow::ffi_stream::ArrowArrayStreamReader; +use crate::traits::FromJString; + pub const NATIVE_DATASET: &str = "nativeDatasetHandle"; #[derive(Clone)] @@ -33,34 +45,135 @@ impl BlockingDataset { reader: impl RecordBatchReader + Send + 'static, uri: &str, params: Option, - ) -> Result { - let inner = RT.block_on(Dataset::write(reader, uri, params))?; + ) -> JavaResult { + let inner = RT.block_on(Dataset::write(reader, uri, params)).infer_error()?; Ok(Self { inner }) } - pub fn open(uri: &str) -> Result { - let inner = RT.block_on(Dataset::open(uri))?; + pub fn open(uri: &str) -> JavaResult { + let inner = RT.block_on(Dataset::open(uri)).infer_error()?; Ok(Self { inner }) } - pub fn commit(uri: &str, operation: Operation, read_version: Option) -> Result { - let inner = RT.block_on(Dataset::commit(uri, operation, read_version, None, None))?; + pub fn commit(uri: &str, operation: Operation, read_version: Option) -> JavaResult { + let inner = RT.block_on(Dataset::commit(uri, operation, read_version, None, None)).infer_error()?; Ok(Self { inner }) } - pub fn latest_version(&self) -> Result { - Ok(RT.block_on(self.inner.latest_version_id())?) + pub fn latest_version(&self) -> JavaResult { + Ok(RT.block_on(self.inner.latest_version_id()).infer_error()?) } - pub fn count_rows(&self, filter: Option) -> Result { - Ok(RT.block_on(self.inner.count_rows(filter))?) + pub fn count_rows(&self, filter: Option) -> JavaResult { + Ok(RT.block_on(self.inner.count_rows(filter)).infer_error()?) } pub fn close(&self) {} } +/////////////////// +// Write Methods // +/////////////////// +#[no_mangle] +pub extern "system" fn Java_com_lancedb_lance_Dataset_createWithFfiSchema<'local>( + mut env: JNIEnv<'local>, + _obj: JObject, + arrow_schema_addr: jlong, + path: JString, + max_rows_per_file: JObject, // Optional + max_rows_per_group: JObject, // Optional + max_bytes_per_file: JObject, // Optional + mode: JObject, // Optional +) -> JObject<'local> { + ok_or_throw!(env, inner_create_with_ffi_schema(&mut env, arrow_schema_addr, path, max_rows_per_file, max_rows_per_group, max_bytes_per_file, mode)) +} + +fn inner_create_with_ffi_schema<'local> ( + env: &mut JNIEnv<'local>, + arrow_schema_addr: jlong, + path: JString, + max_rows_per_file: JObject, // Optional + max_rows_per_group: JObject, // Optional + max_bytes_per_file: JObject, // Optional + mode: JObject, // Optional +) -> JavaResult> { + let c_schema_ptr = arrow_schema_addr as *mut FFI_ArrowSchema; + let c_schema = unsafe { FFI_ArrowSchema::from_raw(c_schema_ptr) }; + let schema = Schema::try_from(&c_schema).infer_error()?; + + let reader = RecordBatchIterator::new(empty(), Arc::new(schema)); + create_dataset( + env, + path, + max_rows_per_file, + max_rows_per_group, + max_bytes_per_file, + mode, + reader + ) +} + +#[no_mangle] +pub extern "system" fn Java_com_lancedb_lance_Dataset_createWithFfiStream<'local>( + mut env: JNIEnv<'local>, + _obj: JObject, + arrow_array_stream_addr: jlong, + path: JString, + max_rows_per_file: JObject, // Optional + max_rows_per_group: JObject, // Optional + max_bytes_per_file: JObject, // Optional + mode: JObject, // Optional +) -> JObject<'local> { + ok_or_throw!(env, inner_create_with_ffi_stream(&mut env, arrow_array_stream_addr, path, max_rows_per_file, max_rows_per_group, max_bytes_per_file, mode)) +} + +fn inner_create_with_ffi_stream<'local>( + env: &mut JNIEnv<'local>, + arrow_array_stream_addr: jlong, + path: JString, + max_rows_per_file: JObject, // Optional + max_rows_per_group: JObject, // Optional + max_bytes_per_file: JObject, // Optional + mode: JObject, // Optional +) -> JavaResult> { + let stream_ptr = arrow_array_stream_addr as *mut FFI_ArrowArrayStream; + let reader = unsafe { ArrowArrayStreamReader::from_raw(stream_ptr) }.infer_error()?; + create_dataset( + env, + path, + max_rows_per_file, + max_rows_per_group, + max_bytes_per_file, + mode, + reader + ) +} + +fn create_dataset<'local>( + env: &mut JNIEnv<'local>, + path: JString, + max_rows_per_file: JObject, + max_rows_per_group: JObject, + max_bytes_per_file: JObject, + mode: JObject, + reader: impl RecordBatchReader + Send + 'static, +) -> JavaResult> { + let path_str = path.extract(env)?; + + let write_params = extract_write_params( + env, + &max_rows_per_file, + &max_rows_per_group, + &max_bytes_per_file, + &mode, + )?; + + let dataset = BlockingDataset::write(reader, &path_str, Some(write_params))?; + dataset.into_java(env) +} + impl IntoJava for BlockingDataset { - fn into_java<'a>(self, env: &mut JNIEnv<'a>) -> JObject<'a> { + fn into_java<'a>(self, env: &mut JNIEnv<'a>) -> JavaResult> { attach_native_dataset(env, self) } } @@ -68,8 +181,8 @@ impl IntoJava for BlockingDataset { fn attach_native_dataset<'local>( env: &mut JNIEnv<'local>, dataset: BlockingDataset, -) -> JObject<'local> { - let j_dataset = create_java_dataset_object(env); +) -> JavaResult> { + let j_dataset = create_java_dataset_object(env)?; // This block sets a native Rust object (dataset) as a field in the Java object (j_dataset). // Caution: This creates a potential for memory leaks. The Rust object (dataset) is not // automatically garbage-collected by Java, and its memory will not be freed unless @@ -79,22 +192,42 @@ fn attach_native_dataset<'local>( // 1. The Java object (`j_dataset`) should implement the `java.io.Closeable` interface. // 2. Users of this Java object should be instructed to always use it within a try-with-resources // statement (or manually call the `close()` method) to ensure that `self.close()` is invoked. - match unsafe { env.set_rust_field(&j_dataset, NATIVE_DATASET, dataset) } { - Ok(_) => j_dataset, - Err(err) => { - env.throw_new( - "java/lang/RuntimeException", - format!("Failed to set native handle: {}", err), - ) - .expect("Error throwing exception"); - JObject::null() - } - } + unsafe { env.set_rust_field(&j_dataset, NATIVE_DATASET, dataset) }.infer_error()?; + Ok(j_dataset) } -fn create_java_dataset_object<'a>(env: &mut JNIEnv<'a>) -> JObject<'a> { - env.new_object("com/lancedb/lance/Dataset", "()V", &[]) - .expect("Failed to create Java Dataset instance") +fn create_java_dataset_object<'a>(env: &mut JNIEnv<'a>) -> JavaResult> { + Ok(env.new_object("com/lancedb/lance/Dataset", "()V", &[]).infer_error()?) +} + +#[no_mangle] +pub extern "system" fn Java_com_lancedb_lance_Dataset_commitAppend<'local>( + mut env: JNIEnv<'local>, + _obj: JObject, + path: JString, + read_version_obj: JObject, // Optional + fragments_obj: JObject, // List, String is json serialized Fragment +) -> JObject<'local> { + ok_or_throw!(env, inner_commit_append(&mut env, path, read_version_obj, fragments_obj)) +} + +pub fn inner_commit_append<'local>( + env: &mut JNIEnv<'local>, + path: JString, + read_version_obj: JObject, // Optional + fragments_obj: JObject, // List, String is json serialized Fragment) +) -> JavaResult> { + let json_fragments = env.get_strings(&fragments_obj)?; + let mut fragments: Vec = Vec::new(); + for json_fragment in json_fragments { + let fragment = Fragment::from_json(&json_fragment).infer_error()?; + fragments.push(fragment); + } + let op = Operation::Append { fragments }; + let path_str = path.extract(env)?; + let read_version = env.get_u64_opt(&read_version_obj)?; + let dataset = BlockingDataset::commit(&path_str, op, read_version)?; + dataset.into_java(env) } #[no_mangle] @@ -102,11 +235,39 @@ pub extern "system" fn Java_com_lancedb_lance_Dataset_releaseNativeDataset( mut env: JNIEnv, obj: JObject, ) { + ok_or_throw_without_return!(env, inner_release_native_dataset(&mut env, obj)) +} + +fn inner_release_native_dataset( + env: &mut JNIEnv, + obj: JObject, +) -> JavaResult<()> { let dataset: BlockingDataset = unsafe { - env.take_rust_field(obj, "nativeDatasetHandle") - .expect("Failed to take native dataset handle") + env.take_rust_field(obj, NATIVE_DATASET).infer_error()? }; - dataset.close() + dataset.close(); + Ok(()) +} + +////////////////// +// Read Methods // +////////////////// +#[no_mangle] +pub extern "system" fn Java_com_lancedb_lance_Dataset_openNative<'local>( + mut env: JNIEnv<'local>, + _obj: JObject, + path: JString, +) -> JObject<'local> { + ok_or_throw!(env, inner_open_native(&mut env, path)) +} + +fn inner_open_native<'local>( + env: &mut JNIEnv<'local>, + path: JString, +) -> JavaResult> { + let path_str: String = path.extract(env)?; + let dataset = BlockingDataset::open(&path_str)?; + dataset.into_java(env) } #[no_mangle] @@ -114,45 +275,101 @@ pub extern "system" fn Java_com_lancedb_lance_Dataset_getJsonFragments<'a>( mut env: JNIEnv<'a>, jdataset: JObject, ) -> JObject<'a> { + ok_or_throw!(env, inner_get_json_fragments(&mut env, jdataset)) +} + +fn inner_get_json_fragments<'local>( + env: &mut JNIEnv<'local>, + jdataset: JObject, +) -> JavaResult> { let fragments = { let dataset = - unsafe { env.get_rust_field::<_, _, BlockingDataset>(jdataset, NATIVE_DATASET) } - .expect("Failed to get native dataset handle"); + unsafe { env.get_rust_field::<_, _, BlockingDataset>(jdataset, NATIVE_DATASET) }.infer_error()?; dataset.inner.get_fragments() }; - ok_or_throw!(env, create_json_fragment_list(&mut env, fragments)) + let array_list_class = env.find_class("java/util/ArrayList").infer_error()?; + + let array_list = env.new_object(array_list_class, "()V", &[]).infer_error()?; + + for fragment in fragments { + let json_string = serde_json::to_string(fragment.metadata()).infer_error()?; + let jstring = env.new_string(json_string).infer_error()?; + env.call_method( + &array_list, + "add", + "(Ljava/lang/Object;)Z", + &[(&jstring).into()], + ).infer_error()?; + } + Ok(array_list) } #[no_mangle] pub extern "system" fn Java_com_lancedb_lance_Dataset_importFfiSchema( - env: JNIEnv, + mut env: JNIEnv, jdataset: JObject, arrow_schema_addr: jlong, ) { - import_ffi_schema(env, jdataset, arrow_schema_addr, None) + ok_or_throw_without_return!(env, inner_import_ffi_schema(&mut env, jdataset, arrow_schema_addr)) } -fn create_json_fragment_list<'a>( - env: &mut JNIEnv<'a>, - fragments: Vec, -) -> Result> { - let array_list_class = env.find_class("java/util/ArrayList")?; +fn inner_import_ffi_schema( + env: &mut JNIEnv, + jdataset: JObject, + arrow_schema_addr: jlong, +) -> JavaResult<()>{ + let dataset = { + let dataset = + unsafe { env.get_rust_field::<_, _, BlockingDataset>(jdataset, NATIVE_DATASET) }.infer_error()?; + dataset.clone() + }; + let schema = Schema::from(dataset.inner.schema()); - let array_list = env.new_object(array_list_class, "()V", &[])?; + let c_schema = FFI_ArrowSchema::try_from(&schema).infer_error()?; + let out_c_schema = unsafe { &mut *(arrow_schema_addr as *mut FFI_ArrowSchema) }; + let _old = std::mem::replace(out_c_schema, c_schema); + Ok(()) +} - for fragment in fragments { - let json_string = serde_json::to_string(fragment.metadata()).map_err(|e| Error::JSON { - message: e.to_string(), - location: location!(), - })?; - let jstring = env.new_string(json_string)?; - env.call_method( - &array_list, - "add", - "(Ljava/lang/Object;)Z", - &[(&jstring).into()], - )?; - } - Ok(array_list) +#[no_mangle] +pub extern "system" fn Java_com_lancedb_lance_Dataset_version( + mut env: JNIEnv, + java_dataset: JObject, +) -> jlong { + ok_or_throw_with_return!(env, inner_version(&mut env, java_dataset), -1) as jlong } + +fn inner_version(env: &mut JNIEnv, java_dataset: JObject) -> JavaResult { + let dataset_guard = + unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }.infer_error()?; + Ok(dataset_guard.inner.version().version) +} + +#[no_mangle] +pub extern "system" fn Java_com_lancedb_lance_Dataset_latestVersion( + mut env: JNIEnv, + java_dataset: JObject, +) -> jlong { + ok_or_throw_with_return!(env, inner_latest_version(&mut env, java_dataset), -1) as jlong +} + +fn inner_latest_version(env: &mut JNIEnv, java_dataset: JObject) -> JavaResult { + let dataset_guard = + unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }.infer_error()?; + dataset_guard.latest_version() +} + +#[no_mangle] +pub extern "system" fn Java_com_lancedb_lance_Dataset_countRows( + mut env: JNIEnv, + java_dataset: JObject, +) -> jint { + ok_or_throw_with_return!(env, inner_count_rows(&mut env, java_dataset), -1) as jint +} + +fn inner_count_rows(env: &mut JNIEnv, java_dataset: JObject) -> JavaResult { + let dataset_guard = + unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }.infer_error()?; + dataset_guard.count_rows(None) +} \ No newline at end of file diff --git a/java/core/lance-jni/src/blocking_scanner.rs b/java/core/lance-jni/src/blocking_scanner.rs index 220c71c849..199d04bd0d 100644 --- a/java/core/lance-jni/src/blocking_scanner.rs +++ b/java/core/lance-jni/src/blocking_scanner.rs @@ -15,16 +15,17 @@ use std::sync::Arc; use crate::ffi::JNIEnvExt; +use crate::JavaError; use arrow::{ffi::FFI_ArrowSchema, ffi_stream::FFI_ArrowArrayStream}; use arrow_schema::SchemaRef; use jni::{objects::JObject, sys::jlong, JNIEnv}; use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner}; use lance_io::ffi::to_ffi_arrow_array_stream; +use crate::error::{JavaResult, JavaErrorExt}; use crate::{ blocking_dataset::{BlockingDataset, NATIVE_DATASET}, - traits::IntoJava, - Error, Result, RT, + traits::IntoJava, RT, }; pub const NATIVE_SCANNER: &str = "nativeScannerHandle"; @@ -41,57 +42,22 @@ impl BlockingScanner { } } - pub fn open_stream(&self) -> Result { - Ok(RT.block_on(self.inner.try_into_stream())?) + pub fn open_stream(&self) -> JavaResult { + Ok(RT.block_on(self.inner.try_into_stream()).infer_error()?) } - pub fn schema(&self) -> Result { - Ok(RT.block_on(self.inner.schema())?) + pub fn schema(&self) -> JavaResult { + Ok(RT.block_on(self.inner.schema()).infer_error()?) } - pub fn count_rows(&self) -> Result { - Ok(RT.block_on(self.inner.count_rows())?) + pub fn count_rows(&self) -> JavaResult { + Ok(RT.block_on(self.inner.count_rows()).infer_error()?) } } -impl IntoJava for BlockingScanner { - fn into_java<'a>(self, env: &mut JNIEnv<'a>) -> JObject<'a> { - attach_native_scanner(env, self) - } -} - -fn attach_native_scanner<'local>( - env: &mut JNIEnv<'local>, - scanner: BlockingScanner, -) -> JObject<'local> { - let j_scanner = create_java_scanner_object(env); - // This block sets a native Rust object (scanner) as a field in the Java object (j_scanner). - // Caution: This creates a potential for memory leaks. The Rust object (scanner) is not - // automatically garbage-collected by Java, and its memory will not be freed unless - // explicitly handled. - // - // To prevent memory leaks, ensure the following: - // 1. The Java object (`j_scanner`) should implement the `java.io.Closeable` interface. - // 2. Users of this Java object should be instructed to always use it within a try-with-resources - // statement (or manually call the `close()` method) to ensure that `self.close()` is invoked. - match unsafe { env.set_rust_field(&j_scanner, NATIVE_SCANNER, scanner) } { - Ok(_) => j_scanner, - Err(err) => { - env.throw_new( - "java/lang/RuntimeException", - format!("Failed to set native handle for scanner: {}", err), - ) - .expect("Error throwing exception"); - JObject::null() - } - } -} - -fn create_java_scanner_object<'a>(env: &mut JNIEnv<'a>) -> JObject<'a> { - env.new_object("com/lancedb/lance/ipc/LanceScanner", "()V", &[]) - .expect("Failed to create Java Lance Scanner instance") -} - +/////////////////// +// Write Methods // +/////////////////// #[no_mangle] pub extern "system" fn Java_com_lancedb_lance_ipc_LanceScanner_createScanner<'local>( mut env: JNIEnv<'local>, @@ -103,50 +69,53 @@ pub extern "system" fn Java_com_lancedb_lance_ipc_LanceScanner_createScanner<'lo filter_obj: JObject, // Optional batch_size_obj: JObject, // Optional ) -> JObject<'local> { + ok_or_throw!(env, inner_create_scanner(&mut env, jdataset, fragment_ids_obj, columns_obj, substrait_filter_obj, filter_obj, batch_size_obj)) +} + +fn inner_create_scanner<'local>( + env: &mut JNIEnv<'local>, + jdataset: JObject, + fragment_ids_obj: JObject, + columns_obj: JObject, + substrait_filter_obj: JObject, + filter_obj: JObject, + batch_size_obj: JObject, +) -> JavaResult> { let dataset = { let dataset = - unsafe { env.get_rust_field::<_, _, BlockingDataset>(jdataset, NATIVE_DATASET) } - .expect("Dataset handle not set"); + unsafe { env.get_rust_field::<_, _, BlockingDataset>(jdataset, NATIVE_DATASET) }.infer_error()?; dataset.clone() }; let mut scanner = dataset.inner.scan(); - let fragment_ids_opt = ok_or_throw!(env, env.get_ints_opt(&fragment_ids_obj)); + let fragment_ids_opt = env.get_ints_opt(&fragment_ids_obj)?; if let Some(fragment_ids) = fragment_ids_opt { let mut fragments = Vec::with_capacity(fragment_ids.len()); for fragment_id in fragment_ids { let Some(fragment) = dataset.inner.get_fragment(fragment_id as usize) else { - env.throw_new( - "java/lang/RuntimeException", - format!("fragment id {fragment_id} not found"), - ) - .expect("failed to throw java exception"); - return JObject::null(); + return Err(JavaError::input_error(format!("Fragment {fragment_id} not found"))); }; fragments.push(fragment.metadata().clone()); } scanner.with_fragments(fragments); } - let columns_opt = ok_or_throw!(env, env.get_strings_opt(&columns_obj)); + let columns_opt = env.get_strings_opt(&columns_obj)?; if let Some(columns) = columns_opt { - ok_or_throw!(env, scanner.project(&columns)); + scanner.project(&columns).infer_error()?; }; - let substrait_opt = ok_or_throw!(env, env.get_bytes_opt(&substrait_filter_obj)); + let substrait_opt = env.get_bytes_opt(&substrait_filter_obj)?; if let Some(substrait) = substrait_opt { - ok_or_throw!( - env, - RT.block_on(async { scanner.filter_substrait(substrait).await }) - ); + RT.block_on(async { scanner.filter_substrait(substrait).await }).infer_error()?; } - let filter_opt = ok_or_throw!(env, env.get_string_opt(&filter_obj)); + let filter_opt = env.get_string_opt(&filter_obj)?; if let Some(filter) = filter_opt { - ok_or_throw!(env, scanner.filter(filter.as_str())); + scanner.filter(filter.as_str()).infer_error()?; } - let batch_size_opt = ok_or_throw!(env, env.get_long_opt(&batch_size_obj)); + let batch_size_opt = env.get_long_opt(&batch_size_obj)?; if let Some(batch_size) = batch_size_opt { scanner.batch_size(batch_size as usize); } let scanner = BlockingScanner::create(scanner); - scanner.into_java(&mut env) + scanner.into_java(env) } #[no_mangle] @@ -154,27 +123,66 @@ pub extern "system" fn Java_com_lancedb_lance_ipc_LanceScanner_releaseNativeScan mut env: JNIEnv, j_scanner: JObject, ) { + ok_or_throw_without_return!(env, inner_release_native_scanner(&mut env, j_scanner)); +} + +fn inner_release_native_scanner(env: &mut JNIEnv, j_scanner: JObject) -> JavaResult<()>{ let _: BlockingScanner = unsafe { env.take_rust_field(j_scanner, NATIVE_SCANNER) - .expect("Failed to take native scanner handle") - }; + }.infer_error()?; + Ok(()) +} + +impl IntoJava for BlockingScanner { + fn into_java<'local>(self, env: &mut JNIEnv<'local>) -> JavaResult> { + attach_native_scanner(env, self) + } } +fn attach_native_scanner<'local>( + env: &mut JNIEnv<'local>, + scanner: BlockingScanner, +) -> JavaResult> { + let j_scanner = create_java_scanner_object(env)?; + // This block sets a native Rust object (scanner) as a field in the Java object (j_scanner). + // Caution: This creates a potential for memory leaks. The Rust object (scanner) is not + // automatically garbage-collected by Java, and its memory will not be freed unless + // explicitly handled. + // + // To prevent memory leaks, ensure the following: + // 1. The Java object (`j_scanner`) should implement the `java.io.Closeable` interface. + // 2. Users of this Java object should be instructed to always use it within a try-with-resources + // statement (or manually call the `close()` method) to ensure that `self.close()` is invoked. + unsafe {env.set_rust_field(&j_scanner, NATIVE_SCANNER, scanner)}.infer_error()?; + Ok(j_scanner) +} + +fn create_java_scanner_object<'a>(env: &mut JNIEnv<'a>) -> JavaResult> { + env.new_object("com/lancedb/lance/ipc/LanceScanner", "()V", &[]).infer_error() +} + +////////////////// +// Read Methods // +////////////////// #[no_mangle] pub extern "system" fn Java_com_lancedb_lance_ipc_LanceScanner_openStream( mut env: JNIEnv, j_scanner: JObject, stream_addr: jlong, ) { + ok_or_throw_without_return!(env, inner_open_stream(&mut env, j_scanner, stream_addr)); +} + +fn inner_open_stream(env: &mut JNIEnv, j_scanner: JObject, stream_addr: jlong) -> JavaResult<()> { let scanner = { let scanner_guard = - unsafe { env.get_rust_field::<_, _, BlockingScanner>(j_scanner, NATIVE_SCANNER) } - .expect("Failed to get native scanner handle"); + unsafe { env.get_rust_field::<_, _, BlockingScanner>(j_scanner, NATIVE_SCANNER) }.infer_error()?; scanner_guard.clone() }; - let record_batch_stream = ok_or_throw_without_return!(env, scanner.open_stream()); - let ffi_stream = to_ffi_arrow_array_stream(record_batch_stream, RT.handle().clone()).unwrap(); + let record_batch_stream = scanner.open_stream()?; + let ffi_stream = to_ffi_arrow_array_stream(record_batch_stream, RT.handle().clone()).infer_error()?; unsafe { std::ptr::write_unaligned(stream_addr as *mut FFI_ArrowArrayStream, ffi_stream) } + Ok(()) } #[no_mangle] @@ -183,15 +191,19 @@ pub extern "system" fn Java_com_lancedb_lance_ipc_LanceScanner_importFfiSchema( j_scanner: JObject, schema_addr: jlong, ) { + ok_or_throw_without_return!(env, inner_import_ffi_schema(&mut env, j_scanner, schema_addr)); +} + +fn inner_import_ffi_schema(env: &mut JNIEnv, j_scanner: JObject, schema_addr: jlong) -> JavaResult<()> { let scanner = { let scanner_guard = - unsafe { env.get_rust_field::<_, _, BlockingScanner>(j_scanner, NATIVE_SCANNER) } - .expect("Failed to get native scanner handle"); + unsafe { env.get_rust_field::<_, _, BlockingScanner>(j_scanner, NATIVE_SCANNER) }.infer_error()?; scanner_guard.clone() }; - let schema = ok_or_throw_without_return!(env, scanner.schema()); - let ffi_schema = ok_or_throw_without_return!(env, FFI_ArrowSchema::try_from(&*schema)); + let schema = scanner.schema()?; + let ffi_schema = FFI_ArrowSchema::try_from(&*schema).infer_error()?; unsafe { std::ptr::write_unaligned(schema_addr as *mut FFI_ArrowSchema, ffi_schema) } + Ok(()) } #[no_mangle] @@ -199,12 +211,14 @@ pub extern "system" fn Java_com_lancedb_lance_ipc_LanceScanner_countRows( mut env: JNIEnv, j_scanner: JObject, ) -> jlong { + ok_or_throw_with_return!(env, inner_count_rows(&mut env, j_scanner), -1) as jlong +} + +fn inner_count_rows(env: &mut JNIEnv, j_scanner: JObject) -> JavaResult { let scanner = { let scanner_guard = - unsafe { env.get_rust_field::<_, _, BlockingScanner>(j_scanner, NATIVE_SCANNER) } - .expect("Failed to get native scanner handle"); + unsafe { env.get_rust_field::<_, _, BlockingScanner>(j_scanner, NATIVE_SCANNER) }.infer_error()?; scanner_guard.clone() }; - let rows = ok_or_throw_with_return!(env, scanner.count_rows(), -1); - rows as jlong + scanner.count_rows() } diff --git a/java/core/lance-jni/src/error.rs b/java/core/lance-jni/src/error.rs index ec8491d624..855eb39136 100644 --- a/java/core/lance-jni/src/error.rs +++ b/java/core/lance-jni/src/error.rs @@ -15,149 +15,167 @@ use std::str::Utf8Error; use arrow_schema::ArrowError; -use jni::errors::Error as JniError; +use jni::{errors::Error as JniError, JNIEnv}; +use lance::error::Error as LanceError; use serde_json::Error as JsonError; -use snafu::{Location, Snafu}; -/// Java Exception types -pub enum JavaException { - IllegalArgumentException, - IOException, - RuntimeException, +pub type JavaResult = std::result::Result; + +#[derive(Debug)] +pub enum JavaExceptionClass { + IllegalArgumentException, + IOException, + RuntimeException, + UnsupportedOperationException, +} + +impl JavaExceptionClass { + pub fn as_str(&self) -> &str { + match self { + Self::IllegalArgumentException => "java/lang/IllegalArgumentException", + Self::IOException => "java/io/IOException", + Self::RuntimeException => "java/lang/RuntimeException", + Self::UnsupportedOperationException => "java/lang/UnsupportedOperationException" + } + } } -impl JavaException { - pub fn as_str(&self) -> &str { - match self { - Self::IllegalArgumentException => "java/lang/IllegalArgumentException", - Self::IOException => "java/io/IOException", - Self::RuntimeException => "java/lang/RuntimeException", - } - } +#[derive(Debug)] +pub struct JavaError { + message: String, + java_class: JavaExceptionClass, } -#[derive(Debug, Snafu)] -#[snafu(visibility(pub))] -pub enum Error { - #[snafu(display("JNI error: {message}, {location}"))] - Jni { message: String, location: Location }, - #[snafu(display("Invalid argument: {message}, {location}"))] - InvalidArgument { message: String, location: Location }, - #[snafu(display("IO error: {message}, {location}"))] - IO { message: String, location: Location }, - #[snafu(display("Arrow error: {message}, {location}"))] - Arrow { message: String, location: Location }, - #[snafu(display("Index error: {message}, {location}"))] - Index { message: String, location: Location }, - #[snafu(display("JSON error: {message}, {location}"))] - JSON { message: String, location: Location }, - #[snafu(display("Dataset at path {path} was not found, {location}"))] - DatasetNotFound { path: String, location: Location }, - #[snafu(display("Dataset already exists: {uri}, {location}"))] - DatasetAlreadyExists { uri: String, location: Location }, - #[snafu(display("Unknown error: {message}, {location}"))] - Other { message: String, location: Location }, +impl JavaError { + pub fn new(message: String, java_class: JavaExceptionClass) -> Self { + JavaError { message, java_class } + } + + pub fn runtime_error(message: String) -> Self { + JavaError { message: message, java_class: JavaExceptionClass::RuntimeException } + } + + pub fn io_error(message: String) -> Self { + JavaError::new(message, JavaExceptionClass::IOException) + } + + pub fn input_error(message: String) -> Self { + JavaError::new(message, JavaExceptionClass::IllegalArgumentException) + } + + pub fn unsupported_error(message: String) -> Self { + JavaError::new(message, JavaExceptionClass::UnsupportedOperationException) + } + + pub fn throw(&self, env: &mut JNIEnv) { + env.throw_new(self.java_class.as_str(), &self.message) + .expect("Error when throwing Java exception"); + } } -impl Error { - /// Throw as Java Exception - pub fn throw(&self, env: &mut jni::JNIEnv) { - match self { - Self::InvalidArgument { .. } => { - self.throw_as(env, JavaException::IllegalArgumentException) - } - Self::IO { .. } | Self::Index { .. } => self.throw_as(env, JavaException::IOException), - Self::Arrow { .. } - | Self::DatasetNotFound { .. } - | Self::DatasetAlreadyExists { .. } - | Self::JSON { .. } - | Self::Other { .. } - | Self::Jni { .. } => self.throw_as(env, JavaException::RuntimeException), - } - } - - /// Throw as an concrete Java Exception - pub fn throw_as(&self, env: &mut jni::JNIEnv, exception: JavaException) { - env.throw_new(exception.as_str(), self.to_string()) - .expect("Error when throwing Java exception"); - } +impl std::fmt::Display for JavaError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}: {}", self.java_class.as_str(), self.message) + } } -pub type Result = std::result::Result; +impl std::error::Error for JavaError {} -trait ToSnafuLocation { - fn to_snafu_location(&'static self) -> snafu::Location; +/// Trait for converting errors to Java exceptions. +pub trait JavaErrorConversion { + /// Convert to `JavaError` as I/O exception. + fn io_error(self) -> JavaResult; + + /// Convert to `JavaError` as runtime exception. + fn runtime_error(self) -> JavaResult; + + /// Convert to `JavaError` as value (input) exception. + fn input_error(self) -> JavaResult; + + /// Convert to `JavaError` as unsupported operation exception. + fn unsupported_error(self) -> JavaResult; } -impl ToSnafuLocation for std::panic::Location<'static> { - fn to_snafu_location(&'static self) -> snafu::Location { - snafu::Location::new(self.file(), self.line(), self.column()) - } + +impl JavaErrorConversion for std::result::Result { + fn io_error(self) -> JavaResult { + self.map_err(|err| JavaError::io_error(err.to_string())) + } + + fn runtime_error(self) -> JavaResult { + self.map_err(|err| JavaError::runtime_error(err.to_string())) + } + + fn input_error(self) -> JavaResult { + self.map_err(|err| JavaError::input_error(err.to_string())) + } + + fn unsupported_error(self) -> JavaResult { + self.map_err(|err| JavaError::unsupported_error(err.to_string())) + } } -impl From for Error { - #[track_caller] - fn from(source: JniError) -> Self { - Self::Jni { - message: source.to_string(), - location: std::panic::Location::caller().to_snafu_location(), - } - } +/// JavaErrorExt trait that converts specific error types to Java exceptions +pub trait JavaErrorExt { + /// Convert to a Java error based on the specific error type + fn infer_error(self) -> JavaResult; } -impl From for Error { - #[track_caller] - fn from(source: Utf8Error) -> Self { - Self::InvalidArgument { - message: source.to_string(), - location: std::panic::Location::caller().to_snafu_location(), - } - } +impl JavaErrorExt for std::result::Result { + fn infer_error(self) -> JavaResult { + match &self { + Ok(_) => Ok(self.unwrap()), + Err(err) => match err { + LanceError::InvalidInput { .. } => self.input_error(), + LanceError::IO { .. } => self.io_error(), + LanceError::NotSupported { .. } => self.unsupported_error(), + _ => self.runtime_error(), + }, + } + } } -impl From for Error { - #[track_caller] - fn from(source: ArrowError) -> Self { - Self::Arrow { - message: source.to_string(), - location: std::panic::Location::caller().to_snafu_location(), - } - } +impl JavaErrorExt for std::result::Result { + fn infer_error(self) -> JavaResult { + match &self { + Ok(_) => Ok(self.unwrap()), + Err(err) => match err { + ArrowError::InvalidArgumentError{ .. } => self.input_error(), + ArrowError::IoError{ .. } => self.io_error(), + ArrowError::NotYetImplemented(_) => self.unsupported_error(), + _ => self.runtime_error(), + }, + } + } } -impl From for Error { - #[track_caller] - fn from(source: JsonError) -> Self { - Self::JSON { - message: source.to_string(), - location: std::panic::Location::caller().to_snafu_location(), - } - } +impl JavaErrorExt for std::result::Result { + fn infer_error(self) -> JavaResult { + match &self { + Ok(_) => Ok(self.unwrap()), + Err(_) => self.io_error(), + } + } } -impl From for Error { - #[track_caller] - fn from(source: lance::Error) -> Self { - match source { - lance::Error::DatasetNotFound { - path, - source: _, - location, - } => Self::DatasetNotFound { path, location }, - lance::Error::DatasetAlreadyExists { uri, location } => { - Self::DatasetAlreadyExists { uri, location } - } - lance::Error::IO { message, location } => Self::IO { message, location }, - lance::Error::Arrow { message, location } => Self::Arrow { message, location }, - lance::Error::Index { message, location } => Self::Index { message, location }, - lance::Error::InvalidInput { source, location } => Self::InvalidArgument { - message: source.to_string(), - location, - }, - _ => Self::Other { - message: source.to_string(), - location: std::panic::Location::caller().to_snafu_location(), - }, - } - } +impl JavaErrorExt for std::result::Result { + fn infer_error(self) -> JavaResult { + match &self { + Ok(_) => Ok(self.unwrap()), + Err(err) => match err { + _ => self.runtime_error(), + }, + } + } } + +impl JavaErrorExt for std::result::Result { + fn infer_error(self) -> JavaResult { + match &self { + Ok(_) => Ok(self.unwrap()), + Err(err) => match err { + _ => self.input_error(), + }, + } + } +} \ No newline at end of file diff --git a/java/core/lance-jni/src/ffi.rs b/java/core/lance-jni/src/ffi.rs index cebe9661e8..e40574f65a 100644 --- a/java/core/lance-jni/src/ffi.rs +++ b/java/core/lance-jni/src/ffi.rs @@ -14,163 +14,162 @@ use core::slice; +use crate::error::{JavaResult, JavaErrorExt}; use jni::objects::{JByteBuffer, JObjectArray, JString}; use jni::sys::jobjectArray; use jni::{objects::JObject, JNIEnv}; -use crate::error::{Error, Result}; - /// Extend JNIEnv with helper functions. pub trait JNIEnvExt { /// Get integers from Java List object. - fn get_integers(&mut self, obj: &JObject) -> Result>; + fn get_integers(&mut self, obj: &JObject) -> JavaResult>; /// Get strings from Java List object. - fn get_strings(&mut self, obj: &JObject) -> Result>; + fn get_strings(&mut self, obj: &JObject) -> JavaResult>; /// Get strings from Java String[] object. /// Note that get Option> from Java Optional just doesn't work. - fn get_strings_array(&mut self, obj: jobjectArray) -> Result>; + fn get_strings_array(&mut self, obj: jobjectArray) -> JavaResult>; /// Get Option from Java Optional. - fn get_string_opt(&mut self, obj: &JObject) -> Result>; + fn get_string_opt(&mut self, obj: &JObject) -> JavaResult>; /// Get Option> from Java Optional>. - fn get_strings_opt(&mut self, obj: &JObject) -> Result>>; + fn get_strings_opt(&mut self, obj: &JObject) -> JavaResult>>; /// Get Option from Java Optional. - fn get_int_opt(&mut self, obj: &JObject) -> Result>; + fn get_int_opt(&mut self, obj: &JObject) -> JavaResult>; /// Get Option> from Java Optional>. - fn get_ints_opt(&mut self, obj: &JObject) -> Result>>; + fn get_ints_opt(&mut self, obj: &JObject) -> JavaResult>>; /// Get Option from Java Optional. - fn get_long_opt(&mut self, obj: &JObject) -> Result>; + fn get_long_opt(&mut self, obj: &JObject) -> JavaResult>; /// Get Option from Java Optional. - fn get_u64_opt(&mut self, obj: &JObject) -> Result>; + fn get_u64_opt(&mut self, obj: &JObject) -> JavaResult>; /// Get Option<&[u8]> from Java Optional. - fn get_bytes_opt(&mut self, obj: &JObject) -> Result>; + fn get_bytes_opt(&mut self, obj: &JObject) -> JavaResult>; - fn get_optional(&mut self, obj: &JObject, f: F) -> Result> + fn get_optional(&mut self, obj: &JObject, f: F) -> JavaResult> where - F: FnOnce(&mut JNIEnv, &JObject) -> Result; + F: FnOnce(&mut JNIEnv, &JObject) -> JavaResult; } impl JNIEnvExt for JNIEnv<'_> { - fn get_integers(&mut self, obj: &JObject) -> Result> { - let list = self.get_list(obj)?; - let mut iter = list.iter(self)?; - let mut results = Vec::with_capacity(list.size(self)? as usize); - while let Some(elem) = iter.next(self)? { - let int_obj = self.call_method(elem, "intValue", "()I", &[])?; - let int_value = int_obj.i()?; + fn get_integers(&mut self, obj: &JObject) -> JavaResult> { + let list = self.get_list(obj).infer_error()?; + let mut iter = list.iter(self).infer_error()?; + let mut results = Vec::with_capacity(list.size(self).infer_error()? as usize); + while let Some(elem) = iter.next(self).infer_error()? { + let int_obj = self.call_method(elem, "intValue", "()I", &[]).infer_error()?; + let int_value = int_obj.i().infer_error()?; results.push(int_value); } Ok(results) } - fn get_strings(&mut self, obj: &JObject) -> Result> { - let list = self.get_list(obj)?; - let mut iter = list.iter(self)?; - let mut results = Vec::with_capacity(list.size(self)? as usize); - while let Some(elem) = iter.next(self)? { + fn get_strings(&mut self, obj: &JObject) -> JavaResult> { + let list = self.get_list(obj).infer_error()?; + let mut iter = list.iter(self).infer_error()?; + let mut results = Vec::with_capacity(list.size(self).infer_error()? as usize); + while let Some(elem) = iter.next(self).infer_error()? { let jstr = JString::from(elem); - let val = self.get_string(&jstr)?; - results.push(val.to_str()?.to_string()) + let val = self.get_string(&jstr).infer_error()?; + results.push(val.to_str().infer_error()?.to_string()) } Ok(results) } - fn get_strings_array(&mut self, obj: jobjectArray) -> Result> { + fn get_strings_array(&mut self, obj: jobjectArray) -> JavaResult> { let jobject_array = unsafe { JObjectArray::from_raw(obj) }; - let array_len = self.get_array_length(&jobject_array)?; + let array_len = self.get_array_length(&jobject_array).infer_error()?; let mut res: Vec = Vec::new(); for i in 0..array_len { - let item: JString = self.get_object_array_element(&jobject_array, i)?.into(); - res.push(self.get_string(&item)?.into()); + let item: JString = self.get_object_array_element(&jobject_array, i).infer_error()?.into(); + res.push(self.get_string(&item).infer_error()?.into()); } Ok(res) } - fn get_string_opt(&mut self, obj: &JObject) -> Result> { + fn get_string_opt(&mut self, obj: &JObject) -> JavaResult> { self.get_optional(obj, |env, inner_obj| { - let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?; - let java_string_obj = java_obj_gen.l()?; + let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[]).infer_error()?; + let java_string_obj = java_obj_gen.l().infer_error()?; let jstr = JString::from(java_string_obj); - let val = env.get_string(&jstr)?; - Ok(val.to_str()?.to_string()) + let val = env.get_string(&jstr).infer_error()?; + Ok(val.to_str().infer_error()?.to_string()) }) } - fn get_strings_opt(&mut self, obj: &JObject) -> Result>> { + fn get_strings_opt(&mut self, obj: &JObject) -> JavaResult>> { self.get_optional(obj, |env, inner_obj| { - let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?; - let java_list_obj = java_obj_gen.l()?; + let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[]).infer_error()?; + let java_list_obj = java_obj_gen.l().infer_error()?; env.get_strings(&java_list_obj) }) } - fn get_int_opt(&mut self, obj: &JObject) -> Result> { + fn get_int_opt(&mut self, obj: &JObject) -> JavaResult> { self.get_optional(obj, |env, inner_obj| { - let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?; - let java_int_obj = java_obj_gen.l()?; - let int_obj = env.call_method(java_int_obj, "intValue", "()I", &[])?; - let int_value = int_obj.i()?; + let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[]).infer_error()?; + let java_int_obj = java_obj_gen.l().infer_error()?; + let int_obj = env.call_method(java_int_obj, "intValue", "()I", &[]).infer_error()?; + let int_value = int_obj.i().infer_error()?; Ok(int_value) }) } - fn get_ints_opt(&mut self, obj: &JObject) -> Result>> { + fn get_ints_opt(&mut self, obj: &JObject) -> JavaResult>> { self.get_optional(obj, |env, inner_obj| { - let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?; - let java_list_obj = java_obj_gen.l()?; + let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[]).infer_error()?; + let java_list_obj = java_obj_gen.l().infer_error()?; env.get_integers(&java_list_obj) }) } - fn get_long_opt(&mut self, obj: &JObject) -> Result> { + fn get_long_opt(&mut self, obj: &JObject) -> JavaResult> { self.get_optional(obj, |env, inner_obj| { - let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?; - let java_long_obj = java_obj_gen.l()?; - let long_obj = env.call_method(java_long_obj, "longValue", "()J", &[])?; - let long_value = long_obj.j()?; + let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[]).infer_error()?; + let java_long_obj = java_obj_gen.l().infer_error()?; + let long_obj = env.call_method(java_long_obj, "longValue", "()J", &[]).infer_error()?; + let long_value = long_obj.j().infer_error()?; Ok(long_value) }) } - fn get_u64_opt(&mut self, obj: &JObject) -> Result> { + fn get_u64_opt(&mut self, obj: &JObject) -> JavaResult> { self.get_optional(obj, |env, inner_obj| { - let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?; - let java_long_obj = java_obj_gen.l()?; - let long_obj = env.call_method(java_long_obj, "longValue", "()J", &[])?; - let long_value = long_obj.j()?; + let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[]).infer_error()?; + let java_long_obj = java_obj_gen.l().infer_error()?; + let long_obj = env.call_method(java_long_obj, "longValue", "()J", &[]).infer_error()?; + let long_value = long_obj.j().infer_error()?; Ok(long_value as u64) }) } - fn get_bytes_opt(&mut self, obj: &JObject) -> Result> { + fn get_bytes_opt(&mut self, obj: &JObject) -> JavaResult> { self.get_optional(obj, |env, inner_obj| { - let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?; - let java_byte_buffer_obj = java_obj_gen.l()?; + let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[]).infer_error()?; + let java_byte_buffer_obj = java_obj_gen.l().infer_error()?; let j_byte_buffer = JByteBuffer::from(java_byte_buffer_obj); - let raw_data = env.get_direct_buffer_address(&j_byte_buffer)?; - let capacity = env.get_direct_buffer_capacity(&j_byte_buffer)?; + let raw_data = env.get_direct_buffer_address(&j_byte_buffer).infer_error()?; + let capacity = env.get_direct_buffer_capacity(&j_byte_buffer).infer_error()?; let data = unsafe { slice::from_raw_parts(raw_data, capacity) }; Ok(data) }) } - fn get_optional(&mut self, obj: &JObject, f: F) -> Result> + fn get_optional(&mut self, obj: &JObject, f: F) -> JavaResult> where - F: FnOnce(&mut JNIEnv, &JObject) -> Result, + F: FnOnce(&mut JNIEnv, &JObject) -> JavaResult, { if obj.is_null() { return Ok(None); } - let is_empty = self.call_method(obj, "isEmpty", "()Z", &[])?; - if is_empty.z()? { + let is_empty = self.call_method(obj, "isEmpty", "()Z", &[]).infer_error()?; + if is_empty.z().infer_error()? { // TODO(lu): put get java object into here cuz can only get java Object Ok(None) } else { diff --git a/java/core/lance-jni/src/fragment.rs b/java/core/lance-jni/src/fragment.rs index 1bf3eb2e68..f13697af74 100644 --- a/java/core/lance-jni/src/fragment.rs +++ b/java/core/lance-jni/src/fragment.rs @@ -21,33 +21,54 @@ use jni::{ sys::{jint, jlong}, JNIEnv, }; -use snafu::{location, Location}; use std::iter::once; use lance::dataset::fragment::FileFragment; +use crate::JavaError; use crate::{ blocking_dataset::{BlockingDataset, NATIVE_DATASET}, - error::{Error, Result}, ffi::JNIEnvExt, traits::FromJString, utils::extract_write_params, RT, }; +use crate::error::{JavaResult, JavaErrorExt}; -fn fragment_count_rows(dataset: &BlockingDataset, fragment_id: jlong) -> Result { - let Some(fragment) = dataset.inner.get_fragment(fragment_id as usize) else { - return Err(Error::InvalidArgument { - message: format!("Fragment not found: {}", fragment_id), - location: location!(), - }); + +/////////////////// +// Write Methods // +/////////////////// + + +////////////////// +// Read Methods // +////////////////// +#[no_mangle] +pub extern "system" fn Java_com_lancedb_lance_DatasetFragment_countRowsNative( + mut env: JNIEnv, + _jfragment: JObject, + jdataset: JObject, + fragment_id: jlong, +) -> jint { + ok_or_throw_with_return!(env, inner_count_rows_native(&mut env, jdataset, fragment_id), -1) as jint +} + +fn inner_count_rows_native( + env: &mut JNIEnv, + jdataset: JObject, + fragment_id: jlong, +) -> JavaResult { + let dataset =unsafe { env.get_rust_field::<_, _, BlockingDataset>(jdataset, NATIVE_DATASET) }.infer_error()?; + let Some(fragment) = dataset.inner.get_fragment(fragment_id as usize) else { + return Err(JavaError::input_error(format!("Fragment not found: {fragment_id}"))) }; - Ok(RT.block_on(fragment.count_rows())? as jint) + RT.block_on(fragment.count_rows()).infer_error() } #[no_mangle] -pub extern "system" fn Java_com_lancedb_lance_Fragment_createWithFfiArray<'a>( - mut env: JNIEnv<'a>, +pub extern "system" fn Java_com_lancedb_lance_Fragment_createWithFfiArray<'local>( + mut env: JNIEnv<'local>, _obj: JObject, dataset_uri: JString, arrow_array_addr: jlong, @@ -57,38 +78,43 @@ pub extern "system" fn Java_com_lancedb_lance_Fragment_createWithFfiArray<'a>( max_rows_per_group: JObject, // Optional max_bytes_per_file: JObject, // Optional mode: JObject, // Optional -) -> JString<'a> { +) -> JString<'local> { + ok_or_throw_with_return!(env, inner_create_with_ffi_array(&mut env, dataset_uri, arrow_array_addr, arrow_schema_addr, fragment_id, max_rows_per_file, max_rows_per_group, max_bytes_per_file, mode), JString::default()) +} + +fn inner_create_with_ffi_array<'local>( + env: &mut JNIEnv<'local>, + dataset_uri: JString, + arrow_array_addr: jlong, + arrow_schema_addr: jlong, + fragment_id: JObject, // Optional + max_rows_per_file: JObject, // Optional + max_rows_per_group: JObject, // Optional + max_bytes_per_file: JObject, // Optional + mode: JObject, // Optional +) -> JavaResult> { let c_array_ptr = arrow_array_addr as *mut FFI_ArrowArray; let c_schema_ptr = arrow_schema_addr as *mut FFI_ArrowSchema; let c_array = unsafe { FFI_ArrowArray::from_raw(c_array_ptr) }; let c_schema = unsafe { FFI_ArrowSchema::from_raw(c_schema_ptr) }; - let data_type = - ok_or_throw_with_return!(env, DataType::try_from(&c_schema), JString::default()); + let data_type = DataType::try_from(&c_schema).infer_error()?; - let array_data = ok_or_throw_with_return!( - env, - unsafe { from_ffi_and_data_type(c_array, data_type) }, - JString::default() - ); + let array_data = unsafe { from_ffi_and_data_type(c_array, data_type) }.infer_error()?; let record_batch = RecordBatch::from(StructArray::from(array_data)); let batch_schema = record_batch.schema().clone(); let reader = RecordBatchIterator::new(once(Ok(record_batch)), batch_schema); - ok_or_throw_with_return!( + create_fragment( env, - create_fragment( - &mut env, - dataset_uri, - fragment_id, - max_rows_per_file, - max_rows_per_group, - max_bytes_per_file, - mode, - reader - ), - JString::default() + dataset_uri, + fragment_id, + max_rows_per_file, + max_rows_per_group, + max_bytes_per_file, + mode, + reader ) } @@ -104,29 +130,32 @@ pub extern "system" fn Java_com_lancedb_lance_Fragment_createWithFfiStream<'a>( max_bytes_per_file: JObject, // Optional mode: JObject, // Optional ) -> JString<'a> { + ok_or_throw_with_return!(env, inner_create_with_ffi_stream(&mut env, dataset_uri, arrow_array_stream_addr, fragment_id, max_rows_per_file, max_rows_per_group, max_bytes_per_file, mode), JString::default()) +} + +fn inner_create_with_ffi_stream<'a>( + env: &mut JNIEnv<'a>, + dataset_uri: JString, + arrow_array_stream_addr: jlong, + fragment_id: JObject, // Optional + max_rows_per_file: JObject, // Optional + max_rows_per_group: JObject, // Optional + max_bytes_per_file: JObject, // Optional + mode: JObject, // Optional +) -> JavaResult> { let stream_ptr = arrow_array_stream_addr as *mut FFI_ArrowArrayStream; - let reader = ok_or_throw_with_return!( - env, - unsafe { ArrowArrayStreamReader::from_raw(stream_ptr) }.map_err(|e| Error::Arrow { - message: e.to_string(), - location: location!(), - }), - JString::default() - ); - - ok_or_throw_with_return!( + let reader = + unsafe { ArrowArrayStreamReader::from_raw(stream_ptr) }.infer_error()?; + + create_fragment( env, - create_fragment( - &mut env, - dataset_uri, - fragment_id, - max_rows_per_file, - max_rows_per_group, - max_bytes_per_file, - mode, - reader - ), - JString::default() + dataset_uri, + fragment_id, + max_rows_per_file, + max_rows_per_group, + max_bytes_per_file, + mode, + reader ) } @@ -140,7 +169,7 @@ fn create_fragment<'a>( max_bytes_per_file: JObject, // Optional mode: JObject, // Optional reader: impl RecordBatchReader + Send + 'static, -) -> Result> { +) -> JavaResult> { let path_str = dataset_uri.extract(env)?; let fragment_id_opts = env.get_int_opt(&fragment_id)?; @@ -157,26 +186,7 @@ fn create_fragment<'a>( fragment_id_opts.unwrap_or(0) as usize, reader, Some(write_params), - ))?; - let json_string = serde_json::to_string(&fragment)?; - Ok(env.new_string(json_string)?) -} - -#[no_mangle] -pub extern "system" fn Java_com_lancedb_lance_DatasetFragment_countRowsNative( - mut env: JNIEnv, - _jfragment: JObject, - jdataset: JObject, - fragment_id: jlong, -) -> jint { - ok_or_throw_with_return!( - env, - { - let dataset = - unsafe { env.get_rust_field::<_, _, BlockingDataset>(jdataset, NATIVE_DATASET) } - .expect("Dataset handle not set"); - fragment_count_rows(&dataset, fragment_id) - }, - -1 - ) + )).infer_error()?; + let json_string = serde_json::to_string(&fragment).infer_error()?; + Ok(env.new_string(json_string).infer_error()?) } diff --git a/java/core/lance-jni/src/lib.rs b/java/core/lance-jni/src/lib.rs index d06d095e55..c5b63b4aa6 100644 --- a/java/core/lance-jni/src/lib.rs +++ b/java/core/lance-jni/src/lib.rs @@ -12,32 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::iter::empty; -use std::sync::Arc; - -use arrow::array::{RecordBatchIterator, RecordBatchReader}; -use arrow::ffi::FFI_ArrowSchema; -use arrow::ffi_stream::{ArrowArrayStreamReader, FFI_ArrowArrayStream}; -use arrow_schema::Schema; -use ffi::JNIEnvExt; -use jni::objects::{JObject, JString}; -use jni::sys::{jint, jlong}; -use jni::JNIEnv; -use lance::dataset::transaction::Operation; -use lance::table::format::Fragment; -use lazy_static::lazy_static; -use snafu::{location, Location}; -use traits::IntoJava; - -use crate::utils::extract_write_params; - #[macro_export] macro_rules! ok_or_throw { ($env:expr, $result:expr) => { match $result { Ok(value) => value, Err(err) => { - Error::from(err).throw(&mut $env); + err.throw(&mut $env); return JObject::null(); } } @@ -49,7 +30,7 @@ macro_rules! ok_or_throw_without_return { match $result { Ok(value) => value, Err(err) => { - Error::from(err).throw(&mut $env); + err.throw(&mut $env); return; } } @@ -62,7 +43,7 @@ macro_rules! ok_or_throw_with_return { match $result { Ok(value) => value, Err(err) => { - Error::from(err).throw(&mut $env); + err.throw(&mut $env); return $ret; } } @@ -76,10 +57,9 @@ mod ffi; mod fragment; mod traits; mod utils; +pub use error::{JavaError, JavaResult, JavaErrorExt}; -use self::traits::FromJString; -use crate::blocking_dataset::BlockingDataset; -pub use error::{Error, Result}; +use lazy_static::lazy_static; lazy_static! { static ref RT: tokio::runtime::Runtime = tokio::runtime::Builder::new_multi_thread() @@ -87,180 +67,3 @@ lazy_static! { .build() .expect("Failed to create tokio runtime"); } - -#[no_mangle] -pub extern "system" fn Java_com_lancedb_lance_Dataset_createWithFfiSchema<'local>( - mut env: JNIEnv<'local>, - _obj: JObject, - arrow_schema_addr: jlong, - path: JString, - max_rows_per_file: JObject, // Optional - max_rows_per_group: JObject, // Optional - max_bytes_per_file: JObject, // Optional - mode: JObject, // Optional -) -> JObject<'local> { - let c_schema_ptr = arrow_schema_addr as *mut FFI_ArrowSchema; - let c_schema = unsafe { FFI_ArrowSchema::from_raw(c_schema_ptr) }; - let schema = ok_or_throw!( - env, - Schema::try_from(&c_schema).map_err(|e| Error::Arrow { - message: e.to_string(), - location: location!(), - }) - ); - - let reader = RecordBatchIterator::new(empty(), Arc::new(schema)); - ok_or_throw!( - env, - create_dataset( - &mut env, - path, - max_rows_per_file, - max_rows_per_group, - max_bytes_per_file, - mode, - reader - ) - ) -} - -#[no_mangle] -pub extern "system" fn Java_com_lancedb_lance_Dataset_writeWithFfiStream<'local>( - mut env: JNIEnv<'local>, - _obj: JObject, - arrow_array_stream_addr: jlong, - path: JString, - max_rows_per_file: JObject, // Optional - max_rows_per_group: JObject, // Optional - max_bytes_per_file: JObject, // Optional - mode: JObject, // Optional -) -> JObject<'local> { - let stream_ptr = arrow_array_stream_addr as *mut FFI_ArrowArrayStream; - let reader = ok_or_throw!( - env, - unsafe { ArrowArrayStreamReader::from_raw(stream_ptr) }.map_err(|e| Error::Arrow { - message: e.to_string(), - location: location!(), - }) - ); - - ok_or_throw!( - env, - create_dataset( - &mut env, - path, - max_rows_per_file, - max_rows_per_group, - max_bytes_per_file, - mode, - reader - ) - ) -} - -fn create_dataset<'local>( - env: &mut JNIEnv<'local>, - path: JString, - max_rows_per_file: JObject, - max_rows_per_group: JObject, - max_bytes_per_file: JObject, - mode: JObject, - reader: impl RecordBatchReader + Send + 'static, -) -> Result> { - let path_str = path.extract(env)?; - - let write_params = extract_write_params( - env, - &max_rows_per_file, - &max_rows_per_group, - &max_bytes_per_file, - &mode, - )?; - - let dataset = BlockingDataset::write(reader, &path_str, Some(write_params))?; - Ok(dataset.into_java(env)) -} - -#[no_mangle] -pub extern "system" fn Java_com_lancedb_lance_Dataset_openNative<'local>( - mut env: JNIEnv<'local>, - _obj: JObject, - path: JString, -) -> JObject<'local> { - let path_str: String = ok_or_throw!(env, path.extract(&mut env)); - - let dataset = ok_or_throw!(env, BlockingDataset::open(&path_str)); - dataset.into_java(&mut env) -} - -#[no_mangle] -pub extern "system" fn Java_com_lancedb_lance_Dataset_commitAppend<'local>( - mut env: JNIEnv<'local>, - _obj: JObject, - path: JString, - read_version_obj: JObject, // Optional - fragments_obj: JObject, // List, String is json serialized Fragment -) -> JObject<'local> { - let json_fragments = ok_or_throw!(env, env.get_strings(&fragments_obj)); - let mut fragments: Vec = Vec::new(); - for json_fragment in json_fragments { - let fragment = ok_or_throw!( - env, - Fragment::from_json(&json_fragment).map_err(|err| Error::IO { - message: err.to_string(), - location: location!() - }) - ); - fragments.push(fragment); - } - let op = Operation::Append { fragments }; - let path_str: String = ok_or_throw!(env, path.extract(&mut env)); - - let read_version = ok_or_throw!(env, env.get_u64_opt(&read_version_obj)); - let dataset = ok_or_throw!(env, BlockingDataset::commit(&path_str, op, read_version)); - dataset.into_java(&mut env) -} - -#[no_mangle] -pub extern "system" fn Java_com_lancedb_lance_Dataset_version( - mut env: JNIEnv, - java_dataset: JObject, -) -> jlong { - let dataset_guard = - unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, "nativeDatasetHandle") }; - match dataset_guard { - Ok(dataset) => dataset.inner.version().version as jlong, - Err(_) => -1, - } -} - -#[no_mangle] -pub extern "system" fn Java_com_lancedb_lance_Dataset_latestVersion( - mut env: JNIEnv, - java_dataset: JObject, -) -> jlong { - let dataset_guard = - unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, "nativeDatasetHandle") }; - match dataset_guard { - Ok(dataset) => dataset - .latest_version() - .expect("Failed to get the latest version.") as jlong, - Err(_) => -1, - } -} - -#[no_mangle] -pub extern "system" fn Java_com_lancedb_lance_Dataset_countRows( - mut env: JNIEnv, - java_dataset: JObject, -) -> jint { - let dataset_guard = - unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, "nativeDatasetHandle") }; - match dataset_guard { - Ok(dataset) => dataset - .count_rows(None) - .expect("Faild to get the row count from dataset's metadata.") - as jint, - Err(_) => -1, - } -} diff --git a/java/core/lance-jni/src/traits.rs b/java/core/lance-jni/src/traits.rs index 01a7a490ea..462b0307d5 100644 --- a/java/core/lance-jni/src/traits.rs +++ b/java/core/lance-jni/src/traits.rs @@ -15,74 +15,74 @@ use jni::objects::{JMap, JObject, JString, JValue}; use jni::JNIEnv; -use crate::Result; +use crate::error::{JavaErrorExt, JavaResult}; pub trait FromJObject { - fn extract(&self) -> Result; + fn extract(&self) -> JavaResult; } /// Convert a Rust type into a Java Object. pub trait IntoJava { - fn into_java<'a>(self, env: &mut JNIEnv<'a>) -> JObject<'a>; + fn into_java<'a>(self, env: &mut JNIEnv<'a>) -> JavaResult>; } impl FromJObject for JObject<'_> { - fn extract(&self) -> Result { - Ok(JValue::from(self).i()?) + fn extract(&self) -> JavaResult { + Ok(JValue::from(self).i().infer_error()?) } } impl FromJObject for JObject<'_> { - fn extract(&self) -> Result { - Ok(JValue::from(self).j()?) + fn extract(&self) -> JavaResult { + Ok(JValue::from(self).j().infer_error()?) } } impl FromJObject for JObject<'_> { - fn extract(&self) -> Result { - Ok(JValue::from(self).f()?) + fn extract(&self) -> JavaResult { + Ok(JValue::from(self).f().infer_error()?) } } impl FromJObject for JObject<'_> { - fn extract(&self) -> Result { - Ok(JValue::from(self).d()?) + fn extract(&self) -> JavaResult { + Ok(JValue::from(self).d().infer_error()?) } } pub trait FromJString { - fn extract(&self, env: &mut JNIEnv) -> Result; + fn extract(&self, env: &mut JNIEnv) -> JavaResult; } impl FromJString for JString<'_> { - fn extract(&self, env: &mut JNIEnv) -> Result { - Ok(env.get_string(self)?.into()) + fn extract(&self, env: &mut JNIEnv) -> JavaResult { + Ok(env.get_string(self).infer_error()?.into()) } } pub trait JMapExt { #[allow(dead_code)] - fn get_string(&self, env: &mut JNIEnv, key: &str) -> Result>; + fn get_string(&self, env: &mut JNIEnv, key: &str) -> JavaResult>; #[allow(dead_code)] - fn get_i32(&self, env: &mut JNIEnv, key: &str) -> Result>; + fn get_i32(&self, env: &mut JNIEnv, key: &str) -> JavaResult>; #[allow(dead_code)] - fn get_i64(&self, env: &mut JNIEnv, key: &str) -> Result>; + fn get_i64(&self, env: &mut JNIEnv, key: &str) -> JavaResult>; #[allow(dead_code)] - fn get_f32(&self, env: &mut JNIEnv, key: &str) -> Result>; + fn get_f32(&self, env: &mut JNIEnv, key: &str) -> JavaResult>; #[allow(dead_code)] - fn get_f64(&self, env: &mut JNIEnv, key: &str) -> Result>; + fn get_f64(&self, env: &mut JNIEnv, key: &str) -> JavaResult>; } -fn get_map_value(env: &mut JNIEnv, map: &JMap, key: &str) -> Result> +fn get_map_value(env: &mut JNIEnv, map: &JMap, key: &str) -> JavaResult> where for<'a> JObject<'a>: FromJObject, { - let key_obj: JObject = env.new_string(key)?.into(); - if let Some(value) = map.get(env, &key_obj)? { + let key_obj: JObject = env.new_string(key).infer_error()?.into(); + if let Some(value) = map.get(env, &key_obj).infer_error()? { if value.is_null() { Ok(None) } else { @@ -94,9 +94,9 @@ where } impl JMapExt for JMap<'_, '_, '_> { - fn get_string(&self, env: &mut JNIEnv, key: &str) -> Result> { - let key_obj: JObject = env.new_string(key)?.into(); - if let Some(value) = self.get(env, &key_obj)? { + fn get_string(&self, env: &mut JNIEnv, key: &str) -> JavaResult> { + let key_obj: JObject = env.new_string(key).infer_error()?.into(); + if let Some(value) = self.get(env, &key_obj).infer_error()? { let value_str: JString = value.into(); Ok(Some(value_str.extract(env)?)) } else { @@ -104,19 +104,19 @@ impl JMapExt for JMap<'_, '_, '_> { } } - fn get_i32(&self, env: &mut JNIEnv, key: &str) -> Result> { + fn get_i32(&self, env: &mut JNIEnv, key: &str) -> JavaResult> { get_map_value(env, self, key) } - fn get_i64(&self, env: &mut JNIEnv, key: &str) -> Result> { + fn get_i64(&self, env: &mut JNIEnv, key: &str) -> JavaResult> { get_map_value(env, self, key) } - fn get_f32(&self, env: &mut JNIEnv, key: &str) -> Result> { + fn get_f32(&self, env: &mut JNIEnv, key: &str) -> JavaResult> { get_map_value(env, self, key) } - fn get_f64(&self, env: &mut JNIEnv, key: &str) -> Result> { + fn get_f64(&self, env: &mut JNIEnv, key: &str) -> JavaResult> { get_map_value(env, self, key) } } diff --git a/java/core/lance-jni/src/utils.rs b/java/core/lance-jni/src/utils.rs index 82c8540d25..27948b842d 100644 --- a/java/core/lance-jni/src/utils.rs +++ b/java/core/lance-jni/src/utils.rs @@ -12,16 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -use arrow::ffi::FFI_ArrowSchema; -use arrow_schema::Schema; use jni::objects::JObject; -use jni::sys::jlong; use jni::JNIEnv; use lance::dataset::{WriteMode, WriteParams}; -use crate::blocking_dataset::{BlockingDataset, NATIVE_DATASET}; use crate::ffi::JNIEnvExt; -use crate::{Error, Result}; +use crate::error::JavaResult; +use crate::JavaErrorExt; pub fn extract_write_params( env: &mut JNIEnv, @@ -29,7 +26,7 @@ pub fn extract_write_params( max_rows_per_group: &JObject, max_bytes_per_file: &JObject, mode: &JObject, -) -> Result { +) -> JavaResult { let mut write_params = WriteParams::default(); if let Some(max_rows_per_file_val) = env.get_int_opt(max_rows_per_file)? { @@ -42,31 +39,7 @@ pub fn extract_write_params( write_params.max_bytes_per_file = max_bytes_per_file_val as usize; } if let Some(mode_val) = env.get_string_opt(mode)? { - write_params.mode = WriteMode::try_from(mode_val.as_str())?; + write_params.mode = WriteMode::try_from(mode_val.as_str()).infer_error()?; } Ok(write_params) } - -pub fn import_ffi_schema( - mut env: JNIEnv, - jdataset: JObject, - arrow_schema_addr: jlong, - columns: Option>, -) { - let dataset = { - let dataset = - unsafe { env.get_rust_field::<_, _, BlockingDataset>(jdataset, NATIVE_DATASET) } - .expect("Failed to get native dataset handle"); - dataset.clone() - }; - let schema = if let Some(columns) = columns { - let ds_schema = ok_or_throw_without_return!(env, dataset.inner.schema().project(&columns)); - Schema::from(&ds_schema) - } else { - Schema::from(dataset.inner.schema()) - }; - - let c_schema = ok_or_throw_without_return!(env, FFI_ArrowSchema::try_from(&schema)); - let out_c_schema = unsafe { &mut *(arrow_schema_addr as *mut FFI_ArrowSchema) }; - let _old = std::mem::replace(out_c_schema, c_schema); -} diff --git a/java/core/src/main/java/com/lancedb/lance/Dataset.java b/java/core/src/main/java/com/lancedb/lance/Dataset.java index ffc0c3d0b8..c812d54be8 100644 --- a/java/core/src/main/java/com/lancedb/lance/Dataset.java +++ b/java/core/src/main/java/com/lancedb/lance/Dataset.java @@ -69,12 +69,8 @@ public static Dataset create(BufferAllocator allocator, String path, Schema sche } } - private static native Dataset createWithFfiSchema(long arrowSchemaMemoryAddress, String path, - Optional maxRowsPerFile, Optional maxRowsPerGroup, - Optional maxBytesPerFile, Optional mode); - /** - * Write a dataset to the specified path. + * Create a dataset with given stream. * * @param allocator buffer allocator * @param stream arrow stream @@ -82,16 +78,20 @@ private static native Dataset createWithFfiSchema(long arrowSchemaMemoryAddress, * @param params write parameters * @return Dataset */ - public static Dataset write(BufferAllocator allocator, ArrowArrayStream stream, + public static Dataset create(BufferAllocator allocator, ArrowArrayStream stream, String path, WriteParams params) { - var dataset = writeWithFfiStream(stream.memoryAddress(), path, + var dataset = createWithFfiStream(stream.memoryAddress(), path, params.getMaxRowsPerFile(), params.getMaxRowsPerGroup(), params.getMaxBytesPerFile(), params.getMode()); dataset.allocator = allocator; return dataset; } - private static native Dataset writeWithFfiStream(long arrowStreamMemoryAddress, String path, + private static native Dataset createWithFfiSchema(long arrowSchemaMemoryAddress, String path, + Optional maxRowsPerFile, Optional maxRowsPerGroup, + Optional maxBytesPerFile, Optional mode); + + private static native Dataset createWithFfiStream(long arrowStreamMemoryAddress, String path, Optional maxRowsPerFile, Optional maxRowsPerGroup, Optional maxBytesPerFile, Optional mode); diff --git a/java/core/src/test/java/com/lancedb/lance/TestUtils.java b/java/core/src/test/java/com/lancedb/lance/TestUtils.java index 2a06e5d80f..fa9e2452af 100644 --- a/java/core/src/test/java/com/lancedb/lance/TestUtils.java +++ b/java/core/src/test/java/com/lancedb/lance/TestUtils.java @@ -139,7 +139,7 @@ public void createDatasetAndValidate() throws IOException, URISyntaxException { allocator); ArrowArrayStream arrowStream = ArrowArrayStream.allocateNew(allocator)) { Data.exportArrayStream(allocator, reader, arrowStream); - try (Dataset dataset = Dataset.write( + try (Dataset dataset = Dataset.create( allocator, arrowStream, datasetPath, From 7ae6fc2db25c507df15a07406fef3bc81bc447fb Mon Sep 17 00:00:00 2001 From: Lu Qiu Date: Fri, 10 May 2024 21:06:34 -0700 Subject: [PATCH 2/9] fmt --- java/core/lance-jni/src/blocking_dataset.rs | 126 +++++++---- java/core/lance-jni/src/blocking_scanner.rs | 63 ++++-- java/core/lance-jni/src/error.rs | 223 ++++++++++---------- java/core/lance-jni/src/ffi.rs | 59 ++++-- java/core/lance-jni/src/fragment.rs | 73 +++++-- java/core/lance-jni/src/lib.rs | 2 +- java/core/lance-jni/src/utils.rs | 2 +- 7 files changed, 339 insertions(+), 209 deletions(-) diff --git a/java/core/lance-jni/src/blocking_dataset.rs b/java/core/lance-jni/src/blocking_dataset.rs index 4aaff5d423..910fba2da4 100644 --- a/java/core/lance-jni/src/blocking_dataset.rs +++ b/java/core/lance-jni/src/blocking_dataset.rs @@ -12,26 +12,26 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::iter::empty; +use crate::error::{JavaErrorExt, JavaResult}; use crate::ffi::JNIEnvExt; -use crate::error::{JavaResult, JavaErrorExt}; -use crate::{traits::IntoJava, RT}; +use crate::traits::FromJString; use crate::utils::extract_write_params; +use crate::{traits::IntoJava, RT}; use arrow::array::RecordBatchReader; +use arrow::datatypes::Schema; +use arrow::ffi::FFI_ArrowSchema; +use arrow::ffi_stream::ArrowArrayStreamReader; +use arrow::ffi_stream::FFI_ArrowArrayStream; +use arrow::record_batch::RecordBatchIterator; +use jni::objects::JString; +use jni::sys::jint; use jni::sys::jlong; use jni::{objects::JObject, JNIEnv}; use lance::dataset::transaction::Operation; use lance::dataset::{Dataset, WriteParams}; use lance::table::format::Fragment; -use jni::sys::jint; -use jni::objects::JString; -use arrow::ffi::FFI_ArrowSchema; -use arrow::datatypes::Schema; -use arrow::record_batch::RecordBatchIterator; +use std::iter::empty; use std::sync::Arc; -use arrow::ffi_stream::FFI_ArrowArrayStream; -use arrow::ffi_stream::ArrowArrayStreamReader; -use crate::traits::FromJString; pub const NATIVE_DATASET: &str = "nativeDatasetHandle"; @@ -46,7 +46,9 @@ impl BlockingDataset { uri: &str, params: Option, ) -> JavaResult { - let inner = RT.block_on(Dataset::write(reader, uri, params)).infer_error()?; + let inner = RT + .block_on(Dataset::write(reader, uri, params)) + .infer_error()?; Ok(Self { inner }) } @@ -56,7 +58,9 @@ impl BlockingDataset { } pub fn commit(uri: &str, operation: Operation, read_version: Option) -> JavaResult { - let inner = RT.block_on(Dataset::commit(uri, operation, read_version, None, None)).infer_error()?; + let inner = RT + .block_on(Dataset::commit(uri, operation, read_version, None, None)) + .infer_error()?; Ok(Self { inner }) } @@ -85,10 +89,21 @@ pub extern "system" fn Java_com_lancedb_lance_Dataset_createWithFfiSchema<'local max_bytes_per_file: JObject, // Optional mode: JObject, // Optional ) -> JObject<'local> { - ok_or_throw!(env, inner_create_with_ffi_schema(&mut env, arrow_schema_addr, path, max_rows_per_file, max_rows_per_group, max_bytes_per_file, mode)) + ok_or_throw!( + env, + inner_create_with_ffi_schema( + &mut env, + arrow_schema_addr, + path, + max_rows_per_file, + max_rows_per_group, + max_bytes_per_file, + mode + ) + ) } -fn inner_create_with_ffi_schema<'local> ( +fn inner_create_with_ffi_schema<'local>( env: &mut JNIEnv<'local>, arrow_schema_addr: jlong, path: JString, @@ -103,14 +118,14 @@ fn inner_create_with_ffi_schema<'local> ( let reader = RecordBatchIterator::new(empty(), Arc::new(schema)); create_dataset( - env, - path, - max_rows_per_file, - max_rows_per_group, - max_bytes_per_file, - mode, - reader - ) + env, + path, + max_rows_per_file, + max_rows_per_group, + max_bytes_per_file, + mode, + reader, + ) } #[no_mangle] @@ -124,7 +139,18 @@ pub extern "system" fn Java_com_lancedb_lance_Dataset_createWithFfiStream<'local max_bytes_per_file: JObject, // Optional mode: JObject, // Optional ) -> JObject<'local> { - ok_or_throw!(env, inner_create_with_ffi_stream(&mut env, arrow_array_stream_addr, path, max_rows_per_file, max_rows_per_group, max_bytes_per_file, mode)) + ok_or_throw!( + env, + inner_create_with_ffi_stream( + &mut env, + arrow_array_stream_addr, + path, + max_rows_per_file, + max_rows_per_group, + max_bytes_per_file, + mode + ) + ) } fn inner_create_with_ffi_stream<'local>( @@ -145,7 +171,7 @@ fn inner_create_with_ffi_stream<'local>( max_rows_per_group, max_bytes_per_file, mode, - reader + reader, ) } @@ -192,12 +218,14 @@ fn attach_native_dataset<'local>( // 1. The Java object (`j_dataset`) should implement the `java.io.Closeable` interface. // 2. Users of this Java object should be instructed to always use it within a try-with-resources // statement (or manually call the `close()` method) to ensure that `self.close()` is invoked. - unsafe { env.set_rust_field(&j_dataset, NATIVE_DATASET, dataset) }.infer_error()?; - Ok(j_dataset) + unsafe { env.set_rust_field(&j_dataset, NATIVE_DATASET, dataset) }.infer_error()?; + Ok(j_dataset) } fn create_java_dataset_object<'a>(env: &mut JNIEnv<'a>) -> JavaResult> { - Ok(env.new_object("com/lancedb/lance/Dataset", "()V", &[]).infer_error()?) + Ok(env + .new_object("com/lancedb/lance/Dataset", "()V", &[]) + .infer_error()?) } #[no_mangle] @@ -208,7 +236,10 @@ pub extern "system" fn Java_com_lancedb_lance_Dataset_commitAppend<'local>( read_version_obj: JObject, // Optional fragments_obj: JObject, // List, String is json serialized Fragment ) -> JObject<'local> { - ok_or_throw!(env, inner_commit_append(&mut env, path, read_version_obj, fragments_obj)) + ok_or_throw!( + env, + inner_commit_append(&mut env, path, read_version_obj, fragments_obj) + ) } pub fn inner_commit_append<'local>( @@ -238,13 +269,9 @@ pub extern "system" fn Java_com_lancedb_lance_Dataset_releaseNativeDataset( ok_or_throw_without_return!(env, inner_release_native_dataset(&mut env, obj)) } -fn inner_release_native_dataset( - env: &mut JNIEnv, - obj: JObject, -) -> JavaResult<()> { - let dataset: BlockingDataset = unsafe { - env.take_rust_field(obj, NATIVE_DATASET).infer_error()? - }; +fn inner_release_native_dataset(env: &mut JNIEnv, obj: JObject) -> JavaResult<()> { + let dataset: BlockingDataset = + unsafe { env.take_rust_field(obj, NATIVE_DATASET).infer_error()? }; dataset.close(); Ok(()) } @@ -284,7 +311,8 @@ fn inner_get_json_fragments<'local>( ) -> JavaResult> { let fragments = { let dataset = - unsafe { env.get_rust_field::<_, _, BlockingDataset>(jdataset, NATIVE_DATASET) }.infer_error()?; + unsafe { env.get_rust_field::<_, _, BlockingDataset>(jdataset, NATIVE_DATASET) } + .infer_error()?; dataset.inner.get_fragments() }; @@ -300,7 +328,8 @@ fn inner_get_json_fragments<'local>( "add", "(Ljava/lang/Object;)Z", &[(&jstring).into()], - ).infer_error()?; + ) + .infer_error()?; } Ok(array_list) } @@ -311,17 +340,21 @@ pub extern "system" fn Java_com_lancedb_lance_Dataset_importFfiSchema( jdataset: JObject, arrow_schema_addr: jlong, ) { - ok_or_throw_without_return!(env, inner_import_ffi_schema(&mut env, jdataset, arrow_schema_addr)) + ok_or_throw_without_return!( + env, + inner_import_ffi_schema(&mut env, jdataset, arrow_schema_addr) + ) } fn inner_import_ffi_schema( env: &mut JNIEnv, jdataset: JObject, arrow_schema_addr: jlong, -) -> JavaResult<()>{ +) -> JavaResult<()> { let dataset = { let dataset = - unsafe { env.get_rust_field::<_, _, BlockingDataset>(jdataset, NATIVE_DATASET) }.infer_error()?; + unsafe { env.get_rust_field::<_, _, BlockingDataset>(jdataset, NATIVE_DATASET) } + .infer_error()?; dataset.clone() }; let schema = Schema::from(dataset.inner.schema()); @@ -342,7 +375,8 @@ pub extern "system" fn Java_com_lancedb_lance_Dataset_version( fn inner_version(env: &mut JNIEnv, java_dataset: JObject) -> JavaResult { let dataset_guard = - unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }.infer_error()?; + unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) } + .infer_error()?; Ok(dataset_guard.inner.version().version) } @@ -356,7 +390,8 @@ pub extern "system" fn Java_com_lancedb_lance_Dataset_latestVersion( fn inner_latest_version(env: &mut JNIEnv, java_dataset: JObject) -> JavaResult { let dataset_guard = - unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }.infer_error()?; + unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) } + .infer_error()?; dataset_guard.latest_version() } @@ -370,6 +405,7 @@ pub extern "system" fn Java_com_lancedb_lance_Dataset_countRows( fn inner_count_rows(env: &mut JNIEnv, java_dataset: JObject) -> JavaResult { let dataset_guard = - unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }.infer_error()?; + unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) } + .infer_error()?; dataset_guard.count_rows(None) -} \ No newline at end of file +} diff --git a/java/core/lance-jni/src/blocking_scanner.rs b/java/core/lance-jni/src/blocking_scanner.rs index 199d04bd0d..9efd9a9e7b 100644 --- a/java/core/lance-jni/src/blocking_scanner.rs +++ b/java/core/lance-jni/src/blocking_scanner.rs @@ -14,6 +14,7 @@ use std::sync::Arc; +use crate::error::{JavaErrorExt, JavaResult}; use crate::ffi::JNIEnvExt; use crate::JavaError; use arrow::{ffi::FFI_ArrowSchema, ffi_stream::FFI_ArrowArrayStream}; @@ -21,11 +22,11 @@ use arrow_schema::SchemaRef; use jni::{objects::JObject, sys::jlong, JNIEnv}; use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner}; use lance_io::ffi::to_ffi_arrow_array_stream; -use crate::error::{JavaResult, JavaErrorExt}; use crate::{ blocking_dataset::{BlockingDataset, NATIVE_DATASET}, - traits::IntoJava, RT, + traits::IntoJava, + RT, }; pub const NATIVE_SCANNER: &str = "nativeScannerHandle"; @@ -69,7 +70,18 @@ pub extern "system" fn Java_com_lancedb_lance_ipc_LanceScanner_createScanner<'lo filter_obj: JObject, // Optional batch_size_obj: JObject, // Optional ) -> JObject<'local> { - ok_or_throw!(env, inner_create_scanner(&mut env, jdataset, fragment_ids_obj, columns_obj, substrait_filter_obj, filter_obj, batch_size_obj)) + ok_or_throw!( + env, + inner_create_scanner( + &mut env, + jdataset, + fragment_ids_obj, + columns_obj, + substrait_filter_obj, + filter_obj, + batch_size_obj + ) + ) } fn inner_create_scanner<'local>( @@ -83,7 +95,8 @@ fn inner_create_scanner<'local>( ) -> JavaResult> { let dataset = { let dataset = - unsafe { env.get_rust_field::<_, _, BlockingDataset>(jdataset, NATIVE_DATASET) }.infer_error()?; + unsafe { env.get_rust_field::<_, _, BlockingDataset>(jdataset, NATIVE_DATASET) } + .infer_error()?; dataset.clone() }; let mut scanner = dataset.inner.scan(); @@ -92,7 +105,9 @@ fn inner_create_scanner<'local>( let mut fragments = Vec::with_capacity(fragment_ids.len()); for fragment_id in fragment_ids { let Some(fragment) = dataset.inner.get_fragment(fragment_id as usize) else { - return Err(JavaError::input_error(format!("Fragment {fragment_id} not found"))); + return Err(JavaError::input_error(format!( + "Fragment {fragment_id} not found" + ))); }; fragments.push(fragment.metadata().clone()); } @@ -104,7 +119,8 @@ fn inner_create_scanner<'local>( }; let substrait_opt = env.get_bytes_opt(&substrait_filter_obj)?; if let Some(substrait) = substrait_opt { - RT.block_on(async { scanner.filter_substrait(substrait).await }).infer_error()?; + RT.block_on(async { scanner.filter_substrait(substrait).await }) + .infer_error()?; } let filter_opt = env.get_string_opt(&filter_obj)?; if let Some(filter) = filter_opt { @@ -126,10 +142,9 @@ pub extern "system" fn Java_com_lancedb_lance_ipc_LanceScanner_releaseNativeScan ok_or_throw_without_return!(env, inner_release_native_scanner(&mut env, j_scanner)); } -fn inner_release_native_scanner(env: &mut JNIEnv, j_scanner: JObject) -> JavaResult<()>{ - let _: BlockingScanner = unsafe { - env.take_rust_field(j_scanner, NATIVE_SCANNER) - }.infer_error()?; +fn inner_release_native_scanner(env: &mut JNIEnv, j_scanner: JObject) -> JavaResult<()> { + let _: BlockingScanner = + unsafe { env.take_rust_field(j_scanner, NATIVE_SCANNER) }.infer_error()?; Ok(()) } @@ -153,12 +168,13 @@ fn attach_native_scanner<'local>( // 1. The Java object (`j_scanner`) should implement the `java.io.Closeable` interface. // 2. Users of this Java object should be instructed to always use it within a try-with-resources // statement (or manually call the `close()` method) to ensure that `self.close()` is invoked. - unsafe {env.set_rust_field(&j_scanner, NATIVE_SCANNER, scanner)}.infer_error()?; + unsafe { env.set_rust_field(&j_scanner, NATIVE_SCANNER, scanner) }.infer_error()?; Ok(j_scanner) } fn create_java_scanner_object<'a>(env: &mut JNIEnv<'a>) -> JavaResult> { - env.new_object("com/lancedb/lance/ipc/LanceScanner", "()V", &[]).infer_error() + env.new_object("com/lancedb/lance/ipc/LanceScanner", "()V", &[]) + .infer_error() } ////////////////// @@ -176,11 +192,13 @@ pub extern "system" fn Java_com_lancedb_lance_ipc_LanceScanner_openStream( fn inner_open_stream(env: &mut JNIEnv, j_scanner: JObject, stream_addr: jlong) -> JavaResult<()> { let scanner = { let scanner_guard = - unsafe { env.get_rust_field::<_, _, BlockingScanner>(j_scanner, NATIVE_SCANNER) }.infer_error()?; + unsafe { env.get_rust_field::<_, _, BlockingScanner>(j_scanner, NATIVE_SCANNER) } + .infer_error()?; scanner_guard.clone() }; let record_batch_stream = scanner.open_stream()?; - let ffi_stream = to_ffi_arrow_array_stream(record_batch_stream, RT.handle().clone()).infer_error()?; + let ffi_stream = + to_ffi_arrow_array_stream(record_batch_stream, RT.handle().clone()).infer_error()?; unsafe { std::ptr::write_unaligned(stream_addr as *mut FFI_ArrowArrayStream, ffi_stream) } Ok(()) } @@ -191,13 +209,21 @@ pub extern "system" fn Java_com_lancedb_lance_ipc_LanceScanner_importFfiSchema( j_scanner: JObject, schema_addr: jlong, ) { - ok_or_throw_without_return!(env, inner_import_ffi_schema(&mut env, j_scanner, schema_addr)); + ok_or_throw_without_return!( + env, + inner_import_ffi_schema(&mut env, j_scanner, schema_addr) + ); } -fn inner_import_ffi_schema(env: &mut JNIEnv, j_scanner: JObject, schema_addr: jlong) -> JavaResult<()> { +fn inner_import_ffi_schema( + env: &mut JNIEnv, + j_scanner: JObject, + schema_addr: jlong, +) -> JavaResult<()> { let scanner = { let scanner_guard = - unsafe { env.get_rust_field::<_, _, BlockingScanner>(j_scanner, NATIVE_SCANNER) }.infer_error()?; + unsafe { env.get_rust_field::<_, _, BlockingScanner>(j_scanner, NATIVE_SCANNER) } + .infer_error()?; scanner_guard.clone() }; let schema = scanner.schema()?; @@ -217,7 +243,8 @@ pub extern "system" fn Java_com_lancedb_lance_ipc_LanceScanner_countRows( fn inner_count_rows(env: &mut JNIEnv, j_scanner: JObject) -> JavaResult { let scanner = { let scanner_guard = - unsafe { env.get_rust_field::<_, _, BlockingScanner>(j_scanner, NATIVE_SCANNER) }.infer_error()?; + unsafe { env.get_rust_field::<_, _, BlockingScanner>(j_scanner, NATIVE_SCANNER) } + .infer_error()?; scanner_guard.clone() }; scanner.count_rows() diff --git a/java/core/lance-jni/src/error.rs b/java/core/lance-jni/src/error.rs index 855eb39136..ed839a712d 100644 --- a/java/core/lance-jni/src/error.rs +++ b/java/core/lance-jni/src/error.rs @@ -23,159 +23,164 @@ pub type JavaResult = std::result::Result; #[derive(Debug)] pub enum JavaExceptionClass { - IllegalArgumentException, - IOException, - RuntimeException, - UnsupportedOperationException, + IllegalArgumentException, + IOException, + RuntimeException, + UnsupportedOperationException, } impl JavaExceptionClass { - pub fn as_str(&self) -> &str { - match self { - Self::IllegalArgumentException => "java/lang/IllegalArgumentException", - Self::IOException => "java/io/IOException", - Self::RuntimeException => "java/lang/RuntimeException", - Self::UnsupportedOperationException => "java/lang/UnsupportedOperationException" - } - } + pub fn as_str(&self) -> &str { + match self { + Self::IllegalArgumentException => "java/lang/IllegalArgumentException", + Self::IOException => "java/io/IOException", + Self::RuntimeException => "java/lang/RuntimeException", + Self::UnsupportedOperationException => "java/lang/UnsupportedOperationException", + } + } } #[derive(Debug)] pub struct JavaError { - message: String, - java_class: JavaExceptionClass, + message: String, + java_class: JavaExceptionClass, } impl JavaError { - pub fn new(message: String, java_class: JavaExceptionClass) -> Self { - JavaError { message, java_class } - } - - pub fn runtime_error(message: String) -> Self { - JavaError { message: message, java_class: JavaExceptionClass::RuntimeException } - } - - pub fn io_error(message: String) -> Self { - JavaError::new(message, JavaExceptionClass::IOException) - } - - pub fn input_error(message: String) -> Self { - JavaError::new(message, JavaExceptionClass::IllegalArgumentException) - } - - pub fn unsupported_error(message: String) -> Self { - JavaError::new(message, JavaExceptionClass::UnsupportedOperationException) - } - - pub fn throw(&self, env: &mut JNIEnv) { - env.throw_new(self.java_class.as_str(), &self.message) - .expect("Error when throwing Java exception"); - } + pub fn new(message: String, java_class: JavaExceptionClass) -> Self { + JavaError { + message, + java_class, + } + } + + pub fn runtime_error(message: String) -> Self { + JavaError { + message: message, + java_class: JavaExceptionClass::RuntimeException, + } + } + + pub fn io_error(message: String) -> Self { + JavaError::new(message, JavaExceptionClass::IOException) + } + + pub fn input_error(message: String) -> Self { + JavaError::new(message, JavaExceptionClass::IllegalArgumentException) + } + + pub fn unsupported_error(message: String) -> Self { + JavaError::new(message, JavaExceptionClass::UnsupportedOperationException) + } + + pub fn throw(&self, env: &mut JNIEnv) { + env.throw_new(self.java_class.as_str(), &self.message) + .expect("Error when throwing Java exception"); + } } impl std::fmt::Display for JavaError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}: {}", self.java_class.as_str(), self.message) - } + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}: {}", self.java_class.as_str(), self.message) + } } impl std::error::Error for JavaError {} /// Trait for converting errors to Java exceptions. pub trait JavaErrorConversion { - /// Convert to `JavaError` as I/O exception. - fn io_error(self) -> JavaResult; + /// Convert to `JavaError` as I/O exception. + fn io_error(self) -> JavaResult; - /// Convert to `JavaError` as runtime exception. - fn runtime_error(self) -> JavaResult; + /// Convert to `JavaError` as runtime exception. + fn runtime_error(self) -> JavaResult; - /// Convert to `JavaError` as value (input) exception. - fn input_error(self) -> JavaResult; + /// Convert to `JavaError` as value (input) exception. + fn input_error(self) -> JavaResult; - /// Convert to `JavaError` as unsupported operation exception. - fn unsupported_error(self) -> JavaResult; + /// Convert to `JavaError` as unsupported operation exception. + fn unsupported_error(self) -> JavaResult; } - impl JavaErrorConversion for std::result::Result { - fn io_error(self) -> JavaResult { - self.map_err(|err| JavaError::io_error(err.to_string())) - } + fn io_error(self) -> JavaResult { + self.map_err(|err| JavaError::io_error(err.to_string())) + } - fn runtime_error(self) -> JavaResult { - self.map_err(|err| JavaError::runtime_error(err.to_string())) - } + fn runtime_error(self) -> JavaResult { + self.map_err(|err| JavaError::runtime_error(err.to_string())) + } - fn input_error(self) -> JavaResult { - self.map_err(|err| JavaError::input_error(err.to_string())) - } + fn input_error(self) -> JavaResult { + self.map_err(|err| JavaError::input_error(err.to_string())) + } - fn unsupported_error(self) -> JavaResult { - self.map_err(|err| JavaError::unsupported_error(err.to_string())) - } + fn unsupported_error(self) -> JavaResult { + self.map_err(|err| JavaError::unsupported_error(err.to_string())) + } } /// JavaErrorExt trait that converts specific error types to Java exceptions pub trait JavaErrorExt { - /// Convert to a Java error based on the specific error type - fn infer_error(self) -> JavaResult; + /// Convert to a Java error based on the specific error type + fn infer_error(self) -> JavaResult; } impl JavaErrorExt for std::result::Result { - fn infer_error(self) -> JavaResult { - match &self { - Ok(_) => Ok(self.unwrap()), - Err(err) => match err { - LanceError::InvalidInput { .. } => self.input_error(), - LanceError::IO { .. } => self.io_error(), - LanceError::NotSupported { .. } => self.unsupported_error(), - _ => self.runtime_error(), - }, - } - } + fn infer_error(self) -> JavaResult { + match &self { + Ok(_) => Ok(self.unwrap()), + Err(err) => match err { + LanceError::InvalidInput { .. } => self.input_error(), + LanceError::IO { .. } => self.io_error(), + LanceError::NotSupported { .. } => self.unsupported_error(), + _ => self.runtime_error(), + }, + } + } } impl JavaErrorExt for std::result::Result { - fn infer_error(self) -> JavaResult { - match &self { - Ok(_) => Ok(self.unwrap()), - Err(err) => match err { - ArrowError::InvalidArgumentError{ .. } => self.input_error(), - ArrowError::IoError{ .. } => self.io_error(), - ArrowError::NotYetImplemented(_) => self.unsupported_error(), - _ => self.runtime_error(), - }, - } - } + fn infer_error(self) -> JavaResult { + match &self { + Ok(_) => Ok(self.unwrap()), + Err(err) => match err { + ArrowError::InvalidArgumentError { .. } => self.input_error(), + ArrowError::IoError { .. } => self.io_error(), + ArrowError::NotYetImplemented(_) => self.unsupported_error(), + _ => self.runtime_error(), + }, + } + } } impl JavaErrorExt for std::result::Result { - fn infer_error(self) -> JavaResult { - match &self { - Ok(_) => Ok(self.unwrap()), - Err(_) => self.io_error(), - } - } + fn infer_error(self) -> JavaResult { + match &self { + Ok(_) => Ok(self.unwrap()), + Err(_) => self.io_error(), + } + } } impl JavaErrorExt for std::result::Result { - fn infer_error(self) -> JavaResult { - match &self { - Ok(_) => Ok(self.unwrap()), - Err(err) => match err { - _ => self.runtime_error(), - }, - } - } + fn infer_error(self) -> JavaResult { + match &self { + Ok(_) => Ok(self.unwrap()), + Err(err) => match err { + _ => self.runtime_error(), + }, + } + } } impl JavaErrorExt for std::result::Result { - fn infer_error(self) -> JavaResult { - match &self { - Ok(_) => Ok(self.unwrap()), - Err(err) => match err { - _ => self.input_error(), - }, - } - } -} \ No newline at end of file + fn infer_error(self) -> JavaResult { + match &self { + Ok(_) => Ok(self.unwrap()), + Err(err) => match err { + _ => self.input_error(), + }, + } + } +} diff --git a/java/core/lance-jni/src/ffi.rs b/java/core/lance-jni/src/ffi.rs index e40574f65a..8bb17c5945 100644 --- a/java/core/lance-jni/src/ffi.rs +++ b/java/core/lance-jni/src/ffi.rs @@ -14,7 +14,7 @@ use core::slice; -use crate::error::{JavaResult, JavaErrorExt}; +use crate::error::{JavaErrorExt, JavaResult}; use jni::objects::{JByteBuffer, JObjectArray, JString}; use jni::sys::jobjectArray; use jni::{objects::JObject, JNIEnv}; @@ -63,7 +63,9 @@ impl JNIEnvExt for JNIEnv<'_> { let mut iter = list.iter(self).infer_error()?; let mut results = Vec::with_capacity(list.size(self).infer_error()? as usize); while let Some(elem) = iter.next(self).infer_error()? { - let int_obj = self.call_method(elem, "intValue", "()I", &[]).infer_error()?; + let int_obj = self + .call_method(elem, "intValue", "()I", &[]) + .infer_error()?; let int_value = int_obj.i().infer_error()?; results.push(int_value); } @@ -87,7 +89,10 @@ impl JNIEnvExt for JNIEnv<'_> { let array_len = self.get_array_length(&jobject_array).infer_error()?; let mut res: Vec = Vec::new(); for i in 0..array_len { - let item: JString = self.get_object_array_element(&jobject_array, i).infer_error()?.into(); + let item: JString = self + .get_object_array_element(&jobject_array, i) + .infer_error()? + .into(); res.push(self.get_string(&item).infer_error()?.into()); } Ok(res) @@ -95,7 +100,9 @@ impl JNIEnvExt for JNIEnv<'_> { fn get_string_opt(&mut self, obj: &JObject) -> JavaResult> { self.get_optional(obj, |env, inner_obj| { - let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[]).infer_error()?; + let java_obj_gen = env + .call_method(inner_obj, "get", "()Ljava/lang/Object;", &[]) + .infer_error()?; let java_string_obj = java_obj_gen.l().infer_error()?; let jstr = JString::from(java_string_obj); let val = env.get_string(&jstr).infer_error()?; @@ -105,7 +112,9 @@ impl JNIEnvExt for JNIEnv<'_> { fn get_strings_opt(&mut self, obj: &JObject) -> JavaResult>> { self.get_optional(obj, |env, inner_obj| { - let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[]).infer_error()?; + let java_obj_gen = env + .call_method(inner_obj, "get", "()Ljava/lang/Object;", &[]) + .infer_error()?; let java_list_obj = java_obj_gen.l().infer_error()?; env.get_strings(&java_list_obj) }) @@ -113,9 +122,13 @@ impl JNIEnvExt for JNIEnv<'_> { fn get_int_opt(&mut self, obj: &JObject) -> JavaResult> { self.get_optional(obj, |env, inner_obj| { - let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[]).infer_error()?; + let java_obj_gen = env + .call_method(inner_obj, "get", "()Ljava/lang/Object;", &[]) + .infer_error()?; let java_int_obj = java_obj_gen.l().infer_error()?; - let int_obj = env.call_method(java_int_obj, "intValue", "()I", &[]).infer_error()?; + let int_obj = env + .call_method(java_int_obj, "intValue", "()I", &[]) + .infer_error()?; let int_value = int_obj.i().infer_error()?; Ok(int_value) }) @@ -123,7 +136,9 @@ impl JNIEnvExt for JNIEnv<'_> { fn get_ints_opt(&mut self, obj: &JObject) -> JavaResult>> { self.get_optional(obj, |env, inner_obj| { - let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[]).infer_error()?; + let java_obj_gen = env + .call_method(inner_obj, "get", "()Ljava/lang/Object;", &[]) + .infer_error()?; let java_list_obj = java_obj_gen.l().infer_error()?; env.get_integers(&java_list_obj) }) @@ -131,9 +146,13 @@ impl JNIEnvExt for JNIEnv<'_> { fn get_long_opt(&mut self, obj: &JObject) -> JavaResult> { self.get_optional(obj, |env, inner_obj| { - let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[]).infer_error()?; + let java_obj_gen = env + .call_method(inner_obj, "get", "()Ljava/lang/Object;", &[]) + .infer_error()?; let java_long_obj = java_obj_gen.l().infer_error()?; - let long_obj = env.call_method(java_long_obj, "longValue", "()J", &[]).infer_error()?; + let long_obj = env + .call_method(java_long_obj, "longValue", "()J", &[]) + .infer_error()?; let long_value = long_obj.j().infer_error()?; Ok(long_value) }) @@ -141,9 +160,13 @@ impl JNIEnvExt for JNIEnv<'_> { fn get_u64_opt(&mut self, obj: &JObject) -> JavaResult> { self.get_optional(obj, |env, inner_obj| { - let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[]).infer_error()?; + let java_obj_gen = env + .call_method(inner_obj, "get", "()Ljava/lang/Object;", &[]) + .infer_error()?; let java_long_obj = java_obj_gen.l().infer_error()?; - let long_obj = env.call_method(java_long_obj, "longValue", "()J", &[]).infer_error()?; + let long_obj = env + .call_method(java_long_obj, "longValue", "()J", &[]) + .infer_error()?; let long_value = long_obj.j().infer_error()?; Ok(long_value as u64) }) @@ -151,11 +174,17 @@ impl JNIEnvExt for JNIEnv<'_> { fn get_bytes_opt(&mut self, obj: &JObject) -> JavaResult> { self.get_optional(obj, |env, inner_obj| { - let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[]).infer_error()?; + let java_obj_gen = env + .call_method(inner_obj, "get", "()Ljava/lang/Object;", &[]) + .infer_error()?; let java_byte_buffer_obj = java_obj_gen.l().infer_error()?; let j_byte_buffer = JByteBuffer::from(java_byte_buffer_obj); - let raw_data = env.get_direct_buffer_address(&j_byte_buffer).infer_error()?; - let capacity = env.get_direct_buffer_capacity(&j_byte_buffer).infer_error()?; + let raw_data = env + .get_direct_buffer_address(&j_byte_buffer) + .infer_error()?; + let capacity = env + .get_direct_buffer_capacity(&j_byte_buffer) + .infer_error()?; let data = unsafe { slice::from_raw_parts(raw_data, capacity) }; Ok(data) }) diff --git a/java/core/lance-jni/src/fragment.rs b/java/core/lance-jni/src/fragment.rs index f13697af74..962b37281e 100644 --- a/java/core/lance-jni/src/fragment.rs +++ b/java/core/lance-jni/src/fragment.rs @@ -25,6 +25,7 @@ use std::iter::once; use lance::dataset::fragment::FileFragment; +use crate::error::{JavaErrorExt, JavaResult}; use crate::JavaError; use crate::{ blocking_dataset::{BlockingDataset, NATIVE_DATASET}, @@ -33,14 +34,11 @@ use crate::{ utils::extract_write_params, RT, }; -use crate::error::{JavaResult, JavaErrorExt}; - /////////////////// // Write Methods // /////////////////// - ////////////////// // Read Methods // ////////////////// @@ -51,7 +49,11 @@ pub extern "system" fn Java_com_lancedb_lance_DatasetFragment_countRowsNative( jdataset: JObject, fragment_id: jlong, ) -> jint { - ok_or_throw_with_return!(env, inner_count_rows_native(&mut env, jdataset, fragment_id), -1) as jint + ok_or_throw_with_return!( + env, + inner_count_rows_native(&mut env, jdataset, fragment_id), + -1 + ) as jint } fn inner_count_rows_native( @@ -59,9 +61,12 @@ fn inner_count_rows_native( jdataset: JObject, fragment_id: jlong, ) -> JavaResult { - let dataset =unsafe { env.get_rust_field::<_, _, BlockingDataset>(jdataset, NATIVE_DATASET) }.infer_error()?; - let Some(fragment) = dataset.inner.get_fragment(fragment_id as usize) else { - return Err(JavaError::input_error(format!("Fragment not found: {fragment_id}"))) + let dataset = unsafe { env.get_rust_field::<_, _, BlockingDataset>(jdataset, NATIVE_DATASET) } + .infer_error()?; + let Some(fragment) = dataset.inner.get_fragment(fragment_id as usize) else { + return Err(JavaError::input_error(format!( + "Fragment not found: {fragment_id}" + ))); }; RT.block_on(fragment.count_rows()).infer_error() } @@ -79,7 +84,21 @@ pub extern "system" fn Java_com_lancedb_lance_Fragment_createWithFfiArray<'local max_bytes_per_file: JObject, // Optional mode: JObject, // Optional ) -> JString<'local> { - ok_or_throw_with_return!(env, inner_create_with_ffi_array(&mut env, dataset_uri, arrow_array_addr, arrow_schema_addr, fragment_id, max_rows_per_file, max_rows_per_group, max_bytes_per_file, mode), JString::default()) + ok_or_throw_with_return!( + env, + inner_create_with_ffi_array( + &mut env, + dataset_uri, + arrow_array_addr, + arrow_schema_addr, + fragment_id, + max_rows_per_file, + max_rows_per_group, + max_bytes_per_file, + mode + ), + JString::default() + ) } fn inner_create_with_ffi_array<'local>( @@ -100,7 +119,7 @@ fn inner_create_with_ffi_array<'local>( let c_schema = unsafe { FFI_ArrowSchema::from_raw(c_schema_ptr) }; let data_type = DataType::try_from(&c_schema).infer_error()?; - let array_data = unsafe { from_ffi_and_data_type(c_array, data_type) }.infer_error()?; + let array_data = unsafe { from_ffi_and_data_type(c_array, data_type) }.infer_error()?; let record_batch = RecordBatch::from(StructArray::from(array_data)); let batch_schema = record_batch.schema().clone(); @@ -114,7 +133,7 @@ fn inner_create_with_ffi_array<'local>( max_rows_per_group, max_bytes_per_file, mode, - reader + reader, ) } @@ -130,7 +149,20 @@ pub extern "system" fn Java_com_lancedb_lance_Fragment_createWithFfiStream<'a>( max_bytes_per_file: JObject, // Optional mode: JObject, // Optional ) -> JString<'a> { - ok_or_throw_with_return!(env, inner_create_with_ffi_stream(&mut env, dataset_uri, arrow_array_stream_addr, fragment_id, max_rows_per_file, max_rows_per_group, max_bytes_per_file, mode), JString::default()) + ok_or_throw_with_return!( + env, + inner_create_with_ffi_stream( + &mut env, + dataset_uri, + arrow_array_stream_addr, + fragment_id, + max_rows_per_file, + max_rows_per_group, + max_bytes_per_file, + mode + ), + JString::default() + ) } fn inner_create_with_ffi_stream<'a>( @@ -144,8 +176,7 @@ fn inner_create_with_ffi_stream<'a>( mode: JObject, // Optional ) -> JavaResult> { let stream_ptr = arrow_array_stream_addr as *mut FFI_ArrowArrayStream; - let reader = - unsafe { ArrowArrayStreamReader::from_raw(stream_ptr) }.infer_error()?; + let reader = unsafe { ArrowArrayStreamReader::from_raw(stream_ptr) }.infer_error()?; create_fragment( env, @@ -155,7 +186,7 @@ fn inner_create_with_ffi_stream<'a>( max_rows_per_group, max_bytes_per_file, mode, - reader + reader, ) } @@ -181,12 +212,14 @@ fn create_fragment<'a>( &max_bytes_per_file, &mode, )?; - let fragment = RT.block_on(FileFragment::create( - &path_str, - fragment_id_opts.unwrap_or(0) as usize, - reader, - Some(write_params), - )).infer_error()?; + let fragment = RT + .block_on(FileFragment::create( + &path_str, + fragment_id_opts.unwrap_or(0) as usize, + reader, + Some(write_params), + )) + .infer_error()?; let json_string = serde_json::to_string(&fragment).infer_error()?; Ok(env.new_string(json_string).infer_error()?) } diff --git a/java/core/lance-jni/src/lib.rs b/java/core/lance-jni/src/lib.rs index c5b63b4aa6..da53927949 100644 --- a/java/core/lance-jni/src/lib.rs +++ b/java/core/lance-jni/src/lib.rs @@ -57,7 +57,7 @@ mod ffi; mod fragment; mod traits; mod utils; -pub use error::{JavaError, JavaResult, JavaErrorExt}; +pub use error::{JavaError, JavaErrorExt, JavaResult}; use lazy_static::lazy_static; diff --git a/java/core/lance-jni/src/utils.rs b/java/core/lance-jni/src/utils.rs index 27948b842d..9cb1fc895e 100644 --- a/java/core/lance-jni/src/utils.rs +++ b/java/core/lance-jni/src/utils.rs @@ -16,8 +16,8 @@ use jni::objects::JObject; use jni::JNIEnv; use lance::dataset::{WriteMode, WriteParams}; -use crate::ffi::JNIEnvExt; use crate::error::JavaResult; +use crate::ffi::JNIEnvExt; use crate::JavaErrorExt; pub fn extract_write_params( From 3a325bc895c00113374e4c68c1e549bdd6695f22 Mon Sep 17 00:00:00 2001 From: Lu Qiu Date: Fri, 10 May 2024 21:31:00 -0700 Subject: [PATCH 3/9] Fix cargo clippy --- java/core/lance-jni/src/blocking_dataset.rs | 9 ++++----- java/core/lance-jni/src/blocking_scanner.rs | 6 +++--- java/core/lance-jni/src/error.rs | 20 ++++++++------------ java/core/lance-jni/src/fragment.rs | 10 ++++++---- java/core/lance-jni/src/traits.rs | 8 ++++---- 5 files changed, 25 insertions(+), 28 deletions(-) diff --git a/java/core/lance-jni/src/blocking_dataset.rs b/java/core/lance-jni/src/blocking_dataset.rs index 910fba2da4..ec7f02148d 100644 --- a/java/core/lance-jni/src/blocking_dataset.rs +++ b/java/core/lance-jni/src/blocking_dataset.rs @@ -65,11 +65,11 @@ impl BlockingDataset { } pub fn latest_version(&self) -> JavaResult { - Ok(RT.block_on(self.inner.latest_version_id()).infer_error()?) + RT.block_on(self.inner.latest_version_id()).infer_error() } pub fn count_rows(&self, filter: Option) -> JavaResult { - Ok(RT.block_on(self.inner.count_rows(filter)).infer_error()?) + RT.block_on(self.inner.count_rows(filter)).infer_error() } pub fn close(&self) {} @@ -223,9 +223,8 @@ fn attach_native_dataset<'local>( } fn create_java_dataset_object<'a>(env: &mut JNIEnv<'a>) -> JavaResult> { - Ok(env - .new_object("com/lancedb/lance/Dataset", "()V", &[]) - .infer_error()?) + env.new_object("com/lancedb/lance/Dataset", "()V", &[]) + .infer_error() } #[no_mangle] diff --git a/java/core/lance-jni/src/blocking_scanner.rs b/java/core/lance-jni/src/blocking_scanner.rs index 9efd9a9e7b..0ec4cb7fd3 100644 --- a/java/core/lance-jni/src/blocking_scanner.rs +++ b/java/core/lance-jni/src/blocking_scanner.rs @@ -44,15 +44,15 @@ impl BlockingScanner { } pub fn open_stream(&self) -> JavaResult { - Ok(RT.block_on(self.inner.try_into_stream()).infer_error()?) + RT.block_on(self.inner.try_into_stream()).infer_error() } pub fn schema(&self) -> JavaResult { - Ok(RT.block_on(self.inner.schema()).infer_error()?) + RT.block_on(self.inner.schema()).infer_error() } pub fn count_rows(&self) -> JavaResult { - Ok(RT.block_on(self.inner.count_rows()).infer_error()?) + RT.block_on(self.inner.count_rows()).infer_error() } } diff --git a/java/core/lance-jni/src/error.rs b/java/core/lance-jni/src/error.rs index ed839a712d..f2ce683144 100644 --- a/java/core/lance-jni/src/error.rs +++ b/java/core/lance-jni/src/error.rs @@ -48,29 +48,29 @@ pub struct JavaError { impl JavaError { pub fn new(message: String, java_class: JavaExceptionClass) -> Self { - JavaError { + Self { message, java_class, } } pub fn runtime_error(message: String) -> Self { - JavaError { - message: message, + Self { + message, java_class: JavaExceptionClass::RuntimeException, } } pub fn io_error(message: String) -> Self { - JavaError::new(message, JavaExceptionClass::IOException) + Self::new(message, JavaExceptionClass::IOException) } pub fn input_error(message: String) -> Self { - JavaError::new(message, JavaExceptionClass::IllegalArgumentException) + Self::new(message, JavaExceptionClass::IllegalArgumentException) } pub fn unsupported_error(message: String) -> Self { - JavaError::new(message, JavaExceptionClass::UnsupportedOperationException) + Self::new(message, JavaExceptionClass::UnsupportedOperationException) } pub fn throw(&self, env: &mut JNIEnv) { @@ -167,9 +167,7 @@ impl JavaErrorExt for std::result::Result { fn infer_error(self) -> JavaResult { match &self { Ok(_) => Ok(self.unwrap()), - Err(err) => match err { - _ => self.runtime_error(), - }, + Err(_) => self.runtime_error(), } } } @@ -178,9 +176,7 @@ impl JavaErrorExt for std::result::Result { fn infer_error(self) -> JavaResult { match &self { Ok(_) => Ok(self.unwrap()), - Err(err) => match err { - _ => self.input_error(), - }, + Err(_) => self.input_error(), } } } diff --git a/java/core/lance-jni/src/fragment.rs b/java/core/lance-jni/src/fragment.rs index 962b37281e..d85d3d94e9 100644 --- a/java/core/lance-jni/src/fragment.rs +++ b/java/core/lance-jni/src/fragment.rs @@ -101,6 +101,7 @@ pub extern "system" fn Java_com_lancedb_lance_Fragment_createWithFfiArray<'local ) } +#[allow(clippy::too_many_arguments)] fn inner_create_with_ffi_array<'local>( env: &mut JNIEnv<'local>, dataset_uri: JString, @@ -165,8 +166,9 @@ pub extern "system" fn Java_com_lancedb_lance_Fragment_createWithFfiStream<'a>( ) } -fn inner_create_with_ffi_stream<'a>( - env: &mut JNIEnv<'a>, +#[allow(clippy::too_many_arguments)] +fn inner_create_with_ffi_stream<'local>( + env: &mut JNIEnv<'local>, dataset_uri: JString, arrow_array_stream_addr: jlong, fragment_id: JObject, // Optional @@ -174,7 +176,7 @@ fn inner_create_with_ffi_stream<'a>( max_rows_per_group: JObject, // Optional max_bytes_per_file: JObject, // Optional mode: JObject, // Optional -) -> JavaResult> { +) -> JavaResult> { let stream_ptr = arrow_array_stream_addr as *mut FFI_ArrowArrayStream; let reader = unsafe { ArrowArrayStreamReader::from_raw(stream_ptr) }.infer_error()?; @@ -221,5 +223,5 @@ fn create_fragment<'a>( )) .infer_error()?; let json_string = serde_json::to_string(&fragment).infer_error()?; - Ok(env.new_string(json_string).infer_error()?) + env.new_string(json_string).infer_error() } diff --git a/java/core/lance-jni/src/traits.rs b/java/core/lance-jni/src/traits.rs index 462b0307d5..a2d2156411 100644 --- a/java/core/lance-jni/src/traits.rs +++ b/java/core/lance-jni/src/traits.rs @@ -28,25 +28,25 @@ pub trait IntoJava { impl FromJObject for JObject<'_> { fn extract(&self) -> JavaResult { - Ok(JValue::from(self).i().infer_error()?) + JValue::from(self).i().infer_error() } } impl FromJObject for JObject<'_> { fn extract(&self) -> JavaResult { - Ok(JValue::from(self).j().infer_error()?) + JValue::from(self).j().infer_error() } } impl FromJObject for JObject<'_> { fn extract(&self) -> JavaResult { - Ok(JValue::from(self).f().infer_error()?) + JValue::from(self).f().infer_error() } } impl FromJObject for JObject<'_> { fn extract(&self) -> JavaResult { - Ok(JValue::from(self).d().infer_error()?) + JValue::from(self).d().infer_error() } } From 3074cb26959dad45232ad7c208f096510fca5da9 Mon Sep 17 00:00:00 2001 From: Lu Qiu Date: Sun, 19 May 2024 22:25:26 -0700 Subject: [PATCH 4/9] Add checkNotNull and some test cases --- java/core/lance-jni/src/error.rs | 5 +- .../main/java/com/lancedb/lance/Dataset.java | 16 +++++++ .../com/lancedb/lance/DatasetFragment.java | 8 +++- .../main/java/com/lancedb/lance/Fragment.java | 10 ++++ .../com/lancedb/lance/FragmentMetadata.java | 2 + .../com/lancedb/lance/FragmentOperation.java | 4 ++ .../com/lancedb/lance/ipc/LanceScanner.java | 4 ++ .../java/com/lancedb/lance/DatasetTest.java | 47 +++++++++++++++++++ .../java/com/lancedb/lance/FragmentTest.java | 42 +++++++++++++++++ 9 files changed, 135 insertions(+), 3 deletions(-) diff --git a/java/core/lance-jni/src/error.rs b/java/core/lance-jni/src/error.rs index f2ce683144..cc3be31f9c 100644 --- a/java/core/lance-jni/src/error.rs +++ b/java/core/lance-jni/src/error.rs @@ -131,7 +131,10 @@ impl JavaErrorExt for std::result::Result { match &self { Ok(_) => Ok(self.unwrap()), Err(err) => match err { - LanceError::InvalidInput { .. } => self.input_error(), + LanceError::DatasetNotFound { .. } + | LanceError::DatasetAlreadyExists { .. } + | LanceError::CommitConflict { .. } + | LanceError::InvalidInput { .. } => self.input_error(), LanceError::IO { .. } => self.io_error(), LanceError::NotSupported { .. } => self.unsupported_error(), _ => self.runtime_error(), diff --git a/java/core/src/main/java/com/lancedb/lance/Dataset.java b/java/core/src/main/java/com/lancedb/lance/Dataset.java index e1b49f52c1..87e477b37b 100644 --- a/java/core/src/main/java/com/lancedb/lance/Dataset.java +++ b/java/core/src/main/java/com/lancedb/lance/Dataset.java @@ -25,6 +25,7 @@ import org.apache.arrow.c.ArrowSchema; import org.apache.arrow.c.Data; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.types.pojo.Schema; /** @@ -59,6 +60,10 @@ private Dataset() { */ public static Dataset create(BufferAllocator allocator, String path, Schema schema, WriteParams params) { + Preconditions.checkNotNull(allocator); + Preconditions.checkNotNull(path); + Preconditions.checkNotNull(schema); + Preconditions.checkNotNull(params); try (ArrowSchema arrowSchema = ArrowSchema.allocateNew(allocator)) { Data.exportSchema(allocator, schema, null, arrowSchema); var dataset = createWithFfiSchema(arrowSchema.memoryAddress(), @@ -80,6 +85,10 @@ public static Dataset create(BufferAllocator allocator, String path, Schema sche */ public static Dataset create(BufferAllocator allocator, ArrowArrayStream stream, String path, WriteParams params) { + Preconditions.checkNotNull(allocator); + Preconditions.checkNotNull(stream); + Preconditions.checkNotNull(path); + Preconditions.checkNotNull(params); var dataset = createWithFfiStream(stream.memoryAddress(), path, params.getMaxRowsPerFile(), params.getMaxRowsPerGroup(), params.getMaxBytesPerFile(), params.getMode()); @@ -103,6 +112,8 @@ private static native Dataset createWithFfiStream(long arrowStreamMemoryAddress, * @return Dataset */ public static Dataset open(String path, BufferAllocator allocator) { + Preconditions.checkNotNull(path); + Preconditions.checkNotNull(allocator); var dataset = openNative(path); dataset.allocator = allocator; return dataset; @@ -129,6 +140,10 @@ public static Dataset open(String path, BufferAllocator allocator) { */ public static Dataset commit(BufferAllocator allocator, String path, FragmentOperation operation, Optional readVersion) { + Preconditions.checkNotNull(allocator); + Preconditions.checkNotNull(path); + Preconditions.checkNotNull(operation); + Preconditions.checkNotNull(readVersion); var dataset = operation.commit(allocator, path, readVersion); dataset.allocator = allocator; return dataset; @@ -163,6 +178,7 @@ public LanceScanner newScan(long batchSize) { * @return a dataset scanner */ public LanceScanner newScan(ScanOptions options) { + Preconditions.checkNotNull(options); return LanceScanner.create(this, options, allocator); } diff --git a/java/core/src/main/java/com/lancedb/lance/DatasetFragment.java b/java/core/src/main/java/com/lancedb/lance/DatasetFragment.java index 59cfe5b5f1..d655be9f3d 100644 --- a/java/core/src/main/java/com/lancedb/lance/DatasetFragment.java +++ b/java/core/src/main/java/com/lancedb/lance/DatasetFragment.java @@ -17,11 +17,12 @@ import com.lancedb.lance.ipc.LanceScanner; import com.lancedb.lance.ipc.ScanOptions; import java.util.List; +import org.apache.arrow.util.Preconditions; /** * Dataset format. * Matching to Lance Rust FileFragment. - * */ + */ public class DatasetFragment { /** Pointer to the {@link Dataset} instance in Java. */ private final Dataset dataset; @@ -29,6 +30,8 @@ public class DatasetFragment { /** Private constructor, calling from JNI. */ DatasetFragment(Dataset dataset, FragmentMetadata metadata) { + Preconditions.checkNotNull(dataset); + Preconditions.checkNotNull(metadata); this.dataset = dataset; this.metadata = metadata; } @@ -52,7 +55,7 @@ public LanceScanner newScan() { public LanceScanner newScan(long batchSize) { return LanceScanner.create(dataset, new ScanOptions.Builder() - .fragmentIds(List.of(metadata.getId())).batchSize(batchSize).build(), + .fragmentIds(List.of(metadata.getId())).batchSize(batchSize).build(), dataset.allocator); } @@ -63,6 +66,7 @@ public LanceScanner newScan(long batchSize) { * @return a dataset scanner */ public LanceScanner newScan(ScanOptions options) { + Preconditions.checkNotNull(options); return LanceScanner.create(dataset, new ScanOptions.Builder(options).fragmentIds(List.of(metadata.getId())).build(), dataset.allocator); diff --git a/java/core/src/main/java/com/lancedb/lance/Fragment.java b/java/core/src/main/java/com/lancedb/lance/Fragment.java index 8f8d12322a..0920c324b0 100644 --- a/java/core/src/main/java/com/lancedb/lance/Fragment.java +++ b/java/core/src/main/java/com/lancedb/lance/Fragment.java @@ -20,6 +20,7 @@ import org.apache.arrow.c.ArrowSchema; import org.apache.arrow.c.Data; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.VectorSchemaRoot; /** Fragment operations. */ @@ -31,6 +32,11 @@ public class Fragment { /** Create a fragment from the given data in vector schema root. */ public static FragmentMetadata create(String datasetUri, BufferAllocator allocator, VectorSchemaRoot root, Optional fragmentId, WriteParams params) { + Preconditions.checkNotNull(datasetUri); + Preconditions.checkNotNull(allocator); + Preconditions.checkNotNull(root); + Preconditions.checkNotNull(fragmentId); + Preconditions.checkNotNull(params); try (ArrowSchema arrowSchema = ArrowSchema.allocateNew(allocator); ArrowArray arrowArray = ArrowArray.allocateNew(allocator)) { Data.exportVectorSchemaRoot(allocator, root, null, arrowArray, arrowSchema); @@ -43,6 +49,10 @@ public static FragmentMetadata create(String datasetUri, BufferAllocator allocat /** Create a fragment from the given data. */ public static FragmentMetadata create(String datasetUri, ArrowArrayStream stream, Optional fragmentId, WriteParams params) { + Preconditions.checkNotNull(datasetUri); + Preconditions.checkNotNull(stream); + Preconditions.checkNotNull(fragmentId); + Preconditions.checkNotNull(params); return FragmentMetadata.fromJson(createWithFfiStream(datasetUri, stream.memoryAddress(), fragmentId, params.getMaxRowsPerFile(), params.getMaxRowsPerGroup(), diff --git a/java/core/src/main/java/com/lancedb/lance/FragmentMetadata.java b/java/core/src/main/java/com/lancedb/lance/FragmentMetadata.java index 36a96ff9f4..1f677854e5 100644 --- a/java/core/src/main/java/com/lancedb/lance/FragmentMetadata.java +++ b/java/core/src/main/java/com/lancedb/lance/FragmentMetadata.java @@ -15,6 +15,7 @@ package com.lancedb.lance; import java.io.Serializable; +import org.apache.arrow.util.Preconditions; import org.json.JSONObject; /** @@ -54,6 +55,7 @@ public String getJsonMetadata() { * @return created fragment metadata */ public static FragmentMetadata fromJson(String jsonMetadata) { + Preconditions.checkNotNull(jsonMetadata); JSONObject metadata = new JSONObject(jsonMetadata); if (!metadata.has(ID_KEY) || !metadata.has(PHYSICAL_ROWS_KEY)) { throw new IllegalArgumentException( diff --git a/java/core/src/main/java/com/lancedb/lance/FragmentOperation.java b/java/core/src/main/java/com/lancedb/lance/FragmentOperation.java index fa76946b36..b3e31fe32b 100644 --- a/java/core/src/main/java/com/lancedb/lance/FragmentOperation.java +++ b/java/core/src/main/java/com/lancedb/lance/FragmentOperation.java @@ -18,6 +18,7 @@ import java.util.Optional; import java.util.stream.Collectors; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.types.pojo.Schema; /** Fragment related operations. */ @@ -42,6 +43,9 @@ public Append(List fragments) { @Override public Dataset commit(BufferAllocator allocator, String path, Optional readVersion) { + Preconditions.checkNotNull(allocator); + Preconditions.checkNotNull(path); + Preconditions.checkNotNull(readVersion); return Dataset.commitAppend(path, readVersion, fragments.stream().map(FragmentMetadata::getJsonMetadata).collect(Collectors.toList())); } diff --git a/java/core/src/main/java/com/lancedb/lance/ipc/LanceScanner.java b/java/core/src/main/java/com/lancedb/lance/ipc/LanceScanner.java index 4520183a35..7d15471291 100644 --- a/java/core/src/main/java/com/lancedb/lance/ipc/LanceScanner.java +++ b/java/core/src/main/java/com/lancedb/lance/ipc/LanceScanner.java @@ -24,6 +24,7 @@ import org.apache.arrow.c.Data; import org.apache.arrow.dataset.scanner.ScanTask; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.arrow.vector.types.pojo.Schema; @@ -49,6 +50,9 @@ private LanceScanner() {} */ public static LanceScanner create(Dataset dataset, ScanOptions options, BufferAllocator allocator) { + Preconditions.checkNotNull(dataset); + Preconditions.checkNotNull(options); + Preconditions.checkNotNull(allocator); LanceScanner scanner = createScanner(dataset, options.getFragmentIds(), options.getColumns(), options.getSubstraitFilter(), options.getFilter(), options.getBatchSize()); scanner.allocator = allocator; diff --git a/java/core/src/test/java/com/lancedb/lance/DatasetTest.java b/java/core/src/test/java/com/lancedb/lance/DatasetTest.java index e545de194a..60dd0c5c1e 100644 --- a/java/core/src/test/java/com/lancedb/lance/DatasetTest.java +++ b/java/core/src/test/java/com/lancedb/lance/DatasetTest.java @@ -60,6 +60,16 @@ void testCreateEmptyDataset() { } } + @Test + void testCreateDirNotExist() throws IOException, URISyntaxException { + String testMethodName = new Object() {}.getClass().getEnclosingMethod().getName(); + String datasetPath = tempDir.resolve(testMethodName).toString(); + try (BufferAllocator allocator = new RootAllocator()) { + TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); + testDataset.createEmptyDataset().close(); + } + } + @Test void testOpenInvalidPath() { String validPath = tempDir.resolve("Invalid_dataset").toString(); @@ -95,4 +105,41 @@ void testDatasetVersion() { } } } + + @Test + void testOpenNonExist() throws IOException, URISyntaxException { + String datasetPath = tempDir.resolve("non_exist").toString(); + try (BufferAllocator allocator = new RootAllocator()) { + assertThrows(IllegalArgumentException.class, () -> { + Dataset.open(datasetPath, allocator); + }); + } + } + + @Test + void testCreateExist() throws IOException, URISyntaxException { + String datasetPath = tempDir.resolve("create_exist").toString(); + try (BufferAllocator allocator = new RootAllocator()) { + TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); + testDataset.createEmptyDataset().close(); + assertThrows(IllegalArgumentException.class, () -> { + testDataset.createEmptyDataset(); + }); + } + } + + @Test + void testCommitConflict() { + String datasetPath = tempDir.resolve("commit_conflict").toString(); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); + try (Dataset dataset = testDataset.createEmptyDataset()) { + assertEquals(1, dataset.version()); + assertEquals(1, dataset.latestVersion()); + assertThrows(IllegalArgumentException.class, () -> { + testDataset.write(0, 5); + }); + } + } + } } diff --git a/java/core/src/test/java/com/lancedb/lance/FragmentTest.java b/java/core/src/test/java/com/lancedb/lance/FragmentTest.java index 2752c62cdb..e5bb40c961 100644 --- a/java/core/src/test/java/com/lancedb/lance/FragmentTest.java +++ b/java/core/src/test/java/com/lancedb/lance/FragmentTest.java @@ -15,9 +15,11 @@ package com.lancedb.lance; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; import com.lancedb.lance.ipc.ScanOptions; import java.nio.file.Path; +import java.util.ArrayList; import java.util.List; import java.util.Optional; import org.apache.arrow.dataset.scanner.Scanner; @@ -71,4 +73,44 @@ void testFragmentCreate() throws Exception { } } } + + @Test + void commitWithoutVersion() { + String datasetPath = tempDir.resolve("commit_without_version").toString(); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); + testDataset.createEmptyDataset().close(); + FragmentMetadata meta = testDataset.createNewFragment(123, 20); + FragmentOperation.Append appendOp = new FragmentOperation.Append(List.of(meta)); + assertThrows(IllegalArgumentException.class, () -> { + Dataset.commit(allocator, datasetPath, appendOp, Optional.empty()); + }); + } + } + + @Test + void commitOldVersion() { + String datasetPath = tempDir.resolve("commit_old_version").toString(); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); + testDataset.createEmptyDataset().close(); + FragmentMetadata meta = testDataset.createNewFragment(123, 20); + FragmentOperation.Append appendOp = new FragmentOperation.Append(List.of(meta)); + assertThrows(IllegalArgumentException.class, () -> { + Dataset.commit(allocator, datasetPath, appendOp, Optional.of(0L)); + }); + } + } + + @Test + void appendWithoutFragment() { + String datasetPath = tempDir.resolve("append_without_fragment").toString(); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); + testDataset.createEmptyDataset().close(); + assertThrows(IllegalArgumentException.class, () -> { + new FragmentOperation.Append(new ArrayList<>()); + }); + } + } } From 0b82b82025df6a8aa16cab328b3da21b40f37277 Mon Sep 17 00:00:00 2001 From: Lu Qiu Date: Mon, 20 May 2024 09:31:49 -0700 Subject: [PATCH 5/9] Improve Dataset thread safe --- java/core/lance-jni/src/blocking_dataset.rs | 11 ++- .../main/java/com/lancedb/lance/Dataset.java | 86 +++++++++++++++---- .../java/com/lancedb/lance/DatasetTest.java | 17 +++- 3 files changed, 90 insertions(+), 24 deletions(-) diff --git a/java/core/lance-jni/src/blocking_dataset.rs b/java/core/lance-jni/src/blocking_dataset.rs index ec7f02148d..7d60984041 100644 --- a/java/core/lance-jni/src/blocking_dataset.rs +++ b/java/core/lance-jni/src/blocking_dataset.rs @@ -350,13 +350,12 @@ fn inner_import_ffi_schema( jdataset: JObject, arrow_schema_addr: jlong, ) -> JavaResult<()> { - let dataset = { + let schema = { let dataset = unsafe { env.get_rust_field::<_, _, BlockingDataset>(jdataset, NATIVE_DATASET) } .infer_error()?; - dataset.clone() + Schema::from(dataset.inner.schema()) }; - let schema = Schema::from(dataset.inner.schema()); let c_schema = FFI_ArrowSchema::try_from(&schema).infer_error()?; let out_c_schema = unsafe { &mut *(arrow_schema_addr as *mut FFI_ArrowSchema) }; @@ -365,7 +364,7 @@ fn inner_import_ffi_schema( } #[no_mangle] -pub extern "system" fn Java_com_lancedb_lance_Dataset_version( +pub extern "system" fn Java_com_lancedb_lance_Dataset_nativeVersion( mut env: JNIEnv, java_dataset: JObject, ) -> jlong { @@ -380,7 +379,7 @@ fn inner_version(env: &mut JNIEnv, java_dataset: JObject) -> JavaResult { } #[no_mangle] -pub extern "system" fn Java_com_lancedb_lance_Dataset_latestVersion( +pub extern "system" fn Java_com_lancedb_lance_Dataset_nativeLatestVersion( mut env: JNIEnv, java_dataset: JObject, ) -> jlong { @@ -395,7 +394,7 @@ fn inner_latest_version(env: &mut JNIEnv, java_dataset: JObject) -> JavaResult jint { diff --git a/java/core/src/main/java/com/lancedb/lance/Dataset.java b/java/core/src/main/java/com/lancedb/lance/Dataset.java index 87e477b37b..8f0f790b9e 100644 --- a/java/core/src/main/java/com/lancedb/lance/Dataset.java +++ b/java/core/src/main/java/com/lancedb/lance/Dataset.java @@ -20,7 +20,10 @@ import java.io.Closeable; import java.util.List; import java.util.Optional; +import java.util.concurrent.locks.ReadWriteLock; +import java.util.concurrent.locks.ReentrantReadWriteLock; import java.util.stream.Collectors; +import javax.annotation.concurrent.NotThreadSafe; import org.apache.arrow.c.ArrowArrayStream; import org.apache.arrow.c.ArrowSchema; import org.apache.arrow.c.Data; @@ -46,6 +49,8 @@ public class Dataset implements Closeable { BufferAllocator allocator; + private final ReadWriteLock lock = new ReentrantReadWriteLock(); + private Dataset() { } @@ -179,25 +184,49 @@ public LanceScanner newScan(long batchSize) { */ public LanceScanner newScan(ScanOptions options) { Preconditions.checkNotNull(options); - return LanceScanner.create(this, options, allocator); + try (ReadLock readLock = new ReadLock()) { + Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); + return LanceScanner.create(this, options, allocator); + } } /** * Gets the currently checked out version of the dataset. */ - public native long version(); + public long version() { + try (ReadLock readLock = new ReadLock()) { + Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); + return nativeVersion(); + } + } + + private native long nativeVersion(); /** * Gets the latest version of the dataset. */ - public native long latestVersion(); + public long latestVersion() { + try (ReadLock readLock = new ReadLock()) { + Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); + return nativeLatestVersion(); + } + } + + private native long nativeLatestVersion(); /** * Count the number of rows in the dataset. * * @return num of rows. */ - public native int countRows(); + public int countRows() { + try (ReadLock readLock = new ReadLock()) { + Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); + return nativeCountRows(); + } + } + + private native int nativeCountRows(); /** * Get all fragments in this dataset. @@ -205,13 +234,16 @@ public LanceScanner newScan(ScanOptions options) { * @return A list of {@link DatasetFragment}. */ public List getFragments() { - // Set a pointer in Fragment to dataset, to make it is easier to issue IOs - // later. - // - // We do not need to close Fragments. - return this.getJsonFragments().stream() - .map(jsonFragment -> new DatasetFragment(this, FragmentMetadata.fromJson(jsonFragment))) - .collect(Collectors.toList()); + try (ReadLock readLock = new ReadLock()) { + Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); + // Set a pointer in Fragment to dataset, to make it is easier to issue IOs + // later. + // + // We do not need to close Fragments. + return this.getJsonFragments().stream() + .map(jsonFragment -> new DatasetFragment(this, FragmentMetadata.fromJson(jsonFragment))) + .collect(Collectors.toList()); + } } private native List getJsonFragments(); @@ -222,9 +254,12 @@ public List getFragments() { * @return the arrow schema */ public Schema getSchema() { - try (ArrowSchema ffiArrowSchema = ArrowSchema.allocateNew(allocator)) { - importFfiSchema(ffiArrowSchema.memoryAddress()); - return Data.importSchema(allocator, ffiArrowSchema, null); + try (ReadLock readLock = new ReadLock()) { + Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); + try (ArrowSchema ffiArrowSchema = ArrowSchema.allocateNew(allocator)) { + importFfiSchema(ffiArrowSchema.memoryAddress()); + return Data.importSchema(allocator, ffiArrowSchema, null); + } } } @@ -237,9 +272,14 @@ public Schema getSchema() { */ @Override public void close() { - if (nativeDatasetHandle != 0) { - releaseNativeDataset(nativeDatasetHandle); - nativeDatasetHandle = 0; + lock.writeLock().lock(); + try { + if (nativeDatasetHandle != 0) { + releaseNativeDataset(nativeDatasetHandle); + nativeDatasetHandle = 0; + } + } finally { + lock.writeLock().unlock(); } } @@ -250,4 +290,16 @@ public void close() { * @param handle The native handle to the dataset resource. */ private native void releaseNativeDataset(long handle); + + private class ReadLock implements AutoCloseable { + /** Read lock. */ + public ReadLock() { + lock.readLock().lock(); + } + + @Override + public void close() { + lock.readLock().unlock(); + } + } } diff --git a/java/core/src/test/java/com/lancedb/lance/DatasetTest.java b/java/core/src/test/java/com/lancedb/lance/DatasetTest.java index 60dd0c5c1e..ccd944689f 100644 --- a/java/core/src/test/java/com/lancedb/lance/DatasetTest.java +++ b/java/core/src/test/java/com/lancedb/lance/DatasetTest.java @@ -130,7 +130,8 @@ void testCreateExist() throws IOException, URISyntaxException { @Test void testCommitConflict() { - String datasetPath = tempDir.resolve("commit_conflict").toString(); + String testMethodName = new Object() {}.getClass().getEnclosingMethod().getName(); + String datasetPath = tempDir.resolve(testMethodName).toString(); try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); try (Dataset dataset = testDataset.createEmptyDataset()) { @@ -142,4 +143,18 @@ void testCommitConflict() { } } } + + @Test + void testGetSchemaWithClosedDataset() { + String testMethodName = new Object() {}.getClass().getEnclosingMethod().getName(); + String datasetPath = tempDir.resolve(testMethodName).toString(); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + TestUtils.SimpleTestDataset testDataset = new TestUtils.SimpleTestDataset(allocator, datasetPath); + Dataset dataset = testDataset.createEmptyDataset(); + dataset.close(); + // dataset.getSchema(); + // assertThrows(IllegalArgumentException.class, testDataset::getSchema); + // assertThrows(RuntimeException.class, dataset::getSchema); + } + } } From 9a4a04a6ae19787f1ba425d0053f9411f37b073c Mon Sep 17 00:00:00 2001 From: Lu Qiu Date: Mon, 20 May 2024 10:04:23 -0700 Subject: [PATCH 6/9] Add ReadWriteLock, avoid clone in JNI Rust object --- java/core/lance-jni/src/blocking_scanner.rs | 37 +++++++--------- .../main/java/com/lancedb/lance/Dataset.java | 36 +++++++-------- .../com/lancedb/lance/DatasetFragment.java | 1 + .../com/lancedb/lance/ipc/LanceScanner.java | 44 +++++++++++++------ 4 files changed, 63 insertions(+), 55 deletions(-) diff --git a/java/core/lance-jni/src/blocking_scanner.rs b/java/core/lance-jni/src/blocking_scanner.rs index 0ec4cb7fd3..38c30b2370 100644 --- a/java/core/lance-jni/src/blocking_scanner.rs +++ b/java/core/lance-jni/src/blocking_scanner.rs @@ -93,18 +93,15 @@ fn inner_create_scanner<'local>( filter_obj: JObject, batch_size_obj: JObject, ) -> JavaResult> { - let dataset = { - let dataset = - unsafe { env.get_rust_field::<_, _, BlockingDataset>(jdataset, NATIVE_DATASET) } - .infer_error()?; - dataset.clone() - }; - let mut scanner = dataset.inner.scan(); let fragment_ids_opt = env.get_ints_opt(&fragment_ids_obj)?; + let dataset_guard = + unsafe { env.get_rust_field::<_, _, BlockingDataset>(jdataset, NATIVE_DATASET) } + .infer_error()?; + let mut scanner = dataset_guard.inner.scan(); if let Some(fragment_ids) = fragment_ids_opt { let mut fragments = Vec::with_capacity(fragment_ids.len()); for fragment_id in fragment_ids { - let Some(fragment) = dataset.inner.get_fragment(fragment_id as usize) else { + let Some(fragment) = dataset_guard.inner.get_fragment(fragment_id as usize) else { return Err(JavaError::input_error(format!( "Fragment {fragment_id} not found" ))); @@ -113,6 +110,7 @@ fn inner_create_scanner<'local>( } scanner.with_fragments(fragments); } + drop(dataset_guard); let columns_opt = env.get_strings_opt(&columns_obj)?; if let Some(columns) = columns_opt { scanner.project(&columns).infer_error()?; @@ -190,13 +188,12 @@ pub extern "system" fn Java_com_lancedb_lance_ipc_LanceScanner_openStream( } fn inner_open_stream(env: &mut JNIEnv, j_scanner: JObject, stream_addr: jlong) -> JavaResult<()> { - let scanner = { + let record_batch_stream = { let scanner_guard = unsafe { env.get_rust_field::<_, _, BlockingScanner>(j_scanner, NATIVE_SCANNER) } .infer_error()?; - scanner_guard.clone() + scanner_guard.open_stream()? }; - let record_batch_stream = scanner.open_stream()?; let ffi_stream = to_ffi_arrow_array_stream(record_batch_stream, RT.handle().clone()).infer_error()?; unsafe { std::ptr::write_unaligned(stream_addr as *mut FFI_ArrowArrayStream, ffi_stream) } @@ -220,20 +217,19 @@ fn inner_import_ffi_schema( j_scanner: JObject, schema_addr: jlong, ) -> JavaResult<()> { - let scanner = { + let schema = { let scanner_guard = unsafe { env.get_rust_field::<_, _, BlockingScanner>(j_scanner, NATIVE_SCANNER) } .infer_error()?; - scanner_guard.clone() + scanner_guard.schema()? }; - let schema = scanner.schema()?; let ffi_schema = FFI_ArrowSchema::try_from(&*schema).infer_error()?; unsafe { std::ptr::write_unaligned(schema_addr as *mut FFI_ArrowSchema, ffi_schema) } Ok(()) } #[no_mangle] -pub extern "system" fn Java_com_lancedb_lance_ipc_LanceScanner_countRows( +pub extern "system" fn Java_com_lancedb_lance_ipc_LanceScanner_nativeCountRows( mut env: JNIEnv, j_scanner: JObject, ) -> jlong { @@ -241,11 +237,8 @@ pub extern "system" fn Java_com_lancedb_lance_ipc_LanceScanner_countRows( } fn inner_count_rows(env: &mut JNIEnv, j_scanner: JObject) -> JavaResult { - let scanner = { - let scanner_guard = - unsafe { env.get_rust_field::<_, _, BlockingScanner>(j_scanner, NATIVE_SCANNER) } - .infer_error()?; - scanner_guard.clone() - }; - scanner.count_rows() + let scanner_guard = + unsafe { env.get_rust_field::<_, _, BlockingScanner>(j_scanner, NATIVE_SCANNER) } + .infer_error()?; + scanner_guard.count_rows() } diff --git a/java/core/src/main/java/com/lancedb/lance/Dataset.java b/java/core/src/main/java/com/lancedb/lance/Dataset.java index 8f0f790b9e..5eebf5dac4 100644 --- a/java/core/src/main/java/com/lancedb/lance/Dataset.java +++ b/java/core/src/main/java/com/lancedb/lance/Dataset.java @@ -49,7 +49,7 @@ public class Dataset implements Closeable { BufferAllocator allocator; - private final ReadWriteLock lock = new ReentrantReadWriteLock(); + private final LockManager lockManager = new LockManager(); private Dataset() { } @@ -184,7 +184,7 @@ public LanceScanner newScan(long batchSize) { */ public LanceScanner newScan(ScanOptions options) { Preconditions.checkNotNull(options); - try (ReadLock readLock = new ReadLock()) { + try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) { Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); return LanceScanner.create(this, options, allocator); } @@ -194,7 +194,7 @@ public LanceScanner newScan(ScanOptions options) { * Gets the currently checked out version of the dataset. */ public long version() { - try (ReadLock readLock = new ReadLock()) { + try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) { Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); return nativeVersion(); } @@ -206,7 +206,7 @@ public long version() { * Gets the latest version of the dataset. */ public long latestVersion() { - try (ReadLock readLock = new ReadLock()) { + try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) { Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); return nativeLatestVersion(); } @@ -220,7 +220,7 @@ public long latestVersion() { * @return num of rows. */ public int countRows() { - try (ReadLock readLock = new ReadLock()) { + try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) { Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); return nativeCountRows(); } @@ -234,7 +234,7 @@ public int countRows() { * @return A list of {@link DatasetFragment}. */ public List getFragments() { - try (ReadLock readLock = new ReadLock()) { + try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) { Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); // Set a pointer in Fragment to dataset, to make it is easier to issue IOs // later. @@ -254,7 +254,7 @@ public List getFragments() { * @return the arrow schema */ public Schema getSchema() { - try (ReadLock readLock = new ReadLock()) { + try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) { Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); try (ArrowSchema ffiArrowSchema = ArrowSchema.allocateNew(allocator)) { importFfiSchema(ffiArrowSchema.memoryAddress()); @@ -272,14 +272,11 @@ public Schema getSchema() { */ @Override public void close() { - lock.writeLock().lock(); - try { + try (LockManager.WriteLock writeLock = lockManager.acquireWriteLock()) { if (nativeDatasetHandle != 0) { releaseNativeDataset(nativeDatasetHandle); nativeDatasetHandle = 0; } - } finally { - lock.writeLock().unlock(); } } @@ -291,15 +288,14 @@ public void close() { */ private native void releaseNativeDataset(long handle); - private class ReadLock implements AutoCloseable { - /** Read lock. */ - public ReadLock() { - lock.readLock().lock(); - } - - @Override - public void close() { - lock.readLock().unlock(); + /** + * Checks if the dataset is closed. + * + * @return true if the dataset is closed, false otherwise. + */ + public boolean closed() { + try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) { + return nativeDatasetHandle == 0; } } } diff --git a/java/core/src/main/java/com/lancedb/lance/DatasetFragment.java b/java/core/src/main/java/com/lancedb/lance/DatasetFragment.java index d655be9f3d..7adf3a0178 100644 --- a/java/core/src/main/java/com/lancedb/lance/DatasetFragment.java +++ b/java/core/src/main/java/com/lancedb/lance/DatasetFragment.java @@ -42,6 +42,7 @@ public class DatasetFragment { * @return a dataset scanner */ public LanceScanner newScan() { + Preconditions.checkState(!dataset.closed(), "Dataset is closed"); return LanceScanner.create(dataset, new ScanOptions.Builder() .fragmentIds(List.of(metadata.getId())).build(), dataset.allocator); } diff --git a/java/core/src/main/java/com/lancedb/lance/ipc/LanceScanner.java b/java/core/src/main/java/com/lancedb/lance/ipc/LanceScanner.java index 7d15471291..20eb5cc518 100644 --- a/java/core/src/main/java/com/lancedb/lance/ipc/LanceScanner.java +++ b/java/core/src/main/java/com/lancedb/lance/ipc/LanceScanner.java @@ -15,6 +15,7 @@ package com.lancedb.lance.ipc; import com.lancedb.lance.Dataset; +import com.lancedb.lance.LockManager; import java.io.IOException; import java.nio.ByteBuffer; import java.util.List; @@ -38,6 +39,8 @@ public class LanceScanner implements org.apache.arrow.dataset.scanner.Scanner { private long nativeScannerHandle; + private final LockManager lockManager = new LockManager(); + private LanceScanner() {} /** @@ -71,9 +74,11 @@ static native LanceScanner createScanner(Dataset dataset, Optional */ @Override public void close() throws Exception { - if (nativeScannerHandle != 0) { - releaseNativeScanner(nativeScannerHandle); - nativeScannerHandle = 0; + try (LockManager.WriteLock writeLock = lockManager.acquireWriteLock()) { + if (nativeScannerHandle != 0) { + releaseNativeScanner(nativeScannerHandle); + nativeScannerHandle = 0; + } } } @@ -87,12 +92,15 @@ public void close() throws Exception { @Override public ArrowReader scanBatches() { - try (ArrowArrayStream s = ArrowArrayStream.allocateNew(allocator)) { - openStream(s.memoryAddress()); - return Data.importArrayStream(allocator, s); - } catch (IOException e) { - // TODO: handle IO exception? - throw new RuntimeException(e); + try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) { + Preconditions.checkArgument(nativeScannerHandle != 0, "Scanner is closed"); + try (ArrowArrayStream s = ArrowArrayStream.allocateNew(allocator)) { + openStream(s.memoryAddress()); + return Data.importArrayStream(allocator, s); + } catch (IOException e) { + // TODO: handle IO exception? + throw new RuntimeException(e); + } } } @@ -106,9 +114,12 @@ public Iterable scan() { @Override public Schema schema() { - try (ArrowSchema ffiArrowSchema = ArrowSchema.allocateNew(allocator)) { - importFfiSchema(ffiArrowSchema.memoryAddress()); - return Data.importSchema(allocator, ffiArrowSchema, null); + try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) { + Preconditions.checkArgument(nativeScannerHandle != 0, "Scanner is closed"); + try (ArrowSchema ffiArrowSchema = ArrowSchema.allocateNew(allocator)) { + importFfiSchema(ffiArrowSchema.memoryAddress()); + return Data.importSchema(allocator, ffiArrowSchema, null); + } } } @@ -119,5 +130,12 @@ public Schema schema() { * * @return num of rows. */ - public native long countRows(); + public long countRows() { + try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) { + Preconditions.checkArgument(nativeScannerHandle != 0, "Scanner is closed"); + return nativeCountRows(); + } + } + + private native long nativeCountRows(); } From 797919fa66482f92c5c962093ef4786335e444c7 Mon Sep 17 00:00:00 2001 From: Lu Qiu Date: Mon, 20 May 2024 11:12:40 -0700 Subject: [PATCH 7/9] small fix --- java/core/src/main/java/com/lancedb/lance/DatasetFragment.java | 1 - 1 file changed, 1 deletion(-) diff --git a/java/core/src/main/java/com/lancedb/lance/DatasetFragment.java b/java/core/src/main/java/com/lancedb/lance/DatasetFragment.java index 7adf3a0178..d655be9f3d 100644 --- a/java/core/src/main/java/com/lancedb/lance/DatasetFragment.java +++ b/java/core/src/main/java/com/lancedb/lance/DatasetFragment.java @@ -42,7 +42,6 @@ public class DatasetFragment { * @return a dataset scanner */ public LanceScanner newScan() { - Preconditions.checkState(!dataset.closed(), "Dataset is closed"); return LanceScanner.create(dataset, new ScanOptions.Builder() .fragmentIds(List.of(metadata.getId())).build(), dataset.allocator); } From 71f35942352acce52a7fa92aa9677807c5f20e7f Mon Sep 17 00:00:00 2001 From: Lu Qiu Date: Mon, 20 May 2024 19:41:26 -0700 Subject: [PATCH 8/9] Use Result Error instead of JavaResult and JavaError to use default From trait --- java/core/lance-jni/src/blocking_dataset.rs | 102 ++++++------- java/core/lance-jni/src/blocking_scanner.rs | 73 ++++----- java/core/lance-jni/src/error.rs | 118 ++++---------- java/core/lance-jni/src/ffi.rs | 161 ++++++++------------ java/core/lance-jni/src/fragment.rs | 44 +++--- java/core/lance-jni/src/lib.rs | 2 +- java/core/lance-jni/src/traits.rs | 62 ++++---- java/core/lance-jni/src/utils.rs | 7 +- 8 files changed, 232 insertions(+), 337 deletions(-) diff --git a/java/core/lance-jni/src/blocking_dataset.rs b/java/core/lance-jni/src/blocking_dataset.rs index 7d60984041..71663ee724 100644 --- a/java/core/lance-jni/src/blocking_dataset.rs +++ b/java/core/lance-jni/src/blocking_dataset.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use crate::error::{JavaErrorExt, JavaResult}; +use crate::error::Result; use crate::ffi::JNIEnvExt; use crate::traits::FromJString; use crate::utils::extract_write_params; @@ -45,31 +45,29 @@ impl BlockingDataset { reader: impl RecordBatchReader + Send + 'static, uri: &str, params: Option, - ) -> JavaResult { - let inner = RT - .block_on(Dataset::write(reader, uri, params)) - .infer_error()?; + ) -> Result { + let inner = RT.block_on(Dataset::write(reader, uri, params))?; Ok(Self { inner }) } - pub fn open(uri: &str) -> JavaResult { - let inner = RT.block_on(Dataset::open(uri)).infer_error()?; + pub fn open(uri: &str) -> Result { + let inner = RT.block_on(Dataset::open(uri))?; Ok(Self { inner }) } - pub fn commit(uri: &str, operation: Operation, read_version: Option) -> JavaResult { - let inner = RT - .block_on(Dataset::commit(uri, operation, read_version, None, None)) - .infer_error()?; + pub fn commit(uri: &str, operation: Operation, read_version: Option) -> Result { + let inner = RT.block_on(Dataset::commit(uri, operation, read_version, None, None))?; Ok(Self { inner }) } - pub fn latest_version(&self) -> JavaResult { - RT.block_on(self.inner.latest_version_id()).infer_error() + pub fn latest_version(&self) -> Result { + let version = RT.block_on(self.inner.latest_version_id())?; + Ok(version) } - pub fn count_rows(&self, filter: Option) -> JavaResult { - RT.block_on(self.inner.count_rows(filter)).infer_error() + pub fn count_rows(&self, filter: Option) -> Result { + let rows = RT.block_on(self.inner.count_rows(filter))?; + Ok(rows) } pub fn close(&self) {} @@ -111,10 +109,10 @@ fn inner_create_with_ffi_schema<'local>( max_rows_per_group: JObject, // Optional max_bytes_per_file: JObject, // Optional mode: JObject, // Optional -) -> JavaResult> { +) -> Result> { let c_schema_ptr = arrow_schema_addr as *mut FFI_ArrowSchema; let c_schema = unsafe { FFI_ArrowSchema::from_raw(c_schema_ptr) }; - let schema = Schema::try_from(&c_schema).infer_error()?; + let schema = Schema::try_from(&c_schema)?; let reader = RecordBatchIterator::new(empty(), Arc::new(schema)); create_dataset( @@ -161,9 +159,9 @@ fn inner_create_with_ffi_stream<'local>( max_rows_per_group: JObject, // Optional max_bytes_per_file: JObject, // Optional mode: JObject, // Optional -) -> JavaResult> { +) -> Result> { let stream_ptr = arrow_array_stream_addr as *mut FFI_ArrowArrayStream; - let reader = unsafe { ArrowArrayStreamReader::from_raw(stream_ptr) }.infer_error()?; + let reader = unsafe { ArrowArrayStreamReader::from_raw(stream_ptr) }?; create_dataset( env, path, @@ -183,7 +181,7 @@ fn create_dataset<'local>( max_bytes_per_file: JObject, mode: JObject, reader: impl RecordBatchReader + Send + 'static, -) -> JavaResult> { +) -> Result> { let path_str = path.extract(env)?; let write_params = extract_write_params( @@ -199,7 +197,7 @@ fn create_dataset<'local>( } impl IntoJava for BlockingDataset { - fn into_java<'a>(self, env: &mut JNIEnv<'a>) -> JavaResult> { + fn into_java<'a>(self, env: &mut JNIEnv<'a>) -> Result> { attach_native_dataset(env, self) } } @@ -207,7 +205,7 @@ impl IntoJava for BlockingDataset { fn attach_native_dataset<'local>( env: &mut JNIEnv<'local>, dataset: BlockingDataset, -) -> JavaResult> { +) -> Result> { let j_dataset = create_java_dataset_object(env)?; // This block sets a native Rust object (dataset) as a field in the Java object (j_dataset). // Caution: This creates a potential for memory leaks. The Rust object (dataset) is not @@ -218,13 +216,13 @@ fn attach_native_dataset<'local>( // 1. The Java object (`j_dataset`) should implement the `java.io.Closeable` interface. // 2. Users of this Java object should be instructed to always use it within a try-with-resources // statement (or manually call the `close()` method) to ensure that `self.close()` is invoked. - unsafe { env.set_rust_field(&j_dataset, NATIVE_DATASET, dataset) }.infer_error()?; + unsafe { env.set_rust_field(&j_dataset, NATIVE_DATASET, dataset) }?; Ok(j_dataset) } -fn create_java_dataset_object<'a>(env: &mut JNIEnv<'a>) -> JavaResult> { - env.new_object("com/lancedb/lance/Dataset", "()V", &[]) - .infer_error() +fn create_java_dataset_object<'a>(env: &mut JNIEnv<'a>) -> Result> { + let objet = env.new_object("com/lancedb/lance/Dataset", "()V", &[])?; + Ok(objet) } #[no_mangle] @@ -246,11 +244,11 @@ pub fn inner_commit_append<'local>( path: JString, read_version_obj: JObject, // Optional fragments_obj: JObject, // List, String is json serialized Fragment) -) -> JavaResult> { +) -> Result> { let json_fragments = env.get_strings(&fragments_obj)?; let mut fragments: Vec = Vec::new(); for json_fragment in json_fragments { - let fragment = Fragment::from_json(&json_fragment).infer_error()?; + let fragment = Fragment::from_json(&json_fragment)?; fragments.push(fragment); } let op = Operation::Append { fragments }; @@ -268,9 +266,8 @@ pub extern "system" fn Java_com_lancedb_lance_Dataset_releaseNativeDataset( ok_or_throw_without_return!(env, inner_release_native_dataset(&mut env, obj)) } -fn inner_release_native_dataset(env: &mut JNIEnv, obj: JObject) -> JavaResult<()> { - let dataset: BlockingDataset = - unsafe { env.take_rust_field(obj, NATIVE_DATASET).infer_error()? }; +fn inner_release_native_dataset(env: &mut JNIEnv, obj: JObject) -> Result<()> { + let dataset: BlockingDataset = unsafe { env.take_rust_field(obj, NATIVE_DATASET)? }; dataset.close(); Ok(()) } @@ -287,10 +284,7 @@ pub extern "system" fn Java_com_lancedb_lance_Dataset_openNative<'local>( ok_or_throw!(env, inner_open_native(&mut env, path)) } -fn inner_open_native<'local>( - env: &mut JNIEnv<'local>, - path: JString, -) -> JavaResult> { +fn inner_open_native<'local>(env: &mut JNIEnv<'local>, path: JString) -> Result> { let path_str: String = path.extract(env)?; let dataset = BlockingDataset::open(&path_str)?; dataset.into_java(env) @@ -307,28 +301,26 @@ pub extern "system" fn Java_com_lancedb_lance_Dataset_getJsonFragments<'a>( fn inner_get_json_fragments<'local>( env: &mut JNIEnv<'local>, jdataset: JObject, -) -> JavaResult> { +) -> Result> { let fragments = { let dataset = - unsafe { env.get_rust_field::<_, _, BlockingDataset>(jdataset, NATIVE_DATASET) } - .infer_error()?; + unsafe { env.get_rust_field::<_, _, BlockingDataset>(jdataset, NATIVE_DATASET) }?; dataset.inner.get_fragments() }; - let array_list_class = env.find_class("java/util/ArrayList").infer_error()?; + let array_list_class = env.find_class("java/util/ArrayList")?; - let array_list = env.new_object(array_list_class, "()V", &[]).infer_error()?; + let array_list = env.new_object(array_list_class, "()V", &[])?; for fragment in fragments { - let json_string = serde_json::to_string(fragment.metadata()).infer_error()?; - let jstring = env.new_string(json_string).infer_error()?; + let json_string = serde_json::to_string(fragment.metadata())?; + let jstring = env.new_string(json_string)?; env.call_method( &array_list, "add", "(Ljava/lang/Object;)Z", &[(&jstring).into()], - ) - .infer_error()?; + )?; } Ok(array_list) } @@ -349,15 +341,14 @@ fn inner_import_ffi_schema( env: &mut JNIEnv, jdataset: JObject, arrow_schema_addr: jlong, -) -> JavaResult<()> { +) -> Result<()> { let schema = { let dataset = - unsafe { env.get_rust_field::<_, _, BlockingDataset>(jdataset, NATIVE_DATASET) } - .infer_error()?; + unsafe { env.get_rust_field::<_, _, BlockingDataset>(jdataset, NATIVE_DATASET) }?; Schema::from(dataset.inner.schema()) }; - let c_schema = FFI_ArrowSchema::try_from(&schema).infer_error()?; + let c_schema = FFI_ArrowSchema::try_from(&schema)?; let out_c_schema = unsafe { &mut *(arrow_schema_addr as *mut FFI_ArrowSchema) }; let _old = std::mem::replace(out_c_schema, c_schema); Ok(()) @@ -371,10 +362,9 @@ pub extern "system" fn Java_com_lancedb_lance_Dataset_nativeVersion( ok_or_throw_with_return!(env, inner_version(&mut env, java_dataset), -1) as jlong } -fn inner_version(env: &mut JNIEnv, java_dataset: JObject) -> JavaResult { +fn inner_version(env: &mut JNIEnv, java_dataset: JObject) -> Result { let dataset_guard = - unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) } - .infer_error()?; + unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }?; Ok(dataset_guard.inner.version().version) } @@ -386,10 +376,9 @@ pub extern "system" fn Java_com_lancedb_lance_Dataset_nativeLatestVersion( ok_or_throw_with_return!(env, inner_latest_version(&mut env, java_dataset), -1) as jlong } -fn inner_latest_version(env: &mut JNIEnv, java_dataset: JObject) -> JavaResult { +fn inner_latest_version(env: &mut JNIEnv, java_dataset: JObject) -> Result { let dataset_guard = - unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) } - .infer_error()?; + unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }?; dataset_guard.latest_version() } @@ -401,9 +390,8 @@ pub extern "system" fn Java_com_lancedb_lance_Dataset_nativeCountRows( ok_or_throw_with_return!(env, inner_count_rows(&mut env, java_dataset), -1) as jint } -fn inner_count_rows(env: &mut JNIEnv, java_dataset: JObject) -> JavaResult { +fn inner_count_rows(env: &mut JNIEnv, java_dataset: JObject) -> Result { let dataset_guard = - unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) } - .infer_error()?; + unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }?; dataset_guard.count_rows(None) } diff --git a/java/core/lance-jni/src/blocking_scanner.rs b/java/core/lance-jni/src/blocking_scanner.rs index 38c30b2370..973bc19b83 100644 --- a/java/core/lance-jni/src/blocking_scanner.rs +++ b/java/core/lance-jni/src/blocking_scanner.rs @@ -14,9 +14,8 @@ use std::sync::Arc; -use crate::error::{JavaErrorExt, JavaResult}; +use crate::error::{Error, Result}; use crate::ffi::JNIEnvExt; -use crate::JavaError; use arrow::{ffi::FFI_ArrowSchema, ffi_stream::FFI_ArrowArrayStream}; use arrow_schema::SchemaRef; use jni::{objects::JObject, sys::jlong, JNIEnv}; @@ -43,16 +42,19 @@ impl BlockingScanner { } } - pub fn open_stream(&self) -> JavaResult { - RT.block_on(self.inner.try_into_stream()).infer_error() + pub fn open_stream(&self) -> Result { + let res = RT.block_on(self.inner.try_into_stream())?; + Ok(res) } - pub fn schema(&self) -> JavaResult { - RT.block_on(self.inner.schema()).infer_error() + pub fn schema(&self) -> Result { + let res = RT.block_on(self.inner.schema())?; + Ok(res) } - pub fn count_rows(&self) -> JavaResult { - RT.block_on(self.inner.count_rows()).infer_error() + pub fn count_rows(&self) -> Result { + let res = RT.block_on(self.inner.count_rows())?; + Ok(res) } } @@ -92,17 +94,16 @@ fn inner_create_scanner<'local>( substrait_filter_obj: JObject, filter_obj: JObject, batch_size_obj: JObject, -) -> JavaResult> { +) -> Result> { let fragment_ids_opt = env.get_ints_opt(&fragment_ids_obj)?; let dataset_guard = - unsafe { env.get_rust_field::<_, _, BlockingDataset>(jdataset, NATIVE_DATASET) } - .infer_error()?; + unsafe { env.get_rust_field::<_, _, BlockingDataset>(jdataset, NATIVE_DATASET) }?; let mut scanner = dataset_guard.inner.scan(); if let Some(fragment_ids) = fragment_ids_opt { let mut fragments = Vec::with_capacity(fragment_ids.len()); for fragment_id in fragment_ids { let Some(fragment) = dataset_guard.inner.get_fragment(fragment_id as usize) else { - return Err(JavaError::input_error(format!( + return Err(Error::input_error(format!( "Fragment {fragment_id} not found" ))); }; @@ -113,16 +114,15 @@ fn inner_create_scanner<'local>( drop(dataset_guard); let columns_opt = env.get_strings_opt(&columns_obj)?; if let Some(columns) = columns_opt { - scanner.project(&columns).infer_error()?; + scanner.project(&columns)?; }; let substrait_opt = env.get_bytes_opt(&substrait_filter_obj)?; if let Some(substrait) = substrait_opt { - RT.block_on(async { scanner.filter_substrait(substrait).await }) - .infer_error()?; + RT.block_on(async { scanner.filter_substrait(substrait).await })?; } let filter_opt = env.get_string_opt(&filter_obj)?; if let Some(filter) = filter_opt { - scanner.filter(filter.as_str()).infer_error()?; + scanner.filter(filter.as_str())?; } let batch_size_opt = env.get_long_opt(&batch_size_obj)?; if let Some(batch_size) = batch_size_opt { @@ -140,14 +140,13 @@ pub extern "system" fn Java_com_lancedb_lance_ipc_LanceScanner_releaseNativeScan ok_or_throw_without_return!(env, inner_release_native_scanner(&mut env, j_scanner)); } -fn inner_release_native_scanner(env: &mut JNIEnv, j_scanner: JObject) -> JavaResult<()> { - let _: BlockingScanner = - unsafe { env.take_rust_field(j_scanner, NATIVE_SCANNER) }.infer_error()?; +fn inner_release_native_scanner(env: &mut JNIEnv, j_scanner: JObject) -> Result<()> { + let _: BlockingScanner = unsafe { env.take_rust_field(j_scanner, NATIVE_SCANNER) }?; Ok(()) } impl IntoJava for BlockingScanner { - fn into_java<'local>(self, env: &mut JNIEnv<'local>) -> JavaResult> { + fn into_java<'local>(self, env: &mut JNIEnv<'local>) -> Result> { attach_native_scanner(env, self) } } @@ -155,7 +154,7 @@ impl IntoJava for BlockingScanner { fn attach_native_scanner<'local>( env: &mut JNIEnv<'local>, scanner: BlockingScanner, -) -> JavaResult> { +) -> Result> { let j_scanner = create_java_scanner_object(env)?; // This block sets a native Rust object (scanner) as a field in the Java object (j_scanner). // Caution: This creates a potential for memory leaks. The Rust object (scanner) is not @@ -166,13 +165,13 @@ fn attach_native_scanner<'local>( // 1. The Java object (`j_scanner`) should implement the `java.io.Closeable` interface. // 2. Users of this Java object should be instructed to always use it within a try-with-resources // statement (or manually call the `close()` method) to ensure that `self.close()` is invoked. - unsafe { env.set_rust_field(&j_scanner, NATIVE_SCANNER, scanner) }.infer_error()?; + unsafe { env.set_rust_field(&j_scanner, NATIVE_SCANNER, scanner) }?; Ok(j_scanner) } -fn create_java_scanner_object<'a>(env: &mut JNIEnv<'a>) -> JavaResult> { - env.new_object("com/lancedb/lance/ipc/LanceScanner", "()V", &[]) - .infer_error() +fn create_java_scanner_object<'a>(env: &mut JNIEnv<'a>) -> Result> { + let res = env.new_object("com/lancedb/lance/ipc/LanceScanner", "()V", &[])?; + Ok(res) } ////////////////// @@ -187,15 +186,13 @@ pub extern "system" fn Java_com_lancedb_lance_ipc_LanceScanner_openStream( ok_or_throw_without_return!(env, inner_open_stream(&mut env, j_scanner, stream_addr)); } -fn inner_open_stream(env: &mut JNIEnv, j_scanner: JObject, stream_addr: jlong) -> JavaResult<()> { +fn inner_open_stream(env: &mut JNIEnv, j_scanner: JObject, stream_addr: jlong) -> Result<()> { let record_batch_stream = { let scanner_guard = - unsafe { env.get_rust_field::<_, _, BlockingScanner>(j_scanner, NATIVE_SCANNER) } - .infer_error()?; + unsafe { env.get_rust_field::<_, _, BlockingScanner>(j_scanner, NATIVE_SCANNER) }?; scanner_guard.open_stream()? }; - let ffi_stream = - to_ffi_arrow_array_stream(record_batch_stream, RT.handle().clone()).infer_error()?; + let ffi_stream = to_ffi_arrow_array_stream(record_batch_stream, RT.handle().clone())?; unsafe { std::ptr::write_unaligned(stream_addr as *mut FFI_ArrowArrayStream, ffi_stream) } Ok(()) } @@ -212,18 +209,13 @@ pub extern "system" fn Java_com_lancedb_lance_ipc_LanceScanner_importFfiSchema( ); } -fn inner_import_ffi_schema( - env: &mut JNIEnv, - j_scanner: JObject, - schema_addr: jlong, -) -> JavaResult<()> { +fn inner_import_ffi_schema(env: &mut JNIEnv, j_scanner: JObject, schema_addr: jlong) -> Result<()> { let schema = { let scanner_guard = - unsafe { env.get_rust_field::<_, _, BlockingScanner>(j_scanner, NATIVE_SCANNER) } - .infer_error()?; + unsafe { env.get_rust_field::<_, _, BlockingScanner>(j_scanner, NATIVE_SCANNER) }?; scanner_guard.schema()? }; - let ffi_schema = FFI_ArrowSchema::try_from(&*schema).infer_error()?; + let ffi_schema = FFI_ArrowSchema::try_from(&*schema)?; unsafe { std::ptr::write_unaligned(schema_addr as *mut FFI_ArrowSchema, ffi_schema) } Ok(()) } @@ -236,9 +228,8 @@ pub extern "system" fn Java_com_lancedb_lance_ipc_LanceScanner_nativeCountRows( ok_or_throw_with_return!(env, inner_count_rows(&mut env, j_scanner), -1) as jlong } -fn inner_count_rows(env: &mut JNIEnv, j_scanner: JObject) -> JavaResult { +fn inner_count_rows(env: &mut JNIEnv, j_scanner: JObject) -> Result { let scanner_guard = - unsafe { env.get_rust_field::<_, _, BlockingScanner>(j_scanner, NATIVE_SCANNER) } - .infer_error()?; + unsafe { env.get_rust_field::<_, _, BlockingScanner>(j_scanner, NATIVE_SCANNER) }?; scanner_guard.count_rows() } diff --git a/java/core/lance-jni/src/error.rs b/java/core/lance-jni/src/error.rs index cc3be31f9c..c2e6153c88 100644 --- a/java/core/lance-jni/src/error.rs +++ b/java/core/lance-jni/src/error.rs @@ -19,8 +19,6 @@ use jni::{errors::Error as JniError, JNIEnv}; use lance::error::Error as LanceError; use serde_json::Error as JsonError; -pub type JavaResult = std::result::Result; - #[derive(Debug)] pub enum JavaExceptionClass { IllegalArgumentException, @@ -41,12 +39,12 @@ impl JavaExceptionClass { } #[derive(Debug)] -pub struct JavaError { +pub struct Error { message: String, java_class: JavaExceptionClass, } -impl JavaError { +impl Error { pub fn new(message: String, java_class: JavaExceptionClass) -> Self { Self { message, @@ -79,107 +77,53 @@ impl JavaError { } } -impl std::fmt::Display for JavaError { +pub type Result = std::result::Result; + +impl std::fmt::Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}: {}", self.java_class.as_str(), self.message) } } -impl std::error::Error for JavaError {} - -/// Trait for converting errors to Java exceptions. -pub trait JavaErrorConversion { - /// Convert to `JavaError` as I/O exception. - fn io_error(self) -> JavaResult; - - /// Convert to `JavaError` as runtime exception. - fn runtime_error(self) -> JavaResult; - - /// Convert to `JavaError` as value (input) exception. - fn input_error(self) -> JavaResult; - - /// Convert to `JavaError` as unsupported operation exception. - fn unsupported_error(self) -> JavaResult; -} - -impl JavaErrorConversion for std::result::Result { - fn io_error(self) -> JavaResult { - self.map_err(|err| JavaError::io_error(err.to_string())) - } - - fn runtime_error(self) -> JavaResult { - self.map_err(|err| JavaError::runtime_error(err.to_string())) - } - - fn input_error(self) -> JavaResult { - self.map_err(|err| JavaError::input_error(err.to_string())) - } - - fn unsupported_error(self) -> JavaResult { - self.map_err(|err| JavaError::unsupported_error(err.to_string())) - } -} - -/// JavaErrorExt trait that converts specific error types to Java exceptions -pub trait JavaErrorExt { - /// Convert to a Java error based on the specific error type - fn infer_error(self) -> JavaResult; -} - -impl JavaErrorExt for std::result::Result { - fn infer_error(self) -> JavaResult { - match &self { - Ok(_) => Ok(self.unwrap()), - Err(err) => match err { - LanceError::DatasetNotFound { .. } - | LanceError::DatasetAlreadyExists { .. } - | LanceError::CommitConflict { .. } - | LanceError::InvalidInput { .. } => self.input_error(), - LanceError::IO { .. } => self.io_error(), - LanceError::NotSupported { .. } => self.unsupported_error(), - _ => self.runtime_error(), - }, +impl From for Error { + fn from(err: LanceError) -> Self { + match err { + LanceError::DatasetNotFound { .. } + | LanceError::DatasetAlreadyExists { .. } + | LanceError::CommitConflict { .. } + | LanceError::InvalidInput { .. } => Self::input_error(err.to_string()), + LanceError::IO { .. } => Self::io_error(err.to_string()), + LanceError::NotSupported { .. } => Self::unsupported_error(err.to_string()), + _ => Self::runtime_error(err.to_string()), } } } -impl JavaErrorExt for std::result::Result { - fn infer_error(self) -> JavaResult { - match &self { - Ok(_) => Ok(self.unwrap()), - Err(err) => match err { - ArrowError::InvalidArgumentError { .. } => self.input_error(), - ArrowError::IoError { .. } => self.io_error(), - ArrowError::NotYetImplemented(_) => self.unsupported_error(), - _ => self.runtime_error(), - }, +impl From for Error { + fn from(err: ArrowError) -> Self { + match err { + ArrowError::InvalidArgumentError { .. } => Self::input_error(err.to_string()), + ArrowError::IoError { .. } => Self::io_error(err.to_string()), + ArrowError::NotYetImplemented(_) => Self::unsupported_error(err.to_string()), + _ => Self::runtime_error(err.to_string()), } } } -impl JavaErrorExt for std::result::Result { - fn infer_error(self) -> JavaResult { - match &self { - Ok(_) => Ok(self.unwrap()), - Err(_) => self.io_error(), - } +impl From for Error { + fn from(err: JsonError) -> Self { + Self::io_error(err.to_string()) } } -impl JavaErrorExt for std::result::Result { - fn infer_error(self) -> JavaResult { - match &self { - Ok(_) => Ok(self.unwrap()), - Err(_) => self.runtime_error(), - } +impl From for Error { + fn from(err: JniError) -> Self { + Self::runtime_error(err.to_string()) } } -impl JavaErrorExt for std::result::Result { - fn infer_error(self) -> JavaResult { - match &self { - Ok(_) => Ok(self.unwrap()), - Err(_) => self.input_error(), - } +impl From for Error { + fn from(err: Utf8Error) -> Self { + Self::input_error(err.to_string()) } } diff --git a/java/core/lance-jni/src/ffi.rs b/java/core/lance-jni/src/ffi.rs index 4f8347fc7b..16aa45a9e8 100644 --- a/java/core/lance-jni/src/ffi.rs +++ b/java/core/lance-jni/src/ffi.rs @@ -14,7 +14,7 @@ use core::slice; -use crate::error::{JavaErrorExt, JavaResult}; +use crate::error::Result; use jni::objects::{JByteBuffer, JObjectArray, JString}; use jni::sys::jobjectArray; use jni::{objects::JObject, JNIEnv}; @@ -22,185 +22,156 @@ use jni::{objects::JObject, JNIEnv}; /// Extend JNIEnv with helper functions. pub trait JNIEnvExt { /// Get integers from Java List object. - fn get_integers(&mut self, obj: &JObject) -> JavaResult>; + fn get_integers(&mut self, obj: &JObject) -> Result>; /// Get strings from Java List object. - fn get_strings(&mut self, obj: &JObject) -> JavaResult>; + fn get_strings(&mut self, obj: &JObject) -> Result>; /// Get strings from Java String[] object. /// Note that get Option> from Java Optional just doesn't work. #[allow(dead_code)] - fn get_strings_array(&mut self, obj: jobjectArray) -> JavaResult>; + fn get_strings_array(&mut self, obj: jobjectArray) -> Result>; /// Get Option from Java Optional. - fn get_string_opt(&mut self, obj: &JObject) -> JavaResult>; + fn get_string_opt(&mut self, obj: &JObject) -> Result>; /// Get Option> from Java Optional>. #[allow(dead_code)] - fn get_strings_opt(&mut self, obj: &JObject) -> JavaResult>>; + fn get_strings_opt(&mut self, obj: &JObject) -> Result>>; /// Get Option from Java Optional. - fn get_int_opt(&mut self, obj: &JObject) -> JavaResult>; + fn get_int_opt(&mut self, obj: &JObject) -> Result>; /// Get Option> from Java Optional>. - fn get_ints_opt(&mut self, obj: &JObject) -> JavaResult>>; + fn get_ints_opt(&mut self, obj: &JObject) -> Result>>; /// Get Option from Java Optional. - fn get_long_opt(&mut self, obj: &JObject) -> JavaResult>; + fn get_long_opt(&mut self, obj: &JObject) -> Result>; /// Get Option from Java Optional. - fn get_u64_opt(&mut self, obj: &JObject) -> JavaResult>; + fn get_u64_opt(&mut self, obj: &JObject) -> Result>; /// Get Option<&[u8]> from Java Optional. - fn get_bytes_opt(&mut self, obj: &JObject) -> JavaResult>; + fn get_bytes_opt(&mut self, obj: &JObject) -> Result>; - fn get_optional(&mut self, obj: &JObject, f: F) -> JavaResult> + fn get_optional(&mut self, obj: &JObject, f: F) -> Result> where - F: FnOnce(&mut JNIEnv, &JObject) -> JavaResult; + F: FnOnce(&mut JNIEnv, &JObject) -> Result; } impl JNIEnvExt for JNIEnv<'_> { - fn get_integers(&mut self, obj: &JObject) -> JavaResult> { - let list = self.get_list(obj).infer_error()?; - let mut iter = list.iter(self).infer_error()?; - let mut results = Vec::with_capacity(list.size(self).infer_error()? as usize); - while let Some(elem) = iter.next(self).infer_error()? { - let int_obj = self - .call_method(elem, "intValue", "()I", &[]) - .infer_error()?; - let int_value = int_obj.i().infer_error()?; + fn get_integers(&mut self, obj: &JObject) -> Result> { + let list = self.get_list(obj)?; + let mut iter = list.iter(self)?; + let mut results = Vec::with_capacity(list.size(self)? as usize); + while let Some(elem) = iter.next(self)? { + let int_obj = self.call_method(elem, "intValue", "()I", &[])?; + let int_value = int_obj.i()?; results.push(int_value); } Ok(results) } - fn get_strings(&mut self, obj: &JObject) -> JavaResult> { - let list = self.get_list(obj).infer_error()?; - let mut iter = list.iter(self).infer_error()?; - let mut results = Vec::with_capacity(list.size(self).infer_error()? as usize); - while let Some(elem) = iter.next(self).infer_error()? { + fn get_strings(&mut self, obj: &JObject) -> Result> { + let list = self.get_list(obj)?; + let mut iter = list.iter(self)?; + let mut results = Vec::with_capacity(list.size(self)? as usize); + while let Some(elem) = iter.next(self)? { let jstr = JString::from(elem); - let val = self.get_string(&jstr).infer_error()?; - results.push(val.to_str().infer_error()?.to_string()) + let val = self.get_string(&jstr)?; + results.push(val.to_str()?.to_string()) } Ok(results) } - fn get_strings_array(&mut self, obj: jobjectArray) -> JavaResult> { + fn get_strings_array(&mut self, obj: jobjectArray) -> Result> { let jobject_array = unsafe { JObjectArray::from_raw(obj) }; - let array_len = self.get_array_length(&jobject_array).infer_error()?; + let array_len = self.get_array_length(&jobject_array)?; let mut res: Vec = Vec::new(); for i in 0..array_len { - let item: JString = self - .get_object_array_element(&jobject_array, i) - .infer_error()? - .into(); - res.push(self.get_string(&item).infer_error()?.into()); + let item: JString = self.get_object_array_element(&jobject_array, i)?.into(); + res.push(self.get_string(&item)?.into()); } Ok(res) } - fn get_string_opt(&mut self, obj: &JObject) -> JavaResult> { + fn get_string_opt(&mut self, obj: &JObject) -> Result> { self.get_optional(obj, |env, inner_obj| { - let java_obj_gen = env - .call_method(inner_obj, "get", "()Ljava/lang/Object;", &[]) - .infer_error()?; - let java_string_obj = java_obj_gen.l().infer_error()?; + let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?; + let java_string_obj = java_obj_gen.l()?; let jstr = JString::from(java_string_obj); - let val = env.get_string(&jstr).infer_error()?; - Ok(val.to_str().infer_error()?.to_string()) + let val = env.get_string(&jstr)?; + Ok(val.to_str()?.to_string()) }) } - fn get_strings_opt(&mut self, obj: &JObject) -> JavaResult>> { + fn get_strings_opt(&mut self, obj: &JObject) -> Result>> { self.get_optional(obj, |env, inner_obj| { - let java_obj_gen = env - .call_method(inner_obj, "get", "()Ljava/lang/Object;", &[]) - .infer_error()?; - let java_list_obj = java_obj_gen.l().infer_error()?; + let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?; + let java_list_obj = java_obj_gen.l()?; env.get_strings(&java_list_obj) }) } - fn get_int_opt(&mut self, obj: &JObject) -> JavaResult> { + fn get_int_opt(&mut self, obj: &JObject) -> Result> { self.get_optional(obj, |env, inner_obj| { - let java_obj_gen = env - .call_method(inner_obj, "get", "()Ljava/lang/Object;", &[]) - .infer_error()?; - let java_int_obj = java_obj_gen.l().infer_error()?; - let int_obj = env - .call_method(java_int_obj, "intValue", "()I", &[]) - .infer_error()?; - let int_value = int_obj.i().infer_error()?; + let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?; + let java_int_obj = java_obj_gen.l()?; + let int_obj = env.call_method(java_int_obj, "intValue", "()I", &[])?; + let int_value = int_obj.i()?; Ok(int_value) }) } - fn get_ints_opt(&mut self, obj: &JObject) -> JavaResult>> { + fn get_ints_opt(&mut self, obj: &JObject) -> Result>> { self.get_optional(obj, |env, inner_obj| { - let java_obj_gen = env - .call_method(inner_obj, "get", "()Ljava/lang/Object;", &[]) - .infer_error()?; - let java_list_obj = java_obj_gen.l().infer_error()?; + let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?; + let java_list_obj = java_obj_gen.l()?; env.get_integers(&java_list_obj) }) } - fn get_long_opt(&mut self, obj: &JObject) -> JavaResult> { + fn get_long_opt(&mut self, obj: &JObject) -> Result> { self.get_optional(obj, |env, inner_obj| { - let java_obj_gen = env - .call_method(inner_obj, "get", "()Ljava/lang/Object;", &[]) - .infer_error()?; - let java_long_obj = java_obj_gen.l().infer_error()?; - let long_obj = env - .call_method(java_long_obj, "longValue", "()J", &[]) - .infer_error()?; - let long_value = long_obj.j().infer_error()?; + let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?; + let java_long_obj = java_obj_gen.l()?; + let long_obj = env.call_method(java_long_obj, "longValue", "()J", &[])?; + let long_value = long_obj.j()?; Ok(long_value) }) } - fn get_u64_opt(&mut self, obj: &JObject) -> JavaResult> { + fn get_u64_opt(&mut self, obj: &JObject) -> Result> { self.get_optional(obj, |env, inner_obj| { - let java_obj_gen = env - .call_method(inner_obj, "get", "()Ljava/lang/Object;", &[]) - .infer_error()?; - let java_long_obj = java_obj_gen.l().infer_error()?; - let long_obj = env - .call_method(java_long_obj, "longValue", "()J", &[]) - .infer_error()?; - let long_value = long_obj.j().infer_error()?; + let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?; + let java_long_obj = java_obj_gen.l()?; + let long_obj = env.call_method(java_long_obj, "longValue", "()J", &[])?; + let long_value = long_obj.j()?; Ok(long_value as u64) }) } - fn get_bytes_opt(&mut self, obj: &JObject) -> JavaResult> { + fn get_bytes_opt(&mut self, obj: &JObject) -> Result> { self.get_optional(obj, |env, inner_obj| { - let java_obj_gen = env - .call_method(inner_obj, "get", "()Ljava/lang/Object;", &[]) - .infer_error()?; - let java_byte_buffer_obj = java_obj_gen.l().infer_error()?; + let java_obj_gen = env.call_method(inner_obj, "get", "()Ljava/lang/Object;", &[])?; + let java_byte_buffer_obj = java_obj_gen.l()?; let j_byte_buffer = JByteBuffer::from(java_byte_buffer_obj); - let raw_data = env - .get_direct_buffer_address(&j_byte_buffer) - .infer_error()?; - let capacity = env - .get_direct_buffer_capacity(&j_byte_buffer) - .infer_error()?; + let raw_data = env.get_direct_buffer_address(&j_byte_buffer)?; + let capacity = env.get_direct_buffer_capacity(&j_byte_buffer)?; let data = unsafe { slice::from_raw_parts(raw_data, capacity) }; Ok(data) }) } - fn get_optional(&mut self, obj: &JObject, f: F) -> JavaResult> + fn get_optional(&mut self, obj: &JObject, f: F) -> Result> where - F: FnOnce(&mut JNIEnv, &JObject) -> JavaResult, + F: FnOnce(&mut JNIEnv, &JObject) -> Result, { if obj.is_null() { return Ok(None); } - let is_empty = self.call_method(obj, "isEmpty", "()Z", &[]).infer_error()?; - if is_empty.z().infer_error()? { + let is_empty = self.call_method(obj, "isEmpty", "()Z", &[])?; + if is_empty.z()? { // TODO(lu): put get java object into here cuz can only get java Object Ok(None) } else { diff --git a/java/core/lance-jni/src/fragment.rs b/java/core/lance-jni/src/fragment.rs index d85d3d94e9..24d9863b22 100644 --- a/java/core/lance-jni/src/fragment.rs +++ b/java/core/lance-jni/src/fragment.rs @@ -25,8 +25,7 @@ use std::iter::once; use lance::dataset::fragment::FileFragment; -use crate::error::{JavaErrorExt, JavaResult}; -use crate::JavaError; +use crate::error::{Error, Result}; use crate::{ blocking_dataset::{BlockingDataset, NATIVE_DATASET}, ffi::JNIEnvExt, @@ -60,15 +59,15 @@ fn inner_count_rows_native( env: &mut JNIEnv, jdataset: JObject, fragment_id: jlong, -) -> JavaResult { - let dataset = unsafe { env.get_rust_field::<_, _, BlockingDataset>(jdataset, NATIVE_DATASET) } - .infer_error()?; +) -> Result { + let dataset = unsafe { env.get_rust_field::<_, _, BlockingDataset>(jdataset, NATIVE_DATASET) }?; let Some(fragment) = dataset.inner.get_fragment(fragment_id as usize) else { - return Err(JavaError::input_error(format!( + return Err(Error::input_error(format!( "Fragment not found: {fragment_id}" ))); }; - RT.block_on(fragment.count_rows()).infer_error() + let res = RT.block_on(fragment.count_rows())?; + Ok(res) } #[no_mangle] @@ -112,15 +111,15 @@ fn inner_create_with_ffi_array<'local>( max_rows_per_group: JObject, // Optional max_bytes_per_file: JObject, // Optional mode: JObject, // Optional -) -> JavaResult> { +) -> Result> { let c_array_ptr = arrow_array_addr as *mut FFI_ArrowArray; let c_schema_ptr = arrow_schema_addr as *mut FFI_ArrowSchema; let c_array = unsafe { FFI_ArrowArray::from_raw(c_array_ptr) }; let c_schema = unsafe { FFI_ArrowSchema::from_raw(c_schema_ptr) }; - let data_type = DataType::try_from(&c_schema).infer_error()?; + let data_type = DataType::try_from(&c_schema)?; - let array_data = unsafe { from_ffi_and_data_type(c_array, data_type) }.infer_error()?; + let array_data = unsafe { from_ffi_and_data_type(c_array, data_type) }?; let record_batch = RecordBatch::from(StructArray::from(array_data)); let batch_schema = record_batch.schema().clone(); @@ -176,9 +175,9 @@ fn inner_create_with_ffi_stream<'local>( max_rows_per_group: JObject, // Optional max_bytes_per_file: JObject, // Optional mode: JObject, // Optional -) -> JavaResult> { +) -> Result> { let stream_ptr = arrow_array_stream_addr as *mut FFI_ArrowArrayStream; - let reader = unsafe { ArrowArrayStreamReader::from_raw(stream_ptr) }.infer_error()?; + let reader = unsafe { ArrowArrayStreamReader::from_raw(stream_ptr) }?; create_fragment( env, @@ -202,7 +201,7 @@ fn create_fragment<'a>( max_bytes_per_file: JObject, // Optional mode: JObject, // Optional reader: impl RecordBatchReader + Send + 'static, -) -> JavaResult> { +) -> Result> { let path_str = dataset_uri.extract(env)?; let fragment_id_opts = env.get_int_opt(&fragment_id)?; @@ -214,14 +213,13 @@ fn create_fragment<'a>( &max_bytes_per_file, &mode, )?; - let fragment = RT - .block_on(FileFragment::create( - &path_str, - fragment_id_opts.unwrap_or(0) as usize, - reader, - Some(write_params), - )) - .infer_error()?; - let json_string = serde_json::to_string(&fragment).infer_error()?; - env.new_string(json_string).infer_error() + let fragment = RT.block_on(FileFragment::create( + &path_str, + fragment_id_opts.unwrap_or(0) as usize, + reader, + Some(write_params), + ))?; + let json_string = serde_json::to_string(&fragment)?; + let res = env.new_string(json_string)?; + Ok(res) } diff --git a/java/core/lance-jni/src/lib.rs b/java/core/lance-jni/src/lib.rs index da53927949..0ee83da971 100644 --- a/java/core/lance-jni/src/lib.rs +++ b/java/core/lance-jni/src/lib.rs @@ -57,7 +57,7 @@ mod ffi; mod fragment; mod traits; mod utils; -pub use error::{JavaError, JavaErrorExt, JavaResult}; +pub use error::Result; use lazy_static::lazy_static; diff --git a/java/core/lance-jni/src/traits.rs b/java/core/lance-jni/src/traits.rs index a2d2156411..d91b449b1c 100644 --- a/java/core/lance-jni/src/traits.rs +++ b/java/core/lance-jni/src/traits.rs @@ -15,74 +15,78 @@ use jni::objects::{JMap, JObject, JString, JValue}; use jni::JNIEnv; -use crate::error::{JavaErrorExt, JavaResult}; +use crate::error::Result; pub trait FromJObject { - fn extract(&self) -> JavaResult; + fn extract(&self) -> Result; } /// Convert a Rust type into a Java Object. pub trait IntoJava { - fn into_java<'a>(self, env: &mut JNIEnv<'a>) -> JavaResult>; + fn into_java<'a>(self, env: &mut JNIEnv<'a>) -> Result>; } impl FromJObject for JObject<'_> { - fn extract(&self) -> JavaResult { - JValue::from(self).i().infer_error() + fn extract(&self) -> Result { + let res = JValue::from(self).i()?; + Ok(res) } } impl FromJObject for JObject<'_> { - fn extract(&self) -> JavaResult { - JValue::from(self).j().infer_error() + fn extract(&self) -> Result { + let res = JValue::from(self).j()?; + Ok(res) } } impl FromJObject for JObject<'_> { - fn extract(&self) -> JavaResult { - JValue::from(self).f().infer_error() + fn extract(&self) -> Result { + let res = JValue::from(self).f()?; + Ok(res) } } impl FromJObject for JObject<'_> { - fn extract(&self) -> JavaResult { - JValue::from(self).d().infer_error() + fn extract(&self) -> Result { + let res = JValue::from(self).d()?; + Ok(res) } } pub trait FromJString { - fn extract(&self, env: &mut JNIEnv) -> JavaResult; + fn extract(&self, env: &mut JNIEnv) -> Result; } impl FromJString for JString<'_> { - fn extract(&self, env: &mut JNIEnv) -> JavaResult { - Ok(env.get_string(self).infer_error()?.into()) + fn extract(&self, env: &mut JNIEnv) -> Result { + Ok(env.get_string(self)?.into()) } } pub trait JMapExt { #[allow(dead_code)] - fn get_string(&self, env: &mut JNIEnv, key: &str) -> JavaResult>; + fn get_string(&self, env: &mut JNIEnv, key: &str) -> Result>; #[allow(dead_code)] - fn get_i32(&self, env: &mut JNIEnv, key: &str) -> JavaResult>; + fn get_i32(&self, env: &mut JNIEnv, key: &str) -> Result>; #[allow(dead_code)] - fn get_i64(&self, env: &mut JNIEnv, key: &str) -> JavaResult>; + fn get_i64(&self, env: &mut JNIEnv, key: &str) -> Result>; #[allow(dead_code)] - fn get_f32(&self, env: &mut JNIEnv, key: &str) -> JavaResult>; + fn get_f32(&self, env: &mut JNIEnv, key: &str) -> Result>; #[allow(dead_code)] - fn get_f64(&self, env: &mut JNIEnv, key: &str) -> JavaResult>; + fn get_f64(&self, env: &mut JNIEnv, key: &str) -> Result>; } -fn get_map_value(env: &mut JNIEnv, map: &JMap, key: &str) -> JavaResult> +fn get_map_value(env: &mut JNIEnv, map: &JMap, key: &str) -> Result> where for<'a> JObject<'a>: FromJObject, { - let key_obj: JObject = env.new_string(key).infer_error()?.into(); - if let Some(value) = map.get(env, &key_obj).infer_error()? { + let key_obj: JObject = env.new_string(key)?.into(); + if let Some(value) = map.get(env, &key_obj)? { if value.is_null() { Ok(None) } else { @@ -94,9 +98,9 @@ where } impl JMapExt for JMap<'_, '_, '_> { - fn get_string(&self, env: &mut JNIEnv, key: &str) -> JavaResult> { - let key_obj: JObject = env.new_string(key).infer_error()?.into(); - if let Some(value) = self.get(env, &key_obj).infer_error()? { + fn get_string(&self, env: &mut JNIEnv, key: &str) -> Result> { + let key_obj: JObject = env.new_string(key)?.into(); + if let Some(value) = self.get(env, &key_obj)? { let value_str: JString = value.into(); Ok(Some(value_str.extract(env)?)) } else { @@ -104,19 +108,19 @@ impl JMapExt for JMap<'_, '_, '_> { } } - fn get_i32(&self, env: &mut JNIEnv, key: &str) -> JavaResult> { + fn get_i32(&self, env: &mut JNIEnv, key: &str) -> Result> { get_map_value(env, self, key) } - fn get_i64(&self, env: &mut JNIEnv, key: &str) -> JavaResult> { + fn get_i64(&self, env: &mut JNIEnv, key: &str) -> Result> { get_map_value(env, self, key) } - fn get_f32(&self, env: &mut JNIEnv, key: &str) -> JavaResult> { + fn get_f32(&self, env: &mut JNIEnv, key: &str) -> Result> { get_map_value(env, self, key) } - fn get_f64(&self, env: &mut JNIEnv, key: &str) -> JavaResult> { + fn get_f64(&self, env: &mut JNIEnv, key: &str) -> Result> { get_map_value(env, self, key) } } diff --git a/java/core/lance-jni/src/utils.rs b/java/core/lance-jni/src/utils.rs index 9cb1fc895e..ba6ba5bf7f 100644 --- a/java/core/lance-jni/src/utils.rs +++ b/java/core/lance-jni/src/utils.rs @@ -16,9 +16,8 @@ use jni::objects::JObject; use jni::JNIEnv; use lance::dataset::{WriteMode, WriteParams}; -use crate::error::JavaResult; +use crate::error::Result; use crate::ffi::JNIEnvExt; -use crate::JavaErrorExt; pub fn extract_write_params( env: &mut JNIEnv, @@ -26,7 +25,7 @@ pub fn extract_write_params( max_rows_per_group: &JObject, max_bytes_per_file: &JObject, mode: &JObject, -) -> JavaResult { +) -> Result { let mut write_params = WriteParams::default(); if let Some(max_rows_per_file_val) = env.get_int_opt(max_rows_per_file)? { @@ -39,7 +38,7 @@ pub fn extract_write_params( write_params.max_bytes_per_file = max_bytes_per_file_val as usize; } if let Some(mode_val) = env.get_string_opt(mode)? { - write_params.mode = WriteMode::try_from(mode_val.as_str()).infer_error()?; + write_params.mode = WriteMode::try_from(mode_val.as_str())?; } Ok(write_params) } From 93a1aeaa3c0416378aca93e37b3200b2c0e6ad79 Mon Sep 17 00:00:00 2001 From: Lu Qiu Date: Mon, 20 May 2024 20:24:41 -0700 Subject: [PATCH 9/9] lock manager --- .../java/com/lancedb/lance/LockManager.java | 77 +++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 java/core/src/main/java/com/lancedb/lance/LockManager.java diff --git a/java/core/src/main/java/com/lancedb/lance/LockManager.java b/java/core/src/main/java/com/lancedb/lance/LockManager.java new file mode 100644 index 0000000000..f65a500188 --- /dev/null +++ b/java/core/src/main/java/com/lancedb/lance/LockManager.java @@ -0,0 +1,77 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.lancedb.lance; + +import java.util.concurrent.locks.ReentrantReadWriteLock; + +/** + * The LockManager class provides a way to manage read and write locks. + */ +public class LockManager { + private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock(); + + /** + * Represents a read lock for the LockManager. + * This lock allows multiple threads to read concurrently, but prevents write access. + */ + public class ReadLock implements AutoCloseable { + /** + * Acquires a read lock on the lock manager. + */ + public ReadLock() { + lock.readLock().lock(); + } + + @Override + public void close() { + lock.readLock().unlock(); + } + } + + /** + * Represents a write lock that can be acquired and released. + */ + public class WriteLock implements AutoCloseable { + /** + * Constructs a new WriteLock and acquires the write lock. + */ + public WriteLock() { + lock.writeLock().lock(); + } + + @Override + public void close() { + lock.writeLock().unlock(); + } + } + + /** + * Acquires a read lock on the LockManager. + * + * @return the acquired ReadLock object + */ + public ReadLock acquireReadLock() { + return new ReadLock(); + } + + /** + * Acquires a write lock on the LockManager. + * + * @return the acquired WriteLock object + */ + public WriteLock acquireWriteLock() { + return new WriteLock(); + } +} \ No newline at end of file