Skip to content

Commit

Permalink
[mlir][VectorOps] Support string literals in vector.print (#68695)
Browse files Browse the repository at this point in the history
Printing strings within integration tests is currently quite annoyingly
verbose, and can't be tucked into shared helpers as the types depend on
the length of the string:

```
llvm.mlir.global internal constant @hello_world("Hello, World!\0")

func.func @entry() {
  %0 = llvm.mlir.addressof @hello_world : !llvm.ptr<array<14 x i8>>
  %1 = llvm.mlir.constant(0 : index) : i64
  %2 = llvm.getelementptr %0[%1, %1]
    : (!llvm.ptr<array<14 x i8>>, i64, i64) -> !llvm.ptr<i8>
  llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
  return
}
```

So this patch adds a simple extension to `vector.print` to simplify
this:
```
func.func @entry() {
   // Print a vector of characters ;)
   vector.print str "Hello, World!"
   return
}
```

Most of the logic for this is now shared with `cf.assert` which already
does something similar.

Depends on #68694
  • Loading branch information
MacDue committed Oct 24, 2023
1 parent 1072fcd commit 3be3883
Show file tree
Hide file tree
Showing 14 changed files with 204 additions and 58 deletions.
35 changes: 35 additions & 0 deletions mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
//===- PrintCallHelper.h - Helper to emit runtime print calls ---*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_LLVMIR_PRINTCALLHELPER_H_
#define MLIR_DIALECT_LLVMIR_PRINTCALLHELPER_H_

#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "llvm/ADT/StringRef.h"
#include <optional>

namespace mlir {

class OpBuilder;
class LLVMTypeConverter;

namespace LLVM {

/// Generate IR that prints the given string to stdout.
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
/// have the signature void(char const*). The default function is `printString`.
void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp,
StringRef symbolName, StringRef string,
const LLVMTypeConverter &typeConverter,
bool addNewline = true,
std::optional<StringRef> runtimeFunctionName = {});
} // namespace LLVM

} // namespace mlir

#endif
9 changes: 7 additions & 2 deletions mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "mlir/IR/Operation.h"
#include "mlir/Support/LLVM.h"
#include <optional>

namespace mlir {
class Location;
Expand All @@ -38,8 +39,12 @@ LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintStrFn(ModuleOp moduleOp,
bool opaquePointers);
/// Declares a function to print a C-string.
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
/// have the signature void(char const*). The default function is `printString`.
LLVM::LLVMFuncOp
lookupOrCreatePrintStringFn(ModuleOp moduleOp, bool opaquePointers,
std::optional<StringRef> runtimeFunctionName = {});
LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(ModuleOp moduleOp);
Expand Down
37 changes: 33 additions & 4 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/VectorInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir/IR/BuiltinAttributes.td"

// TODO: Add an attribute to specify a different algebra with operators other
// than the current set: {*, +}.
Expand Down Expand Up @@ -2477,12 +2478,18 @@ def Vector_TransposeOp :
}

def Vector_PrintOp :
Vector_Op<"print", []>,
Vector_Op<"print", [
PredOpTrait<
"`source` or `punctuation` are not set when printing strings",
CPred<"!getStringLiteral() || (!getSource() && getPunctuation() == PrintPunctuation::NewLine)">
>,
]>,
Arguments<(ins Optional<Type<Or<[
AnyVectorOfAnyRank.predicate,
AnyInteger.predicate, Index.predicate, AnyFloat.predicate
]>>>:$source, DefaultValuedAttr<Vector_PrintPunctuation,
"::mlir::vector::PrintPunctuation::NewLine">:$punctuation)
"::mlir::vector::PrintPunctuation::NewLine">:$punctuation,
OptionalAttr<Builtin_StringAttr>:$stringLiteral)
> {
let summary = "print operation (for testing and debugging)";
let description = [{
Expand Down Expand Up @@ -2521,6 +2528,13 @@ def Vector_PrintOp :
```mlir
vector.print punctuation <newline>
```

Additionally, to aid with debugging and testing `vector.print` can also
print constant strings:

```mlir
vector.print str "Hello, World!"
```
}];
let extraClassDeclaration = [{
Type getPrintType() {
Expand All @@ -2529,11 +2543,26 @@ def Vector_PrintOp :
}];
let builders = [
OpBuilder<(ins "PrintPunctuation":$punctuation), [{
build($_builder, $_state, {}, punctuation);
build($_builder, $_state, {}, punctuation, {});
}]>,
OpBuilder<(ins "::mlir::Value":$source), [{
build($_builder, $_state, source, PrintPunctuation::NewLine);
}]>,
OpBuilder<(ins "::mlir::Value":$source, "PrintPunctuation":$punctuation), [{
build($_builder, $_state, source, punctuation, {});
}]>,
OpBuilder<(ins "::llvm::StringRef":$string), [{
build($_builder, $_state, {}, PrintPunctuation::NewLine, $_builder.getStringAttr(string));
}]>,
];

let assemblyFormat = "($source^ `:` type($source))? (`punctuation` $punctuation^)? attr-dict";
let assemblyFormat = [{
($source^ `:` type($source))?
oilist(
`str` $stringLiteral
| `punctuation` $punctuation)
attr-dict
}];
}

//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/ExecutionEngine/CRunnerUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,7 @@ extern "C" MLIR_CRUNNERUTILS_EXPORT void printI64(int64_t i);
extern "C" MLIR_CRUNNERUTILS_EXPORT void printU64(uint64_t u);
extern "C" MLIR_CRUNNERUTILS_EXPORT void printF32(float f);
extern "C" MLIR_CRUNNERUTILS_EXPORT void printF64(double d);
extern "C" MLIR_CRUNNERUTILS_EXPORT void printString(char const *s);
extern "C" MLIR_CRUNNERUTILS_EXPORT void printOpen();
extern "C" MLIR_CRUNNERUTILS_EXPORT void printClose();
extern "C" MLIR_CRUNNERUTILS_EXPORT void printComma();
Expand Down
50 changes: 4 additions & 46 deletions mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
Expand All @@ -36,51 +37,6 @@ using namespace mlir;

#define PASS_NAME "convert-cf-to-llvm"

static std::string generateGlobalMsgSymbolName(ModuleOp moduleOp) {
std::string prefix = "assert_msg_";
int counter = 0;
while (moduleOp.lookupSymbol(prefix + std::to_string(counter)))
++counter;
return prefix + std::to_string(counter);
}

/// Generate IR that prints the given string to stderr.
static void createPrintMsg(OpBuilder &builder, Location loc, ModuleOp moduleOp,
StringRef msg,
const LLVMTypeConverter &typeConverter) {
auto ip = builder.saveInsertionPoint();
builder.setInsertionPointToStart(moduleOp.getBody());
MLIRContext *ctx = builder.getContext();

// Create a zero-terminated byte representation and allocate global symbol.
SmallVector<uint8_t> elementVals;
elementVals.append(msg.begin(), msg.end());
elementVals.push_back(0);
auto dataAttrType = RankedTensorType::get(
{static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
auto dataAttr =
DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals));
auto arrayTy =
LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
std::string symbolName = generateGlobalMsgSymbolName(moduleOp);
auto globalOp = builder.create<LLVM::GlobalOp>(
loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private, symbolName,
dataAttr);

// Emit call to `printStr` in runtime library.
builder.restoreInsertionPoint(ip);
auto msgAddr = builder.create<LLVM::AddressOfOp>(
loc, typeConverter.getPointerType(arrayTy), globalOp.getName());
SmallVector<LLVM::GEPArg> indices(1, 0);
Value gep = builder.create<LLVM::GEPOp>(
loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr,
indices);
Operation *printer = LLVM::lookupOrCreatePrintStrFn(
moduleOp, typeConverter.useOpaquePointers());
builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
gep);
}

namespace {
/// Lower `cf.assert`. The default lowering calls the `abort` function if the
/// assertion is violated and has no effect otherwise. The failure message is
Expand All @@ -105,7 +61,9 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {

// Failed block: Generate IR to print the message and call `abort`.
Block *failureBlock = rewriter.createBlock(opBlock->getParent());
createPrintMsg(rewriter, loc, module, op.getMsg(), *getTypeConverter());
LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(),
*getTypeConverter(), /*addNewLine=*/false,
/*runtimeFunctionName=*/"puts");
if (abortOnFailedAssert) {
// Insert the `abort` declaration if necessary.
auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_mlir_conversion_library(MLIRLLVMCommonConversion
LoweringOptions.cpp
MemRefBuilder.cpp
Pattern.cpp
PrintCallHelper.cpp
StructBuilder.cpp
TypeConverter.cpp
VectorPattern.cpp
Expand Down
66 changes: 66 additions & 0 deletions mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
//===- PrintCallHelper.cpp - Helper to emit runtime print calls -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "llvm/ADT/ArrayRef.h"

using namespace mlir;
using namespace llvm;

static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp,
StringRef symbolName) {
static int counter = 0;
std::string uniqueName = std::string(symbolName);
while (moduleOp.lookupSymbol(uniqueName)) {
uniqueName = std::string(symbolName) + "_" + std::to_string(counter++);
}
return uniqueName;
}

void mlir::LLVM::createPrintStrCall(
OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName,
StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline,
std::optional<StringRef> runtimeFunctionName) {
auto ip = builder.saveInsertionPoint();
builder.setInsertionPointToStart(moduleOp.getBody());
MLIRContext *ctx = builder.getContext();

// Create a zero-terminated byte representation and allocate global symbol.
SmallVector<uint8_t> elementVals;
elementVals.append(string.begin(), string.end());
if (addNewline)
elementVals.push_back('\n');
elementVals.push_back('\0');
auto dataAttrType = RankedTensorType::get(
{static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
auto dataAttr =
DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals));
auto arrayTy =
LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
auto globalOp = builder.create<LLVM::GlobalOp>(
loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private,
ensureSymbolNameIsUnique(moduleOp, symbolName), dataAttr);

// Emit call to `printStr` in runtime library.
builder.restoreInsertionPoint(ip);
auto msgAddr = builder.create<LLVM::AddressOfOp>(
loc, typeConverter.getPointerType(arrayTy), globalOp.getName());
SmallVector<LLVM::GEPArg> indices(1, 0);
Value gep = builder.create<LLVM::GEPOp>(
loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr,
indices);
Operation *printer = LLVM::lookupOrCreatePrintStringFn(
moduleOp, typeConverter.useOpaquePointers(), runtimeFunctionName);
builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
gep);
}
6 changes: 5 additions & 1 deletion mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"

#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
Expand Down Expand Up @@ -1548,7 +1549,10 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
}

