Skip to content

Commit

Permalink
[StreamExecutor] Simplify Kernel classes
Browse files Browse the repository at this point in the history
Summary:
Make the Kernel class follow the pattern of the other classes. It now
has a type-safe user wrapper and a typeless, platform-specific handle.

Reviewers: jlebar

Subscribers: jprice, parallel_libs-commits

Differential Revision: https://reviews.llvm.org/D24043

llvm-svn: 280176
  • Loading branch information
henline committed Aug 30, 2016
1 parent ddb53dd commit 90ce6e1
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 212 deletions.
26 changes: 20 additions & 6 deletions parallel-libs/streamexecutor/include/streamexecutor/Device.h
Expand Up @@ -15,25 +15,39 @@
#ifndef STREAMEXECUTOR_DEVICE_H
#define STREAMEXECUTOR_DEVICE_H

#include <type_traits>

#include "streamexecutor/KernelSpec.h"
#include "streamexecutor/PlatformInterfaces.h"
#include "streamexecutor/Utils/Error.h"

namespace streamexecutor {

class KernelInterface;
class Stream;

class Device {
public:
explicit Device(PlatformDevice *PDevice);
virtual ~Device();

/// Gets the kernel implementation for the underlying platform.
virtual Expected<std::unique_ptr<KernelInterface>>
getKernelImplementation(const MultiKernelLoaderSpec &Spec) {
// TODO(jhen): Implement this.
return nullptr;
/// Creates a kernel object for this device.
///
/// If the return value is not an error, the returned pointer will never be
/// null.
///
/// See \ref CompilerGeneratedKernelExample "Kernel.h" for an example of how
/// this method is used.
template <typename KernelT>
Expected<std::unique_ptr<typename std::enable_if<
std::is_base_of<KernelBase, KernelT>::value, KernelT>::type>>
createKernel(const MultiKernelLoaderSpec &Spec) {
Expected<std::unique_ptr<PlatformKernelHandle>> MaybeKernelHandle =
PDevice->createKernel(Spec);
if (!MaybeKernelHandle) {
return MaybeKernelHandle.takeError();
}
return llvm::make_unique<KernelT>(Spec.getKernelName(),
std::move(*MaybeKernelHandle));
}

Expected<std::unique_ptr<Stream>> createStream();
Expand Down
114 changes: 40 additions & 74 deletions parallel-libs/streamexecutor/include/streamexecutor/Kernel.h
Expand Up @@ -11,62 +11,64 @@
/// Types to represent device kernels (code compiled to run on GPU or other
/// accelerator).
///
/// The TypedKernel class is used to provide type safety to the user API's
/// launch functions, and the KernelBase class is used like a void* function
/// pointer to perform type-unsafe operations inside StreamExecutor.
///
/// With the kernel parameter types recorded in the TypedKernel template
/// parameters, type-safe kernel launch functions can be written with signatures
/// like the following:
/// With the kernel parameter types recorded in the Kernel template parameters,
/// type-safe kernel launch functions can be written with signatures like the
/// following:
/// \code
/// template <typename... ParameterTs>
/// void Launch(
/// const TypedKernel<ParameterTs...> &Kernel, ParamterTs... Arguments);
/// const Kernel<ParameterTs...> &Kernel, ParamterTs... Arguments);
/// \endcode
/// and the compiler will check that the user passes in arguments with types
/// matching the corresponding kernel parameters.
///
/// A problem is that a TypedKernel template specialization with the right
/// parameter types must be passed as the first argument to the Launch function,
/// and it's just as hard to get the types right in that template specialization
/// as it is to get them right for the kernel arguments.
/// A problem is that a Kernel template specialization with the right parameter
/// types must be passed as the first argument to the Launch function, and it's
/// just as hard to get the types right in that template specialization as it is
/// to get them right for the kernel arguments.
///
/// With this problem in mind, it is not recommended for users to specialize the
/// TypedKernel template class themselves, but instead to let the compiler do it
/// for them. When the compiler encounters a device kernel function, it can
/// create a TypedKernel template specialization in the host code that has the
/// right parameter types for that kernel and which has a type name based on the
/// name of the kernel function.
/// Kernel template class themselves, but instead to let the compiler do it for
/// them. When the compiler encounters a device kernel function, it can create a
/// Kernel template specialization in the host code that has the right parameter
/// types for that kernel and which has a type name based on the name of the
/// kernel function.
///
/// \anchor CompilerGeneratedKernelExample
/// For example, if a CUDA device kernel function with the following signature
/// has been defined:
/// \code
/// void Saxpy(float *A, float *X, float *Y);
/// void Saxpy(float A, float *X, float *Y);
/// \endcode
/// the compiler can insert the following declaration in the host code:
/// \code
/// namespace compiler_cuda_namespace {
/// namespace se = streamexecutor;
/// using SaxpyKernel =
/// streamexecutor::TypedKernel<float *, float *, float *>;
/// se::Kernel<
/// float,
/// se::GlobalDeviceMemory<float>,
/// se::GlobalDeviceMemory<float>>;
/// } // namespace compiler_cuda_namespace
/// \endcode
/// and then the user can launch the kernel by calling the StreamExecutor launch
/// function as follows:
/// \code
/// namespace ccn = compiler_cuda_namespace;
/// using KernelPtr = std::unique_ptr<cnn::SaxpyKernel>;
/// // Assumes Device is a pointer to the Device on which to launch the
/// // kernel.
/// //
/// // See KernelSpec.h for details on how the compiler can create a
/// // MultiKernelLoaderSpec instance like SaxpyKernelLoaderSpec below.
/// Expected<ccn::SaxpyKernel> MaybeKernel =
/// ccn::SaxpyKernel::create(Device, ccn::SaxpyKernelLoaderSpec);
/// Expected<KernelPtr> MaybeKernel =
/// Device->createKernel<ccn::SaxpyKernel>(ccn::SaxpyKernelLoaderSpec);
/// if (!MaybeKernel) { /* Handle error */ }
/// ccn::SaxpyKernel SaxpyKernel = *MaybeKernel;
/// Launch(SaxpyKernel, A, X, Y);
/// KernelPtr SaxpyKernel = std::move(*MaybeKernel);
/// Launch(*SaxpyKernel, A, X, Y);
/// \endcode
///
/// With the compiler's help in specializing TypedKernel for each device kernel
/// With the compiler's help in specializing Kernel for each device kernel
/// function (and generating a MultiKernelLoaderSpec instance for each kernel),
/// the user can safely launch the device kernel from the host and get an error
/// message at compile time if the argument types don't match the kernel
Expand All @@ -84,73 +86,37 @@

namespace streamexecutor {

class Device;
class KernelInterface;
class PlatformKernelHandle;

/// The base class for device kernel functions.
///
/// This class has no information about the types of the parameters taken by the
/// kernel, so it is analogous to a void* pointer to a device function.
/// The base class for all kernel types.
///
/// See the TypedKernel class below for the subclass which does have information
/// about parameter types.
/// Stores the name of the kernel in both mangled and demangled forms.
class KernelBase {
public:
KernelBase(KernelBase &&) = default;
KernelBase &operator=(KernelBase &&) = default;
~KernelBase();

/// Creates a kernel object from a Device and a MultiKernelLoaderSpec.
///
/// The Device knows which platform it belongs to and the
/// MultiKernelLoaderSpec knows how to find the kernel code for different
/// platforms, so the combined information is enough to get the kernel code
/// for the appropriate platform.
static Expected<KernelBase> create(Device *Dev,
const MultiKernelLoaderSpec &Spec);
KernelBase(llvm::StringRef Name);

const std::string &getName() const { return Name; }
const std::string &getDemangledName() const { return DemangledName; }

/// Gets a pointer to the platform-specific implementation of this kernel.
KernelInterface *getImplementation() { return Implementation.get(); }

private:
KernelBase(Device *Dev, const std::string &Name,
const std::string &DemangledName,
std::unique_ptr<KernelInterface> Implementation);

Device *TheDevice;
std::string Name;
std::string DemangledName;
std::unique_ptr<KernelInterface> Implementation;

KernelBase(const KernelBase &) = delete;
KernelBase &operator=(const KernelBase &) = delete;
};

/// A device kernel function with specified parameter types.
template <typename... ParameterTs> class TypedKernel : public KernelBase {
/// A StreamExecutor kernel.
///
/// The template parameters are the types of the parameters to the kernel
/// function.
template <typename... ParameterTs> class Kernel : public KernelBase {
public:
TypedKernel(TypedKernel &&) = default;
TypedKernel &operator=(TypedKernel &&) = default;
Kernel(llvm::StringRef Name, std::unique_ptr<PlatformKernelHandle> PHandle)
: KernelBase(Name), PHandle(std::move(PHandle)) {}

/// Parameters here have the same meaning as in KernelBase::create.
static Expected<TypedKernel> create(Device *Dev,
const MultiKernelLoaderSpec &Spec) {
auto MaybeBase = KernelBase::create(Dev, Spec);
if (!MaybeBase) {
return MaybeBase.takeError();
}
TypedKernel Instance(std::move(*MaybeBase));
return std::move(Instance);
}
/// Gets the underlying platform-specific handle for this kernel.
PlatformKernelHandle *getPlatformHandle() const { return PHandle.get(); }

private:
TypedKernel(KernelBase &&Base) : KernelBase(std::move(Base)) {}

TypedKernel(const TypedKernel &) = delete;
TypedKernel &operator=(const TypedKernel &) = delete;
std::unique_ptr<PlatformKernelHandle> PHandle;
};

} // namespace streamexecutor
Expand Down
Expand Up @@ -33,9 +33,17 @@ namespace streamexecutor {

class PlatformDevice;

/// Methods supported by device kernel function objects on all platforms.
class KernelInterface {
// TODO(jhen): Add methods.
/// Platform-specific kernel handle.
class PlatformKernelHandle {
public:
explicit PlatformKernelHandle(PlatformDevice *PDevice) : PDevice(PDevice) {}

virtual ~PlatformKernelHandle();

PlatformDevice *getDevice() { return PDevice; }

private:
PlatformDevice *PDevice;
};

/// Platform-specific stream handle.
Expand Down Expand Up @@ -64,12 +72,20 @@ class PlatformDevice {

virtual std::string getName() const = 0;

/// Creates a platform-specific kernel.
virtual Expected<std::unique_ptr<PlatformKernelHandle>>
createKernel(const MultiKernelLoaderSpec &Spec) {
return make_error("createKernel not implemented for platform " + getName());
}

/// Creates a platform-specific stream.
virtual Expected<std::unique_ptr<PlatformStreamHandle>> createStream() = 0;
virtual Expected<std::unique_ptr<PlatformStreamHandle>> createStream() {
return make_error("createStream not implemented for platform " + getName());
}

/// Launches a kernel on the given stream.
virtual Error launch(PlatformStreamHandle *S, BlockDimensions BlockSize,
GridDimensions GridSize, const KernelBase &Kernel,
GridDimensions GridSize, PlatformKernelHandle *K,
const PackedKernelArgumentArrayBase &ArgumentArray) {
return make_error("launch not implemented for platform " + getName());
}
Expand Down
6 changes: 3 additions & 3 deletions parallel-libs/streamexecutor/include/streamexecutor/Stream.h
Expand Up @@ -86,15 +86,15 @@ class Stream {
/// These arguments can be device memory types like GlobalDeviceMemory<T> and
/// SharedDeviceMemory<T>, or they can be primitive types such as int. The
/// allowable argument types are determined by the template parameters to the
/// TypedKernel argument.
/// Kernel argument.
template <typename... ParameterTs>
Stream &thenLaunch(BlockDimensions BlockSize, GridDimensions GridSize,
const TypedKernel<ParameterTs...> &Kernel,
const Kernel<ParameterTs...> &K,
const ParameterTs &... Arguments) {
auto ArgumentArray =
make_kernel_argument_pack<ParameterTs...>(Arguments...);
setError(PDevice->launch(ThePlatformStream.get(), BlockSize, GridSize,
Kernel, ArgumentArray));
K.getPlatformHandle(), ArgumentArray));
return *this;
}

Expand Down
24 changes: 3 additions & 21 deletions parallel-libs/streamexecutor/lib/Kernel.cpp
Expand Up @@ -20,26 +20,8 @@

namespace streamexecutor {

KernelBase::KernelBase(Device *Dev, const std::string &Name,
const std::string &DemangledName,
std::unique_ptr<KernelInterface> Implementation)
: TheDevice(Dev), Name(Name), DemangledName(DemangledName),
Implementation(std::move(Implementation)) {}

KernelBase::~KernelBase() = default;

Expected<KernelBase> KernelBase::create(Device *Dev,
const MultiKernelLoaderSpec &Spec) {
auto MaybeImplementation = Dev->getKernelImplementation(Spec);
if (!MaybeImplementation) {
return MaybeImplementation.takeError();
}
std::string Name = Spec.getKernelName();
std::string DemangledName =
llvm::symbolize::LLVMSymbolizer::DemangleName(Name, nullptr);
KernelBase Instance(Dev, Name, DemangledName,
std::move(*MaybeImplementation));
return std::move(Instance);
}
KernelBase::KernelBase(llvm::StringRef Name)
: Name(Name), DemangledName(llvm::symbolize::LLVMSymbolizer::DemangleName(
Name, nullptr)) {}

} // namespace streamexecutor
10 changes: 0 additions & 10 deletions parallel-libs/streamexecutor/lib/unittests/CMakeLists.txt
Expand Up @@ -8,16 +8,6 @@ target_link_libraries(
${CMAKE_THREAD_LIBS_INIT})
add_test(DeviceTest device_test)

add_executable(
kernel_test
KernelTest.cpp)
target_link_libraries(
kernel_test
streamexecutor
${GTEST_BOTH_LIBRARIES}
${CMAKE_THREAD_LIBS_INIT})
add_test(KernelTest kernel_test)

add_executable(
kernel_spec_test
KernelSpecTest.cpp)
Expand Down

0 comments on commit 90ce6e1

Please sign in to comment.