Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OpenAI] Support logit_bias in chat completion #331

Merged
merged 1 commit into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions examples/openai-api/src/openai_api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,14 @@ async function mainNonStreaming() {
],
n: 3,
temperature: 1.5,
max_gen_len: 25,
max_gen_len: 50,
// 13813 is "Florida", 10319 is "Texas", and 7660 is "Washington" in Llama-2-7b-chat
// So we would have a higher chance of seeing the latter two, but never the first in the answer
logit_bias: {
"13813": -100,
"10319": 5,
"7660": 5,
}
};

const reply0 = await chat.chatCompletion(request);
Expand Down Expand Up @@ -127,6 +134,6 @@ async function mainStateful() {
}

// Run one of the functions
// mainNonStreaming();
mainNonStreaming();
// mainStreaming();
mainStateful();
// mainStateful();
1 change: 1 addition & 0 deletions src/chat_module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ export class ChatModule implements ChatInterface {
stop: request.stop,
top_p: request.top_p,
temperature: request.temperature,
logit_bias: request.logit_bias,
}

// 1. If request is streaming, return an AsyncIterable (an iterable version of `generate()`)
Expand Down
17 changes: 17 additions & 0 deletions src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ export interface GenerationConfig {
presence_penalty?: number | null;
stop?: string | null | Array<string>;
n?: number | null;
logit_bias?: Record<string, number> | null;
}

export function postInitAndCheckGenerationConfigValues(config: GenerationConfig): void {
Expand Down Expand Up @@ -107,6 +108,22 @@ export function postInitAndCheckGenerationConfigValues(config: GenerationConfig)
config.frequency_penalty = 0.0;
console.log("Only presence_penalty is set; we default frequency_penalty to 0.")
}
// Check logit_bias range
if (_hasValue(config.logit_bias)) {
for (const tokenID in config.logit_bias) {
const bias = config.logit_bias[tokenID];
if (bias > 100 || bias < -100) {
throw new Error(
"logit_bias should be in range [-100, 100]; got " + bias + "for tokenID " + tokenID
);
}
if (isNaN(parseInt(tokenID))) {
throw new Error(
"Expect logit_bias's keys to be number represented in string; got " + tokenID
)
}
}
}
}

/**
Expand Down
24 changes: 20 additions & 4 deletions src/llm_chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,8 @@ export class LLMChatPipeline {
logitsOnGPU: tvmjs.NDArray,
genConfig?: GenerationConfig,
) {
// 0. Get value of temperature, top_p, and reptition_penalty, possibly overridden by genConfig
// 0. Get value of temperature, top_p, and various penalties, possibly overridden by genConfig
// Also load other genConfig items like logit_bias. Consume all fields of `genConfig` here.
function _hasValue(value: any): boolean {
return value !== undefined && value !== null;
}
Expand All @@ -624,6 +625,7 @@ export class LLMChatPipeline {
let repetition_penalty = this.config.repetition_penalty;
let frequency_penalty = undefined;
let presence_penalty = undefined;
let logit_bias = undefined;
if (genConfig !== undefined) {
if (_hasValue(genConfig.temperature)) { temperature = genConfig.temperature!; }
if (_hasValue(genConfig.top_p)) { top_p = genConfig.top_p!; }
Expand All @@ -633,6 +635,7 @@ export class LLMChatPipeline {
// If only one of frequency or presence penatly is set, make the other one 0.0
if (_hasValue(frequency_penalty) && !_hasValue(presence_penalty)) { presence_penalty = 0.0; }
if (_hasValue(presence_penalty) && !_hasValue(frequency_penalty)) { frequency_penalty = 0.0; }
if (_hasValue(genConfig.logit_bias)) { logit_bias = genConfig.logit_bias; }
}
// Check range validity
if (top_p <= 0 || top_p >= 1) { throw new Error("Make sure 0 < `top_p` < 1."); }
Expand All @@ -655,10 +658,23 @@ export class LLMChatPipeline {
throw Error("logits should be assigned");
}

// 2. Post process logits
if (this.logitProcessor !== undefined) {
// 2. Post process logits via logitProcessor and/or logit_bias
if (this.logitProcessor !== undefined || _hasValue(logit_bias)) {
let logitsOnCPUArray: Float32Array = <Float32Array>(this.logitsOnCPU.toArray());
logitsOnCPUArray = this.logitProcessor.processLogits(logitsOnCPUArray);
const vocab_size = logitsOnCPUArray.length;
if (this.logitProcessor !== undefined) {
logitsOnCPUArray = this.logitProcessor.processLogits(logitsOnCPUArray);
}
if (_hasValue(logit_bias)) {
for (const tokenID in logit_bias) {
const curBias = logit_bias[tokenID];
const curTokenID = parseInt(tokenID);
if (curTokenID > vocab_size) {
throw Error("Token " + curTokenID + " in logit_bias exceeds vocab_size " + vocab_size);
}
logitsOnCPUArray[curTokenID] += curBias;
}
}
this.logitsOnCPU.copyFrom(logitsOnCPUArray);
}

Expand Down
29 changes: 20 additions & 9 deletions src/openai_api_protocols/chat_completion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,26 @@ export interface ChatCompletionRequestBase {
*/
top_p?: number | null;

