Skip to content

Commit

Permalink
Remove .impl_UNBOXED() and functionalities associated with it (pytorc…
Browse files Browse the repository at this point in the history
…h#49220)

Summary:
Pull Request resolved: pytorch#49220

Since all ops are c10-full, we can remove .impl_UNBOXED now.
This also removes the ability of KernelFunction or CppFunction to store unboxedOnly kernels.
ghstack-source-id: 119450489

Test Plan: waitforsandcastle

Reviewed By: ezyang

Differential Revision: D25490225

fbshipit-source-id: 32de9d591e6a842fe18abc82541580647e9cfdad
  • Loading branch information
smessmer authored and hwangdeyu committed Jan 14, 2021
1 parent f285487 commit 00c95ca
Show file tree
Hide file tree
Showing 16 changed files with 64 additions and 218 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/BatchingRegistrations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1015,7 +1015,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
m.impl("_add_batch_dim", native::_add_batch_dim);
m.impl("_remove_batch_dim", native::_remove_batch_dim);

m.impl_UNBOXED("sum.dim_IntList", sum_batching_rule);
m.impl("sum.dim_IntList", sum_batching_rule);
m.impl("is_complex", native::is_complex);
m.impl("conj", native::conj);

Expand Down
30 changes: 13 additions & 17 deletions aten/src/ATen/autocast_mode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,13 +239,9 @@ Therefore, for the moment, this is all copy pasted in from VariableTypeEverythin
m.impl(TORCH_SELECTIVE_NAME("aten::" REGISTER_NAME), \
&WrapFunction<CastPolicy::POLICY, SIGNATURE, SIGNATURE, &FUNC>::type::call);

#define KERNEL_UNBOXED_ONLY(FUNC, REGISTER_NAME, SIGNATURE, POLICY) \
m.impl_UNBOXED(TORCH_SELECTIVE_NAME("aten::" REGISTER_NAME), \
&WrapFunction<CastPolicy::POLICY, SIGNATURE, SIGNATURE, &FUNC>::type::call);

// Less-common but still useful case: redispatching to a function with a new signature (e.g. appending a dtype)
#define KERNEL_UNBOXED_ONLY_DIFFERENT_REDISPATCH_SIGNATURE(REDISPATCH_FUNC, REGISTER_NAME, REGISTER_SIGNATURE, REDISPATCH_SIGNATURE, POLICY) \
m.impl_UNBOXED(TORCH_SELECTIVE_NAME("aten::" REGISTER_NAME), \
#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(REDISPATCH_FUNC, REGISTER_NAME, REGISTER_SIGNATURE, REDISPATCH_SIGNATURE, POLICY) \
m.impl(TORCH_SELECTIVE_NAME("aten::" REGISTER_NAME), \
&WrapFunction<CastPolicy::POLICY, REGISTER_SIGNATURE, REDISPATCH_SIGNATURE, &REDISPATCH_FUNC>::type::call);

/*****************************************
Expand Down Expand Up @@ -367,41 +363,41 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
KERNEL(ADD_NS(binary_cross_entropy_with_logits), "binary_cross_entropy_with_logits", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, const c10::optional<Tensor>&, int64_t), fp32)
KERNEL(ADD_NS(dist), "dist", Tensor (const Tensor &, const Tensor &, Scalar), fp32)
KERNEL(ADD_NS(pdist), "pdist", Tensor (const Tensor &, double), fp32)
KERNEL_UNBOXED_ONLY(ADD_NS(cdist), "cdist", Tensor (const Tensor &, const Tensor &, double, c10::optional<int64_t>), fp32)
KERNEL(ADD_NS(cdist), "cdist", Tensor (const Tensor &, const Tensor &, double, c10::optional<int64_t>), fp32)
KERNEL(ADD_NS(renorm), "renorm", Tensor (const Tensor &, Scalar, int64_t, Scalar), fp32)
// fp32_set_opt_dtype
KERNEL(ADD_NS(prod), "prod", Tensor (const Tensor &, c10::optional<ScalarType>), fp32_set_opt_dtype)
KERNEL(ADD_NS(prod), "prod.dim_int", Tensor (const Tensor &, int64_t, bool, c10::optional<ScalarType>), fp32_set_opt_dtype)
KERNEL_UNBOXED_ONLY(ADD_NS(prod), "prod.dim_Dimname", Tensor (const Tensor &, Dimname, bool, c10::optional<ScalarType>), fp32_set_opt_dtype)
KERNEL(ADD_NS(prod), "prod.dim_Dimname", Tensor (const Tensor &, Dimname, bool, c10::optional<ScalarType>), fp32_set_opt_dtype)
KERNEL(ADD_NS(softmax), "softmax.int", Tensor (const Tensor &, int64_t, c10::optional<ScalarType>), fp32_set_opt_dtype)
KERNEL_UNBOXED_ONLY(ADD_NS(softmax), "softmax.Dimname", Tensor (const Tensor &, Dimname, c10::optional<ScalarType>), fp32_set_opt_dtype)
KERNEL(ADD_NS(softmax), "softmax.Dimname", Tensor (const Tensor &, Dimname, c10::optional<ScalarType>), fp32_set_opt_dtype)
KERNEL(ADD_NS(log_softmax), "log_softmax.int", Tensor (const Tensor &, int64_t, c10::optional<ScalarType>), fp32_set_opt_dtype)
KERNEL_UNBOXED_ONLY(ADD_NS(log_softmax), "log_softmax.Dimname", Tensor (const Tensor &, Dimname, c10::optional<ScalarType>), fp32_set_opt_dtype)
KERNEL(ADD_NS(log_softmax), "log_softmax.Dimname", Tensor (const Tensor &, Dimname, c10::optional<ScalarType>), fp32_set_opt_dtype)
KERNEL(ADD_NS(cumprod), "cumprod", Tensor (const Tensor &, int64_t, c10::optional<ScalarType>), fp32_set_opt_dtype)
KERNEL_UNBOXED_ONLY(ADD_NS(cumprod), "cumprod.dimname", Tensor (const Tensor &, Dimname, c10::optional<ScalarType>), fp32_set_opt_dtype)
KERNEL(ADD_NS(cumprod), "cumprod.dimname", Tensor (const Tensor &, Dimname, c10::optional<ScalarType>), fp32_set_opt_dtype)
KERNEL(ADD_NS(cumsum), "cumsum", Tensor (const Tensor &, int64_t, c10::optional<ScalarType>), fp32_set_opt_dtype)
KERNEL_UNBOXED_ONLY(ADD_NS(cumsum), "cumsum.dimname", Tensor (const Tensor &, Dimname, c10::optional<ScalarType>), fp32_set_opt_dtype)
KERNEL(ADD_NS(cumsum), "cumsum.dimname", Tensor (const Tensor &, Dimname, c10::optional<ScalarType>), fp32_set_opt_dtype)
// commenting these out because they accept an explicit (not-optional) dtype, and we shouldn't try to flip that even
// when autocasting.
// KERNEL(ADD_NS(norm), "norm.ScalarOpt_dtype", Tensor (const Tensor &, c10::optional<Scalar>, ScalarType), fp32_set_opt_dtype)
// KERNEL(ADD_NS(norm), "norm.ScalarOpt_dim_dtype", Tensor (const Tensor &, c10::optional<Scalar>, IntArrayRef, bool, ScalarType), fp32_set_opt_dtype)
// KERNEL(ADD_NS(norm), "norm.names_ScalarOpt_dim_dtype", Tensor (const Tensor &, c10::optional<Scalar>, DimnameList, bool, ScalarType), fp32_set_opt_dtype)
KERNEL(ADD_NS(sum), "sum", Tensor (const Tensor &, c10::optional<ScalarType>), fp32_set_opt_dtype)
KERNEL(ADD_NS(sum), "sum.dim_IntList", Tensor (const Tensor &, IntArrayRef, bool, c10::optional<ScalarType>), fp32_set_opt_dtype)
KERNEL_UNBOXED_ONLY(ADD_NS(sum), "sum.dim_DimnameList", Tensor (const Tensor &, DimnameList, bool, c10::optional<ScalarType>), fp32_set_opt_dtype)
KERNEL(ADD_NS(sum), "sum.dim_DimnameList", Tensor (const Tensor &, DimnameList, bool, c10::optional<ScalarType>), fp32_set_opt_dtype)
// fp32_append_dtype
// The fp32_append_dtype wrapper overrides implicit promotion behavior.
// norm does not implicitly promote, but be aware when adding new ops to this policy.
KERNEL_UNBOXED_ONLY_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.Scalar", Tensor (const Tensor &, Scalar), Tensor (const Tensor &, c10::optional<Scalar>, ScalarType), fp32_append_dtype)
KERNEL_UNBOXED_ONLY_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.ScalarOpt_dim", Tensor (const Tensor &, c10::optional<Scalar>, IntArrayRef, bool), Tensor (const Tensor &, c10::optional<Scalar>, IntArrayRef, bool, ScalarType), fp32_append_dtype)
KERNEL_UNBOXED_ONLY_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.names_ScalarOpt_dim", Tensor (const Tensor &, c10::optional<Scalar>, DimnameList, bool), Tensor (const Tensor &, c10::optional<Scalar>, DimnameList, bool, ScalarType), fp32_append_dtype)
KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.Scalar", Tensor (const Tensor &, Scalar), Tensor (const Tensor &, c10::optional<Scalar>, ScalarType), fp32_append_dtype)
KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.ScalarOpt_dim", Tensor (const Tensor &, c10::optional<Scalar>, IntArrayRef, bool), Tensor (const Tensor &, c10::optional<Scalar>, IntArrayRef, bool, ScalarType), fp32_append_dtype)
KERNEL_DIFFERENT_REDISPATCH_SIGNATURE(ADD_NS(norm), "norm.names_ScalarOpt_dim", Tensor (const Tensor &, c10::optional<Scalar>, DimnameList, bool), Tensor (const Tensor &, c10::optional<Scalar>, DimnameList, bool, ScalarType), fp32_append_dtype)
// promote
KERNEL(ADD_NS(addcdiv), "addcdiv", Tensor (const Tensor &, const Tensor &, const Tensor &, Scalar), promote)
KERNEL(ADD_NS(addcmul), "addcmul", Tensor (const Tensor &, const Tensor &, const Tensor &, Scalar), promote)
KERNEL(ADD_NS(atan2), "atan2", Tensor (const Tensor &, const Tensor &), promote)
KERNEL(ADD_NS(bilinear), "bilinear", Tensor (const Tensor &, const Tensor &, const Tensor &, const c10::optional<Tensor>&), promote)
KERNEL(ADD_NS(cat), "cat", Tensor (TensorList, int64_t), promote)
KERNEL_UNBOXED_ONLY(ADD_NS(cat), "cat.names", Tensor (TensorList, Dimname), promote)
KERNEL(ADD_NS(cat), "cat.names", Tensor (TensorList, Dimname), promote)
KERNEL(ADD_NS(_cat), "_cat", Tensor (TensorList, int64_t), promote)
KERNEL(ADD_NS(cross), "cross", Tensor (const Tensor &, const Tensor &, c10::optional<int64_t>), promote)
KERNEL(ADD_NS(dot), "dot", Tensor (const Tensor &, const Tensor &), promote)
Expand Down
21 changes: 0 additions & 21 deletions aten/src/ATen/core/boxing/KernelFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,25 +57,4 @@ bool KernelFunction::_equalsBoxedAndUnboxed(const KernelFunction& other) const {
unboxed_kernel_func_ == other.unboxed_kernel_func_;
}

void KernelFunction::checkBoxedKernel(const OperatorHandle& opHandle) const {
if (C10_UNLIKELY(boxed_kernel_func_ == nullptr)) {
if (unboxed_kernel_func_ == nullptr) {
TORCH_INTERNAL_ASSERT(
false,
"Tried to call KernelFunction::callBoxed() on an uninitialized KernelFunction.",
" opname: ",
opHandle.operator_name(),
" If you're using mobile selective build please make sure to include all ops exported from `torch.jit.export_opnames(model)`.");
} else {
// TODO We want to introduce the invariant that all kernels must be callable in a boxed way, then this case should be impossible.
TORCH_INTERNAL_ASSERT(
false,
"Tried to call KernelFunction::callBoxed() on a KernelFunction that can only be called with KernelFunction::call().",
" opname: ",
opHandle.operator_name(),
" If you're using mobile selective build please make sure to include all ops exported from `torch.jit.export_opnames(model)`.");
}
}
}

} // namespace c10
42 changes: 0 additions & 42 deletions aten/src/ATen/core/boxing/KernelFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,26 +123,6 @@ class TORCH_API KernelFunction final {
template<bool AllowLegacyTypes = false, class KernelFunctor>
static KernelFunction makeFromUnboxedFunctor(std::unique_ptr<OperatorKernel> kernelFunctor);

/**
* Create a KernelFunction from an unboxed functor and prevent creation of an
* unboxing-wrapper. This means that you cannot call this KernelFunction
* using KernelFunction::callBoxed()
*
* This is necessary because our unboxing wrappers don't work for all types
* yet, so if you want to use one of these types as function arguments,
* you need to use makeFromUnboxedOnlyFunctor.
*
* Example:
*
* > class MyFunctor final {
* > public:
* > Tensor operator()(Tensor a, Tensor b) {...}
* > };
* > KernelFunction func = KernelFunction::makeFromUnboxedOnlyFunctor(std::make_unique<MyFunctor>());
*/
template<class KernelFunctor>
static KernelFunction makeFromUnboxedOnlyFunctor(std::unique_ptr<OperatorKernel> kernelFunctor);

/**
* Create a KernelFunction from an unboxed function.
* This is usually better than KernelFunction::makeFromUnboxedRuntimeFunction
Expand All @@ -158,23 +138,6 @@ class TORCH_API KernelFunction final {
template<class FuncPtr, bool AllowLegacyTypes = false>
static KernelFunction makeFromUnboxedFunction(FuncPtr);

/**
* Create a KernelFunction from an unboxed function and prevent creation of an
* unboxing-wrapper. This means that you cannot call this KernelFunction
* using KernelFunction::callBoxed()
*
* This is necessary because our unboxing wrappers don't work for all types
* yet, so if you want to use one of these types as function arguments,
* you need to use makeFromUnboxedOnlyFunctor.
*
* Example:
*
* > Tensor unboxed_func(Tensor a, Tensor b) {...}
* > KernelFunction func = KernelFunction::makeFromUnboxedOnlyFunction<decltype(unboxed_func), &unboxed_func>();
*/
template<class FuncPtr>
static KernelFunction makeFromUnboxedOnlyFunction(FuncPtr);

/**
* Create a KernelFunction from an unboxed function.
* KernelFunction::makeFromUnboxedFunction is usually a better choice than
Expand All @@ -189,9 +152,6 @@ class TORCH_API KernelFunction final {
template<bool AllowLegacyTypes = false, class FuncType>
static KernelFunction makeFromUnboxedRuntimeFunction(FuncType* func);

template<class FuncType>
static KernelFunction makeFromUnboxedOnlyRuntimeFunction(FuncType* func);

static KernelFunction makeFallthrough();
static KernelFunction makeAmbiguousAutogradOther();
static KernelFunction makeNamedNotSupported();
Expand Down Expand Up @@ -226,8 +186,6 @@ class TORCH_API KernelFunction final {
template<BoxedKernelFunction* func>
static void make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, Stack* stack);

void checkBoxedKernel(const OperatorHandle& opHandle) const;

OperatorKernel* getFunctor_() const;

std::shared_ptr<OperatorKernel> functor_;
Expand Down
54 changes: 5 additions & 49 deletions aten/src/ATen/core/boxing/KernelFunction_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,18 @@ inline void KernelFunction::make_boxed_function(OperatorKernel*, const OperatorH
}

inline bool KernelFunction::isValid() const {
// TODO We want to introduce the invariant that all kernels must be callable in a boxed way, then this should only check boxed_kernel_func_.
return boxed_kernel_func_ != nullptr || unboxed_kernel_func_ != nullptr;
return boxed_kernel_func_ != nullptr;
}

inline bool KernelFunction::isFallthrough() const {
return boxed_kernel_func_ == &fallthrough_kernel;
}

inline void KernelFunction::callBoxed(const OperatorHandle& opHandle, Stack* stack) const {
checkBoxedKernel(opHandle);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
boxed_kernel_func_ != nullptr,
"Tried to call KernelFunction::callBoxed() on an uninitialized KernelFunction."
);
(*boxed_kernel_func_)(functor_.get(), opHandle, stack);
}

Expand Down Expand Up @@ -111,21 +113,6 @@ inline KernelFunction KernelFunction::makeFromUnboxedFunctor(std::unique_ptr<Ope
);
}

template<class KernelFunctor>
inline KernelFunction KernelFunction::makeFromUnboxedOnlyFunctor(std::unique_ptr<OperatorKernel> kernelFunctor) {
// TODO We want to get rid of kernels that have only an unboxed function pointer.
// All kernels should have a boxed pointer.

static_assert(guts::is_functor<KernelFunctor>::value, "Tried to call KernelFunction::makeFromUnboxedFunctor<KernelFunctor> but the argument is not a functor.");
static_assert(std::is_base_of<OperatorKernel, KernelFunctor>::value, "Tried to call KernelFunction::makeFromUnboxedFunctor<KernelFunctor>, but the functor doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it.");

return KernelFunction(
std::move(kernelFunctor),
nullptr, // Don't create a boxed kernel for this
reinterpret_cast<void*>(&impl::wrap_kernel_functor_unboxed<KernelFunctor>::call)
);
}

template<class FuncPtr, bool AllowLegacyTypes>
inline KernelFunction KernelFunction::makeFromUnboxedFunction(FuncPtr func_ptr) {
static_assert(is_compile_time_function_pointer<FuncPtr>::value, "Tried to call KernelFunction::makeFromUnboxedFunction with an invalid parameter. It must be a function pointer created with TORCH_FN.");
Expand All @@ -144,26 +131,6 @@ inline KernelFunction KernelFunction::makeFromUnboxedFunction(FuncPtr func_ptr)
#endif
}

template<class FuncPtr>
inline KernelFunction KernelFunction::makeFromUnboxedOnlyFunction(FuncPtr func_ptr) {
// TODO We want to get rid of kernels that have only an unboxed function pointer.
// All kernels should have a boxed pointer.
static_assert(is_compile_time_function_pointer<FuncPtr>::value, "Tried to call KernelFunction::makeFromUnboxedOnlyFunction with an invalid parameter. It must be a function pointer created with TORCH_FN.");
static_assert(!std::is_same<typename FuncPtr::FuncType, BoxedKernelFunction>::value, "Tried to call KernelFunction::makeFromUnboxedOnlyFunction with a boxed function pointer. Please use KernelFunction::makeFromBoxedFunction instead.");
static_assert(FuncPtr::func_ptr() != nullptr, "Kernel function cannot be nullptr");

#if !defined(C10_MOBILE)
return makeFromUnboxedOnlyFunctor<typename impl::WrapFunctionIntoFunctor<FuncPtr>::type> (
guts::make_unique_base<OperatorKernel, typename impl::WrapFunctionIntoFunctor<FuncPtr>::type>()
);
#else
// On mobile, we rather want to optimize for binary size than for performance,
// so let's not inline the kernel into the wrapper but use makeFromUnboxedOnlyRuntimeFunction
// instead.
return makeFromUnboxedOnlyRuntimeFunction(func_ptr.func_ptr());
#endif
}

template<bool AllowLegacyTypes, class FuncType>
inline KernelFunction KernelFunction::makeFromUnboxedRuntimeFunction(FuncType* func) {
static_assert(guts::is_function_type<FuncType>::value, "Tried to call KernelFunction::makeFromUnboxedRuntimeFunction with a non-function type.");
Expand All @@ -175,17 +142,6 @@ inline KernelFunction KernelFunction::makeFromUnboxedRuntimeFunction(FuncType* f
);
}

template<class FuncType>
inline KernelFunction KernelFunction::makeFromUnboxedOnlyRuntimeFunction(FuncType* func) {
static_assert(guts::is_function_type<FuncType>::value, "Tried to call KernelFunction::makeFromUnboxedRuntimeFunction with a non-function type.");
static_assert(!std::is_same<FuncType, BoxedKernelFunction>::value, "Tried to call KernelFunction::makeFromUnboxedRuntimeFunction with a boxed function pointer. Please use KernelFunction::makeFromBoxedFunction instead.");
TORCH_INTERNAL_ASSERT(func != nullptr, "Kernel function cannot be nullptr");

return makeFromUnboxedOnlyFunctor<impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<FuncType>>>(
guts::make_unique_base<OperatorKernel, impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<FuncType>>>(func)
);
}

template<bool AllowLegacyTypes, class Lambda>
inline std::enable_if_t<guts::is_stateless_lambda<std::decay_t<Lambda>>::value, KernelFunction> KernelFunction::makeFromUnboxedLambda(Lambda&& lambda) {
static_assert(guts::is_functor<std::decay_t<Lambda>>::value, "Tried to call KernelFunction::makeFromUnboxedLambda with a non-lambda type.");
Expand Down

0 comments on commit 00c95ca

Please sign in to comment.