|
1 | 1 | import { CommandRunner, SubCommand } from 'nest-commander'; |
2 | 2 | import { exit } from 'node:process'; |
3 | 3 | import { ModelsCliUsecases } from '../usecases/models.cli.usecases'; |
| 4 | +import { RepoDesignation, listFiles } from '@huggingface/hub'; |
| 5 | +import YAML from 'yaml'; |
4 | 6 |
|
5 | 7 | @SubCommand({ |
6 | 8 | name: 'pull', |
7 | 9 | aliases: ['download'], |
8 | 10 | description: 'Download a model. Working with HuggingFace model id.', |
9 | 11 | }) |
10 | 12 | export class ModelPullCommand extends CommandRunner { |
| 13 | + private metadataFileName = 'metadata.yaml'; |
| 14 | + |
11 | 15 | constructor(private readonly modelsCliUsecases: ModelsCliUsecases) { |
12 | 16 | super(); |
13 | 17 | } |
14 | 18 |
|
15 | 19 | async run(input: string[]) { |
16 | 20 | if (input.length < 1) { |
17 | | - console.error('Model ID is required'); |
| 21 | + console.error('Model Id is required'); |
18 | 22 | exit(1); |
19 | 23 | } |
20 | 24 |
|
21 | | - await this.modelsCliUsecases.pullModel(input[0]); |
| 25 | + // Check if metadata.yaml file exist |
| 26 | + const metadata = await this.getJanMetadata(input[0]); |
| 27 | + |
| 28 | + if (!metadata) { |
| 29 | + await this.modelsCliUsecases.pullModel(input[0]); |
| 30 | + } else { |
| 31 | + await this.handleJanHqModel(input[0], metadata); |
| 32 | + } |
| 33 | + |
22 | 34 | console.log('\nDownload complete!'); |
23 | 35 | exit(0); |
24 | 36 | } |
| 37 | + |
| 38 | + private async getJanMetadata(input: string): Promise<any> { |
| 39 | + // try to append with janhq/ if it's not already |
| 40 | + const sanitizedInput = input.trim().startsWith('janhq/') |
| 41 | + ? input |
| 42 | + : `janhq/${input}`; |
| 43 | + |
| 44 | + const repo: RepoDesignation = { type: 'model', name: sanitizedInput }; |
| 45 | + let isMetadataFileExist = false; |
| 46 | + for await (const fileInfo of listFiles({ repo })) { |
| 47 | + if (fileInfo.path === this.metadataFileName) { |
| 48 | + isMetadataFileExist = true; |
| 49 | + break; |
| 50 | + } |
| 51 | + } |
| 52 | + |
| 53 | + if (!isMetadataFileExist) { |
| 54 | + return undefined; |
| 55 | + } |
| 56 | + |
| 57 | + const path = `https://huggingface.co/${sanitizedInput}/raw/main/${this.metadataFileName}`; |
| 58 | + const res = await fetch(path); |
| 59 | + const metadataJson = await res.text(); |
| 60 | + const parsedMetadata = YAML.parse(metadataJson); |
| 61 | + return parsedMetadata; |
| 62 | + } |
| 63 | + |
| 64 | + private async handleJanHqModel(repoName: string, metadata: any) { |
| 65 | + // TODO: asking user to choose here |
| 66 | + const sanitizedRepoName = `janhq/${repoName.trim()}`; |
| 67 | + const branch = 'default'; |
| 68 | + const engine = 'llamacpp'; // TODO: currently, we only support llamacpp |
| 69 | + |
| 70 | + const revision = metadata.tags?.[branch]?.[engine]; |
| 71 | + if (!revision) { |
| 72 | + console.error("Can't find model revision."); |
| 73 | + exit(1); |
| 74 | + } |
| 75 | + |
| 76 | + const repo: RepoDesignation = { type: 'model', name: sanitizedRepoName }; |
| 77 | + let ggufUrl: string | undefined = undefined; |
| 78 | + for await (const fileInfo of listFiles({ |
| 79 | + repo: repo, |
| 80 | + revision: revision, |
| 81 | + })) { |
| 82 | + if (fileInfo.path.endsWith('.gguf')) { |
| 83 | + ggufUrl = `https://huggingface.co/${sanitizedRepoName}/resolve/${revision}/${fileInfo.path}`; |
| 84 | + break; |
| 85 | + } |
| 86 | + } |
| 87 | + |
| 88 | + if (!ggufUrl) { |
| 89 | + console.error("Can't find model file."); |
| 90 | + exit(1); |
| 91 | + } |
| 92 | + await this.modelsCliUsecases.pullModelWithExactUrl( |
| 93 | + `${sanitizedRepoName}/${revision}`, |
| 94 | + ggufUrl, |
| 95 | + ); |
| 96 | + } |
25 | 97 | } |
0 commit comments