Skip to content

Commit

Permalink
[mlir][OpenMP] Add custom parser and pretty printer for parallel cons…
Browse files Browse the repository at this point in the history
…truct

Reviewers: jdoerfert

Subscribers: yaxunl, guansong, mehdi_amini, rriddle, jpienaar, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, aartbik, liufengdb, stephenneuendorffer, Joonsoo, grosul1, frgossen, Kayjukh, jurahul, sstefan1, msifontes

Tags: #mlir

Differential Revision: https://reviews.llvm.org/D81264
  • Loading branch information
DavidTruby committed Jun 16, 2020
1 parent d7deef1 commit 245b299
Show file tree
Hide file tree
Showing 4 changed files with 374 additions and 5 deletions.
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ def ParallelOp : OpenMP_Op<"parallel", [AttrSizedOperandSegments]> {
OptionalAttr<ClauseProcBind>:$proc_bind_val);

let regions = (region AnyRegion:$region);

let parser = [{ return parseParallelOp(parser, result); }];
let printer = [{ return printParallelOp(p, *this); }];
}

def TerminatorOp : OpenMP_Op<"terminator", [Terminator]> {
Expand Down
244 changes: 244 additions & 0 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@

#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"

#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
#include <cstddef>

#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"

Expand All @@ -29,6 +34,245 @@ OpenMPDialect::OpenMPDialect(MLIRContext *context)
>();
}

//===----------------------------------------------------------------------===//
// ParallelOp
//===----------------------------------------------------------------------===//

/// Parse a list of operands with types.
///
/// operand-and-type-list ::= `(` ssa-id-and-type-list `)`
/// ssa-id-and-type-list ::= ssa-id-and-type |
/// ssa-id-and-type ',' ssa-id-and-type-list
/// ssa-id-and-type ::= ssa-id `:` type
static ParseResult
parseOperandAndTypeList(OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::OperandType> &operands,
SmallVectorImpl<Type> &types) {
if (parser.parseLParen())
return failure();

do {
OpAsmParser::OperandType operand;
Type type;
if (parser.parseOperand(operand) || parser.parseColonType(type))
return failure();
operands.push_back(operand);
types.push_back(type);
} while (succeeded(parser.parseOptionalComma()));

if (parser.parseRParen())
return failure();

return success();
}

static void printParallelOp(OpAsmPrinter &p, ParallelOp op) {
p << "omp.parallel";

if (auto ifCond = op.if_expr_var())
p << " if(" << ifCond << ")";

if (auto threads = op.num_threads_var())
p << " num_threads(" << threads << " : " << threads.getType() << ")";

// Print private, firstprivate, shared and copyin parameters
auto printDataVars = [&p](StringRef name, OperandRange vars) {
if (vars.size()) {
p << " " << name << "(";
for (unsigned i = 0; i < vars.size(); ++i) {
std::string separator = i == vars.size() - 1 ? ")" : ", ";
p << vars[i] << " : " << vars[i].getType() << separator;
}
}
};
printDataVars("private", op.private_vars());
printDataVars("firstprivate", op.firstprivate_vars());
printDataVars("shared", op.shared_vars());
printDataVars("copyin", op.copyin_vars());

if (auto def = op.default_val())
p << " default(" << def->drop_front(3) << ")";

if (auto bind = op.proc_bind_val())
p << " proc_bind(" << bind << ")";

p.printRegion(op.getRegion());
}

/// Emit an error if the same clause is present more than once on an operation.
static ParseResult allowedOnce(OpAsmParser &parser, llvm::StringRef clause,
llvm::StringRef operation) {
return parser.emitError(parser.getNameLoc())
<< " at most one " << clause << " clause can appear on the "
<< operation << " operation";
}