auto punct = printOp.getPunctuation();
if (punct != PrintPunctuation::NoPunctuation) {
if (auto stringLiteral = printOp.getStringLiteral()) {
LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
*stringLiteral, *getTypeConverter());
} else if (punct != PrintPunctuation::NoPunctuation) {
emitCall(rewriter, printOp->getLoc(), [&] {
switch (punct) {
case PrintPunctuation::Close:
Expand Down
9 changes: 5 additions & 4 deletions mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ static constexpr llvm::StringRef kPrintF16 = "printF16";
static constexpr llvm::StringRef kPrintBF16 = "printBF16";
static constexpr llvm::StringRef kPrintF32 = "printF32";
static constexpr llvm::StringRef kPrintF64 = "printF64";
static constexpr llvm::StringRef kPrintStr = "puts";
static constexpr llvm::StringRef kPrintString = "printString";
static constexpr llvm::StringRef kPrintOpen = "printOpen";
static constexpr llvm::StringRef kPrintClose = "printClose";
static constexpr llvm::StringRef kPrintComma = "printComma";
Expand Down Expand Up @@ -107,9 +107,10 @@ static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context,
return getCharPtr(context, opaquePointers);
}

LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStrFn(ModuleOp moduleOp,
bool opaquePointers) {
return lookupOrCreateFn(moduleOp, kPrintStr,
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn(
ModuleOp moduleOp, bool opaquePointers,
std::optional<StringRef> runtimeFunctionName) {
return lookupOrCreateFn(moduleOp, runtimeFunctionName.value_or(kPrintString),
getCharPtr(moduleOp->getContext(), opaquePointers),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/ExecutionEngine/CRunnerUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ extern "C" void printI64(int64_t i) { fprintf(stdout, "%" PRId64, i); }
extern "C" void printU64(uint64_t u) { fprintf(stdout, "%" PRIu64, u); }
extern "C" void printF32(float f) { fprintf(stdout, "%g", f); }
extern "C" void printF64(double d) { fprintf(stdout, "%lg", d); }
extern "C" void printString(char const *s) { fputs(s, stdout); }
extern "C" void printOpen() { fputs("( ", stdout); }
extern "C" void printClose() { fputs(" )", stdout); }
extern "C" void printComma() { fputs(", ", stdout); }
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/ExecutionEngine/RunnerUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ extern "C" void printMemrefC64(int64_t rank, void *ptr) {
_mlir_ciface_printMemrefC64(&descriptor);
}

extern "C" void printCString(char *str) { printf("%s", str); }
/// Deprecated. This should be unified with printString from CRunnerUtils.
extern "C" void printCString(char *str) { fputs(str, stdout); }

extern "C" void _mlir_ciface_printMemref0dF32(StridedMemRefType<float, 0> *M) {
impl::printMemRef(*M);
Expand Down
14 changes: 14 additions & 0 deletions mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1068,6 +1068,20 @@ func.func @vector_print_scalar_f64(%arg0: f64) {

// -----

// CHECK-LABEL: module {
// CHECK: llvm.func @printString(!llvm.ptr)
// CHECK: llvm.mlir.global private constant @[[GLOBAL_STR:.*]]({{.*}})
// CHECK: @vector_print_string
// CHECK-NEXT: %[[GLOBAL_ADDR:.*]] = llvm.mlir.addressof @[[GLOBAL_STR]] : !llvm.ptr
// CHECK-NEXT: %[[STR_PTR:.*]] = llvm.getelementptr %[[GLOBAL_ADDR]][0] : (!llvm.ptr) -> !llvm.ptr
// CHECK-NEXT: llvm.call @printString(%[[STR_PTR]]) : (!llvm.ptr) -> ()
func.func @vector_print_string() {
vector.print str "Hello, World!"
return
}

// -----

func.func @extract_strided_slice1(%arg0: vector<4xf32>) -> vector<2xf32> {
%0 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
return %0 : vector<2xf32>
Expand Down
16 changes: 16 additions & 0 deletions mlir/test/Dialect/Vector/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,22 @@ func.func private @print_needs_vector(%arg0: tensor<8xf32>) {

// -----

func.func @cannot_print_string_with_punctuation_set() {
// expected-error@+1 {{`source` or `punctuation` are not set when printing strings}}
vector.print str "Whoops!" punctuation <comma>
return
}

// -----

func.func @cannot_print_string_with_source_set(%vec: vector<[4]xf32>) {
// expected-error@+1 {{`source` or `punctuation` are not set when printing strings}}
vector.print %vec: vector<[4]xf32> str "Yay!"
return
}

// -----

func.func @reshape_bad_input_shape(%arg0 : vector<3x2x4xf32>) {
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
Expand Down

0 comments on commit 3be3883

Please sign in to comment.