Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit d4b8a75

Browse files
committed
feat: pull model yaml from hf
1 parent 9a115ba commit d4b8a75

File tree

8 files changed

+215
-49
lines changed

8 files changed

+215
-49
lines changed

cortex-js/package.json

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
},
2727
"dependencies": {
2828
"@huggingface/gguf": "^0.1.5",
29+
"@huggingface/hub": "^0.15.1",
2930
"@nestjs/axios": "^3.0.2",
3031
"@nestjs/common": "^10.0.0",
3132
"@nestjs/config": "^3.2.2",
@@ -47,7 +48,8 @@
4748
"sqlite": "^5.1.1",
4849
"sqlite3": "^5.1.7",
4950
"typeorm": "^0.3.20",
50-
"ulid": "^2.3.0"
51+
"ulid": "^2.3.0",
52+
"yaml": "^2.4.2"
5153
},
5254
"devDependencies": {
5355
"@nestjs/cli": "^10.0.0",
Lines changed: 103 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,124 @@
1-
import { CommandRunner, SubCommand } from 'nest-commander';
1+
import { CommandRunner, InquirerService, SubCommand } from 'nest-commander';
22
import { exit } from 'node:process';
33
import { ModelsCliUsecases } from '../usecases/models.cli.usecases';
4+
import { RepoDesignation, listFiles } from '@huggingface/hub';
5+
import { basename } from 'node:path';
46

57
@SubCommand({
68
name: 'pull',
79
aliases: ['download'],
810
description: 'Download a model. Working with HuggingFace model id.',
911
})
1012
export class ModelPullCommand extends CommandRunner {
11-
constructor(private readonly modelsCliUsecases: ModelsCliUsecases) {
13+
private janHqModelPrefix = 'janhq';
14+
15+
constructor(
16+
private readonly inquirerService: InquirerService,
17+
private readonly modelsCliUsecases: ModelsCliUsecases,
18+
) {
1219
super();
1320
}
1421

1522
async run(input: string[]) {
1623
if (input.length < 1) {
17-
console.error('Model ID is required');
24+
console.error('Model Id is required');
1825
exit(1);
1926
}
2027

21-
await this.modelsCliUsecases.pullModel(input[0]);
28+
const branches = await this.tryToGetBranches(input[0]);
29+
30+
if (!branches) {
31+
await this.modelsCliUsecases.pullModel(input[0]);
32+
} else {
33+
// if there's metadata.yaml file, we assumed it's a JanHQ model
34+
await this.handleJanHqModel(input[0], branches);
35+
}
36+
2237
console.log('\nDownload complete!');
2338
exit(0);
2439
}
40+
41+
private async tryToGetBranches(input: string): Promise<any> {
42+
try {
43+
// try to append with janhq/ if it's not already
44+
const sanitizedInput = input.trim().startsWith(this.janHqModelPrefix)
45+
? input
46+
: `${this.janHqModelPrefix}/${input}`;
47+
48+
const repo: RepoDesignation = {
49+
type: 'model',
50+
name: sanitizedInput,
51+
};
52+
53+
for await (const _fileInfo of listFiles({ repo })) {
54+
break;
55+
}
56+
57+
const response = await fetch(
58+
`https://huggingface.co/api/models/${sanitizedInput}/refs`,
59+
);
60+
const data = await response.json();
61+
const branches: string[] = data.branches.map((branch: any) => {
62+
return branch.name;
63+
});
64+
65+
return branches;
66+
} catch (err) {
67+
return undefined;
68+
}
69+
}
70+
71+
private async versionInquiry(tags: string[]): Promise<string> {
72+
const { tag } = await this.inquirerService.inquirer.prompt({
73+
type: 'list',
74+
name: 'tag',
75+
message: 'Select version',
76+
choices: tags,
77+
});
78+
79+
return tag;
80+
}
81+
82+
private async handleJanHqModel(repoName: string, branches: string[]) {
83+
const sanitizedRepoName = repoName.trim().startsWith(this.janHqModelPrefix)
84+
? repoName
85+
: `${this.janHqModelPrefix}/${repoName}`;
86+
87+
let selectedTag = branches[0];
88+
89+
if (branches.length > 1) {
90+
selectedTag = await this.versionInquiry(branches);
91+
}
92+
93+
const revision = selectedTag;
94+
if (!revision) {
95+
console.error("Can't find model revision.");
96+
exit(1);
97+
}
98+
99+
const repo: RepoDesignation = { type: 'model', name: sanitizedRepoName };
100+
let ggufUrl: string | undefined = undefined;
101+
let fileSize = 0;
102+
for await (const fileInfo of listFiles({
103+
repo: repo,
104+
revision: revision,
105+
})) {
106+
if (fileInfo.path.endsWith('.gguf')) {
107+
ggufUrl = `https://huggingface.co/${sanitizedRepoName}/resolve/${revision}/${fileInfo.path}`;
108+
fileSize = fileInfo.size;
109+
break;
110+
}
111+
}
112+
113+
if (!ggufUrl) {
114+
console.error("Can't find model file.");
115+
exit(1);
116+
}
117+
console.log('Downloading', basename(ggufUrl));
118+
await this.modelsCliUsecases.pullModelWithExactUrl(
119+
`${sanitizedRepoName}/${revision}`,
120+
ggufUrl,
121+
fileSize,
122+
);
123+
}
25124
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
export interface ModelTokenizer {
2+
stopWord?: string;
3+
promptTemplate: string;
4+
}

cortex-js/src/infrastructure/commanders/usecases/init.cli.usecases.ts

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import decompress from 'decompress';
66
import { exit } from 'node:process';
77
import { InitOptions } from '../types/init-options.interface';
88
import { Injectable } from '@nestjs/common';
9+
import { firstValueFrom } from 'rxjs';
910

1011
@Injectable()
1112
export class InitCliUsecases {
@@ -19,17 +20,17 @@ export class InitCliUsecases {
1920
engineFileName: string,
2021
version: string = 'latest',
2122
): Promise<any> => {
22-
const res = await this.httpService
23-
.get(
23+
const res = await firstValueFrom(
24+
this.httpService.get(
2425
this.CORTEX_RELEASES_URL + `${version === 'latest' ? '/latest' : ''}`,
2526
{
2627
headers: {
2728
'X-GitHub-Api-Version': '2022-11-28',
2829
Accept: 'application/vnd.github+json',
2930
},
3031
},
31-
)
32-
.toPromise();
32+
),
33+
);
3334

3435
if (!res?.data) {
3536
console.log('Failed to fetch releases');
@@ -55,11 +56,11 @@ export class InitCliUsecases {
5556
const engineDir = resolve(this.rootDir(), 'cortex-cpp');
5657
if (existsSync(engineDir)) rmSync(engineDir, { recursive: true });
5758

58-
const download = await this.httpService
59-
.get(toDownloadAsset.browser_download_url, {
59+
const download = await firstValueFrom(
60+
this.httpService.get(toDownloadAsset.browser_download_url, {
6061
responseType: 'stream',
61-
})
62-
.toPromise();
62+
}),
63+
);
6364
if (!download) {
6465
console.log('Failed to download model');
6566
process.exit(1);
@@ -183,11 +184,11 @@ export class InitCliUsecases {
183184
).replace('<platform>', platform);
184185
const destination = resolve(this.rootDir(), 'cuda-toolkit.tar.gz');
185186

186-
const download = await this.httpService
187-
.get(url, {
187+
const download = await firstValueFrom(
188+
this.httpService.get(url, {
188189
responseType: 'stream',
189-
})
190-
.toPromise();
190+
}),
191+
);
191192

192193
if (!download) {
193194
console.log('Failed to download dependency');

cortex-js/src/infrastructure/commanders/usecases/models.cli.usecases.ts

Lines changed: 74 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ import {
2121
ZEPHYR,
2222
ZEPHYR_JINJA,
2323
} from '../prompt-constants';
24+
import { ModelTokenizer } from '../types/model-tokenizer.interface';
25+
import { HttpService } from '@nestjs/axios';
26+
import { firstValueFrom } from 'rxjs';
2427

2528
const AllQuantizations = [
2629
'Q3_K_S',
@@ -51,6 +54,7 @@ export class ModelsCliUsecases {
5154
private readonly modelsUsecases: ModelsUsecases,
5255
@Inject(InquirerService)
5356
private readonly inquirerService: InquirerService,
57+
private readonly httpService: HttpService,
5458
) {}
5559

5660
/**
@@ -139,6 +143,47 @@ export class ModelsCliUsecases {
139143
return this.modelsUsecases.remove(modelId);
140144
}
141145

146+
async pullModelWithExactUrl(modelId: string, url: string, fileSize: number) {
147+
const tokenizer = await this.getHFModelTokenizer(url);
148+
const promptTemplate = tokenizer?.promptTemplate ?? LLAMA_2;
149+
const stopWords: string[] = [tokenizer?.stopWord ?? ''];
150+
151+
const model: CreateModelDto = {
152+
sources: [
153+
{
154+
url: url,
155+
},
156+
],
157+
id: modelId,
158+
name: modelId,
159+
version: '',
160+
format: ModelFormat.GGUF,
161+
description: '',
162+
settings: {
163+
prompt_template: promptTemplate,
164+
},
165+
parameters: {
166+
stop: stopWords,
167+
},
168+
metadata: {
169+
author: 'janhq',
170+
size: fileSize,
171+
tags: [],
172+
},
173+
engine: 'cortex',
174+
};
175+
if (!(await this.modelsUsecases.findOne(modelId))) {
176+
await this.modelsUsecases.create(model);
177+
}
178+
179+
const bar = new SingleBar({}, Presets.shades_classic);
180+
bar.start(100, 0);
181+
const callback = (progress: number) => {
182+
bar.update(progress);
183+
};
184+
await this.modelsUsecases.downloadModel(modelId, callback);
185+
}
186+
142187
/**
143188
* Pull model from Model repository (HF, Jan...)
144189
* @param modelId
@@ -155,6 +200,30 @@ export class ModelsCliUsecases {
155200
await this.modelsUsecases.downloadModel(modelId, callback);
156201
}
157202

203+
private async getHFModelTokenizer(
204+
ggufUrl: string,
205+
): Promise<ModelTokenizer | undefined> {
206+
try {
207+
const { metadata } = await gguf(ggufUrl);
208+
// @ts-expect-error "tokenizer.ggml.eos_token_id"
209+
const index = metadata['tokenizer.ggml.eos_token_id'];
210+
// @ts-expect-error "tokenizer.ggml.eos_token_id"
211+
const hfChatTemplate = metadata['tokenizer.chat_template'];
212+
const promptTemplate =
213+
this.guessPromptTemplateFromHuggingFace(hfChatTemplate);
214+
// @ts-expect-error "tokenizer.ggml.tokens"
215+
const stopWord: string = metadata['tokenizer.ggml.tokens'][index] ?? '';
216+
217+
return {
218+
stopWord,
219+
promptTemplate,
220+
};
221+
} catch (err) {
222+
console.log('Failed to get model metadata:', err);
223+
return undefined;
224+
}
225+
}
226+
158227
//// PRIVATE METHODS ////
159228

160229
/**
@@ -193,26 +262,10 @@ export class ModelsCliUsecases {
193262
sibling = data.siblings.find((e) => e.rfilename.includes('.gguf'));
194263
}
195264
if (!sibling) throw 'No expected quantization found';
265+
const tokenizer = await this.getHFModelTokenizer(sibling.downloadUrl!);
196266

197-
let stopWord = '';
198-
let promptTemplate = LLAMA_2;
199-
200-
try {
201-
const { metadata } = await gguf(sibling.downloadUrl!);
202-
// @ts-expect-error "tokenizer.ggml.eos_token_id"
203-
const index = metadata['tokenizer.ggml.eos_token_id'];
204-
// @ts-expect-error "tokenizer.ggml.eos_token_id"
205-
const hfChatTemplate = metadata['tokenizer.chat_template'];
206-
promptTemplate = this.guessPromptTemplateFromHuggingFace(hfChatTemplate);
207-
208-
// @ts-expect-error "tokenizer.ggml.tokens"
209-
stopWord = metadata['tokenizer.ggml.tokens'][index] ?? '';
210-
} catch (err) {}
211-
212-
const stopWords: string[] = [];
213-
if (stopWord.length > 0) {
214-
stopWords.push(stopWord);
215-
}
267+
const promptTemplate = tokenizer?.promptTemplate ?? LLAMA_2;
268+
const stopWords: string[] = [tokenizer?.stopWord ?? ''];
216269

217270
const model: CreateModelDto = {
218271
sources: [
@@ -343,8 +396,8 @@ export class ModelsCliUsecases {
343396
private async fetchHuggingFaceRepoData(repoId: string) {
344397
const sanitizedUrl = this.getRepoModelsUrl(repoId);
345398

346-
const res = await fetch(sanitizedUrl);
347-
const response = await res.json();
399+
const res = await firstValueFrom(this.httpService.get(sanitizedUrl));
400+
const response = res.data;
348401
if (response['error'] != null) {
349402
throw new Error(response['error']);
350403
}

cortex-js/src/infrastructure/providers/cortex/cortex.provider.ts

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import { HttpService } from '@nestjs/axios';
77
import { defaultCortexCppHost, defaultCortexCppPort } from 'constant';
88
import { readdirSync } from 'node:fs';
99
import { normalizeModelId } from '@/infrastructure/commanders/utils/normalize-model-id';
10+
import { firstValueFrom } from 'rxjs';
1011

1112
/**
1213
* A class that implements the InferenceExtension interface from the @janhq/core package.
@@ -72,13 +73,15 @@ export default class CortexProvider extends OAIEngineExtension {
7273
modelSettings.ai_prompt = prompt.ai_prompt;
7374
}
7475

75-
await this.httpService.post(this.loadModelUrl, modelSettings).toPromise();
76+
await firstValueFrom(
77+
this.httpService.post(this.loadModelUrl, modelSettings),
78+
);
7679
}
7780

7881
override async unloadModel(modelId: string): Promise<void> {
79-
await this.httpService
80-
.post(this.unloadModelUrl, { model: modelId })
81-
.toPromise();
82+
await firstValueFrom(
83+
this.httpService.post(this.unloadModelUrl, { model: modelId }),
84+
);
8285
}
8386

8487
private readonly promptTemplateConverter = (

0 commit comments

Comments
 (0)