Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Add vectorization support to binary kernels #445

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 122 additions & 65 deletions aten/src/ATen/native/mps/operations/BinaryKernel.mm
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,18 @@
Strided_Tensor
};

struct BinaryKernelOpInfo {
std::string op_name;
std::string metal_operator;
bool canVectorize;
};

static char* BINARY_OP_TEMPLATE_TENSOR = R"METAL_BINARY(
kernel void {3}_kernel(uint tid [[thread_position_in_grid]],
const device {1} * input [[buffer(0)]],
const device {2} * other [[buffer(1)]],
device {0} * output [[buffer(2)]]) {{
output[tid] = ({5})input[tid] {4} ({5})other[tid];
output[tid] = ({5})input[tid] {4} ({6})other[tid];
}}
)METAL_BINARY";

Expand All @@ -40,7 +46,7 @@
const device {1}* input = (const device {1}*)((const device uint8_t*)input_ + offsets.y);
const device {2}* other = (const device {2}*)((const device uint8_t*)other_ + offsets.z);

*output = ({5})*input {4} ({5})*other;
*output = ({5})*input {4} ({6})*other;
}}
)METAL_BINARY";

Expand All @@ -49,7 +55,7 @@
const device {1} & input [[buffer(0)]],
const device {2} * other [[buffer(1)]],
device {0} * output [[buffer(2)]]) {{
output[tid] = ({5})input {4} ({5})other[tid];
output[tid] = ({5})input {4} ({6})other[tid];
}}
)METAL_BINARY";

Expand All @@ -58,7 +64,7 @@
const device {1} * input [[buffer(0)]],
const device {2} & other [[buffer(1)]],
device {0} * output [[buffer(2)]]) {{
output[tid] = ({5})input[tid] {4} ({5})other;
output[tid] = ({5})input[tid] {4} ({6})other;
}}
)METAL_BINARY";

Expand All @@ -67,7 +73,7 @@
const device {1} & input [[buffer(0)]],
const device {2} & other [[buffer(1)]],
device {0} & output [[buffer(2)]]) {{
output = ({5})input {4} ({5})other;
output = ({5})input {4} ({6})other;
}}
)METAL_BINARY";

Expand All @@ -85,7 +91,7 @@
device {0}* output = (device {0}*)((device uint8_t*)output_ + offsets.x);
const device {1}* input = (const device {1}*)((const device uint8_t*)input_ + offsets.y);

*output = ({5})*input {4} ({5})other;
*output = ({5})*input {4} ({6})other;
}}
)METAL_BINARY";

Expand All @@ -103,19 +109,28 @@
device {0}* output = (device {0}*)((device uint8_t*)output_ + offsets.x);
const device {2}* other = (const device {2}*)((const device uint8_t*)other_ + offsets.z);

*output = ({5})input {4} ({5})*other;
*output = ({5})input {4} ({6})*other;
}}
)METAL_BINARY";

static uint8_t getVectorType(int64_t input_numel) {
if (input_numel % 4 == 0) return 4;
if (input_numel % 3 == 0) return 3;
if (input_numel % 2 == 0) return 2;

return 0;
}

static id<MTLLibrary> compileBinaryOpsLibrary(id<MTLDevice> device,
const std::string& t1,
const std::string& t2,
const std::string& t3,
const std::string& common_dtype,
const std::string& cast_dtype_input,
const std::string& cast_dtype_other,
const std::string& op,
const std::string& kernel_operator,
BinaryKernelType binaryKernelType) {
auto key = op + t1 + t2 + t3 + common_dtype + std::to_string(int(binaryKernelType));
auto key = op + t1 + t2 + t3 + cast_dtype_input + std::to_string(int(binaryKernelType));
static std::unordered_map<std::string, id<MTLLibrary>> libMap;
auto it = libMap.find(key);
if (it != libMap.end()) {
Expand Down Expand Up @@ -158,7 +173,7 @@
TORCH_CHECK(false, "Unknown binary template");
}

auto rc = [device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(str, t1, t2, t3, op, kernel_operator, common_dtype).c_str()]
auto rc = [device newLibraryWithSource:[NSString stringWithUTF8String:fmt::format(str, t1, t2, t3, op, kernel_operator, cast_dtype_input, cast_dtype_other).c_str()]
options:options
error:&error];
TORCH_CHECK(rc != nil && error == nil, "Failed to compile library: ", [[error localizedDescription] UTF8String]);
Expand All @@ -170,19 +185,20 @@
const std::string& t1,
const std::string& t2,
const std::string& t3,
const std::string& common_dtype,
const std::string& cast_dtype_input,
const std::string& cast_dtype_other,
const std::string& fname,
const std::string& op,
const std::string& kernel_operator,
BinaryKernelType binaryKernelType) {
auto key = t1 + t2 + t3 + common_dtype + fname;
auto key = t1 + t2 + t3 + cast_dtype_input + fname;
static std::unordered_map<std::string, id<MTLComputePipelineState>> cplMap;
auto it = cplMap.find(key);
if (it != cplMap.end()) {
return it->second;
}
NSError *error = nil;
auto library = compileBinaryOpsLibrary(device, t1, t2, t3, common_dtype, op, kernel_operator, binaryKernelType);
auto library = compileBinaryOpsLibrary(device, t1, t2, t3, cast_dtype_input, cast_dtype_other, op, kernel_operator, binaryKernelType);
id<MTLFunction> func = [library newFunctionWithName:[NSString stringWithUTF8String:fname.c_str()]];
TORCH_CHECK(func != nil, "Can't get function ", fname);
auto rc = [device newComputePipelineStateWithFunction:func error:&error];
Expand All @@ -192,11 +208,14 @@
}

