Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU] Simplify how tile sizes are updated #16435

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

banach-space
Copy link
Collaborator

This is a follow-up for #16350 and is meant to simplify how tile sizes
are updated. In particular, a new wrapper for tile sizes is added,
SizesAndScalableFlagsTuple, that enables the following simplification:

BEFORE

vecTileSizes[idx] = innerVecTileSizes[idx];
vecScalableTileFlags[idx] = innerVecScalableTileFlags[idx];

(size and scalable flag updated separately)

AFTER

vecTileSizesAndFlags[idx] = innerVecTileSizesAndFlags[idx];

(size and scalable flag updated in one stmt)

The ultimate goal is to "hide" scalable flags for folks working on
targets that don't require those while preserving enough flexibility for
targets that do need to track this extra info. It should also simplify
further work (and review process) for future patches similar to #16350.

This is a follow-up for iree-org#16350 and is meant to simplify how tile sizes
are updated. In particular, a new wrapper for tile sizes is added,
`SizesAndScalableFlagsTuple`, that enables the following simplification:

**BEFORE**
```cpp
vecTileSizes[idx] = innerVecTileSizes[idx];
vecScalableTileFlags[idx] = innerVecScalableTileFlags[idx];
```
(size and scalable flag updated separately)

**AFTER**
```cpp
vecTileSizesAndFlags[idx] = innerVecTileSizesAndFlags[idx];

```
(size and scalable flag updated in one stmt)

The ultimate goal is to "hide" scalable flags for folks working on
targets that don't require those while preserving enough flexibility for
targets that do need to track this extra info. It should also simplify
further work (and review process) for future patches similar to iree-org#16350.
@banach-space banach-space force-pushed the andrzej/better_logic_for_tile_size_update branch from 8c63a4c to 52cb8af Compare February 17, 2024 08:45
@@ -15,6 +15,75 @@ namespace mlir::iree_compiler {
using SizesAndScalableFlags =
std::pair<SmallVector<int64_t>, SmallVector<bool>>;

using SizeAndScalableFlag = std::tuple<int64_t &, bool &>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm about the make a refactor on this part and found this PR.

Is there a reason not to define a struct like:

struct ScalableTileSize {
  int64_t tileSize;
  bool scalableFlag;
};

Having field name seems to be clearer and easier to access them?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What we really need is a pair of references (rather than a plain int64_t and a bool) to elements of 2 vectors. That’s because we do this in KernelDispatch.cpp

SmallVector<int64_t> commonVecTileSizes = parallelVecTileSizes;
SmallVector<bool> commonVecScalableTileFlags = parallelVecScalableTileSizes;

Since one of the vectors contains bool, things are not as straightforward.

Btw, this alias is not really used (sorry, I was meant to remove it). But I appreciate that this code is not pretty regardless.

return std::pair<SmallVector<int64_t>, SmallVector<bool>>(sizes, flags);
}

