Skip to content

Commit

Permalink
Disallow index types in memrefs.
Browse files Browse the repository at this point in the history
As specified in the MLIR language reference and rationale documents, `memref`
types should not be allowed to have `index` as element types. As observed in
https://groups.google.com/a/tensorflow.org/forum/#!msg/mlir/P49hVWqTMNc/nW89a4i_AgAJ
this restriction was lifted when canonicalization unit tests for affine
operations were introduced, without sufficient motivation to lift the
restriction itself.  The test in question can be trivially rewritten (return
the value from a function instead of storing it to prevent DCE from removing
the producer operation) and the restriction put back in place.

If `memref<...x index>` is relevant for some use cases, the relaxation of the
type system can be implemented separately with appropriate modifications to the
documentation.

PiperOrigin-RevId: 272607043
  • Loading branch information
ftynse authored and tensorflower-gardener committed Oct 3, 2019
1 parent 9604bb6 commit 44ef5e5
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 7 deletions.
7 changes: 7 additions & 0 deletions mlir/lib/IR/StandardTypes.cpp
Expand Up @@ -341,6 +341,13 @@ MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
Optional<Location> location) {
auto *context = elementType.getContext();

// Check that memref is formed from allowed types.
if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>()) {
if (location)
emitError(*location, "invalid memref element type");
return nullptr;
}

for (int64_t s : shape) {
// Negative sizes are not allowed except for `-1` that means dynamic size.
if (s < -1) {
Expand Down
7 changes: 2 additions & 5 deletions mlir/test/AffineOps/canonicalize.mlir
Expand Up @@ -258,15 +258,12 @@ func @trivial_maps() {
}

// CHECK-LABEL: func @partial_fold_map
func @partial_fold_map(%arg0: memref<index>, %arg1: index, %arg2: index) {
func @partial_fold_map(%arg1: index, %arg2: index) -> index {
// TODO: Constant fold one index into affine.apply
%c42 = constant 42 : index
%2 = affine.apply (d0, d1) -> (d0 - d1) (%arg1, %c42)
store %2, %arg0[] : memref<index>
// CHECK: [[X:%[0-9]+]] = affine.apply [[MAP15]]()[%{{.*}}]
// CHECK-NEXT: store [[X]], %{{.*}}

return
return %2 : index
}

// CHECK-LABEL: func @symbolic_composition_a(%{{.*}}: index, %{{.*}}: index) -> index {
Expand Down
3 changes: 1 addition & 2 deletions mlir/test/IR/invalid.mlir
Expand Up @@ -21,8 +21,7 @@ func @indexvector(vector<4 x index>) -> () // expected-error {{vector elements m

// -----

// Everything is valid in a memref.
func @indexmemref(memref<? x index>) -> ()
func @indexmemref(memref<? x index>) -> () // expected-error {{invalid memref element type}}

// -----

Expand Down

0 comments on commit 44ef5e5

Please sign in to comment.