Skip to content

Commit

Permalink
[MPS] Support for median with dim (pytorch#88807)
Browse files Browse the repository at this point in the history
## Summary ⚡

**Aim**: Add support for aten::median for MPS backend (Fixes pytorch#87220)

This is fresh clean PR from the previous [PR](pytorch#88554)

- Implementing the new median function in aten/src/ATen/native/mps/operations/ReduceOps.mm
- Adding it to aten/src/ATen/native/native_functions.yaml
- Adding it to existing test_median

### **this will works like this** 🪶
median of entire input tensor on MPS
`torch.median(mps_inputTensor)`
median of along a dim
`torch.median(mps_inputTensor, dim=[int], keepdim=[Bool])`
Pull Request resolved: pytorch#88807
Approved by: https://github.com/kulinseth
  • Loading branch information
Raman-Kumar authored and kulinseth committed Dec 9, 2022
1 parent 159f70c commit 3f4a672
Show file tree
Hide file tree
Showing 4 changed files with 366 additions and 0 deletions.
8 changes: 8 additions & 0 deletions aten/src/ATen/native/mps/MPSGraphVenturaOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,12 @@
- (MPSGraphTensor *)cumulativeSumWithTensor:(MPSGraphTensor *)tensor
axis:(NSInteger)axis
name:(NSString *)name;

- (MPSGraphTensor *)sortWithTensor:(MPSGraphTensor *)tensor
axis:(NSInteger)axis
name:(NSString *)name;

- (MPSGraphTensor *)argSortWithTensor:(MPSGraphTensor *)tensor
axis:(NSInteger)axis
name:(NSString *)name;
@end
315 changes: 315 additions & 0 deletions aten/src/ATen/native/mps/operations/ReduceOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <ATen/native/ReduceOpsUtils.h>
#include <ATen/native/Pool.h>
#include <torch/library.h>
#include <ATen/native/mps/MPSGraphVenturaOps.h>

namespace at {
namespace native {
Expand Down Expand Up @@ -1638,5 +1639,319 @@ Tensor min_mps(const Tensor& input_t) {
return min_max_mps(input_t, dim, keepdim, MPSReductionType::MIN, "min_mps");
}

// Median of entire tensor into scalar result
Tensor median_mps(const Tensor& input_t) {

if(!is_macos_13_or_newer()){
TORCH_WARN_ONCE("MPS: median op is supported natively starting from macOS 13.0. ",
"Falling back on CPU. This may have performace implications.");
return at::median(input_t.to("cpu"));
}

TORCH_INTERNAL_ASSERT(input_t.scalar_type() != ScalarType::Long, "median not supported for Long dtype on MPS");

namespace native_mps = at::native::mps;
using CachedGraph = native_mps::MPSUnaryCachedGraph;

native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance();

IntArrayRef input_shape = input_t.sizes();
int64_t num_input_dims = input_shape.size();

// calculate total no. of elements in the input tensor to reduce it to one dimension
NSMutableArray<NSNumber*> *apparent_input_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:1];
int64_t num_in_elements = 1;
for(int i = 0; i < num_input_dims; i++) {
num_in_elements *= input_shape[i];
}

apparent_input_shape[0] = [NSNumber numberWithInt:num_in_elements];

Tensor output_t = at::native::empty_mps({}, input_t.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt);

if (output_t.numel() == 0 || num_in_elements == 0) {
return output_t;
}

@autoreleasepool {
string key = "median_mps:"+ mps::getMPSTypeString(input_t.scalar_type()) + mps::getTensorsStringKey(input_t);
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
// Initialize once if configuration not found in cache
if(!cachedGraph) {
native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () {

CachedGraph *newCachedGraph = nil;

@autoreleasepool {
MPSGraph* mpsGraph = native_mps::make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);

MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_t);

MPSGraphTensor* outputTensor = nil;

MPSGraphTensor * reshapedTensor = [mpsGraph reshapeTensor:inputTensor
withShape:@[@-1]
name:nil];
MPSGraphTensor * sortedTensor = [mpsGraph
sortWithTensor:reshapedTensor
axis:((NSUInteger) (int)0)
name:nil];

outputTensor = [mpsGraph sliceTensor:sortedTensor
dimension:0
start:((NSUInteger) (int)((num_in_elements+1)/2 ) - 1)
length:1
name:nil];

newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->outputTensor_ = outputTensor;
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}

auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t);
auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, output_t, @[@1]);

NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
};

NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};

native_mps::runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
}

return output_t;
}


