Skip to content

Commit

Permalink
Extract Xdlops code selection logic from xdlops_gemm lowering logic.
Browse files Browse the repository at this point in the history
the same logic would be used elsewhere (ex: gridwise_gemm lowering logic).
  • Loading branch information
whchung committed Aug 2, 2020
1 parent 07e1385 commit a375550
Showing 1 changed file with 102 additions and 31 deletions.
133 changes: 102 additions & 31 deletions mlir/include/mlir/Dialect/MIOpen/LowerMIOpenOps.h
Expand Up @@ -3233,34 +3233,28 @@ struct TransformRewritePattern : public OpRewritePattern<miopen::TransformOp> {
}
};

//===----------------------------------------------------------------------===//
// XdlopsGemm lowering.
//===----------------------------------------------------------------------===//

struct XdlopsGemmRewritePattern
: public OpRewritePattern<miopen::XdlopsGemmOp> {
using OpRewritePattern<miopen::XdlopsGemmOp>::OpRewritePattern;

LogicalResult matchAndRewrite(miopen::XdlopsGemmOp op,
PatternRewriter &b) const override {
auto loc = op.getLoc();

// Obtain critical information.
int64_t M = op.getAttr("m").template dyn_cast<IntegerAttr>().getInt();
int64_t N = op.getAttr("n").template dyn_cast<IntegerAttr>().getInt();
int64_t K = op.getAttr("k").template dyn_cast<IntegerAttr>().getInt();
int64_t MPerWave = op.getAttr("m_per_wave").template dyn_cast<IntegerAttr>().getInt();
int64_t NPerWave = op.getAttr("n_per_wave").template dyn_cast<IntegerAttr>().getInt();
auto dataType = op.matrixA()
.getType()
.template dyn_cast<MemRefType>()
.getElementType()
.template dyn_cast<FloatType>();

auto MConstantIndexOp = b.create<ConstantIndexOp>(loc, M);
auto NConstantIndexOp = b.create<ConstantIndexOp>(loc, N);
auto KConstantIndexOp = b.create<ConstantIndexOp>(loc, K);

struct XdlopsCodeSelection {
StringRef mfmaInstr;
int64_t MPerXdlops;
int64_t NPerXdlops;
int64_t MRepeats;
int64_t NRepeats;

int64_t group_size;
int64_t num_groups_blk;
int64_t num_regs_blk;
int64_t num_threads_blk;
int64_t wave_size;
int64_t num_input_blks;
int64_t num_output_blks;
int64_t num_regs_xdlops;
int64_t m;
int64_t n;
int64_t k;
int64_t cycles;
int64_t k_base;

static XdlopsCodeSelection get(FloatType dataType, int64_t MPerWave, int64_t NPerWave, PatternRewriter &b) {
// Determine which XDLOPS be used.
int64_t MPerXdlops = 0, NPerXdlops = 0, MRepeats = 0, NRepeats = 0;
StringRef mfmaInstr = "";
Expand Down Expand Up @@ -3375,7 +3369,7 @@ struct XdlopsGemmRewritePattern
NRepeats = 1;
} else {
llvm::errs() << "Unsupported case:\n";
llvm::errs() << "M, N, K:" << M << " " << N << " " << K << "\n";
//llvm::errs() << "M, N, K:" << M << " " << N << " " << K << "\n";
llvm::errs() << "MPerWave: " << MPerWave << "\n";
llvm::errs() << "NPerWave: " << NPerWave << "\n";
llvm::errs() << "dataType: ";
Expand Down Expand Up @@ -3493,7 +3487,7 @@ struct XdlopsGemmRewritePattern
NRepeats = 1;
} else {
llvm::errs() << "Unsupported case:\n";
llvm::errs() << "M, N, K:" << M << " " << N << " " << K << "\n";
//llvm::errs() << "M, N, K:" << M << " " << N << " " << K << "\n";
llvm::errs() << "MPerWave: " << MPerWave << "\n";
llvm::errs() << "NPerWave: " << NPerWave << "\n";
llvm::errs() << "dataType: ";
Expand Down Expand Up @@ -3611,7 +3605,7 @@ struct XdlopsGemmRewritePattern
NRepeats = 1;
} else {
llvm::errs() << "Unsupported case:\n";
llvm::errs() << "M, N, K:" << M << " " << N << " " << K << "\n";
//llvm::errs() << "M, N, K:" << M << " " << N << " " << K << "\n";
llvm::errs() << "MPerWave: " << MPerWave << "\n";
llvm::errs() << "NPerWave: " << NPerWave << "\n";
llvm::errs() << "dataType: ";
Expand Down Expand Up @@ -3836,6 +3830,83 @@ struct XdlopsGemmRewritePattern
llvm::errs() << "Unsupported case as mfmaInstr not selected!\n";
}

// Populate result.
XdlopsCodeSelection result;
result.mfmaInstr = mfmaInstr;
result.MPerXdlops = MPerXdlops;
result.NPerXdlops = NPerXdlops;
result.MRepeats = MRepeats;
result.NRepeats = NRepeats;

result.group_size = group_size;
result.num_groups_blk = num_groups_blk;
result.num_regs_blk = num_regs_blk;
result.num_threads_blk = num_threads_blk;
result.wave_size = wave_size;
result.num_input_blks = num_input_blks;
result.num_output_blks = num_output_blks;
result.num_regs_xdlops = num_regs_xdlops;
result.m = m;
result.n = n;
result.k = k;
result.cycles = cycles;
result.k_base = k_base;

return result;
}
};

//===----------------------------------------------------------------------===//
// XdlopsGemm lowering.
//===----------------------------------------------------------------------===//

struct XdlopsGemmRewritePattern
: public OpRewritePattern<miopen::XdlopsGemmOp> {
using OpRewritePattern<miopen::XdlopsGemmOp>::OpRewritePattern;

LogicalResult matchAndRewrite(miopen::XdlopsGemmOp op,
PatternRewriter &b) const override {
auto loc = op.getLoc();

// Obtain critical information.
int64_t M = op.getAttr("m").template dyn_cast<IntegerAttr>().getInt();
int64_t N = op.getAttr("n").template dyn_cast<IntegerAttr>().getInt();
int64_t K = op.getAttr("k").template dyn_cast<IntegerAttr>().getInt();
int64_t MPerWave = op.getAttr("m_per_wave").template dyn_cast<IntegerAttr>().getInt();
int64_t NPerWave = op.getAttr("n_per_wave").template dyn_cast<IntegerAttr>().getInt();
auto dataType = op.matrixA()
.getType()
.template dyn_cast<MemRefType>()
.getElementType()
.template dyn_cast<FloatType>();

auto MConstantIndexOp = b.create<ConstantIndexOp>(loc, M);
auto NConstantIndexOp = b.create<ConstantIndexOp>(loc, N);
auto KConstantIndexOp = b.create<ConstantIndexOp>(loc, K);

XdlopsCodeSelection xcs = XdlopsCodeSelection::get(dataType, MPerWave, NPerWave, b);

// Extract values from XdlopsCodeSelection.
StringRef mfmaInstr = xcs.mfmaInstr;
int64_t MPerXdlops = xcs.MPerXdlops;
int64_t NPerXdlops = xcs.NPerXdlops;
int64_t MRepeats = xcs.MRepeats;
int64_t NRepeats = xcs.NRepeats;

int64_t group_size = xcs.group_size;
int64_t num_groups_blk = xcs.num_groups_blk;
int64_t num_regs_blk = xcs.num_regs_blk;
int64_t num_threads_blk = xcs.num_threads_blk;
int64_t wave_size = xcs.wave_size;
int64_t num_input_blks = xcs.num_input_blks;
int64_t num_output_blks = xcs.num_output_blks;
int64_t num_regs_xdlops = xcs.num_regs_xdlops;
int64_t m = xcs.m;
int64_t n = xcs.n;
int64_t k = xcs.k;
int64_t cycles = xcs.cycles;
int64_t k_base = xcs.k_base;

bool IsABroadcast = (NPerXdlops >= MPerXdlops);
bool IsKReduction = (num_output_blks == 1) && (num_input_blks > 1);

Expand Down

0 comments on commit a375550

Please sign in to comment.