Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 050fe6e

Browse files
committed
feat: add model settings and prompt template from hf
Signed-off-by: James <namnh0122@gmail.com>
1 parent 0ae2c27 commit 050fe6e

File tree

10 files changed

+379
-7
lines changed

10 files changed

+379
-7
lines changed

cortex-js/src/command.module.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import { ModelRemoveCommand } from './infrastructure/commanders/models/model-rem
2121
import { RunCommand } from './infrastructure/commanders/shortcuts/run.command';
2222
import { InitCudaQuestions } from './infrastructure/commanders/questions/cuda.questions';
2323
import { CliUsecasesModule } from './infrastructure/commanders/usecases/cli.usecases.module';
24+
import { ModelUpdateCommand } from './infrastructure/commanders/models/model-update.command';
2425

2526
@Module({
2627
imports: [
@@ -55,6 +56,7 @@ import { CliUsecasesModule } from './infrastructure/commanders/usecases/cli.usec
5556
ModelGetCommand,
5657
ModelRemoveCommand,
5758
ModelPullCommand,
59+
ModelUpdateCommand,
5860

5961
// Shortcuts
6062
RunCommand,

cortex-js/src/infrastructure/commanders/chat.command.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ export class ChatCommand extends CommandRunner {
3232
}
3333

3434
@Option({
35-
flags: '--model <model_id>',
35+
flags: '-m, --model <model_id>',
3636
description: 'Model Id to start chat with',
3737
})
3838
parseModelId(value: string) {

cortex-js/src/infrastructure/commanders/models.command.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import { ModelListCommand } from './models/model-list.command';
55
import { ModelStopCommand } from './models/model-stop.command';
66
import { ModelPullCommand } from './models/model-pull.command';
77
import { ModelRemoveCommand } from './models/model-remove.command';
8+
import { ModelUpdateCommand } from './models/model-update.command';
89

910
@SubCommand({
1011
name: 'models',
@@ -15,6 +16,7 @@ import { ModelRemoveCommand } from './models/model-remove.command';
1516
ModelListCommand,
1617
ModelGetCommand,
1718
ModelRemoveCommand,
19+
ModelUpdateCommand,
1820
],
1921
description: 'Subcommands for managing models',
2022
})
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import { CommandRunner, SubCommand, Option } from 'nest-commander';
2+
import { ModelsCliUsecases } from '../usecases/models.cli.usecases';
3+
import { exit } from 'node:process';
4+
import { ModelParameterParser } from '../utils/model-parameter.parser';
5+
import {
6+
ModelRuntimeParams,
7+
ModelSettingParams,
8+
} from '@/domain/models/model.interface';
9+
10+
type UpdateOptions = {
11+
model?: string;
12+
options?: string[];
13+
};
14+
15+
@SubCommand({ name: 'update', description: 'Update configuration of a model.' })
16+
export class ModelUpdateCommand extends CommandRunner {
17+
constructor(private readonly modelsCliUsecases: ModelsCliUsecases) {
18+
super();
19+
}
20+
21+
async run(_input: string[], option: UpdateOptions): Promise<void> {
22+
const modelId = option.model;
23+
if (!modelId) {
24+
console.error('Model Id is required');
25+
exit(1);
26+
}
27+
28+
const options = option.options;
29+
if (!options || options.length === 0) {
30+
console.log('Nothing to update');
31+
exit(0);
32+
}
33+
34+
const parser = new ModelParameterParser();
35+
const settingParams: ModelSettingParams = {};
36+
const runtimeParams: ModelRuntimeParams = {};
37+
38+
options.forEach((option) => {
39+
const [key, stringValue] = option.split('=');
40+
if (parser.isModelSettingParam(key)) {
41+
const value = parser.parse(key, stringValue);
42+
// @ts-expect-error did the check so it's safe
43+
settingParams[key] = value;
44+
} else if (parser.isModelRuntimeParam(key)) {
45+
const value = parser.parse(key, stringValue);
46+
// @ts-expect-error did the check so it's safe
47+
runtimeParams[key] = value;
48+
}
49+
});
50+
51+
if (Object.keys(settingParams).length > 0) {
52+
const updatedSettingParams =
53+
await this.modelsCliUsecases.updateModelSettingParams(
54+
modelId,
55+
settingParams,
56+
);
57+
console.log(
58+
'Updated setting params! New setting params:',
59+
updatedSettingParams,
60+
);
61+
}
62+
63+
if (Object.keys(runtimeParams).length > 0) {
64+
await this.modelsCliUsecases.updateModelRuntimeParams(
65+
modelId,
66+
runtimeParams,
67+
);
68+
console.log('Updated runtime params! New runtime params:', runtimeParams);
69+
}
70+
}
71+
72+
@Option({
73+
flags: '-m, --model <model_id>',
74+
required: true,
75+
description: 'Model Id to update',
76+
})
77+
parseModelId(value: string) {
78+
return value;
79+
}
80+
81+
@Option({
82+
flags: '-c, --options <options...>',
83+
description:
84+
'Specify the options to update the model. Syntax: -c option1=value1 option2=value2. For example: cortex models update -c max_tokens=100 temperature=0.5',
85+
})
86+
parseOptions(option: string, optionsAccumulator: string[] = []): string[] {
87+
optionsAccumulator.push(option);
88+
return optionsAccumulator;
89+
}
90+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
//// HF Chat template
2+
export const OPEN_CHAT_3_5_JINJA = ``;
3+
4+
export const ZEPHYR_JINJA = `{% for message in messages %}
5+
{% if message['role'] == 'user' %}
6+
{{ '<|user|>
7+
' + message['content'] + eos_token }}
8+
{% elif message['role'] == 'system' %}
9+
{{ '<|system|>
10+
' + message['content'] + eos_token }}
11+
{% elif message['role'] == 'assistant' %}
12+
{{ '<|assistant|>
13+
' + message['content'] + eos_token }}
14+
{% endif %}
15+
{% if loop.last and add_generation_prompt %}
16+
{{ '<|assistant|>' }}
17+
{% endif %}
18+
{% endfor %}`;
19+
20+
//// Corresponding prompt template
21+
export const OPEN_CHAT_3_5 = `GPT4 Correct User: {prompt}<|end_of_turn|>GPT4 Correct Assistant:`;
22+
23+
export const ZEPHYR = `<|system|>
24+
{system_message}</s>
25+
<|user|>
26+
{prompt}</s>
27+
<|assistant|>
28+
`;
29+
30+
export const COMMAND_R = `<BOS_TOKEN><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{system}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>{prompt}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{response}
31+
`;
32+
33+
// getting from https://huggingface.co/TheBloke/Llama-2-70B-Chat-GGUF
34+
export const LLAMA_2 = `[INST] <<SYS>>
35+
{system_message}
36+
<</SYS>>
37+
{prompt}[/INST]`;

cortex-js/src/infrastructure/commanders/shortcuts/run.command.ts

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import { CommandRunner, SubCommand, Option } from 'nest-commander';
44
import { exit } from 'node:process';
55
import { ChatUsecases } from '@/usecases/chat/chat.usecases';
66
import { ChatCliUsecases } from '../usecases/chat.cli.usecases';
7+
import { defaultCortexCppHost, defaultCortexCppPort } from 'constant';
78

89
type RunOptions = {
910
model?: string;
@@ -29,7 +30,11 @@ export class RunCommand extends CommandRunner {
2930
exit(1);
3031
}
3132

32-
await this.cortexUsecases.startCortex();
33+
await this.cortexUsecases.startCortex(
34+
defaultCortexCppHost,
35+
defaultCortexCppPort,
36+
false,
37+
);
3338
await this.modelsUsecases.startModel(modelId);
3439
const chatCliUsecases = new ChatCliUsecases(
3540
this.chatUsecases,
@@ -39,7 +44,7 @@ export class RunCommand extends CommandRunner {
3944
}
4045

4146
@Option({
42-
flags: '--model <model_id>',
47+
flags: '-m, --model <model_id>',
4348
description: 'Model Id to start chat with',
4449
})
4550
parseModelId(value: string) {

cortex-js/src/infrastructure/commanders/usecases/models.cli.usecases.ts

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,24 @@
11
import { exit } from 'node:process';
22
import { ModelsUsecases } from '@/usecases/models/models.usecases';
3-
import { Model, ModelFormat } from '@/domain/models/model.interface';
3+
import {
4+
Model,
5+
ModelFormat,
6+
ModelRuntimeParams,
7+
ModelSettingParams,
8+
} from '@/domain/models/model.interface';
49
import { CreateModelDto } from '@/infrastructure/dtos/models/create-model.dto';
510
import { HuggingFaceRepoData } from '@/domain/models/huggingface.interface';
611
import { gguf } from '@huggingface/gguf';
712
import { InquirerService } from 'nest-commander';
813
import { Inject, Injectable } from '@nestjs/common';
914
import { Presets, SingleBar } from 'cli-progress';
15+
import {
16+
LLAMA_2,
17+
OPEN_CHAT_3_5,
18+
OPEN_CHAT_3_5_JINJA,
19+
ZEPHYR,
20+
ZEPHYR_JINJA,
21+
} from '../prompt-constants';
1022

1123
const AllQuantizations = [
1224
'Q3_K_S',
@@ -49,6 +61,20 @@ export class ModelsCliUsecases {
4961
await this.modelsUsecases.stopModel(modelId);
5062
}
5163

64+
async updateModelSettingParams(
65+
modelId: string,
66+
settingParams: ModelSettingParams,
67+
): Promise<ModelSettingParams> {
68+
return this.modelsUsecases.updateModelSettingParams(modelId, settingParams);
69+
}
70+
71+
async updateModelRuntimeParams(
72+
modelId: string,
73+
runtimeParams: ModelRuntimeParams,
74+
): Promise<ModelRuntimeParams> {
75+
return this.modelsUsecases.updateModelRuntimeParams(modelId, runtimeParams);
76+
}
77+
5278
private async getModelOrStop(modelId: string): Promise<Model> {
5379
const model = await this.modelsUsecases.findOne(modelId);
5480
if (!model) {
@@ -103,10 +129,16 @@ export class ModelsCliUsecases {
103129
if (!sibling) throw 'No expected quantization found';
104130

105131
let stopWord = '';
132+
let promptTemplate = LLAMA_2;
133+
106134
try {
107135
const { metadata } = await gguf(sibling.downloadUrl!);
108136
// @ts-expect-error "tokenizer.ggml.eos_token_id"
109137
const index = metadata['tokenizer.ggml.eos_token_id'];
138+
// @ts-expect-error "tokenizer.ggml.eos_token_id"
139+
const hfChatTemplate = metadata['tokenizer.chat_template'];
140+
promptTemplate = this.guessPromptTemplateFromHuggingFace(hfChatTemplate);
141+
110142
// @ts-expect-error "tokenizer.ggml.tokens"
111143
stopWord = metadata['tokenizer.ggml.tokens'][index] ?? '';
112144
} catch (err) {
@@ -129,7 +161,9 @@ export class ModelsCliUsecases {
129161
version: '',
130162
format: ModelFormat.GGUF,
131163
description: '',
132-
settings: {},
164+
settings: {
165+
prompt_template: promptTemplate,
166+
},
133167
parameters: {
134168
stop: stopWords,
135169
},
@@ -144,6 +178,37 @@ export class ModelsCliUsecases {
144178
await this.modelsUsecases.create(model);
145179
}
146180

181+
// TODO: move this to somewhere else, should be reused by API as well. Maybe in a separate service / provider?
182+
private guessPromptTemplateFromHuggingFace(jinjaCode?: string): string {
183+
if (!jinjaCode) {
184+
console.log('No jinja code provided. Returning default LLAMA_2');
185+
return LLAMA_2;
186+
}
187+
188+
if (typeof jinjaCode !== 'string') {
189+
console.log(
190+
`Invalid jinja code provided (type is ${typeof jinjaCode}). Returning default LLAMA_2`,
191+
);
192+
return LLAMA_2;
193+
}
194+
195+
switch (jinjaCode) {
196+
case ZEPHYR_JINJA:
197+
return ZEPHYR;
198+
199+
case OPEN_CHAT_3_5_JINJA:
200+
return OPEN_CHAT_3_5;
201+
202+
default:
203+
console.log(
204+
'Unknown jinja code:',
205+
jinjaCode,
206+
'Returning default LLAMA_2',
207+
);
208+
return LLAMA_2;
209+
}
210+
}
211+
147212
private async fetchHuggingFaceRepoData(repoId: string) {
148213
const sanitizedUrl = this.toHuggingFaceUrl(repoId);
149214

0 commit comments

Comments
 (0)