Skip to content

Commit

Permalink
Add special token handling
Browse files Browse the repository at this point in the history
  • Loading branch information
eoctet committed Oct 17, 2023
1 parent c729dc8 commit b7acd3d
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 15 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
1. Update tensor_split param.
2. Update Java docs.
3. Update llama.cpp libs version to b1385.
3. Update llama.cpp libs version to b1387.
4 changes: 2 additions & 2 deletions llamajava/llamajava.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,14 +418,14 @@ JNIEXPORT jint JNICALL Java_chat_octet_model_LlamaService_getTokenEOS
JNIEXPORT jint JNICALL Java_chat_octet_model_LlamaService_tokenize
(JNIEnv *env, jclass thisClass, jbyteArray buf, jint buffer_length,
jintArray tokens_arrays,
jint maxTokens, jboolean addBos) {
jint maxTokens, jboolean addBos, jboolean specialTokens) {
llama_token *tokens = (llama_token *) env->GetIntArrayElements(tokens_arrays, JNI_FALSE);

jbyte *buffer = new jbyte[buffer_length];
env->GetByteArrayRegion(buf, 0, buffer_length, buffer);
const char *text = (char *) buffer;

int code = llama_tokenize(model, text, buffer_length, tokens, maxTokens, ToCBool(addBos));
int code = llama_tokenize(model, text, buffer_length, tokens, maxTokens, ToCBool(addBos), ToCBool(specialTokens));
env->ReleaseIntArrayElements(tokens_arrays, tokens, 0);
env->ReleaseByteArrayElements(buf, buffer, 0);
return code;
Expand Down
2 changes: 1 addition & 1 deletion llamajava/llamajava.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ JNIEXPORT jint JNICALL Java_chat_octet_model_LlamaService_getTokenEOS
* Method: tokenize
*/
JNIEXPORT jint JNICALL Java_chat_octet_model_LlamaService_tokenize
(JNIEnv *, jclass, jbyteArray, jint, jintArray, jint, jboolean);
(JNIEnv *, jclass, jbyteArray, jint, jintArray, jint, jboolean, jboolean);

/*
* Class: chat_octet_model_LlamaService
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/chat/octet/model/Generator.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ protected Generator(GenerateParameter generateParams, String prompt, Status srcS
this.contextSize = LlamaService.getContextSize();
this.status = srcStatus == null ? new Status() : new Status(srcStatus);

int[] tokens = StringUtils.isNotBlank(prompt) ? LlamaService.tokenize(prompt, true) : new int[]{LlamaService.getTokenBOS()};
int[] tokens = StringUtils.isNotBlank(prompt) ? LlamaService.tokenize(prompt, true, true) : new int[]{LlamaService.getTokenBOS()};
if (tokens.length >= contextSize) {
throw new IllegalArgumentException(MessageFormat.format("Requested tokens ({0}) exceed context window of {1}.", tokens.length, contextSize));
}
Expand Down
22 changes: 12 additions & 10 deletions src/main/java/chat/octet/model/LlamaService.java
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,15 @@ public class LlamaService {
* The tokens pointer must be large enough to hold the resulting tokens.
* Returns the number of tokens on success, no more than n_max_tokens.
*
* @param buf Text byte buffer.
* @param bufferLength Text byte buffer length.
* @param tokens Empty token arrays, Used to receive the returned tokens.
* @param maxTokens Max token size, by default is context size.
* @param addBos Add special BOS token.
* @param buf Text byte buffer.
* @param bufferLength Text byte buffer length.
* @param tokens Empty token arrays, Used to receive the returned tokens.
* @param maxTokens Max token size, by default is context size.
* @param addBos Add special BOS token.
* @param specialTokens Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space.
* @return int, Returns a negative number on failure, else the number of tokens that would have been returned.
*/
public static native int tokenize(byte[] buf, int bufferLength, int[] tokens, int maxTokens, boolean addBos);
public static native int tokenize(byte[] buf, int bufferLength, int[] tokens, int maxTokens, boolean addBos, boolean specialTokens);

/**
* Convert the token id to text piece.
Expand Down Expand Up @@ -263,15 +264,16 @@ public static void clearCache(int sequenceId) {
/**
* Convert the provided text into tokens.
*
* @param text Input text.
* @param addBos Add special BOS token.
* @param text Input text.
* @param addBos Add special BOS token.
* @param specialTokens Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space.
* @return Returns a negative number on failure, else the number of tokens that would have been returned.
*/
public static int[] tokenize(String text, boolean addBos) {
public static int[] tokenize(String text, boolean addBos, boolean specialTokens) {
Preconditions.checkNotNull(text, "Text cannot be null");
int[] tokens = new int[getContextSize()];
byte[] textBytes = text.getBytes(StandardCharsets.UTF_8);
int nextTokens = tokenize(textBytes, textBytes.length, tokens, getContextSize(), addBos);
int nextTokens = tokenize(textBytes, textBytes.length, tokens, getContextSize(), addBos, specialTokens);
if (nextTokens < 0) {
throw new ModelException(MessageFormat.format("Failed to tokenize: {0}, next_tokens: {1}", text, nextTokens));
}
Expand Down

0 comments on commit b7acd3d

Please sign in to comment.