/**
* Modify the likelihood of specified tokens appearing in the completion.
*
* Accepts a JSON object that maps tokens (specified by their token ID, which varies per model)
* to an associated bias value from -100 to 100. Typically, you can see `tokenizer.json` of the
* model to see which token ID maps to what string. Mathematically, the bias is added to the
* logits generated by the model prior to sampling. The exact effect will vary per model, but
* values between -1 and 1 should decrease or increase likelihood of selection; values like -100
* or 100 should result in a ban or exclusive selection of the relevant token.
*
* As an example, you can pass `{"16230": -100}` to prevent the `Hello` token from being
* generated in Mistral-7B-Instruct-v0.2, according to the mapping in
* https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/raw/main/tokenizer.json.
*
* @note For stateful and customizable / flexible logit processing, see `webllm.LogitProcessor`.
* @note If used in combination with `webllm.LogitProcessor`, `logit_bias` is applied after
* `LogitProcessor.processLogits()` is called.
*/
logit_bias?: Record<string, number> | null;

//////////////// BELOW FIELDS NOT SUPPORTED YET ////////////////

/**
Expand All @@ -108,14 +128,6 @@ export interface ChatCompletionRequestBase {
*/
model?: string | null;

/**
*
* Modify the likelihood of specified tokens appearing in the completion.
*
* @note Not supported, see `webllm.LogitProcessor` instead.
*/
logit_bias?: Record<string, number> | null;

/**
* Whether to return log probabilities of the output tokens or not.
*
Expand Down Expand Up @@ -294,7 +306,6 @@ export interface ChatCompletionChunk {

export const ChatCompletionRequestUnsupportedFields: Array<string> = [
"model",
"logit_bias",
"logprobs",
"tool_choice",
"tools",
Expand Down
3 changes: 2 additions & 1 deletion src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ export type GenerateProgressCallback = (step: number, currentMessage: string) =>

/**
* A stateful logitProcessor used to post-process logits after forwarding the input and before
* sampling the next token.
* sampling the next token. If used with `GenerationConfig.logit_bias`, logit_bias is applied after
* `processLogits()` is called.
*/
export interface LogitProcessor {
/**
Expand Down
24 changes: 24 additions & 0 deletions tests/generation_config.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,30 @@ describe('Check generation config illegal values', () => {
postInitAndCheckGenerationConfigValues(genConfig)
}).toThrow("`max_gen_len` should be greater than zero.");
});

test('logit_bias exceeds range', () => {
expect(() => {
const genConfig: GenerationConfig = {
max_gen_len: 10,
logit_bias: {
"1355": 155
}
};
postInitAndCheckGenerationConfigValues(genConfig)
}).toThrow("logit_bias should be in range [-100, 100];");
});

test('logit_bias invalid key', () => {
expect(() => {
const genConfig: GenerationConfig = {
max_gen_len: 10,
logit_bias: {
"thisRaisesError": 50
}
};
postInitAndCheckGenerationConfigValues(genConfig)
}).toThrow("Expect logit_bias's keys to be number represented in string");
});
});

describe('Check generation post init', () => {
Expand Down