Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions build2cmake/src/templates/metal/compile-metal.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,30 @@ function(compile_metal_shaders TARGET_NAME METAL_SOURCES)
VERBATIM
)

# Generate C++ header with embedded metallib data
set(METALLIB_HEADER "${CMAKE_BINARY_DIR}/${TARGET_NAME}_metallib.h")
set(METALLIB_TO_HEADER_SCRIPT "${CMAKE_CURRENT_SOURCE_DIR}/cmake/metallib_to_header.py")

add_custom_command(
OUTPUT ${METALLIB_HEADER}
COMMAND ${Python_EXECUTABLE} ${METALLIB_TO_HEADER_SCRIPT} ${METALLIB_FILE} ${METALLIB_HEADER} ${TARGET_NAME}
DEPENDS ${METALLIB_FILE} ${METALLIB_TO_HEADER_SCRIPT}
COMMENT "Generating embedded Metal library header ${METALLIB_HEADER}"
VERBATIM
)

# Create a custom target for the metallib
add_custom_target(${TARGET_NAME}_metallib ALL DEPENDS ${METALLIB_FILE})
add_custom_target(${TARGET_NAME}_metallib ALL DEPENDS ${METALLIB_FILE} ${METALLIB_HEADER})

# Add dependency to main target
add_dependencies(${TARGET_NAME} ${TARGET_NAME}_metallib)

# Set property so we can access the metallib path later
set_target_properties(${TARGET_NAME} PROPERTIES
METALLIB_FILE ${METALLIB_FILE}
# Add the generated header to include directories
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_BINARY_DIR})

# Pass the metallib header and namespace as compile definitions
target_compile_definitions(${TARGET_NAME} PRIVATE
EMBEDDED_METALLIB_HEADER="${TARGET_NAME}_metallib.h"
EMBEDDED_METALLIB_NAMESPACE=${TARGET_NAME}_metal
)
endfunction()
73 changes: 73 additions & 0 deletions build2cmake/src/templates/metal/metallib_to_header.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#!/usr/bin/env python3
import sys
import os

def convert_metallib_to_header(metallib_path: str, header_path: str, target_name: str) -> None:
"""Convert a metallib binary file to a C++ header with embedded data."""

# Read the metallib binary data
with open(metallib_path, 'rb') as f:
data: bytes = f.read()

# Generate the header content
header_content: str = """// Auto-generated file containing embedded Metal library
#pragma once
#include <cstddef>
#include <Metal/Metal.h>

namespace """ + target_name + """_metal {
static const unsigned char metallib_data[] = {
"""

# Convert binary data to C array format
bytes_per_line: int = 16
for i in range(0, len(data), bytes_per_line):
chunk: bytes = data[i:i + bytes_per_line]
hex_values: str = ', '.join('0x{:02x}'.format(b) for b in chunk)
header_content += " " + hex_values + ","
if i + bytes_per_line < len(data):
header_content += "\n"

header_content += """
};
static const size_t metallib_data_len = """ + str(len(data)) + """;

// Convenience function to create Metal library from embedded data
inline id<MTLLibrary> createLibrary(id<MTLDevice> device, NSError** error = nullptr) {
dispatch_data_t libraryData = dispatch_data_create(
metallib_data,
metallib_data_len,
dispatch_get_main_queue(),
^{ /* No cleanup needed for static data */ });

NSError* localError = nil;
id<MTLLibrary> library = [device newLibraryWithData:libraryData error:&localError];

if (error) {
*error = localError;
}

return library;
}
} // namespace """ + target_name + """_metal
"""

# Write the header file
dir_path: str = os.path.dirname(header_path)
if dir_path:
os.makedirs(dir_path, exist_ok=True)
with open(header_path, 'w') as f:
f.write(header_content)

print("Generated {} ({} bytes)".format(header_path, len(data)))

if __name__ == "__main__":
if len(sys.argv) != 4:
print("Usage: metallib_to_header.py <metallib_path> <header_path> <target_name>")
sys.exit(1)

metallib_path: str = sys.argv[1]
header_path: str = sys.argv[2]
target_name: str = sys.argv[3]

