diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 541dcc91b6f0..5cf464797bf5 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -104,8 +104,8 @@ def TPU_TiledLayoutAttr [DeclareAttrInterfaceMethods]> { let description = [{TODO}]; let parameters = (ins - "int64_t":$rank, - ArrayRefParameter<"::xla::Tile", "">:$tiles + ArrayRefParameter<"::xla::Tile", "">:$tiles, + ArrayRefParameter<"int64_t", "">:$tile_strides ); let hasCustomAssemblyFormat = 1; diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc index e0be38f13c5a..164453ce58c2 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.cc @@ -85,22 +85,23 @@ Attribute VectorLayoutAttr::parse(AsmParser &parser, Type type) { void TiledLayoutAttr::print(AsmPrinter &printer) const { printer << '<'; - printer << getRank(); - printer << ','; for (const xla::Tile &tile : getTiles()) { printer << tile.ToString(); } - printer << '>'; + printer << ",["; + for (int i = 0; i < getTileStrides().size(); ++i) { + if (i > 0) { + printer << ','; + } + printer << getTileStrides()[i]; + } + printer << "]>"; } Attribute TiledLayoutAttr::parse(AsmParser &parser, Type type) { if (failed(parser.parseLess())) { return {}; } - int64_t rank; - if (parser.parseInteger(rank) || parser.parseComma()) { - return {}; - } llvm::SmallVector tiles; int64_t size; while (succeeded(parser.parseOptionalLParen())) { @@ -119,14 +120,37 @@ Attribute TiledLayoutAttr::parse(AsmParser &parser, Type type) { tile.add_dimensions(size); } } + llvm::SmallVector tile_strides; + int64_t stride; + if (failed(parser.parseComma())) { + return {}; + } + if (succeeded(parser.parseOptionalLSquare())) { + bool first = true; + while (!succeeded(parser.parseOptionalRSquare())) { + if (!first) { + if (failed(parser.parseComma())) { + return {}; + } + } + first = false; + if (failed(parser.parseInteger(stride))) { + return {}; + } + tile_strides.push_back(stride); + } + } else { + return {}; + } if (failed(parser.parseGreater())) { return {}; } - return get(parser.getContext(), rank, tiles); + return get(parser.getContext(), tiles, tile_strides); } AffineMap TiledLayoutAttr::getAffineMap() const { - AffineMap map = AffineMap::getMultiDimIdentityMap(getRank(), getContext()); + AffineMap map = + AffineMap::getMultiDimIdentityMap(getTileStrides().size(), getContext()); SmallVector exprs; for (const xla::Tile &tile : getTiles()) { exprs.clear(); diff --git a/jaxlib/mosaic/dialect/tpu/transforms/communication.cc b/jaxlib/mosaic/dialect/tpu/transforms/communication.cc index 513210da9f19..89e3a8bb9f70 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/communication.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/communication.cc @@ -106,7 +106,7 @@ struct LogicalToPhysicalDeviceIdPass if (func.getName() == "main") { auto device_assignment_type = MemRefType::get( {total_devices}, IntegerType::get(func.getContext(), 32), - TiledLayoutAttr::get(func.getContext(), 1, {xla::Tile({128})}), + TiledLayoutAttr::get(func.getContext(), {xla::Tile({128})}, {1}), MemorySpaceAttr::get(func.getContext(), MemorySpace::smem)); func.insertArgument(func.getNumArguments(), device_assignment_type, nullptr, UnknownLoc::get(func.getContext())); diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc index f3bf4c0a84f5..7e2901b03f78 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc @@ -75,7 +75,7 @@ FailureOr inferLayout(MemRefType memref, } tiles.append({xla::Tile({128}), xla::Tile({32 / bitwidth, 1})}); } - return TiledLayoutAttr::get(memref.getContext(), /*rank=*/1, tiles); + return TiledLayoutAttr::get(memref.getContext(), tiles, {1}); } // memref.getRank() > 1 const ArrayRef shape = memref.getShape(); @@ -91,7 +91,21 @@ FailureOr inferLayout(MemRefType memref, } tiles.push_back(xla::Tile({32 / bitwidth, 1})); } - return TiledLayoutAttr::get(memref.getContext(), memref.getRank(), tiles); + SmallVector tile_strides; + tile_strides.reserve(memref.getRank()); + int64_t stride = 1; + for (int i = memref.getRank() - 1; i > 0; --i) { + tile_strides[i] = stride; + if (i == memref.getRank() - 1) { + stride *= (memref.getShape()[i] + 127) / 128; + } else if (i == memref.getRank() - 2) { + stride *= + (memref.getShape()[i] + leading_tile_rows - 1) / leading_tile_rows; + } else { + stride *= memref.getShape()[i]; + } + } + return TiledLayoutAttr::get(memref.getContext(), tiles, tile_strides); } return emitError(UnknownLoc::get(memref.getContext()), "Unrecognized layout annotation"); diff --git a/jaxlib/mosaic/python/infer_memref_layout.py b/jaxlib/mosaic/python/infer_memref_layout.py index 62300a28bdba..554ec198815b 100644 --- a/jaxlib/mosaic/python/infer_memref_layout.py +++ b/jaxlib/mosaic/python/infer_memref_layout.py @@ -89,9 +89,7 @@ def infer_memref( if bitwidth.bit_count() != 1 or bitwidth > 32: raise NotImplementedError(f"Unsupported bitwidth: {bitwidth}") trailing_tiles = f"(128)({32 // bitwidth},1)" - layout = ir.Attribute.parse( - f"#tpu.tiled<{memref.rank},({tile}){trailing_tiles}>" - ) + layout = ir.Attribute.parse(f"#tpu.tiled<({tile}){trailing_tiles},[1]>") else: leading_tile = _tiling_factor( memref.shape[-2], hardware_generation, bitwidth @@ -102,8 +100,18 @@ def infer_memref( if bitwidth.bit_count() != 1 or bitwidth > 32: raise NotImplementedError(f"Unsupported bitwidth: {bitwidth}") trailing_tiles = f"({32 // bitwidth},1)" + tile_strides = [None] * memref.rank + stride = 1 + for i in range(memref.rank - 1, -1, -1): + tile_strides[i] = stride + if i == memref.rank - 1: + stride *= (memref.shape[i] + 127) // 128 + elif i == memref.rank - 2: + stride *= (memref.shape[i] + leading_tile - 1) // leading_tile + else: + stride *= memref.shape[i] layout = ir.Attribute.parse( - f"#tpu.tiled<{memref.rank},({leading_tile},128){trailing_tiles}>" + f"#tpu.tiled<({leading_tile},128){trailing_tiles},{tile_strides}>" ) elif tpu.private_is_tiled_layout(memref.layout): layout = memref.layout