Skip to content

Commit

Permalink
Build the indexing library with the correct Metal version (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
DenisVieriu97 committed Jul 15, 2022
1 parent 6c6ebdd commit 133b250
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 1 deletion.
31 changes: 30 additions & 1 deletion aten/src/ATen/mps/MPSDevice.mm
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,32 @@
static std::unique_ptr<MPSDevice> mps_device;
static c10::once_flag mpsdev_init;

static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& device) {
#if defined(__MAC_10_13) && __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_10_13
#else
#error "Metal is not available on the current platform."
#endif

// MPS Advanced Indexing needs at least Metal 2.0 (support for Argument Buffer and function constants)
MTLLanguageVersion languageVersion;
if (@available(macOS 13.0, *)) {
languageVersion = MTLLanguageVersion3_0;
} else if (@available(macOS 12.0, *)) {
languageVersion = MTLLanguageVersion2_4;
} else if (@available(macOS 11.0, *)) {
languageVersion = MTLLanguageVersion2_3;
} else if (@available(macOS 10.15, *)) {
languageVersion = MTLLanguageVersion2_2;
} else if (@available(macOS 10.14, *)) {
languageVersion = MTLLanguageVersion2_1;
} else if (@available(macOS 10.13, *)) {
languageVersion = MTLLanguageVersion2_0;
}

TORCH_CHECK([device supportsFamily:MTLGPUFamilyMac2], "Missing Metal support for MTLGPUFamilyMac2");
return languageVersion;
}

MPSDevice* MPSDevice::getInstance() {
c10::call_once(mpsdev_init, [] {
mps_device = std::unique_ptr<MPSDevice>(new MPSDevice());
Expand All @@ -22,8 +48,11 @@
assert(_mtl_device);
NSError* error = nil;
if (!_mtl_indexing_library) {
MTLCompileOptions *options = [MTLCompileOptions new];
[options setLanguageVersion: getMetalLanguageVersion(_mtl_device)];
[options setFastMathEnabled: YES];
_mtl_indexing_library = [_mtl_device newLibraryWithSource: [NSString stringWithCString: mps::indexing_metal_shaders encoding:NSASCIIStringEncoding]
options: nil
options: options
error: &error];
TORCH_CHECK(_mtl_indexing_library, "Failed to create indexing library, error: ", [[error description] UTF8String]);
}
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/mps/operations/Indexing.mm
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ bool dispatchIndexSelectKernel(TensorIteratorBase& iter, IntArrayRef index_size,
using namespace mps;

@autoreleasepool {
if (iter.numel() == 0) {
return true;
}

const Tensor& inputTensor = iter.tensor(1);
Tensor outputTensor = iter.tensor(0);

Expand Down

0 comments on commit 133b250

Please sign in to comment.