Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][SVE] Add an e2e test for vectorization of linalg.matmul #69592

Merged
merged 4 commits into from Oct 26, 2023

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Oct 19, 2023

Adds an end-2-end test for scalable vectorization of linalg.matmul. This
is the most basic case where the dimension along which we vectorize fits
perfectly within SVE registers. I will be extending this to more generic
cases in the forthcoming patches.

Depends on #68794.

@llvmbot
Copy link
Collaborator

llvmbot commented Oct 19, 2023

@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-sve

@llvm/pr-subscribers-mlir-llvm

Author: Andrzej Warzyński (banach-space)

Changes
  • [mlir][VectorOps] Support string literals in vector.print
  • [mlir][ArmSVE] Add -arm-sve-legalize-vector-storage pass
  • [mlir][SVE] Add an e2e test for vectorization of linalg.matmul

Patch is 51.15 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/69592.diff

18 Files Affected:

  • (added) mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h (+36)
  • (modified) mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt (+1)
  • (added) mlir/include/mlir/Dialect/ArmSVE/Transforms/CMakeLists.txt (+5)
  • (added) mlir/include/mlir/Dialect/ArmSVE/Transforms/Passes.h (+33)
  • (added) mlir/include/mlir/Dialect/ArmSVE/Transforms/Passes.td (+67)
  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+33-4)
  • (modified) mlir/include/mlir/InitAllPasses.h (+2)
  • (modified) mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp (+3-46)
  • (modified) mlir/lib/Conversion/LLVMCommon/CMakeLists.txt (+1)
  • (added) mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp (+66)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+5-1)
  • (modified) mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt (+2)
  • (added) mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp (+310)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+14)
  • (added) mlir/test/Dialect/ArmSVE/legalize-vector-storage.mlir (+160)
  • (added) mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir (+77)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/arrays-of-scalable-vectors.mlir (+121)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/test-hello-world.mlir (+10)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
new file mode 100644
index 000000000000000..7e26858589f2756
--- /dev/null
+++ b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
@@ -0,0 +1,36 @@
+
+//===- PrintCallHelper.h - LLVM Interfaces ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LLVMIR_PRINTCALLHELPER_H_
+#define MLIR_DIALECT_LLVMIR_PRINTCALLHELPER_H_
+
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace mlir {
+
+class Location;
+class ModuleOp;
+class OpBuilder;
+class Operation;
+class Type;
+class ValueRange;
+class LLVMTypeConverter;
+
+namespace LLVM {
+
+/// Generate IR that prints the given string to stdout.
+void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp,
+                        StringRef symbolName, StringRef string,
+                        const LLVMTypeConverter &typeConverter);
+} // namespace LLVM
+
+} // namespace mlir
+
+#endif
diff --git a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
index f33061b2d87cffc..9f57627c321fb0c 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmSVE/CMakeLists.txt
@@ -1 +1,2 @@
 add_subdirectory(IR)
