@@ -4,7 +4,7 @@ import { UpdateModelDto } from '@/infrastructure/dtos/models/update-model.dto';
44import { BadRequestException , Injectable } from '@nestjs/common' ;
55import { Model , ModelSettingParams } from '@/domain/models/model.interface' ;
66import { ModelNotFoundException } from '@/infrastructure/exception/model-not-found.exception' ;
7- import { basename , join } from 'path' ;
7+ import { basename , join , parse } from 'path' ;
88import { promises , existsSync , mkdirSync , readFileSync , rmSync } from 'fs' ;
99import { StartModelSuccessDto } from '@/infrastructure/dtos/models/start-model-success.dto' ;
1010import { ExtensionRepository } from '@/domain/repositories/extension.interface' ;
@@ -17,7 +17,6 @@ import { TelemetrySource } from '@/domain/telemetry/telemetry.interface';
1717import { ModelRepository } from '@/domain/repositories/model.interface' ;
1818import { ModelParameterParser } from '@/utils/model-parameter.parser' ;
1919import {
20- HuggingFaceModelVersion ,
2120 HuggingFaceRepoData ,
2221 HuggingFaceRepoSibling ,
2322} from '@/domain/models/huggingface.interface' ;
@@ -26,7 +25,10 @@ import {
2625 fetchJanRepoData ,
2726 getHFModelMetadata ,
2827} from '@/utils/huggingface' ;
29- import { DownloadType } from '@/domain/models/download.interface' ;
28+ import {
29+ DownloadStatus ,
30+ DownloadType ,
31+ } from '@/domain/models/download.interface' ;
3032import { EventEmitter2 } from '@nestjs/event-emitter' ;
3133import { ModelEvent , ModelId , ModelStatus } from '@/domain/models/model.event' ;
3234import { DownloadManagerService } from '@/infrastructure/services/download-manager/download-manager.service' ;
@@ -35,6 +37,7 @@ import { Engines } from '@/infrastructure/commanders/types/engine.interface';
3537import { load } from 'js-yaml' ;
3638import { llamaModelFile } from '@/utils/app-path' ;
3739import { CortexUsecases } from '../cortex/cortex.usecases' ;
40+ import { isLocalFile } from '@/utils/urls' ;
3841
3942@Injectable ( )
4043export class ModelsUsecases {
@@ -127,7 +130,9 @@ export class ModelsUsecases {
127130 ) ) as EngineExtension | undefined ;
128131
129132 if ( engine ) {
130- await engine . unloadModel ( id , model . engine || Engines . llamaCPP ) . catch ( ( ) => { } ) ; // Silent fail
133+ await engine
134+ . unloadModel ( id , model . engine || Engines . llamaCPP )
135+ . catch ( ( ) => { } ) ; // Silent fail
131136 }
132137 return this . modelRepository
133138 . remove ( id )
@@ -174,7 +179,7 @@ export class ModelsUsecases {
174179 }
175180
176181 // Attempt to start cortex
177- await this . cortexUsecases . startCortex ( )
182+ await this . cortexUsecases . startCortex ( ) ;
178183
179184 const loadingModelSpinner = ora ( 'Loading model...' ) . start ( ) ;
180185 // update states and emitting event
@@ -341,10 +346,26 @@ export class ModelsUsecases {
341346 ) {
342347 const modelId = persistedModelId ?? originModelId ;
343348 const existingModel = await this . findOne ( modelId ) ;
349+
344350 if ( isLocalModel ( existingModel ?. files ) ) {
345351 throw new BadRequestException ( 'Model already exists' ) ;
346352 }
347353
354+ // Pull a local model file
355+ if ( isLocalFile ( originModelId ) ) {
356+ await this . populateHuggingFaceModel ( originModelId , persistedModelId ) ;
357+ this . eventEmitter . emit ( 'download.event' , [
358+ {
359+ id : modelId ,
360+ type : DownloadType . Model ,
361+ status : DownloadStatus . Downloaded ,
362+ progress : 100 ,
363+ children : [ ] ,
364+ } ,
365+ ] ) ;
366+ return ;
367+ }
368+
348369 const modelsContainerDir = await this . fileManagerService . getModelsPath ( ) ;
349370
350371 if ( ! existsSync ( modelsContainerDir ) ) {
@@ -422,22 +443,18 @@ export class ModelsUsecases {
422443 model . model = modelId ;
423444 if ( ! ( await this . findOne ( modelId ) ) ) await this . create ( model ) ;
424445 } else {
425- await this . populateHuggingFaceModel ( modelId , files [ 0 ] ) ;
426- const model = await this . findOne ( modelId ) ;
427- if ( model ) {
428- const fileUrl = join (
429- await this . fileManagerService . getModelsPath ( ) ,
430- normalizeModelId ( modelId ) ,
431- basename (
432- files . find ( ( e ) => e . rfilename . endsWith ( '.gguf' ) ) ?. rfilename ??
433- files [ 0 ] . rfilename ,
434- ) ,
435- ) ;
436- await this . update ( modelId , {
437- files : [ fileUrl ] ,
438- name : modelId . replace ( ':main' , '' ) ,
439- } ) ;
440- }
446+ const fileUrl = join (
447+ await this . fileManagerService . getModelsPath ( ) ,
448+ normalizeModelId ( modelId ) ,
449+ basename (
450+ files . find ( ( e ) => e . rfilename . endsWith ( '.gguf' ) ) ?. rfilename ??
451+ files [ 0 ] . rfilename ,
452+ ) ,
453+ ) ;
454+ await this . populateHuggingFaceModel (
455+ fileUrl ,
456+ modelId . replace ( ':main' , '' ) ,
457+ ) ;
441458 }
442459 uploadModelMetadataSpiner . succeed ( 'Model metadata updated' ) ;
443460 const modelEvent : ModelEvent = {
@@ -458,21 +475,18 @@ export class ModelsUsecases {
458475 * It could be a model from Jan's repo or other authors
459476 * @param modelId HuggingFace model id. e.g. "janhq/llama-3 or llama3:7b"
460477 */
461- async populateHuggingFaceModel (
462- modelId : string ,
463- modelVersion : HuggingFaceModelVersion ,
464- ) {
465- if ( ! modelVersion ) throw 'No expected quantization found' ;
466-
467- const tokenizer = await getHFModelMetadata ( modelVersion . downloadUrl ! ) ;
478+ async populateHuggingFaceModel ( ggufUrl : string , overridenId ?: string ) {
479+ const metadata = await getHFModelMetadata ( ggufUrl ) ;
468480
469- const stopWords : string [ ] = tokenizer ?. stopWord ? [ tokenizer . stopWord ] : [ ] ;
481+ const stopWords : string [ ] = metadata ?. stopWord ? [ metadata . stopWord ] : [ ] ;
470482
483+ const modelId =
484+ overridenId ?? ( isLocalFile ( ggufUrl ) ? parse ( ggufUrl ) . name : ggufUrl ) ;
471485 const model : CreateModelDto = {
472- files : [ modelVersion . downloadUrl ?? '' ] ,
486+ files : [ ggufUrl ] ,
473487 model : modelId ,
474- name : modelId ,
475- prompt_template : tokenizer ?. promptTemplate ,
488+ name : metadata ?. name ?? modelId ,
489+ prompt_template : metadata ?. promptTemplate ,
476490 stop : stopWords ,
477491
478492 // Default Inference Params
0 commit comments