From b7acd3d1082d02373e8a2fde15da0da76ce38c99 Mon Sep 17 00:00:00 2001 From: william Date: Wed, 18 Oct 2023 01:14:37 +0800 Subject: [PATCH] Add special token handling --- CHANGELOG | 2 +- llamajava/llamajava.cpp | 4 ++-- llamajava/llamajava.h | 2 +- src/main/java/chat/octet/model/Generator.java | 2 +- .../java/chat/octet/model/LlamaService.java | 22 ++++++++++--------- 5 files changed, 17 insertions(+), 15 deletions(-) diff --git a/CHANGELOG b/CHANGELOG index dd8a2f1..c3cbda8 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,3 +1,3 @@ 1. Update tensor_split param. 2. Update Java docs. -3. Update llama.cpp libs version to b1385. \ No newline at end of file +3. Update llama.cpp libs version to b1387. \ No newline at end of file diff --git a/llamajava/llamajava.cpp b/llamajava/llamajava.cpp index 900cf39..eba09e2 100644 --- a/llamajava/llamajava.cpp +++ b/llamajava/llamajava.cpp @@ -418,14 +418,14 @@ JNIEXPORT jint JNICALL Java_chat_octet_model_LlamaService_getTokenEOS JNIEXPORT jint JNICALL Java_chat_octet_model_LlamaService_tokenize (JNIEnv *env, jclass thisClass, jbyteArray buf, jint buffer_length, jintArray tokens_arrays, - jint maxTokens, jboolean addBos) { + jint maxTokens, jboolean addBos, jboolean specialTokens) { llama_token *tokens = (llama_token *) env->GetIntArrayElements(tokens_arrays, JNI_FALSE); jbyte *buffer = new jbyte[buffer_length]; env->GetByteArrayRegion(buf, 0, buffer_length, buffer); const char *text = (char *) buffer; - int code = llama_tokenize(model, text, buffer_length, tokens, maxTokens, ToCBool(addBos)); + int code = llama_tokenize(model, text, buffer_length, tokens, maxTokens, ToCBool(addBos), ToCBool(specialTokens)); env->ReleaseIntArrayElements(tokens_arrays, tokens, 0); env->ReleaseByteArrayElements(buf, buffer, 0); return code; diff --git a/llamajava/llamajava.h b/llamajava/llamajava.h index 79c7a89..c79e4aa 100644 --- a/llamajava/llamajava.h +++ b/llamajava/llamajava.h @@ -132,7 +132,7 @@ JNIEXPORT jint JNICALL Java_chat_octet_model_LlamaService_getTokenEOS * Method: tokenize */ JNIEXPORT jint JNICALL Java_chat_octet_model_LlamaService_tokenize - (JNIEnv *, jclass, jbyteArray, jint, jintArray, jint, jboolean); + (JNIEnv *, jclass, jbyteArray, jint, jintArray, jint, jboolean, jboolean); /* * Class: chat_octet_model_LlamaService diff --git a/src/main/java/chat/octet/model/Generator.java b/src/main/java/chat/octet/model/Generator.java index fda7631..b92a789 100644 --- a/src/main/java/chat/octet/model/Generator.java +++ b/src/main/java/chat/octet/model/Generator.java @@ -47,7 +47,7 @@ protected Generator(GenerateParameter generateParams, String prompt, Status srcS this.contextSize = LlamaService.getContextSize(); this.status = srcStatus == null ? new Status() : new Status(srcStatus); - int[] tokens = StringUtils.isNotBlank(prompt) ? LlamaService.tokenize(prompt, true) : new int[]{LlamaService.getTokenBOS()}; + int[] tokens = StringUtils.isNotBlank(prompt) ? LlamaService.tokenize(prompt, true, true) : new int[]{LlamaService.getTokenBOS()}; if (tokens.length >= contextSize) { throw new IllegalArgumentException(MessageFormat.format("Requested tokens ({0}) exceed context window of {1}.", tokens.length, contextSize)); } diff --git a/src/main/java/chat/octet/model/LlamaService.java b/src/main/java/chat/octet/model/LlamaService.java index 9991f94..854d448 100644 --- a/src/main/java/chat/octet/model/LlamaService.java +++ b/src/main/java/chat/octet/model/LlamaService.java @@ -163,14 +163,15 @@ public class LlamaService { * The tokens pointer must be large enough to hold the resulting tokens. * Returns the number of tokens on success, no more than n_max_tokens. * - * @param buf Text byte buffer. - * @param bufferLength Text byte buffer length. - * @param tokens Empty token arrays, Used to receive the returned tokens. - * @param maxTokens Max token size, by default is context size. - * @param addBos Add special BOS token. + * @param buf Text byte buffer. + * @param bufferLength Text byte buffer length. + * @param tokens Empty token arrays, Used to receive the returned tokens. + * @param maxTokens Max token size, by default is context size. + * @param addBos Add special BOS token. + * @param specialTokens Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space. * @return int, Returns a negative number on failure, else the number of tokens that would have been returned. */ - public static native int tokenize(byte[] buf, int bufferLength, int[] tokens, int maxTokens, boolean addBos); + public static native int tokenize(byte[] buf, int bufferLength, int[] tokens, int maxTokens, boolean addBos, boolean specialTokens); /** * Convert the token id to text piece. @@ -263,15 +264,16 @@ public static void clearCache(int sequenceId) { /** * Convert the provided text into tokens. * - * @param text Input text. - * @param addBos Add special BOS token. + * @param text Input text. + * @param addBos Add special BOS token. + * @param specialTokens Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space. * @return Returns a negative number on failure, else the number of tokens that would have been returned. */ - public static int[] tokenize(String text, boolean addBos) { + public static int[] tokenize(String text, boolean addBos, boolean specialTokens) { Preconditions.checkNotNull(text, "Text cannot be null"); int[] tokens = new int[getContextSize()]; byte[] textBytes = text.getBytes(StandardCharsets.UTF_8); - int nextTokens = tokenize(textBytes, textBytes.length, tokens, getContextSize(), addBos); + int nextTokens = tokenize(textBytes, textBytes.length, tokens, getContextSize(), addBos, specialTokens); if (nextTokens < 0) { throw new ModelException(MessageFormat.format("Failed to tokenize: {0}, next_tokens: {1}", text, nextTokens)); }