Skip to content

Commit

Permalink
Cache the jclass objects
Browse files Browse the repository at this point in the history
Signed-off-by: Fredy Wijaya <fredyw@google.com>
  • Loading branch information
fredyw committed May 24, 2024
1 parent 3fab851 commit 840d8cd
Show file tree
Hide file tree
Showing 10 changed files with 115 additions and 85 deletions.
12 changes: 3 additions & 9 deletions mobile/library/jni/jni_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,7 @@ jmethodID JniHelper::getStaticMethodId(jclass clazz, const char* name, const cha
return method_id;
}

LocalRefUniquePtr<jclass> JniHelper::findClass(const char* class_name) {
LocalRefUniquePtr<jclass> result(env_->FindClass(class_name), LocalRefDeleter(env_));
rethrowException();
return result;
}

jclass JniHelper::findClassFromCache(const char* class_name) {
jclass JniHelper::findClass(const char* class_name) {
if (auto i = JCLASS_CACHES.find(class_name); i != JCLASS_CACHES.end()) {
return i->second;
}
Expand All @@ -108,9 +102,9 @@ LocalRefUniquePtr<jclass> JniHelper::getObjectClass(jobject object) {
}

void JniHelper::throwNew(const char* java_class_name, const char* message) {
LocalRefUniquePtr<jclass> java_class = findClass(java_class_name);
jclass java_class = findClass(java_class_name);
if (java_class != nullptr) {
jint error = env_->ThrowNew(java_class.get(), message);
jint error = env_->ThrowNew(java_class, message);
ASSERT(error == JNI_OK, "Failed calling ThrowNew.");
}
}
Expand Down
11 changes: 2 additions & 9 deletions mobile/library/jni/jni_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,18 +230,11 @@ class JniHelper {
jmethodID getStaticMethodId(jclass clazz, const char* name, const char* signature);

/**
* Finds the given `class_name` using Java classloader.
* Finds the given `class_name` using from the cache.
*
* https://docs.oracle.com/en/java/javase/17/docs/specs/jni/functions.html#findclass
*/
[[nodiscard]] LocalRefUniquePtr<jclass> findClass(const char* class_name);

/**
* Finds the given `class_name` from the cache.
*
* https://docs.oracle.com/en/java/javase/17/docs/specs/jni/functions.html#findclass
*/
[[nodiscard]] jclass findClassFromCache(const char* class_name);
[[nodiscard]] jclass findClass(const char* class_name);

/**
* Returns the class of a given `object`.
Expand Down
24 changes: 17 additions & 7 deletions mobile/library/jni/jni_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ using Envoy::Platform::EngineBuilder;

extern "C" JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void* /*reserved*/) {
Envoy::JNI::JniHelper::initialize(vm);
Envoy::JNI::JniHelper::addClassToCache("java/lang/Object");
Envoy::JNI::JniHelper::addClassToCache("java/lang/Integer");
Envoy::JNI::JniHelper::addClassToCache("java/lang/ClassLoader");
Envoy::JNI::JniHelper::addClassToCache("java/nio/ByteBuffer");
Envoy::JNI::JniHelper::addClassToCache("java/lang/Throwable");
Envoy::JNI::JniHelper::addClassToCache("java/lang/UnsupportedOperationException");
Envoy::JNI::JniHelper::addClassToCache("[B");
Envoy::JNI::JniHelper::addClassToCache("java/util/Map$Entry");
Envoy::JNI::JniHelper::addClassToCache("java/util/LinkedHashMap");
Envoy::JNI::JniHelper::addClassToCache("java/util/HashMap");
Envoy::JNI::JniHelper::addClassToCache("java/util/List");
Envoy::JNI::JniHelper::addClassToCache("java/util/ArrayList");
Envoy::JNI::JniHelper::addClassToCache("io/envoyproxy/envoymobile/engine/types/EnvoyStreamIntel");
Envoy::JNI::JniHelper::addClassToCache(
"io/envoyproxy/envoymobile/engine/types/EnvoyFinalStreamIntel");
Expand Down Expand Up @@ -229,15 +241,13 @@ jvm_on_headers(const char* method, const Envoy::Types::ManagedEnvoyHeaders& head
// Create a "no operation" result:
// 1. Tell the filter chain to continue the iteration.
// 2. Return headers received on as method's input as part of the method's output.
Envoy::JNI::LocalRefUniquePtr<jclass> jcls_object_array =
jni_helper.findClass("java/lang/Object");
jclass jcls_object_array = jni_helper.findClass("java/lang/Object");
Envoy::JNI::LocalRefUniquePtr<jobjectArray> noopResult =
jni_helper.newObjectArray(2, jcls_object_array.get(), NULL);
jni_helper.newObjectArray(2, jcls_object_array, NULL);

Envoy::JNI::LocalRefUniquePtr<jclass> jcls_int = jni_helper.findClass("java/lang/Integer");
jmethodID jmid_intInit = jni_helper.getMethodId(jcls_int.get(), "<init>", "(I)V");
Envoy::JNI::LocalRefUniquePtr<jobject> j_status =
jni_helper.newObject(jcls_int.get(), jmid_intInit, 0);
jclass jcls_int = jni_helper.findClass("java/lang/Integer");
jmethodID jmid_intInit = jni_helper.getMethodId(jcls_int, "<init>", "(I)V");
Envoy::JNI::LocalRefUniquePtr<jobject> j_status = jni_helper.newObject(jcls_int, jmid_intInit, 0);
// Set status to "0" (FilterHeadersStatus::Continue). Signal that the intent
// is to continue the iteration of the filter chain.
jni_helper.setObjectArrayElement(noopResult.get(), 0, j_status.get());
Expand Down
88 changes: 44 additions & 44 deletions mobile/library/jni/jni_utility.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ jobject getClassLoader() {

LocalRefUniquePtr<jclass> findClass(const char* class_name) {
JniHelper jni_helper(JniHelper::getThreadLocalEnv());
LocalRefUniquePtr<jclass> class_loader = jni_helper.findClass("java/lang/ClassLoader");
jmethodID find_class_method = jni_helper.getMethodId(class_loader.get(), "loadClass",
"(Ljava/lang/String;)Ljava/lang/Class;");
jclass class_loader = jni_helper.findClass("java/lang/ClassLoader");
jmethodID find_class_method =
jni_helper.getMethodId(class_loader, "loadClass", "(Ljava/lang/String;)Ljava/lang/Class;");
LocalRefUniquePtr<jstring> str_class_name = jni_helper.newStringUtf(class_name);
LocalRefUniquePtr<jclass> clazz = jni_helper.callObjectMethod<jclass>(
getClassLoader(), find_class_method, str_class_name.get());
Expand All @@ -43,8 +43,8 @@ void jniDeleteConstGlobalRef(const void* context) {
}

int javaIntegerToCppInt(JniHelper& jni_helper, jobject boxed_integer) {
LocalRefUniquePtr<jclass> jcls_Integer = jni_helper.findClass("java/lang/Integer");
jmethodID jmid_intValue = jni_helper.getMethodId(jcls_Integer.get(), "intValue", "()I");
jclass jcls_Integer = jni_helper.findClass("java/lang/Integer");
jmethodID jmid_intValue = jni_helper.getMethodId(jcls_Integer, "intValue", "()I");
return jni_helper.callIntMethod(boxed_integer, jmid_intValue);
}

Expand Down Expand Up @@ -120,11 +120,11 @@ envoy_data javaByteBufferToEnvoyData(JniHelper& jni_helper, jobject j_data) {
jlong data_length = jni_helper.getDirectBufferCapacity(j_data);

if (data_length < 0) {
LocalRefUniquePtr<jclass> jcls_ByteBuffer = jni_helper.findClass("java/nio/ByteBuffer");
jclass jcls_ByteBuffer = jni_helper.findClass("java/nio/ByteBuffer");
// We skip checking hasArray() because only direct ByteBuffers or array-backed ByteBuffers
// are supported. We will crash here if this is an invalid buffer, but guards may be
// implemented in the JVM layer.
jmethodID jmid_array = jni_helper.getMethodId(jcls_ByteBuffer.get(), "array", "()[B");
jmethodID jmid_array = jni_helper.getMethodId(jcls_ByteBuffer, "array", "()[B");
LocalRefUniquePtr<jbyteArray> array =
jni_helper.callObjectMethod<jbyteArray>(j_data, jmid_array);
envoy_data native_data = javaByteArrayToEnvoyData(jni_helper, array.get());
Expand All @@ -139,11 +139,11 @@ envoy_data javaByteBufferToEnvoyData(JniHelper& jni_helper, jobject j_data, jlon
uint8_t* direct_address = jni_helper.getDirectBufferAddress<uint8_t*>(j_data);

if (direct_address == nullptr) {
LocalRefUniquePtr<jclass> jcls_ByteBuffer = jni_helper.findClass("java/nio/ByteBuffer");
jclass jcls_ByteBuffer = jni_helper.findClass("java/nio/ByteBuffer");
// We skip checking hasArray() because only direct ByteBuffers or array-backed ByteBuffers
// are supported. We will crash here if this is an invalid buffer, but guards may be
// implemented in the JVM layer.
jmethodID jmid_array = jni_helper.getMethodId(jcls_ByteBuffer.get(), "array", "()[B");
jmethodID jmid_array = jni_helper.getMethodId(jcls_ByteBuffer, "array", "()[B");
LocalRefUniquePtr<jbyteArray> array =
jni_helper.callObjectMethod<jbyteArray>(j_data, jmid_array);
envoy_data native_data = javaByteArrayToEnvoyData(jni_helper, array.get(), data_length);
Expand Down Expand Up @@ -229,9 +229,9 @@ envoy_map javaArrayOfObjectArrayToEnvoyMap(JniHelper& jni_helper, jobjectArray e
LocalRefUniquePtr<jobjectArray>
envoyHeadersToJavaArrayOfObjectArray(JniHelper& jni_helper,
const Envoy::Types::ManagedEnvoyHeaders& map) {
LocalRefUniquePtr<jclass> jcls_byte_array = jni_helper.findClass("java/lang/Object");
jclass jcls_byte_array = jni_helper.findClass("java/lang/Object");
LocalRefUniquePtr<jobjectArray> javaArray =
jni_helper.newObjectArray(2 * map.get().length, jcls_byte_array.get(), nullptr);
jni_helper.newObjectArray(2 * map.get().length, jcls_byte_array, nullptr);

for (envoy_map_size_t i = 0; i < map.get().length; i++) {
LocalRefUniquePtr<jbyteArray> key =
Expand All @@ -248,9 +248,9 @@ envoyHeadersToJavaArrayOfObjectArray(JniHelper& jni_helper,

LocalRefUniquePtr<jobjectArray>
vectorStringToJavaArrayOfByteArray(JniHelper& jni_helper, const std::vector<std::string>& v) {
LocalRefUniquePtr<jclass> jcls_byte_array = jni_helper.findClass("[B");
jclass jcls_byte_array = jni_helper.findClass("[B");
LocalRefUniquePtr<jobjectArray> joa =
jni_helper.newObjectArray(v.size(), jcls_byte_array.get(), nullptr);
jni_helper.newObjectArray(v.size(), jcls_byte_array, nullptr);

for (size_t i = 0; i < v.size(); ++i) {
LocalRefUniquePtr<jbyteArray> byte_array = byteArrayToJavaByteArray(
Expand Down Expand Up @@ -367,14 +367,14 @@ absl::flat_hash_map<std::string, std::string> javaMapToCppMap(JniHelper& jni_hel
auto java_entry_set_object = jni_helper.callObjectMethod(java_map, java_entry_set_method_id);

auto java_set_class = jni_helper.getObjectClass(java_entry_set_object.get());
auto java_map_entry_class = jni_helper.findClass("java/util/Map$Entry");
jclass java_map_entry_class = jni_helper.findClass("java/util/Map$Entry");

auto java_iterator_method_id =
jni_helper.getMethodId(java_set_class.get(), "iterator", "()Ljava/util/Iterator;");
auto java_get_key_method_id =
jni_helper.getMethodId(java_map_entry_class.get(), "getKey", "()Ljava/lang/Object;");
jni_helper.getMethodId(java_map_entry_class, "getKey", "()Ljava/lang/Object;");
auto java_get_value_method_id =
jni_helper.getMethodId(java_map_entry_class.get(), "getValue", "()Ljava/lang/Object;");
jni_helper.getMethodId(java_map_entry_class, "getValue", "()Ljava/lang/Object;");

auto java_iterator_object =
jni_helper.callObjectMethod(java_entry_set_object.get(), java_iterator_method_id);
Expand Down Expand Up @@ -403,18 +403,18 @@ absl::flat_hash_map<std::string, std::string> javaMapToCppMap(JniHelper& jni_hel
LocalRefUniquePtr<jobject> cppHeadersToJavaHeaders(JniHelper& jni_helper,
const Http::HeaderMap& cpp_headers) {
// Use LinkedHashMap to preserve the insertion order.
auto java_map_class = jni_helper.findClass("java/util/LinkedHashMap");
auto java_map_init_method_id = jni_helper.getMethodId(java_map_class.get(), "<init>", "()V");
jclass java_map_class = jni_helper.findClass("java/util/LinkedHashMap");
auto java_map_init_method_id = jni_helper.getMethodId(java_map_class, "<init>", "()V");
auto java_map_put_method_id = jni_helper.getMethodId(
java_map_class.get(), "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");
java_map_class, "put", "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;");
auto java_map_get_method_id =
jni_helper.getMethodId(java_map_class.get(), "get", "(Ljava/lang/Object;)Ljava/lang/Object;");
auto java_map_object = jni_helper.newObject(java_map_class.get(), java_map_init_method_id);
jni_helper.getMethodId(java_map_class, "get", "(Ljava/lang/Object;)Ljava/lang/Object;");
auto java_map_object = jni_helper.newObject(java_map_class, java_map_init_method_id);

auto java_list_class = jni_helper.findClass("java/util/ArrayList");
auto java_list_init_method_id = jni_helper.getMethodId(java_list_class.get(), "<init>", "()V");
jclass java_list_class = jni_helper.findClass("java/util/ArrayList");
auto java_list_init_method_id = jni_helper.getMethodId(java_list_class, "<init>", "()V");
auto java_list_add_method_id =
jni_helper.getMethodId(java_list_class.get(), "add", "(Ljava/lang/Object;)Z");
jni_helper.getMethodId(java_list_class, "add", "(Ljava/lang/Object;)Z");

cpp_headers.iterate([&](const Http::HeaderEntry& header) -> Http::HeaderMap::Iterate {
std::string cpp_key = std::string(header.key().getStringView());
Expand All @@ -431,7 +431,7 @@ LocalRefUniquePtr<jobject> cppHeadersToJavaHeaders(JniHelper& jni_helper,
jni_helper.callObjectMethod(java_map_object.get(), java_map_get_method_id, java_key.get());
if (existing_value == nullptr) { // the key does not exist
// Create a new list.
auto java_list_object = jni_helper.newObject(java_list_class.get(), java_list_init_method_id);
auto java_list_object = jni_helper.newObject(java_list_class, java_list_init_method_id);
jni_helper.callBooleanMethod(java_list_object.get(), java_list_add_method_id,
java_value.get());
// Put the new list into the map.
Expand All @@ -456,14 +456,14 @@ void javaHeadersToCppHeaders(JniHelper& jni_helper, jobject java_headers,
auto java_entry_set_object = jni_helper.callObjectMethod(java_headers, java_entry_set_method_id);

auto java_set_class = jni_helper.getObjectClass(java_entry_set_object.get());
auto java_map_entry_class = jni_helper.findClass("java/util/Map$Entry");
jclass java_map_entry_class = jni_helper.findClass("java/util/Map$Entry");

auto java_map_iter_method_id =
jni_helper.getMethodId(java_set_class.get(), "iterator", "()Ljava/util/Iterator;");
auto java_map_get_key_method_id =
jni_helper.getMethodId(java_map_entry_class.get(), "getKey", "()Ljava/lang/Object;");
jni_helper.getMethodId(java_map_entry_class, "getKey", "()Ljava/lang/Object;");
auto java_map_get_value_method_id =
jni_helper.getMethodId(java_map_entry_class.get(), "getValue", "()Ljava/lang/Object;");
jni_helper.getMethodId(java_map_entry_class, "getValue", "()Ljava/lang/Object;");

auto java_iter_object =
jni_helper.callObjectMethod(java_entry_set_object.get(), java_map_iter_method_id);
Expand All @@ -473,10 +473,10 @@ void javaHeadersToCppHeaders(JniHelper& jni_helper, jobject java_headers,
auto java_iter_next_method_id =
jni_helper.getMethodId(java_iterator_class.get(), "next", "()Ljava/lang/Object;");

auto java_list_class = jni_helper.findClass("java/util/List");
auto java_list_size_method_id = jni_helper.getMethodId(java_list_class.get(), "size", "()I");
jclass java_list_class = jni_helper.findClass("java/util/List");
auto java_list_size_method_id = jni_helper.getMethodId(java_list_class, "size", "()I");
auto java_list_get_method_id =
jni_helper.getMethodId(java_list_class.get(), "get", "(I)Ljava/lang/Object;");
jni_helper.getMethodId(java_list_class, "get", "(I)Ljava/lang/Object;");

while (jni_helper.callBooleanMethod(java_iter_object.get(), java_iter_has_next_method_id)) {
auto java_entry_object =
Expand All @@ -503,9 +503,9 @@ void javaHeadersToCppHeaders(JniHelper& jni_helper, jobject java_headers,
}

bool isJavaDirectByteBuffer(JniHelper& jni_helper, jobject java_byte_buffer) {
auto java_byte_buffer_class = jni_helper.findClass("java/nio/ByteBuffer");
jclass java_byte_buffer_class = jni_helper.findClass("java/nio/ByteBuffer");
auto java_byte_buffer_is_direct_method_id =
jni_helper.getMethodId(java_byte_buffer_class.get(), "isDirect", "()Z");
jni_helper.getMethodId(java_byte_buffer_class, "isDirect", "()Z");
return jni_helper.callBooleanMethod(java_byte_buffer, java_byte_buffer_is_direct_method_id);
}

Expand Down Expand Up @@ -541,9 +541,9 @@ LocalRefUniquePtr<jobject> cppBufferInstanceToJavaDirectByteBuffer(
Buffer::InstancePtr javaNonDirectByteBufferToCppBufferInstance(JniHelper& jni_helper,
jobject java_byte_buffer,
jlong length) {
auto java_byte_buffer_class = jni_helper.findClass("java/nio/ByteBuffer");
jclass java_byte_buffer_class = jni_helper.findClass("java/nio/ByteBuffer");
auto java_byte_buffer_array_method_id =
jni_helper.getMethodId(java_byte_buffer_class.get(), "array", "()[B");
jni_helper.getMethodId(java_byte_buffer_class, "array", "()[B");
auto java_byte_array =
jni_helper.callObjectMethod<jbyteArray>(java_byte_buffer, java_byte_buffer_array_method_id);
ASSERT(java_byte_array != nullptr, "The ByteBuffer argument is not a non-direct ByteBuffer.");
Expand All @@ -556,20 +556,20 @@ Buffer::InstancePtr javaNonDirectByteBufferToCppBufferInstance(JniHelper& jni_he

LocalRefUniquePtr<jobject> cppBufferInstanceToJavaNonDirectByteBuffer(
JniHelper& jni_helper, const Buffer::Instance& cpp_buffer_instance, uint64_t length) {
auto java_byte_buffer_class = jni_helper.findClass("java/nio/ByteBuffer");
auto java_byte_buffer_wrap_method_id = jni_helper.getStaticMethodId(
java_byte_buffer_class.get(), "wrap", "([B)Ljava/nio/ByteBuffer;");
jclass java_byte_buffer_class = jni_helper.findClass("java/nio/ByteBuffer");
auto java_byte_buffer_wrap_method_id =
jni_helper.getStaticMethodId(java_byte_buffer_class, "wrap", "([B)Ljava/nio/ByteBuffer;");
auto java_byte_array = jni_helper.newByteArray(static_cast<jsize>(cpp_buffer_instance.length()));
auto java_byte_array_elements = jni_helper.getByteArrayElements(java_byte_array.get(), nullptr);
cpp_buffer_instance.copyOut(0, length, static_cast<void*>(java_byte_array_elements.get()));
return jni_helper.callStaticObjectMethod(java_byte_buffer_class.get(),
java_byte_buffer_wrap_method_id, java_byte_array.get());
return jni_helper.callStaticObjectMethod(java_byte_buffer_class, java_byte_buffer_wrap_method_id,
java_byte_array.get());
}

std::string getJavaExceptionMessage(JniHelper& jni_helper, jthrowable throwable) {
auto java_throwable_class = jni_helper.findClass("java/lang/Throwable");
jclass java_throwable_class = jni_helper.findClass("java/lang/Throwable");
auto java_get_message_method_id =
jni_helper.getMethodId(java_throwable_class.get(), "getMessage", "()Ljava/lang/String;");
jni_helper.getMethodId(java_throwable_class, "getMessage", "()Ljava/lang/String;");
auto java_exception_message =
jni_helper.callObjectMethod<jstring>(throwable, java_get_message_method_id);
return javaStringToCppString(jni_helper, java_exception_message.get());
Expand Down Expand Up @@ -602,7 +602,7 @@ envoy_stream_intel javaStreamIntelToCppStreamIntel(JniHelper& jni_helper,
LocalRefUniquePtr<jobject> cppStreamIntelToJavaStreamIntel(JniHelper& jni_helper,
const envoy_stream_intel& stream_intel) {
auto java_stream_intel_class =
jni_helper.findClassFromCache("io/envoyproxy/envoymobile/engine/types/EnvoyStreamIntel");
jni_helper.findClass("io/envoyproxy/envoymobile/engine/types/EnvoyStreamIntel");
auto java_stream_intel_init_method_id =
jni_helper.getMethodId(java_stream_intel_class, "<init>", "(JJJJ)V");
return jni_helper.newObject(java_stream_intel_class, java_stream_intel_init_method_id,
Expand Down Expand Up @@ -688,7 +688,7 @@ LocalRefUniquePtr<jobject>
cppFinalStreamIntelToJavaFinalStreamIntel(JniHelper& jni_helper,
const envoy_final_stream_intel& final_stream_intel) {
auto java_final_stream_intel_class =
jni_helper.findClassFromCache("io/envoyproxy/envoymobile/engine/types/EnvoyFinalStreamIntel");
jni_helper.findClass("io/envoyproxy/envoymobile/engine/types/EnvoyFinalStreamIntel");
auto java_final_stream_intel_init_method_id =
jni_helper.getMethodId(java_final_stream_intel_class, "<init>", "(JJJJJJJJJJJZJJJJ)V");
return jni_helper.newObject(java_final_stream_intel_class, java_final_stream_intel_init_method_id,
Expand Down
Loading

0 comments on commit 840d8cd

Please sign in to comment.