Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eye op #7

Merged
merged 4 commits into from
May 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions aten/src/ATen/native/mps/operations/Eye.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#include <ATen/ATen.h>
#include <ATen/Tensor.h>
#include <ATen/Utils.h>
#include <ATen/mps/MPSStream.h>
#include <ATen/native/mps/OperationUtils.h>
#include <torch/library.h>
#include "c10/util/Optional.h"


// Steps to add op for MPS backend:
// 1. Register the op in aten/src/ATen/native/native_functions.yaml with the "MPS" dispatch key
// 2. Define the function interface for the MPS backend similar to other
// backends depending on whether its structured or non-structured
// 3. Add boiler-plate error checking code as expected for the Op
// 4. The code structure roughly follows the pattern
// a) get the MPS stream handle to encode work onto
// b) get an instance of MPSGraphCache and create a key unique to the Graph
// needed for implementing this Op. Any shape, dataType or parameter
// passed to the MPSGraph during its construction will need to be included
// here.
// c) Create the graph using make_mps_graph() and add operations to the
// instance of MPSGraph. This is if the Cache->lookup() fails.
// d) Store the MPSGraphTensors for inputs and output which are needed at
// runtime.
// e) Use the CachedGraph instance's inputs and output to create Placeholders
// You will need to pass in Tensor to create MPSGraphTensorData objects.
// f) Using MPSGraphTensor and MPSGraphTensorData instances create a feeds
// dictionary.
// g) Then call runMPSGraph() with input params and return the result.
//


namespace at {
namespace native {

Tensor& eye_out_mps(int64_t n, Tensor& result) {
// the default value of `m` equals to `n`
return eye_out_mps(n, n, result);
}

Tensor& eye_out_mps(int64_t n, int64_t m, Tensor& result) {

// This is one example of boiler-plate error checking, taking after CPU/CUDA counterparts
TORCH_CHECK(n >= 0, "n must be greater or equal to 0, got ", n);
TORCH_CHECK(m >= 0, "m must be greater or equal to 0, got ", m);

result.resize_({n, m});
result.zero_();

// Handle empty outputs
if(result.numel() == 0)
return result;

// Get MPS stream
using namespace mps;
MPSStream* stream = getCurrentMPSStream();

// Derive from MPSCachedGraph
// This structure is used to cache an MPSGraph with certain keys, so that we don't have to compile the same MPSGraph time and time again for the same operation
// The keys of this structure are based on the inputs and outputs needed for the operation
// Here, we don't have any input tensors, just an output tensor
struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* outputTensor_ = nil;
};

MPSGraphCache* cache_ = MPSGraphCache::getInstance();

@autoreleasepool {
// A key is used to identify the MPSGraph which was created once, and can be reused if the parameters, data types etc match the earlier created MPSGraph
string key = "eye_out_mps:" + getTensorsStringKey({result});
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
if(!cachedGraph) {
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {

CachedGraph *newCachedGraph = nil;

@autoreleasepool {
// Initialize graph
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
MPSGraphTensor* onesTensor = [mpsGraph constantWithScalar:1.0f
shape:getMPSShape(result)
dataType:getMPSDataType(result.scalar_type())];

// Here we can call the MPSGraph API needed to execute the operation.
// The API details can be found here: https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraph
MPSGraphTensor* outputTensor = [mpsGraph bandPartWithTensor:onesTensor
numLower:0
numUpper:0
name:nil];
newCachedGraph->outputTensor_ = outputTensor;
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}

// Create placeholders which use the keys of the CachedGraph to create inputs and outputs of the operation
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result);

// Create dictionary of inputs/feeds and outputs/results
// In this case, there are no inputs, so the feeds are nil
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = nil;

NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};

// Run the graph
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}

return result;
}


}
}
2 changes: 2 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2178,11 +2178,13 @@
dispatch:
CPU: eye_out_cpu
CUDA: eye_out_cuda
MPS: eye_out_mps

- func: eye.m_out(int n, int m, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU: eye_out_cpu
CUDA: eye_out_cuda
MPS: eye_out_mps

- func: flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)
variants: function, method
Expand Down
23 changes: 23 additions & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -3516,6 +3516,29 @@ def helper(shape, diag=0):
helper((2, 8, 4, 5), diag=-2)
helper((2, 8, 4, 5), diag=-3)

# test eye
def test_eye(self):
def helper(n, m, dtype):
cpu_result = None
result = None

if(n == m):
cpu_result = torch.eye(n, dtype=dtype, device='cpu')
result = torch.eye(n, dtype=dtype, device='mps')
else:
cpu_result = torch.eye(n, m, device='cpu')
result = torch.eye(n, m, device='mps')

self.assertEqual(result, cpu_result)

for dtype in [torch.float32, torch.int32, torch.int64]:
helper(2,2, dtype)
helper(2,3, dtype)
helper(0,2, dtype)
helper(0,0, dtype)
helper(3,8, dtype)
helper(8,3, dtype)

# Test diag
def test_diag(self):
def helper(shape, diag=0):
Expand Down