|
1 | 1 | import { CommandRunner, SubCommand } from 'nest-commander'; |
2 | | -import { exit } from 'node:process'; |
| 2 | +import { exit, stdin, stdout } from 'node:process'; |
3 | 3 | import { ModelsCliUsecases } from '../usecases/models.cli.usecases'; |
| 4 | +import { RepoDesignation, listFiles } from '@huggingface/hub'; |
| 5 | +import YAML from 'yaml'; |
| 6 | +import * as readline from 'node:readline/promises'; |
| 7 | +import { basename } from 'node:path'; |
4 | 8 |
|
5 | 9 | @SubCommand({ |
6 | 10 | name: 'pull', |
7 | 11 | aliases: ['download'], |
8 | 12 | description: 'Download a model. Working with HuggingFace model id.', |
9 | 13 | }) |
10 | 14 | export class ModelPullCommand extends CommandRunner { |
| 15 | + private metadataFileName = 'metadata.yaml'; |
| 16 | + |
11 | 17 | constructor(private readonly modelsCliUsecases: ModelsCliUsecases) { |
12 | 18 | super(); |
13 | 19 | } |
14 | 20 |
|
15 | 21 | async run(input: string[]) { |
16 | 22 | if (input.length < 1) { |
17 | | - console.error('Model ID is required'); |
| 23 | + console.error('Model Id is required'); |
18 | 24 | exit(1); |
19 | 25 | } |
20 | 26 |
|
21 | | - await this.modelsCliUsecases.pullModel(input[0]); |
| 27 | + // Check if metadata.yaml file exist |
| 28 | + const metadata = await this.getJanMetadata(input[0]); |
| 29 | + |
| 30 | + if (!metadata) { |
| 31 | + await this.modelsCliUsecases.pullModel(input[0]); |
| 32 | + } else { |
| 33 | + await this.handleJanHqModel(input[0], metadata); |
| 34 | + } |
| 35 | + |
22 | 36 | console.log('\nDownload complete!'); |
23 | 37 | exit(0); |
24 | 38 | } |
| 39 | + |
| 40 | + private async getJanMetadata(input: string): Promise<any> { |
| 41 | + // try to append with janhq/ if it's not already |
| 42 | + const sanitizedInput = input.trim().startsWith('janhq/') |
| 43 | + ? input |
| 44 | + : `janhq/${input}`; |
| 45 | + |
| 46 | + const repo: RepoDesignation = { type: 'model', name: sanitizedInput }; |
| 47 | + let isMetadataFileExist = false; |
| 48 | + for await (const fileInfo of listFiles({ repo })) { |
| 49 | + if (fileInfo.path === this.metadataFileName) { |
| 50 | + isMetadataFileExist = true; |
| 51 | + break; |
| 52 | + } |
| 53 | + } |
| 54 | + |
| 55 | + if (!isMetadataFileExist) { |
| 56 | + return undefined; |
| 57 | + } |
| 58 | + |
| 59 | + const path = `https://huggingface.co/${sanitizedInput}/raw/main/${this.metadataFileName}`; |
| 60 | + const res = await fetch(path); |
| 61 | + const metadataJson = await res.text(); |
| 62 | + const parsedMetadata = YAML.parse(metadataJson); |
| 63 | + return parsedMetadata; |
| 64 | + } |
| 65 | + |
| 66 | + private async versionInquiry(tags: string[]): Promise<string> { |
| 67 | + return new Promise((resolve) => { |
| 68 | + let selectedTag = 'default'; |
| 69 | + let prompt = 'Select the version you want to download:\n'; |
| 70 | + for (let i = 0; i < tags.length; i++) { |
| 71 | + prompt += `${i}. ${tags[i]}\n`; |
| 72 | + } |
| 73 | + prompt += '>> '; |
| 74 | + |
| 75 | + const rl = readline.createInterface({ |
| 76 | + input: stdin, |
| 77 | + output: stdout, |
| 78 | + prompt: prompt, |
| 79 | + }); |
| 80 | + rl.prompt(); |
| 81 | + |
| 82 | + rl.on('close', () => { |
| 83 | + resolve(selectedTag); |
| 84 | + }); |
| 85 | + |
| 86 | + rl.on('line', (input) => { |
| 87 | + if (input.trim().length === 0) { |
| 88 | + rl.close(); |
| 89 | + } |
| 90 | + |
| 91 | + try { |
| 92 | + if (Number(input) >= 0 && Number(input) < tags.length) { |
| 93 | + selectedTag = tags[Number(input)]; |
| 94 | + rl.close(); |
| 95 | + } else { |
| 96 | + console.error('Invalid version'); |
| 97 | + rl.prompt(); |
| 98 | + } |
| 99 | + } catch (e) { |
| 100 | + console.error('Invalid version'); |
| 101 | + rl.prompt(); |
| 102 | + } |
| 103 | + }); |
| 104 | + }); |
| 105 | + } |
| 106 | + |
| 107 | + private async handleJanHqModel(repoName: string, metadata: any) { |
| 108 | + const sanitizedRepoName = repoName.trim().startsWith('janhq/') |
| 109 | + ? repoName |
| 110 | + : `janhq/${repoName}`; |
| 111 | + |
| 112 | + const tags = metadata.tags; |
| 113 | + let selectedTag = 'default'; |
| 114 | + const allTags: string[] = Object.keys(tags); |
| 115 | + |
| 116 | + if (allTags.length > 1) { |
| 117 | + selectedTag = await this.versionInquiry(allTags); |
| 118 | + } |
| 119 | + |
| 120 | + const branch = selectedTag; |
| 121 | + const engine = 'llamacpp'; // TODO: currently, we only support llamacpp |
| 122 | + |
| 123 | + const revision = metadata.tags?.[branch]?.[engine]; |
| 124 | + if (!revision) { |
| 125 | + console.error("Can't find model revision."); |
| 126 | + exit(1); |
| 127 | + } |
| 128 | + |
| 129 | + const repo: RepoDesignation = { type: 'model', name: sanitizedRepoName }; |
| 130 | + let ggufUrl: string | undefined = undefined; |
| 131 | + for await (const fileInfo of listFiles({ |
| 132 | + repo: repo, |
| 133 | + revision: revision, |
| 134 | + })) { |
| 135 | + if (fileInfo.path.endsWith('.gguf')) { |
| 136 | + ggufUrl = `https://huggingface.co/${sanitizedRepoName}/resolve/${revision}/${fileInfo.path}`; |
| 137 | + break; |
| 138 | + } |
| 139 | + } |
| 140 | + |
| 141 | + if (!ggufUrl) { |
| 142 | + console.error("Can't find model file."); |
| 143 | + exit(1); |
| 144 | + } |
| 145 | + console.log('Downloading', basename(ggufUrl)); |
| 146 | + await this.modelsCliUsecases.pullModelWithExactUrl( |
| 147 | + `${sanitizedRepoName}/${revision}`, |
| 148 | + ggufUrl, |
| 149 | + ); |
| 150 | + } |
25 | 151 | } |
0 commit comments