@@ -16,9 +16,13 @@ import { join, basename } from 'path';
1616import { load } from 'js-yaml' ;
1717import { existsSync , readdirSync , readFileSync } from 'fs' ;
1818import { isLocalModel , normalizeModelId } from '@/utils/normalize-model-id' ;
19- import { getHFModelMetadata } from '@/utils/huggingface' ;
19+ import {
20+ fetchJanRepoData ,
21+ getHFModelMetadata ,
22+ } from '@/utils/huggingface' ;
2023import { createWriteStream , mkdirSync , promises } from 'node:fs' ;
2124import { firstValueFrom } from 'rxjs' ;
25+ import { listFiles , downloadFile } from '@huggingface/hub' ;
2226
2327@Injectable ( )
2428export class ModelsCliUsecases {
@@ -120,8 +124,8 @@ export class ModelsCliUsecases {
120124 process . exit ( 1 ) ;
121125 }
122126
123- if ( modelId . includes ( 'onnx' ) ) {
124- await this . pullOnnxModel ( modelId ) ;
127+ if ( modelId . includes ( 'onnx' ) || modelId . includes ( 'tensorrt' ) ) {
128+ await this . pullEngineModelFiles ( modelId ) ;
125129 } else {
126130 await this . pullGGUFModel ( modelId ) ;
127131 const bar = new SingleBar ( { } , Presets . shades_classic ) ;
@@ -151,10 +155,10 @@ export class ModelsCliUsecases {
151155 }
152156
153157 /**
154- * It's to pull ONNX model from HuggingFace repository
158+ * It's to pull engine model files from HuggingFace repository
155159 * @param modelId
156160 */
157- private async pullOnnxModel ( modelId : string ) {
161+ private async pullEngineModelFiles ( modelId : string ) {
158162 const modelsContainerDir = await this . fileService . getModelsPath ( ) ;
159163
160164 if ( ! existsSync ( modelsContainerDir ) ) {
@@ -164,35 +168,22 @@ export class ModelsCliUsecases {
164168 const modelFolder = join ( modelsContainerDir , normalizeModelId ( modelId ) ) ;
165169 await promises . mkdir ( modelFolder , { recursive : true } ) . catch ( ( ) => { } ) ;
166170
167- const files = [
168- 'genai_config.json' ,
169- 'model.onnx' ,
170- 'model.onnx.data' ,
171- 'model.yml' ,
172- 'special_tokens_map.json' ,
173- 'tokenizer.json' ,
174- 'tokenizer_config.json' ,
175- ] ;
176- const repo = modelId . split ( ':' ) [ 0 ] ;
177- const branch = modelId . split ( ':' ) [ 1 ] || 'default' ;
171+ const files = ( await fetchJanRepoData ( modelId ) ) . siblings ;
178172 for ( const file of files ) {
179- console . log ( `Downloading ${ file } ` ) ;
173+ console . log ( `Downloading ${ file . rfilename } ` ) ;
180174 const bar = new SingleBar ( { } , Presets . shades_classic ) ;
181175 bar . start ( 100 , 0 ) ;
182176 const response = await firstValueFrom (
183- this . httpService . get (
184- `https://huggingface.co/cortexhub/${ repo } /resolve/${ branch } /${ file } ?download=true` ,
185- {
186- responseType : 'stream' ,
187- } ,
188- ) ,
177+ this . httpService . get ( file . downloadUrl ?? '' , {
178+ responseType : 'stream' ,
179+ } ) ,
189180 ) ;
190181 if ( ! response ) {
191182 throw new Error ( 'Failed to download model' ) ;
192183 }
193184
194185 await new Promise ( ( resolve , reject ) => {
195- const writer = createWriteStream ( join ( modelFolder , file ) ) ;
186+ const writer = createWriteStream ( join ( modelFolder , file . rfilename ) ) ;
196187 let receivedBytes = 0 ;
197188 const totalBytes = response . headers [ 'content-length' ] ;
198189
0 commit comments