Skip to content

Commit

Permalink
[Mosaic] Introduce tile_strides to memref layout to support slice.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 567764629
  • Loading branch information
bythew3i authored and jax authors committed Sep 23, 2023
1 parent 1466c3d commit 6ab4806
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 18 deletions.
4 changes: 2 additions & 2 deletions jaxlib/mosaic/dialect/tpu/tpu.td
Expand Up @@ -104,8 +104,8 @@ def TPU_TiledLayoutAttr
[DeclareAttrInterfaceMethods<MemRefLayoutAttrInterface>]> {
let description = [{TODO}];
let parameters = (ins
"int64_t":$rank,
ArrayRefParameter<"::xla::Tile", "">:$tiles
ArrayRefParameter<"::xla::Tile", "">:$tiles,
ArrayRefParameter<"int64_t", "">:$tile_strides
);

let hasCustomAssemblyFormat = 1;
Expand Down
42 changes: 33 additions & 9 deletions jaxlib/mosaic/dialect/tpu/tpu_dialect.cc
Expand Up @@ -85,22 +85,23 @@ Attribute VectorLayoutAttr::parse(AsmParser &parser, Type type) {

void TiledLayoutAttr::print(AsmPrinter &printer) const {
printer << '<';
printer << getRank();
printer << ',';
for (const xla::Tile &tile : getTiles()) {
printer << tile.ToString();
}
printer << '>';
printer << ",[";
for (int i = 0; i < getTileStrides().size(); ++i) {
if (i > 0) {
printer << ',';
}
printer << getTileStrides()[i];
}
printer << "]>";
}

Attribute TiledLayoutAttr::parse(AsmParser &parser, Type type) {
if (failed(parser.parseLess())) {
return {};
}
int64_t rank;
if (parser.parseInteger(rank) || parser.parseComma()) {
return {};
}
llvm::SmallVector<xla::Tile, 2> tiles;
int64_t size;
while (succeeded(parser.parseOptionalLParen())) {
Expand All @@ -119,14 +120,37 @@ Attribute TiledLayoutAttr::parse(AsmParser &parser, Type type) {
tile.add_dimensions(size);
}
}
llvm::SmallVector<int64_t, 2> tile_strides;
int64_t stride;
if (failed(parser.parseComma())) {
return {};
}
if (succeeded(parser.parseOptionalLSquare())) {
bool first = true;
while (!succeeded(parser.parseOptionalRSquare())) {
if (!first) {
if (failed(parser.parseComma())) {
return {};
}
}
first = false;
if (failed(parser.parseInteger(stride))) {
return {};
}
tile_strides.push_back(stride);
}
} else {
return {};
}
if (failed(parser.parseGreater())) {
return {};
}
return get(parser.getContext(), rank, tiles);
return get(parser.getContext(), tiles, tile_strides);
}

AffineMap TiledLayoutAttr::getAffineMap() const {
AffineMap map = AffineMap::getMultiDimIdentityMap(getRank(), getContext());
AffineMap map =
AffineMap::getMultiDimIdentityMap(getTileStrides().size(), getContext());
SmallVector<AffineExpr, 8> exprs;
for (const xla::Tile &tile : getTiles()) {
exprs.clear();
Expand Down
2 changes: 1 addition & 1 deletion jaxlib/mosaic/dialect/tpu/transforms/communication.cc
Expand Up @@ -106,7 +106,7 @@ struct LogicalToPhysicalDeviceIdPass
if (func.getName() == "main") {
auto device_assignment_type = MemRefType::get(
{total_devices}, IntegerType::get(func.getContext(), 32),
TiledLayoutAttr::get(func.getContext(), 1, {xla::Tile({128})}),
TiledLayoutAttr::get(func.getContext(), {xla::Tile({128})}, {1}),
MemorySpaceAttr::get(func.getContext(), MemorySpace::smem));
func.insertArgument(func.getNumArguments(), device_assignment_type,
nullptr, UnknownLoc::get(func.getContext()));
Expand Down
18 changes: 16 additions & 2 deletions jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc
Expand Up @@ -75,7 +75,7 @@ FailureOr<TiledLayoutAttr> inferLayout(MemRefType memref,
}
tiles.append({xla::Tile({128}), xla::Tile({32 / bitwidth, 1})});
}
return TiledLayoutAttr::get(memref.getContext(), /*rank=*/1, tiles);
return TiledLayoutAttr::get(memref.getContext(), tiles, {1});
}
// memref.getRank() > 1
const ArrayRef<int64_t> shape = memref.getShape();
Expand All @@ -91,7 +91,21 @@ FailureOr<TiledLayoutAttr> inferLayout(MemRefType memref,
}
tiles.push_back(xla::Tile({32 / bitwidth, 1}));
}
return TiledLayoutAttr::get(memref.getContext(), memref.getRank(), tiles);
SmallVector<int64_t> tile_strides;
tile_strides.reserve(memref.getRank());
int64_t stride = 1;
for (int i = memref.getRank() - 1; i > 0; --i) {
tile_strides[i] = stride;
if (i == memref.getRank() - 1) {
stride *= (memref.getShape()[i] + 127) / 128;
} else if (i == memref.getRank() - 2) {
stride *=
(memref.getShape()[i] + leading_tile_rows - 1) / leading_tile_rows;
} else {
stride *= memref.getShape()[i];
}
}
return TiledLayoutAttr::get(memref.getContext(), tiles, tile_strides);
}
return emitError(UnknownLoc::get(memref.getContext()),
"Unrecognized layout annotation");
Expand Down
16 changes: 12 additions & 4 deletions jaxlib/mosaic/python/infer_memref_layout.py
Expand Up @@ -89,9 +89,7 @@ def infer_memref(
if bitwidth.bit_count() != 1 or bitwidth > 32:
raise NotImplementedError(f"Unsupported bitwidth: {bitwidth}")
trailing_tiles = f"(128)({32 // bitwidth},1)"
layout = ir.Attribute.parse(
f"#tpu.tiled<{memref.rank},({tile}){trailing_tiles}>"
)
layout = ir.Attribute.parse(f"#tpu.tiled<({tile}){trailing_tiles},[1]>")
else:
leading_tile = _tiling_factor(
memref.shape[-2], hardware_generation, bitwidth
Expand All @@ -102,8 +100,18 @@ def infer_memref(
if bitwidth.bit_count() != 1 or bitwidth > 32:
raise NotImplementedError(f"Unsupported bitwidth: {bitwidth}")
trailing_tiles = f"({32 // bitwidth},1)"
tile_strides = [None] * memref.rank
stride = 1
for i in range(memref.rank - 1, -1, -1):
tile_strides[i] = stride
if i == memref.rank - 1:
stride *= (memref.shape[i] + 127) // 128
elif i == memref.rank - 2:
stride *= (memref.shape[i] + leading_tile - 1) // leading_tile
else:
stride *= memref.shape[i]
layout = ir.Attribute.parse(
f"#tpu.tiled<{memref.rank},({leading_tile},128){trailing_tiles}>"
f"#tpu.tiled<({leading_tile},128){trailing_tiles},{tile_strides}>"
)
elif tpu.private_is_tiled_layout(memref.layout):
layout = memref.layout
Expand Down

0 comments on commit 6ab4806

Please sign in to comment.