static
void dispatch_binary_kernel_mps_(TensorIteratorBase& iter, const std::string& op, const std::string& kernel_operator) {
void dispatch_binary_kernel_mps_(TensorIteratorBase& iter, const BinaryKernelOpInfo& binaryKernelOpInfo) {
Tensor inputTensor;
Tensor otherTensor;
BinaryKernelType type;

const std::string& op = binaryKernelOpInfo.op_name;
const std::string& kernel_operator = binaryKernelOpInfo.metal_operator;

int scalar_pos = 0;
bool all_scalar = false;
const Tensor& outputTensor = iter.tensor(0);
Expand Down Expand Up @@ -264,52 +283,87 @@ void dispatch_binary_kernel_mps_(TensorIteratorBase& iter, const std::string& op
MPSStream* mpsStream = getCurrentMPSStream();
id<MTLDevice> device = MPSDevice::getInstance()->device();

std::string outputStringType = getMetalScalarType(outputDataType);
std::string inputStringType = getMetalScalarType(inputDataType);
std::string otherStringType = getMetalScalarType(otherDataType);
std::string inputCastType = getMetalScalarType(common_dtype);
std::string otherCastType = getMetalScalarType(common_dtype);

id<MTLBuffer> inputBuffer = mps::getMTLBufferStorage(inputTensor);
id<MTLBuffer> otherBuffer = mps::getMTLBufferStorage(otherTensor);
id<MTLBuffer> outputBuffer = mps::getMTLBufferStorage(outputTensor);
uint32_t inputTensorStorage = inputTensor.storage_offset() * inputTensor.element_size();
uint32_t otherTensorStorage = otherTensor.storage_offset() * otherTensor.element_size();
mps::MPSScalar scalar;
uint32_t numThreads = iter.numel();

if (all_scalar) {
type = BinaryKernelType::Scalar;
if (iter.is_cpu_scalar(1)) {
scalar = mps::getMPSScalar(inputTensor.item(), inputTensor.scalar_type());
inputBuffer = (id<MTLBuffer>)getIMPSAllocator()->allocScalarBufferWithValue(&scalar.value, scalar.size).get();
inputTensorStorage = 0;
}
if (iter.is_cpu_scalar(2)) {
scalar = mps::getMPSScalar(otherTensor.item(), otherTensor.scalar_type());
otherBuffer = (id<MTLBuffer>)getIMPSAllocator()->allocScalarBufferWithValue(&scalar.value, scalar.size).get();
otherTensorStorage = 0;
}
} else if (scalar_pos) {
if (allContiguous) {
type = scalar_pos == 1 ? BinaryKernelType::LHS_Scalar : BinaryKernelType::RHS_Scalar;
} else {
type = scalar_pos == 1 ? BinaryKernelType::Strided_LHS_Scalar : BinaryKernelType::Strided_RHS_Scalar;
}

if (iter.is_cpu_scalar(scalar_pos)) {
uint8_t vecType = 0;
string vecStringType = "";
if (scalar_pos == 1) {
scalar = mps::getMPSScalar(inputTensor.item(), inputTensor.scalar_type());
inputBuffer = (id<MTLBuffer>)getIMPSAllocator()->allocScalarBufferWithValue(&scalar.value, scalar.size).get();
inputTensorStorage = 0;
type = BinaryKernelType::LHS_Scalar;
if (binaryKernelOpInfo.canVectorize) {
vecType = getVectorType(otherTensor.numel());
vecStringType = vecType >= 2 ? std::to_string(vecType) : "";
otherStringType += vecStringType;
otherCastType += vecStringType;
}
} else {
scalar = mps::getMPSScalar(otherTensor.item(), otherTensor.scalar_type());
otherBuffer = (id<MTLBuffer>)getIMPSAllocator()->allocScalarBufferWithValue(&scalar.value, scalar.size).get();
otherTensorStorage = 0;
type = BinaryKernelType::RHS_Scalar;
if (binaryKernelOpInfo.canVectorize) {
vecType = getVectorType(inputTensor.numel());
vecStringType = vecType >= 2 ? std::to_string(vecType) : "";
inputStringType += vecStringType;
inputCastType += vecStringType;
}
}
if (vecType >= 2) {
numThreads /= vecType;
outputStringType += vecStringType;
}
} else {
type = scalar_pos == 1 ? BinaryKernelType::Strided_LHS_Scalar : BinaryKernelType::Strided_RHS_Scalar;
}
} else {
type = allContiguous ? BinaryKernelType::Tensor : BinaryKernelType::Strided_Tensor;
if (allContiguous) {
type = BinaryKernelType::Tensor;
if (binaryKernelOpInfo.canVectorize) {
uint8_t inputVecType = getVectorType(inputTensor.numel());
uint8_t otherVecType = getVectorType(otherTensor.numel());
if (inputVecType >= 2 && inputVecType == otherVecType) {
std::string vecType = std::to_string(inputVecType);
inputStringType += vecType;
inputCastType += vecType;
otherStringType += vecType;
otherCastType += vecType;
outputStringType += vecType;
numThreads /= inputVecType;
}
}
} else {
type = BinaryKernelType::Strided_Tensor;
}
}

if (iter.is_cpu_scalar(1)) {
scalar = mps::getMPSScalar(inputTensor.item(), inputTensor.scalar_type());
inputBuffer = (id<MTLBuffer>)getIMPSAllocator()->allocScalarBufferWithValue(&scalar.value, scalar.size).get();
inputTensorStorage = 0;
}
if (iter.is_cpu_scalar(2)) {
scalar = mps::getMPSScalar(otherTensor.item(), otherTensor.scalar_type());
otherBuffer = (id<MTLBuffer>)getIMPSAllocator()->allocScalarBufferWithValue(&scalar.value, scalar.size).get();
otherTensorStorage = 0;
}

const uint32_t nDim = iter.ndim();
constexpr uint32_t nOffsets = 3;

dispatch_sync(mpsStream->queue(), ^(){
@autoreleasepool {
uint32_t numThreads = iter.numel();
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
MTLSize gridSize = MTLSizeMake(numThreads, 1, 1);
const IntArrayRef& iterShape = iter.shape();
Expand Down Expand Up @@ -347,14 +401,15 @@ void dispatch_binary_kernel_mps_(TensorIteratorBase& iter, const std::string& op
}

id<MTLComputePipelineState> binaryPSO = mps::getBinaryPSO(device,
getMetalScalarType(outputDataType),
getMetalScalarType(inputDataType),
getMetalScalarType(otherDataType),
getMetalScalarType(common_dtype),
kernel,
op,
kernel_operator,
type);
outputStringType,
inputStringType,
otherStringType,
inputCastType,
otherCastType,
kernel,
op,
kernel_operator,
type);
getMPSProfiler().beginProfileKernel(binaryPSO, kernel, {inputTensor, otherTensor, outputTensor});
[computeEncoder setComputePipelineState:binaryPSO];
[computeEncoder setBuffer:inputBuffer offset:inputTensorStorage atIndex:0];
Expand All @@ -380,46 +435,47 @@ void dispatch_binary_kernel_mps_(TensorIteratorBase& iter, const std::string& op
}

static
void dispatch_binary_kernel_mps(const Tensor& self, const Tensor& other, const Tensor& output, const std::string& op, const std::string& kernel_operator) {
void dispatch_binary_kernel_mps(const Tensor& self, const Tensor& other, const Tensor& output, const BinaryKernelOpInfo& binaryKernelOpInfo) {
TensorIterator iter;
const std::string& op = binaryKernelOpInfo.op_name;
if (op == "lt" || op == "le" || op == "gt" || op == "ge" || op == "ne" || op == "logical_or" || op == "logical_and" || op == "eq") {
iter = TensorIterator::comparison_op(const_cast<Tensor&>(output), self, other);
} else {
iter = TensorIterator::borrowing_binary_op(output, self, other);
}

dispatch_binary_kernel_mps_(iter, op, kernel_operator);
dispatch_binary_kernel_mps_(iter, binaryKernelOpInfo);
}

bool getBinaryKernelOperator(const std::string& op_name, std::pair<std::string, std::string>& kernel_operator) {
bool getBinaryKernelOpInfo(const std::string& op_name, BinaryKernelOpInfo& binaryKernelOpInfo) {
static bool macOS13_0_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_0_PLUS);
if (!macOS13_0_plus) {
return false;
}

static std::unordered_map<std::string, std::pair<std::string, std::string>> opToKernelOperator = {
{"multiplication", {"mul", "*" }},
{"div_out_mps:", {"div", "/" }},
{"add_out_mps:", {"add", "+" }},
{"sub_out_mps:", {"sub", "-" }},
static std::unordered_map<std::string, BinaryKernelOpInfo> opToKernelOperator = {
{"multiplication", {"mul", "*", true}},
{"div_out_mps:", {"div", "/", true}},
{"add_out_mps:", {"add", "+", true}},
{"sub_out_mps:", {"sub", "-", true}},

// comparison ops
{"lessThan", {"lt", "<" }},
{"lessThanOrEqualTo", {"le", "<="}},
{"greaterThan", {"gt", ">" }},
{"greaterThanOrEqualTo", {"ge", ">="}},
{"notEqual", {"ne", "!="}},
{"logicalOR", {"logical_or", "||"}},
{"logicalAND", {"logical_and", "&&"}},
{"equal", {"eq", "=="}},
{"lessThan", {"lt", "<" , true}},
{"lessThanOrEqualTo", {"le", "<=", true}},
{"greaterThan", {"gt", ">" , true}},
{"greaterThanOrEqualTo", {"ge", ">=", true}},
{"notEqual", {"ne", "!=", true}},
{"logicalOR", {"logical_or", "||", false}},
{"logicalAND", {"logical_and", "&&", false}},
{"equal", {"eq", "==", true}},
};

auto it = opToKernelOperator.find(op_name);
if (it == opToKernelOperator.end()) {
return false;
}

kernel_operator = it->second;
binaryKernelOpInfo = it->second;
return true;
}

Expand All @@ -430,8 +486,9 @@ bool dispatchNativeBinaryKernel(const Tensor& self,
const std::string& op_name) {
if (alpha.toFloat() == 1.0) {
std::pair<std::string, std::string> kernel_operator;
if (getBinaryKernelOperator(op_name, kernel_operator)) {
dispatch_binary_kernel_mps(self, other, output, kernel_operator.first, kernel_operator.second);
BinaryKernelOpInfo binaryKernelOpInfo;
if (getBinaryKernelOpInfo(op_name, binaryKernelOpInfo)) {
dispatch_binary_kernel_mps(self, other, output, binaryKernelOpInfo);
return true;
}
}
Expand Down