Skip to content

Commit

Permalink
[MPS] Fix chaining of View tensor when shape is different from Parent…
Browse files Browse the repository at this point in the history
… tensor.

Also add the TestConsistency tests to test_mps.
  • Loading branch information
kulinseth committed Jun 14, 2022
1 parent 5399fef commit d360fe4
Show file tree
Hide file tree
Showing 2 changed files with 1,308 additions and 8 deletions.
21 changes: 18 additions & 3 deletions aten/src/ATen/native/mps/operations/Copy.mm
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,7 @@ Tensor as_strided_tensorimpl_mps(const Tensor& self, IntArrayRef size,
setStrided(result, size, stride, storage_offset);

// 0 sizes won't result in any change in the shape of the Tensor so we can
// skip it. Also if the memory is contiguous we don't need to do
// gather-scatter operations using graph.
// skip it.
if (size.size() > 0) {

// If self itself was a view tensor, that means we need to chain the graphs
Expand All @@ -127,17 +126,33 @@ Tensor as_strided_tensorimpl_mps(const Tensor& self, IntArrayRef size,
MPSGraphCache* cache_ = MPSGraphCache::getInstance();

@autoreleasepool {
string lookup_key = mps::getStridedKey(self, self.sizes(), self.strides(),
self.storage_offset());

MPSGraphTensor *parentInputTensor = nil;
CachedGraph* parentCachedGraph = static_cast<CachedGraph *>(cache_->LookUp(lookup_key));
if (parentCachedGraph) {
parentInputTensor = parentCachedGraph->inputTensor_;
}

string key = mps::getStridedKey(self, size, stride, storage_offset);
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
if (!cachedGraph) {
cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
@autoreleasepool {
MPSShape *shape = nil;
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);

// All chained view operations should use the shape of the first contiguous tensor from which the view was created
if (parentInputTensor)
shape = [parentInputTensor shape];
else
shape = getMPSShape(self);

// Self is the input tensor we are creating view of
MPSGraphTensor* inputTensor = [mpsGraph placeholderWithShape : getMPSShape(self)
MPSGraphTensor* inputTensor = [mpsGraph placeholderWithShape : shape
dataType : getMPSDataType(self.scalar_type())
name : nil];
newCachedGraph->inputTensor_ = inputTensor;
Expand Down
Loading

0 comments on commit d360fe4

Please sign in to comment.