Skip to content

Commit

Permalink
Dev/denis/reduce ops multi axes support (#188)
Browse files Browse the repository at this point in the history
* Add multi axes support for reduce ops

* Add back bessel_corrected variable

* Remove input flattening from reduce ops; enable more tests in TestConsistency

* Refactor Repeat.mm

* Fix remaining reduce ops issues

* Remove debug code

* Fix missing colon

* Always wrap input dimensions

* Remove dimension wrapping (already wrapped)

* Address remaining PR comments
  • Loading branch information
DenisVieriu97 committed Jan 6, 2023
1 parent f219970 commit 8c1df75
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 91 deletions.
25 changes: 25 additions & 0 deletions aten/src/ATen/native/mps/OperationUtils.mm
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,31 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) {
}
}

NSArray<NSNumber*>* getTensorAxes(const Tensor& t) {
int64_t ndim = t.dim();
auto axes = [NSMutableArray<NSNumber*> arrayWithCapacity:ndim];
for (const auto i: c10::irange(ndim)) {
axes[i] = [NSNumber numberWithInteger:i];
}
return axes;
}

NSArray<NSNumber*>* getTensorAxes(const Tensor& t, at::OptionalIntArrayRef dim) {
if (dim.has_value() && dim.value().size() != 0) {
IntArrayRef dimValues = dim.value();
int ndim = dimValues.size();
auto axes = [NSMutableArray<NSNumber*> arrayWithCapacity:ndim];
for (const auto i: c10::irange(ndim)) {
axes[i] = [NSNumber numberWithInteger:dimValues[i]];
}

return axes;
}

return getTensorAxes(t);
}


NSArray<NSNumber*>* getTensorAxes(const Tensor& t) {
int64_t ndim = t.dim();
auto axes = [NSMutableArray<NSNumber*> arrayWithCapacity:ndim];
Expand Down
113 changes: 24 additions & 89 deletions aten/src/ATen/native/mps/operations/Repeat.mm
Original file line number Diff line number Diff line change
Expand Up @@ -36,48 +36,6 @@ Tensor permute_mps(const Tensor& self, IntArrayRef dims) {
return self.as_strided(newSizes, newStrides);
}

void set_apparent_shapes(NSArray<NSNumber*> * input_shape,
NSArray<NSNumber*> * &apparent_input_shape,
int64_t num_input_dims,
IntArrayRef repeats,
NSMutableArray<NSNumber*> * &repeats_shape,
int64_t num_repeat_dims) {


bool repeat_empty = false;
if(num_repeat_dims == 0) {
num_repeat_dims = num_input_dims;
repeat_empty = true;
}

// Set repeats_shape
repeats_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:num_repeat_dims];

for(int i = 0; i < num_repeat_dims; i++) {
if(repeat_empty)
repeats_shape[i] = [NSNumber numberWithInteger:1];
else
repeats_shape[i] = [NSNumber numberWithInteger:repeats[i]];
}

// If no extension of the shape is needed
if(num_repeat_dims == num_input_dims) {
apparent_input_shape = input_shape;
}
// num_repeat_dims > num_input_dims
else {
auto rc = [NSMutableArray<NSNumber*> arrayWithCapacity:num_repeat_dims];

for(int i = 0; i < num_repeat_dims - num_input_dims; i++)
rc[i] = @1;

for(int i = num_repeat_dims - num_input_dims; i < num_repeat_dims; i++)
rc[i] = input_shape[i + num_input_dims - num_repeat_dims];
apparent_input_shape = rc;
}

}

Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) {

using namespace mps;
Expand All @@ -91,54 +49,32 @@ Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) {
MPSGraphTensor *outputTensor_ = nil;
};

MPSGraphCache* cache_ = MPSGraphCache::getInstance();

NSArray<NSNumber*> *apparent_input_shape = nil;
NSMutableArray<NSNumber*> *repeats_shape = nil;

auto input_shape = getMPSShape(self);
auto num_input_dims = [input_shape count];
auto num_repeat_dims = repeats.size();

set_apparent_shapes(input_shape,
apparent_input_shape,
num_input_dims,
repeats,
repeats_shape,
num_repeat_dims);

// Set output shape
std::vector<int64_t> output_shape(num_repeat_dims);
// Add new leading dimensions to the tensor if the
// number of target dimensions is larger than the
// number of source dimensions.
int64_t num_new_dimensions = repeats.size() - self.dim();
DimVector padded_size(num_new_dimensions, 1);
padded_size.insert(padded_size.end(), self.sizes().begin(), self.sizes().end());
DimVector target_size(repeats.size());
bool zero_tensor = false;
for(auto i : c10::irange(num_repeat_dims)) {
output_shape[i] = repeats[i] * [apparent_input_shape[i] intValue];
if(output_shape[i] == 0) {
for(const auto idx : c10::irange(repeats.size())) {
if (repeats[idx] == 0) {
zero_tensor = true;
}
target_size[idx] = padded_size[idx] * repeats[idx];
}

Tensor output = at::native::empty_mps(
IntArrayRef(output_shape),
self.scalar_type(),
c10::nullopt,
kMPS,
c10::nullopt,
c10::nullopt);

// Empty output
if(zero_tensor || output.numel() == 0)
return output;
Tensor expanded_tensor = self.expand(padded_size);
Tensor result = at::empty(target_size, self.options());
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
if(zero_tensor || result.numel() == 0) {
return result;
}

auto stream = at::mps::getCurrentMPSStream();

@autoreleasepool {

NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","];
NSString* ns_repeats_key = [[repeats_shape valueForKey:@"description"] componentsJoinedByString:@","];

string key = "repeat_mps:" + getMPSTypeString(self.scalar_type())
+ ":" + string([ns_shape_key UTF8String])
+ ":" + string([ns_repeats_key UTF8String]);
string key = "repeat_mps:" + getTensorsStringKey(self) + ":" + getArrayRefString(repeats);
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));

if(!cachedGraph) {
Expand All @@ -149,9 +85,9 @@ Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);

MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self.scalar_type()), apparent_input_shape);
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, expanded_tensor);
MPSGraphTensor* outputTensor = [mpsGraph tileTensor:inputTensor
withMultiplier:repeats_shape
withMultiplier:getMPSShape(repeats)
name:nil];

newCachedGraph->inputTensor_ = inputTensor;
Expand All @@ -162,8 +98,8 @@ Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) {
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}

Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, apparent_input_shape);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, expanded_tensor);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result);

NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()
Expand All @@ -175,9 +111,8 @@ Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) {
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}

return output;

return result;
}

}
}
} // namespace native
} // namespace at
24 changes: 22 additions & 2 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -8358,17 +8358,31 @@ class TestConsistency(TestCase):
'unbind': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'unflatten': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'unsqueeze': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'var': ['f32'],
'view': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'view_as': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'vsplit': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'vstack': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'zero_': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'where': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'nonzero': ['f32', 'i16', 'i32', 'i64'],
'unique_consecutive': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'cross': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'linalg.cross': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'unique_consecutive': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'std': ['f16','f32'],
'var': ['f16','f32'],
'amax': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'amin': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'sum': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'prod': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'mean': ['f16', 'f32'],
'count_nonzero': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'masked.amax': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'masked.amin': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'masked.mean': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'masked.prod': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'masked.std': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'masked.sum': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'masked.var': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'],
}


Expand Down Expand Up @@ -8567,6 +8581,11 @@ class TestConsistency(TestCase):
'__rpow__': [torch.int64],
'masked.std': [torch.int32],
'masked.var': [torch.int32],

# Failures due to inconsistency between CPU and GPU for `inf` case
'masked.argmax': ['f16', 'f32', 'i32'],
'masked.argmin': ['f16', 'f32', 'i32'],

'as_strided_scatter': [torch.uint8],
'atan2': [torch.int64],
'bfloat16': None,
Expand Down Expand Up @@ -8683,6 +8702,7 @@ class TestConsistency(TestCase):
'masked.softmax': [torch.float32],
'masked.softmin': [torch.float32],
'masked.log_softmax': [torch.float32],
'masked.var': ['f16'],
'dot': [torch.int64],
}

Expand Down

0 comments on commit 8c1df75

Please sign in to comment.