Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@
<artifactId>maven-compiler-plugin</artifactId>
<version>3.11.0</version>
<configuration>
<source>13</source>
<target>13</target>
<source>11</source>
<target>11</target>
</configuration>
</plugin>
</plugins>
Expand Down
34 changes: 17 additions & 17 deletions src/main/java/de/kherud/llama/LlamaModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ public class LlamaModel implements AutoCloseable {

// cache some things for performance
private final Pointer logitsPointer; // pointer to retrieve the logits of the llm
private final IntBuffer contextBuffer; // used to hold all tokens of a conversation
private final IntBuffer tokenBuffer; // used for tokenization
private final SliceableIntBuffer contextBuffer; // used to hold all tokens of a conversation
private final SliceableIntBuffer tokenBuffer; // used for tokenization
private final byte[] tokenPieceBuffer; // used to decode tokens to string
private final llama_token_data.ByReference[] candidateData; // candidates used for sampling
private final llama_token_data_array candidates; // array holding the candidates
Expand Down Expand Up @@ -88,8 +88,8 @@ public LlamaModel(String filePath, Parameters params) {
}

// setup some cached variables used throughout lifecycle
contextBuffer = IntBuffer.allocate(params.ctx.n_ctx);
tokenBuffer = IntBuffer.allocate(params.ctx.n_ctx);
contextBuffer = new SliceableIntBuffer(IntBuffer.allocate(params.ctx.n_ctx));
tokenBuffer = new SliceableIntBuffer(IntBuffer.allocate(params.ctx.n_ctx));
tokenPieceBuffer = new byte[64];
nVocab = getVocabularySize();

Expand Down Expand Up @@ -148,7 +148,7 @@ public float[] getEmbedding(String prompt) {
if (params.ctx.embedding == 0) {
throw new IllegalStateException("embedding mode not activated (see parameters)");
}
IntBuffer tokens = tokenize(prompt, false);
SliceableIntBuffer tokens = tokenize(prompt, false);
addContext(tokens);
evaluate();
return LlamaLibrary.llama_get_embeddings(ctx).getPointer().getFloatArray(0, getEmbeddingSize());
Expand All @@ -171,9 +171,9 @@ public void reset() {
* @return an array of integers each representing a token id (see {@link #getVocabularySize()})
*/
public int[] encode(String prompt) {
IntBuffer buffer = tokenize(prompt, false);
SliceableIntBuffer buffer = tokenize(prompt, false);
int[] tokens = new int[buffer.capacity()];
System.arraycopy(buffer.array(), 0, tokens, 0, buffer.capacity());
System.arraycopy(buffer.delegate.array(), 0, tokens, 0, buffer.capacity());
return tokens;
}

Expand Down Expand Up @@ -310,8 +310,8 @@ public String toString() {
* @return an IntBuffer containing the tokenized prompt without any padding
* @throws RuntimeException if tokenization fails
*/
private IntBuffer tokenize(String prompt, boolean addBos) {
int nTokens = LlamaLibrary.llama_tokenize(ctx, prompt, tokenBuffer, params.ctx.n_ctx, addBos ? (byte) 1 : 0);
private SliceableIntBuffer tokenize(String prompt, boolean addBos) {
int nTokens = LlamaLibrary.llama_tokenize(ctx, prompt, tokenBuffer.delegate, params.ctx.n_ctx, addBos ? (byte) 1 : 0);
if (nTokens < 0) {
throw new RuntimeException("tokenization failed due to unknown reasons");
}
Expand All @@ -324,7 +324,7 @@ private void evaluate() {
if (nEval > params.ctx.n_batch) {
nEval = params.ctx.n_batch;
}
if (LlamaLibrary.llama_eval(ctx, contextBuffer.slice(nPast, nEval), nEval, nPast, params.nThreads) != 0) {
if (LlamaLibrary.llama_eval(ctx, contextBuffer.slice(nPast, nEval).delegate, nEval, nPast, params.nThreads) != 0) {
String msg = String.format("evaluation failed (%d to evaluate, %d past, %d threads)", nEval, nPast, params.nThreads);
log(LogLevel.ERROR, msg);
throw new RuntimeException("token evaluation failed");
Expand Down Expand Up @@ -380,18 +380,18 @@ private void samplePenalty() {
int repeat_last_n = params.repeatLastN < 0 ? params.ctx.n_ctx : params.repeatLastN;
int last_n_repeat = Math.min(Math.min(nContext, repeat_last_n), params.ctx.n_ctx);
NativeSize nTokens = new NativeSize(last_n_repeat);
IntBuffer lastTokens = tokenBuffer.slice(nContext - last_n_repeat, last_n_repeat);
SliceableIntBuffer lastTokens = tokenBuffer.slice(nContext - last_n_repeat, last_n_repeat);
LlamaLibrary.llama_sample_repetition_penalty(
ctx,
candidates,
lastTokens,
lastTokens.delegate,
nTokens,
params.repeatPenalty
);
LlamaLibrary.llama_sample_frequency_and_presence_penalties(
ctx,
candidates,
lastTokens,
lastTokens.delegate,
nTokens,
params.frequencyPenalty,
params.presencePenalty
Expand Down Expand Up @@ -462,9 +462,9 @@ private int sampleTopK() {
return LlamaLibrary.llama_sample_token(ctx, candidates);
}

private void addContext(IntBuffer tokens) {
private void addContext(SliceableIntBuffer tokens) {
truncateContext(tokens.capacity());
System.arraycopy(tokens.array(), 0, contextBuffer.array(), nContext, tokens.capacity());
System.arraycopy(tokens.delegate.array(), 0, contextBuffer.delegate.array(), nContext, tokens.capacity());
nContext += tokens.capacity();
}

Expand All @@ -473,7 +473,7 @@ private void truncateContext(int nAdd) {
int nCtxKeep = params.ctx.n_ctx / 2 - nAdd;
String msg = "truncating context from " + nContext + " to " + nCtxKeep + " tokens (+" + nAdd + " to add)";
log(LogLevel.INFO, msg);
System.arraycopy(contextBuffer.array(), nContext - nCtxKeep, contextBuffer.array(), 0, nCtxKeep);
System.arraycopy(contextBuffer.delegate.array(), nContext - nCtxKeep, contextBuffer.delegate.array(), 0, nCtxKeep);
nPast = 0;
nContext = nCtxKeep;
}
Expand Down Expand Up @@ -532,7 +532,7 @@ private void setup(String prompt) {
if (nContext == 0 && !prompt.startsWith(" ")) {
prompt = " " + prompt;
}
IntBuffer tokens = tokenize(prompt, true);
SliceableIntBuffer tokens = tokenize(prompt, true);
addContext(tokens);
}

Expand Down
76 changes: 76 additions & 0 deletions src/main/java/de/kherud/llama/SliceableIntBuffer.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package de.kherud.llama;

import java.nio.IntBuffer;


/**
* Container that allows slicing an {@link IntBuffer}
* with arbitrary slicing lengths on Java versions older than 13.
* Does not extend IntBuffer because the super constructor
* requires memory segment proxies, and we can't access the delegate's.
* Does not implement Buffer because the {@link java.nio.Buffer#slice(int, int)}
* method is specifically blocked from being implemented or used on older jdk versions.
*/
class SliceableIntBuffer {
final IntBuffer delegate;

private final int offset;

private final int capacity;

SliceableIntBuffer(IntBuffer delegate) {
this.delegate = delegate;
this.capacity = delegate.capacity();
this.offset = 0;
}

SliceableIntBuffer(IntBuffer delegate, int offset, int capacity) {
this.delegate = delegate;
this.offset = offset;
this.capacity = capacity;
}

SliceableIntBuffer slice(int offset, int length) {
// Where the magic happens
// Wrapping is equivalent to the slice operation so long
// as you keep track of your offsets and capacities.
// So, we use this container class to track those offsets and translate
// them to the correct frame of reference.
return new SliceableIntBuffer(
IntBuffer.wrap(
this.delegate.array(),
this.offset + offset,
length
),
this.offset + offset,
length
);

}

int capacity() {
return capacity;
}

SliceableIntBuffer put(int index, int i) {
delegate.put(offset + index, i);
return this;
}

int get(int index) {
return delegate.get(offset + index);
}

void clear() {
// Clear set the limit and position
// to 0 and capacity respectively,
// but that's not what the buffer was initially
// after the wrap() call, so we manually
// set the limit and position to what they were
// after the wrap call.
delegate.clear();
delegate.limit(offset + capacity);
delegate.position(offset);
}

}