Skip to content

[mlir][x86] Move AMX dialect into X86 dialect#183717

Merged
adam-smnk merged 1 commit into
llvm:mainfrom
adam-smnk:mlir-unify-x86-dialects
Mar 2, 2026
Merged

[mlir][x86] Move AMX dialect into X86 dialect#183717
adam-smnk merged 1 commit into
llvm:mainfrom
adam-smnk:mlir-unify-x86-dialects

Conversation

@adam-smnk
Copy link
Copy Markdown
Member

@adam-smnk adam-smnk commented Feb 27, 2026

Unifies the two dialects that define x86 operations into a single one. The AMX dialect is moved into X86 in line with other x86 extensions.

Following the dialect renaming, X86 dialect is now a suitable home for wider range of operations targeting specific hardware features. Moving AMX definitions to X86 dialect creates a single, centralized hub for defining all x86 intrinsic-like operations. The new grouping aims to eliminate the need for new dialects as new hardware extensions become available.

The two dialects are simply merged together. X86 dialect refactoring will be addressed separately.

List of changes:

  • operations: 'amx.tile_' => 'x86.amx.tile_'
  • types: '!amx.tile' => '!x86.amx.tile'
  • namespace: 'mlir::amx' => 'mlir::x86::amx'
  • test define: 'MLIR_RUN_AMX_TESTS' => 'MLIR_RUN_X86_AMX_TESTS'
  • vector lowering: AMX is enabled by default together with X86

The MLIR AMX tests are now nested under X86 directory. To enable AMX integration tests, 'MLIR_RUN_X86_TESTS' must also be defined.

Unifies the two dialects that define x86 operations into a single one.
The AMX dialect is moved into x86 in line with other x86 extensions.

The two dialects are simply merged together. X86 dialect refactoring
will be addressed separately.

List of changes:
  - operations: 'amx.tile_*' => 'x86.amx.tile_*'
  - types: '!amx.tile' => '!x86.amx.tile'
  - namespace: 'mlir::amx' => 'mlir::x86::amx'
  - test define: 'MLIR_RUN_AMX_TESTS' => 'MLIR_RUN_X86_AMX_TESTS'
  - vector lowering: AMX is enabled by default together with X86

The MLIR AMX tests are now nested under X86 directory. To enable AMX
integration tests, 'MLIR_RUN_X86_TESTS' must also be defined.
@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Feb 27, 2026

@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sparse

Author: Adam Siemieniuk (adam-smnk)

Changes

Unifies the two dialects that define x86 operations into a single one. The AMX dialect is moved into X86 in line with other x86 extensions.

The two dialects are simply merged together. X86 dialect refactoring will be addressed separately.

List of changes:

  • operations: 'amx.tile_' => 'x86.amx.tile_'
  • types: '!amx.tile' => '!x86.amx.tile'
  • namespace: 'mlir::amx' => 'mlir::x86::amx'
  • test define: 'MLIR_RUN_AMX_TESTS' => 'MLIR_RUN_X86_AMX_TESTS'
  • vector lowering: AMX is enabled by default together with X86

The MLIR AMX tests are now nested under X86 directory. To enable AMX integration tests, 'MLIR_RUN_X86_TESTS' must also be defined.


Patch is 163.24 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/183717.diff

56 Files Affected:

  • (modified) mlir/Maintainers.md (-1)
  • (modified) mlir/docs/TargetLLVMIR.md (+2-2)
  • (removed) mlir/include/mlir-c/Dialect/AMX.h (-25)
  • (modified) mlir/include/mlir/Conversion/Passes.td (+5-9)
  • (modified) mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h (+2-2)
  • (removed) mlir/include/mlir/Dialect/AMX/AMX.td (-440)
  • (removed) mlir/include/mlir/Dialect/AMX/AMXDialect.h (-34)
  • (removed) mlir/include/mlir/Dialect/AMX/AMXInterfaces.td (-31)
  • (removed) mlir/include/mlir/Dialect/AMX/CMakeLists.txt (-5)
  • (removed) mlir/include/mlir/Dialect/AMX/Transforms.h (-33)
  • (modified) mlir/include/mlir/Dialect/CMakeLists.txt (-1)
  • (modified) mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h (-5)
  • (modified) mlir/include/mlir/Dialect/X86/Transforms.h (+5-2)
  • (modified) mlir/include/mlir/Dialect/X86/X86.td (+384)
  • (modified) mlir/include/mlir/Dialect/X86/X86Dialect.h (+13)
  • (removed) mlir/lib/CAPI/Dialect/AMX.cpp (-13)
  • (modified) mlir/lib/CAPI/Dialect/CMakeLists.txt (-9)
  • (modified) mlir/lib/Conversion/VectorToAMX/CMakeLists.txt (+1-1)
  • (modified) mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp (+32-30)
  • (modified) mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt (-2)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (-8)
  • (removed) mlir/lib/Dialect/AMX/CMakeLists.txt (-2)
  • (removed) mlir/lib/Dialect/AMX/IR/AMXDialect.cpp (-318)
  • (removed) mlir/lib/Dialect/AMX/IR/CMakeLists.txt (-15)
  • (removed) mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt (-9)
  • (removed) mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp (-70)
  • (modified) mlir/lib/Dialect/CMakeLists.txt (-1)
  • (modified) mlir/lib/Dialect/X86/IR/X86Dialect.cpp (+285)
  • (modified) mlir/lib/Dialect/X86/Transforms/LegalizeForLLVMExport.cpp (+24-1)
  • (modified) mlir/lib/RegisterAllDialects.cpp (-2)
  • (modified) mlir/lib/RegisterAllExtensions.cpp (+2-2)
  • (modified) mlir/test/CMakeLists.txt (+2-2)
  • (modified) mlir/test/Conversion/VectorToAMX/contract-to-amx.mlir (+14-14)
  • (modified) mlir/test/Conversion/VectorToAMX/transfer-to-amx.mlir (+9-9)
  • (modified) mlir/test/Conversion/VectorToLLVM/pass-option-serialization.mlir (-1)
  • (removed) mlir/test/Dialect/AMX/invalid.mlir (-158)
  • (removed) mlir/test/Dialect/AMX/roundtrip.mlir (-77)
  • (removed) mlir/test/Dialect/AMX/side-effects.mlir (-32)
  • (modified) mlir/test/Dialect/Linalg/invalid.mlir (+9-9)
  • (added) mlir/test/Dialect/X86/AMX/invalid.mlir (+158)
  • (renamed) mlir/test/Dialect/X86/AMX/legalize-for-llvm.mlir (+34-34)
  • (added) mlir/test/Dialect/X86/AMX/roundtrip.mlir (+77)
  • (added) mlir/test/Dialect/X86/AMX/side-effects.mlir (+32)
  • (renamed) mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/lit.local.cfg (+1-1)
  • (renamed) mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/mulf-full.mlir (+6-6)
  • (renamed) mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/mulf.mlir (+11-11)
  • (renamed) mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/muli-ext.mlir (+21-21)
  • (renamed) mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/muli-full.mlir (+6-6)
  • (renamed) mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/muli.mlir (+11-11)
  • (renamed) mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/tilezero-block.mlir (+3-3)
  • (renamed) mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/tilezero.mlir (+3-3)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/X86/dot.mlir (+4-3)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/X86/sparse-dot-product.mlir (+11-13)
  • (modified) mlir/test/Target/LLVMIR/amx.mlir (+42-42)
  • (modified) mlir/test/lit.site.cfg.py.in (+1-1)
  • (modified) mlir/test/mlir-opt/commandline.mlir (-1)
