Skip to content

Commit

Permalink
Extend DispatchStub to support CUDA dispatch (pytorch#9579)
Browse files Browse the repository at this point in the history
Summary:
This is a few files taken from pytorch#8919. They're unchanged from the latest versions of that PR.

```
This is part of pytorch#8919. It's
separated to make it easier to merge the PR in pieces.

There are a few major changes to DispatchStub

 - The environment variable ATEN_CPU_CAPABILITY overrides the CPU
   capability detection code (Previous ATEN_DISABLE_AVX/AVX2)

 - DispatchStub is defined in the generic native code instead of the
   CPU_CAPABILITY_DEFAULT kernel.
```
Pull Request resolved: pytorch#9579

Differential Revision: D8909000

Pulled By: colesbury

fbshipit-source-id: fdeb606270b06acdab3c01dba97ec9d81584ecc0
  • Loading branch information
colesbury authored and facebook-github-bot committed Jul 19, 2018
1 parent a08119a commit bcf0bf4
Show file tree
Hide file tree
Showing 10 changed files with 139 additions and 55 deletions.
9 changes: 3 additions & 6 deletions .jenkins/pytorch/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,10 @@ if [[ "$BUILD_ENVIRONMENT" == *asan* ]]; then
(cd test && ! get_exit_code python -c "import torch; torch._C._crash_if_aten_asan(3)")
fi

export ATEN_DISABLE_AVX=
export ATEN_DISABLE_AVX2=
if [[ "${JOB_BASE_NAME}" == *-NO_AVX-* ]]; then
export ATEN_DISABLE_AVX=1
fi
if [[ "${JOB_BASE_NAME}" == *-NO_AVX2-* ]]; then
export ATEN_DISABLE_AVX2=1
export ATEN_CPU_CAPABILITY=default
elif [[ "${JOB_BASE_NAME}" == *-NO_AVX2-* ]]; then
export ATEN_CPU_CAPABILITY=avx
fi

test_python_nn() {
Expand Down
44 changes: 44 additions & 0 deletions aten/src/ATen/native/DispatchStub.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#include "DispatchStub.h"

#include <ATen/Error.h>

#include <cpuinfo.h>
#include <cstdlib>
#include <cstring>

namespace at { namespace native {

static CPUCapability compute_cpu_capability() {
auto envar = std::getenv("ATEN_CPU_CAPABILITY");
if (envar) {
if (strcmp(envar, "avx2") == 0) {
return CPUCapability::AVX2;
}
if (strcmp(envar, "avx") == 0) {
return CPUCapability::AVX;
}
if (strcmp(envar, "default") == 0) {
return CPUCapability::DEFAULT;
}
AT_WARN("ignoring invalid value for ATEN_CPU_CAPABILITY: ", envar);
}

#ifndef __powerpc__
if (cpuinfo_initialize()) {
if (cpuinfo_has_x86_avx2() && cpuinfo_has_x86_fma3()) {
return CPUCapability::AVX2;
}
if (cpuinfo_has_x86_avx()) {
return CPUCapability::AVX;
}
}
#endif
return CPUCapability::DEFAULT;
}

CPUCapability get_cpu_capability() {
static CPUCapability capability = compute_cpu_capability();
return capability;
}

}} // namespace at::native
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#pragma once

#include <cpuinfo.h>
#include <ATen/Error.h>
#include <ATen/ScalarType.h>
#include <type_traits>
#include <iostream>

// Implements instruction set specific function dispatch.
//
Expand All @@ -23,72 +23,82 @@
// REGISTER_DISPATCH(stub, &kernel);
//
// To call:
// stub(tensor);
// stub(kCPU, tensor);
//

namespace at {
namespace native {

enum class CPUCapability { DEFAULT, AVX, AVX2, NUM_OPTIONS };
enum class CPUCapability {
DEFAULT = 0,
AVX = 1,
AVX2 = 2,
NUM_OPTIONS
};

CPUCapability get_cpu_capability();

template <typename FnPtr>
struct DispatchStub {
static_assert(std::is_pointer<FnPtr>::value, "FnPtr should be a pointer type");

template <typename... ArgTypes>
void operator()(ArgTypes... args) {
if (!dispatch_ptr) {
dispatch_ptr = choose_impl();
void operator()(Backend backend, ArgTypes... args) {
if (backend == Backend::CPU) {
if (!dispatch_ptr) {
dispatch_ptr = choose_cpu_impl();
}
(*dispatch_ptr)(args...);
} else if (backend == Backend::CUDA) {
AT_ASSERTM(cuda_dispatch_ptr, "DispatchStub: missing CUDA kernel");
(*cuda_dispatch_ptr)(args...);
} else {
AT_ERROR("DispatchStub: unsupported backend", backend);
}
(*dispatch_ptr)(args...);
}

FnPtr choose_impl() {
// Do not use cpuinfo on PowerPC as it shows confusing errors when run on ppc
#ifndef __powerpc__
if (cpuinfo_initialize()) {
int avx2 = static_cast<int>(CPUCapability::AVX2);
if (!std::getenv("ATEN_DISABLE_AVX2") && cpuinfo_has_x86_avx2() &&
cpuinfo_has_x86_fma3() && table[avx2]) {
return table[avx2];
}
int avx = static_cast<int>(CPUCapability::AVX);
if (!std::getenv("ATEN_DISABLE_AVX") && cpuinfo_has_x86_avx() && table[avx]) {
return table[avx];
}
}
#endif
FnPtr choose_cpu_impl() {
int def = static_cast<int>(CPUCapability::DEFAULT);
int avx = static_cast<int>(CPUCapability::AVX);
int avx2 = static_cast<int>(CPUCapability::AVX2);

auto capability = static_cast<int>(get_cpu_capability());
if (capability >= avx2 && table[avx2]) {
return table[avx2];
}
if (capability >= avx && table[avx]) {
return table[avx];
}
AT_ASSERTM(table[def], "DispatchStub: missing default kernel");
return table[def];
}

FnPtr dispatch_ptr = nullptr;
FnPtr cuda_dispatch_ptr = nullptr;
FnPtr table[static_cast<int>(CPUCapability::NUM_OPTIONS)];
};


#if defined(CPU_CAPABILITY)
#if defined(CPU_CAPABILITY) || defined(__CUDACC__)

constexpr CPUCapability CURRENT_CAPABILITY = CPUCapability::CPU_CAPABILITY;
namespace {

// Registers an implementation a kernel for the current CPU capability.
template<typename FnPtr>
template <typename FnPtr>
struct RegisterDispatch {
RegisterDispatch(DispatchStub<FnPtr>& stub, FnPtr value) {
stub.table[static_cast<int>(CURRENT_CAPABILITY)] = value;
#if defined(__CUDACC__)
stub.cuda_dispatch_ptr = value;
#else
int cap = static_cast<int>(CPUCapability::CPU_CAPABILITY);
AT_ASSERT(!stub.table[cap])
stub.table[cap] = value;
#endif
}
};

// We only define the stub once in the DEFAULT capability compilation
#if defined(CPU_CAPABILITY_DEFAULT)
#define _DEFINE_STUB(stub, fn) DispatchStub<decltype(fn)> stub
#else
#define _DEFINE_STUB(stub, fn)
#endif
} // anonymous namespace

#define REGISTER_DISPATCH(stub, fn) \
_DEFINE_STUB(stub, fn); \
static RegisterDispatch<decltype(fn)> stub ## __register(stub, fn);

#endif
Expand Down
11 changes: 7 additions & 4 deletions aten/src/ATen/native/ReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
namespace at {
namespace native {

DispatchStub<reduce_fn> sum_kernel;
DispatchStub<reduce_fn> prod_kernel;

static inline Tensor integer_upcast(const Tensor& self, optional<ScalarType> dtype) {
ScalarType scalarType = self.type().scalarType();
ScalarType upcast_scalarType = dtype.value_or(at::isIntegralType(scalarType) ? ScalarType::Long : scalarType);
Expand Down Expand Up @@ -127,7 +130,7 @@ Tensor sum(const Tensor &self) {
Tensor _sum_cpu(const Tensor& self) {
if (self.is_contiguous()) {
Tensor result = at::empty({}, self.type());
sum_kernel(result, self, at::nullopt);
sum_kernel(kCPU, result, self, at::nullopt);
return result;
}
return self._sumall();
Expand All @@ -148,7 +151,7 @@ Tensor prod(const Tensor &self) {
Tensor _prod_cpu(const Tensor &self) {
if (self.is_contiguous()) {
Tensor result = at::empty({}, self.type());
prod_kernel(result, self, at::nullopt);
prod_kernel(kCPU, result, self, at::nullopt);
return result;
}
return self._prodall();
Expand Down Expand Up @@ -222,7 +225,7 @@ Tensor &_sum_out_cpu(Tensor &result, const Tensor &self, int64_t dim_,
return result;
if (self.is_contiguous() && result.is_contiguous()) {
_dimreduce_setup(result, self, dim);
sum_kernel(result, self, dim);
sum_kernel(kCPU, result, self, dim);
if (!keepdim) result.squeeze_(dim);
return result;
}
Expand Down Expand Up @@ -260,7 +263,7 @@ Tensor &_prod_out_cpu(Tensor &result, const Tensor &self, int64_t dim_,
return result;
if (self.is_contiguous() && result.is_contiguous()) {
_dimreduce_setup(result, self, dim);
prod_kernel(result, self, dim);
prod_kernel(kCPU, result, self, dim);
if (!keepdim) result.squeeze_(dim);
return result;
}
Expand Down
14 changes: 10 additions & 4 deletions aten/src/ATen/native/SoftMax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ Tensor softmax_cpu(const Tensor& input_, const int64_t dim_) {
dim >= 0 && dim < input.dim(),
"dim must be non-negative and less than input dimensions");
if (input.ndimension() > 0 && dim == input.ndimension() - 1) {
softmax_lastdim_kernel(output, input);
softmax_lastdim_kernel(kCPU, output, input);
} else {
AT_DISPATCH_FLOATING_TYPES(input.type(), "softmax", [&] {
host_softmax<scalar_t, false>(output, input, dim);
Expand All @@ -147,7 +147,7 @@ Tensor log_softmax_cpu(const Tensor& input_, const int64_t dim_) {
dim >= 0 && dim < input.dim(),
"dim must be non-negative and less than input dimensions");
if (input.ndimension() > 0 && dim == input.ndimension() - 1) {
log_softmax_lastdim_kernel(output, input);
log_softmax_lastdim_kernel(kCPU, output, input);
} else {
AT_DISPATCH_FLOATING_TYPES(input.type(), "log_softmax", [&] {
host_softmax<scalar_t, true>(output, input, dim);
Expand Down Expand Up @@ -176,7 +176,7 @@ Tensor softmax_backward_cpu(
dim >= 0 && dim < grad.dim(),
"dim must be non-negative and less than input dimensions");
if (grad.ndimension() > 0 && dim == grad.ndimension() - 1) {
softmax_backward_lastdim_kernel(grad_input, grad, output);
softmax_backward_lastdim_kernel(kCPU, grad_input, grad, output);
} else {
AT_DISPATCH_FLOATING_TYPES(grad.type(), "softmax_backward", [&] {
host_softmax_backward<scalar_t, false>(grad_input, grad, output, dim);
Expand Down Expand Up @@ -205,13 +205,19 @@ Tensor log_softmax_backward_cpu(
dim >= 0 && dim < grad.dim(),
"dim must be non-negative and less than input dimensions");
if (grad.ndimension() > 0 && dim == grad.ndimension() - 1) {
log_softmax_backward_lastdim_kernel(grad_input, grad, output);
log_softmax_backward_lastdim_kernel(kCPU, grad_input, grad, output);
} else {
AT_DISPATCH_FLOATING_TYPES(grad.type(), "log_softmax_backward", [&] {
host_softmax_backward<scalar_t, true>(grad_input, grad, output, dim);
});
}
return grad_input;
}

DispatchStub<forward_fn> softmax_lastdim_kernel;
DispatchStub<forward_fn> log_softmax_lastdim_kernel;
DispatchStub<backward_fn> softmax_backward_lastdim_kernel;
DispatchStub<backward_fn> log_softmax_backward_lastdim_kernel;

}
}
28 changes: 26 additions & 2 deletions aten/src/ATen/native/UnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,14 @@ Tensor& fill_(Tensor& self, const Tensor& value) {
Tensor& _##op##__cpu(Tensor& self_) { \
if (self_.numel() > 0) { \
Tensor self = sort_strides(self_); \
op##Impl(self, self); \
op##Impl(kCPU, self, self); \
} \
return self_; \
} \
Tensor& _##op##_out_cpu(Tensor& result, const Tensor& self) { \
result.resize_(self.sizes()); \
if (result.numel() > 0) { \
op##Impl(result, self); \
op##Impl(kCPU, result, self); \
} \
return result; \
}
Expand Down Expand Up @@ -145,5 +145,29 @@ IMPLEMENT_UNARY_OP_VEC(tan)
IMPLEMENT_UNARY_OP_VEC(tanh)
IMPLEMENT_UNARY_OP_VEC(trunc)

DispatchStub<unary_fn> absImpl;
DispatchStub<unary_fn> acosImpl;
DispatchStub<unary_fn> asinImpl;
DispatchStub<unary_fn> atanImpl;
DispatchStub<unary_fn> ceilImpl;
DispatchStub<unary_fn> cosImpl;
DispatchStub<unary_fn> erfImpl;
DispatchStub<unary_fn> erfcImpl;
DispatchStub<unary_fn> expImpl;
DispatchStub<unary_fn> expm1Impl;
DispatchStub<unary_fn> floorImpl;
DispatchStub<unary_fn> logImpl;
DispatchStub<unary_fn> log10Impl;
DispatchStub<unary_fn> log1pImpl;
DispatchStub<unary_fn> log2Impl;
DispatchStub<unary_fn> roundImpl;
DispatchStub<unary_fn> rsqrtImpl;
DispatchStub<unary_fn> sigmoidImpl;
DispatchStub<unary_fn> sinImpl;
DispatchStub<unary_fn> sqrtImpl;
DispatchStub<unary_fn> tanImpl;
DispatchStub<unary_fn> tanhImpl;
DispatchStub<unary_fn> truncImpl;

}
} // namespace at
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cpu/ReduceOpsKernel.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#pragma once

#include <ATen/ATen.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/optional.h>
#include "CapabilityDispatch.h"

namespace at {
namespace native {
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cpu/SoftmaxKernel.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma once

#include <ATen/ATen.h>
#include "CapabilityDispatch.h"
#include <ATen/native/DispatchStub.h>

namespace at {
namespace native {
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include "ATen/Dispatch.h"
#include "ATen/cpu/vml.h"
#include "ATen/CPUApplyUtils.h"
#include "ATen/native/cpu/CapabilityDispatch.h"
#include "ATen/native/DispatchStub.h"
#ifdef __AVX2__
#include "ATen/native/cpu/avx_mathfun.h"
#endif
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cpu/UnaryOpsKernel.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#pragma once

#include <ATen/ATen.h>
#include <ATen/native/DispatchStub.h>
#include <stdexcept>
#include "CapabilityDispatch.h"

namespace at { namespace native {

Expand Down

0 comments on commit bcf0bf4

Please sign in to comment.