+add_subdirectory(Transforms)
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSVE/Transforms/CMakeLists.txt
new file mode 100644
index 000000000000000..7226642daf86172
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/CMakeLists.txt
@@ -0,0 +1,5 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name ArmSVE)
+add_public_tablegen_target(MLIRArmSVEPassIncGen)
+
+add_mlir_doc(Passes ArmSVEPasses ./ -gen-pass-doc)
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Passes.h
new file mode 100644
index 000000000000000..317fb9021b3c577
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Passes.h
@@ -0,0 +1,33 @@
+//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMSVE_TRANSFORMS_PASSES_H
+#define MLIR_DIALECT_ARMSVE_TRANSFORMS_PASSES_H
+
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir::arm_sve {
+
+#define GEN_PASS_DECL
+#include "mlir/Dialect/ArmSVE/Transforms/Passes.h.inc"
+
+/// Pass to legalize the types of mask stores.
+std::unique_ptr<Pass> createLegalizeVectorStoragePass();
+
+//===----------------------------------------------------------------------===//
+// Registration
+//===----------------------------------------------------------------------===//
+
+/// Generate the code for registering passes.
+#define GEN_PASS_REGISTRATION
+#include "mlir/Dialect/ArmSVE/Transforms/Passes.h.inc"
+
+} // namespace mlir::arm_sve
+
+#endif // MLIR_DIALECT_ARMSVE_TRANSFORMS_PASSES_H
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Passes.td
new file mode 100644
index 000000000000000..35c49607181da0c
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Passes.td
@@ -0,0 +1,67 @@
+//===-- Passes.td - ArmSVE pass definition file ------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMSVE_TRANSFORMS_PASSES_TD
+#define MLIR_DIALECT_ARMSVE_TRANSFORMS_PASSES_TD
+
+include "mlir/Pass/PassBase.td"
+
+def LegalizeVectorStorage
+    : Pass<"arm-sve-legalize-vector-storage", "mlir::func::FuncOp"> {
+  let summary = "Ensures stores of SVE vector types will be legal";
+  let description = [{
+    This pass ensures that loads, stores, and allocations of SVE vector types
+    will be legal in the LLVM backend. It does this at the memref level, so this
+    pass must be applied before lowering all the way to LLVM.
+
+    This pass currently fixes two issues.
+
+    ## Loading and storing predicate types
+
+    It is only legal to load/store predicate types equal to (or greater than) a
+    full predicate register, which in MLIR is `vector<[16]xi1>`. Smaller
+    predicate types (`vector<[1|2|4|8]xi1>`) must be converted to/from a full
+    predicate type (referred to as a `svbool`) before and after storing and
+    loading respectively. This pass does this by widening allocations and
+    inserting conversion intrinsics.
+
+    For example:
+
+    ```mlir
+    %alloca = memref.alloca() : memref<vector<[4]xi1>>
+    %mask = vector.constant_mask [4] : vector<[4]xi1>
+    memref.store %mask, %alloca[] : memref<vector<[4]xi1>>
+    %reload = memref.load %alloca[] : memref<vector<[4]xi1>>
+    ```
+    Becomes:
+    ```mlir
+    %alloca = memref.alloca() {alignment = 1 : i64} : memref<vector<[16]xi1>>
+    %mask = vector.constant_mask [4] : vector<[4]xi1>
+    %svbool = arm_sve.convert_to_svbool %mask : vector<[4]xi1>
+    memref.store %svbool, %alloca[] : memref<vector<[16]xi1>>
+    %reload_svbool = memref.load %alloca[] : memref<vector<[16]xi1>>
+    %reload = arm_sve.convert_from_svbool %reload_svbool : vector<[4]xi1>
+    ```
+
+    ## Relax alignments for SVE vector allocas
+
+    The storage for SVE vector types only needs to have an alignment that
+    matches the element type (for example 4 byte alignment for `f32`s). However,
+    the LLVM backend currently defaults to aligning to `base size` x
+    `element size` bytes. For non-legal vector types like `vector<[8]xf32>` this
+    results in 8 x 4 = 32-byte alignment, but the backend only supports up to
+    16-byte alignment for SVE vectors on the stack. Explicitly setting a smaller
+    alignment prevents this issue.
+  }];
+  let constructor = "mlir::arm_sve::createLegalizeVectorStoragePass()";
+  let dependentDialects = ["func::FuncDialect",
+    "memref::MemRefDialect", "vector::VectorDialect",
+    "arm_sve::ArmSVEDialect"];
+}
+
+#endif // MLIR_DIALECT_ARMSVE_TRANSFORMS_PASSES_TD
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 2df2fe4c5ce8e9c..2b60055ca9db94b 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -26,6 +26,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/VectorInterfaces.td"
 include "mlir/Interfaces/ViewLikeInterface.td"
+include "mlir/IR/BuiltinAttributes.td"
 
 // TODO: Add an attribute to specify a different algebra with operators other
 // than the current set: {*, +}.
@@ -2476,12 +2477,18 @@ def Vector_TransposeOp :
 }
 
 def Vector_PrintOp :
-  Vector_Op<"print", []>,
+  Vector_Op<"print", [
+    PredOpTrait<
+      "`source` or `punctuation` are not set printing strings",
+      CPred<"!getStringLiteral() || (!getSource() && getPunctuation() == PrintPunctuation::NewLine)">
+    >,
+  ]>,
   Arguments<(ins Optional<Type<Or<[
     AnyVectorOfAnyRank.predicate,
     AnyInteger.predicate, Index.predicate, AnyFloat.predicate
   ]>>>:$source, DefaultValuedAttr<Vector_PrintPunctuation,
-                      "::mlir::vector::PrintPunctuation::NewLine">:$punctuation)
+                      "::mlir::vector::PrintPunctuation::NewLine">:$punctuation,
+                OptionalAttr<Builtin_StringAttr>:$stringLiteral)
   > {
   let summary = "print operation (for testing and debugging)";
   let description = [{
@@ -2520,6 +2527,13 @@ def Vector_PrintOp :
     ```mlir
     vector.print punctuation <newline>
     ```
+
+    Additionally, to aid with debugging and testing `vector.print` can also
+    print constant strings:
+
+    ```mlir
+    vector.print str "Hello, World!"
+    ```
   }];
   let extraClassDeclaration = [{
     Type getPrintType() {
@@ -2528,11 +2542,26 @@ def Vector_PrintOp :
   }];
   let builders = [
     OpBuilder<(ins "PrintPunctuation":$punctuation), [{
-      build($_builder, $_state, {}, punctuation);
+      build($_builder, $_state, {}, punctuation, {});
+    }]>,
+    OpBuilder<(ins "::mlir::Value":$source), [{
+      build($_builder, $_state, source, PrintPunctuation::NewLine);
+    }]>,
+    OpBuilder<(ins "::mlir::Value":$source, "PrintPunctuation":$punctuation), [{
+      build($_builder, $_state, source, punctuation, {});
+    }]>,
+    OpBuilder<(ins "::llvm::StringRef":$string), [{
+      build($_builder, $_state, {}, PrintPunctuation::NewLine, $_builder.getStringAttr(string));
     }]>,
   ];
 
-  let assemblyFormat = "($source^ `:` type($source))? (`punctuation` $punctuation^)? attr-dict";
+  let assemblyFormat = [{
+      ($source^ `:` type($source))?
+        oilist(
+            `str` $stringLiteral
+          | `punctuation` $punctuation)
+        attr-dict
+    }];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h
index 5489a13a8040bdb..7301905954f56d8 100644
--- a/mlir/include/mlir/InitAllPasses.h
+++ b/mlir/include/mlir/InitAllPasses.h
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/Affine/Passes.h"
 #include "mlir/Dialect/Arith/Transforms/Passes.h"
 #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Passes.h"
 #include "mlir/Dialect/Async/Passes.h"
 #include "mlir/Dialect/Bufferization/Pipelines/Passes.h"
 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"
@@ -82,6 +83,7 @@ inline void registerAllPasses() {
   transform::registerTransformPasses();
   vector::registerVectorPasses();
   arm_sme::registerArmSMEPasses();
+  arm_sve::registerArmSVEPasses();
 
   // Dialect pipelines
   bufferization::registerBufferizationPipelines();
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index a4f146bbe475cc6..6b7647b038f1d94 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
@@ -36,51 +37,6 @@ using namespace mlir;
 
 #define PASS_NAME "convert-cf-to-llvm"
 
-static std::string generateGlobalMsgSymbolName(ModuleOp moduleOp) {
-  std::string prefix = "assert_msg_";
-  int counter = 0;
-  while (moduleOp.lookupSymbol(prefix + std::to_string(counter)))
-    ++counter;
-  return prefix + std::to_string(counter);
-}
-
-/// Generate IR that prints the given string to stderr.
-static void createPrintMsg(OpBuilder &builder, Location loc, ModuleOp moduleOp,
-                           StringRef msg,
-                           const LLVMTypeConverter &typeConverter) {
-  auto ip = builder.saveInsertionPoint();
-  builder.setInsertionPointToStart(moduleOp.getBody());
-  MLIRContext *ctx = builder.getContext();
-
-  // Create a zero-terminated byte representation and allocate global symbol.
-  SmallVector<uint8_t> elementVals;
-  elementVals.append(msg.begin(), msg.end());
-  elementVals.push_back(0);
-  auto dataAttrType = RankedTensorType::get(
-      {static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
-  auto dataAttr =
-      DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals));
-  auto arrayTy =
-      LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
-  std::string symbolName = generateGlobalMsgSymbolName(moduleOp);
-  auto globalOp = builder.create<LLVM::GlobalOp>(
-      loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private, symbolName,
-      dataAttr);
-
-  // Emit call to `printStr` in runtime library.
-  builder.restoreInsertionPoint(ip);
-  auto msgAddr = builder.create<LLVM::AddressOfOp>(
-      loc, typeConverter.getPointerType(arrayTy), globalOp.getName());
-  SmallVector<LLVM::GEPArg> indices(1, 0);
-  Value gep = builder.create<LLVM::GEPOp>(
-      loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr,
-      indices);
-  Operation *printer = LLVM::lookupOrCreatePrintStrFn(
-      moduleOp, typeConverter.useOpaquePointers());
-  builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
-                               gep);
-}
-
 namespace {
 /// Lower `cf.assert`. The default lowering calls the `abort` function if the
 /// assertion is violated and has no effect otherwise. The failure message is
@@ -105,7 +61,8 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
 
     // Failed block: Generate IR to print the message and call `abort`.
     Block *failureBlock = rewriter.createBlock(opBlock->getParent());
-    createPrintMsg(rewriter, loc, module, op.getMsg(), *getTypeConverter());
+    LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(),
+                             *getTypeConverter());
     if (abortOnFailedAssert) {
       // Insert the `abort` declaration if necessary.
       auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
diff --git a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
index 091cd539f0ae014..568d9339aaabcb4 100644
--- a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
+++ b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_conversion_library(MLIRLLVMCommonConversion
   LoweringOptions.cpp
   MemRefBuilder.cpp
   Pattern.cpp
+  PrintCallHelper.cpp
   StructBuilder.cpp
   TypeConverter.cpp
   VectorPattern.cpp
diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
new file mode 100644
index 000000000000000..487abb435d10ad7
--- /dev/null
+++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
@@ -0,0 +1,66 @@
+
+//===- PrintCallHelper.cpp - LLVM Interfaces --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "llvm/ADT/ArrayRef.h"
+
+using namespace mlir;
+using namespace llvm;
+
+static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp,
+                                            StringRef symbolName) {
+  static int counter = 0;
+  std::string uniqueName = std::string(symbolName);
+  while (moduleOp.lookupSymbol(uniqueName)) {
+    uniqueName = std::string(symbolName) + "_" + std::to_string(counter++);
+  }
+  return uniqueName;
+}
+
+void mlir::LLVM::createPrintStrCall(OpBuilder &builder, Location loc,
+                                    ModuleOp moduleOp, StringRef symbolName,
+                                    StringRef string,
+                                    const LLVMTypeConverter &typeConverter) {
+  auto ip = builder.saveInsertionPoint();
+  builder.setInsertionPointToStart(moduleOp.getBody());
+  MLIRContext *ctx = builder.getContext();
+
+  // Create a zero-terminated byte representation and allocate global symbol.
+  SmallVector<uint8_t> elementVals;
+  elementVals.append(string.begin(), string.end());
+  elementVals.push_back(0);
+  auto dataAttrType = RankedTensorType::get(
+      {static_cast<int64_t>(elementVals.size())}, builder.getI8Type());
+  auto dataAttr =
+      DenseElementsAttr::get(dataAttrType, llvm::ArrayRef(elementVals));
+  auto arrayTy =
+      LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), elementVals.size());
+  auto globalOp = builder.create<LLVM::GlobalOp>(
+      loc, arrayTy, /*constant=*/true, LLVM::Linkage::Private,
+      ensureSymbolNameIsUnique(moduleOp, symbolName), dataAttr);
+
+  // Emit call to `printStr` in runtime library.
+  builder.restoreInsertionPoint(ip);
+  auto msgAddr = builder.create<LLVM::AddressOfOp>(
+      loc, typeConverter.getPointerType(arrayTy), globalOp.getName());
+  SmallVector<LLVM::GEPArg> indices(1, 0);
+  Value gep = builder.create<LLVM::GEPOp>(
+      loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr,
+      indices);
+  Operation *printer = LLVM::lookupOrCreatePrintStrFn(
+      moduleOp, typeConverter.useOpaquePointers());
+  builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
+                               gep);
+}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 8427d60f14c0bcc..4af58653c8227ae 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 
 #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