diff --git a/mlir/Maintainers.md b/mlir/Maintainers.md
index a023ee0ea1bba..181541a0f3a93 100644
--- a/mlir/Maintainers.md
+++ b/mlir/Maintainers.md
@@ -104,7 +104,6 @@ available, should be contacted first, as they're more active in those areas.
 * ‘arm_neon’ Dialect ([@banach-space](https://github.com/banach-space))
 * ‘arm_sve’ Dialect ([@banach-space](https://github.com/banach-space))
 * ‘ArmSME’ Dialect ([@banach-space](https://github.com/banach-space))
-* ‘amx’ Dialect ([@adam-smnk](https://github.com/adam-smnk))
 * ‘x86’ Dialect ([@adam-smnk](https://github.com/adam-smnk))
 * ‘vcix’ Dialect ([@mshockwave](https://github.com/mshockwave))
 
diff --git a/mlir/docs/TargetLLVMIR.md b/mlir/docs/TargetLLVMIR.md
index 2bdf400a7759f..2bcacbb4ee946 100644
--- a/mlir/docs/TargetLLVMIR.md
+++ b/mlir/docs/TargetLLVMIR.md
@@ -5,8 +5,8 @@ overall flow is two-stage:
 
 1.  **conversion** of the IR to a set of dialects translatable to LLVM IR, for
     example [LLVM Dialect](Dialects/LLVM.md) or one of the hardware-specific
-    dialects derived from LLVM IR intrinsics such as [AMX](Dialects/AMX.md),
-    [X86](Dialects/X86.md) or [ArmNeon](Dialects/ArmNeon.md);
+    dialects derived from LLVM IR intrinsics such as [X86](Dialects/X86.md)
+    or [ArmNeon](Dialects/ArmNeon.md);
 2.  **translation** of MLIR dialects to LLVM IR.
 
 This flow allows the non-trivial transformation to be performed within MLIR
diff --git a/mlir/include/mlir-c/Dialect/AMX.h b/mlir/include/mlir-c/Dialect/AMX.h
deleted file mode 100644
index ac4695a107ae6..0000000000000
--- a/mlir/include/mlir-c/Dialect/AMX.h
+++ /dev/null
@@ -1,25 +0,0 @@
-//===-- mlir-c/Dialect/AMX.h - C API for AMX Dialect --------*- C -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM
-// Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_C_DIALECT_AMX_H
-#define MLIR_C_DIALECT_AMX_H
-
-#include "mlir-c/IR.h"
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(AMX, amx);
-
-#ifdef __cplusplus
-}
-#endif
-
-#endif // MLIR_C_DIALECT_AMX_H
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index ecc22abb0f935..e77860897399f 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1521,8 +1521,8 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
     operations. The lowering pass provides several options to control
     the kinds of optimizations that are allowed. It also provides options
     that enable the use of one or more architectural-specific dialects
-    (AMX, X86, ArmNeon, ArmSVE, etc.) in combination with the
-    architectural-neutral vector dialect lowering.
+    (X86, ArmNeon, ArmSVE, etc.) in combination with the architectural-neutral
+    vector dialect lowering.
 
   }];
   // Override explicitly in C++ to allow conditional dialect dependence.
@@ -1544,10 +1544,6 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
            "vector access are naturally aligned. If operations have an "
            "alignment attribute set, the alignment attribute takes priority "
            "over this option ">,
-    Option<"amx", "enable-amx",
-           "bool", /*default=*/"false",
-           "Enables the use of AMX dialect while lowering the vector "
-	   "dialect.">,
     Option<"armNeon", "enable-arm-neon",
            "bool", /*default=*/"false",
            "Enables the use of ArmNeon dialect while lowering the vector "
@@ -1626,10 +1622,10 @@ def ConvertVectorToXeGPU : Pass<"convert-vector-to-xegpu"> {
 //===----------------------------------------------------------------------===//
 
 def ConvertVectorToAMX : Pass<"convert-vector-to-amx"> {
-  let summary = "Lower the operations from the vector dialect into the AMX "
-                "dialect";
+  let summary = "Lower the operations from the vector dialect into the X86 "
+                "dialect AMX operations";
   let dependentDialects = [
-    "affine::AffineDialect", "amx::AMXDialect", "arith::ArithDialect",
+    "affine::AffineDialect", "x86::X86Dialect", "arith::ArithDialect",
     "memref::MemRefDialect", "scf::SCFDialect", "vector::VectorDialect"
   ];
 }
diff --git a/mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h b/mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h
index b075ac92990a2..6b178e02684c0 100644
--- a/mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h
+++ b/mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h
@@ -1,4 +1,4 @@
-//===- VectorToAMX.h - Convert vector to AMX dialect ------------*- C++ -*-===//
+//===- VectorToAMX.h - Convert vector to X86 dialect AMX ops ----*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -18,7 +18,7 @@ class RewritePatternSet;
 #define GEN_PASS_DECL_CONVERTVECTORTOAMX
 #include "mlir/Conversion/Passes.h.inc"
 
-/// Collect a set of patterns to convert from the vector to AMX ops.
+/// Collect a set of patterns to convert from the vector to X86 AMX ops.
 void populateVectorToAMXConversionPatterns(RewritePatternSet &patterns);
 
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td
deleted file mode 100644
index cace63d32fd80..0000000000000
--- a/mlir/include/mlir/Dialect/AMX/AMX.td
+++ /dev/null
@@ -1,440 +0,0 @@
-//===-- AMX.td - AMX dialect operation definitions *- tablegen -*----------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file defines the basic operations for the AMX dialect.
-//
-// The Intel Advanced Matrix Extensions (AMX) provide a tile matrix
-// multiply unit (TMUL), a tile control register (TILECFG), and eight
-// tile registers TMM0 through TMM7 (TILEDATA).
-//
-// The AMX dialect provides a bridge between MLIR concepts, such as
-// 2-d vector, operations, and memrefs, and the lower level details
-// of Intel AMX, such as configuration setup, tile sizes, instructions,
-// and tile release.
-//
-// Note that since configuration changes (implicit at dialect level) are
-// costly, it is highly recommended to use the AMX dialect on same-shaped
-// vectors, at least within a single method.
-//
-// https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef AMX
-#define AMX
-
-include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
-include "mlir/Dialect/AMX/AMXInterfaces.td"
-include "mlir/Interfaces/SideEffectInterfaces.td"
-include "mlir/IR/AttrTypeBase.td"
-include "mlir/IR/BuiltinTypes.td"
-
-//===----------------------------------------------------------------------===//
-// AMX dialect definition.
-//===----------------------------------------------------------------------===//
-
-def AMX_Dialect : Dialect {
-  let name = "amx";
-  let cppNamespace = "::mlir::amx";
-  let description = [{
-    The Intel Advanced Matrix Extensions (AMX) provide a tile matrix
-    multiply unit (TMUL), a tile control register (TILECFG), and eight
-    tile registers TMM0 through TMM7 (TILEDATA).
-
-    This `AMX` dialect provides a bridge between MLIR concepts such as
-    vectors and memrefs and the lower level LLVM IR support of AMX.
-
-    Note that since configuration changes (implicit at dialect level) are
-    costly, it is highly recommended to use the AMX dialect on same-shaped
-    vectors, at least within a single method.
-
-    For details, see the Intel documentation:
-    https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html
-  }];
-  let useDefaultTypePrinterParser = 1;
-}
-
-//===----------------------------------------------------------------------===//
-// AMX Tile definition.
-//===----------------------------------------------------------------------===//
-
-class AMX_Type<string typeName, string typeMnemonic, list<Trait> traits = []>
-    : TypeDef<AMX_Dialect, typeName, traits> {
-  let mnemonic = typeMnemonic;
-}
-
-def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> {
-  let cppFunctionName = "isValidTileTypeElementType";
-}
-
-def AMX_TileType : AMX_Type<"Tile", "tile", [ShapedTypeInterface, ValueSemantics]> {
-  let summary = "AMX 2D tile to be used by AMX opertaions.";
-
-  let description = [{
-    This type is used to represent values in AMX tile registers. All AMX operations
-    work on AMX tiles and these tiles cannot be used in other operations directly.
-    LLVM IR type for AMX tile is a primitive type, but in MLIR we provide shape and
-    element type for IR verification and lowering to LLVMIR dialect.
-  }];
-
-  let parameters = (ins
-    ArrayRefParameter<"int64_t">:$shape,
-    AMX_TileTypeElementType:$elementType
-  );
-
-  let builders = [
-    TypeBuilderWithInferredContext<(ins
-      "ArrayRef<int64_t>":$shape, "Type":$elementType), [{
-      return $_get(elementType.getContext(), shape, elementType);
-    }]>
-  ];
-
-  let extraClassDeclaration = [{
-    /// Returns if this type is ranked (always true).
-    bool hasRank() const { return true; }
-
-    /// Clone this tile type with the given shape and element type. If the
-    /// provided shape is `std::nullopt`, the current shape of the type is used.
-    TileType cloneWith(std::optional<ArrayRef<int64_t>> shape,
-                       Type elementType) const {
-      return get(shape.value_or(getShape()), elementType);
-    }
-  }];
-
-  let hasCustomAssemblyFormat = 1;
-  let skipDefaultBuilders = 1;
-}
-
-def IsAMXTilePred : And<[CPred<"::llvm::isa<::mlir::amx::TileType>($_self)">,
-  CPred<[{::llvm::cast<::mlir::amx::TileType>($_self).getRank() == 2}]>]>;
-
-class AMXTileOf<list<Type> allowedTypes> :
-  ShapedContainerType<allowedTypes, IsAMXTilePred, "tile",
-                      "::mlir::amx::TileType">;
-
-def AnyAMXTile : AMXTileOf<[F32, F16, BF16, I32, I8]>;
-
-def AMXTileF32 : AMXTileOf<[F32]>;
-
-def AMXTileF16OrBF16 : AMXTileOf<[F16, BF16]>;
-
-def AMXTileI32 : AMXTileOf<[I32]>;
-
-def AMXTileI8 : AMXTileOf<[I8]>;
-
-//===----------------------------------------------------------------------===//
-// AMX Op and IntrOp definitions.
-//===----------------------------------------------------------------------===//
-
-class AMX_Op<string mnemonic, list<Trait> traits = []> :
-  Op<AMX_Dialect, mnemonic, traits> {}
-
-//===----------------------------------------------------------------------===//
-// AMX Op definitions
-//===----------------------------------------------------------------------===//
-
-//
-// Tile reset.
-//
-
-def TileZeroOp : AMX_Op<"tile_zero", [
-    AMXIntrinsicOpInterface,
-    MemoryEffects<[MemWrite]>
-  ]> {
-  let summary = "tile zero operation";
-  let description = [{
-    Zeroes the destination tile, with the shape defined by the 2-dim
-    vector type of the result.
-    
-    The operation is eventually lowered into the "tilezero" instruction
-    with the corresponding tile configuration.
-    
-    With the write memory effect, each `amx.tile_zero` operation serves as
-    a compilation hint to use a separate tile register.
-
-    Example:
-
-    ```mlir
-      %0 = amx.tile_zero : !amx.tile<16x16xbf16>
-    ```
-  }];
-  let results = (outs AnyAMXTile:$res);
-  let extraClassDeclaration = [{
-    TileType getTileType() {
-      return ::llvm::cast<TileType>(getRes().getType());
-    }
-
-    std::string getIntrinsicName() {
-      return "llvm.x86.tilezero.internal";
-    }
-    SmallVector<Value> getIntrinsicOperands(
-        ::mlir::ArrayRef<Value> operands,
-        const ::mlir::LLVMTypeConverter &typeConverter,
-        ::mlir::RewriterBase &rewriter);
-  }];
-  let assemblyFormat = "attr-dict `:` qualified(type($res))";
-  let hasVerifier = 1;
-}
-
-//
-// Tile memory operations.
-//
-
-def TileLoadOp : AMX_Op<"tile_load", [
-    AMXIntrinsicOpInterface,
-    MemoryEffects<[MemWrite]>,
-    AttrSizedOperandSegments
-  ]> {
-  let summary = "tile load operation";
-  let description = [{
-    Loads a tile from memory defined by a `base` and `indices`, with the
-    shape defined by the 2-dim vector type of the result.
-    The tile's rows are populated by reading contiguous elements starting
-    at the `base`. For each tile row, the `base` is incremented by `stride`
-    number of elements.
-
-    The tile is loaded using the following indexing scheme:
-
-    ```
-    for row in enumerate(tile_rows):
-      mem_row = base[i0, i1, ..., iN + row * stride]
-      for col in enumerate(tile_cols):
-        tile[row, col] = mem_row[col]
-    ```
-
-    If the `stride` is not provided, then the `base` buffer must be at least
-    2-dimensional, and the `stride` is automatically inferred and corresponds
-    to the stride of the buffer's second innermost dimension.
-
-    The operation is eventually lowered into the "tileloadd" instruction
-    with the corresponding tile configuration.
-
-    With the write memory effect, each `amx.tile_load` operation serves as
-    a compilation hint to use a separate tile register.
-
-    Example:
-
-    ```mlir
-      // Tile load from a 2-D memref with implicit stride.
-      %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into !amx.tile<16x64xi8>
-
-      // Tile load from a 1-D memref with explicit stride.
-      %0 = amx.tile_load %arg0[%c0], %stride : memref<?xi8> into !amx.tile<16x64xi8>
-    ```
-  }];
-  let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
-                   Variadic<Index>:$indices,
-                   Optional<Index>:$stride);
-  let results = (outs AnyAMXTile:$res);
-  let builders = [
-    OpBuilder<(ins "Type":$res, "Value":$base, "ValueRange":$indices)>
-  ];
-  let extraClassDeclaration = [{
-    MemRefType getMemRefType() {
-      return ::llvm::cast<MemRefType>(getBase().getType());
-    }
-    TileType getTileType() {
-      return ::llvm::cast<TileType>(getRes().getType());
-    }
-
-    std::string getIntrinsicName() {
-      return "llvm.x86.tileloadd64.internal";
-    }
-    SmallVector<Value> getIntrinsicOperands(
-        ::mlir::ArrayRef<Value> operands,
-        const ::mlir::LLVMTypeConverter &typeConverter,
-        ::mlir::RewriterBase &rewriter);
-  }];
-  let assemblyFormat = "$base `[` $indices `]` (`,` $stride^ )? attr-dict"
-                       "`:` type($base) `into` qualified(type($res))";
-  let hasVerifier = 1;
-}
-
-def TileStoreOp : AMX_Op<"tile_store", [
-    AMXIntrinsicOpInterface,
-    AttrSizedOperandSegments
-  ]> {
-  let summary = "tile store operation";
-  let description = [{
-    Stores a tile to memory defined by a `base` and `indices`, with the
-    shape defined by the 2-dim vector type of the value.
-    The tile's rows are written contiguously to the buffer starting at
-    the `base`. For each tile row, the `base` is incremented by `stride`
-    number of elements.
-
-    The tile is stored using the following indexing scheme:
-
-    ```
-    for row in enumerate(tile_rows):
-      mem_row = base[i0, i1, ..., iN + row * stride]
-      for col in enumerate(tile_cols):
-        mem_row[col] = tile[row, col]
-    ```
-
-    If the `stride` is not provided, then the `base` buffer must be at least
-    2-dimensional, and the `stride` is automatically inferred and corresponds
-    to the stride of the buffer's second innermost dimension.
-
-    The operation is eventually lowered into the "tilestored" instruction
-    with the corresponding tile configuration.
-
-    Example:
-
-    ```mlir
-      // Tile store to a 2-D memref with implicit stride.
-      amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, !amx.tile<16x64xi8>
-
-      // Tile store to a 1-D memref with explicit stride.
-      amx.tile_store %arg1[%c0], %0, %stride : memref<?xi8>, !amx.tile<16x64xi8>
-    ```
-  }];
-  let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
-                   Variadic<Index>:$indices,
-                   AnyAMXTile:$val,
-                   Optional<Index>:$stride);
-  let builders = [
-    OpBuilder<(ins "Value":$base, "ValueRange":$indices, "Value":$val)>
-  ];
-  let extraClassDeclaration = [{
-    MemRefType getMemRefType() {
-      return ::llvm::cast<MemRefType>(getBase().getType());
-    }
-    TileType getTileType() {
-      return ::llvm::cast<TileType>(getVal().getType());
-    }
-
-    std::string getIntrinsicName() {
-      return "llvm.x86.tilestored64.internal";
-    }
-    SmallVector<Value> getIntrinsicOperands(
-        ::mlir::ArrayRef<Value> operands,
-        const ::mlir::LLVMTypeConverter &typeConverter,
-        ::mlir::RewriterBase &rewriter);
-  }];
-  let assemblyFormat = "$base `[` $indices `]` `,` $val (`,` $stride^ )?"
-                       "attr-dict `:` type($base) `,` qualified(type($val))";
-  let hasVerifier = 1;
-}
-
-//
-// Tile arithmetic operations.
-//
-
-def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
-    AMXIntrinsicOpInterface,
-    AllTypesMatch<["acc", "res"]>
-  ]> {
-  let summary = "tile multiplication operation (floating-point)";
-  let description = [{
-    Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
-    into a "m x n" destination tile. Supports "f32 <- bf16 x bf16" (with
-    pairs of "bf16").
-    
-    The operation is eventually lowered into the "tdpbf16ps" instruction with
-    the corresponding tile configuration.
-
-    Example:
-
-    ```mlir
-      %0 = amx.tile_mulf %a, %b, %c
-        : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
-    ```
-  }];
-  let arguments = (ins AMXTileF16OrBF16:$lhs,
-                       AMXTileF16OrBF16:$rhs,
-                       AMXTileF32:$acc);
-  let results = (outs AMXTileF32:$res);
-  let extraClassDeclaration = [{
-    TileType getLhsTileType() {
-      return ::llvm::cast<TileType>(getLhs().getType());
-    }
-    TileType getRhsTileType() {
-      return ::llvm::cast<TileType>(getRhs().getType());
-    }
-    TileType getTileType() {
-      return ::llvm::cast<TileType>(getRes().getType());
-    }
-
-    std::string getIntrinsicName() {
-      std::string intr = "llvm.x86.tdp";
-      auto elementType =
-        getLhsTileType().getElementType();
-      intr += elementType.isF16() ? "fp16" : "bf16";
-      intr += "ps.internal";
-      return intr;
-    }
-    SmallVector<Value> getIntrinsicOperands(
-        ::mlir::ArrayRef<Value> operands,
-        const ::mlir::LLVMTypeConverter &typeConverter,
-        ::mlir::RewriterBase &rewriter);
-  }];
-  let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
-                       "qualified(type($lhs)) `,` qualified(type($rhs))"
-                       " `,` qualified(type($acc)) ";
-  let hasVerifier = 1;
-}
-
-def TileMulIOp : AMX_Op<"tile_muli", [Pure,
-    AMXIntrinsicOpInterface,
-    AllTypesMatch<["acc", "res"]>
-  ]> {
-  let summary = "tile multiplication operation (integer)";
-  let description = [{
-    Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
-    into a "m x n" destination tile. Supports all "si32 <- s/ui8 x s/ui8"
-    combinations (4 bytes packed into dwords in the columns of both the
-    source operand tiles; the zero or sign extension is specified with
-    the attributes and default to sign extended).
-    
-    The operation is eventually lowered into one of the "tdpbssd",
-    "tdpbsud", "tdpbusd", or "tdpbuud" instructions with the corresponding
-    tile configuration.
-
-    Example:
-
-    ```mlir
-      %0 = amx.tile_muli %a zext, %b zext, %c
-        : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
-    ```
-  }];
-  let arguments = (ins AMXTileI8:$lhs,
-                       AMXTileI8:$rhs,
-                       AMXTileI32:$acc,
-                       UnitAttr:$isZextLhs,
-                       UnitAttr:$isZextRhs
-                       );
-  let results = (outs AMXTileI32:$res);
-  let extraClassDeclaration = [{
-    TileType getLhsTileType() {
-      return ::llvm::cast<TileType>(getLhs().getType());
-    }
-    TileType getRhsT...
[truncated]

@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Feb 27, 2026

@llvm/pr-subscribers-mlir-amx

Author: Adam Siemieniuk (adam-smnk)

Changes

Unifies the two dialects that define x86 operations into a single one. The AMX dialect is moved into X86 in line with other x86 extensions.

The two dialects are simply merged together. X86 dialect refactoring will be addressed separately.

List of changes:

  • operations: 'amx.tile_' => 'x86.amx.tile_'
  • types: '!amx.tile' => '!x86.amx.tile'
  • namespace: 'mlir::amx' => 'mlir::x86::amx'
  • test define: 'MLIR_RUN_AMX_TESTS' => 'MLIR_RUN_X86_AMX_TESTS'
  • vector lowering: AMX is enabled by default together with X86

The MLIR AMX tests are now nested under X86 directory. To enable AMX integration tests, 'MLIR_RUN_X86_TESTS' must also be defined.


Patch is 163.24 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/183717.diff

56 Files Affected:

  • (modified) mlir/Maintainers.md (-1)
  • (modified) mlir/docs/TargetLLVMIR.md (+2-2)
  • (removed) mlir/include/mlir-c/Dialect/AMX.h (-25)
  • (modified) mlir/include/mlir/Conversion/Passes.td (+5-9)
  • (modified) mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h (+2-2)
  • (removed) mlir/include/mlir/Dialect/AMX/AMX.td (-440)
  • (removed) mlir/include/mlir/Dialect/AMX/AMXDialect.h (-34)
  • (removed) mlir/include/mlir/Dialect/AMX/AMXInterfaces.td (-31)
  • (removed) mlir/include/mlir/Dialect/AMX/CMakeLists.txt (-5)
  • (removed) mlir/include/mlir/Dialect/AMX/Transforms.h (-33)
  • (modified) mlir/include/mlir/Dialect/CMakeLists.txt (-1)
  • (modified) mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h (-5)
  • (modified) mlir/include/mlir/Dialect/X86/Transforms.h (+5-2)
  • (modified) mlir/include/mlir/Dialect/X86/X86.td (+384)
  • (modified) mlir/include/mlir/Dialect/X86/X86Dialect.h (+13)
  • (removed) mlir/lib/CAPI/Dialect/AMX.cpp (-13)
  • (modified) mlir/lib/CAPI/Dialect/CMakeLists.txt (-9)
  • (modified) mlir/lib/Conversion/VectorToAMX/CMakeLists.txt (+1-1)
  • (modified) mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp (+32-30)
  • (modified) mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt (-2)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (-8)
  • (removed) mlir/lib/Dialect/AMX/CMakeLists.txt (-2)
  • (removed) mlir/lib/Dialect/AMX/IR/AMXDialect.cpp (-318)
  • (removed) mlir/lib/Dialect/AMX/IR/CMakeLists.txt (-15)
  • (removed) mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt (-9)
  • (removed) mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp (-70)
  • (modified) mlir/lib/Dialect/CMakeLists.txt (-1)
  • (modified) mlir/lib/Dialect/X86/IR/X86Dialect.cpp (+285)
  • (modified) mlir/lib/Dialect/X86/Transforms/LegalizeForLLVMExport.cpp (+24-1)
  • (modified) mlir/lib/RegisterAllDialects.cpp (-2)
  • (modified) mlir/lib/RegisterAllExtensions.cpp (+2-2)
  • (modified) mlir/test/CMakeLists.txt (+2-2)
  • (modified) mlir/test/Conversion/VectorToAMX/contract-to-amx.mlir (+14-14)
  • (modified) mlir/test/Conversion/VectorToAMX/transfer-to-amx.mlir (+9-9)
  • (modified) mlir/test/Conversion/VectorToLLVM/pass-option-serialization.mlir (-1)
  • (removed) mlir/test/Dialect/AMX/invalid.mlir (-158)
  • (removed) mlir/test/Dialect/AMX/roundtrip.mlir (-77)
  • (removed) mlir/test/Dialect/AMX/side-effects.mlir (-32)
  • (modified) mlir/test/Dialect/Linalg/invalid.mlir (+9-9)
  • (added) mlir/test/Dialect/X86/AMX/invalid.mlir (+158)
  • (renamed) mlir/test/Dialect/X86/AMX/legalize-for-llvm.mlir (+34-34)
  • (added) mlir/test/Dialect/X86/AMX/roundtrip.mlir (+77)
  • (added) mlir/test/Dialect/X86/AMX/side-effects.mlir (+32)
  • (renamed) mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/lit.local.cfg (+1-1)
  • (renamed) mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/mulf-full.mlir (+6-6)
  • (renamed) mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/mulf.mlir (+11-11)
  • (renamed) mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/muli-ext.mlir (+21-21)
  • (renamed) mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/muli-full.mlir (+6-6)
  • (renamed) mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/muli.mlir (+11-11)
  • (renamed) mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/tilezero-block.mlir (+3-3)
  • (renamed) mlir/test/Integration/Dialect/Vector/CPU/X86/AMX/tilezero.mlir (+3-3)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/X86/dot.mlir (+4-3)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/X86/sparse-dot-product.mlir (+11-13)
  • (modified) mlir/test/Target/LLVMIR/amx.mlir (+42-42)
  • (modified) mlir/test/lit.site.cfg.py.in (+1-1)
  • (modified) mlir/test/mlir-opt/commandline.mlir (-1)
diff --git a/mlir/Maintainers.md b/mlir/Maintainers.md
index a023ee0ea1bba..181541a0f3a93 100644
--- a/mlir/Maintainers.md
+++ b/mlir/Maintainers.md
@@ -104,7 +104,6 @@ available, should be contacted first, as they're more active in those areas.
 * ‘arm_neon’ Dialect ([@banach-space](https://github.com/banach-space))
 * ‘arm_sve’ Dialect ([@banach-space](https://github.com/banach-space))
 * ‘ArmSME’ Dialect ([@banach-space](https://github.com/banach-space))
-* ‘amx’ Dialect ([@adam-smnk](https://github.com/adam-smnk))
 * ‘x86’ Dialect ([@adam-smnk](https://github.com/adam-smnk))
 * ‘vcix’ Dialect ([@mshockwave](https://github.com/mshockwave))
 
diff --git a/mlir/docs/TargetLLVMIR.md b/mlir/docs/TargetLLVMIR.md
index 2bdf400a7759f..2bcacbb4ee946 100644
--- a/mlir/docs/TargetLLVMIR.md
+++ b/mlir/docs/TargetLLVMIR.md
@@ -5,8 +5,8 @@ overall flow is two-stage:
 
 1.  **conversion** of the IR to a set of dialects translatable to LLVM IR, for
     example [LLVM Dialect](Dialects/LLVM.md) or one of the hardware-specific
-    dialects derived from LLVM IR intrinsics such as [AMX](Dialects/AMX.md),
-    [X86](Dialects/X86.md) or [ArmNeon](Dialects/ArmNeon.md);
+    dialects derived from LLVM IR intrinsics such as [X86](Dialects/X86.md)
+    or [ArmNeon](Dialects/ArmNeon.md);
 2.  **translation** of MLIR dialects to LLVM IR.
 
 This flow allows the non-trivial transformation to be performed within MLIR
diff --git a/mlir/include/mlir-c/Dialect/AMX.h b/mlir/include/mlir-c/Dialect/AMX.h
deleted file mode 100644
index ac4695a107ae6..0000000000000
--- a/mlir/include/mlir-c/Dialect/AMX.h
+++ /dev/null
@@ -1,25 +0,0 @@
-//===-- mlir-c/Dialect/AMX.h - C API for AMX Dialect --------*- C -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM
-// Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef MLIR_C_DIALECT_AMX_H
-#define MLIR_C_DIALECT_AMX_H
-
-#include "mlir-c/IR.h"
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(AMX, amx);
-
-#ifdef __cplusplus
-}
-#endif
-
-#endif // MLIR_C_DIALECT_AMX_H
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index ecc22abb0f935..e77860897399f 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1521,8 +1521,8 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
     operations. The lowering pass provides several options to control
     the kinds of optimizations that are allowed. It also provides options
     that enable the use of one or more architectural-specific dialects
-    (AMX, X86, ArmNeon, ArmSVE, etc.) in combination with the
-    architectural-neutral vector dialect lowering.
+    (X86, ArmNeon, ArmSVE, etc.) in combination with the architectural-neutral
+    vector dialect lowering.
 
   }];
   // Override explicitly in C++ to allow conditional dialect dependence.
@@ -1544,10 +1544,6 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
            "vector access are naturally aligned. If operations have an "
            "alignment attribute set, the alignment attribute takes priority "
            "over this option ">,
-    Option<"amx", "enable-amx",
-           "bool", /*default=*/"false",
-           "Enables the use of AMX dialect while lowering the vector "
-	   "dialect.">,
     Option<"armNeon", "enable-arm-neon",
            "bool", /*default=*/"false",
            "Enables the use of ArmNeon dialect while lowering the vector "
@@ -1626,10 +1622,10 @@ def ConvertVectorToXeGPU : Pass<"convert-vector-to-xegpu"> {
 //===----------------------------------------------------------------------===//
 
 def ConvertVectorToAMX : Pass<"convert-vector-to-amx"> {
-  let summary = "Lower the operations from the vector dialect into the AMX "
-                "dialect";
+  let summary = "Lower the operations from the vector dialect into the X86 "
+                "dialect AMX operations";
   let dependentDialects = [
-    "affine::AffineDialect", "amx::AMXDialect", "arith::ArithDialect",
+    "affine::AffineDialect", "x86::X86Dialect", "arith::ArithDialect",
     "memref::MemRefDialect", "scf::SCFDialect", "vector::VectorDialect"
   ];
 }
diff --git a/mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h b/mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h
index b075ac92990a2..6b178e02684c0 100644
--- a/mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h
+++ b/mlir/include/mlir/Conversion/VectorToAMX/VectorToAMX.h
@@ -1,4 +1,4 @@
-//===- VectorToAMX.h - Convert vector to AMX dialect ------------*- C++ -*-===//
+//===- VectorToAMX.h - Convert vector to X86 dialect AMX ops ----*- C++ -*-===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -18,7 +18,7 @@ class RewritePatternSet;
 #define GEN_PASS_DECL_CONVERTVECTORTOAMX
 #include "mlir/Conversion/Passes.h.inc"
 
-/// Collect a set of patterns to convert from the vector to AMX ops.
+/// Collect a set of patterns to convert from the vector to X86 AMX ops.
 void populateVectorToAMXConversionPatterns(RewritePatternSet &patterns);
 
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td
deleted file mode 100644
index cace63d32fd80..0000000000000
--- a/mlir/include/mlir/Dialect/AMX/AMX.td
+++ /dev/null
@@ -1,440 +0,0 @@
-//===-- AMX.td - AMX dialect operation definitions *- tablegen -*----------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This file defines the basic operations for the AMX dialect.
-//
-// The Intel Advanced Matrix Extensions (AMX) provide a tile matrix
-// multiply unit (TMUL), a tile control register (TILECFG), and eight
-// tile registers TMM0 through TMM7 (TILEDATA).
-//
-// The AMX dialect provides a bridge between MLIR concepts, such as
-// 2-d vector, operations, and memrefs, and the lower level details
-// of Intel AMX, such as configuration setup, tile sizes, instructions,
-// and tile release.
-//
-// Note that since configuration changes (implicit at dialect level) are
-// costly, it is highly recommended to use the AMX dialect on same-shaped
-// vectors, at least within a single method.
-//
-// https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef AMX
-#define AMX
-
-include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
-include "mlir/Dialect/AMX/AMXInterfaces.td"
-include "mlir/Interfaces/SideEffectInterfaces.td"
-include "mlir/IR/AttrTypeBase.td"
-include "mlir/IR/BuiltinTypes.td"
-
-//===----------------------------------------------------------------------===//
-// AMX dialect definition.
-//===----------------------------------------------------------------------===//
-
-def AMX_Dialect : Dialect {
-  let name = "amx";
-  let cppNamespace = "::mlir::amx";
-  let description = [{
-    The Intel Advanced Matrix Extensions (AMX) provide a tile matrix
-    multiply unit (TMUL), a tile control register (TILECFG), and eight
-    tile registers TMM0 through TMM7 (TILEDATA).
-
-    This `AMX` dialect provides a bridge between MLIR concepts such as
-    vectors and memrefs and the lower level LLVM IR support of AMX.
-
-    Note that since configuration changes (implicit at dialect level) are
-    costly, it is highly recommended to use the AMX dialect on same-shaped
-    vectors, at least within a single method.
-
-    For details, see the Intel documentation:
-    https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html
-  }];
-  let useDefaultTypePrinterParser = 1;
-}
-
-//===----------------------------------------------------------------------===//
-// AMX Tile definition.
-//===----------------------------------------------------------------------===//
-
-class AMX_Type<string typeName, string typeMnemonic, list<Trait> traits = []>
-    : TypeDef<AMX_Dialect, typeName, traits> {
-  let mnemonic = typeMnemonic;
-}
-
-def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> {
-  let cppFunctionName = "isValidTileTypeElementType";
-}
-
-def AMX_TileType : AMX_Type<"Tile", "tile", [ShapedTypeInterface, ValueSemantics]> {
-  let summary = "AMX 2D tile to be used by AMX opertaions.";
-
-  let description = [{
-    This type is used to represent values in AMX tile registers. All AMX operations
-    work on AMX tiles and these tiles cannot be used in other operations directly.
-    LLVM IR type for AMX tile is a primitive type, but in MLIR we provide shape and
-    element type for IR verification and lowering to LLVMIR dialect.
-  }];
-
-  let parameters = (ins
-    ArrayRefParameter<"int64_t">:$shape,
-    AMX_TileTypeElementType:$elementType
-  );
-
-  let builders = [
-    TypeBuilderWithInferredContext<(ins
-      "ArrayRef<int64_t>":$shape, "Type":$elementType), [{
-      return $_get(elementType.getContext(), shape, elementType);
-    }]>
-  ];
-
-  let extraClassDeclaration = [{
-    /// Returns if this type is ranked (always true).
-    bool hasRank() const { return true; }
-
-    /// Clone this tile type with the given shape and element type. If the
-    /// provided shape is `std::nullopt`, the current shape of the type is used.
-    TileType cloneWith(std::optional<ArrayRef<int64_t>> shape,
-                       Type elementType) const {
-      return get(shape.value_or(getShape()), elementType);
-    }
-  }];
-
-  let hasCustomAssemblyFormat = 1;
-  let skipDefaultBuilders = 1;
-}
-
-def IsAMXTilePred : And<[CPred<"::llvm::isa<::mlir::amx::TileType>($_self)">,
-  CPred<[{::llvm::cast<::mlir::amx::TileType>($_self).getRank() == 2}]>]>;
-
-class AMXTileOf<list<Type> allowedTypes> :
-  ShapedContainerType<allowedTypes, IsAMXTilePred, "tile",
-                      "::mlir::amx::TileType">;
-
-def AnyAMXTile : AMXTileOf<[F32, F16, BF16, I32, I8]>;
-
-def AMXTileF32 : AMXTileOf<[F32]>;
-
-def AMXTileF16OrBF16 : AMXTileOf<[F16, BF16]>;
-
-def AMXTileI32 : AMXTileOf<[I32]>;
-
-def AMXTileI8 : AMXTileOf<[I8]>;
-
-//===----------------------------------------------------------------------===//
-// AMX Op and IntrOp definitions.
-//===----------------------------------------------------------------------===//
-
-class AMX_Op<string mnemonic, list<Trait> traits = []> :
-  Op<AMX_Dialect, mnemonic, traits> {}
-
-//===----------------------------------------------------------------------===//
-// AMX Op definitions
-//===----------------------------------------------------------------------===//
-
-//
-// Tile reset.
-//
-
-def TileZeroOp : AMX_Op<"tile_zero", [
-    AMXIntrinsicOpInterface,
-    MemoryEffects<[MemWrite]>
-  ]> {
-  let summary = "tile zero operation";
-  let description = [{
-    Zeroes the destination tile, with the shape defined by the 2-dim
-    vector type of the result.
-    
-    The operation is eventually lowered into the "tilezero" instruction
-    with the corresponding tile configuration.
-    
-    With the write memory effect, each `amx.tile_zero` operation serves as
-    a compilation hint to use a separate tile register.
-
-    Example:
-
-    ```mlir
-      %0 = amx.tile_zero : !amx.tile<16x16xbf16>
-    ```
-  }];
-  let results = (outs AnyAMXTile:$res);
-  let extraClassDeclaration = [{
-    TileType getTileType() {
-      return ::llvm::cast<TileType>(getRes().getType());
-    }
-
-    std::string getIntrinsicName() {
-      return "llvm.x86.tilezero.internal";
-    }
-    SmallVector<Value> getIntrinsicOperands(
-        ::mlir::ArrayRef<Value> operands,
-        const ::mlir::LLVMTypeConverter &typeConverter,
-        ::mlir::RewriterBase &rewriter);
-  }];
-  let assemblyFormat = "attr-dict `:` qualified(type($res))";
-  let hasVerifier = 1;
-}
-
-//
-// Tile memory operations.
-//
-
-def TileLoadOp : AMX_Op<"tile_load", [
-    AMXIntrinsicOpInterface,
-    MemoryEffects<[MemWrite]>,
-    AttrSizedOperandSegments
-  ]> {
-  let summary = "tile load operation";
-  let description = [{
-    Loads a tile from memory defined by a `base` and `indices`, with the
-    shape defined by the 2-dim vector type of the result.
-    The tile's rows are populated by reading contiguous elements starting
-    at the `base`. For each tile row, the `base` is incremented by `stride`
-    number of elements.
-
-    The tile is loaded using the following indexing scheme:
-
-    ```
-    for row in enumerate(tile_rows):
-      mem_row = base[i0, i1, ..., iN + row * stride]
-      for col in enumerate(tile_cols):
-        tile[row, col] = mem_row[col]
-    ```
-
-    If the `stride` is not provided, then the `base` buffer must be at least
-    2-dimensional, and the `stride` is automatically inferred and corresponds
-    to the stride of the buffer's second innermost dimension.
-
-    The operation is eventually lowered into the "tileloadd" instruction
-    with the corresponding tile configuration.
-
-    With the write memory effect, each `amx.tile_load` operation serves as
-    a compilation hint to use a separate tile register.
-
-    Example:
-
-    ```mlir
-      // Tile load from a 2-D memref with implicit stride.
-      %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into !amx.tile<16x64xi8>
-
-      // Tile load from a 1-D memref with explicit stride.
-      %0 = amx.tile_load %arg0[%c0], %stride : memref<?xi8> into !amx.tile<16x64xi8>
-    ```
-  }];
-  let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
-                   Variadic<Index>:$indices,
-                   Optional<Index>:$stride);
-  let results = (outs AnyAMXTile:$res);
-  let builders = [
-    OpBuilder<(ins "Type":$res, "Value":$base, "ValueRange":$indices)>
-  ];
-  let extraClassDeclaration = [{
-    MemRefType getMemRefType() {
-      return ::llvm::cast<MemRefType>(getBase().getType());
-    }
-    TileType getTileType() {
-      return ::llvm::cast<TileType>(getRes().getType());
-    }
-
-    std::string getIntrinsicName() {
-      return "llvm.x86.tileloadd64.internal";
-    }
-    SmallVector<Value> getIntrinsicOperands(
-        ::mlir::ArrayRef<Value> operands,
-        const ::mlir::LLVMTypeConverter &typeConverter,
-        ::mlir::RewriterBase &rewriter);
-  }];
-  let assemblyFormat = "$base `[` $indices `]` (`,` $stride^ )? attr-dict"
-                       "`:` type($base) `into` qualified(type($res))";
-  let hasVerifier = 1;
-}
-
-def TileStoreOp : AMX_Op<"tile_store", [
-    AMXIntrinsicOpInterface,
-    AttrSizedOperandSegments
-  ]> {
-  let summary = "tile store operation";
-  let description = [{
-    Stores a tile to memory defined by a `base` and `indices`, with the
-    shape defined by the 2-dim vector type of the value.
-    The tile's rows are written contiguously to the buffer starting at
-    the `base`. For each tile row, the `base` is incremented by `stride`
-    number of elements.
-
-    The tile is stored using the following indexing scheme:
-
-    ```
-    for row in enumerate(tile_rows):
-      mem_row = base[i0, i1, ..., iN + row * stride]
-      for col in enumerate(tile_cols):
-        mem_row[col] = tile[row, col]
-    ```
-
-    If the `stride` is not provided, then the `base` buffer must be at least
-    2-dimensional, and the `stride` is automatically inferred and corresponds
-    to the stride of the buffer's second innermost dimension.
-
-    The operation is eventually lowered into the "tilestored" instruction
-    with the corresponding tile configuration.
-
-    Example:
-
-    ```mlir
-      // Tile store to a 2-D memref with implicit stride.
-      amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, !amx.tile<16x64xi8>
-
-      // Tile store to a 1-D memref with explicit stride.
-      amx.tile_store %arg1[%c0], %0, %stride : memref<?xi8>, !amx.tile<16x64xi8>
-    ```
-  }];
-  let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
-                   Variadic<Index>:$indices,
-                   AnyAMXTile:$val,
-                   Optional<Index>:$stride);
-  let builders = [
-    OpBuilder<(ins "Value":$base, "ValueRange":$indices, "Value":$val)>
-  ];
-  let extraClassDeclaration = [{
-    MemRefType getMemRefType() {
-      return ::llvm::cast<MemRefType>(getBase().getType());
-    }
-    TileType getTileType() {
-      return ::llvm::cast<TileType>(getVal().getType());
-    }
-
-    std::string getIntrinsicName() {
-      return "llvm.x86.tilestored64.internal";
-    }
-    SmallVector<Value> getIntrinsicOperands(
-        ::mlir::ArrayRef<Value> operands,
-        const ::mlir::LLVMTypeConverter &typeConverter,
-        ::mlir::RewriterBase &rewriter);
-  }];
-  let assemblyFormat = "$base `[` $indices `]` `,` $val (`,` $stride^ )?"
-                       "attr-dict `:` type($base) `,` qualified(type($val))";
-  let hasVerifier = 1;
-}
-
-//
-// Tile arithmetic operations.
-//
-
-def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
-    AMXIntrinsicOpInterface,
-    AllTypesMatch<["acc", "res"]>
-  ]> {
-  let summary = "tile multiplication operation (floating-point)";
-  let description = [{
-    Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
-    into a "m x n" destination tile. Supports "f32 <- bf16 x bf16" (with
-    pairs of "bf16").
-    
-    The operation is eventually lowered into the "tdpbf16ps" instruction with
-    the corresponding tile configuration.
-
-    Example:
-
-    ```mlir
-      %0 = amx.tile_mulf %a, %b, %c
-        : !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
-    ```
-  }];
-  let arguments = (ins AMXTileF16OrBF16:$lhs,
-                       AMXTileF16OrBF16:$rhs,
-                       AMXTileF32:$acc);
-  let results = (outs AMXTileF32:$res);
-  let extraClassDeclaration = [{
-    TileType getLhsTileType() {
-      return ::llvm::cast<TileType>(getLhs().getType());
-    }
-    TileType getRhsTileType() {
-      return ::llvm::cast<TileType>(getRhs().getType());
-    }
-    TileType getTileType() {
-      return ::llvm::cast<TileType>(getRes().getType());
-    }
-
-    std::string getIntrinsicName() {
-      std::string intr = "llvm.x86.tdp";
-      auto elementType =
-        getLhsTileType().getElementType();
-      intr += elementType.isF16() ? "fp16" : "bf16";
-      intr += "ps.internal";
-      return intr;
-    }
-    SmallVector<Value> getIntrinsicOperands(
-        ::mlir::ArrayRef<Value> operands,
-        const ::mlir::LLVMTypeConverter &typeConverter,
-        ::mlir::RewriterBase &rewriter);
-  }];
-  let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
-                       "qualified(type($lhs)) `,` qualified(type($rhs))"
-                       " `,` qualified(type($acc)) ";
-  let hasVerifier = 1;
-}
-
-def TileMulIOp : AMX_Op<"tile_muli", [Pure,
-    AMXIntrinsicOpInterface,
-    AllTypesMatch<["acc", "res"]>
-  ]> {
-  let summary = "tile multiplication operation (integer)";
-  let description = [{
-    Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
-    into a "m x n" destination tile. Supports all "si32 <- s/ui8 x s/ui8"
-    combinations (4 bytes packed into dwords in the columns of both the
-    source operand tiles; the zero or sign extension is specified with
-    the attributes and default to sign extended).
-    
-    The operation is eventually lowered into one of the "tdpbssd",
-    "tdpbsud", "tdpbusd", or "tdpbuud" instructions with the corresponding
-    tile configuration.
-
-    Example:
-
-    ```mlir
-      %0 = amx.tile_muli %a zext, %b zext, %c
-        : !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
-    ```
-  }];
-  let arguments = (ins AMXTileI8:$lhs,
-                       AMXTileI8:$rhs,
-                       AMXTileI32:$acc,
-                       UnitAttr:$isZextLhs,
-                       UnitAttr:$isZextRhs
-                       );
-  let results = (outs AMXTileI32:$res);
-  let extraClassDeclaration = [{
-    TileType getLhsTileType() {
-      return ::llvm::cast<TileType>(getLhs().getType());
-    }
-    TileType getRhsT...
[truncated]

Comment thread mlir/include/mlir/Dialect/X86/X86Dialect.h
@rengolin
Copy link
Copy Markdown
Member

The MLIR AMX tests are now nested under X86 directory. To enable AMX integration tests, 'MLIR_RUN_X86_TESTS' must also be defined.

Can we not just merge the two flags into one? Why do we need MLIR_RUN_X86_AMX_TESTS?

@adam-smnk
Copy link
Copy Markdown
Member Author

Can we not just merge the two flags into one? Why do we need MLIR_RUN_X86_AMX_TESTS?

It's only for the integration tests. I'd imagine AMX extension is not that widely available.

@rengolin
Copy link
Copy Markdown
Member

It's only for the integration tests. I'd imagine AMX extension is not that widely available.

No, and neither is AVX512. I thought those tests were just IR tests. Integration tests needs to check the actual hardware flags (lit.cfg tricks on different dirs).

@joker-eph
Copy link
Copy Markdown
Contributor

joker-eph commented Feb 27, 2026

Can you please update the commit description to not just describe "what" the change is doing but the "why"? (it's easy to recover the "what" from the diff, the "context" and "motivation" on the other hand is what is most important in a commit description IMO, context can be also a link to a Discourse thread or similar).
Thanks!

@adam-smnk
Copy link
Copy Markdown
Member Author

It's only for the integration tests. I'd imagine AMX extension is not that widely available.

No, and neither is AVX512. I thought those tests were just IR tests. Integration tests needs to check the actual hardware flags (lit.cfg tricks on different dirs).

There are IR only tests too and these are executed by default just like for any other dialect.
x86 integration testing needs larger revision as these tests don't even actively run on any bots.

Well, for now I'm just moving things around and keeping changes to the minimum.

Copy link
Copy Markdown
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As with the previous PR, mostly mechanical change of previously agreed move, so LGTM. I'll let others chime in and approve.

@kuhar @banach-space FYI.

@rengolin
Copy link
Copy Markdown
Member

There are IR only tests too and these are executed by default just like for any other dialect. x86 integration testing needs larger revision as these tests don't even actively run on any bots.

Well, for now I'm just moving things around and keeping changes to the minimum.

Makes sense. That can go to a future PR.

@adam-smnk
Copy link
Copy Markdown
Member Author

@joker-eph Thanks for the input. The change is mechanical in my mind but the context isn't necessarily obvious, indeed.
I updated the description, hopefully it paints better picture now.

@joker-eph
Copy link
Copy Markdown
Contributor

Thanks @adam-smnk !

Copy link
Copy Markdown
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to me

@adam-smnk adam-smnk merged commit e44fd05 into llvm:main Mar 2, 2026
18 checks passed
ingomueller-net added a commit that referenced this pull request Mar 2, 2026
This PR fixes the bazel build that got broken by #183717, which moved
the AMX dialect into the X86 dialect. The fix consists of replicating
the changes from the CMake files into BUILD files as usual; in this
case, mostly removing the AMX dialect targets, adding a few new
references to the corresponding X86 targets, and adding a few new
dependencies to the existing X86 targets due to the new code.

Signed-off-by: Ingo Müller <ingomueller@google.com>
sahas3 pushed a commit to sahas3/llvm-project that referenced this pull request Mar 4, 2026
Unifies the two dialects that define x86 operations into a single one.
The AMX dialect is moved into X86 in line with other x86 extensions.

Following the dialect renaming, X86 dialect is now a suitable home for
wider range of operations targeting specific hardware features. Moving
AMX definitions to X86 dialect creates a single, centralized hub for
defining all x86 intrinsic-like operations. The new grouping aims to
eliminate the need for new dialects as new hardware extensions become
available.

The two dialects are simply merged together. X86 dialect refactoring
will be addressed separately.

List of changes:
  - operations: 'amx.tile_*' => 'x86.amx.tile_*'
  - types: '!amx.tile' => '!x86.amx.tile'
  - namespace: 'mlir::amx' => 'mlir::x86::amx'
  - test define: 'MLIR_RUN_AMX_TESTS' => 'MLIR_RUN_X86_AMX_TESTS'
  - vector lowering: AMX is enabled by default together with X86

The MLIR AMX tests are now nested under X86 directory. To enable AMX
integration tests, 'MLIR_RUN_X86_TESTS' must also be defined.
sahas3 pushed a commit to sahas3/llvm-project that referenced this pull request Mar 4, 2026
…vm#184165)

This PR fixes the bazel build that got broken by llvm#183717, which moved
the AMX dialect into the X86 dialect. The fix consists of replicating
the changes from the CMake files into BUILD files as usual; in this
case, mostly removing the AMX dialect targets, adding a few new
references to the corresponding X86 targets, and adding a few new
dependencies to the existing X86 targets due to the new code.

Signed-off-by: Ingo Müller <ingomueller@google.com>
sujianIBM pushed a commit to sujianIBM/llvm-project that referenced this pull request Mar 5, 2026
Unifies the two dialects that define x86 operations into a single one.
The AMX dialect is moved into X86 in line with other x86 extensions.

Following the dialect renaming, X86 dialect is now a suitable home for
wider range of operations targeting specific hardware features. Moving
AMX definitions to X86 dialect creates a single, centralized hub for
defining all x86 intrinsic-like operations. The new grouping aims to
eliminate the need for new dialects as new hardware extensions become
available.

The two dialects are simply merged together. X86 dialect refactoring
will be addressed separately.

List of changes:
  - operations: 'amx.tile_*' => 'x86.amx.tile_*'
  - types: '!amx.tile' => '!x86.amx.tile'
  - namespace: 'mlir::amx' => 'mlir::x86::amx'
  - test define: 'MLIR_RUN_AMX_TESTS' => 'MLIR_RUN_X86_AMX_TESTS'
  - vector lowering: AMX is enabled by default together with X86

The MLIR AMX tests are now nested under X86 directory. To enable AMX
integration tests, 'MLIR_RUN_X86_TESTS' must also be defined.
sujianIBM pushed a commit to sujianIBM/llvm-project that referenced this pull request Mar 5, 2026
…vm#184165)

This PR fixes the bazel build that got broken by llvm#183717, which moved
the AMX dialect into the X86 dialect. The fix consists of replicating
the changes from the CMake files into BUILD files as usual; in this
case, mostly removing the AMX dialect targets, adding a few new
references to the corresponding X86 targets, and adding a few new
dependencies to the existing X86 targets due to the new code.

Signed-off-by: Ingo Müller <ingomueller@google.com>
arun-thmn added a commit to libxsmm/tpp-mlir that referenced this pull request Mar 28, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants