@@ -16,7 +16,10 @@ import { join, basename } from 'path';
1616import { load } from 'js-yaml' ;
1717import { existsSync , readdirSync , readFileSync } from 'fs' ;
1818import { isLocalModel , normalizeModelId } from '@/utils/normalize-model-id' ;
19- import { getHFModelMetadata } from '@/utils/huggingface' ;
19+ import {
20+ fetchJanRepoData ,
21+ getHFModelMetadata ,
22+ } from '@/utils/huggingface' ;
2023import { createWriteStream , mkdirSync , promises } from 'node:fs' ;
2124import { firstValueFrom } from 'rxjs' ;
2225
@@ -120,8 +123,8 @@ export class ModelsCliUsecases {
120123 process . exit ( 1 ) ;
121124 }
122125
123- if ( modelId . includes ( 'onnx' ) ) {
124- await this . pullOnnxModel ( modelId ) ;
126+ if ( modelId . includes ( 'onnx' ) || modelId . includes ( 'tensorrt' ) ) {
127+ await this . pullEngineModelFiles ( modelId ) ;
125128 } else {
126129 await this . pullGGUFModel ( modelId ) ;
127130 const bar = new SingleBar ( { } , Presets . shades_classic ) ;
@@ -151,10 +154,10 @@ export class ModelsCliUsecases {
151154 }
152155
153156 /**
154- * It's to pull ONNX model from HuggingFace repository
157+ * It's to pull engine model files from HuggingFace repository
155158 * @param modelId
156159 */
157- private async pullOnnxModel ( modelId : string ) {
160+ private async pullEngineModelFiles ( modelId : string ) {
158161 const modelsContainerDir = await this . fileService . getModelsPath ( ) ;
159162
160163 if ( ! existsSync ( modelsContainerDir ) ) {
@@ -164,35 +167,22 @@ export class ModelsCliUsecases {
164167 const modelFolder = join ( modelsContainerDir , normalizeModelId ( modelId ) ) ;
165168 await promises . mkdir ( modelFolder , { recursive : true } ) . catch ( ( ) => { } ) ;
166169
167- const files = [
168- 'genai_config.json' ,
169- 'model.onnx' ,
170- 'model.onnx.data' ,
171- 'model.yml' ,
172- 'special_tokens_map.json' ,
173- 'tokenizer.json' ,
174- 'tokenizer_config.json' ,
175- ] ;
176- const repo = modelId . split ( ':' ) [ 0 ] ;
177- const branch = modelId . split ( ':' ) [ 1 ] || 'default' ;
170+ const files = ( await fetchJanRepoData ( modelId ) ) . siblings ;
178171 for ( const file of files ) {
179- console . log ( `Downloading ${ file } ` ) ;
172+ console . log ( `Downloading ${ file . rfilename } ` ) ;
180173 const bar = new SingleBar ( { } , Presets . shades_classic ) ;
181174 bar . start ( 100 , 0 ) ;
182175 const response = await firstValueFrom (
183- this . httpService . get (
184- `https://huggingface.co/cortexhub/${ repo } /resolve/${ branch } /${ file } ?download=true` ,
185- {
186- responseType : 'stream' ,
187- } ,
188- ) ,
176+ this . httpService . get ( file . downloadUrl ?? '' , {
177+ responseType : 'stream' ,
178+ } ) ,
189179 ) ;
190180 if ( ! response ) {
191181 throw new Error ( 'Failed to download model' ) ;
192182 }
193183
194184 await new Promise ( ( resolve , reject ) => {
195- const writer = createWriteStream ( join ( modelFolder , file ) ) ;
185+ const writer = createWriteStream ( join ( modelFolder , file . rfilename ) ) ;
196186 let receivedBytes = 0 ;
197187 const totalBytes = response . headers [ 'content-length' ] ;
198188
0 commit comments