+#include "mlir/Conversion/LLVMCommon/PrintCallHelper.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
@@ -1548,7 +1549,10 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
     }
 
     auto punct = printOp.getPunctuation();
-    if (punct != PrintPunctuation::NoPunctuation) {
+    if (auto stringLiteral = printOp.getStringLiteral()) {
+      LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
+                               *stringLiteral, *getTypeConverter());
+    } else if (punct != PrintPunctuation::NoPunctuation) {
       emitCall(rewriter, printOp->getLoc(), [&] {
         switch (punct) {
         case PrintPunctuation::Close:
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
index 2f1c43fae240d51..a70c489a51fea9a 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
@@ -1,8 +1,10 @@
 add_mlir_dialect_library(MLIRArmSVETransforms
   LegalizeForLLVMExport.cpp
+  LegalizeVectorStorage.cpp
 
   DEPENDS
   MLIRArmSVEConversionsIncGen
+  MLIRArmSVEPassIncGen
 
   LINK_LIBS PUBLIC
   MLIRArmSVEDialect
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
new file mode 100644
index 000000000000000..610eb38089c4c88
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
@@ -0,0 +1,310 @@
+//===- LegalizeVectorStorage.cpp - Ensures SVE loads/stores are legal -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#i...
[truncated]

@banach-space banach-space changed the title andrzej/add e2e scalable matmul [mlir][SVE] Add an e2e test for vectorization of linalg.matmul Oct 19, 2023
@banach-space
Copy link
Contributor Author

NOTE: Only the latest commit belongs to this PR.

// Hence, when checking the outupt there will always be at least 4 elements
// in every row. For implementations with wider vectors, you should see more
// elements being printed.
// CHECK: [9.8596, 9.8596, 9.8596, 9.8596
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// CHECK: [9.8596, 9.8596, 9.8596, 9.8596
// CHECK-NEXT: [9.8596, 9.8596, 9.8596, 9.8596

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The actual output is:

SVE: START OF TEST OUTPUT
Unranked Memref base@ = 0x5500237580 rank = 2 offset = 0 sizes = [2, 16] strides = [16, 1] data =
[[9.8596,   9.8596,   9.8596,   9.8596,   9.8596,   9.8596,   9.8596,   9.8596,   9.8596,   9.8596,   9.8596,   9.8596,   9.8596,   9.8596,   9.8596,   9.8596],
 [9.8596,   9.8596,   9.8596,   9.8596,   9.8596,   9.8596,   9.8596,   9.8596,   9.8596,   9.8596,   9.8596,   9.8596,   9.8596,   9.8596,   9.8596,   9.8596]]
SVE: END OF TEST OUTPUT

Not sure it's worth checking the 2nd line?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently it's doing:

// CHECK-LABEL: 
// CHECK:
// CHECK-NEXT:

But you could do:

// CHECK-LABEL: 
// CHECK-NEXT:
// CHECK-NEXT:

And that'd work too, and be nice an aligned.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍🏻 Updated, will land shortly.

@banach-space banach-space force-pushed the andrzej/add_e2e_scalable_matmul branch from 395644a to c75447f Compare October 24, 2023 14:21
Adds an end-2-end test for scalable vectorization of linalg.matmul. This
is the most basic case where the dimension along which we vectorize fits
perfectly within SVE registers. I will be extending this to more generic
cases in the forthcoming patches.

Depends on llvm#68794.
@banach-space banach-space force-pushed the andrzej/add_e2e_scalable_matmul branch from c75447f to 2b08169 Compare October 26, 2023 11:51
@banach-space banach-space merged commit 64025b8 into llvm:main Oct 26, 2023
2 of 3 checks passed
@banach-space banach-space deleted the andrzej/add_e2e_scalable_matmul branch October 26, 2023 12:16
zahiraam pushed a commit to zahiraam/llvm-project that referenced this pull request Oct 26, 2023
zahiraam pushed a commit to zahiraam/llvm-project that referenced this pull request Oct 26, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants