Skip to content

Commit

Permalink
Improve vectorization computation logic.
Browse files Browse the repository at this point in the history
  • Loading branch information
whchung committed Jun 6, 2020
1 parent f0c88af commit 266a958
Showing 1 changed file with 69 additions and 37 deletions.
Expand Up @@ -30,6 +30,16 @@ namespace {
// result string to keep C++ source / header / flags emission.
std::string resultStr;

template <typename Number>
Number gcd(Number u, Number v) {
while (v != 0) {
Number r = u % v;
u = v;
v = r;
}
return u;
}

class TunableParameters : public TunableParametersBase {
public:
TunableParameters() : TunableParametersBase("gridwise_convolution_implicit_gemm_v4r4.yaml") {}
Expand All @@ -41,27 +51,28 @@ class TunableParameters : public TunableParametersBase {
params["CK_PARAM_TUNABLE_GEMM_K_PER_BLOCK"] = 8;
params["CK_PARAM_TUNABLE_GEMM_M_PER_THREAD_SUB_C"] = 4;
params["CK_PARAM_TUNABLE_GEMM_N_PER_THREAD_SUB_C"] = 4;
params["CK_PARAM_TUNABLE_BLOCK_SIZE"] = 256;

// parameters derivable from tunable parameters.
// parameters fixed.
params["CK_PARAM_TUNABLE_GEMM_M_LEVEL0_CLUSTER"] = 4;
params["CK_PARAM_TUNABLE_GEMM_N_LEVEL0_CLUSTER"] = 4;
params["CK_PARAM_TUNABLE_GEMM_M_LEVEL1_CLUSTER"] = 4;
params["CK_PARAM_TUNABLE_GEMM_N_LEVEL1_CLUSTER"] = 4;
params["CK_PARAM_TUNABLE_BLOCK_SIZE"] = 256;

params["CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_M"] = 1;
params["CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_N"] = 1;

// parameters vary per data layout.
// specify the most conservative parameters first.
// TBD. add vectorization computation logic.
params["CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K"] = 2;
params["CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_M"] = 128;
params["CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_SRC_DATA_PER_READ_GEMM"] = 1;

params["CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K"] = 2;
params["CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_N"] = 128;
params["CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_SRC_DATA_PER_READ_GEMM"] = 1;

// parameters vary per data layout.
// specify the most conservative parameters first.
// TBD. add vectorization computation logic.
params["CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_SRC_DATA_PER_READ_GEMM_K"] = 1;
params["CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_M"] = 1;
params["CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_SRC_DATA_PER_READ_GEMM_N"] = 1;
params["CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_N"] = 1;
params["CK_PARAM_TUNABLE_GEMM_C_THREAD_COPY_DST_DATA_PER_WRITE_GEMM_N1"] = 1;
}
};
Expand Down Expand Up @@ -148,7 +159,7 @@ static constexpr StringLiteral kCppInterlude = R"(
Sequence<GemmABlockCopyClusterLengths_GemmK, GemmABlockCopyClusterLengths_GemmM>;
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK =
CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_SRC_DATA_PER_READ_GEMM_K;
CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_SRC_DATA_PER_READ_GEMM;
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM =
CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_M;
Expand All @@ -172,7 +183,7 @@ static constexpr StringLiteral kCppInterlude = R"(
Sequence<GemmBBlockCopyClusterLengths_GemmK, GemmBBlockCopyClusterLengths_GemmN>;
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN =
CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_SRC_DATA_PER_READ_GEMM_N;
CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_SRC_DATA_PER_READ_GEMM;
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN =
CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_N;
Expand Down Expand Up @@ -712,7 +723,7 @@ std::unique_ptr<llvm::StringRef> mlir::translateModuleToMIOpenHeader(ModuleOp m)
output << ");\n\n";
});

bool filterGemmKVectorizable, inputGemmKVectorizable;
bool filterGemmKVectorizable = false, inputGemmKVectorizable = false;
f.walk([&filterGemmKVectorizable, &inputGemmKVectorizable](miopen::GridwiseGemmOp op) {
auto filterLayoutAttr = op.getAttrOfType<ArrayAttr>("filter_layout");
auto inputLayoutAttr = op.getAttrOfType<ArrayAttr>("input_layout");
Expand Down Expand Up @@ -952,23 +963,22 @@ std::unique_ptr<llvm::StringRef> mlir::translateModuleToMIOpenCFlags(ModuleOp m)

// TBD.
// Determine vectorization dimensions and lengths.
int64_t vectorizableLength = 0;

// Filter tensor.
// Find the fastest changing dimension.
bool filterGemmKVectorizable = false;
if (dimKF == 3) {
// When K is the fastest changing dimension,
// gemmM dimension is vectorizable.
// vectorization width depending on length of K.
if (k % 4 == 0) {
params.setValue("CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_M", 4);
} else if (k % 2 == 0) {
params.setValue("CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_M", 2);
}
vectorizableLength = k;

// gemmK dimension non-vectorizable.
filterGemmKVectorizable = false;
} else {
// gemmK dimension vectorizable,
// depending on which among C, Y, X be the fastest changing dimension.
int64_t vectorizableLength = 0;
if (dimKF == 0) {
// dimKF is the lowest changing dimension, which means dimC/dimY/dimX
vectorizableLength = c * y * x;
Expand All @@ -979,45 +989,67 @@ std::unique_ptr<llvm::StringRef> mlir::translateModuleToMIOpenCFlags(ModuleOp m)
vectorizableLength = y * x;
}
}
if (vectorizableLength % 4 == 0) {
params.setValue("CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_SRC_DATA_PER_READ_GEMM_K", 4);
} else if (vectorizableLength % 2 == 0) {
params.setValue("CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_SRC_DATA_PER_READ_GEMM_K", 2);
}

filterGemmKVectorizable = true;
// gemmM dimension non-vectorizable.
}

int perThreadOpsA = params["CK_PARAM_TUNABLE_GEMM_M_PER_BLOCK"] * params["CK_PARAM_TUNABLE_GEMM_K_PER_BLOCK"] / params["CK_PARAM_TUNABLE_BLOCK_SIZE"];
int perThreadOpsAVectorLength = 1;
if ((vectorizableLength > 0) && (vectorizableLength % 4 == 0)) {
perThreadOpsAVectorLength = gcd(4, perThreadOpsA);
} else if ((vectorizableLength > 0) && (vectorizableLength % 2 == 0)) {
perThreadOpsAVectorLength = gcd(2, perThreadOpsA);
}
int perThreadOpsANonVectorizedLength = perThreadOpsA / perThreadOpsAVectorLength;
params.setValue("CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_SRC_DATA_PER_READ_GEMM", perThreadOpsAVectorLength);
if (filterGemmKVectorizable) {
params.setValue("CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_M", params["CK_PARAM_TUNABLE_GEMM_M_PER_BLOCK"] / perThreadOpsANonVectorizedLength);
params.setValue("CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K", params["CK_PARAM_TUNABLE_GEMM_K_PER_BLOCK"] / perThreadOpsAVectorLength);
} else {
params.setValue("CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K", params["CK_PARAM_TUNABLE_GEMM_K_PER_BLOCK"] / perThreadOpsANonVectorizedLength);
params.setValue("CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_M", params["CK_PARAM_TUNABLE_GEMM_M_PER_BLOCK"] / perThreadOpsAVectorLength);
}

// Input tensor.
bool inputGemmKVectorizable = false;
vectorizableLength = 0;
// Find the fastest changing dimension.
if (dimNI == 3) {
// When N is the fastest changing dimension,
// gemmN dimension is vectorizable.
// vectorization width depending on length of N.
if (n % 4 == 0) {
params.setValue("CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_SRC_DATA_PER_READ_GEMM_N", 4);
params.setValue("CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_N", 4);
} else if (n % 2 == 0) {
params.setValue("CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_SRC_DATA_PER_READ_GEMM_N", 2);
params.setValue("CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_N", 2);
}
vectorizableLength = n;

// gemmK dimension non-vectorizable.
inputGemmKVectorizable = false;
} else if (dimCI == 3) {
// When C is the fastest changing dimension,
// gemmK dimension vectorizable.
// vectorization width depending on length of C.
vectorizableLength = c;

// NOTE: After discussion with MIOpen dev, set only READ vectorization here. NOT WRITE.
if (c % 4 == 0) {
params.setValue("CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_SRC_DATA_PER_READ_GEMM_N", 4);
} else if (c % 2 == 0) {
params.setValue("CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_SRC_DATA_PER_READ_GEMM_N", 2);
}

inputGemmKVectorizable = true;
// gemmN dimension non-vectorizable.
}

int perThreadOpsB = params["CK_PARAM_TUNABLE_GEMM_N_PER_BLOCK"] * params["CK_PARAM_TUNABLE_GEMM_K_PER_BLOCK"] / params["CK_PARAM_TUNABLE_BLOCK_SIZE"];
int perThreadOpsBVectorLength = 1;
if ((vectorizableLength > 0) && (vectorizableLength % 4 == 0)) {
perThreadOpsBVectorLength = gcd(4, perThreadOpsB);
} else if ((vectorizableLength > 0) && (vectorizableLength % 2 == 0)) {
perThreadOpsBVectorLength = gcd(2, perThreadOpsB);
}
int perThreadOpsBNonVectorizedLength = perThreadOpsB / perThreadOpsBVectorLength;
params.setValue("CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_SRC_DATA_PER_READ_GEMM", perThreadOpsBVectorLength);
if (inputGemmKVectorizable) {
params.setValue("CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_N", params["CK_PARAM_TUNABLE_GEMM_N_PER_BLOCK"] / perThreadOpsBNonVectorizedLength);
params.setValue("CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K", params["CK_PARAM_TUNABLE_GEMM_K_PER_BLOCK"] / perThreadOpsBVectorLength);
} else {
params.setValue("CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K", params["CK_PARAM_TUNABLE_GEMM_K_PER_BLOCK"] / perThreadOpsBNonVectorizedLength);
params.setValue("CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_N", params["CK_PARAM_TUNABLE_GEMM_N_PER_BLOCK"] / perThreadOpsBVectorLength);
}

// Output tensor.
if (dimKO == 3) {
// gemmM vectorizable.
Expand Down

0 comments on commit 266a958

Please sign in to comment.