Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 147 additions & 38 deletions cortex-js/src/infrastructure/commanders/usecases/models.cli.usecases.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,30 +51,55 @@ export class ModelsCliUsecases {
private readonly inquirerService: InquirerService,
) {}

/**
* Start a model by ID
* @param modelId
*/
async startModel(modelId: string): Promise<void> {
await this.getModelOrStop(modelId);
await this.modelsUsecases.startModel(modelId);
}

/**
* Stop a model by ID
* @param modelId
*/
async stopModel(modelId: string): Promise<void> {
await this.getModelOrStop(modelId);
await this.modelsUsecases.stopModel(modelId);
}

/**
* Update model's settings. E.g. ngl, prompt_template, etc.
* @param modelId
* @param settingParams
* @returns
*/
async updateModelSettingParams(
modelId: string,
settingParams: ModelSettingParams,
): Promise<ModelSettingParams> {
return this.modelsUsecases.updateModelSettingParams(modelId, settingParams);
}

/**
* Update model's runtime parameters. E.g. max_tokens, temperature, etc.
* @param modelId
* @param runtimeParams
* @returns
*/
async updateModelRuntimeParams(
modelId: string,
runtimeParams: ModelRuntimeParams,
): Promise<ModelRuntimeParams> {
return this.modelsUsecases.updateModelRuntimeParams(modelId, runtimeParams);
}

