Skip to content

Commit

Permalink
Revise XDLOPS vectorization flag emission logic.
Browse files Browse the repository at this point in the history
- Port logic from MIOpen.
- Consider generic layouts.
  • Loading branch information
whchung committed Jun 6, 2020
1 parent c64f0ad commit 7eed23f
Showing 1 changed file with 154 additions and 120 deletions.
Expand Up @@ -47,30 +47,77 @@ class TunableParameters : public TunableParametersBase {
calculateGemmABlockCopyPerformanceParameters(llvm::SmallVector<int64_t, 5> &param) {
int64_t clusterLengths_GemmK = 0;
int64_t clusterLengths_GemmM = 0;
int64_t srcDataPerRead_GemmK = 4;
int64_t srcDataPerRead_Gemm = 4;
int64_t dstDataPerWrite_GemmM = 4;

const auto waveSize = 64;
const auto blockSize =
param[gemmNPerBlock] * param[gemmMPerBlock] / (param[gemmMPerWave] * param[gemmNPerWave]) * waveSize;

// Determine vectorization dimensions and lengths.
int64_t vectorizableLength = 0;

// Find the fastest changing dimension.
bool gemmKVectorizable = false;
bool gemmMVectorizable = false;
if (ctx.dimKF == 3) {
// When K is the fastest changing dimension,
// gemmM dimension is vectorizable.
// vectorization width depending on length of K.
vectorizableLength = ctx.k;
gemmMVectorizable = true;

// gemmK dimension non-vectorizable.
} else {
// gemmK dimension vectorizable,
// depending on which among C, Y, X be the fastest changing dimension.
if (ctx.dimKF == 0) {
// dimKF is the lowest changing dimension, which means dimC/dimY/dimX
vectorizableLength = ctx.c * ctx.y * ctx.x;
} else {
if (ctx.dimCF == 3) {
vectorizableLength = ctx.c;
} else if (ctx.dimXF == 3 && ctx.dimYF == 2) {
vectorizableLength = ctx.y * ctx.x;
}
}

gemmKVectorizable = true;
// gemmM dimension non-vectorizable.
}

if (gemmMVectorizable) {
srcDataPerRead_Gemm = gcd(srcDataPerRead_Gemm, param[gemmMPerBlock]);
} else if (gemmKVectorizable) {
srcDataPerRead_Gemm = gcd(srcDataPerRead_Gemm, param[gemmKPerBlock]);
} else {
srcDataPerRead_Gemm = 1;
}

// calculate threadwise copy size
const auto a_data_per_thread_copy = (param[gemmKPerBlock] * param[gemmMPerBlock]) / blockSize;

if(!(a_data_per_thread_copy > 0))
return std::make_tuple(-1, -1, -1, -1, false);

// calculate vector length on gemmk dimension
srcDataPerRead_GemmK = gcd(srcDataPerRead_GemmK, param[gemmKPerBlock]);


// GemmABlockCopySrcDataPerRead_GemmK also bounded by size of threadwise copy
srcDataPerRead_GemmK = gcd(srcDataPerRead_GemmK, a_data_per_thread_copy);
srcDataPerRead_Gemm = gcd(srcDataPerRead_Gemm, a_data_per_thread_copy);

// decide threadwise copy lengths
const auto a_data_per_thread_copy_gemmk = srcDataPerRead_GemmK;
const auto a_data_per_thread_copy_gemmm =
a_data_per_thread_copy / a_data_per_thread_copy_gemmk;
const auto a_data_per_thread_copy_gemm_vectorized = srcDataPerRead_Gemm;
const auto a_data_per_thread_copy_gemm_nonvectorized =
a_data_per_thread_copy / a_data_per_thread_copy_gemm_vectorized;

int64_t a_data_per_thread_copy_gemmk = 0;
int64_t a_data_per_thread_copy_gemmm = 0;
if (gemmMVectorizable) {
a_data_per_thread_copy_gemmk = a_data_per_thread_copy_gemm_nonvectorized;
a_data_per_thread_copy_gemmm = a_data_per_thread_copy_gemm_vectorized;
} else {
a_data_per_thread_copy_gemmk = a_data_per_thread_copy_gemm_vectorized;
a_data_per_thread_copy_gemmm = a_data_per_thread_copy_gemm_nonvectorized;
}

// GemmABlockCopyDstDataPerWrite_GemmM also bounded by size of threadwise copy
dstDataPerWrite_GemmM = gcd(dstDataPerWrite_GemmM, a_data_per_thread_copy_gemmm);

Expand All @@ -80,10 +127,20 @@ class TunableParameters : public TunableParametersBase {

if(!(clusterLengths_GemmK > 0 && clusterLengths_GemmM > 0))
return std::make_tuple(-1, -1, -1, -1, false);

//llvm::errs() << "======================\n";
//llvm::errs() << "Matrix A\n";
//llvm::errs() << "gemmK Vectorizable: " << gemmKVectorizable << "\n";
//llvm::errs() << "gemmM Vectorizable: " << gemmMVectorizable << "\n";
//llvm::errs() << "cluster lengths gemmK: " << clusterLengths_GemmK << "\n";
//llvm::errs() << "cluster lengths gemmM: " << clusterLengths_GemmM << "\n";
//llvm::errs() << "data per read: " << srcDataPerRead_Gemm << "\n";
//llvm::errs() << "data per write: " << dstDataPerWrite_GemmM << "\n";
//llvm::errs() << "======================\n";

return std::make_tuple(clusterLengths_GemmK,
clusterLengths_GemmM,
srcDataPerRead_GemmK,
srcDataPerRead_Gemm,
dstDataPerWrite_GemmM,
true);
}
Expand All @@ -93,45 +150,84 @@ class TunableParameters : public TunableParametersBase {
calculateGemmBBlockCopyPerformanceParameters(llvm::SmallVector<int64_t, 5> &param) {
int64_t clusterLengths_GemmK = 0;
int64_t clusterLengths_GemmN = 0;
int64_t srcDataPerRead_GemmN = 4;
int64_t srcDataPerRead_Gemm = 4;
int64_t dstDataPerWrite_GemmN = 4;

const auto waveSize = 64;
const auto blockSize =
const int64_t waveSize = 64;
const int64_t blockSize =
param[gemmNPerBlock] * param[gemmMPerBlock] / (param[gemmMPerWave] * param[gemmNPerWave]) * waveSize;

srcDataPerRead_GemmN = gcd(srcDataPerRead_GemmN, param[gemmNPerBlock]);

// calculate vector length on gemmn dimension
if(ctx.y == 1 && ctx.x == 1 && ctx.strideH == 1 && ctx.strideW == 1 && ctx.paddingHL == 0 &&
ctx.paddingHR == 0 && ctx.paddingWL == 0 && ctx.paddingWR == 0)
{
// \todo there are more configs that can go through this if branch
srcDataPerRead_GemmN = gcd(srcDataPerRead_GemmN, ctx.hi * ctx.wi);
}
else if(ctx.strideW == 1)
{
srcDataPerRead_GemmN =
gcd(srcDataPerRead_GemmN, ctx.paddingWL, ctx.wi, ctx.paddingWR, ctx.dilationW);
// Determine vectorization dimensions and lengths.
int64_t vectorizableLength = 0;

bool gemmKVectorizable = false;
bool gemmNVectorizable = false;
// Find the fastest changing dimension.
if (ctx.dimNI == 3) {
// When N is the fastest changing dimension,
// gemmN dimension is vectorizable.
// vectorization width depending on length of N.
vectorizableLength = ctx.n;
gemmNVectorizable = true;

// gemmK dimension non-vectorizable.
} else if (ctx.dimCI == 3) {
// When C is the fastest changing dimension,
// gemmK dimension vectorizable.
// vectorization width depending on length of C.
vectorizableLength = ctx.c;
gemmKVectorizable = true;
// gemmN dimension non-vectorizable.
} else if (ctx.dimHI == 2 && ctx.dimWI == 3) {
if(ctx.y == 1 && ctx.x == 1 && ctx.strideH == 1 && ctx.strideW == 1 && ctx.paddingHL == 0 &&
ctx.paddingHR == 0 && ctx.paddingWL == 0 && ctx.paddingWR == 0) {
// \todo there are more configs that can go through this if branch
srcDataPerRead_Gemm = gcd(srcDataPerRead_Gemm, ctx.hi * ctx.wi);

gemmNVectorizable = true;
}
else if(ctx.strideW == 1) {
srcDataPerRead_Gemm =
gcd(srcDataPerRead_Gemm, ctx.paddingWL, ctx.wi, ctx.paddingWR, ctx.dilationW);

gemmNVectorizable = true;
} else {
srcDataPerRead_Gemm = 1;
}
} else {
srcDataPerRead_Gemm = 1;
}
else
{
srcDataPerRead_GemmN = 1;

if (gemmNVectorizable) {
srcDataPerRead_Gemm = gcd(srcDataPerRead_Gemm, param[gemmNPerBlock]);
} else if (gemmKVectorizable) {
srcDataPerRead_Gemm = gcd(srcDataPerRead_Gemm, param[gemmKPerBlock]);
} else {
srcDataPerRead_Gemm = 1;
}

// calculate threadwise copy size
const auto b_data_per_thread_copy = (param[gemmKPerBlock] * param[gemmNPerBlock]) / blockSize;
const int64_t b_data_per_thread_copy = (param[gemmKPerBlock] * param[gemmNPerBlock]) / blockSize;

if(!(b_data_per_thread_copy > 0))
return std::make_tuple(-1, -1, -1, -1, false);

// GemmBBlockCopySrcDataPerRead_GemmN also bounded by size of threadwise copy
srcDataPerRead_GemmN = gcd(srcDataPerRead_GemmN, b_data_per_thread_copy);
srcDataPerRead_Gemm = gcd(srcDataPerRead_Gemm, b_data_per_thread_copy);

const auto b_data_per_thread_copy_gemmn = srcDataPerRead_GemmN;
const auto b_data_per_thread_copy_gemmk =
b_data_per_thread_copy / b_data_per_thread_copy_gemmn;
const int64_t b_data_per_thread_copy_gemm_vectorized = srcDataPerRead_Gemm;
const int64_t b_data_per_thread_copy_gemm_nonvectorized =
b_data_per_thread_copy / b_data_per_thread_copy_gemm_vectorized;

int64_t b_data_per_thread_copy_gemmk = 0;
int64_t b_data_per_thread_copy_gemmn = 0;
if (gemmNVectorizable) {
b_data_per_thread_copy_gemmk = b_data_per_thread_copy_gemm_nonvectorized;
b_data_per_thread_copy_gemmn = b_data_per_thread_copy_gemm_vectorized;
} else {
b_data_per_thread_copy_gemmk = b_data_per_thread_copy_gemm_vectorized;
b_data_per_thread_copy_gemmn = b_data_per_thread_copy_gemm_nonvectorized;
}

// GemmBBlockCopyDstDataPerWrite_GemmN also bounded by size of threadwise copy
dstDataPerWrite_GemmN = gcd(dstDataPerWrite_GemmN, b_data_per_thread_copy_gemmn);

Expand All @@ -142,9 +238,19 @@ class TunableParameters : public TunableParametersBase {
if(!(clusterLengths_GemmK > 0 && clusterLengths_GemmN > 0))
return std::make_tuple(-1, -1, -1, -1, false);

//llvm::errs() << "======================\n";
//llvm::errs() << "Matrix B\n";
//llvm::errs() << "gemmK Vectorizable: " << gemmKVectorizable << "\n";
//llvm::errs() << "gemmN Vectorizable: " << gemmNVectorizable << "\n";
//llvm::errs() << "cluster lengths gemmK: " << clusterLengths_GemmK << "\n";
//llvm::errs() << "cluster lengths gemmN: " << clusterLengths_GemmN << "\n";
//llvm::errs() << "data per read: " << srcDataPerRead_Gemm << "\n";
//llvm::errs() << "data per write: " << dstDataPerWrite_GemmN << "\n";
//llvm::errs() << "======================\n";

return std::make_tuple(clusterLengths_GemmK,
clusterLengths_GemmN,
srcDataPerRead_GemmN,
srcDataPerRead_Gemm,
dstDataPerWrite_GemmN,
true);
}
Expand Down Expand Up @@ -231,22 +337,22 @@ class TunableParameters : public TunableParametersBase {
int64_t gemmN = ctx.n * ctx.ho * ctx.wo;
int64_t gemmK = ctx.c * ctx.y * ctx.x;

llvm::errs() << "gemmM: " << gemmM << " gemmN: " << gemmN << " gemmK: " << gemmK << "\n";
llvm::errs() << "MPerBlock: " << param[gemmMPerBlock] << "\n";
llvm::errs() << "NPerBlock: " << param[gemmNPerBlock] << "\n";
llvm::errs() << "KPerBlock: " << param[gemmKPerBlock] << "\n";
llvm::errs() << "MPerWave: " << param[gemmMPerWave] << "\n";
llvm::errs() << "NPerWave: " << param[gemmNPerWave] << "\n";
//llvm::errs() << "gemmM: " << gemmM << " gemmN: " << gemmN << " gemmK: " << gemmK << "\n";
//llvm::errs() << "MPerBlock: " << param[gemmMPerBlock] << "\n";
//llvm::errs() << "NPerBlock: " << param[gemmNPerBlock] << "\n";
//llvm::errs() << "KPerBlock: " << param[gemmKPerBlock] << "\n";
//llvm::errs() << "MPerWave: " << param[gemmMPerWave] << "\n";
//llvm::errs() << "NPerWave: " << param[gemmNPerWave] << "\n";

if (!(gemmM % param[gemmMPerBlock] == 0 &&
gemmN % param[gemmNPerBlock] == 0 &&
gemmK % param[gemmKPerBlock] == 0)) {
llvm::errs() << "NOT VALID\n";
//llvm::errs() << "NOT VALID\n";
return false;
}

if (!isValidXDLOPSGemm(param)) {
llvm::errs() << "NOT VALID\n";
//llvm::errs() << "NOT VALID\n";
return false;
}

Expand All @@ -257,7 +363,7 @@ class TunableParameters : public TunableParametersBase {
calculateGemmABlockCopyPerformanceParameters(param);

if(!valid) {
llvm::errs() << "NOT VALID\n";
//llvm::errs() << "NOT VALID\n";
return false;
}

Expand All @@ -266,19 +372,19 @@ class TunableParameters : public TunableParametersBase {
calculateGemmBBlockCopyPerformanceParameters(param);

if(!valid) {
llvm::errs() << "NOT VALID\n";
//llvm::errs() << "NOT VALID\n";
return false;
}

std::size_t lds_size = 0;
std::tie(lds_size, valid) = calculateLdsNumberOfByte(param);

if (!valid || (lds_size > 64 * 1024)) {
llvm::errs() << "NOT VALID\n";
//llvm::errs() << "NOT VALID\n";
return false;
}

llvm::errs() << "VALID WITH LDS SIZE: " << lds_size << "\n";
//llvm::errs() << "VALID WITH LDS SIZE: " << lds_size << "\n";
return (valid && lds_size <= 64 * 1024);
}

Expand Down Expand Up @@ -1198,57 +1304,6 @@ std::unique_ptr<llvm::StringRef> mlir::translateModuleToMIOpenCFlagsXDLOPS(Modul
TunableParameters params;
params.initWithContext(ctx);

// XXX disable for now.
//// Determine vectorization dimensions and lengths.
//int64_t vectorizableLength = 0;

//// Filter tensor.
//// Find the fastest changing dimension.
//bool filterGemmKVectorizable = false;
//if (ctx.dimKF == 3) {
// // When K is the fastest changing dimension,
// // gemmM dimension is vectorizable.
// // vectorization width depending on length of K.
// vectorizableLength = ctx.k;

// // gemmK dimension non-vectorizable.
// filterGemmKVectorizable = false;
//} else {
// // gemmK dimension vectorizable,
// // depending on which among C, Y, X be the fastest changing dimension.
// if (ctx.dimKF == 0) {
// // dimKF is the lowest changing dimension, which means dimC/dimY/dimX
// vectorizableLength = ctx.c * ctx.y * ctx.x;
// } else {
// if (ctx.dimCF == 3) {
// vectorizableLength = ctx.c;
// } else if (ctx.dimXF == 3 && ctx.dimYF == 2) {
// vectorizableLength = ctx.y * ctx.x;
// }
// }

// filterGemmKVectorizable = true;
// // gemmM dimension non-vectorizable.
//}

// XXX disable vectorization logic on matrix A for now.
//int perThreadOpsA = params["CK_PARAM_TUNABLE_GEMM_M_PER_BLOCK"] * params["CK_PARAM_TUNABLE_GEMM_K_PER_BLOCK"] / params["CK_PARAM_TUNABLE_BLOCK_SIZE"];
//int perThreadOpsAVectorLength = 1;
//if ((vectorizableLength > 0) && (vectorizableLength % 4 == 0)) {
// perThreadOpsAVectorLength = gcd(4, perThreadOpsA);
//} else if ((vectorizableLength > 0) && (vectorizableLength % 2 == 0)) {
// perThreadOpsAVectorLength = gcd(2, perThreadOpsA);
//}
//int perThreadOpsANonVectorizedLength = perThreadOpsA / perThreadOpsAVectorLength;
//params.setValue("CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_SRC_DATA_PER_READ_GEMM", perThreadOpsAVectorLength);
//if (filterGemmKVectorizable) {
// params.setValue("CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_M", params["CK_PARAM_TUNABLE_GEMM_M_PER_BLOCK"] / perThreadOpsANonVectorizedLength);
// params.setValue("CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K", params["CK_PARAM_TUNABLE_GEMM_K_PER_BLOCK"] / perThreadOpsAVectorLength);
//} else {
// params.setValue("CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K", params["CK_PARAM_TUNABLE_GEMM_K_PER_BLOCK"] / perThreadOpsANonVectorizedLength);
// params.setValue("CK_PARAM_TUNABLE_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_M", params["CK_PARAM_TUNABLE_GEMM_M_PER_BLOCK"] / perThreadOpsAVectorLength);
//}

// XXX disable for now.
//// Input tensor.
//bool inputGemmKVectorizable = false;
Expand All @@ -1272,27 +1327,6 @@ std::unique_ptr<llvm::StringRef> mlir::translateModuleToMIOpenCFlagsXDLOPS(Modul
// // gemmN dimension non-vectorizable.
//}

// XXX disable vectorization logic on matrix B for now.
//int perThreadOpsB = params["CK_PARAM_TUNABLE_GEMM_N_PER_BLOCK"] * params["CK_PARAM_TUNABLE_GEMM_K_PER_BLOCK"] / params["CK_PARAM_TUNABLE_BLOCK_SIZE"];
//int perThreadOpsBVectorLength = 1;
//if ((vectorizableLength > 0) && (vectorizableLength % 4 == 0)) {
// perThreadOpsBVectorLength = gcd(4, perThreadOpsB);
//} else if ((vectorizableLength > 0) && (vectorizableLength % 2 == 0)) {
// perThreadOpsBVectorLength = gcd(2, perThreadOpsB);
//}
//int perThreadOpsBNonVectorizedLength = perThreadOpsB / perThreadOpsBVectorLength;
//params.setValue("CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_SRC_DATA_PER_READ_GEMM", perThreadOpsBVectorLength);
//if (inputGemmKVectorizable) {
// params.setValue("CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_N", params["CK_PARAM_TUNABLE_GEMM_N_PER_BLOCK"] / perThreadOpsBNonVectorizedLength);
// params.setValue("CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K", params["CK_PARAM_TUNABLE_GEMM_K_PER_BLOCK"] / perThreadOpsBVectorLength);
//} else {
// params.setValue("CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K", params["CK_PARAM_TUNABLE_GEMM_K_PER_BLOCK"] / perThreadOpsBNonVectorizedLength);
// params.setValue("CK_PARAM_TUNABLE_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_N", params["CK_PARAM_TUNABLE_GEMM_N_PER_BLOCK"] / perThreadOpsBVectorLength);
//}

// Output tensor.
// Dont vectorize on matrix C for now.

// Print out the tunable parameters.
params.print(output);
if (IsPopulateTunableParameters.getValue()) {
Expand Down

0 comments on commit 7eed23f

Please sign in to comment.