convert_metallib_to_header(metallib_path, header_path, target_name)
22 changes: 0 additions & 22 deletions build2cmake/src/templates/metal/torch-extension.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,4 @@ define_gpu_extension_target(
# Compile Metal shaders if any were found
if(ALL_METAL_SOURCES)
compile_metal_shaders({{ ops_name }} "${ALL_METAL_SOURCES}")

# Get the metallib file path
get_target_property(METALLIB_FILE {{ ops_name }} METALLIB_FILE)

# Copy metallib to the output directory (same as the .so file)
add_custom_command(TARGET {{ ops_name }} POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy
${METALLIB_FILE}
$<TARGET_FILE_DIR:{{ ops_name }}>/{{ ops_name }}.metallib
COMMENT "Copying metallib to output directory"
)

# Also copy to the source directory for editable installs
add_custom_command(TARGET {{ ops_name }} POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy
${METALLIB_FILE}
${CMAKE_CURRENT_SOURCE_DIR}/torch-ext/{{ name }}/{{ ops_name }}.metallib
COMMENT "Copying metallib to source directory for editable installs"
)

# Use a relative path for runtime loading
target_compile_definitions({{ ops_name }} PRIVATE METALLIB_PATH="{{ ops_name }}.metallib")
endif()
8 changes: 8 additions & 0 deletions build2cmake/src/torch/metal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::{
static CMAKE_UTILS: &str = include_str!("../templates/utils.cmake");
static REGISTRATION_H: &str = include_str!("../templates/registration.h");
static COMPILE_METAL_CMAKE: &str = include_str!("../templates/metal/compile-metal.cmake");
static METALLIB_TO_HEADER_PY: &str = include_str!("../templates/metal/metallib_to_header.py");

pub fn write_torch_ext_metal(
env: &Environment,
Expand Down Expand Up @@ -77,6 +78,13 @@ fn write_cmake(
.entry(compile_metal_path)
.extend_from_slice(COMPILE_METAL_CMAKE.as_bytes());

let mut metallib_to_header_path = PathBuf::new();
metallib_to_header_path.push("cmake");
metallib_to_header_path.push("metallib_to_header.py");
file_set
.entry(metallib_to_header_path)
.extend_from_slice(METALLIB_TO_HEADER_PY.as_bytes());

let cmake_writer = file_set.entry("CMakeLists.txt");

render_preamble(env, name, cmake_writer)?;
Expand Down
40 changes: 14 additions & 26 deletions examples/relu/relu_metal/relu.mm
Original file line number Diff line number Diff line change
Expand Up @@ -2,44 +2,32 @@

#import <Foundation/Foundation.h>
#import <Metal/Metal.h>
#include <string>
#include <dlfcn.h>
#include <mach-o/dyld.h>

// Include the auto-generated header with embedded metallib
#ifdef EMBEDDED_METALLIB_HEADER
#include EMBEDDED_METALLIB_HEADER
#else
#error "EMBEDDED_METALLIB_HEADER not defined"
#endif

static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor &tensor) {
return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
}

static std::string getModuleDirectory() {
Dl_info dl_info;
if (dladdr((void*)getModuleDirectory, &dl_info)) {
std::string path(dl_info.dli_fname);
size_t pos = path.find_last_of('/');
if (pos != std::string::npos) {
return path.substr(0, pos);
}
}
return ".";
}

torch::Tensor &dispatchReluKernel(torch::Tensor const &input,
torch::Tensor &output) {
@autoreleasepool {
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
NSError *error = nil;

int numThreads = input.numel();

// Construct the full path to the metallib file
std::string moduleDir = getModuleDirectory();
std::string metallibPath = moduleDir + "/" + METALLIB_PATH;

NSString *metallibPathStr = [NSString stringWithUTF8String:metallibPath.c_str()];
NSURL *metallibURL = [NSURL fileURLWithPath:metallibPathStr];
id<MTLLibrary> customKernelLibrary = [device newLibraryWithURL:metallibURL error:&error];
if (!customKernelLibrary) {
NSLog(@"[relu.mm] Failed to load pre-compiled Metal library at %@, will fall back to runtime compilation. Error: %@", metallibPathStr, error.localizedDescription);
}
// Load the embedded Metal library from memory
NSError *error = nil;
id<MTLLibrary> customKernelLibrary = EMBEDDED_METALLIB_NAMESPACE::createLibrary(device, &error);
TORCH_CHECK(customKernelLibrary,
"Failed to create Metal library from embedded data: ",
error.localizedDescription.UTF8String);

std::string kernel_name =
std::string("relu_forward_kernel_") +
Expand Down Expand Up @@ -94,7 +82,7 @@
return output;
}

void relu(torch::Tensor &out, const torch::Tensor &input) {
void relu(torch::Tensor &out, torch::Tensor const &input) {
TORCH_CHECK(input.device().is_mps(), "input must be a MPS tensor");
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
TORCH_CHECK(input.scalar_type() == torch::kFloat ||
Expand Down
2 changes: 1 addition & 1 deletion examples/relu/torch-ext/relu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ def relu(x: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
if out is None:
out = torch.empty_like(x)
ops.relu(out, x)
return out
return out
2 changes: 1 addition & 1 deletion examples/relu/torch-ext/torch_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
#endif
}

REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
2 changes: 1 addition & 1 deletion examples/relu/torch-ext/torch_binding.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

#include <torch/torch.h>

void relu(torch::Tensor &out, torch::Tensor const &input);
void relu(torch::Tensor &out, torch::Tensor const &input);