Skip to content

Commit

Permalink
Matmul crash errors (#76)
Browse files Browse the repository at this point in the history
* Handle empty input with non-empty output

* Remove transpose options from mm op
  • Loading branch information
abhudev authored and kulinseth committed Aug 9, 2022
1 parent 3953764 commit 68e486a
Showing 1 changed file with 29 additions and 34 deletions.
63 changes: 29 additions & 34 deletions aten/src/ATen/native/mps/operations/LinearAlgebra.mm
Original file line number Diff line number Diff line change
Expand Up @@ -125,17 +125,11 @@ void prepare_matrices_for_broadcasting(

MPSStream* stream = getCurrentMPSStream();

bool transpose_mat1 = false;
bool transpose_mat2 = false;

prepare_matrices_for_broadcasting(NULL, self, other, NULL, NULL, transpose_mat1, transpose_mat2);

mps::MPSGraphCache *cache_ = mps::MPSGraphCache::getInstance();

@autoreleasepool {

string key = "mm_out_mps_impl" + getTensorsStringKey({self, other})
+ ":" + to_string(transpose_mat1) + ":" + to_string(transpose_mat2);
string key = "mm_out_mps_impl" + getTensorsStringKey({self, other});

CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
if(!cachedGraph) {
Expand All @@ -147,31 +141,25 @@ void prepare_matrices_for_broadcasting(
MPSGraph *mpsGraph = mps::make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);

MPSGraphTensor *selfTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, self);
MPSGraphTensor *otherTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, other);
MPSGraphTensor *selfTensor = nil;
MPSGraphTensor *otherTensor = nil;
MPSGraphTensor *outputTensor = nil;

MPSGraphTensor* t1 = nil;
MPSGraphTensor* t2 = nil;
if(self.numel() == 0 || other.numel() == 0) {

if(transpose_mat1)
t1 = [mpsGraph transposeTensor:selfTensor
dimension:-1
withDimension:-2
name:nil];
else
t1 = selfTensor;
outputTensor = [mpsGraph constantWithScalar:0.
shape:getMPSShape(output_sizes)
dataType:getMPSDataType(output.scalar_type())];

if(transpose_mat2)
t2 = [mpsGraph transposeTensor:otherTensor
dimension:-1
withDimension:-2
name:nil];
else
t2 = otherTensor;
}
else {

MPSGraphTensor* outputTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:t1
secondaryTensor:t2
name:nil];
selfTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, self);
otherTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, other);
outputTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:selfTensor
secondaryTensor:otherTensor
name:nil];
}

newCachedGraph->selfTensor_ = selfTensor;
newCachedGraph->otherTensor_ = otherTensor;
Expand All @@ -181,14 +169,21 @@ void prepare_matrices_for_broadcasting(
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}
Placeholder selfPlaceholder = Placeholder(cachedGraph->selfTensor_, self);
Placeholder otherPlaceholder = Placeholder(cachedGraph->otherTensor_, other);
Placeholder selfPlaceholder = Placeholder();
Placeholder otherPlaceholder = Placeholder();
if(!(self.numel() == 0 || other.numel() == 0)) {
selfPlaceholder = Placeholder(cachedGraph->selfTensor_, self);
otherPlaceholder = Placeholder(cachedGraph->otherTensor_, other);
}
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);

NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(),
otherPlaceholder.getMPSGraphTensor() : otherPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = nil;

if(!(self.numel() == 0 || other.numel() == 0))
feeds = @{
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(),
otherPlaceholder.getMPSGraphTensor() : otherPlaceholder.getMPSGraphTensorData()
};

NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
Expand Down

0 comments on commit 68e486a

Please sign in to comment.