/**
* Find a model or abort if not exist
* @param modelId
* @returns
*/
private async getModelOrStop(modelId: string): Promise<Model> {
const model = await this.modelsUsecases.findOne(modelId);
if (!model) {
Expand All @@ -84,25 +109,42 @@ export class ModelsCliUsecases {
return model;
}

/**
* List all of the models
* @returns
*/
async listAllModels(): Promise<Model[]> {
return this.modelsUsecases.findAll();
}

/**
* Get a model by ID
* @param modelId
* @returns
*/
async getModel(modelId: string): Promise<Model> {
const model = await this.getModelOrStop(modelId);
return model;
}

/**
* Remove a model, this would also delete model files
* @param modelId
* @returns
*/
async removeModel(modelId: string) {
await this.getModelOrStop(modelId);
return this.modelsUsecases.remove(modelId);
}

/**
* Pull model from Model repository (HF, Jan...)
* @param modelId
*/
async pullModel(modelId: string) {
if (modelId.includes('/')) {
if (modelId.includes('/') || modelId.includes(':')) {
await this.pullHuggingFaceModel(modelId);
}

const bar = new SingleBar({}, Presets.shades_classic);
bar.start(100, 0);
const callback = (progress: number) => {
Expand All @@ -111,21 +153,43 @@ export class ModelsCliUsecases {
await this.modelsUsecases.downloadModel(modelId, callback);
}

private async pullHuggingFaceModel(modelId: string) {
const data = await this.fetchHuggingFaceRepoData(modelId);
const { quantization } = await this.inquirerService.inquirer.prompt({
type: 'list',
name: 'quantization',
message: 'Select quantization',
choices: data.siblings
.map((e) => e.quantization)
.filter((e) => e != null),
});
//// PRIVATE METHODS ////

const sibling = data.siblings
.filter((e) => !!e.quantization)
.find((e: any) => e.quantization === quantization);
/**
* It's to pull model from HuggingFace repository
* It could be a model from Jan's repo or other authors
* @param modelId HuggingFace model id. e.g. "janhq/llama-3 or llama3:7b"
*/
private async pullHuggingFaceModel(modelId: string) {
let data: HuggingFaceRepoData;
if (modelId.includes('/'))
data = await this.fetchHuggingFaceRepoData(modelId);
else data = await this.fetchJanRepoData(modelId);

let sibling;

const listChoices = data.siblings
.filter((e) => e.quantization != null)
.map((e) => {
return {
name: e.quantization,
value: e.quantization,
};
});

if (listChoices.length > 1) {
const { quantization } = await this.inquirerService.inquirer.prompt({
type: 'list',
name: 'quantization',
message: 'Select quantization',
choices: listChoices,
});
sibling = data.siblings
.filter((e) => !!e.quantization)
.find((e: any) => e.quantization === quantization);
} else {
sibling = data.siblings.find((e) => e.rfilename.includes('.gguf'));
}
if (!sibling) throw 'No expected quantization found';

let stopWord = '';
Expand All @@ -141,9 +205,7 @@ export class ModelsCliUsecases {

// @ts-expect-error "tokenizer.ggml.tokens"
stopWord = metadata['tokenizer.ggml.tokens'][index] ?? '';
} catch (err) {
console.log('Failed to get stop word: ', err);
}
} catch (err) {}

const stopWords: string[] = [];
if (stopWord.length > 0) {
Expand All @@ -163,6 +225,7 @@ export class ModelsCliUsecases {
description: '',
settings: {
prompt_template: promptTemplate,
llama_model_path: sibling.rfilename,
},
parameters: {
stop: stopWords,
Expand Down Expand Up @@ -209,8 +272,71 @@ export class ModelsCliUsecases {
}
}

/**
* Fetch the model data from Jan's repo
* @param modelId HuggingFace model id. e.g. "llama-3:7b"
* @returns
*/
private async fetchJanRepoData(modelId: string) {
const repo = modelId.split(':')[0];
const tree = modelId.split(':')[1];
const url = this.getRepoModelsUrl(`janhq/${repo}`, tree);
const res = await fetch(url);
const response:
| {
path: string;
size: number;
}[]
| { error: string } = await res.json();

if ('error' in response && response.error != null) {
throw new Error(response.error);
}

const data: HuggingFaceRepoData = {
siblings: Array.isArray(response)
? response.map((e) => {
return {
rfilename: e.path,
downloadUrl: `https://huggingface.co/janhq/${repo}/resolve/${tree}/${e.path}`,
fileSize: e.size ?? 0,
};
})
: [],
tags: ['gguf'],
id: modelId,
modelId: modelId,
author: 'janhq',
sha: '',
downloads: 0,
lastModified: '',
private: false,
disabled: false,
gated: false,
pipeline_tag: 'text-generation',
cardData: {},
createdAt: '',
};

AllQuantizations.forEach((quantization) => {
data.siblings.forEach((sibling: any) => {
if (!sibling.quantization && sibling.rfilename.includes(quantization)) {
sibling.quantization = quantization;
}
});
});

data.modelUrl = url;
return data;
}

/**
* Fetches the model data from HuggingFace API
* @param repoId HuggingFace model id. e.g. "janhq/llama-3"
* @returns
*/
private async fetchHuggingFaceRepoData(repoId: string) {
const sanitizedUrl = this.toHuggingFaceUrl(repoId);
const sanitizedUrl = this.getRepoModelsUrl(repoId);

const res = await fetch(sanitizedUrl);
const response = await res.json();
Expand Down Expand Up @@ -245,24 +371,7 @@ export class ModelsCliUsecases {
return data;
}

private toHuggingFaceUrl(repoId: string): string {
try {
const url = new URL(`https://huggingface.co/${repoId}`);
if (url.host !== 'huggingface.co') {
throw `Invalid Hugging Face repo URL: ${repoId}`;
}

const paths = url.pathname.split('/').filter((e) => e.trim().length > 0);
if (paths.length < 2) {
throw `Invalid Hugging Face repo URL: ${repoId}`;
}

return `${url.origin}/api/models/${paths[0]}/${paths[1]}`;
} catch (err) {
if (repoId.startsWith('https')) {
throw new Error(`Cannot parse url: ${repoId}`);
}
throw err;
}
private getRepoModelsUrl(repoId: string, tree?: string): string {
return `https://huggingface.co/api/models/${repoId}${tree ? `/tree/${tree}` : ''}`;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
export const normalizeModelId = (modelId: string): string => {
return modelId.replace(':', '%3A');
};
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { Model, ModelSettingParams } from '@/domain/models/model.interface';
import { HttpService } from '@nestjs/axios';
import { defaultCortexCppHost, defaultCortexCppPort } from 'constant';
import { readdirSync } from 'node:fs';
import { normalizeModelId } from '@/infrastructure/commanders/utils/normalize-model-id';

/**
* A class that implements the InferenceExtension interface from the @janhq/core package.
Expand All @@ -32,7 +33,10 @@ export default class CortexProvider extends OAIEngineExtension {
): Promise<void> {
const modelsContainerDir = this.modelDir();

const modelFolderFullPath = join(modelsContainerDir, model.id);
const modelFolderFullPath = join(
modelsContainerDir,
normalizeModelId(model.id),
);
const ggufFiles = readdirSync(modelFolderFullPath).filter((file) => {
return file.endsWith('.gguf');
});
Expand Down
2 changes: 1 addition & 1 deletion cortex-js/src/usecases/cortex/cortex.usecases.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ export class CortexUsecases {
);

if (!existsSync(cortexCppPath)) {
throw new Error('Cortex binary not found');
throw new Error('The engine is not available, please run "cortex init".');
}

// go up one level to get the binary folder, have to also work on windows
Expand Down
5 changes: 3 additions & 2 deletions cortex-js/src/usecases/models/models.usecases.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import { ExtensionRepository } from '@/domain/repositories/extension.interface';
import { EngineExtension } from '@/domain/abstracts/engine.abstract';
import { HttpService } from '@nestjs/axios';
import { ModelSettingParamsDto } from '@/infrastructure/dtos/models/model-setting-params.dto';
import { normalizeModelId } from '@/infrastructure/commanders/utils/normalize-model-id';

@Injectable()
export class ModelsUsecases {
Expand Down Expand Up @@ -106,7 +107,7 @@ export class ModelsUsecases {
return;
}

const modelFolder = join(modelsContainerDir, id);
const modelFolder = join(modelsContainerDir, normalizeModelId(id));

return this.modelRepository
.delete(id)
Expand Down Expand Up @@ -205,7 +206,7 @@ export class ModelsUsecases {
mkdirSync(modelsContainerDir, { recursive: true });
}

const modelFolder = join(modelsContainerDir, model.id);
const modelFolder = join(modelsContainerDir, normalizeModelId(model.id));
await promises.mkdir(modelFolder, { recursive: true });
const destination = join(modelFolder, fileName);

Expand Down