/// Parses a parallel operation.
///
/// operation ::= `omp.parallel` clause-list
/// clause-list ::= clause | clause clause-list
/// clause ::= if | numThreads | private | firstprivate | shared | copyin |
/// default | procBind
/// if ::= `if` `(` ssa-id `)`
/// numThreads ::= `num_threads` `(` ssa-id-and-type `)`
/// private ::= `private` operand-and-type-list
/// firstprivate ::= `firstprivate` operand-and-type-list
/// shared ::= `shared` operand-and-type-list
/// copyin ::= `copyin` operand-and-type-list
/// default ::= `default` `(` (`private` | `firstprivate` | `shared` | `none`)
/// procBind ::= `proc_bind` `(` (`master` | `close` | `spread`) `)`
///
/// Note that each clause can only appear once in the clase-list.
static ParseResult parseParallelOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType ifCond;
std::pair<OpAsmParser::OperandType, Type> numThreads;
llvm::SmallVector<OpAsmParser::OperandType, 4> privates;
llvm::SmallVector<Type, 4> privateTypes;
llvm::SmallVector<OpAsmParser::OperandType, 4> firstprivates;
llvm::SmallVector<Type, 4> firstprivateTypes;
llvm::SmallVector<OpAsmParser::OperandType, 4> shareds;
llvm::SmallVector<Type, 4> sharedTypes;
llvm::SmallVector<OpAsmParser::OperandType, 4> copyins;
llvm::SmallVector<Type, 4> copyinTypes;
std::array<int, 6> segments{0, 0, 0, 0, 0, 0};
llvm::StringRef keyword;
bool defaultVal = false;
bool procBind = false;

const int ifClausePos = 0;
const int numThreadsClausePos = 1;
const int privateClausePos = 2;
const int firstprivateClausePos = 3;
const int sharedClausePos = 4;
const int copyinClausePos = 5;
const llvm::StringRef opName = result.name.getStringRef();

while (succeeded(parser.parseOptionalKeyword(&keyword))) {
if (keyword == "if") {
// Fail if there was already another if condition
if (segments[ifClausePos])
return allowedOnce(parser, "if", opName);
if (parser.parseLParen() || parser.parseOperand(ifCond) ||
parser.parseRParen())
return failure();
segments[ifClausePos] = 1;
} else if (keyword == "num_threads") {
// fail if there was already another num_threads clause
if (segments[numThreadsClausePos])
return allowedOnce(parser, "num_threads", opName);
if (parser.parseLParen() || parser.parseOperand(numThreads.first) ||
parser.parseColonType(numThreads.second) || parser.parseRParen())
return failure();
segments[numThreadsClausePos] = 1;
} else if (keyword == "private") {
// fail if there was already another private clause
if (segments[privateClausePos])
return allowedOnce(parser, "private", opName);
if (parseOperandAndTypeList(parser, privates, privateTypes))
return failure();
segments[privateClausePos] = privates.size();
} else if (keyword == "firstprivate") {
// fail if there was already another firstprivate clause
if (segments[firstprivateClausePos])
return allowedOnce(parser, "firstprivate", opName);
if (parseOperandAndTypeList(parser, firstprivates, firstprivateTypes))
return failure();
segments[firstprivateClausePos] = firstprivates.size();
} else if (keyword == "shared") {
// fail if there was already another shared clause
if (segments[sharedClausePos])
return allowedOnce(parser, "shared", opName);
if (parseOperandAndTypeList(parser, shareds, sharedTypes))
return failure();
segments[sharedClausePos] = shareds.size();
} else if (keyword == "copyin") {
// fail if there was already another copyin clause
if (segments[copyinClausePos])
return allowedOnce(parser, "copyin", opName);
if (parseOperandAndTypeList(parser, copyins, copyinTypes))
return failure();
segments[copyinClausePos] = copyins.size();
} else if (keyword == "default") {
// fail if there was already another default clause
if (defaultVal)
return allowedOnce(parser, "default", opName);
defaultVal = true;
llvm::StringRef defval;
if (parser.parseLParen() || parser.parseKeyword(&defval) ||
parser.parseRParen())
return failure();
llvm::SmallString<16> attrval;
// The def prefix is required for the attribute as "private" is a keyword
// in C++
attrval += "def";
attrval += defval;
auto attr = parser.getBuilder().getStringAttr(attrval);
result.addAttribute("default_val", attr);
} else if (keyword == "proc_bind") {
// fail if there was already another default clause
if (procBind)
return allowedOnce(parser, "proc_bind", opName);
procBind = true;
llvm::StringRef bind;
if (parser.parseLParen() || parser.parseKeyword(&bind) ||
parser.parseRParen())
return failure();
auto attr = parser.getBuilder().getStringAttr(bind);
result.addAttribute("proc_bind_val", attr);
} else {
return parser.emitError(parser.getNameLoc())
<< keyword << " is not a valid clause for the " << opName
<< " operation";
}
}

