Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 0aa0dd8

Browse files
committed
chore: download model files
1 parent ba6cb15 commit 0aa0dd8

File tree

3 files changed

+77
-4
lines changed

3 files changed

+77
-4
lines changed

cortex-js/src/utils/cuda.ts

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,13 @@ import { existsSync } from 'fs';
33
import { delimiter } from 'path';
44
import { checkFileExistenceInPaths } from './app-path';
55

6+
export type GpuSettingInfo = {
7+
id: string;
8+
vram: string;
9+
name: string;
10+
arch?: string;
11+
};
12+
613
/**
714
* Return the CUDA version installed on the system
815
* @returns CUDA Version 11 | 12
@@ -63,3 +70,46 @@ export const checkNvidiaGPUExist = (): Promise<boolean> => {
6370
});
6471
});
6572
};
73+
74+
/**
75+
* Get GPU information from the system
76+
* @returns GPU information
77+
*/
78+
export const getGpuInfo = async (): Promise<GpuSettingInfo[]> =>
79+
new Promise((resolve) => {
80+
exec(
81+
'nvidia-smi --query-gpu=index,memory.total,name --format=csv,noheader,nounits',
82+
async (error, stdout) => {
83+
if (!error) {
84+
// Get GPU info and gpu has higher memory first
85+
let highestVram = 0;
86+
let highestVramId = '0';
87+
const gpus: GpuSettingInfo[] = stdout
88+
.trim()
89+
.split('\n')
90+
.map((line) => {
91+
let [id, vram, name] = line.split(', ');
92+
const arch = getGpuArch(name);
93+
vram = vram.replace(/\r/g, '');
94+
if (parseFloat(vram) > highestVram) {
95+
highestVram = parseFloat(vram);
96+
highestVramId = id;
97+
}
98+
return { id, vram, name, arch };
99+
});
100+
101+
resolve(gpus);
102+
} else {
103+
resolve([]);
104+
}
105+
},
106+
);
107+
});
108+
109+
const getGpuArch = (gpuName: string): string => {
110+
if (!gpuName.toLowerCase().includes('nvidia')) return 'unknown';
111+
112+
if (gpuName.includes('30')) return 'ampere';
113+
else if (gpuName.includes('40')) return 'ada';
114+
else return 'unknown';
115+
};

cortex-js/src/utils/huggingface.ts

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import {
2020
} from '@/infrastructure/constants/prompt-constants';
2121
import { gguf } from '@huggingface/gguf';
2222
import axios from 'axios';
23+
import { parseModelHubEngineBranch } from './normalize-model-id';
2324

2425
// TODO: move this to somewhere else, should be reused by API as well. Maybe in a separate service / provider?
2526
export function guessPromptTemplateFromHuggingFace(jinjaCode?: string): string {
@@ -64,7 +65,6 @@ export function guessPromptTemplateFromHuggingFace(jinjaCode?: string): string {
6465
export async function fetchHuggingFaceRepoData(
6566
repoId: string,
6667
): Promise<HuggingFaceRepoData> {
67-
6868
const sanitizedUrl = getRepoModelsUrl(repoId);
6969

7070
const { data: response } = await axios.get(sanitizedUrl);
@@ -113,7 +113,7 @@ export async function fetchJanRepoData(
113113
modelId: string,
114114
): Promise<HuggingFaceRepoData> {
115115
const repo = modelId.split(':')[0];
116-
const tree = modelId.split(':')[1] ?? 'default';
116+
const tree = await parseModelHubEngineBranch(modelId.split(':')[1] ?? 'default');
117117
const url = getRepoModelsUrl(`cortexhub/${repo}`, tree);
118118

119119
const res = await fetch(url);
@@ -164,8 +164,6 @@ export async function fetchJanRepoData(
164164

165165
data.modelUrl = url;
166166

167-
168-
169167
return data;
170168
}
171169

cortex-js/src/utils/normalize-model-id.ts

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import { ModelArtifact } from '@/domain/models/model.interface';
2+
import { getGpuInfo } from './cuda';
23

34
export const normalizeModelId = (modelId: string): string => {
45
return modelId.replace(':default', '').replace(/[:/]/g, '-');
@@ -13,3 +14,27 @@ export const isLocalModel = (
1314
!/^(http|https):\/\/[^/]+\/.*/.test(modelFiles[0])
1415
);
1516
};
17+
18+
/**
19+
* Parse the model hub engine branch
20+
* @param branch
21+
* @returns
22+
*/
23+
export const parseModelHubEngineBranch = async (
24+
branch: string,
25+
): Promise<string> => {
26+
if (branch.includes('tensorrt')) {
27+
let engineBranch = branch;
28+
const platform = process.platform == 'win32' ? 'windows' : 'linux';
29+
if (!engineBranch.includes(platform)) {
30+
engineBranch += `-${platform}`;
31+
}
32+
33+
const gpus = await getGpuInfo();
34+
if (gpus[0]?.arch && !engineBranch.includes(gpus[0].arch)) {
35+
engineBranch += `-${gpus[0].arch}`;
36+
}
37+
return engineBranch;
38+
}
39+
return branch;
40+
};

0 commit comments

Comments
 (0)