From caf6cdc9a87abca9614a863c2428d37232a61251 Mon Sep 17 00:00:00 2001 From: Branden Butler Date: Fri, 8 Sep 2023 23:12:15 -0500 Subject: [PATCH 1/2] Support Java 11 by backporting IntBuffer.slice(int, int) via a custom container class --- pom.xml | 4 +- src/main/java/de/kherud/llama/LlamaModel.java | 34 ++++---- .../de/kherud/llama/SliceableIntBuffer.java | 77 +++++++++++++++++++ 3 files changed, 96 insertions(+), 19 deletions(-) create mode 100644 src/main/java/de/kherud/llama/SliceableIntBuffer.java diff --git a/pom.xml b/pom.xml index bf3f1326..5ead070e 100644 --- a/pom.xml +++ b/pom.xml @@ -81,8 +81,8 @@ maven-compiler-plugin 3.11.0 - 13 - 13 + 11 + 11 diff --git a/src/main/java/de/kherud/llama/LlamaModel.java b/src/main/java/de/kherud/llama/LlamaModel.java index 3ac98863..ae734fb4 100644 --- a/src/main/java/de/kherud/llama/LlamaModel.java +++ b/src/main/java/de/kherud/llama/LlamaModel.java @@ -44,8 +44,8 @@ public class LlamaModel implements AutoCloseable { // cache some things for performance private final Pointer logitsPointer; // pointer to retrieve the logits of the llm - private final IntBuffer contextBuffer; // used to hold all tokens of a conversation - private final IntBuffer tokenBuffer; // used for tokenization + private final SliceableIntBuffer contextBuffer; // used to hold all tokens of a conversation + private final SliceableIntBuffer tokenBuffer; // used for tokenization private final byte[] tokenPieceBuffer; // used to decode tokens to string private final llama_token_data.ByReference[] candidateData; // candidates used for sampling private final llama_token_data_array candidates; // array holding the candidates @@ -88,8 +88,8 @@ public LlamaModel(String filePath, Parameters params) { } // setup some cached variables used throughout lifecycle - contextBuffer = IntBuffer.allocate(params.ctx.n_ctx); - tokenBuffer = IntBuffer.allocate(params.ctx.n_ctx); + contextBuffer = new SliceableIntBuffer(IntBuffer.allocate(params.ctx.n_ctx)); + tokenBuffer = new SliceableIntBuffer(IntBuffer.allocate(params.ctx.n_ctx)); tokenPieceBuffer = new byte[64]; nVocab = getVocabularySize(); @@ -148,7 +148,7 @@ public float[] getEmbedding(String prompt) { if (params.ctx.embedding == 0) { throw new IllegalStateException("embedding mode not activated (see parameters)"); } - IntBuffer tokens = tokenize(prompt, false); + SliceableIntBuffer tokens = tokenize(prompt, false); addContext(tokens); evaluate(); return LlamaLibrary.llama_get_embeddings(ctx).getPointer().getFloatArray(0, getEmbeddingSize()); @@ -171,9 +171,9 @@ public void reset() { * @return an array of integers each representing a token id (see {@link #getVocabularySize()}) */ public int[] encode(String prompt) { - IntBuffer buffer = tokenize(prompt, false); + SliceableIntBuffer buffer = tokenize(prompt, false); int[] tokens = new int[buffer.capacity()]; - System.arraycopy(buffer.array(), 0, tokens, 0, buffer.capacity()); + System.arraycopy(buffer.delegate.array(), 0, tokens, 0, buffer.capacity()); return tokens; } @@ -310,8 +310,8 @@ public String toString() { * @return an IntBuffer containing the tokenized prompt without any padding * @throws RuntimeException if tokenization fails */ - private IntBuffer tokenize(String prompt, boolean addBos) { - int nTokens = LlamaLibrary.llama_tokenize(ctx, prompt, tokenBuffer, params.ctx.n_ctx, addBos ? (byte) 1 : 0); + private SliceableIntBuffer tokenize(String prompt, boolean addBos) { + int nTokens = LlamaLibrary.llama_tokenize(ctx, prompt, tokenBuffer.delegate, params.ctx.n_ctx, addBos ? (byte) 1 : 0); if (nTokens < 0) { throw new RuntimeException("tokenization failed due to unknown reasons"); } @@ -324,7 +324,7 @@ private void evaluate() { if (nEval > params.ctx.n_batch) { nEval = params.ctx.n_batch; } - if (LlamaLibrary.llama_eval(ctx, contextBuffer.slice(nPast, nEval), nEval, nPast, params.nThreads) != 0) { + if (LlamaLibrary.llama_eval(ctx, contextBuffer.slice(nPast, nEval).delegate, nEval, nPast, params.nThreads) != 0) { String msg = String.format("evaluation failed (%d to evaluate, %d past, %d threads)", nEval, nPast, params.nThreads); log(LogLevel.ERROR, msg); throw new RuntimeException("token evaluation failed"); @@ -380,18 +380,18 @@ private void samplePenalty() { int repeat_last_n = params.repeatLastN < 0 ? params.ctx.n_ctx : params.repeatLastN; int last_n_repeat = Math.min(Math.min(nContext, repeat_last_n), params.ctx.n_ctx); NativeSize nTokens = new NativeSize(last_n_repeat); - IntBuffer lastTokens = tokenBuffer.slice(nContext - last_n_repeat, last_n_repeat); + SliceableIntBuffer lastTokens = tokenBuffer.slice(nContext - last_n_repeat, last_n_repeat); LlamaLibrary.llama_sample_repetition_penalty( ctx, candidates, - lastTokens, + lastTokens.delegate, nTokens, params.repeatPenalty ); LlamaLibrary.llama_sample_frequency_and_presence_penalties( ctx, candidates, - lastTokens, + lastTokens.delegate, nTokens, params.frequencyPenalty, params.presencePenalty @@ -462,9 +462,9 @@ private int sampleTopK() { return LlamaLibrary.llama_sample_token(ctx, candidates); } - private void addContext(IntBuffer tokens) { + private void addContext(SliceableIntBuffer tokens) { truncateContext(tokens.capacity()); - System.arraycopy(tokens.array(), 0, contextBuffer.array(), nContext, tokens.capacity()); + System.arraycopy(tokens.delegate.array(), 0, contextBuffer.delegate.array(), nContext, tokens.capacity()); nContext += tokens.capacity(); } @@ -473,7 +473,7 @@ private void truncateContext(int nAdd) { int nCtxKeep = params.ctx.n_ctx / 2 - nAdd; String msg = "truncating context from " + nContext + " to " + nCtxKeep + " tokens (+" + nAdd + " to add)"; log(LogLevel.INFO, msg); - System.arraycopy(contextBuffer.array(), nContext - nCtxKeep, contextBuffer.array(), 0, nCtxKeep); + System.arraycopy(contextBuffer.delegate.array(), nContext - nCtxKeep, contextBuffer.delegate.array(), 0, nCtxKeep); nPast = 0; nContext = nCtxKeep; } @@ -532,7 +532,7 @@ private void setup(String prompt) { if (nContext == 0 && !prompt.startsWith(" ")) { prompt = " " + prompt; } - IntBuffer tokens = tokenize(prompt, true); + SliceableIntBuffer tokens = tokenize(prompt, true); addContext(tokens); } diff --git a/src/main/java/de/kherud/llama/SliceableIntBuffer.java b/src/main/java/de/kherud/llama/SliceableIntBuffer.java new file mode 100644 index 00000000..615fc4ff --- /dev/null +++ b/src/main/java/de/kherud/llama/SliceableIntBuffer.java @@ -0,0 +1,77 @@ +package de.kherud.llama; + +import java.nio.IntBuffer; + + +/** + * Container that allows slicing an {@link IntBuffer} + * with arbitrary slicing lengths on Java versions older than 13. + * Does not extend IntBuffer because the super constructor + * requires memory segment proxies, and we can't access the delegate's. + * Does not implement Buffer because the {@link java.nio.Buffer#slice(int, int)} + * method is specifically blocked from being implemented or used on older jdk versions. + */ +public class SliceableIntBuffer { + public final IntBuffer delegate; + + private final int offset; + + private final int capacity; + + public SliceableIntBuffer(IntBuffer delegate) { + this.delegate = delegate; + this.capacity = delegate.capacity(); + this.offset = 0; + } + + public SliceableIntBuffer(IntBuffer delegate, int offset, int capacity) { + this.delegate = delegate; + this.offset = offset; + this.capacity = capacity; + } + + public SliceableIntBuffer slice(int offset, int length) { + // Where the magic happens + // Wrapping is equivalent to the slice operation so long + // as you keep track of your offsets and capacities. + // So, we use this container class to track those offsets and translate + // them to the correct frame of reference. + return new SliceableIntBuffer( + IntBuffer.wrap( + this.delegate.array(), + this.offset + offset, + length + ), + this.offset + offset, + length + ); + + } + + public int capacity() { + return capacity; + } + + public SliceableIntBuffer put(int index, int i) { + delegate.put(offset + index, i); + return this; + } + + public int get(int index) { + return delegate.get(offset + index); + } + + public void clear() { + // Clear set the limit and position + // to 0 and capacity respectively, + // but that's not what the buffer was initially + // after the wrap() call, so we manually + // set the limit and position to what they were + // after the wrap call. + delegate.clear(); + delegate.limit(offset + capacity); + delegate.position(offset); + } + + +} From db8f5f534e236e8b40b1e8a7f2c4656af7400392 Mon Sep 17 00:00:00 2001 From: Konstantin Herud Date: Sat, 9 Sep 2023 11:01:30 +0200 Subject: [PATCH 2/2] Make sliceable int buffer package private --- .../de/kherud/llama/SliceableIntBuffer.java | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/main/java/de/kherud/llama/SliceableIntBuffer.java b/src/main/java/de/kherud/llama/SliceableIntBuffer.java index 615fc4ff..2ce769e2 100644 --- a/src/main/java/de/kherud/llama/SliceableIntBuffer.java +++ b/src/main/java/de/kherud/llama/SliceableIntBuffer.java @@ -11,26 +11,26 @@ * Does not implement Buffer because the {@link java.nio.Buffer#slice(int, int)} * method is specifically blocked from being implemented or used on older jdk versions. */ -public class SliceableIntBuffer { - public final IntBuffer delegate; +class SliceableIntBuffer { + final IntBuffer delegate; private final int offset; private final int capacity; - public SliceableIntBuffer(IntBuffer delegate) { + SliceableIntBuffer(IntBuffer delegate) { this.delegate = delegate; this.capacity = delegate.capacity(); this.offset = 0; } - public SliceableIntBuffer(IntBuffer delegate, int offset, int capacity) { + SliceableIntBuffer(IntBuffer delegate, int offset, int capacity) { this.delegate = delegate; this.offset = offset; this.capacity = capacity; } - public SliceableIntBuffer slice(int offset, int length) { + SliceableIntBuffer slice(int offset, int length) { // Where the magic happens // Wrapping is equivalent to the slice operation so long // as you keep track of your offsets and capacities. @@ -48,20 +48,20 @@ public SliceableIntBuffer slice(int offset, int length) { } - public int capacity() { + int capacity() { return capacity; } - public SliceableIntBuffer put(int index, int i) { + SliceableIntBuffer put(int index, int i) { delegate.put(offset + index, i); return this; } - public int get(int index) { + int get(int index) { return delegate.get(offset + index); } - public void clear() { + void clear() { // Clear set the limit and position // to 0 and capacity respectively, // but that's not what the buffer was initially @@ -73,5 +73,4 @@ public void clear() { delegate.position(offset); } - }