Skip to content

Commit

Permalink
[mlir][tosa] Switch TosaFoldConstantTranspose to use ElementsAttr.
Browse files Browse the repository at this point in the history
Also avoid redoing index calculation.

Differential Revision: https://reviews.llvm.org/D132274
  • Loading branch information
jpienaar committed Aug 22, 2022
1 parent 9f6cb3e commit b1f2e26
Showing 1 changed file with 4 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
if (!outputType.getElementType().isIntOrIndexOrFloat())
return failure();

DenseElementsAttr inputValues;
ElementsAttr inputValues;
if (!matchPattern(op.getInput1(), m_Constant(&inputValues)))
return failure();
// Make sure the input is a constant that has a single user.
Expand All @@ -57,10 +57,9 @@ struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
// index.
auto attrValues = inputValues.getValues<Attribute>();
ArrayRef<int64_t> outputShape = outputType.getShape();
for (int srcLinearIndex = 0; srcLinearIndex < numElements;
++srcLinearIndex) {
for (const auto &it : llvm::enumerate(attrValues)) {
SmallVector<uint64_t, 6> srcIndices(inputType.getRank(), 0);
int totalCount = srcLinearIndex;
int totalCount = it.index();
for (int dim = inputType.getRank() - 1; dim >= 0; --dim) {
srcIndices[dim] = totalCount % inputShape[dim];
totalCount /= inputShape[dim];
Expand All @@ -74,7 +73,7 @@ struct TosaFoldConstantTranspose : public OpRewritePattern<tosa::TransposeOp> {
for (int dim = 1; dim < outputType.getRank(); ++dim)
dstLinearIndex = dstLinearIndex * outputShape[dim] + dstIndices[dim];

outputValues[dstLinearIndex] = attrValues[srcIndices];
outputValues[dstLinearIndex] = it.value();
}

rewriter.replaceOpWithNewOp<tosa::ConstOp>(
Expand Down

0 comments on commit b1f2e26

Please sign in to comment.