void median_out_mps
(const Tensor& input_t,
int64_t dim,
bool keepdim,
const Tensor& output_t,
const Tensor& indices_t,
const std::string& func_name) {

namespace native_mps = at::native::mps;

if (output_t.numel() == 0) {
return;
}
if (input_t.numel() == 1 && input_t.dim() == 0) {
output_t.fill_(input_t);
indices_t.fill_(0);
return;
}

// Derive from MPSCachedGraph
struct CachedGraph : public native_mps::MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *inputTensor_ = nil;
MPSGraphTensor *outputTensor_ = nil;
MPSGraphTensor *indicesTensor_ = nil;
};

native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance();

int64_t dim_ = maybe_wrap_dim(dim, input_t.dim());

// Calculate the output shape according to keepdim=True
// If there is no dim argument, the input shape is flattened
IntArrayRef input_shape = input_t.sizes();
int64_t num_input_dims = input_shape.size();
NSMutableArray<NSNumber*> *apparent_out_shape = nil;

apparent_out_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
for(int i = 0; i < num_input_dims; i++) {
if(dim_ == i)
apparent_out_shape[i] = @1;
else
apparent_out_shape[i] = [NSNumber numberWithInt:input_shape[i]];
}
int dim_total_elements = input_shape[dim_];

auto stream = at::mps::getCurrentMPSStream();

@autoreleasepool {
string key = func_name + ":" + to_string(dim_) + ":" + native_mps::getMPSTypeString(input_t.scalar_type());
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));

if(!cachedGraph) {
native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () {

CachedGraph *newCachedGraph = nil;

@autoreleasepool {
MPSGraph* mpsGraph = native_mps::make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);

MPSGraphTensor* inputTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(input_t.scalar_type()));
MPSGraphTensor* outputTensor = nil;
MPSGraphTensor * sortedTensor = [mpsGraph
sortWithTensor:inputTensor
axis:((NSUInteger) (int)dim_)
name:nil];

outputTensor = [mpsGraph sliceTensor:sortedTensor
dimension:dim_
start:((NSUInteger) (int)((dim_total_elements+1)/2 ) - 1)
length:1
name:nil];
MPSGraphTensor* argreduceOutTensor = nil;
argreduceOutTensor = [mpsGraph argSortWithTensor:inputTensor
axis:(NSInteger)dim_
name:@"argmax_out"];
MPSGraphTensor* argOutputTensor = [mpsGraph sliceTensor:argreduceOutTensor
dimension:dim_
start:((NSUInteger) (int)((dim_total_elements+1)/2 ) - 1)
length:1
name:nil];

newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->outputTensor_ = outputTensor;
newCachedGraph->indicesTensor_ = argOutputTensor;
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}

auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t);
auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, output_t, apparent_out_shape);
auto indicesPlaceholder = native_mps::Placeholder(cachedGraph->indicesTensor_, indices_t, apparent_out_shape);

NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
};

NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData(),
indicesPlaceholder.getMPSGraphTensor() : indicesPlaceholder.getMPSGraphTensorData()
};

native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);

}

}