ReferencePair operator[](size_t index) {
Copy link
Contributor

@pzread pzread Feb 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allows users to get and store a reference, which seems to easily create a bug that keeps using the reference pair while the SizesAndScalableFlagsTuple has been destroyed. Especially if users don't understand the internal state of ReferencePair

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out!

Not sure there’s a way around it, TBH. Well, other than replacing (“pair of vectors”)

SmallVector<int64_t> commonVecTileSizes = parallelVecTileSizes;
SmallVector<bool> commonVecScalableTileFlags = parallelVecScalableTileSizes;

with (“vector of pairs”):

SmallVector<ScalableTileSize> commonVecTileSizesAndFlags 
    = parallelVecTileSizesAndFlags

That would be more intrusive, but would allow safer code. Having said that, ReferencePair is not intended for general use anyway. WDYT?

Copy link
Contributor

@pzread pzread Feb 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I was thinking about replacing it with vector of pair SmallVector<ScalableTileSize> everywhere in KernelDispatch.cpp. And in the end convert it into two separate vectors tileSizes and scalableTileFlags for the lowering_config attribute here:

https://github.com/openxla/iree/blob/9eff8615c2b2615db9f2ba6c79211452329aa872/compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h#L135-L152

So the new setOpConfigAndEntryPointFnTranslation will be:

inline LogicalResult setOpConfigAndEntryPointFnTranslation(
    mlir::FunctionOpInterface entryPointFn, Operation *op,
    SmallVector<SmallVector<ScalableTileSize>> scalableTileSizesList,
    IREE::Codegen::DispatchLoweringPassPipeline passPipeline,
    ArrayRef<int64_t> workgroupSize = {},
    std::optional<int64_t> subgroupSize = {},
    DictionaryAttr pipelineConfig = DictionaryAttr()) { ... }

Ideally we can also have [{tileSize, scalableFlag}, ...] on lowering_config attribute, but that will be a larger refactor and can be done separately. WDYT?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

I am really glad that you also see this as a better design.

One small nit, I would call it TileSizeAndScalableFlag rather than ScalableTileSize. The former indicates that 2 values are encapsulated - that would be always true. The latter suggests that tile sizes are always scalable - not always true - tile sizes can be both fixed width or scalable. Naming is hard.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can have a C++ class which maintain all the information about tile sizes and scalable flags. Then we won't be bothered by pair and vector things. The class generates SmallVector<SmallVector<ScalableTileSize>> at the end, and we use it to set the strategy and configuration. It changes the whole flow, and it requires a broader review from folks. I proposed the idea internally (in AMD) and collected some feedback from @pzread and @dcaballe, but I did not get a chance to push it forward. I'll try to find some cycles and can share with you offline.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please share when possible. If there's appetite for a larger refactor then perhaps that's the path that we should be taking. Is it something to raise in the mai-tai call?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was in collecting initial feedback stage. I will revamp it and share with you.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fly by comment. I see this is in Common/ . While we are cleaning this up can we move this to Common/CPU?
Thanks for the cleanup here though. I prefer structs to tuples as well. Less verbose.

flags(SmallVector<bool>(numElements, false)) {}

SizesAndScalableFlags get() {
return std::pair<SmallVector<int64_t>, SmallVector<bool>>(sizes, flags);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be std::make_pair(sizes, flags)?

return std::pair<SmallVector<int64_t>, SmallVector<bool>>(sizes, flags);
}

ReferencePair operator[](size_t index) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can have a C++ class which maintain all the information about tile sizes and scalable flags. Then we won't be bothered by pair and vector things. The class generates SmallVector<SmallVector<ScalableTileSize>> at the end, and we use it to set the strategy and configuration. It changes the whole flow, and it requires a broader review from folks. I proposed the idea internally (in AMD) and collected some feedback from @pzread and @dcaballe, but I did not get a chance to push it forward. I'll try to find some cycles and can share with you offline.

Comment on lines +2378 to +2379
SizesAndScalableFlagsTuple commanVecTileSizesAndFlags = {
parallelVecTileSizes, parallelVecScalableTileSizes};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can directly use the constructor here? E.g.,

Suggested change
SizesAndScalableFlagsTuple commanVecTileSizesAndFlags = {
parallelVecTileSizes, parallelVecScalableTileSizes};
SizesAndScalableFlagsTuple commanVecTileSizesAndFlags(parallelVecTileSizes, parallelVecScalableTileSizes);

}
};

SizesAndScalableFlagsTuple(SmallVector<int64_t> s, SmallVector<bool> f)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be ArrayRef?

Comment on lines +73 to +75
SizesAndScalableFlagsTuple(size_t numElements)
: sizes(SmallVector<int64_t>(numElements, 0)),
flags(SmallVector<bool>(numElements, false)) {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need SmallVector<...> to construct it. Does the below snippet work?

Suggested change
SizesAndScalableFlagsTuple(size_t numElements)
: sizes(SmallVector<int64_t>(numElements, 0)),
flags(SmallVector<bool>(numElements, false)) {}
SizesAndScalableFlagsTuple(size_t numElements)
: sizes(numElements, 0),
flags(numElements, false) {}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants