Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions src/main/cpp/jllama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "arg.h"
#include "llama.h"
#include "log.h"
#include "json-schema-to-grammar.h"
#include "nlohmann/json.hpp"
#include "server.hpp"

Expand Down Expand Up @@ -431,7 +432,6 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo
if (!ctx_server->load_model(params))
{
llama_backend_free();
;
env->ThrowNew(c_llama_error, "could not load model from given file path");
return;
}
Expand All @@ -442,7 +442,7 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo
LOG_INF("%s: model loaded\n", __func__);

const auto model_meta = ctx_server->model_meta();

if (!params.speculative.model.empty() || !params.speculative.hf_repo.empty()) {
SRV_INF("loading draft model '%s'\n", params.speculative.model.c_str());
auto params_dft = params;
Expand Down Expand Up @@ -493,7 +493,7 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo
LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
common_chat_templates_source(ctx_server->chat_templates.get()),
common_chat_format_example(ctx_server->chat_templates.get(), ctx_server->params_base.use_jinja).c_str());


// print sample chat example to make it clear which template is used
// LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
Expand Down Expand Up @@ -543,9 +543,9 @@ JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestCompletion(JNIEnv
try
{
const auto & prompt = data.at("prompt");

std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server->vocab, prompt, true, true);

tasks.reserve(tokenized_prompts.size());
for (size_t i = 0; i < tokenized_prompts.size(); i++)
{
Expand Down Expand Up @@ -600,7 +600,7 @@ JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_receiveCompletion(JNIE
auto *ctx_server = reinterpret_cast<server_context *>(server_handle); // NOLINT(*-no-int-to-ptr)

server_task_result_ptr result = ctx_server->queue_results.recv(id_task);

if (result->is_error())
{
std::string response = result->to_json()["message"].get<std::string>();
Expand All @@ -609,9 +609,9 @@ JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_receiveCompletion(JNIE
return nullptr;
}
const auto out_res = result->to_json();





std::string response = out_res["content"].get<std::string>();
if (result->is_stop())
{
Expand Down Expand Up @@ -652,11 +652,11 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env,
"model was not loaded with embedding support (see ModelParameters#setEmbedding(boolean))");
return nullptr;
}



const std::string prompt = parse_jstring(env, jprompt);

SRV_INF("Calling embedding '%s'\n", prompt.c_str());

const auto tokens = tokenize_mixed(ctx_server->vocab, prompt, true, true);
Expand Down Expand Up @@ -716,7 +716,7 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env,
// Extract only the first row
const std::vector<float>& first_row = embedding[0]; // Reference to avoid copying


// Create a new float array in JNI
jfloatArray j_embedding = env->NewFloatArray(embedding_cols);
if (j_embedding == nullptr) {
Expand Down Expand Up @@ -819,3 +819,11 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_setLogger(JNIEnv *env, jc
}
}
}

JNIEXPORT jbyteArray JNICALL Java_de_kherud_llama_LlamaModel_jsonSchemaToGrammarBytes(JNIEnv *env, jclass clazz, jstring j_schema)
{
const std::string c_schema = parse_jstring(env, j_schema);
nlohmann::ordered_json c_schema_json = nlohmann::ordered_json::parse(c_schema);
const std::string c_grammar = json_schema_to_grammar(c_schema_json);
return parse_jbytes(env, c_grammar);
}
16 changes: 11 additions & 5 deletions src/main/cpp/jllama.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions src/main/java/de/kherud/llama/LlamaModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,9 @@ public void close() {

private native void releaseTask(int taskId);

private static native byte[] jsonSchemaToGrammarBytes(String schema);

public static String jsonSchemaToGrammar(String schema) {
return new String(jsonSchemaToGrammarBytes(schema), StandardCharsets.UTF_8);
}
}
25 changes: 25 additions & 0 deletions src/test/java/de/kherud/llama/LlamaModelTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -271,4 +271,29 @@ private LogMessage(LogLevel level, String text) {
this.text = text;
}
}

@Test
public void testJsonSchemaToGrammar() {
String schema = "{\n" +
" \"properties\": {\n" +
" \"a\": {\"type\": \"string\"},\n" +
" \"b\": {\"type\": \"string\"},\n" +
" \"c\": {\"type\": \"string\"}\n" +
" },\n" +
" \"additionalProperties\": false\n" +
"}";

String expectedGrammar = "a-kv ::= \"\\\"a\\\"\" space \":\" space string\n" +
"a-rest ::= ( \",\" space b-kv )? b-rest\n" +
"b-kv ::= \"\\\"b\\\"\" space \":\" space string\n" +
"b-rest ::= ( \",\" space c-kv )?\n" +
"c-kv ::= \"\\\"c\\\"\" space \":\" space string\n" +
"char ::= [^\"\\\\\\x7F\\x00-\\x1F] | [\\\\] ([\"\\\\bfnrt] | \"u\" [0-9a-fA-F]{4})\n" +
"root ::= \"{\" space (a-kv a-rest | b-kv b-rest | c-kv )? \"}\" space\n" +
"space ::= | \" \" | \"\\n\"{1,2} [ \\t]{0,20}\n" +
"string ::= \"\\\"\" char* \"\\\"\" space\n";

String actualGrammar = LlamaModel.jsonSchemaToGrammar(schema);
Assert.assertEquals(expectedGrammar, actualGrammar);
}
}