Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] Add SyclRuntimeWrapper #69648

Merged
merged 8 commits into from Oct 26, 2023
Merged

[MLIR] Add SyclRuntimeWrapper #69648

merged 8 commits into from Oct 26, 2023

Conversation

nbpatel
Copy link
Contributor

@nbpatel nbpatel commented Oct 19, 2023

This PR is a breakdown of the big PR #65539 which enables intel gpu integration. In this PR we add the code for the sycl runtime wrappers and also the cmake modules to find the dependent components. Integration test will be a follow up PR.

This PR is a joint effort by Nishant Patel & Sang Ik Lee.

@github-actions
Copy link

github-actions bot commented Oct 19, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@nbpatel
Copy link
Contributor Author

nbpatel commented Oct 23, 2023

@grypp @joker-eph tagging for review since I cannot add reviewers

#include <CL/sycl.hpp>
#include <level_zero/ze_api.h>
#include <map>
#include <mutex>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of those includes are not being used and can be removed (map, mutex, vector, atomic)

Copy link
Member

@grypp grypp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall to me. I left a few comments

isDeviceInitialised = true;
return syclDevice;
}
throw std::runtime_error("getDefaultDevice failed");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In our other runtimes, we don't use 'throw' or 'std::runtime_error.' Instead, we use on 'fprintf.' (see CUDA and ROCm).

From what I understand, throwing exceptions are heavy and requires a C++ runtime. Do you really want to use exceptions in this context?

