Skip to content

Commit

Permalink
Revert "Revert MPS changes (pytorch#78335)"
Browse files Browse the repository at this point in the history
This reverts commit ffb3101.
  • Loading branch information
kulinseth committed May 26, 2022
1 parent f4e493e commit b6fee7f
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 27 deletions.
4 changes: 2 additions & 2 deletions aten/src/ATen/mps/EmptyTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ TensorBase empty_mps(
c10::optional<c10::MemoryFormat> memory_format_opt) {
#if defined(__APPLE__)
#if __is_target_os(macOS)
if (__builtin_available(macOS 12.3, *) || __builtin_available(macOSApplicationExtension 12.3, *)) {
if (at::hasMPS()) {
auto device = device_or_default(device_opt);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(device.type() == DeviceType::MPS);

Expand Down Expand Up @@ -87,7 +87,7 @@ TensorBase empty_strided_mps(
c10::optional<Device> device_opt) {
#if defined(__APPLE__)
#if __is_target_os(macOS)
if (__builtin_available(macOS 12.3, *) || __builtin_available(macOSApplicationExtension 12.3, *)) {
if (at::hasMPS()) {
auto device = device_or_default(device_opt);
TORCH_INTERNAL_ASSERT(device.is_mps());
TORCH_CHECK_TYPE(dtype != ScalarType::Double, MPS_ERROR_DOUBLE_NOT_SUPPORTED);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/mps/MPSAllocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ struct HeapBlock
d.type = MTLHeapTypeAutomatic;
heap = [device newHeapWithDescriptor: d];
if (heap) {
[heap setPurgeableState:MTLPurgeableStateEmpty];
[heap setPurgeableState:MTLPurgeableStateNonVolatile];
}
[d release];
}
Expand Down
15 changes: 11 additions & 4 deletions aten/src/ATen/mps/MPSDevice.mm
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,17 @@

MPSDevice::MPSDevice(): _mtl_device(nil) {
// Check that MacOS 12.3+ version of MPS framework is available
id mpsCD = NSClassFromString(@"MPSGraphCompilationDescriptor");
if (![mpsCD instancesRespondToSelector:@selector(optimizationLevel)]) {
// According to https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraphcompilationdescriptor/3922624-optimizationlevel
// this means we are running on older MacOS
// Create the MPSGraph and check method introduced in 12.3+
// which is used by MPS backend.
id mpsCD = NSClassFromString(@"MPSGraph");
if ([mpsCD instancesRespondToSelector:@selector(LSTMWithSourceTensor:
recurrentWeight:
inputWeight:
bias:
initState:
initCell:
descriptor:
name:)] == NO)) {
return;
}
NSArray* devices = [MTLCopyAllDevices() autorelease];
Expand Down
30 changes: 10 additions & 20 deletions aten/src/ATen/native/mps/operations/PointwiseOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,21 @@
const bool is_div,
const string op_name)
{
using scalar_t = double;
scalar_t value_scalar = value_opt.to<scalar_t>();
if (&output != &self) {
output.resize_(output.sizes());
}
TORCH_CHECK(output.is_mps());
MPSStream* mpsStream = getCurrentMPSStream();

// Derive from MPSCachedGraph
struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *inputTensor = nil, *outputTensor = nil;
MPSGraphTensor *firstTensor = nil, *secondTensor = nil;
MPSGraphTensor *firstTensor = nil, *secondTensor = nil, *valueTensor = nil;
};
MPSGraphCache* cache_ = MPSGraphCache::getInstance();

@autoreleasepool {
string key = op_name + to_string(value_scalar)
+ getTensorsStringKey({self, tensor1, tensor2})+ ":"
+ getMPSTypeString(value_opt.type());
string key = op_name + getTensorsStringKey({self, tensor1, tensor2}, false);

CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));

Expand All @@ -49,6 +44,7 @@
newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
newCachedGraph->firstTensor = mpsGraphRankedPlaceHolder(mpsGraph, tensor1);
newCachedGraph->secondTensor = mpsGraphRankedPlaceHolder(mpsGraph, tensor2);
newCachedGraph->valueTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSScalarType(self.scalar_type()));

// the tensor to be optionally multiplied by value_scalar
MPSGraphTensor *multiplicandTensor = nil;
Expand All @@ -62,15 +58,9 @@
name:nil];
}
// the tensor to be added to input_tensor
MPSGraphTensor *addendTensor = multiplicandTensor;
// if value_scalar is 1.0, then we don't bother adding another multiply to graph
if (value_scalar != 1.0) {
MPSGraphTensor* valueTensor = [mpsGraph constantWithScalar:value_scalar
dataType:getMPSScalarType(value_opt.type())];
addendTensor = [mpsGraph multiplicationWithPrimaryTensor:multiplicandTensor
secondaryTensor:valueTensor
name:nil];
}
MPSGraphTensor *addendTensor = [mpsGraph multiplicationWithPrimaryTensor:multiplicandTensor
secondaryTensor:newCachedGraph->valueTensor
name:nil];
newCachedGraph->outputTensor = [mpsGraph additionWithPrimaryTensor:newCachedGraph->inputTensor
secondaryTensor:addendTensor
name:nil];
Expand All @@ -87,18 +77,18 @@
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output);

// Create dictionary of inputs and outputs
// Utility to dump out graph : [mpsGraph dump];
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(),
tensor1Placeholder.getMPSGraphTensor() : tensor1Placeholder.getMPSGraphTensorData(),
tensor2Placeholder.getMPSGraphTensor() : tensor2Placeholder.getMPSGraphTensorData()
tensor2Placeholder.getMPSGraphTensor() : tensor2Placeholder.getMPSGraphTensorData(),
cachedGraph->valueTensor : getMPSGraphTensorFromScalar(mpsStream, value_opt, getMPSScalarType(self.scalar_type())),
};

NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};

runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
runMPSGraph(mpsStream, cachedGraph->graph(), feeds, results);
}

return output;
Expand Down

0 comments on commit b6fee7f

Please sign in to comment.