Skip to content

Commit

Permalink
[MOSAIC] apply_vector_layout C++ rewrite (1) VectorLayout functions
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 561237760
  • Loading branch information
tlongeri authored and jax authors committed Aug 30, 2023
1 parent 6b57470 commit d02b59e
Show file tree
Hide file tree
Showing 4 changed files with 411 additions and 9 deletions.
1 change: 1 addition & 0 deletions jaxlib/mosaic/BUILD
Expand Up @@ -41,6 +41,7 @@ cc_library(
"dialect/tpu/layout.cc",
"dialect/tpu/tpu_dialect.cc",
"dialect/tpu/tpu_ops.cc",
"dialect/tpu/util.h",
] + glob(["dialect/tpu/transforms/*.cc"]),
hdrs = [
"dialect/tpu/layout.h",
Expand Down
178 changes: 176 additions & 2 deletions jaxlib/mosaic/dialect/tpu/layout.cc
Expand Up @@ -15,22 +15,77 @@ limitations under the License.

#include "jaxlib/mosaic/dialect/tpu/layout.h"

#include <algorithm>
#include <array>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <iterator>
#include <optional>
#include <ostream>
#include <string>
#include <tuple>

#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/MathExtras.h"
#include "absl/log/check.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
#include "jaxlib/mosaic/dialect/tpu/util.h"

namespace mlir::tpu {

bool RectangularVregBounds::maskVariesAlong(
const Direction direction,
const std::array<int64_t, 2> target_shape) const {
switch (direction) {
case Direction::kSublanes:
return starts_[0] != 0 || ends_[0] != target_shape[0];
case Direction::kLanes:
return starts_[1] != 0 || ends_[1] != target_shape[1];
case Direction::kSubelements:
return false;
}
}

FailureOr<TypedValue<VectorType>> RectangularVregBounds::getVectorMask(
OpBuilder& builder, const int /*generation*/,
const std::array<int64_t, 2> target_shape) const {
auto boundIdxConst = std::bind(IdxConst, std::placeholders::_1, builder,
builder.getUnknownLoc());
return cast<TypedValue<VectorType>>(
builder
.create<tpu::CreateMaskOp>(
builder.getUnknownLoc(),
VectorType::get(target_shape, builder.getI1Type()),
/*low=*/
ValueRange{boundIdxConst(starts_[0]), boundIdxConst(starts_[1])},
/*high=*/
ValueRange{boundIdxConst(ends_[0]), boundIdxConst(ends_[1])})
.getResult());
}

DenseBoolArrayAttr RectangularVregBounds::getSublaneMask(
MLIRContext* mlir_ctxt, const std::array<int64_t, 2> target_shape) const {
llvm::SmallVector<bool, 8> sublane_mask(target_shape[0], false);
for (int64_t i = starts_[0]; i < ends_[0]; ++i) {
sublane_mask[i] = true;
}
return DenseBoolArrayAttr::get(mlir_ctxt, sublane_mask);
}

namespace {

mlir::ParseResult parseOffset(llvm::StringRef* data,
Expand All @@ -47,6 +102,12 @@ mlir::ParseResult parseOffset(llvm::StringRef* data,
return failure();
}

std::array<int64_t, 2> nativeTiling(const int8_t bitwidth,
const std::array<int64_t, 2> target_shape) {
const int packing = 32 / bitwidth;
return {target_shape[0] * packing, target_shape[1]};
}

} // namespace

std::tuple<std::optional<int64_t>, std::optional<int64_t>, int64_t, int64_t,
Expand All @@ -60,6 +121,114 @@ bool VectorLayout::operator==(const VectorLayout& other) const {
return as_tuple() == other.as_tuple();
}

bool VectorLayout::hasNativeTiling(
const std::array<int64_t, 2> target_shape) const {
return tiling_ == nativeTiling(bitwidth_, target_shape);
}

llvm::SmallVector<int64_t> VectorLayout::implicitShape(
ArrayRef<int64_t> shape) const {
CHECK(!shape.empty());
switch (implicit_dim_) {
case ImplicitDim::kNone:
return llvm::SmallVector<int64_t>(shape);
case ImplicitDim::kMinor: {
llvm::SmallVector<int64_t> implicit_shape;
implicit_shape.reserve(shape.size() + 1);
implicit_shape.append(shape.begin(), shape.end());
implicit_shape.push_back(1);
return implicit_shape;
}
case ImplicitDim::kSecondMinor: {
llvm::SmallVector<int64_t> implicit_shape;
implicit_shape.reserve(shape.size() + 1);
implicit_shape.append(shape.begin(), std::prev(shape.end()));
implicit_shape.push_back(1);
implicit_shape.push_back(shape.back());
return implicit_shape;
}
}
}

llvm::SmallVector<int64_t> VectorLayout::tileArrayImplicitShape(
const ArrayRef<int64_t> shape,
const std::array<int64_t, 2> target_shape) const {
const std::array<int64_t, 2> vreg_slice = vregSlice(target_shape);
llvm::SmallVector<int64_t> tiles_shape = implicitShape(shape);
tiles_shape[tiles_shape.size() - 2] =
ceilDiv(offsets_[0].value_or(0) + tiles_shape[tiles_shape.size() - 2],
vreg_slice[0]);
tiles_shape[tiles_shape.size() - 1] =
ceilDiv(offsets_[1].value_or(0) + tiles_shape[tiles_shape.size() - 1],
vreg_slice[1]);
return tiles_shape;
}

llvm::SmallVector<int64_t> VectorLayout::tileArrayShape(
const ArrayRef<int64_t> shape,
const std::array<int64_t, 2> target_shape) const {
llvm::SmallVector<int64_t> tiles_shape =
tileArrayImplicitShape(shape, target_shape);
// Remove the implicit dimension --- it's always of size 1.
switch (implicit_dim_) {
case ImplicitDim::kNone:
break;
case ImplicitDim::kMinor:
tiles_shape.pop_back();
break;
case ImplicitDim::kSecondMinor:
tiles_shape.erase(tiles_shape.end() - 1);
break;
}
return tiles_shape;
}

bool VectorLayout::generalizes(
const VectorLayout& other, ArrayRef<int64_t> shape,
const std::array<int64_t, 2> target_shape) const {
if (bitwidth_ != other.bitwidth_) {
return false;
}
for (auto [s, o] : llvm::zip(offsets_, other.offsets_)) {
if (s.has_value() && s != o) {
return false;
}
}
if (implicit_dim_ != other.implicit_dim_) {
// Don't fail yet!
// If the second-minor dimension is of size 1, then it does not matter
// whether we have a second minor implicit dim or not.
if (shape.data() == nullptr) {
return false;
}
const llvm::SmallVector<int64_t> implicit_shape = implicitShape(shape);
if (!(implicit_shape[implicit_shape.size() - 2] == 1 &&
((implicit_dim_ == ImplicitDim::kSecondMinor &&
other.implicit_dim_ == ImplicitDim::kNone) ||
(other.implicit_dim_ == ImplicitDim::kSecondMinor &&
implicit_dim_ == ImplicitDim::kNone)))) {
return false;
}
}
if (tiling_ != other.tiling_) {
// Don't fail yet!
// If there is only one tile in both tilings, then they are equivalent.
if (shape.data() == nullptr) {
return false;
}
const SmallVector<int64_t> ishape = implicitShape(shape);
if (!(tiling_[1] == other.tiling_[1] &&
tiling_[1] == target_shape[1] &&
offsets_[1].value_or(0) + ishape[ishape.size() - 1] <=
target_shape[1] &&
offsets_[0].value_or(0) + ishape[ishape.size() - 2] <=
std::min(tiling_[0], other.tiling_[0]))) {
return false;
}
}
return true;
}

template <typename Stream>
void VectorLayout::print(Stream& os) const {
os << static_cast<int32_t>(bitwidth_) << ",{";
Expand Down Expand Up @@ -146,7 +315,8 @@ std::optional<VectorLayout> VectorLayout::parse(llvm::StringRef* data) {
}

namespace {
template<class> inline constexpr bool false_v = false;
template <class>
inline constexpr bool false_v = false;

template <typename Stream>
Stream& printLayout(Stream& os, const Layout& v) {
Expand All @@ -170,6 +340,10 @@ llvm::raw_ostream& operator<<(llvm::raw_ostream& os, const Layout& v) {
return printLayout<llvm::raw_ostream>(os, v);
}

mlir::Diagnostic& operator<<(mlir::Diagnostic& diag, const Layout& v) {
return printLayout<mlir::Diagnostic>(diag, v);
}

llvm::hash_code hash_value(const VectorLayout& layout) {
return llvm::hash_value(layout.as_tuple());
}
Expand Down

0 comments on commit d02b59e

Please sign in to comment.