Skip to content

Commit

Permalink
[MPS] Add fmax fmin op (#95191)
Browse files Browse the repository at this point in the history
Fixes #ISSUE_NUMBER

Pull Request resolved: pytorch/pytorch#95191
Approved by: https://github.com/kulinseth
  • Loading branch information
qqaatw authored and cyyever committed Mar 5, 2023
1 parent 300e2d0 commit a2d04c0
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 2 deletions.
2 changes: 2 additions & 0 deletions aten/src/ATen/native/mps/OperationUtils.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// Copyright © 2022 Apple Inc.

#pragma once

#include <ATen/ATen.h>
#include <ATen/Tensor.h>
#include <ATen/Utils.h>
Expand Down
199 changes: 199 additions & 0 deletions aten/src/ATen/native/mps/operations/BinaryKernel.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/BinaryOps.h>

namespace at::native {
namespace mps {

static const char* METAL_BINARY = R"BINARY_METAL(
#include <metal_stdlib>
using namespace metal;
template<typename T>
kernel void fmax(constant void * input_ [[buffer(0)]],
constant void * other_ [[buffer(1)]],
device void * out_ [[buffer(2)]],
constant uint3 * offsets [[buffer(3)]],
uint tid [[thread_position_in_grid]]) {
device T* out = (device T*)((device uint8_t*)out_ + offsets[tid].x);
constant T* input = (constant T*)((constant uint8_t*)input_ + offsets[tid].y);
constant T* other = (constant T*)((constant uint8_t*)other_ + offsets[tid].z);
*out = fmax(*input, *other);
}
template<typename T>
kernel void fmin(constant void * input_ [[buffer(0)]],
constant void * other_ [[buffer(1)]],
device void * out_ [[buffer(2)]],
constant uint3 * offsets [[buffer(3)]],
uint tid [[thread_position_in_grid]]) {
device T* out = (device T*)((device uint8_t*)out_ + offsets[tid].x);
constant T* input = (constant T*)((constant uint8_t*)input_ + offsets[tid].y);
constant T* other = (constant T*)((constant uint8_t*)other_ + offsets[tid].z);
*out = fmin(*input, *other);
}
#define REGISTER_FMAX_OP(DTYPE) \
template \
[[host_name("fmax_" #DTYPE)]] \
kernel void fmax<DTYPE>( \
constant void * input_ [[buffer(0)]], \
constant void * other_ [[buffer(1)]], \
device void * out_ [[buffer(2)]], \
constant uint3 * offsets [[buffer(3)]], \
uint tid [[thread_position_in_grid]]);
#define REGISTER_FMIN_OP(DTYPE) \
template \
[[host_name("fmin_" #DTYPE)]] \
kernel void fmin<DTYPE>( \
constant void * input_ [[buffer(0)]], \
constant void * other_ [[buffer(1)]], \
device void * out_ [[buffer(2)]], \
constant uint3 * offsets [[buffer(3)]], \
uint tid [[thread_position_in_grid]]);
REGISTER_FMAX_OP(float);
REGISTER_FMAX_OP(half);
REGISTER_FMIN_OP(float);
REGISTER_FMIN_OP(half);
)BINARY_METAL";

using namespace mps;

static id<MTLLibrary> compileBinaryOpsLibrary(id<MTLDevice> device) {
static id<MTLLibrary> binaryLibrary = nil;
if (binaryLibrary) {
return binaryLibrary;
}

NSError *error = nil;
MTLCompileOptions *options = [[MTLCompileOptions new] autorelease];
[options setLanguageVersion: MTLLanguageVersion2_3];
binaryLibrary = [device newLibraryWithSource:[NSString stringWithCString: METAL_BINARY encoding:NSASCIIStringEncoding]
options:options
error:&error];
TORCH_CHECK(binaryLibrary, "Failed to create metal binary library, error: ", [[error description] UTF8String]);
return binaryLibrary;
}

static id<MTLComputePipelineState> binaryPipelineState(id<MTLDevice> device, const std::string& kernel) {
static std::unordered_map<std::string, id<MTLComputePipelineState>> psoCache;
id<MTLComputePipelineState> pso = psoCache[kernel];
if (pso) {
return pso;
}

NSError* error = nil;
id<MTLLibrary> binaryLib = compileBinaryOpsLibrary(device);
id<MTLFunction> binaryFunc = [binaryLib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]];
TORCH_CHECK(binaryFunc, "Failed to create function state object for: ", kernel);
pso = [device newComputePipelineStateWithFunction:binaryFunc error:&error];
TORCH_CHECK(pso, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);

psoCache[kernel] = pso;
return pso;
}

void fmax_fmin_mps_impl(TensorIteratorBase& iter, const std::string max_min) {
TORCH_CHECK(iter.common_dtype() != at::kDouble, "float64 is not supported on MPS");

Tensor input = iter.input(0);
Tensor other = iter.input(1);
Tensor out = iter.output(0);
id<MTLBuffer> inputBuffer = getMTLBufferStorage(input);
id<MTLBuffer> otherBuffer = getMTLBufferStorage(other);
id<MTLBuffer> outputBuffer = getMTLBufferStorage(out);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
const uint32_t nDim = iter.ndim();
constexpr uint32_t nOffsets = 3;
const uint32_t numThreads = iter.numel();
dispatch_sync(mpsStream->queue(), ^(){
@autoreleasepool {
NSError* error = nil;
id<MTLCommandBuffer> commandBuffer = mpsStream->commandBuffer();
id<MTLComputeCommandEncoder> computeEncoder = [commandBuffer computeCommandEncoder];
MTLSize gridSize = MTLSizeMake(numThreads, 1, 1);
const IntArrayRef& iterShape = iter.shape();
std::vector<uint32_t> iterShapeData(iterShape.size());
std::vector<std::array<uint32_t, nOffsets>> strides(nDim);

for (const auto i: c10::irange(iterShape.size())) {
TORCH_CHECK(i <= UINT32_MAX);
iterShapeData[i] = (uint32_t)(iterShape[i]);
}

for (const auto i: c10::irange(nDim)) {
for (const auto offset: c10::irange(nOffsets)) {
strides[i][offset] = iter.strides(offset)[i];
}
}

id<MTLFunction> kernelDataOffsetsFunction = MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets", nil);
id<MTLComputePipelineState> kernelDataOffsetsPSO = [[device newComputePipelineStateWithFunction: kernelDataOffsetsFunction
error: &error] autorelease];
id<MTLBuffer> kernelDataOffsets = [[device newBufferWithLength: numThreads * sizeof(simd_uint3)
options: 0] autorelease];
TORCH_CHECK(kernelDataOffsetsPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
[computeEncoder setComputePipelineState:kernelDataOffsetsPSO];
[computeEncoder setBytes:strides.data() length:sizeof(uint32_t) * nDim * nOffsets atIndex:0];
[computeEncoder setBuffer:kernelDataOffsets offset:0 atIndex:1];
[computeEncoder setBytes:iterShapeData.data() length:sizeof(uint32_t) * iterShape.size() atIndex:2];
[computeEncoder setBytes:&nDim length:sizeof(uint32_t) atIndex:3];
[computeEncoder setBytes:&nOffsets length:sizeof(uint32_t) atIndex:4];

NSUInteger kernelOffsetsTGSize = kernelDataOffsetsPSO.maxTotalThreadsPerThreadgroup;
if (kernelOffsetsTGSize > numThreads)
kernelOffsetsTGSize = numThreads;

MTLSize kernelOffsetsThreadGroupSize = MTLSizeMake(kernelOffsetsTGSize, 1, 1);
[computeEncoder dispatchThreads: gridSize
threadsPerThreadgroup: kernelOffsetsThreadGroupSize];

const std::string kernel = "f" + max_min + "_" + scalarToMetalTypeString(out.scalar_type());
id<MTLComputePipelineState> fmaxfminPSO = binaryPipelineState(device, kernel);
[computeEncoder setComputePipelineState:fmaxfminPSO];
[computeEncoder setBuffer:inputBuffer offset:input.storage_offset() * input.element_size() atIndex:0];
[computeEncoder setBuffer:otherBuffer offset:other.storage_offset() * other.element_size() atIndex:1];
[computeEncoder setBuffer:outputBuffer offset:out.storage_offset() * out.element_size() atIndex:2];
[computeEncoder setBuffer:kernelDataOffsets offset:0 atIndex:3];

NSUInteger tgSize = fmaxfminPSO.maxTotalThreadsPerThreadgroup;
if (tgSize > numThreads) {
tgSize = numThreads;
}

MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreads: gridSize
threadsPerThreadgroup: threadGroupSize];

[computeEncoder endEncoding];
mpsStream->commit(true);
}
});
}
} // namespace mps

void fmax_mps_kernel(TensorIteratorBase& iter) {
if (isFloatingType(iter.common_dtype())) {
mps::fmax_fmin_mps_impl(iter, "max");
} else {
at::maximum_out(const_cast<Tensor&>(iter.output()), iter.input(0), iter.input(1));
}
}
void fmin_mps_kernel(TensorIteratorBase& iter) {
if (isFloatingType(iter.common_dtype())) {
mps::fmax_fmin_mps_impl(iter, "min");
} else {
at::minimum_out(const_cast<Tensor&>(iter.output()), iter.input(0), iter.input(1));
}
}

REGISTER_DISPATCH(fmax_stub, &fmax_mps_kernel);
REGISTER_DISPATCH(fmin_stub, &fmin_mps_kernel);

} // namespace at::native
4 changes: 2 additions & 2 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9301,7 +9301,7 @@
structured_inherits: TensorIteratorBase
device_check: NoCheck # TensorIterator
dispatch:
CPU, CUDA: fmin_out
CPU, CUDA, MPS: fmin_out
tags: pointwise

- func: max(Tensor self) -> Tensor
Expand All @@ -9323,7 +9323,7 @@
structured_inherits: TensorIteratorBase
device_check: NoCheck # TensorIterator
dispatch:
CPU, CUDA: fmax_out
CPU, CUDA, MPS: fmax_out
tags: pointwise

- func: maximum(Tensor self, Tensor other) -> Tensor
Expand Down
4 changes: 4 additions & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -9404,6 +9404,8 @@ class TestConsistency(TestCaseMPS):
'float': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'floor': ['f32', 'f16', 'i16', 'i32', 'i64'],
'floor_divide': ['f32', 'f16'],
'fmax': ['b8', 'f32', 'f16', 'i16', 'i32', 'i64', 'u8'],
'fmin': ['b8', 'f32', 'f16', 'i16', 'i32', 'i64', 'u8'],
'fmod': ['f32', 'f16', 'i16', 'i32', 'i64', 'u8'],
'frac': ['f16', 'f32'],
'gather': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
Expand Down Expand Up @@ -9673,6 +9675,8 @@ class TestConsistency(TestCaseMPS):
'flipud': ['f16', 'f32'],
'float': ['f32'],
'floor': ['f32'],
'fmax': ['f16', 'f32'],
'fmin': ['f16', 'f32'],
'gradient': ['f32'],
'half': ['f16'],
'hstack': ['f16', 'f32'],
Expand Down

0 comments on commit a2d04c0

Please sign in to comment.