Skip to content

Commit

Permalink
simplify nvrtc major, minor versions (pytorch#81)
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-xq authored and Mikhail Zolotukhin committed Feb 18, 2020
1 parent 278cd37 commit fd2439b
Showing 1 changed file with 22 additions and 29 deletions.
51 changes: 22 additions & 29 deletions torch/csrc/jit/tensorexpr/cuda_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,36 +61,29 @@ static void getMajorMinor(
const cudaDeviceProp* const prop,
int& major,
int& minor) {
int nvrtc_major, nvrtc_minor;
AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcVersion(&nvrtc_major, &nvrtc_minor));

// Short-circuits if NVRTC version too low
AT_ASSERT(nvrtc_major >= 6);

// Major and minor is determined by device properties and
// possibly "downcompiled" to a lower (compatible) compute architecture
// based on the NVRTC version
major = prop->major;
minor = prop->minor;
if (nvrtc_major <= 7 && prop->major > 5) { // 7 supports 2-5.x
major = 5;
minor = 0;
} else if (nvrtc_major <= 8 && prop->major > 6) { // 8 supports 2-6.x
major = 6;
minor = 0;
} else if (nvrtc_major <= 9 && prop->major >= 7) { // 9 supports 3-7.2
major = 7;
if (prop->major == 7 && prop->minor <= 2)
minor = prop->minor;
else
minor = 0;
} else if (nvrtc_major <= 10 && prop->major >= 7) { // 10 supports 3-7.5
major = 7;
if (prop->major == 7 && prop->minor <= 5)
minor = prop->minor;
else
minor = 0;
using CudaVersion = std::pair<int, int>;
CudaVersion nvrtc_version;
AT_CUDA_NVRTC_CHECK(
nvrtc().nvrtcVersion(&nvrtc_version.first, &nvrtc_version.second));

AT_ASSERT(nvrtc_version.first >= 6);

CudaVersion dev_version = CudaVersion(prop->major, prop->minor);
CudaVersion max_dev_version(dev_version);
if (nvrtc_version.first <= 7) { // 7 supports 2-5.x
max_dev_version = CudaVersion(5, 0);
} else if (nvrtc_version.first <= 8) { // 8 supports 2-6.x
max_dev_version = CudaVersion(6, 0);
} else if (nvrtc_version.first <= 9) { // 9 supports 3-7.2
max_dev_version = CudaVersion(7, 2);
} else if (nvrtc_version.first <= 10) { // 10 supports 3-7.5
max_dev_version = CudaVersion(7, 5);
}
if (dev_version > max_dev_version) {
dev_version = max_dev_version;
}
major = dev_version.first;
minor = dev_version.second;
}

void CudaPrinter::visit(const For* v) {
Expand Down

0 comments on commit fd2439b

Please sign in to comment.