Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 6151fea

Browse files
committed
feat: pull model yaml from hf
1 parent 9a115ba commit 6151fea

File tree

4 files changed

+183
-25
lines changed

4 files changed

+183
-25
lines changed

cortex-js/package.json

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
},
2727
"dependencies": {
2828
"@huggingface/gguf": "^0.1.5",
29+
"@huggingface/hub": "^0.15.1",
2930
"@nestjs/axios": "^3.0.2",
3031
"@nestjs/common": "^10.0.0",
3132
"@nestjs/config": "^3.2.2",
@@ -47,7 +48,8 @@
4748
"sqlite": "^5.1.1",
4849
"sqlite3": "^5.1.7",
4950
"typeorm": "^0.3.20",
50-
"ulid": "^2.3.0"
51+
"ulid": "^2.3.0",
52+
"yaml": "^2.4.2"
5153
},
5254
"devDependencies": {
5355
"@nestjs/cli": "^10.0.0",
Lines changed: 106 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,127 @@
1-
import { CommandRunner, SubCommand } from 'nest-commander';
1+
import { CommandRunner, InquirerService, SubCommand } from 'nest-commander';
22
import { exit } from 'node:process';
33
import { ModelsCliUsecases } from '../usecases/models.cli.usecases';
4+
import { RepoDesignation, listFiles } from '@huggingface/hub';
5+
import YAML from 'yaml';
6+
import { basename } from 'node:path';
47

58
@SubCommand({
69
name: 'pull',
710
aliases: ['download'],
811
description: 'Download a model. Working with HuggingFace model id.',
912
})
1013
export class ModelPullCommand extends CommandRunner {
11-
constructor(private readonly modelsCliUsecases: ModelsCliUsecases) {
14+
private metadataFileName = 'metadata.yaml';
15+
private janHqModelPrefix = 'janhq';
16+
17+
constructor(
18+
private readonly inquirerService: InquirerService,
19+
private readonly modelsCliUsecases: ModelsCliUsecases,
20+
) {
1221
super();
1322
}
1423

1524
async run(input: string[]) {
1625
if (input.length < 1) {
17-
console.error('Model ID is required');
26+
console.error('Model Id is required');
1827
exit(1);
1928
}
2029

21-
await this.modelsCliUsecases.pullModel(input[0]);
30+
// Check if metadata.yaml file exist
31+
const metadata = await this.getJanMetadata(input[0]);
32+
33+
if (!metadata) {
34+
await this.modelsCliUsecases.pullModel(input[0]);
35+
} else {
36+
// if there's metadata.yaml file, we assumed it's a JanHQ model
37+
await this.handleJanHqModel(input[0], metadata);
38+
}
39+
2240
console.log('\nDownload complete!');
2341
exit(0);
2442
}
43+
44+
private async getJanMetadata(input: string): Promise<any> {
45+
// try to append with janhq/ if it's not already
46+
const sanitizedInput = input.trim().startsWith(this.janHqModelPrefix)
47+
? input
48+
: `${this.janHqModelPrefix}/${input}`;
49+
50+
const repo: RepoDesignation = { type: 'model', name: sanitizedInput };
51+
let isMetadataFileExist = false;
52+
for await (const fileInfo of listFiles({ repo })) {
53+
if (fileInfo.path === this.metadataFileName) {
54+
isMetadataFileExist = true;
55+
break;
56+
}
57+
}
58+
59+
if (!isMetadataFileExist) {
60+
return undefined;
61+
}
62+
63+
const path = `https://huggingface.co/${sanitizedInput}/raw/main/${this.metadataFileName}`;
64+
const res = await fetch(path);
65+
const metadataJson = await res.text();
66+
return YAML.parse(metadataJson);
67+
}
68+
69+
private async versionInquiry(tags: string[]): Promise<string> {
70+
const { tag } = await this.inquirerService.inquirer.prompt({
71+
type: 'list',
72+
name: 'tag',
73+
message: 'Select version',
74+
choices: tags,
75+
});
76+
77+
return tag;
78+
}
79+
80+
private async handleJanHqModel(repoName: string, metadata: any) {
81+
const sanitizedRepoName = repoName.trim().startsWith(this.janHqModelPrefix)
82+
? repoName
83+
: `${this.janHqModelPrefix}/${repoName}`;
84+
85+
const tags = metadata.tags;
86+
let selectedTag = 'default';
87+
const allTags: string[] = Object.keys(tags);
88+
89+
if (allTags.length > 1) {
90+
selectedTag = await this.versionInquiry(allTags);
91+
}
92+
93+
const branch = selectedTag;
94+
const engine = 'llamacpp'; // TODO: currently, we only support llamacpp
95+
96+
const revision = metadata.tags?.[branch]?.[engine];
97+
if (!revision) {
98+
console.error("Can't find model revision.");
99+
exit(1);
100+
}
101+
102+
const repo: RepoDesignation = { type: 'model', name: sanitizedRepoName };
103+
let ggufUrl: string | undefined = undefined;
104+
let fileSize = 0;
105+
for await (const fileInfo of listFiles({
106+
repo: repo,
107+
revision: revision,
108+
})) {
109+
if (fileInfo.path.endsWith('.gguf')) {
110+
ggufUrl = `https://huggingface.co/${sanitizedRepoName}/resolve/${revision}/${fileInfo.path}`;
111+
fileSize = fileInfo.size;
112+
break;
113+
}
114+
}
115+
116+
if (!ggufUrl) {
117+
console.error("Can't find model file.");
118+
exit(1);
119+
}
120+
console.log('Downloading', basename(ggufUrl));
121+
await this.modelsCliUsecases.pullModelWithExactUrl(
122+
`${sanitizedRepoName}/${revision}`,
123+
ggufUrl,
124+
fileSize,
125+
);
126+
}
25127
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
export interface ModelTokenizer {
2+
stopWord?: string;
3+
promptTemplate: string;
4+
}

cortex-js/src/infrastructure/commanders/usecases/models.cli.usecases.ts

Lines changed: 70 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import {
2121
ZEPHYR,
2222
ZEPHYR_JINJA,
2323
} from '../prompt-constants';
24+
import { ModelTokenizer } from '../types/model-tokenizer.interface';
2425

2526
const AllQuantizations = [
2627
'Q3_K_S',
@@ -139,7 +140,48 @@ export class ModelsCliUsecases {
139140
return this.modelsUsecases.remove(modelId);
140141
}
141142

142-
/**
143+
async pullModelWithExactUrl(modelId: string, url: string, fileSize: number) {
144+
const tokenizer = await this.getHFModelTokenizer(url);
145+
const promptTemplate = tokenizer?.promptTemplate ?? LLAMA_2;
146+
const stopWords: string[] = [tokenizer?.stopWord ?? ''];
147+
148+
const model: CreateModelDto = {
149+
sources: [
150+
{
151+
url: url,
152+
},
153+
],
154+
id: modelId,
155+
name: modelId,
156+
version: '',
157+
format: ModelFormat.GGUF,
158+
description: '',
159+
settings: {
160+
prompt_template: promptTemplate,
161+
},
162+
parameters: {
163+
stop: stopWords,
164+
},
165+
metadata: {
166+
author: 'janhq',
167+
size: fileSize,
168+
tags: [],
169+
},
170+
engine: 'cortex',
171+
};
172+
if (!(await this.modelsUsecases.findOne(modelId))) {
173+
await this.modelsUsecases.create(model);
174+
}
175+
176+
const bar = new SingleBar({}, Presets.shades_classic);
177+
bar.start(100, 0);
178+
const callback = (progress: number) => {
179+
bar.update(progress);
180+
};
181+
await this.modelsUsecases.downloadModel(modelId, callback);
182+
}
183+
184+
/**
143185
* Pull model from Model repository (HF, Jan...)
144186
* @param modelId
145187
*/
@@ -155,6 +197,30 @@ export class ModelsCliUsecases {
155197
await this.modelsUsecases.downloadModel(modelId, callback);
156198
}
157199

200+
private async getHFModelTokenizer(
201+
ggufUrl: string,
202+
): Promise<ModelTokenizer | undefined> {
203+
try {
204+
const { metadata } = await gguf(ggufUrl);
205+
// @ts-expect-error "tokenizer.ggml.eos_token_id"
206+
const index = metadata['tokenizer.ggml.eos_token_id'];
207+
// @ts-expect-error "tokenizer.ggml.eos_token_id"
208+
const hfChatTemplate = metadata['tokenizer.chat_template'];
209+
const promptTemplate =
210+
this.guessPromptTemplateFromHuggingFace(hfChatTemplate);
211+
// @ts-expect-error "tokenizer.ggml.tokens"
212+
const stopWord: string = metadata['tokenizer.ggml.tokens'][index] ?? '';
213+
214+
return {
215+
stopWord,
216+
promptTemplate,
217+
};
218+
} catch (err) {
219+
console.log('Failed to get model metadata:', err);
220+
return undefined;
221+
}
222+
}
223+
158224
//// PRIVATE METHODS ////
159225

160226
/**
@@ -193,26 +259,10 @@ export class ModelsCliUsecases {
193259
sibling = data.siblings.find((e) => e.rfilename.includes('.gguf'));
194260
}
195261
if (!sibling) throw 'No expected quantization found';
262+
const tokenizer = await this.getHFModelTokenizer(sibling.downloadUrl!);
196263

197-
let stopWord = '';
198-
let promptTemplate = LLAMA_2;
199-
200-
try {
201-
const { metadata } = await gguf(sibling.downloadUrl!);
202-
// @ts-expect-error "tokenizer.ggml.eos_token_id"
203-
const index = metadata['tokenizer.ggml.eos_token_id'];
204-
// @ts-expect-error "tokenizer.ggml.eos_token_id"
205-
const hfChatTemplate = metadata['tokenizer.chat_template'];
206-
promptTemplate = this.guessPromptTemplateFromHuggingFace(hfChatTemplate);
207-
208-
// @ts-expect-error "tokenizer.ggml.tokens"
209-
stopWord = metadata['tokenizer.ggml.tokens'][index] ?? '';
210-
} catch (err) {}
211-
212-
const stopWords: string[] = [];
213-
if (stopWord.length > 0) {
214-
stopWords.push(stopWord);
215-
}
264+
const promptTemplate = tokenizer?.promptTemplate ?? LLAMA_2;
265+
const stopWords: string[] = [tokenizer?.stopWord ?? ''];
216266

217267
const model: CreateModelDto = {
218268
sources: [

0 commit comments

Comments
 (0)