Our approach keeps the error lightweight. In the event of an error, the runtime prints the error message, and the program proceeds to run. The compiler can implement fallback if needed (we don't do it right now).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just side comment, SYCL itself uses exceptions to report errors, so it's not possible to disable them completely and that's catchAll wrapper is for.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the explanation! In this case, using throw std::runtime makes perfect sense to me. It aligns the behavior of this small runtime with SYCL.

I was wondering if you truly needed exceptions here, and it turns out that yes, you do want them.

mlir/lib/ExecutionEngine/SyclRuntimeWrappers.cpp Outdated Show resolved Hide resolved
@grypp grypp requested a review from joker-eph October 24, 2023 13:17
@llvmbot
Copy link
Collaborator

llvmbot commented Oct 25, 2023

@llvm/pr-subscribers-mlir-execution-engine

@llvm/pr-subscribers-mlir

Author: Nishant Patel (nbpatel)

Changes

This PR is a breakdown of the big PR #65539 which enables intel gpu integration. In this PR we add the code for the sycl runtime wrappers and also the cmake modules to find the dependent components. Integration test will be a follow up PR.

This PR is a joint effort by Nishant Patel & Sang Ik Lee.


Full diff: https://github.com/llvm/llvm-project/pull/69648.diff

5 Files Affected:

  • (modified) mlir/CMakeLists.txt (+1)
  • (added) mlir/cmake/modules/FindLevelZero.cmake (+221)
  • (added) mlir/cmake/modules/FindSyclRuntime.cmake (+68)
  • (modified) mlir/lib/ExecutionEngine/CMakeLists.txt (+36)
  • (added) mlir/lib/ExecutionEngine/SyclRuntimeWrappers.cpp (+210)
diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt
index ac120aad0d1eda7..16ff950089734b7 100644
--- a/mlir/CMakeLists.txt
+++ b/mlir/CMakeLists.txt
@@ -126,6 +126,7 @@ add_definitions(-DMLIR_ROCM_CONVERSIONS_ENABLED=${MLIR_ENABLE_ROCM_CONVERSIONS})
 set(MLIR_ENABLE_DEPRECATED_GPU_SERIALIZATION 0 CACHE BOOL "Enable deprecated GPU serialization passes")
 set(MLIR_ENABLE_CUDA_RUNNER 0 CACHE BOOL "Enable building the mlir CUDA runner")
 set(MLIR_ENABLE_ROCM_RUNNER 0 CACHE BOOL "Enable building the mlir ROCm runner")
+set(MLIR_ENABLE_SYCL_RUNNER 0 CACHE BOOL "Enable building the mlir Sycl runner")
 set(MLIR_ENABLE_SPIRV_CPU_RUNNER 0 CACHE BOOL "Enable building the mlir SPIR-V cpu runner")
 set(MLIR_ENABLE_VULKAN_RUNNER 0 CACHE BOOL "Enable building the mlir Vulkan runner")
 set(MLIR_ENABLE_NVPTXCOMPILER 0 CACHE BOOL
diff --git a/mlir/cmake/modules/FindLevelZero.cmake b/mlir/cmake/modules/FindLevelZero.cmake
new file mode 100644
index 000000000000000..012187f0afc0b07
--- /dev/null
+++ b/mlir/cmake/modules/FindLevelZero.cmake
@@ -0,0 +1,221 @@
+# CMake find_package() module for level-zero
+#
+# Example usage:
+#
+# find_package(LevelZero)
+#
+# If successful, the following variables will be defined:
+# LevelZero_FOUND
+# LevelZero_INCLUDE_DIRS
+# LevelZero_LIBRARY
+# LevelZero_LIBRARIES_DIR
+#
+# By default, the module searches the standard paths to locate the "ze_api.h"
+# and the ze_loader shared library. When using a custom level-zero installation,
+# the environment variable "LEVEL_ZERO_DIR" should be specified telling the
+# module to get the level-zero library and headers from that location.
+
+include(FindPackageHandleStandardArgs)
+
+# Search path priority
+# 1. CMake Variable LEVEL_ZERO_DIR
+# 2. Environment Variable LEVEL_ZERO_DIR
+
+if(NOT LEVEL_ZERO_DIR)
+    if(DEFINED ENV{LEVEL_ZERO_DIR})
+        set(LEVEL_ZERO_DIR "$ENV{LEVEL_ZERO_DIR}")
+    endif()
+endif()
+
+if(LEVEL_ZERO_DIR)
+    find_path(LevelZero_INCLUDE_DIR
+        NAMES level_zero/ze_api.h
+        PATHS ${LEVEL_ZERO_DIR}/include
+        NO_DEFAULT_PATH
+    )
+
+    if(LINUX)
+        find_library(LevelZero_LIBRARY
+            NAMES ze_loader
+            PATHS ${LEVEL_ZERO_DIR}/lib
+                  ${LEVEL_ZERO_DIR}/lib/x86_64-linux-gnu
+            NO_DEFAULT_PATH
+        )
+    else()
+        find_library(LevelZero_LIBRARY
+            NAMES ze_loader
+            PATHS ${LEVEL_ZERO_DIR}/lib
+            NO_DEFAULT_PATH
+        )
+    endif()
+else()
+    find_path(LevelZero_INCLUDE_DIR
+        NAMES level_zero/ze_api.h
+    )
+
+    find_library(LevelZero_LIBRARY
+        NAMES ze_loader
+    )
+endif()
+
+# Compares the two version string that are supposed to be in x.y.z format
+# and reports if the argument VERSION_STR1 is greater than or equal than
+# version_str2. The strings are compared lexicographically after conversion to
+# lists of equal lengths, with the shorter string getting zero-padded.
+function(compare_versions VERSION_STR1 VERSION_STR2 OUTPUT)
+    # Convert the strings to list
+    string(REPLACE  "." ";" VL1 ${VERSION_STR1})
+    string(REPLACE  "." ";" VL2 ${VERSION_STR2})
+    # get lengths of both lists
+    list(LENGTH VL1 VL1_LEN)
+    list(LENGTH VL2 VL2_LEN)
+    set(LEN ${VL1_LEN})
+    # If they differ in size pad the shorter list with 0s
+    if(VL1_LEN GREATER VL2_LEN)
+        math(EXPR DIFF "${VL1_LEN} - ${VL2_LEN}" OUTPUT_FORMAT DECIMAL)
+        foreach(IDX RANGE 1 ${DIFF} 1)
+            list(APPEND VL2 "0")
+        endforeach()
+    elseif(VL2_LEN GREATER VL2_LEN)
+        math(EXPR DIFF "${VL1_LEN} - ${VL2_LEN}" OUTPUT_FORMAT DECIMAL)
+        foreach(IDX RANGE 1 ${DIFF} 1)
+            list(APPEND VL2 "0")
+        endforeach()
+        set(LEN ${VL2_LEN})
+    endif()
+    math(EXPR LEN_SUB_ONE "${LEN}-1")
+    foreach(IDX RANGE 0 ${LEN_SUB_ONE} 1)
+        list(GET VL1 ${IDX} VAL1)
+        list(GET VL2 ${IDX} VAL2)
+
+        if(${VAL1} GREATER ${VAL2})
+            set(${OUTPUT} TRUE PARENT_SCOPE)
+            break()
+        elseif(${VAL1} LESS ${VAL2})
+            set(${OUTPUT} FALSE PARENT_SCOPE)
+            break()
+        else()
+            set(${OUTPUT} TRUE PARENT_SCOPE)
+        endif()
+    endforeach()
+
+    endfunction(compare_versions)
+
+# Creates a small function to run and extract the LevelZero loader version.
+function(get_l0_loader_version)
+
+    set(L0_VERSIONEER_SRC
+        [====[
+        #include <iostream>
+        #include <level_zero/loader/ze_loader.h>
+        #include <string>
+        int main() {
+            ze_result_t result;
+            std::string loader("loader");
+            zel_component_version_t *versions;
+            size_t size = 0;
+            result = zeInit(0);
+            if (result != ZE_RESULT_SUCCESS) {
+                std::cerr << "Failed to init ze driver" << std::endl;
+                return -1;
+            }
+            zelLoaderGetVersions(&size, nullptr);
+            versions = new zel_component_version_t[size];
+            zelLoaderGetVersions(&size, versions);
+            for (size_t i = 0; i < size; i++) {
+                if (loader.compare(versions[i].component_name) == 0) {
+                    std::cout << versions[i].component_lib_version.major << "."
+                              << versions[i].component_lib_version.minor << "."
+                              << versions[i].component_lib_version.patch;
+                    break;
+                }
+            }
+            delete[] versions;
+            return 0;
+        }
+        ]====]
+    )
+
+    set(L0_VERSIONEER_FILE ${CMAKE_BINARY_DIR}/temp/l0_versioneer.cpp)
+
+    file(WRITE ${L0_VERSIONEER_FILE} "${L0_VERSIONEER_SRC}")
+
+    # We need both the directories in the include path as ze_loader.h
+    # includes "ze_api.h" and not "level_zero/ze_api.h".
+    list(APPEND INCLUDE_DIRS ${LevelZero_INCLUDE_DIR})
+    list(APPEND INCLUDE_DIRS ${LevelZero_INCLUDE_DIR}/level_zero)
+    list(JOIN INCLUDE_DIRS ";" INCLUDE_DIRS_STR)
+    try_run(L0_VERSIONEER_RUN L0_VERSIONEER_COMPILE
+            "${CMAKE_BINARY_DIR}"
+            "${L0_VERSIONEER_FILE}"
+            LINK_LIBRARIES ${LevelZero_LIBRARY}
+            CMAKE_FLAGS
+                "-DINCLUDE_DIRECTORIES=${INCLUDE_DIRS_STR}"
+            RUN_OUTPUT_VARIABLE L0_VERSION
+    )
+    if(${L0_VERSIONEER_COMPILE} AND (DEFINED L0_VERSIONEER_RUN))
+        set(LevelZero_VERSION ${L0_VERSION} PARENT_SCOPE)
+        message(STATUS "Found Level Zero of version: ${L0_VERSION}")
+    else()
+        message(FATAL_ERROR
+            "Could not compile a level-zero program to extract loader version"
+        )
+    endif()
+endfunction(get_l0_loader_version)
+
+if(LevelZero_INCLUDE_DIR AND LevelZero_LIBRARY)
+    list(APPEND LevelZero_LIBRARIES "${LevelZero_LIBRARY}")
+    list(APPEND LevelZero_INCLUDE_DIRS ${LevelZero_INCLUDE_DIR})
+    if(OpenCL_FOUND)
+      list(APPEND LevelZero_INCLUDE_DIRS ${OpenCL_INCLUDE_DIRS})
+    endif()
+
+    cmake_path(GET LevelZero_LIBRARY PARENT_PATH LevelZero_LIBRARIES_PATH)
+    set(LevelZero_LIBRARIES_DIR ${LevelZero_LIBRARIES_PATH})
+
+    if(NOT TARGET LevelZero::LevelZero)
+      add_library(LevelZero::LevelZero INTERFACE IMPORTED)
+      set_target_properties(LevelZero::LevelZero
+        PROPERTIES INTERFACE_LINK_LIBRARIES "${LevelZero_LIBRARIES}"
+      )
+      set_target_properties(LevelZero::LevelZero
+        PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${LevelZero_INCLUDE_DIRS}"
+      )
+    endif()
+endif()
+
+# Check if a specific version of Level Zero is required
+if(LevelZero_FIND_VERSION)
+    get_l0_loader_version()
+    set(VERSION_GT_FIND_VERSION FALSE)
+    compare_versions(
+        ${LevelZero_VERSION}
+        ${LevelZero_FIND_VERSION}
+        VERSION_GT_FIND_VERSION
+    )
+    if(${VERSION_GT_FIND_VERSION})
+        set(LevelZero_FOUND TRUE)
+    else()
+        set(LevelZero_FOUND FALSE)
+    endif()
+else()
+    set(LevelZero_FOUND TRUE)
+endif()
+
+find_package_handle_standard_args(LevelZero
+    REQUIRED_VARS
+        LevelZero_FOUND
+        LevelZero_INCLUDE_DIRS
+        LevelZero_LIBRARY
+        LevelZero_LIBRARIES_DIR
+    HANDLE_COMPONENTS
+)
+mark_as_advanced(LevelZero_LIBRARY LevelZero_INCLUDE_DIRS)
+
+if(LevelZero_FOUND)
+    find_package_message(LevelZero "Found LevelZero: ${LevelZero_LIBRARY}"
+        "(found version ${LevelZero_VERSION})"
+    )
+else()
+    find_package_message(LevelZero "Could not find LevelZero" "")
+endif()
diff --git a/mlir/cmake/modules/FindSyclRuntime.cmake b/mlir/cmake/modules/FindSyclRuntime.cmake
new file mode 100644
index 000000000000000..38b065a3f284c2c
--- /dev/null
+++ b/mlir/cmake/modules/FindSyclRuntime.cmake
@@ -0,0 +1,68 @@
+# CMake find_package() module for SYCL Runtime
+#
+# Example usage:
+#
+# find_package(SyclRuntime)
+#
+# If successful, the following variables will be defined:
+# SyclRuntime_FOUND
+# SyclRuntime_INCLUDE_DIRS
+# SyclRuntime_LIBRARY
+# SyclRuntime_LIBRARIES_DIR
+#
+
+include(FindPackageHandleStandardArgs)
+
+if(NOT DEFINED ENV{CMPLR_ROOT})
+    message(WARNING "Please make sure to install Intel DPC++ Compiler and run setvars.(sh/bat)")
+    message(WARNING "You can download standalone Intel DPC++ Compiler from https://www.intel.com/content/www/us/en/developer/articles/tool/oneapi-standalone-components.html#compilers")
+else()
+    if(LINUX OR (${CMAKE_SYSTEM_NAME} MATCHES "Linux"))
+        set(SyclRuntime_ROOT "$ENV{CMPLR_ROOT}/linux")
+    elseif(WIN32)
+        set(SyclRuntime_ROOT "$ENV{CMPLR_ROOT}/windows")
+    endif()
+    list(APPEND SyclRuntime_INCLUDE_DIRS "${SyclRuntime_ROOT}/include")
+    list(APPEND SyclRuntime_INCLUDE_DIRS "${SyclRuntime_ROOT}/include/sycl")
+
+    set(SyclRuntime_LIBRARY_DIR "${SyclRuntime_ROOT}/lib")
+
+    message(STATUS "SyclRuntime_LIBRARY_DIR: ${SyclRuntime_LIBRARY_DIR}")
+    find_library(SyclRuntime_LIBRARY
+        NAMES sycl
+        PATHS ${SyclRuntime_LIBRARY_DIR}
+        NO_DEFAULT_PATH
+        )
+endif()
+
+if(SyclRuntime_LIBRARY)
+    set(SyclRuntime_FOUND TRUE)
+    if(NOT TARGET SyclRuntime::SyclRuntime)
+        add_library(SyclRuntime::SyclRuntime INTERFACE IMPORTED)
+        set_target_properties(SyclRuntime::SyclRuntime
+            PROPERTIES INTERFACE_LINK_LIBRARIES "${SyclRuntime_LIBRARY}"
+      )
+      set_target_properties(SyclRuntime::SyclRuntime
+          PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${SyclRuntime_INCLUDE_DIRS}"
+      )
+    endif()
+else()
+    set(SyclRuntime_FOUND FALSE)
+endif()
+
+find_package_handle_standard_args(SyclRuntime
+    REQUIRED_VARS
+        SyclRuntime_FOUND
+        SyclRuntime_INCLUDE_DIRS
+        SyclRuntime_LIBRARY
+        SyclRuntime_LIBRARY_DIR
+    HANDLE_COMPONENTS
+)
+
+mark_as_advanced(SyclRuntime_LIBRARY SyclRuntime_INCLUDE_DIRS)
+
+if(SyclRuntime_FOUND)
+    find_package_message(SyclRuntime "Found SyclRuntime: ${SyclRuntime_LIBRARY}" "")
+else()
+    find_package_message(SyclRuntime "Could not find SyclRuntime" "")
+endif()
diff --git a/mlir/lib/ExecutionEngine/CMakeLists.txt b/mlir/lib/ExecutionEngine/CMakeLists.txt
index ea33c2c6ed261e1..fdc797763ae3a41 100644
--- a/mlir/lib/ExecutionEngine/CMakeLists.txt
+++ b/mlir/lib/ExecutionEngine/CMakeLists.txt
@@ -12,6 +12,7 @@ set(LLVM_OPTIONAL_SOURCES
   RunnerUtils.cpp
   OptUtils.cpp
   JitRunner.cpp
+  SyclRuntimeWrappers.cpp
   )
 
 # Use a separate library for OptUtils, to avoid pulling in the entire JIT and
@@ -328,4 +329,39 @@ if(LLVM_ENABLE_PIC)
       hip::host hip::amdhip64
     )
   endif()
+
+  if(MLIR_ENABLE_SYCL_RUNNER)
+    find_package(SyclRuntime)
+
+    if(NOT SyclRuntime_FOUND)
+      message(FATAL_ERROR "syclRuntime not found. Please set check oneapi installation and run setvars.sh.")
+    endif()
+
+    find_package(LevelZero)
+
+    if(NOT LevelZero_FOUND)
+      message(FATAL_ERROR "LevelZero not found. Please set LEVEL_ZERO_DIR.")
+    endif()
+
+    add_mlir_library(mlir_sycl_runtime
+      SHARED
+      SyclRuntimeWrappers.cpp
+
+      EXCLUDE_FROM_LIBMLIR
+    )
+
+    check_cxx_compiler_flag("-frtti" CXX_HAS_FRTTI_FLAG)
+    if(NOT CXX_HAS_FRTTI_FLAG)
+      message(FATAL_ERROR "CXX compiler does not accept flag -frtti")
+    endif()
+    target_compile_options (mlir_sycl_runtime PUBLIC -fexceptions -frtti)
+
+    target_include_directories(mlir_sycl_runtime PRIVATE
+      ${MLIR_INCLUDE_DIRS}
+    )
+
+    target_link_libraries(mlir_sycl_runtime PRIVATE LevelZero::LevelZero SyclRuntime::SyclRuntime)
+
+    set_property(TARGET mlir_sycl_runtime APPEND PROPERTY BUILD_RPATH "${LevelZero_LIBRARIES_DIR}" "${SyclRuntime_LIBRARIES_DIR}")
+  endif()
 endif()
diff --git a/mlir/lib/ExecutionEngine/SyclRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/SyclRuntimeWrappers.cpp
new file mode 100644
index 000000000000000..c36ae2859ea65ab
--- /dev/null
+++ b/mlir/lib/ExecutionEngine/SyclRuntimeWrappers.cpp
@@ -0,0 +1,210 @@
+//===- SyclRuntimeWrappers.cpp - MLIR SYCL wrapper library ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implements wrappers around the sycl runtime library with C linkage
+//
+//===----------------------------------------------------------------------===//
+
+
+#include <CL/sycl.hpp>
+#include <level_zero/ze_api.h>
+#include <sycl/ext/oneapi/backend/level_zero.hpp>
+
+#ifdef _WIN32
+#define SYCL_RUNTIME_EXPORT __declspec(dllexport)
+#else
+#define SYCL_RUNTIME_EXPORT
+#endif // _WIN32
+
+namespace {
+
+template <typename F>
+auto catchAll(F &&func) {
+  try {
+    return func();
+  } catch (const std::exception &e) {
+    fprintf(stdout, "An exception was thrown: %s\n", e.what());
+    fflush(stdout);
+    abort();
+  } catch (...) {
+    fprintf(stdout, "An unknown exception was thrown\n");
+    fflush(stdout);
+    abort();
+  }
+}
+
+#define L0_SAFE_CALL(call)                                                     \
+  {                                                                            \
+    ze_result_t status = (call);                                               \
+    if (status != ZE_RESULT_SUCCESS) {                                         \
+      fprintf(stdout, "L0 error %d\n", status);                                \
+      fflush(stdout);                                                          \
+      abort();                                                                 \
+    }                                                                          \
+  }
+
+} // namespace
+
+static sycl::device getDefaultDevice() {
+  static sycl::device syclDevice;
+  static bool isDeviceInitialised = false;
+  if (!isDeviceInitialised) {
+    auto platformList = sycl::platform::get_platforms();
+    for (const auto &platform : platformList) {
+      auto platformName = platform.get_info<sycl::info::platform::name>();
+      bool isLevelZero = platformName.find("Level-Zero") != std::string::npos;
+      if (!isLevelZero)
+        continue;
+
+      syclDevice = platform.get_devices()[0];
+      isDeviceInitialised = true;
+      return syclDevice;
+    }
+    throw std::runtime_error("getDefaultDevice failed");
+  } else
+    return syclDevice;
+}
+
+static sycl::context getDefaultContext() {
+  static sycl::context syclContext {getDefaultDevice()};
+  return syclContext;
+}
+
+static void *allocDeviceMemory(sycl::queue *queue, size_t size, bool isShared) {
+  void *memPtr = nullptr;
+  if (isShared) {
+    memPtr = sycl::aligned_alloc_shared(64, size, getDefaultDevice(),
+                                        getDefaultContext());
+  } else {
+    memPtr = sycl::aligned_alloc_device(64, size, getDefaultDevice(),
+                                        getDefaultContext());
+  }
+  if (memPtr == nullptr) {
+    throw std::runtime_error("mem allocation failed!");
+  }
+  return memPtr;
+}
+
+static void deallocDeviceMemory(sycl::queue *queue, void *ptr) {
+  sycl::free(ptr, *queue);
+}
+
+static ze_module_handle_t loadModule(const void *data, size_t dataSize) {
+  assert(data);
+  ze_module_handle_t zeModule;
+  ze_module_desc_t desc = {ZE_STRUCTURE_TYPE_MODULE_DESC,
+                           nullptr,
+                           ZE_MODULE_FORMAT_IL_SPIRV,
+                           dataSize,
+                           (const uint8_t *)data,
+                           nullptr,
+                           nullptr};
+  auto zeDevice = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(
+      getDefaultDevice());
+  auto zeContext = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(
+      getDefaultContext());
+  L0_SAFE_CALL(zeModuleCreate(zeContext, zeDevice, &desc, &zeModule, nullptr));
+  return zeModule;
+}
+
+static sycl::kernel *getKernel(ze_module_handle_t zeModule, const char *name) {
+  assert(zeModule);
+  assert(name);
+  ze_kernel_handle_t zeKernel;
+  ze_kernel_desc_t desc = {};
+  desc.pKernelName = name;
+
+  L0_SAFE_CALL(zeKernelCreate(zeModule, &desc, &zeKernel));
+  sycl::kernel_bundle<sycl::bundle_state::executable> kernelBundle =
+      sycl::make_kernel_bundle<sycl::backend::ext_oneapi_level_zero,
+                               sycl::bundle_state::executable>(
+          {zeModule}, getDefaultContext());
+
+  auto kernel = sycl::make_kernel<sycl::backend::ext_oneapi_level_zero>(
+      {kernelBundle, zeKernel}, getDefaultContext());
+  return new sycl::kernel(kernel);
+}
+
+static void launchKernel(sycl::queue *queue, sycl::kernel *kernel, size_t gridX,
+                         size_t gridY, size_t gridZ, size_t blockX,
+                         size_t blockY, size_t blockZ, size_t sharedMemBytes,
+                         void **params, size_t paramsCount) {
+  auto syclGlobalRange =
+      sycl::range<3>(blockZ * gridZ, blockY * gridY, blockX * gridX);
+  auto syclLocalRange = sycl::range<3>(blockZ, blockY, blockX);
+  sycl::nd_range<3> syclNdRange(syclGlobalRange, syclLocalRange);
+
+  queue->submit([&](sycl::handler &cgh) {
+    for (size_t i = 0; i < paramsCount; i++) {
+      cgh.set_arg(static_cast<uint32_t>(i), *(static_cast<void **>(params[i])));
+    }
+    cgh.parallel_for(syclNdRange, *kernel);
+  });
+}
+
+// Wrappers
+
+extern "C" SYCL_RUNTIME_EXPORT sycl::queue *mgpuStreamCreate() {
+
+  return catchAll([&]() {
+    sycl::queue *queue =
+        new sycl::queue(getDefaultContext(), getDefaultDevice());
+    return queue;
+  });
+}
+
+extern "C" SYCL_RUNTIME_EXPORT void mgpuStreamDestroy(sycl::queue *queue) {
+  catchAll([&]() { delete queue; });
+}
+
+extern "C" SYCL_RUNTIME_EXPORT void *
+mgpuMemAlloc(uint64_t size, sycl::queue *queue, bool isShared) {
+  return catchAll([&]() {
+    return allocDeviceMemory(queue, static_cast<size_t>(size), true);
+  });
+}
+
+extern "C" SYCL_RUNTIME_EXPORT void mgpuMemFree(void *ptr, sycl::queue *queue) {
+  catchAll([&]() {
+    if (ptr) {
+      deallocDeviceMemory(queue, ptr);
+    }
+  });
+}
+
+extern "C" SYCL_RUNTIME_EXPORT ze_module_handle_t
+mgpuModuleLoad(const void *data, size_t gpuBlobSize) {
+  return catchAll([&]() { return loadModule(data, gpuBlobSize); });
+}
+
+extern "C" SYCL_RUNTIME_EXPORT sycl::kernel *
+mgpuModuleGetFunction(ze_module_handle_t module, const char *name) {
+  return catchAll([&]() { return getKernel(module, name); });
+}
+
+extern "C" SYCL_RUNTIME_EXPORT void
+mgpuLaunchKernel(sycl::kernel *kernel, size_t gridX, size_t gridY, size_t gridZ,
+                 size_t blockX, size_t blockY, size_t blockZ,
+                 size_t sharedMemBytes, sycl::queue *queue, void **params,
+                 void **extra, size_t paramsCount) {
+  return catchAll([&]() {
+    launchKernel(queue, kernel, gridX, gridY, gridZ, blockX, blockY, blockZ,
+                 sharedMemBytes, params, paramsCount);
+  });
+}
+
+extern "C" SYCL_RUNTIME_EXPORT void mgpuStreamSynchronize(sycl::queue *queue) {
+
+  catchAll([&]() { queue->wait(); });
+}
+
+extern "C" SYCL_RUNTIME_EXPORT void
+mgpuModuleUnload(ze_module_handle_t module) {
+
+  catchAll([&]() { L0_SAFE_CALL(zeModuleDestroy(module)); });
+}

@nbpatel nbpatel requested a review from grypp October 25, 2023 23:37
Copy link
Member

@grypp grypp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me! I'm not a SYCL expert, but maybe someone with more expertise could take a quick look here. @Hardcode84 ?

Copy link
Contributor

@Hardcode84 Hardcode84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

mgpuLaunchKernel(sycl::kernel *kernel, size_t gridX, size_t gridY, size_t gridZ,
size_t blockX, size_t blockY, size_t blockZ,
size_t sharedMemBytes, sycl::queue *queue, void **params,
void **extra, size_t paramsCount) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: extra is not being used and will result in warning, either do /*extra*/ or (void)extra.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@nbpatel
Copy link
Contributor Author

nbpatel commented Oct 26, 2023

Thanks for the feedback and approval. @grypp can you help me merge this PR ?

@grypp grypp merged commit 7fa19e6 into llvm:main Oct 26, 2023
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants