/
detect_target.py
132 lines (113 loc) · 3.51 KB
/
detect_target.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Automatic detect target for testing
"""
import logging
import os
from subprocess import PIPE, Popen
from aitemplate.backend.target import CUDA, ROCM
# pylint: disable=W0702, W0612,R1732
_LOGGER = logging.getLogger(__name__)
IS_CUDA = None
FLAG = ""
def _detect_cuda_with_nvidia_smi():
try:
proc = Popen(
["nvidia-smi", "--query-gpu=gpu_name", "--format=csv"],
stdout=PIPE,
stderr=PIPE,
)
stdout, stderr = proc.communicate()
stdout = stdout.decode("utf-8")
sm_names = {
"70": ["V100"],
"75": ["T4", "Quadro T2000"],
"80": ["PG509", "A100", "A800", "A10G", "RTX 30", "A30", "RTX 40"],
"90": ["H100", "H800"],
}
for sm, names in sm_names.items():
if any(name in stdout for name in names):
return sm
return None
except Exception:
return None
def _detect_cuda():
try:
from cuda import cuda
def assert_cuda(res):
if res[0].value != 0:
raise RuntimeError(f"CUDA error code={res[0].value}")
return res[1:]
assert_cuda(cuda.cuInit(0))
# Get Compute Capability of the first Visible device
major, minor = assert_cuda(cuda.cuDeviceComputeCapability(0))
comp_cap = major * 10 + minor
if comp_cap >= 90:
return "90"
elif comp_cap >= 80:
return "80"
elif comp_cap >= 75:
return "75"
elif comp_cap >= 70:
return "70"
else:
return None
except ImportError:
# go back to old way to detect the CUDA arch
return _detect_cuda_with_nvidia_smi()
except Exception:
return None
def _detect_rocm():
try:
proc = Popen(["rocminfo"], stdout=PIPE, stderr=PIPE)
stdout, stderr = proc.communicate()
stdout = stdout.decode("utf-8")
if "gfx90a" in stdout:
return "gfx90a"
if "gfx908" in stdout:
return "gfx908"
return None
except Exception:
return None
def detect_target(**kwargs):
"""Detect GPU target based on nvidia-smi and rocminfo
Returns
-------
Target
CUDA or ROCM target
"""
global IS_CUDA, FLAG
if FLAG:
if IS_CUDA:
return CUDA(arch=FLAG, **kwargs)
else:
return ROCM(arch=FLAG, **kwargs)
doc_flag = os.getenv("AIT_BUILD_DOCS", None)
if doc_flag is not None:
return CUDA(arch="80", **kwargs)
flag = _detect_cuda()
if flag is not None:
IS_CUDA = True
FLAG = flag
_LOGGER.info("Set target to CUDA")
return CUDA(arch=flag, **kwargs)
flag = _detect_rocm()
if flag is not None:
IS_CUDA = False
FLAG = flag
_LOGGER.info("Set target to ROCM")
return ROCM(arch=flag, **kwargs)
raise RuntimeError("Unsupported platform")