Skip to content

Commit

Permalink
[CPU] Simplify how tile sizes are updated
Browse files Browse the repository at this point in the history
This is a follow-up for #16350 and is meant to simplify how tile sizes
are updated. In particular, a new wrapper for tile sizes is added,
`SizesAndScalableFlagsTuple`, that enables the following simplification:

**BEFORE**
```cpp
vecTileSizes[idx] = innerVecTileSizes[idx];
vecScalableTileFlags[idx] = innerVecScalableTileFlags[idx];
```
(size and scalable flag updated separately)

**AFTER**
```cpp
vecTileSizesAndFlags[idx] = innerVecTileSizesAndFlags[idx];

```
(size and scalable flag updated in one stmt)

The ultimate goal is to "hide" scalable flags for folks working on
targets that don't require those while preserving enough flexibility for
targets that do need to track this extra info. It should also simplify
further work (and review process) for future patches similar to #16350.
  • Loading branch information
banach-space committed Feb 17, 2024
1 parent 2892d81 commit 52cb8af
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 21 deletions.
75 changes: 73 additions & 2 deletions compiler/src/iree/compiler/Codegen/Common/TileSizeSelection.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,75 @@ namespace mlir::iree_compiler {
using SizesAndScalableFlags =
std::pair<SmallVector<int64_t>, SmallVector<bool>>;

using SizeAndScalableFlag = std::tuple<int64_t &, bool &>;

/// A tuple that encapsulates two quantities describing tile sizes:
/// * regular tile sizes (integers) - that's always required
/// * scalable tile flags (bool - only used/required for scalable
/// vectorisation.
/// Use this wrapper to make sure that both quantities are upated when
/// manipulating tile sizes.
struct SizesAndScalableFlagsTuple {
SmallVector<int64_t> sizes;
SmallVector<bool> flags;

// Represents a pair of references to a size and a scalable flag at the given
// index. Due to various implementation details of vector of bools, it's much
// easier to store a reference to a whole container and an index. While this
// increases the size of this wrapper, it also simplifies the implementation.
struct ReferencePair {
SmallVector<int64_t> &sizesVec;
SmallVector<bool> &flagsVec;
// Index of this pair within the vectors
size_t index;

explicit ReferencePair(const ReferencePair &a) = default;
ReferencePair(SmallVector<int64_t> &sizesVecRef,
SmallVector<bool> &boolVectorRef, size_t indexRef)
: sizesVec(sizesVecRef), flagsVec(boolVectorRef), index(indexRef) {}

// Update this pair based on the input integer + bool
ReferencePair &operator=(const std::pair<int64_t, bool> &values) {
sizesVec[index] = values.first;
flagsVec[index] = values.second;
return *this;
}

// Update this pair based on the input integer. Assume that the scalable
// size is false. This is safe to use in cases where no scalable
// vectorisation/tiling is used/supported.
ReferencePair &operator=(int64_t size) {
sizesVec[index] = size;
flagsVec[index] = false;
return *this;
}

// Update this pair based on the input ReferencePair
ReferencePair &operator=(const ReferencePair &pair) {
sizesVec[index] = pair.sizesVec[index];
flagsVec[index] = pair.flagsVec[index];
return *this;
}
};

SizesAndScalableFlagsTuple(SmallVector<int64_t> s, SmallVector<bool> f)
: sizes(s), flags(f) {}

// Initialise to {0, false} for all sizes
SizesAndScalableFlagsTuple(size_t numElements)
: sizes(SmallVector<int64_t>(numElements, 0)),
flags(SmallVector<bool>(numElements, false)) {}

SizesAndScalableFlags get() {
return std::pair<SmallVector<int64_t>, SmallVector<bool>>(sizes, flags);
}

ReferencePair operator[](size_t index) {
// A new pair requires a reference to sizes, scalable flags and an index.
return {sizes, flags, index};
}
};

/// Provides unified API to get access to all the tile size needed during the
/// CPU lowering process, while abstracting the representation and verification
/// details of such information in the IR.
Expand Down Expand Up @@ -127,8 +196,10 @@ class TilingConfig {

private:
SizesAndScalableFlags getVectorSizesForLevel(unsigned level) {
return std::make_pair(loweringConfig.getTileSizeVals(level),
loweringConfig.getScalableTileFlagVals(level));
return SizesAndScalableFlagsTuple(
loweringConfig.getTileSizeVals(level),
loweringConfig.getScalableTileFlagVals(level))
.get();
}

SmallVector<int64_t> getTileSizesForLevel(unsigned level) {
Expand Down
34 changes: 15 additions & 19 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/KernelDispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2375,18 +2375,16 @@ setLoweringConfigForComputeOps(mlir::FunctionOpInterface entryPointFn,
<< "\n");

// Split parallel vector tile sizes into common parts and op-specific parts.
SmallVector<int64_t> commonVecTileSizes = parallelVecTileSizes;
SmallVector<bool> commonVecScalableTileFlags = parallelVecScalableTileSizes;
SmallVector<int64_t> innerVecTileSizes(maxLoopNums, 0);
SmallVector<bool> innerVecScalableTileFlags(maxLoopNums, false);
SizesAndScalableFlagsTuple commanVecTileSizesAndFlags = {
parallelVecTileSizes, parallelVecScalableTileSizes};
SizesAndScalableFlagsTuple innerVecTileSizesAndFlags(maxLoopNums);
for (auto op : computeOps) {
auto iterTypes = cast<TilingInterface>(op).getLoopIteratorTypes();
for (auto [idx, iterType] : llvm::enumerate(iterTypes)) {
if (iterType == utils::IteratorType::reduction) {
innerVecTileSizes[idx] = parallelVecTileSizes[idx];
innerVecScalableTileFlags[idx] = parallelVecScalableTileSizes[idx];
commonVecTileSizes[idx] = 0;
commonVecScalableTileFlags[idx] = false;
innerVecTileSizesAndFlags[idx] = {parallelVecTileSizes[idx],
parallelVecScalableTileSizes[idx]};
commanVecTileSizesAndFlags[idx] = {/*size=*/0, /*scalableFlag=*/false};
}
}
}
Expand All @@ -2409,18 +2407,18 @@ setLoweringConfigForComputeOps(mlir::FunctionOpInterface entryPointFn,
}
if (tilingConfig.getNumTilingLevels() > 1) {
tileSizesList[tilingConfig.getVectorCommonParallelLevel()] =
commonVecTileSizes;
commanVecTileSizesAndFlags.sizes;
scalableTileFlagsList[tilingConfig.getVectorCommonParallelLevel()] =
commonVecScalableTileFlags;
commanVecTileSizesAndFlags.flags;
}
} else {
// Build 4-level lowering configs for other ops.
tileSizesList = {distTileSizes, commonVecTileSizes};
tileSizesList = {distTileSizes, commanVecTileSizesAndFlags.sizes};
SmallVector<int64_t> zeros(numLoops, 0);
SmallVector<bool> falseVec(numLoops, 0);
// No scalable tiling for the distribution
scalableTileFlagsList.push_back(falseVec);
scalableTileFlagsList.push_back(commonVecScalableTileFlags);
scalableTileFlagsList.push_back(commanVecTileSizesAndFlags.flags);
bool setUpOK =
TypeSwitch<Operation *, bool>(op)
.Case<tensor::PackOp>([&](auto packOp) {
Expand All @@ -2431,7 +2429,7 @@ setLoweringConfigForComputeOps(mlir::FunctionOpInterface entryPointFn,
return false;
}
tileSizesList.push_back(zeros);
tileSizesList.push_back(innerVecTileSizes);
tileSizesList.push_back(innerVecTileSizesAndFlags.sizes);
// Scale and permutate the outer dim tiles for pack op.
ArrayRef<int64_t> innerTiles = packOp.getStaticInnerTiles();
ArrayRef<int64_t> dimPos = packOp.getInnerDimsPos();
Expand Down Expand Up @@ -2460,18 +2458,16 @@ setLoweringConfigForComputeOps(mlir::FunctionOpInterface entryPointFn,
scalableTileFlagsList.push_back(falseVec);
}
// Only copy the inner vector tile sizes on parallel dims.
SmallVector<int64_t> vecTileSizes(numLoops, 0);
SmallVector<bool> vecScalableTileFlags(numLoops, false);
SizesAndScalableFlagsTuple vecTileSizesAndFlags(numLoops);
auto iterTypes =
cast<TilingInterface>(op).getLoopIteratorTypes();
for (auto [idx, iterType] : llvm::enumerate(iterTypes)) {
if (iterType == utils::IteratorType::parallel) {
vecTileSizes[idx] = innerVecTileSizes[idx];
vecScalableTileFlags[idx] = innerVecScalableTileFlags[idx];
vecTileSizesAndFlags[idx] = innerVecTileSizesAndFlags[idx];
}
}
tileSizesList.push_back(vecTileSizes);
scalableTileFlagsList.push_back(vecScalableTileFlags);
tileSizesList.push_back(vecTileSizesAndFlags.sizes);
scalableTileFlagsList.push_back(vecTileSizesAndFlags.flags);

return true;
});
Expand Down

0 comments on commit 52cb8af

Please sign in to comment.