Skip to content

Commit

Permalink
Revert "[mlir][VectorOps] Use SCF for vector.print and allow scalable…
Browse files Browse the repository at this point in the history
… vectors"

This reverts commit 3875804.

This caused some test failures for the MLIR python bindings. Reverting
until those are addressed.
  • Loading branch information
MacDue committed Aug 9, 2023
1 parent f04b5ba commit b160442
Show file tree
Hide file tree
Showing 59 changed files with 238 additions and 460 deletions.
67 changes: 14 additions & 53 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2464,80 +2464,41 @@ def Vector_TransposeOp :
let hasVerifier = 1;
}

def PrintPunctuation : I32EnumAttr<"PrintPunctuation",
"Punctuation for separating vectors or vector elements", [
I32EnumAttrCase<"NoPunctuation", 0, "no_punctuation">,
I32EnumAttrCase<"NewLine", 1, "newline">,
I32EnumAttrCase<"Comma", 2, "comma">,
I32EnumAttrCase<"Open", 3, "open">,
I32EnumAttrCase<"Close", 4, "close">
]> {
let cppNamespace = "::mlir::vector";
let genSpecializedAttr = 0;
}

def Vector_PrintPunctuation : EnumAttr<Vector_Dialect, PrintPunctuation, "punctuation"> {
let assemblyFormat = "`<` $value `>`";
}

