Skip to content

Commit

Permalink
Handle transposes in second batch of matrices in bmm (#451)
Browse files Browse the repository at this point in the history
* Add transpose in bmm

* Handle transposes in second batch of matrices in bmm

* Fix comment

* Fix formatting

* Fix build failure

* Fix test failures
  • Loading branch information
DenisVieriu97 committed Jun 14, 2023
1 parent 0ea54ca commit 3683c76
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 7 deletions.
8 changes: 4 additions & 4 deletions aten/src/ATen/native/mps/operations/Copy.mm
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ kernel void copy_cast_kernel(uint tid [[thread_position_in_grid]],
NSError *error = nil;
MTLCompileOptions *options = [[MTLCompileOptions new] autorelease];
[options setLanguageVersion: MTLLanguageVersion2_3];
auto gatherScatterLib = [device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(COPY_CAST_OP_TEMPLATE_TENSOR, dtypeSrc, dtypeDst).c_str()]
auto copyCastLib = [device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(COPY_CAST_OP_TEMPLATE_TENSOR, dtypeSrc, dtypeDst).c_str()]
options:options
error:&error];
TORCH_CHECK(gatherScatterLib != nil && error == nil, "Failed to compile gather-scatter library, error: ", [[error description] UTF8String]);
_libCache[key] = gatherScatterLib;
return gatherScatterLib;
TORCH_CHECK(copyCastLib != nil && error == nil, "Failed to compile copy cast library, error: ", [[error description] UTF8String]);
_libCache[key] = copyCastLib;
return copyCastLib;
}

static id<MTLComputePipelineState> getPipelineState(id<MTLDevice> device,
Expand Down
34 changes: 31 additions & 3 deletions aten/src/ATen/native/mps/operations/LinearAlgebra.mm
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

// Copyright © 2022 Apple Inc.

#include <ATen/native/mps/OperationUtils.h>
Expand Down Expand Up @@ -530,8 +531,28 @@ Tensor addr_mps(const Tensor& self,

mps::MPSGraphCache *cache_ = mps::MPSGraphCache::getInstance();

MPSShape* shape = nil;
bool doTranspose = false;

// Handle transposes for the second batch of matrices.
if (batch2.is_view() && !batch2.is_contiguous()) {
if (batch2.numel() == batch2._base().numel()) {
const IntArrayRef& viewSizes = batch2.sizes();

// Handle 3D and 4D tensors.
// For 4D tensors, first it must have been reshaped from 4D to 3D and then transposed.
int32_t baseTransposeStrideDim = batch2._base().dim() == 4 ? -3 : -2;
if (batch2.size(0) == batch2.size(0) &&
batch2._base().stride(0) == batch2.stride(0) &&
batch2._base().stride(baseTransposeStrideDim) == batch2.stride(-1)) {
shape = @[@(viewSizes[0]), @(viewSizes[2]), @(viewSizes[1])];
doTranspose = true;
}
}
}

@autoreleasepool {
string key = "bmm_out_mps_impl" + getTensorsStringKey({batch1, batch2}, true, /*exclude_shape*/true);
string key = "bmm_out_mps_impl" + getTensorsStringKey({batch1, batch2}, true, /*exclude_shape*/true) + std::to_string(doTranspose);

CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
if(!cachedGraph) {
Expand All @@ -545,9 +566,16 @@ Tensor addr_mps(const Tensor& self,

MPSGraphTensor *batch1Tensor = mps::mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(batch1.scalar_type()));
MPSGraphTensor *batch2Tensor = mps::mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(batch2.scalar_type()));
MPSGraphTensor *batch2TensorTranspose = batch2Tensor;

if (doTranspose) {
batch2TensorTranspose = [mpsGraph transposeTensor:batch2Tensor
dimension:-1
withDimension:-2
name:nil];
}
MPSGraphTensor* productTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:batch1Tensor
secondaryTensor:batch2Tensor
secondaryTensor:batch2TensorTranspose
name:@"MM/(batch1@batch2)"];

newCachedGraph->batch1Tensor_ = batch1Tensor;
Expand All @@ -559,7 +587,7 @@ Tensor addr_mps(const Tensor& self,
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}
Placeholder batch1Placeholder = Placeholder(cachedGraph->batch1Tensor_, batch1);
Placeholder batch2Placeholder = Placeholder(cachedGraph->batch2Tensor_, batch2);
Placeholder batch2Placeholder = Placeholder(cachedGraph->batch2Tensor_, batch2, shape, !doTranspose);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result);

NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
Expand Down

0 comments on commit 3683c76

Please sign in to comment.