diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 43ad435ccf1c1..0bd4f4ffe11b2 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2082,23 +2082,28 @@ def Vector_GatherOp : 3-D and the result is 2-D: ```mlir - func.func @gather_3D_to_2D( - %base: memref, %ofs_0: index, %ofs_1: index, %ofs_2: index, - %indices: vector<2x3xi32>, %mask: vector<2x3xi1>, - %fall_thru: vector<2x3xf32>) -> vector<2x3xf32> { - %result = vector.gather %base[%ofs_0, %ofs_1, %ofs_2] - [%indices], %mask, %fall_thru : [...] - return %result : vector<2x3xf32> + %base: memref, %ofs_0: index, %ofs_1: index, %ofs_2: index, + %indices: vector<2x3xi32>, %mask: vector<2x3xi1>, + %pass_thru: vector<2x3xf32>) -> vector<2x3xf32> { + %result = vector.gather %base[%ofs_0, %ofs_1, %ofs_2] + [%indices], %mask, %pass_thru : [...] } ``` The indexing semantics are then, ``` - result[i,j] := if mask[i,j] then base[i0, i1, i2 + indices[i,j]] - else pass_thru[i,j] + %result[i,j] := if %mask[i,j] then %base[%ofs_0, %ofs_1, %ofs_2 + %indices[i,j]] + else %pass_thru[i,j] ``` - The index into `base` only varies in the innermost ((k-1)-th) dimension. + Note, `indices` are element offsets - they are expressed in units of + elements (not bytes). Each element in `indices` represents a displacement + in units of elements from the starting element, i.e. `%base[%ofs_0, %ofs_1, + %ofs_2]` for the example above. Importantly, for MemRefs, `indices` are + interpreted assuming an identity (contiguous) MemRef layout. Any + non-identity layout (e.g. strided) is not reflected in the indices + themselves and is instead handled during lowering. + If a mask bit is set and the corresponding index is out-of-bounds for the given base, the behavior is undefined. If a mask bit is not set, the value @@ -2191,6 +2196,28 @@ def Vector_ScatterOp is stored regardless of the index, and the index is allowed to be out-of-bounds. + ```mlir + %base: memref, %ofs_0: index, %ofs_1: index, %ofs_2: index, + %indices: vector<2x3xi32>, %mask: vector<2x3xi1>, + %src: vector<2x3xf32>) -> memref { + %result = vector.scatter %base[%ofs_0, %ofs_1, %ofs_2] + [%indices], %mask, %src : [...] + ``` + The indexing semantics are then, + + ``` + if %mask[i,j] then + %base[%ofs_0, %ofs_1, %ofs_2 + %indices[i,j]] := %valueToStore[i,j] + ``` + + Note, `indices` are element offsets - they are expressed in units of + elements (not bytes). Each element in `indices` represents a displacement + in units of elements from the starting element, i.e. `%base[%ofs_0, %ofs_1, + %ofs_2]` for the example above. Importantly, for MemRefs, `indices` are + interpreted assuming an identity (contiguous) MemRef layout. Any + non-identity layout (e.g. strided) is not reflected in the indices + themselves and is instead handled during lowering. + If the index vector contains two or more duplicate indices, the behavior is undefined. Underlying implementation may enforce strict sequential semantics.