Skip to content

Commit

Permalink
[mlir][vulkan-runner] Explicitly export vulkan-runtime-wrapper entry …
Browse files Browse the repository at this point in the history
…points.

This ensure that the symbols are being exported no matter what default
visibility is set.
  • Loading branch information
ThomasRaoux committed Sep 1, 2020
1 parent 49dda4e commit 8d65504
Showing 1 changed file with 43 additions and 28 deletions.
71 changes: 43 additions & 28 deletions mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp
Expand Up @@ -16,6 +16,9 @@

#include "VulkanRuntime.h"

// Explicitly export entry points to the vulkan-runtime-wrapper.
#define VULKAN_WRAPPER_SYMBOL_EXPORT __attribute__((visibility("default")))

namespace {

class VulkanRuntimeManager {
Expand Down Expand Up @@ -62,8 +65,7 @@ class VulkanRuntimeManager {

} // namespace

template <typename T, int N>
struct MemRefDescriptor {
template <typename T, int N> struct MemRefDescriptor {
T *allocated;
T *aligned;
int64_t offset;
Expand All @@ -84,37 +86,41 @@ void bindMemRef(void *vkRuntimeManager, DescriptorSetIndex setIndex,

extern "C" {
/// Initializes `VulkanRuntimeManager` and returns a pointer to it.
void *initVulkan() { return new VulkanRuntimeManager(); }
VULKAN_WRAPPER_SYMBOL_EXPORT void *initVulkan() {
return new VulkanRuntimeManager();
}

/// Deinitializes `VulkanRuntimeManager` by the given pointer.
void deinitVulkan(void *vkRuntimeManager) {
VULKAN_WRAPPER_SYMBOL_EXPORT void deinitVulkan(void *vkRuntimeManager) {
delete reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager);
}

void runOnVulkan(void *vkRuntimeManager) {
VULKAN_WRAPPER_SYMBOL_EXPORT void runOnVulkan(void *vkRuntimeManager) {
reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)->runOnVulkan();
}

void setEntryPoint(void *vkRuntimeManager, const char *entryPoint) {
VULKAN_WRAPPER_SYMBOL_EXPORT void setEntryPoint(void *vkRuntimeManager,
const char *entryPoint) {
reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
->setEntryPoint(entryPoint);
}

void setNumWorkGroups(void *vkRuntimeManager, uint32_t x, uint32_t y,
uint32_t z) {
VULKAN_WRAPPER_SYMBOL_EXPORT void
setNumWorkGroups(void *vkRuntimeManager, uint32_t x, uint32_t y, uint32_t z) {
reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
->setNumWorkGroups({x, y, z});
}

void setBinaryShader(void *vkRuntimeManager, uint8_t *shader, uint32_t size) {
VULKAN_WRAPPER_SYMBOL_EXPORT void
setBinaryShader(void *vkRuntimeManager, uint8_t *shader, uint32_t size) {
reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)
->setShaderModule(shader, size);
}

/// Binds the given memref to the given descriptor set and descriptor
/// index.
#define DECLARE_BIND_MEMREF(size, type, typeName) \
void bindMemRef##size##D##typeName( \
VULKAN_WRAPPER_SYMBOL_EXPORT void bindMemRef##size##D##typeName( \
void *vkRuntimeManager, DescriptorSetIndex setIndex, \
BindingIndex bindIndex, MemRefDescriptor<type, size> *ptr) { \
bindMemRef<type, size>(vkRuntimeManager, setIndex, bindIndex, ptr); \
Expand All @@ -137,58 +143,67 @@ DECLARE_BIND_MEMREF(2, int16_t, Half)
DECLARE_BIND_MEMREF(3, int16_t, Half)

/// Fills the given 1D float memref with the given float value.
void _mlir_ciface_fillResource1DFloat(MemRefDescriptor<float, 1> *ptr, // NOLINT
float value) {
VULKAN_WRAPPER_SYMBOL_EXPORT void
_mlir_ciface_fillResource1DFloat(MemRefDescriptor<float, 1> *ptr, // NOLINT
float value) {
std::fill_n(ptr->allocated, ptr->sizes[0], value);
}

/// Fills the given 2D float memref with the given float value.
void _mlir_ciface_fillResource2DFloat(MemRefDescriptor<float, 2> *ptr, // NOLINT
float value) {
VULKAN_WRAPPER_SYMBOL_EXPORT void
_mlir_ciface_fillResource2DFloat(MemRefDescriptor<float, 2> *ptr, // NOLINT
float value) {
std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value);
}

/// Fills the given 3D float memref with the given float value.
void _mlir_ciface_fillResource3DFloat(MemRefDescriptor<float, 3> *ptr, // NOLINT
float value) {
VULKAN_WRAPPER_SYMBOL_EXPORT void
_mlir_ciface_fillResource3DFloat(MemRefDescriptor<float, 3> *ptr, // NOLINT
float value) {
std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
value);
}

/// Fills the given 1D int memref with the given int value.
void _mlir_ciface_fillResource1DInt(MemRefDescriptor<int32_t, 1> *ptr, // NOLINT
int32_t value) {
VULKAN_WRAPPER_SYMBOL_EXPORT void
_mlir_ciface_fillResource1DInt(MemRefDescriptor<int32_t, 1> *ptr, // NOLINT
int32_t value) {
std::fill_n(ptr->allocated, ptr->sizes[0], value);
}

/// Fills the given 2D int memref with the given int value.
void _mlir_ciface_fillResource2DInt(MemRefDescriptor<int32_t, 2> *ptr, // NOLINT
int32_t value) {
VULKAN_WRAPPER_SYMBOL_EXPORT void
_mlir_ciface_fillResource2DInt(MemRefDescriptor<int32_t, 2> *ptr, // NOLINT
int32_t value) {
std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value);
}

/// Fills the given 3D int memref with the given int value.
void _mlir_ciface_fillResource3DInt(MemRefDescriptor<int32_t, 3> *ptr, // NOLINT
int32_t value) {
VULKAN_WRAPPER_SYMBOL_EXPORT void
_mlir_ciface_fillResource3DInt(MemRefDescriptor<int32_t, 3> *ptr, // NOLINT
int32_t value) {
std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
value);
}

/// Fills the given 1D int memref with the given int8 value.
void _mlir_ciface_fillResource1DInt8(MemRefDescriptor<int8_t, 1> *ptr, // NOLINT
int8_t value) {
VULKAN_WRAPPER_SYMBOL_EXPORT void
_mlir_ciface_fillResource1DInt8(MemRefDescriptor<int8_t, 1> *ptr, // NOLINT
int8_t value) {
std::fill_n(ptr->allocated, ptr->sizes[0], value);
}

/// Fills the given 2D int memref with the given int8 value.
void _mlir_ciface_fillResource2DInt8(MemRefDescriptor<int8_t, 2> *ptr, // NOLINT
int8_t value) {
VULKAN_WRAPPER_SYMBOL_EXPORT void
_mlir_ciface_fillResource2DInt8(MemRefDescriptor<int8_t, 2> *ptr, // NOLINT
int8_t value) {
std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value);
}

/// Fills the given 3D int memref with the given int8 value.
void _mlir_ciface_fillResource3DInt8(MemRefDescriptor<int8_t, 3> *ptr, // NOLINT
int8_t value) {
VULKAN_WRAPPER_SYMBOL_EXPORT void
_mlir_ciface_fillResource3DInt8(MemRefDescriptor<int8_t, 3> *ptr, // NOLINT
int8_t value) {
std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1] * ptr->sizes[2],
value);
}
Expand Down

0 comments on commit 8d65504

Please sign in to comment.