Skip to content

Commit

Permalink
Address review feedback - get rid of runtime dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryyin committed Oct 12, 2020
1 parent 00bbf5f commit c3289a4
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 72 deletions.
40 changes: 1 addition & 39 deletions mlir/tools/mlir-miopen-driver/CMakeLists.txt
Expand Up @@ -5,38 +5,6 @@ set(LLVM_LINK_COMPONENTS
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)

# Configure ROCm support.
if (NOT DEFINED ROCM_PATH)
if (NOT DEFINED ENV{ROCM_PATH})
set(ROCM_PATH "/opt/rocm" CACHE PATH "Path to which ROCm has been installed")
else()
set(ROCM_PATH $ENV{ROCM_PATH} CACHE PATH "Path to which ROCm has been installed")
endif()
set(HIP_PATH "${ROCM_PATH}/hip" CACHE PATH " Path to which HIP has been installed")
endif()
set(CMAKE_MODULE_PATH "${HIP_PATH}/cmake" ${CMAKE_MODULE_PATH})
find_package(HIP)
if (NOT HIP_FOUND)
message(SEND_ERROR "Build the mlir rocm runner requires a working ROCm and HIP install")
else()
message(STATUS "ROCm HIP version: ${HIP_VERSION}")
endif()

# Set compile-time flags for ROCm path.
add_definitions(-D__ROCM_PATH__="${ROCM_PATH}")

# Locate HIP runtime library.
find_library(ROCM_RUNTIME_LIBRARY amdhip64
PATHS "${HIP_PATH}/lib")
if (NOT ROCM_RUNTIME_LIBRARY)
message(SEND_ERROR "Could not locate ROCm HIP runtime library")
else()
message(STATUS "ROCm HIP runtime lib: ${ROCM_RUNTIME_LIBRARY}")
endif()

# Set HIP compile-time flags.
add_definitions(-D__HIP_PLATFORM_HCC__)

set(LIBS
${dialect_libs}
${conversion_libs}
Expand All @@ -63,14 +31,8 @@ add_llvm_executable(mlir-miopen-driver
${LIBS}
)

target_include_directories(mlir-miopen-driver
PRIVATE
"${HIP_PATH}/../include"
"${HIP_PATH}/include"
)

llvm_update_compile_flags(mlir-miopen-driver)
target_link_libraries(mlir-miopen-driver PRIVATE ${LIBS} ${ROCM_RUNTIME_LIBRARY})
target_link_libraries(mlir-miopen-driver PRIVATE ${LIBS})
mlir_check_link_libraries(mlir-miopen-driver)


Expand Down
41 changes: 8 additions & 33 deletions mlir/tools/mlir-miopen-driver/mlir-miopen-driver.cpp
Expand Up @@ -37,8 +37,6 @@
#include "llvm/Support/ToolOutputFile.h"
#include "llvm/Support/raw_ostream.h"

#include "hip/hip_runtime.h"

using namespace llvm;
using namespace mlir;

Expand All @@ -56,11 +54,16 @@ static cl::opt<std::string>
cl::value_desc("convolution flavor string"), cl::init("conv2d"));

static cl::opt<std::string>
arch("arch", cl::desc("amdgpu architecture, eg: gfx900, gfx906 ..."),
arch("arch",
cl::desc("amdgpu architecture, eg: gfx803, gfx900, gfx906, gfx908"),
cl::value_desc("GFX architecture string"), cl::init("gfx906"));

static cl::opt<int> num_cu("num_cu", cl::desc("Number of compute units"),
cl::value_desc("compute unit value"), cl::init(64));
static cl::opt<int>
num_cu("num_cu",
cl::desc("Number of compute units, valid combinations include: "
"gfx803(36/64), gfx900(56/64), "
"gfx906(60/64), gfx908(120)"),
cl::value_desc("compute unit value"), cl::init(64));

static cl::opt<std::string> filterLayout("fil_layout", cl::desc("Filter layout"),
cl::value_desc("layout string"),
Expand Down Expand Up @@ -190,31 +193,6 @@ static cl::opt<std::string> tensorDataType("t", cl::desc("Data type for convolut
cl::value_desc("Data type for convolution"),
cl::init("f32"));

int getDeviceId() // Get default device
{
int device = 0;
auto status = hipGetDevice(&device);
if (status != hipSuccess)
llvm::errs() << "No device found";
return device;
}

std::size_t GetMaxComputeUnits(int device) {
int result = 0;
auto status = hipDeviceGetAttribute(
&result, hipDeviceAttributeMultiprocessorCount, device);
if (status != hipSuccess)
llvm::errs() << "Failed to get compute units.";

return result;
}

std::string GetDeviceName(int device) {
hipDeviceProp_t props{};
hipGetDeviceProperties(&props, device);
return "gfx" + std::to_string(props.gcnArch);
}

static void populateDefaults() {
if (populateDefaultValues == true) {
if (xdlopsV2.getValue() == false) {
Expand Down Expand Up @@ -250,9 +228,6 @@ static void populateDefaults() {
paddingHeight.setValue(0);
paddingWidth.setValue(0);
}
int device = getDeviceId();
arch.setValue(GetDeviceName(device));
num_cu.setValue(GetMaxComputeUnits(device));
}
}

Expand Down

0 comments on commit c3289a4

Please sign in to comment.