diff --git a/pom.xml b/pom.xml
index bf3f1326..5ead070e 100644
--- a/pom.xml
+++ b/pom.xml
@@ -81,8 +81,8 @@
maven-compiler-plugin
3.11.0
- 13
- 13
+ 11
+ 11
diff --git a/src/main/java/de/kherud/llama/LlamaModel.java b/src/main/java/de/kherud/llama/LlamaModel.java
index 3ac98863..ae734fb4 100644
--- a/src/main/java/de/kherud/llama/LlamaModel.java
+++ b/src/main/java/de/kherud/llama/LlamaModel.java
@@ -44,8 +44,8 @@ public class LlamaModel implements AutoCloseable {
// cache some things for performance
private final Pointer logitsPointer; // pointer to retrieve the logits of the llm
- private final IntBuffer contextBuffer; // used to hold all tokens of a conversation
- private final IntBuffer tokenBuffer; // used for tokenization
+ private final SliceableIntBuffer contextBuffer; // used to hold all tokens of a conversation
+ private final SliceableIntBuffer tokenBuffer; // used for tokenization
private final byte[] tokenPieceBuffer; // used to decode tokens to string
private final llama_token_data.ByReference[] candidateData; // candidates used for sampling
private final llama_token_data_array candidates; // array holding the candidates
@@ -88,8 +88,8 @@ public LlamaModel(String filePath, Parameters params) {
}
// setup some cached variables used throughout lifecycle
- contextBuffer = IntBuffer.allocate(params.ctx.n_ctx);
- tokenBuffer = IntBuffer.allocate(params.ctx.n_ctx);
+ contextBuffer = new SliceableIntBuffer(IntBuffer.allocate(params.ctx.n_ctx));
+ tokenBuffer = new SliceableIntBuffer(IntBuffer.allocate(params.ctx.n_ctx));
tokenPieceBuffer = new byte[64];
nVocab = getVocabularySize();
@@ -148,7 +148,7 @@ public float[] getEmbedding(String prompt) {
if (params.ctx.embedding == 0) {
throw new IllegalStateException("embedding mode not activated (see parameters)");
}
- IntBuffer tokens = tokenize(prompt, false);
+ SliceableIntBuffer tokens = tokenize(prompt, false);
addContext(tokens);
evaluate();
return LlamaLibrary.llama_get_embeddings(ctx).getPointer().getFloatArray(0, getEmbeddingSize());
@@ -171,9 +171,9 @@ public void reset() {
* @return an array of integers each representing a token id (see {@link #getVocabularySize()})
*/
public int[] encode(String prompt) {
- IntBuffer buffer = tokenize(prompt, false);
+ SliceableIntBuffer buffer = tokenize(prompt, false);
int[] tokens = new int[buffer.capacity()];
- System.arraycopy(buffer.array(), 0, tokens, 0, buffer.capacity());
+ System.arraycopy(buffer.delegate.array(), 0, tokens, 0, buffer.capacity());
return tokens;
}
@@ -310,8 +310,8 @@ public String toString() {
* @return an IntBuffer containing the tokenized prompt without any padding
* @throws RuntimeException if tokenization fails
*/
- private IntBuffer tokenize(String prompt, boolean addBos) {
- int nTokens = LlamaLibrary.llama_tokenize(ctx, prompt, tokenBuffer, params.ctx.n_ctx, addBos ? (byte) 1 : 0);
+ private SliceableIntBuffer tokenize(String prompt, boolean addBos) {
+ int nTokens = LlamaLibrary.llama_tokenize(ctx, prompt, tokenBuffer.delegate, params.ctx.n_ctx, addBos ? (byte) 1 : 0);
if (nTokens < 0) {
throw new RuntimeException("tokenization failed due to unknown reasons");
}
@@ -324,7 +324,7 @@ private void evaluate() {
if (nEval > params.ctx.n_batch) {
nEval = params.ctx.n_batch;
}
- if (LlamaLibrary.llama_eval(ctx, contextBuffer.slice(nPast, nEval), nEval, nPast, params.nThreads) != 0) {
+ if (LlamaLibrary.llama_eval(ctx, contextBuffer.slice(nPast, nEval).delegate, nEval, nPast, params.nThreads) != 0) {
String msg = String.format("evaluation failed (%d to evaluate, %d past, %d threads)", nEval, nPast, params.nThreads);
log(LogLevel.ERROR, msg);
throw new RuntimeException("token evaluation failed");
@@ -380,18 +380,18 @@ private void samplePenalty() {
int repeat_last_n = params.repeatLastN < 0 ? params.ctx.n_ctx : params.repeatLastN;
int last_n_repeat = Math.min(Math.min(nContext, repeat_last_n), params.ctx.n_ctx);
NativeSize nTokens = new NativeSize(last_n_repeat);
- IntBuffer lastTokens = tokenBuffer.slice(nContext - last_n_repeat, last_n_repeat);
+ SliceableIntBuffer lastTokens = tokenBuffer.slice(nContext - last_n_repeat, last_n_repeat);
LlamaLibrary.llama_sample_repetition_penalty(
ctx,
candidates,
- lastTokens,
+ lastTokens.delegate,
nTokens,
params.repeatPenalty
);
LlamaLibrary.llama_sample_frequency_and_presence_penalties(
ctx,
candidates,
- lastTokens,
+ lastTokens.delegate,
nTokens,
params.frequencyPenalty,
params.presencePenalty
@@ -462,9 +462,9 @@ private int sampleTopK() {
return LlamaLibrary.llama_sample_token(ctx, candidates);
}
- private void addContext(IntBuffer tokens) {
+ private void addContext(SliceableIntBuffer tokens) {
truncateContext(tokens.capacity());
- System.arraycopy(tokens.array(), 0, contextBuffer.array(), nContext, tokens.capacity());
+ System.arraycopy(tokens.delegate.array(), 0, contextBuffer.delegate.array(), nContext, tokens.capacity());
nContext += tokens.capacity();
}
@@ -473,7 +473,7 @@ private void truncateContext(int nAdd) {
int nCtxKeep = params.ctx.n_ctx / 2 - nAdd;
String msg = "truncating context from " + nContext + " to " + nCtxKeep + " tokens (+" + nAdd + " to add)";
log(LogLevel.INFO, msg);
- System.arraycopy(contextBuffer.array(), nContext - nCtxKeep, contextBuffer.array(), 0, nCtxKeep);
+ System.arraycopy(contextBuffer.delegate.array(), nContext - nCtxKeep, contextBuffer.delegate.array(), 0, nCtxKeep);
nPast = 0;
nContext = nCtxKeep;
}
@@ -532,7 +532,7 @@ private void setup(String prompt) {
if (nContext == 0 && !prompt.startsWith(" ")) {
prompt = " " + prompt;
}
- IntBuffer tokens = tokenize(prompt, true);
+ SliceableIntBuffer tokens = tokenize(prompt, true);
addContext(tokens);
}
diff --git a/src/main/java/de/kherud/llama/SliceableIntBuffer.java b/src/main/java/de/kherud/llama/SliceableIntBuffer.java
new file mode 100644
index 00000000..2ce769e2
--- /dev/null
+++ b/src/main/java/de/kherud/llama/SliceableIntBuffer.java
@@ -0,0 +1,76 @@
+package de.kherud.llama;
+
+import java.nio.IntBuffer;
+
+
+/**
+ * Container that allows slicing an {@link IntBuffer}
+ * with arbitrary slicing lengths on Java versions older than 13.
+ * Does not extend IntBuffer because the super constructor
+ * requires memory segment proxies, and we can't access the delegate's.
+ * Does not implement Buffer because the {@link java.nio.Buffer#slice(int, int)}
+ * method is specifically blocked from being implemented or used on older jdk versions.
+ */
+class SliceableIntBuffer {
+ final IntBuffer delegate;
+
+ private final int offset;
+
+ private final int capacity;
+
+ SliceableIntBuffer(IntBuffer delegate) {
+ this.delegate = delegate;
+ this.capacity = delegate.capacity();
+ this.offset = 0;
+ }
+
+ SliceableIntBuffer(IntBuffer delegate, int offset, int capacity) {
+ this.delegate = delegate;
+ this.offset = offset;
+ this.capacity = capacity;
+ }
+
+ SliceableIntBuffer slice(int offset, int length) {
+ // Where the magic happens
+ // Wrapping is equivalent to the slice operation so long
+ // as you keep track of your offsets and capacities.
+ // So, we use this container class to track those offsets and translate
+ // them to the correct frame of reference.
+ return new SliceableIntBuffer(
+ IntBuffer.wrap(
+ this.delegate.array(),
+ this.offset + offset,
+ length
+ ),
+ this.offset + offset,
+ length
+ );
+
+ }
+
+ int capacity() {
+ return capacity;
+ }
+
+ SliceableIntBuffer put(int index, int i) {
+ delegate.put(offset + index, i);
+ return this;
+ }
+
+ int get(int index) {
+ return delegate.get(offset + index);
+ }
+
+ void clear() {
+ // Clear set the limit and position
+ // to 0 and capacity respectively,
+ // but that's not what the buffer was initially
+ // after the wrap() call, so we manually
+ // set the limit and position to what they were
+ // after the wrap call.
+ delegate.clear();
+ delegate.limit(offset + capacity);
+ delegate.position(offset);
+ }
+
+}