def Vector_PrintOp :
Vector_Op<"print", []>,
Arguments<(ins Optional<Type<Or<[
Arguments<(ins Type<Or<[
AnyVectorOfAnyRank.predicate,
AnyInteger.predicate, Index.predicate, AnyFloat.predicate
]>>>:$source, DefaultValuedAttr<Vector_PrintPunctuation,
"::mlir::vector::PrintPunctuation::NewLine">:$punctuation)
> {
]>>:$source)> {
let summary = "print operation (for testing and debugging)";
let description = [{
Prints the source vector (or scalar) to stdout in a human-readable format
(for testing and debugging). No return value.
Prints the source vector (or scalar) to stdout in human readable
format (for testing and debugging). No return value.

Example:

```mlir
%v = arith.constant dense<0.0> : vector<4xf32>
vector.print %v : vector<4xf32>
```
%0 = arith.constant 0.0 : f32
%1 = vector.broadcast %0 : f32 to vector<4xf32>
vector.print %1 : vector<4xf32>

When lowered to LLVM, the vector print is decomposed into elementary
printing method calls that at runtime will yield:
when lowered to LLVM, the vector print is unrolled into
elementary printing method calls that at runtime will yield

```
( 0.0, 0.0, 0.0, 0.0 )
```

This is printed to stdout via a small runtime support library, which only
needs to provide a few printing methods (single value for all data
types, opening/closing bracket, comma, newline).

By default `vector.print` adds a newline after the vector, but this can be
controlled by the `punctuation` attribute. For example, to print a comma
after instead do:

```mlir
vector.print %v : vector<4xf32> #vector.punctuation<comma>
```

Note that it is possible to use the punctuation attribute alone. The
following will print a single newline:

```mlir
vector.print #vector.punctuation<newline>
on stdout when linked with a small runtime support library,
which only needs to provide a few printing methods (single
value for all data types, opening/closing bracket, comma,
newline).
```
}];
let extraClassDeclaration = [{
Type getPrintType() {
return getSource().getType();
}
}];
let builders = [
OpBuilder<(ins "PrintPunctuation":$punctuation), [{
build($_builder, $_state, {}, punctuation);
}]>,
];

let assemblyFormat = "($source^ `:` type($source))? ($punctuation^)? attr-dict";
let assemblyFormat = "$source attr-dict `:` type($source)";
}

//===----------------------------------------------------------------------===//
Expand Down
185 changes: 105 additions & 80 deletions mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@
using namespace mlir;
using namespace mlir::vector;

// Helper to reduce vector type by one rank at front.
static VectorType reducedVectorTypeFront(VectorType tp) {
assert((tp.getRank() > 1) && "unlowerable vector type");
return VectorType::get(tp.getShape().drop_front(), tp.getElementType(),
tp.getScalableDims().drop_front());
}

// Helper to reduce vector type by *all* but one rank at back.
static VectorType reducedVectorTypeBack(VectorType tp) {
assert((tp.getRank() > 1) && "unlowerable vector type");
Expand Down Expand Up @@ -1409,89 +1416,45 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
public:
using ConvertOpToLLVMPattern<vector::PrintOp>::ConvertOpToLLVMPattern;

// Lowering implementation that relies on a small runtime support library,
// which only needs to provide a few printing methods (single value for all
// data types, opening/closing bracket, comma, newline). The lowering splits
// the vector into elementary printing operations. The advantage of this
// approach is that the library can remain unaware of all low-level
// implementation details of vectors while still supporting output of any
// shaped and dimensioned vector.
//
// Note: This lowering only handles scalars, n-D vectors are broken into
// printing scalars in loops in VectorToSCF.
// Proof-of-concept lowering implementation that relies on a small
// runtime support library, which only needs to provide a few
// printing methods (single value for all data types, opening/closing
// bracket, comma, newline). The lowering fully unrolls a vector
// in terms of these elementary printing operations. The advantage
// of this approach is that the library can remain unaware of all
// low-level implementation details of vectors while still supporting
// output of any shaped and dimensioned vector. Due to full unrolling,
// this approach is less suited for very large vectors though.
//
// TODO: rely solely on libc in future? something else?
//
LogicalResult
matchAndRewrite(vector::PrintOp printOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto parent = printOp->getParentOfType<ModuleOp>();
auto loc = printOp->getLoc();
Type printType = printOp.getPrintType();

if (auto value = adaptor.getSource()) {
Type printType = printOp.getPrintType();
if (isa<VectorType>(printType)) {
// Vectors should be broken into elementary print ops in VectorToSCF.
return failure();
}
if (failed(emitScalarPrint(rewriter, parent, loc, printType, value)))
return failure();
}

auto punct = printOp.getPunctuation();
if (punct != PrintPunctuation::NoPunctuation) {
emitCall(rewriter, printOp->getLoc(), [&] {
switch (punct) {
case PrintPunctuation::Close:
return LLVM::lookupOrCreatePrintCloseFn(parent);
case PrintPunctuation::Open:
return LLVM::lookupOrCreatePrintOpenFn(parent);
case PrintPunctuation::Comma:
return LLVM::lookupOrCreatePrintCommaFn(parent);
case PrintPunctuation::NewLine:
return LLVM::lookupOrCreatePrintNewlineFn(parent);
default:
llvm_unreachable("unexpected punctuation");
}
}());
}

rewriter.eraseOp(printOp);
return success();
}

private:
enum class PrintConversion {
// clang-format off
None,
ZeroExt64,
SignExt64,
Bitcast16
// clang-format on
};

LogicalResult emitScalarPrint(ConversionPatternRewriter &rewriter,
ModuleOp parent, Location loc, Type printType,
Value value) const {
if (typeConverter->convertType(printType) == nullptr)
return failure();

// Make sure element type has runtime support.
PrintConversion conversion = PrintConversion::None;
VectorType vectorType = dyn_cast<VectorType>(printType);
Type eltType = vectorType ? vectorType.getElementType() : printType;
auto parent = printOp->getParentOfType<ModuleOp>();
Operation *printer;
if (printType.isF32()) {
if (eltType.isF32()) {
printer = LLVM::lookupOrCreatePrintF32Fn(parent);
} else if (printType.isF64()) {
} else if (eltType.isF64()) {
printer = LLVM::lookupOrCreatePrintF64Fn(parent);
} else if (printType.isF16()) {
} else if (eltType.isF16()) {
conversion = PrintConversion::Bitcast16; // bits!
printer = LLVM::lookupOrCreatePrintF16Fn(parent);
} else if (printType.isBF16()) {
} else if (eltType.isBF16()) {
conversion = PrintConversion::Bitcast16; // bits!
printer = LLVM::lookupOrCreatePrintBF16Fn(parent);
} else if (printType.isIndex()) {
} else if (eltType.isIndex()) {
printer = LLVM::lookupOrCreatePrintU64Fn(parent);
} else if (auto intTy = dyn_cast<IntegerType>(printType)) {
} else if (auto intTy = dyn_cast<IntegerType>(eltType)) {
// Integers need a zero or sign extension on the operand
// (depending on the source type) as well as a signed or
// unsigned print method. Up to 64-bit is supported.
Expand Down Expand Up @@ -1522,26 +1485,88 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
return failure();
}

switch (conversion) {
case PrintConversion::ZeroExt64:
value = rewriter.create<arith::ExtUIOp>(
loc, IntegerType::get(rewriter.getContext(), 64), value);
break;
case PrintConversion::SignExt64:
value = rewriter.create<arith::ExtSIOp>(
loc, IntegerType::get(rewriter.getContext(), 64), value);
break;
case PrintConversion::Bitcast16:
value = rewriter.create<LLVM::BitcastOp>(
loc, IntegerType::get(rewriter.getContext(), 16), value);
break;
case PrintConversion::None:
break;
}
emitCall(rewriter, loc, printer, value);
// Unroll vector into elementary print calls.
int64_t rank = vectorType ? vectorType.getRank() : 0;
Type type = vectorType ? vectorType : eltType;
emitRanks(rewriter, printOp, adaptor.getSource(), type, printer, rank,
conversion);
emitCall(rewriter, printOp->getLoc(),
LLVM::lookupOrCreatePrintNewlineFn(parent));
rewriter.eraseOp(printOp);
return success();
}

private:
enum class PrintConversion {
// clang-format off
None,
ZeroExt64,
SignExt64,
Bitcast16
// clang-format on
};

void emitRanks(ConversionPatternRewriter &rewriter, Operation *op,
Value value, Type type, Operation *printer, int64_t rank,
PrintConversion conversion) const {
VectorType vectorType = dyn_cast<VectorType>(type);
Location loc = op->getLoc();
if (!vectorType) {
assert(rank == 0 && "The scalar case expects rank == 0");
switch (conversion) {
case PrintConversion::ZeroExt64:
value = rewriter.create<arith::ExtUIOp>(
loc, IntegerType::get(rewriter.getContext(), 64), value);
break;
case PrintConversion::SignExt64:
value = rewriter.create<arith::ExtSIOp>(
loc, IntegerType::get(rewriter.getContext(), 64), value);
break;
case PrintConversion::Bitcast16:
value = rewriter.create<LLVM::BitcastOp>(
loc, IntegerType::get(rewriter.getContext(), 16), value);
break;
case PrintConversion::None:
break;
}
emitCall(rewriter, loc, printer, value);
return;
}

auto parent = op->getParentOfType<ModuleOp>();
emitCall(rewriter, loc, LLVM::lookupOrCreatePrintOpenFn(parent));
Operation *printComma = LLVM::lookupOrCreatePrintCommaFn(parent);

if (rank <= 1) {
auto reducedType = vectorType.getElementType();
auto llvmType = typeConverter->convertType(reducedType);
int64_t dim = rank == 0 ? 1 : vectorType.getDimSize(0);
for (int64_t d = 0; d < dim; ++d) {
Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
llvmType, /*rank=*/0, /*pos=*/d);
emitRanks(rewriter, op, nestedVal, reducedType, printer, /*rank=*/0,
conversion);
if (d != dim - 1)
emitCall(rewriter, loc, printComma);
}
emitCall(rewriter, loc, LLVM::lookupOrCreatePrintCloseFn(parent));
return;
}

int64_t dim = vectorType.getDimSize(0);
for (int64_t d = 0; d < dim; ++d) {
auto reducedType = reducedVectorTypeFront(vectorType);
auto llvmType = typeConverter->convertType(reducedType);
Value nestedVal = extractOne(rewriter, *getTypeConverter(), loc, value,
llvmType, rank, d);
emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1,
conversion);
if (d != dim - 1)
emitCall(rewriter, loc, printComma);
}
emitCall(rewriter, loc, LLVM::lookupOrCreatePrintCloseFn(parent));
}

// Helper to emit a call.
static void emitCall(ConversionPatternRewriter &rewriter, Location loc,
Operation *ref, ValueRange params = ValueRange()) {
Expand Down
Loading

0 comments on commit b160442

Please sign in to comment.