@@ -51,30 +51,55 @@ export class ModelsCliUsecases {
5151 private readonly inquirerService : InquirerService ,
5252 ) { }
5353
54+ /**
55+ * Start a model by ID
56+ * @param modelId
57+ */
5458 async startModel ( modelId : string ) : Promise < void > {
5559 await this . getModelOrStop ( modelId ) ;
5660 await this . modelsUsecases . startModel ( modelId ) ;
5761 }
5862
63+ /**
64+ * Stop a model by ID
65+ * @param modelId
66+ */
5967 async stopModel ( modelId : string ) : Promise < void > {
6068 await this . getModelOrStop ( modelId ) ;
6169 await this . modelsUsecases . stopModel ( modelId ) ;
6270 }
6371
72+ /**
73+ * Update model's settings. E.g. ngl, prompt_template, etc.
74+ * @param modelId
75+ * @param settingParams
76+ * @returns
77+ */
6478 async updateModelSettingParams (
6579 modelId : string ,
6680 settingParams : ModelSettingParams ,
6781 ) : Promise < ModelSettingParams > {
6882 return this . modelsUsecases . updateModelSettingParams ( modelId , settingParams ) ;
6983 }
7084
85+ /**
86+ * Update model's runtime parameters. E.g. max_tokens, temperature, etc.
87+ * @param modelId
88+ * @param runtimeParams
89+ * @returns
90+ */
7191 async updateModelRuntimeParams (
7292 modelId : string ,
7393 runtimeParams : ModelRuntimeParams ,
7494 ) : Promise < ModelRuntimeParams > {
7595 return this . modelsUsecases . updateModelRuntimeParams ( modelId , runtimeParams ) ;
7696 }
7797
98+ /**
99+ * Find a model or abort if not exist
100+ * @param modelId
101+ * @returns
102+ */
78103 private async getModelOrStop ( modelId : string ) : Promise < Model > {
79104 const model = await this . modelsUsecases . findOne ( modelId ) ;
80105 if ( ! model ) {
@@ -84,25 +109,42 @@ export class ModelsCliUsecases {
84109 return model ;
85110 }
86111
112+ /**
113+ * List all of the models
114+ * @returns
115+ */
87116 async listAllModels ( ) : Promise < Model [ ] > {
88117 return this . modelsUsecases . findAll ( ) ;
89118 }
90119
120+ /**
121+ * Get a model by ID
122+ * @param modelId
123+ * @returns
124+ */
91125 async getModel ( modelId : string ) : Promise < Model > {
92126 const model = await this . getModelOrStop ( modelId ) ;
93127 return model ;
94128 }
95129
130+ /**
131+ * Remove a model, this would also delete model files
132+ * @param modelId
133+ * @returns
134+ */
96135 async removeModel ( modelId : string ) {
97136 await this . getModelOrStop ( modelId ) ;
98137 return this . modelsUsecases . remove ( modelId ) ;
99138 }
100139
140+ /**
141+ * Pull model from Model repository (HF, Jan...)
142+ * @param modelId
143+ */
101144 async pullModel ( modelId : string ) {
102- if ( modelId . includes ( '/' ) ) {
145+ if ( modelId . includes ( '/' ) || modelId . includes ( ':' ) ) {
103146 await this . pullHuggingFaceModel ( modelId ) ;
104147 }
105-
106148 const bar = new SingleBar ( { } , Presets . shades_classic ) ;
107149 bar . start ( 100 , 0 ) ;
108150 const callback = ( progress : number ) => {
@@ -111,21 +153,43 @@ export class ModelsCliUsecases {
111153 await this . modelsUsecases . downloadModel ( modelId , callback ) ;
112154 }
113155
114- private async pullHuggingFaceModel ( modelId : string ) {
115- const data = await this . fetchHuggingFaceRepoData ( modelId ) ;
116- const { quantization } = await this . inquirerService . inquirer . prompt ( {
117- type : 'list' ,
118- name : 'quantization' ,
119- message : 'Select quantization' ,
120- choices : data . siblings
121- . map ( ( e ) => e . quantization )
122- . filter ( ( e ) => e != null ) ,
123- } ) ;
156+ //// PRIVATE METHODS ////
124157
125- const sibling = data . siblings
126- . filter ( ( e ) => ! ! e . quantization )
127- . find ( ( e : any ) => e . quantization === quantization ) ;
158+ /**
159+ * It's to pull model from HuggingFace repository
160+ * It could be a model from Jan's repo or other authors
161+ * @param modelId HuggingFace model id. e.g. "janhq/llama-3 or llama3:7b"
162+ */
163+ private async pullHuggingFaceModel ( modelId : string ) {
164+ let data : HuggingFaceRepoData ;
165+ if ( modelId . includes ( '/' ) )
166+ data = await this . fetchHuggingFaceRepoData ( modelId ) ;
167+ else data = await this . fetchJanRepoData ( modelId ) ;
168+
169+ let sibling ;
170+
171+ const listChoices = data . siblings
172+ . filter ( ( e ) => e . quantization != null )
173+ . map ( ( e ) => {
174+ return {
175+ name : e . quantization ,
176+ value : e . quantization ,
177+ } ;
178+ } ) ;
128179
180+ if ( listChoices . length > 1 ) {
181+ const { quantization } = await this . inquirerService . inquirer . prompt ( {
182+ type : 'list' ,
183+ name : 'quantization' ,
184+ message : 'Select quantization' ,
185+ choices : listChoices ,
186+ } ) ;
187+ sibling = data . siblings
188+ . filter ( ( e ) => ! ! e . quantization )
189+ . find ( ( e : any ) => e . quantization === quantization ) ;
190+ } else {
191+ sibling = data . siblings . find ( ( e ) => e . rfilename . includes ( '.gguf' ) ) ;
192+ }
129193 if ( ! sibling ) throw 'No expected quantization found' ;
130194
131195 let stopWord = '' ;
@@ -141,9 +205,7 @@ export class ModelsCliUsecases {
141205
142206 // @ts -expect-error "tokenizer.ggml.tokens"
143207 stopWord = metadata [ 'tokenizer.ggml.tokens' ] [ index ] ?? '' ;
144- } catch ( err ) {
145- console . log ( 'Failed to get stop word: ' , err ) ;
146- }
208+ } catch ( err ) { }
147209
148210 const stopWords : string [ ] = [ ] ;
149211 if ( stopWord . length > 0 ) {
@@ -163,6 +225,7 @@ export class ModelsCliUsecases {
163225 description : '' ,
164226 settings : {
165227 prompt_template : promptTemplate ,
228+ llama_model_path : sibling . rfilename ,
166229 } ,
167230 parameters : {
168231 stop : stopWords ,
@@ -209,8 +272,71 @@ export class ModelsCliUsecases {
209272 }
210273 }
211274
275+ /**
276+ * Fetch the model data from Jan's repo
277+ * @param modelId HuggingFace model id. e.g. "llama-3:7b"
278+ * @returns
279+ */
280+ private async fetchJanRepoData ( modelId : string ) {
281+ const repo = modelId . split ( ':' ) [ 0 ] ;
282+ const tree = modelId . split ( ':' ) [ 1 ] ;
283+ const url = this . getRepoModelsUrl ( `janhq/${ repo } ` , tree ) ;
284+ const res = await fetch ( url ) ;
285+ const response :
286+ | {
287+ path : string ;
288+ size : number ;
289+ } [ ]
290+ | { error : string } = await res . json ( ) ;
291+
292+ if ( 'error' in response && response . error != null ) {
293+ throw new Error ( response . error ) ;
294+ }
295+
296+ const data : HuggingFaceRepoData = {
297+ siblings : Array . isArray ( response )
298+ ? response . map ( ( e ) => {
299+ return {
300+ rfilename : e . path ,
301+ downloadUrl : `https://huggingface.co/janhq/${ repo } /resolve/${ tree } /${ e . path } ` ,
302+ fileSize : e . size ?? 0 ,
303+ } ;
304+ } )
305+ : [ ] ,
306+ tags : [ 'gguf' ] ,
307+ id : modelId ,
308+ modelId : modelId ,
309+ author : 'janhq' ,
310+ sha : '' ,
311+ downloads : 0 ,
312+ lastModified : '' ,
313+ private : false ,
314+ disabled : false ,
315+ gated : false ,
316+ pipeline_tag : 'text-generation' ,
317+ cardData : { } ,
318+ createdAt : '' ,
319+ } ;
320+
321+ AllQuantizations . forEach ( ( quantization ) => {
322+ data . siblings . forEach ( ( sibling : any ) => {
323+ if ( ! sibling . quantization && sibling . rfilename . includes ( quantization ) ) {
324+ sibling . quantization = quantization ;
325+ }
326+ } ) ;
327+ } ) ;
328+
329+ data . modelUrl = url ;
330+ return data ;
331+ }
332+
333+ /**
334+ * Fetches the model data from HuggingFace API
335+ * @param repoId HuggingFace model id. e.g. "janhq/llama-3"
336+ * @returns
337+ */
212338 private async fetchHuggingFaceRepoData ( repoId : string ) {
213- const sanitizedUrl = this . toHuggingFaceUrl ( repoId ) ;
339+ const sanitizedUrl = this . getRepoModelsUrl ( repoId ) ;
214340
215341 const res = await fetch ( sanitizedUrl ) ;
216342 const response = await res . json ( ) ;
@@ -245,24 +371,7 @@ export class ModelsCliUsecases {
245371 return data ;
246372 }
247373
248- private toHuggingFaceUrl ( repoId : string ) : string {
249- try {
250- const url = new URL ( `https://huggingface.co/${ repoId } ` ) ;
251- if ( url . host !== 'huggingface.co' ) {
252- throw `Invalid Hugging Face repo URL: ${ repoId } ` ;
253- }
254-
255- const paths = url . pathname . split ( '/' ) . filter ( ( e ) => e . trim ( ) . length > 0 ) ;
256- if ( paths . length < 2 ) {
257- throw `Invalid Hugging Face repo URL: ${ repoId } ` ;
258- }
259-
260- return `${ url . origin } /api/models/${ paths [ 0 ] } /${ paths [ 1 ] } ` ;
261- } catch ( err ) {
262- if ( repoId . startsWith ( 'https' ) ) {
263- throw new Error ( `Cannot parse url: ${ repoId } ` ) ;
264- }
265- throw err ;
266- }
374+ private getRepoModelsUrl ( repoId : string , tree ?: string ) : string {
375+ return `https://huggingface.co/api/models/${ repoId } ${ tree ? `/tree/${ tree } ` : '' } ` ;
267376 }
268377}
0 commit comments