Skip to content

Commit

Permalink
[mlir] Use ReassociationIndices instead of affine maps in linalg.resh…
Browse files Browse the repository at this point in the history
…ape.

Differential Revision: https://reviews.llvm.org/D101861
  • Loading branch information
pifon2a committed May 5, 2021
1 parent e4eec51 commit 2865d11
Show file tree
Hide file tree
Showing 13 changed files with 508 additions and 742 deletions.
69 changes: 34 additions & 35 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
Expand Up @@ -315,55 +315,54 @@ class Linalg_ReshapeLikeOp<string mnemonic, list<OpTrait> traits = []> :
// Builders for a contracting reshape whose result type is computed from
// `src` and `reassociation`.
OpBuilder<(ins "Value":$src,
"ArrayRef<ReassociationExprs>":$reassociation,
"ArrayRef<ReassociationIndices>":$reassociation,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
OpBuilder<(ins "Value":$src,
"ArrayRef<ReassociationIndices>":$reassociation,
"ArrayRef<ReassociationExprs>":$reassociation,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
[{
auto reassociationMaps =
convertReassociationIndicesToMaps($_builder, reassociation);
convertReassociationMapsToIndices($_builder, reassociation);
build($_builder, $_state, src, reassociationMaps, attrs);
}]>,

// Builders for a reshape whose result type is passed explicitly. This may
// be either a contracting or expanding reshape.
OpBuilder<(ins "Type":$resultType, "Value":$src,
"ArrayRef<ReassociationExprs>":$reassociation,
"ArrayRef<ReassociationIndices>":$reassociation,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
OpBuilder<(ins "Type":$resultType, "Value":$src,
"ArrayRef<ReassociationIndices>":$reassociation,
"ArrayRef<ReassociationExprs>":$reassociation,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
[{
auto reassociationMaps =
convertReassociationIndicesToMaps($_builder, reassociation);
convertReassociationMapsToIndices($_builder, reassociation);
build($_builder, $_state, resultType, src, reassociationMaps, attrs);
}]>
];

code commonExtraClassDeclaration = [{
static StringRef getReassociationAttrName() { return "reassociation"; }
SmallVector<AffineMap, 4> getReassociationMaps() {
return llvm::to_vector<4>(llvm::map_range(reassociation(), [
](Attribute a) { return a.cast<AffineMapAttr>().getValue(); }));
}
SmallVector<ReassociationExprs, 4> getReassociationExprs() {
return
llvm::to_vector<4>(llvm::map_range(reassociation(),
[](Attribute a) {
return llvm::to_vector<2>(
a.cast<AffineMapAttr>().getValue().getResults());
}));
}
}];
let assemblyFormat = [{
$src $reassociation attr-dict `:` type($src) `into` type(results)
SmallVector<AffineMap, 4> getReassociationMaps();
SmallVector<ReassociationExprs, 4> getReassociationExprs();
SmallVector<ReassociationIndices, 4> getReassociationIndices() {
SmallVector<ReassociationIndices, 4> reassociationIndices;
for (auto attr : reassociation())
reassociationIndices.push_back(llvm::to_vector<2>(
llvm::map_range(attr.cast<ArrayAttr>(), [&](Attribute indexAttr) {
return indexAttr.cast<IntegerAttr>().getInt();
})));
return reassociationIndices;
};
}];
}

def IndexListArrayAttr :
TypedArrayAttrBase<I64ArrayAttr, "Array of 64-bit integer array attributes">;

def Linalg_ReshapeOp : Linalg_ReshapeLikeOp<"reshape",
[DeclareOpInterfaceMethods<ViewLikeOpInterface>]>,
Arguments<(ins AnyStridedMemRef:$src, AffineMapArrayAttr:$reassociation)>,
Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>,
Results<(outs AnyStridedMemRef:$result)> {
let summary = "linalg.reshape produces a new view into the operand view";
let description = [{
Expand All @@ -373,9 +372,7 @@ def Linalg_ReshapeOp : Linalg_ReshapeLikeOp<"reshape",
and copies.

A reassociation is defined as a continuous grouping of dimensions and is
represented with an affine map array attribute. In the future,
non-continuous groupings may be allowed (i.e. permutations, reindexings
etc).
represented with an array of I64ArrayAttr attribute.

For now, it is assumed that either:
1. a reassociation produces and consumes contiguous MemRefType or,
Expand All @@ -401,13 +398,13 @@ def Linalg_ReshapeOp : Linalg_ReshapeLikeOp<"reshape",

```mlir
// Dimension collapse (i, j) -> i' and k -> k'
%1 = linalg.reshape %0 [(i, j, k) -> (i, j), (i, j, k) -> (k)] :
%1 = linalg.reshape %0 [[0, 1], [2]] :
memref<?x?x?xf32, stride_spec> into memref<?x?xf32, stride_spec_2>
```

```mlir
// Dimension expansion i -> (i', j') and (k) -> (k')
%1 = linalg.reshape %0 [(i, j, k) -> (i, j), (i, j, k) -> (k)] :
%1 = linalg.reshape %0 [[0, 1], [2]] :
memref<?x?xf32, stride_spec> into memref<?x?x?xf32, stride_spec_2>
```
}];
Expand All @@ -417,24 +414,24 @@ def Linalg_ReshapeOp : Linalg_ReshapeLikeOp<"reshape",
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parseReshapeLikeOp(parser, result); }];
}

def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<
"tensor_reshape",
[DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["reifyReturnTypeShapesPerResultDim"]>]>,
Arguments<(ins AnyTensor:$src,
AffineMapArrayAttr:$reassociation)>,
IndexListArrayAttr:$reassociation)>,
Results<(outs AnyTensor:$result)> {
let summary = "linalg.tensor_reshape produces a new reshaped tensor.";
let description = [{
The `linalg.reshape` op produces a new tensor whose sizes are a
reassociation of the original `src`.

A reassociation is defined as a continuous grouping of dimensions and is
represented with an affine map array attribute. In the future,
non-continuous groupings may be allowed (i.e. permutations, reindexings
etc).
represented with an array of I64ArrayAttr attribute.

A reshape may either collapse or expand dimensions, depending on the
relationship between source and target tensor ranks. The verification rule
Expand All @@ -453,14 +450,14 @@ def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<

```mlir
// Dimension collapse (i, j) -> i' and k -> k'
%b = linalg.tensor_reshape %a [(i, j, k) -> (i, j), (i, j, k) -> (k)] :
tensor<?x?x?xf32> into tensor<?x?xf32>
%b = linalg.tensor_reshape %a [[0, 1], [2]]
: tensor<?x?x?xf32> into tensor<?x?xf32>
```

```mlir
// Dimension expansion i -> (i', j') and (k) -> (k')
%b = linalg.tensor_reshape %a [(i, j, k) -> (i, j), (i, j, k) -> (k)] :
tensor<?x?xf32> into tensor<?x?x?xf32>
%b = linalg.tensor_reshape %a [[0, 1], [2]]
: tensor<?x?xf32> into tensor<?x?x?xf32>
```
}];
let extraClassDeclaration = commonExtraClassDeclaration # [{
Expand All @@ -473,6 +470,8 @@ def Linalg_TensorReshapeOp : Linalg_ReshapeLikeOp<
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parseReshapeLikeOp(parser, result); }];
}

def Linalg_YieldOp : Linalg_Op<"yield", [NoSideEffect, ReturnLike, Terminator]>,
Expand Down

0 comments on commit 2865d11

Please sign in to comment.