Skip to content

Commit

Permalink
Optimize code
Browse files Browse the repository at this point in the history
  • Loading branch information
eoctet committed Feb 24, 2024
1 parent a27eecd commit eed4c31
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 46 deletions.
23 changes: 1 addition & 22 deletions llama-java-core/src/main/java/chat/octet/model/Generator.java
Original file line number Diff line number Diff line change
Expand Up @@ -308,28 +308,7 @@ public Token next() {
int startIndex = Math.max(0, status.getInputLength() - generateParams.getLastTokensSize());
lastTokens = status.subInputIds(startIndex);
}
int tokenId = LlamaService.sampling(
logits,
lastTokens,
generateParams.getLastTokensSize(),
generateParams.getRepeatPenalty(),
generateParams.getFrequencyPenalty(),
generateParams.getPresencePenalty(),
generateParams.isPenalizeNl(),
generateParams.getMirostatMode().ordinal(),
generateParams.getMirostatTAU(),
generateParams.getMirostatETA(),
generateParams.getTemperature(),
generateParams.getTopK(),
generateParams.getTopP(),
generateParams.getTsf(),
generateParams.getTypical(),
generateParams.getMinP(),
generateParams.getDynatempRange(),
generateParams.getDynatempExponent(),
status.getId(),
status.getPastTokenSize()
);
int tokenId = LlamaService.sampling(generateParams, logits, lastTokens, status.getId(), status.getPastTokenSize());
Token token = new Token(tokenId, LlamaService.getLlamaTokenType(tokenId), tokenToText(tokenId));
//update generate status
status.appendNextToken(token);
Expand Down
57 changes: 37 additions & 20 deletions llama-java-core/src/main/java/chat/octet/model/LlamaService.java
Original file line number Diff line number Diff line change
Expand Up @@ -247,26 +247,43 @@ public class LlamaService {
* @return int, Returns the sampled token id.
* @see GenerateParameter
*/
public static native int sampling(float[] logits,
int[] lastTokens,
int lastTokensSize,
float penalty,
float alphaFrequency,
float alphaPresence,
boolean penalizeNL,
int mirostatMode,
float mirostatTAU,
float mirostatETA,
float temperature,
int topK,
float topP,
float tsf,
float typical,
float minP,
float dynatempRange,
float dynatempExponent,
int sequenceId,
int pastTokenSize) throws DecodeException;
public static native int sampling(float[] logits, int[] lastTokens, int lastTokensSize, float penalty, float alphaFrequency, float alphaPresence, boolean penalizeNL, int mirostatMode, float mirostatTAU, float mirostatETA, float temperature, int topK, float topP, float tsf, float typical, float minP, float dynatempRange, float dynatempExponent, int sequenceId, int pastTokenSize) throws DecodeException;

/**
* Inference sampling the next token.
*
* @param generateParams generation parameter.
* @param logits User-defined logits, Adjustments can be made via LogitsProcessor.
* @param lastTokens Last token array.
* @param sequenceId Generation sequence id.
* @param pastTokenSize Past token size.
* @return int, Returns the sampled token id.
* @see GenerateParameter
*/
public static int sampling(GenerateParameter generateParams, float[] logits, int[] lastTokens, int sequenceId, int pastTokenSize) throws DecodeException {
return sampling(
logits,
lastTokens,
generateParams.getLastTokensSize(),
generateParams.getRepeatPenalty(),
generateParams.getFrequencyPenalty(),
generateParams.getPresencePenalty(),
generateParams.isPenalizeNl(),
generateParams.getMirostatMode().ordinal(),
generateParams.getMirostatTAU(),
generateParams.getMirostatETA(),
generateParams.getTemperature(),
generateParams.getTopK(),
generateParams.getTopP(),
generateParams.getTsf(),
generateParams.getTypical(),
generateParams.getMinP(),
generateParams.getDynatempRange(),
generateParams.getDynatempExponent(),
sequenceId,
pastTokenSize
);
}

/**
* Load llama grammar by rules.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.util.Arrays;
import java.util.List;
import java.util.stream.IntStream;


public class StoppingWordCriteria implements StoppingCriteria {
Expand Down Expand Up @@ -42,9 +42,8 @@ public boolean criteria(@Nullable int[] inputTokenIds, @Nonnull float[] scores,
if (length > generateTokens.size()) {
continue;
}
List<Token> lastTokens = generateTokens.subList(generateTokens.size() - length, generateTokens.size());
int matched = (int) IntStream.range(0, length).filter(i -> tokens[i] == lastTokens.get(i).getId()).count();
if (matched == length) {
int[] lastTokens = generateTokens.subList(generateTokens.size() - length, generateTokens.size()).stream().mapToInt(Token::getId).toArray();
if (Arrays.equals(tokens, lastTokens)) {
return true;
}
}
Expand Down

0 comments on commit eed4c31

Please sign in to comment.