// Add if parameter
if (segments[ifClausePos]) {
parser.resolveOperand(ifCond, parser.getBuilder().getI1Type(),
result.operands);
}

// Add num_threads parameter
if (segments[numThreadsClausePos]) {
parser.resolveOperand(numThreads.first, numThreads.second, result.operands);
}

// Add private parameters
if (segments[privateClausePos]) {
parser.resolveOperands(privates, privateTypes, privates[0].location,
result.operands);
}

// Add firstprivate parameters
if (segments[firstprivateClausePos]) {
parser.resolveOperands(firstprivates, firstprivateTypes,
firstprivates[0].location, result.operands);
}

// Add shared parameters
if (segments[sharedClausePos]) {
parser.resolveOperands(shareds, sharedTypes, shareds[0].location,
result.operands);
}

// Add copyin parameters
if (segments[copyinClausePos]) {
parser.resolveOperands(copyins, copyinTypes, copyins[0].location,
result.operands);
}

result.addAttribute("operand_segment_sizes",
parser.getBuilder().getI32VectorAttr(segments));

Region *body = result.addRegion();
llvm::SmallVector<OpAsmParser::OperandType, 4> regionArgs;
llvm::SmallVector<Type, 4> regionArgTypes;
if (parser.parseRegion(*body, regionArgs, regionArgTypes))
return failure();
return success();
}

namespace mlir {
namespace omp {
#define GET_OP_CLASSES
Expand Down
89 changes: 89 additions & 0 deletions mlir/test/Dialect/OpenMP/invalid.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s

func @unknown_clause() {
// expected-error@+1 {{invalid is not a valid clause for the omp.parallel operation}}
omp.parallel invalid {
}

return
}

// -----

func @if_once(%n : i1) {
// expected-error@+1 {{at most one if clause can appear on the omp.parallel operation}}
omp.parallel if(%n) if(%n) {
}

return
}

// -----

func @num_threads_once(%n : si32) {
// expected-error@+1 {{at most one num_threads clause can appear on the omp.parallel operation}}
omp.parallel num_threads(%n : si32) num_threads(%n : si32) {
}

return
}

// -----

func @private_once(%n : memref<i32>) {
// expected-error@+1 {{at most one private clause can appear on the omp.parallel operation}}
omp.parallel private(%n : memref<i32>) private(%n : memref<i32>) {
}

return
}

// -----

func @firstprivate_once(%n : memref<i32>) {
// expected-error@+1 {{at most one firstprivate clause can appear on the omp.parallel operation}}
omp.parallel firstprivate(%n : memref<i32>) firstprivate(%n : memref<i32>) {
}

return
}

// -----

func @shared_once(%n : memref<i32>) {
// expected-error@+1 {{at most one shared clause can appear on the omp.parallel operation}}
omp.parallel shared(%n : memref<i32>) shared(%n : memref<i32>) {
}

return
}

// -----

func @copyin_once(%n : memref<i32>) {
// expected-error@+1 {{at most one copyin clause can appear on the omp.parallel operation}}
omp.parallel copyin(%n : memref<i32>) copyin(%n : memref<i32>) {
}

return
}

// -----

func @default_once() {
// expected-error@+1 {{at most one default clause can appear on the omp.parallel operation}}
omp.parallel default(private) default(firstprivate) {
}

return
}

// -----

func @proc_bind_once() {
// expected-error@+1 {{at most one proc_bind clause can appear on the omp.parallel operation}}
omp.parallel proc_bind(close) proc_bind(spread) {
}

return
}
Loading

0 comments on commit 245b299

Please sign in to comment.