// in case mps sortWithTensor do not supported on macOS
std::tuple<Tensor&, Tensor&> median_from_cpu(
const Tensor& self,
int64_t dim,
bool keepdim, Tensor & valuesI, Tensor & indicesI, IntArrayRef vec_out_shape, IntArrayRef vec_apparent_out_shape) {
// Tensor a = at::median(self.to("cpu"));
Tensor values;
Tensor indices;
if (!keepdim){
values = at::empty({vec_out_shape}, self.options());
indices = at::empty({vec_out_shape}, self.options().dtype(kLong));

}
else{
values = at::empty({vec_apparent_out_shape}, self.options());
indices = at::empty({vec_apparent_out_shape}, self.options().dtype(kLong));
}
at::median_out(values, indices, self, dim, keepdim);

valuesI.copy_(values);
indicesI.copy_(indices);
return std::forward_as_tuple(valuesI, indicesI);
}

TORCH_API ::std::tuple<at::Tensor &,at::Tensor &> median_out_mps
(const at::Tensor & input_t,
int64_t dim,
bool keepdim,
at::Tensor & values,
at::Tensor & indices){

TORCH_INTERNAL_ASSERT(input_t.scalar_type() != ScalarType::Long, "median not supported for Long dtype on MPS");

namespace native_mps = at::native::mps;
int64_t dim_ = maybe_wrap_dim(dim, input_t.dim());
native::zero_numel_check_dims(input_t, dim_, "max()");

// Calculate the output shape according to keepdim=True
// If there is no dim argument, the input shape is flattened
IntArrayRef input_shape = input_t.sizes();
int64_t num_input_dims = input_shape.size();
NSMutableArray<NSNumber*> *apparent_out_shape = nil;
// Use this if keepdim is false
int64_t num_output_dims = num_input_dims - 1;

std::vector<int64_t> vec_apparent_out_shape(num_input_dims);
std::vector<int64_t> vec_out_shape(num_output_dims);

apparent_out_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
// Counter for shape when keepdim is false
int out_i = 0;
for(int i = 0; i < num_input_dims; i++) {
if(dim_ == i) {
apparent_out_shape[i] = @1;
vec_apparent_out_shape[i] = 1;
}
else {
apparent_out_shape[i] = [NSNumber numberWithInt:input_shape[i]];
vec_apparent_out_shape[i] = input_shape[i];
vec_out_shape[out_i] = input_shape[i];
out_i++;
}
}

if(!keepdim) {
values = at::native::empty_mps(
IntArrayRef(vec_out_shape),
input_t.scalar_type(),
c10::nullopt,
kMPS,
c10::nullopt,
c10::nullopt);
indices = at::native::empty_mps(
IntArrayRef(vec_out_shape),
ScalarType::Long,
c10::nullopt,
kMPS,
c10::nullopt,
c10::nullopt);
} else {
values = at::native::empty_mps(
IntArrayRef(vec_apparent_out_shape),
input_t.scalar_type(),
c10::nullopt,
kMPS,
c10::nullopt,
c10::nullopt);
indices = at::native::empty_mps(
IntArrayRef(vec_apparent_out_shape),
ScalarType::Long,
c10::nullopt,
kMPS,
c10::nullopt,
c10::nullopt);
}

if (values.numel() == 0 || input_t.numel() == 0) {
return std::tuple<Tensor&, Tensor&>{values, indices};
}

if(!is_macos_13_or_newer()){
TORCH_WARN_ONCE("MPS: median op is supported natively starting from macOS 13.0.",
"Falling back on CPU. This may have performace implications.");
return median_from_cpu(input_t.to("cpu"), dim, keepdim, values, indices, IntArrayRef(vec_out_shape),IntArrayRef(vec_apparent_out_shape) );
}

median_out_mps(input_t, dim, keepdim, values, indices, "median_out_mps");

return std::tuple<Tensor&, Tensor&>{values, indices};
}

} // native
} // at
2 changes: 2 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3492,6 +3492,7 @@
dispatch:
CPU: median_cpu
CUDA: median_cuda
MPS: median_mps
autogen: median.out

- func: median.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
Expand All @@ -3503,6 +3504,7 @@
dispatch:
CPU: median_out_cpu
CUDA: median_out_cuda
MPS: median_out_mps

- func: median.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)
variants: function, method
Expand Down
Loading

0 comments on commit 3f4a672

Please sign in to comment.