Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 1084844

Browse files
authored
feat: pull model from jan hub (#606)
1 parent d31a788 commit 1084844

File tree

5 files changed

+159
-42
lines changed

5 files changed

+159
-42
lines changed

cortex-js/src/infrastructure/commanders/usecases/models.cli.usecases.ts

Lines changed: 147 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -51,30 +51,55 @@ export class ModelsCliUsecases {
5151
private readonly inquirerService: InquirerService,
5252
) {}
5353

54+
/**
55+
* Start a model by ID
56+
* @param modelId
57+
*/
5458
async startModel(modelId: string): Promise<void> {
5559
await this.getModelOrStop(modelId);
5660
await this.modelsUsecases.startModel(modelId);
5761
}
5862

63+
/**
64+
* Stop a model by ID
65+
* @param modelId
66+
*/
5967
async stopModel(modelId: string): Promise<void> {
6068
await this.getModelOrStop(modelId);
6169
await this.modelsUsecases.stopModel(modelId);
6270
}
6371

72+
/**
73+
* Update model's settings. E.g. ngl, prompt_template, etc.
74+
* @param modelId
75+
* @param settingParams
76+
* @returns
77+
*/
6478
async updateModelSettingParams(
6579
modelId: string,
6680
settingParams: ModelSettingParams,
6781
): Promise<ModelSettingParams> {
6882
return this.modelsUsecases.updateModelSettingParams(modelId, settingParams);
6983
}
7084

85+
/**
86+
* Update model's runtime parameters. E.g. max_tokens, temperature, etc.
87+
* @param modelId
88+
* @param runtimeParams
89+
* @returns
90+
*/
7191
async updateModelRuntimeParams(
7292
modelId: string,
7393
runtimeParams: ModelRuntimeParams,
7494
): Promise<ModelRuntimeParams> {
7595
return this.modelsUsecases.updateModelRuntimeParams(modelId, runtimeParams);
7696
}
7797

98+
/**
99+
* Find a model or abort if not exist
100+
* @param modelId
101+
* @returns
102+
*/
78103
private async getModelOrStop(modelId: string): Promise<Model> {
79104
const model = await this.modelsUsecases.findOne(modelId);
80105
if (!model) {
@@ -84,25 +109,42 @@ export class ModelsCliUsecases {
84109
return model;
85110
}
86111

112+
/**
113+
* List all of the models
114+
* @returns
115+
*/
87116
async listAllModels(): Promise<Model[]> {
88117
return this.modelsUsecases.findAll();
89118
}
90119

120+
/**
121+
* Get a model by ID
122+
* @param modelId
123+
* @returns
124+
*/
91125
async getModel(modelId: string): Promise<Model> {
92126
const model = await this.getModelOrStop(modelId);
93127
return model;
94128
}
95129

130+
/**
131+
* Remove a model, this would also delete model files
132+
* @param modelId
133+
* @returns
134+
*/
96135
async removeModel(modelId: string) {
97136
await this.getModelOrStop(modelId);
98137
return this.modelsUsecases.remove(modelId);
99138
}
100139

140+
/**
141+
* Pull model from Model repository (HF, Jan...)
142+
* @param modelId
143+
*/
101144
async pullModel(modelId: string) {
102-
if (modelId.includes('/')) {
145+
if (modelId.includes('/') || modelId.includes(':')) {
103146
await this.pullHuggingFaceModel(modelId);
104147
}
105-
106148
const bar = new SingleBar({}, Presets.shades_classic);
107149
bar.start(100, 0);
108150
const callback = (progress: number) => {
@@ -111,21 +153,43 @@ export class ModelsCliUsecases {
111153
await this.modelsUsecases.downloadModel(modelId, callback);
112154
}
113155

114-
private async pullHuggingFaceModel(modelId: string) {
115-
const data = await this.fetchHuggingFaceRepoData(modelId);
116-
const { quantization } = await this.inquirerService.inquirer.prompt({
117-
type: 'list',
118-
name: 'quantization',
119-
message: 'Select quantization',
120-
choices: data.siblings
121-
.map((e) => e.quantization)
122-
.filter((e) => e != null),
123-
});
156+
//// PRIVATE METHODS ////
124157

125-
const sibling = data.siblings
126-
.filter((e) => !!e.quantization)
127-
.find((e: any) => e.quantization === quantization);
158+
/**
159+
* It's to pull model from HuggingFace repository
160+
* It could be a model from Jan's repo or other authors
161+
* @param modelId HuggingFace model id. e.g. "janhq/llama-3 or llama3:7b"
162+
*/
163+
private async pullHuggingFaceModel(modelId: string) {
164+
let data: HuggingFaceRepoData;
165+
if (modelId.includes('/'))
166+
data = await this.fetchHuggingFaceRepoData(modelId);
167+
else data = await this.fetchJanRepoData(modelId);
168+
169+
let sibling;
170+
171+
const listChoices = data.siblings
172+
.filter((e) => e.quantization != null)
173+
.map((e) => {
174+
return {
175+
name: e.quantization,
176+
value: e.quantization,
177+
};
178+
});
128179

180+
if (listChoices.length > 1) {
181+
const { quantization } = await this.inquirerService.inquirer.prompt({
182+
type: 'list',
183+
name: 'quantization',
184+
message: 'Select quantization',
185+
choices: listChoices,
186+
});
187+
sibling = data.siblings
188+
.filter((e) => !!e.quantization)
189+
.find((e: any) => e.quantization === quantization);
190+
} else {
191+
sibling = data.siblings.find((e) => e.rfilename.includes('.gguf'));
192+
}
129193
if (!sibling) throw 'No expected quantization found';
130194

131195
let stopWord = '';
@@ -141,9 +205,7 @@ export class ModelsCliUsecases {
141205

142206
// @ts-expect-error "tokenizer.ggml.tokens"
143207
stopWord = metadata['tokenizer.ggml.tokens'][index] ?? '';
144-
} catch (err) {
145-
console.log('Failed to get stop word: ', err);
146-
}
208+
} catch (err) {}
147209

148210
const stopWords: string[] = [];
149211
if (stopWord.length > 0) {
@@ -163,6 +225,7 @@ export class ModelsCliUsecases {
163225
description: '',
164226
settings: {
165227
prompt_template: promptTemplate,
228+
llama_model_path: sibling.rfilename,
166229
},
167230
parameters: {
168231
stop: stopWords,
@@ -209,8 +272,71 @@ export class ModelsCliUsecases {
209272
}
210273
}
211274

275+
/**
276+
* Fetch the model data from Jan's repo
277+
* @param modelId HuggingFace model id. e.g. "llama-3:7b"
278+
* @returns
279+
*/
280+
private async fetchJanRepoData(modelId: string) {
281+
const repo = modelId.split(':')[0];
282+
const tree = modelId.split(':')[1];
283+
const url = this.getRepoModelsUrl(`janhq/${repo}`, tree);
284+
const res = await fetch(url);
285+
const response:
286+
| {
287+
path: string;
288+
size: number;
289+
}[]
290+
| { error: string } = await res.json();
291+
292+
if ('error' in response && response.error != null) {
293+
throw new Error(response.error);
294+
}
295+
296+
const data: HuggingFaceRepoData = {
297+
siblings: Array.isArray(response)
298+
? response.map((e) => {
299+
return {
300+
rfilename: e.path,
301+
downloadUrl: `https://huggingface.co/janhq/${repo}/resolve/${tree}/${e.path}`,
302+
fileSize: e.size ?? 0,
303+
};
304+
})
305+
: [],
306+
tags: ['gguf'],
307+
id: modelId,
308+
modelId: modelId,
309+
author: 'janhq',
310+
sha: '',
311+
downloads: 0,
312+
lastModified: '',
313+
private: false,
314+
disabled: false,
315+
gated: false,
316+
pipeline_tag: 'text-generation',
317+
cardData: {},
318+
createdAt: '',
319+
};
320+
321+
AllQuantizations.forEach((quantization) => {
322+
data.siblings.forEach((sibling: any) => {
323+
if (!sibling.quantization && sibling.rfilename.includes(quantization)) {
324+
sibling.quantization = quantization;
325+
}
326+
});
327+
});
328+
329+
data.modelUrl = url;
330+
return data;
331+
}
332+
333+
/**
334+
* Fetches the model data from HuggingFace API
335+
* @param repoId HuggingFace model id. e.g. "janhq/llama-3"
336+
* @returns
337+
*/
212338
private async fetchHuggingFaceRepoData(repoId: string) {
213-
const sanitizedUrl = this.toHuggingFaceUrl(repoId);
339+
const sanitizedUrl = this.getRepoModelsUrl(repoId);
214340

215341
const res = await fetch(sanitizedUrl);
216342
const response = await res.json();
@@ -245,24 +371,7 @@ export class ModelsCliUsecases {
245371
return data;
246372
}
247373

248-
private toHuggingFaceUrl(repoId: string): string {
249-
try {
250-
const url = new URL(`https://huggingface.co/${repoId}`);
251-
if (url.host !== 'huggingface.co') {
252-
throw `Invalid Hugging Face repo URL: ${repoId}`;
253-
}
254-
255-
const paths = url.pathname.split('/').filter((e) => e.trim().length > 0);
256-
if (paths.length < 2) {
257-
throw `Invalid Hugging Face repo URL: ${repoId}`;
258-
}
259-
260-
return `${url.origin}/api/models/${paths[0]}/${paths[1]}`;
261-
} catch (err) {
262-
if (repoId.startsWith('https')) {
263-
throw new Error(`Cannot parse url: ${repoId}`);
264-
}
265-
throw err;
266-
}
374+
private getRepoModelsUrl(repoId: string, tree?: string): string {
375+
return `https://huggingface.co/api/models/${repoId}${tree ? `/tree/${tree}` : ''}`;
267376
}
268377
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
export const normalizeModelId = (modelId: string): string => {
2+
return modelId.replace(':', '%3A');
3+
};

cortex-js/src/infrastructure/providers/cortex/cortex.provider.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import { Model, ModelSettingParams } from '@/domain/models/model.interface';
66
import { HttpService } from '@nestjs/axios';
77
import { defaultCortexCppHost, defaultCortexCppPort } from 'constant';
88
import { readdirSync } from 'node:fs';
9+
import { normalizeModelId } from '@/infrastructure/commanders/utils/normalize-model-id';
910

1011
/**
1112
* A class that implements the InferenceExtension interface from the @janhq/core package.
@@ -32,7 +33,10 @@ export default class CortexProvider extends OAIEngineExtension {
3233
): Promise<void> {
3334
const modelsContainerDir = this.modelDir();
3435

35-
const modelFolderFullPath = join(modelsContainerDir, model.id);
36+
const modelFolderFullPath = join(
37+
modelsContainerDir,
38+
normalizeModelId(model.id),
39+
);
3640
const ggufFiles = readdirSync(modelFolderFullPath).filter((file) => {
3741
return file.endsWith('.gguf');
3842
});

cortex-js/src/usecases/cortex/cortex.usecases.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ export class CortexUsecases {
3232
);
3333

3434
if (!existsSync(cortexCppPath)) {
35-
throw new Error('Cortex binary not found');
35+
throw new Error('The engine is not available, please run "cortex init".');
3636
}
3737

3838
// go up one level to get the binary folder, have to also work on windows

cortex-js/src/usecases/models/models.usecases.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import { ExtensionRepository } from '@/domain/repositories/extension.interface';
2323
import { EngineExtension } from '@/domain/abstracts/engine.abstract';
2424
import { HttpService } from '@nestjs/axios';
2525
import { ModelSettingParamsDto } from '@/infrastructure/dtos/models/model-setting-params.dto';
26+
import { normalizeModelId } from '@/infrastructure/commanders/utils/normalize-model-id';
2627

2728
@Injectable()
2829
export class ModelsUsecases {
@@ -106,7 +107,7 @@ export class ModelsUsecases {
106107
return;
107108
}
108109

109-
const modelFolder = join(modelsContainerDir, id);
110+
const modelFolder = join(modelsContainerDir, normalizeModelId(id));
110111

111112
return this.modelRepository
112113
.delete(id)
@@ -205,7 +206,7 @@ export class ModelsUsecases {
205206
mkdirSync(modelsContainerDir, { recursive: true });
206207
}
207208

208-
const modelFolder = join(modelsContainerDir, model.id);
209+
const modelFolder = join(modelsContainerDir, normalizeModelId(model.id));
209210
await promises.mkdir(modelFolder, { recursive: true });
210211
const destination = join(modelFolder, fileName);
211212

0 commit comments

Comments
 (0)