Skip to content

Commit

Permalink
[Mosaic] Expose C API for VectorLayout, VRegDataBounds
Browse files Browse the repository at this point in the history
This is in preparation for Python bindings

PiperOrigin-RevId: 579355000
  • Loading branch information
tlongeri authored and jax authors committed Nov 4, 2023
1 parent 953f467 commit 1c1dd7c
Show file tree
Hide file tree
Showing 3 changed files with 418 additions and 6 deletions.
4 changes: 4 additions & 0 deletions jaxlib/mosaic/BUILD
Expand Up @@ -186,8 +186,12 @@ cc_library(
deps = [
":tpu_dialect",
":tpu_inc_gen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:CAPIIR",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
],
)

Expand Down
265 changes: 260 additions & 5 deletions jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc
Expand Up @@ -15,13 +15,114 @@ limitations under the License.

#include "jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h"

#include "mlir/include/mlir/CAPI/Pass.h"
#include "mlir/include/mlir/CAPI/Registration.h"
#include "mlir/include/mlir/CAPI/Support.h"
#include "mlir/include/mlir/IR/Attributes.h"
#include "mlir/include/mlir/IR/BuiltinAttributes.h"
#include <array>
#include <cstdint>
#include <cstring>
#include <memory>
#include <optional>
#include <utility>

#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/MemAlloc.h"
#include "mlir-c/IR.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Registration.h"
#include "mlir/CAPI/Wrap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "jaxlib/mosaic/dialect/tpu/layout.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"

// TODO(tlongeri): null pointer checks?

namespace {
DEFINE_C_API_PTR_METHODS(MlirTpuVectorLayout, mlir::tpu::VectorLayout);
DEFINE_C_API_PTR_METHODS(MlirTpuVregDataBounds, mlir::tpu::VRegDataBounds);

MlirTpuImplicitDim wrap(mlir::tpu::VectorLayout::ImplicitDim implicit_dim) {
switch (implicit_dim) {
case mlir::tpu::VectorLayout::ImplicitDim::kNone:
return MlirTpuImplicitDimNone;
case mlir::tpu::VectorLayout::ImplicitDim::kMinor:
return MlirTpuImplicitDimMinor;
case mlir::tpu::VectorLayout::ImplicitDim::kSecondMinor:
return MlirTpuImplicitDimSecondMinor;
}
LOG(FATAL) << "Invalid implicit dim (C++)";
}
mlir::tpu::VectorLayout::ImplicitDim unwrap(MlirTpuImplicitDim implicit_dim) {
switch (implicit_dim) {
case MlirTpuImplicitDimNone:
return mlir::tpu::VectorLayout::ImplicitDim::kNone;
case MlirTpuImplicitDimMinor:
return mlir::tpu::VectorLayout::ImplicitDim::kMinor;
case MlirTpuImplicitDimSecondMinor:
return mlir::tpu::VectorLayout::ImplicitDim::kSecondMinor;
}
LOG(FATAL) << "Invalid implicit dim (C)";
}
mlir::tpu::Direction unwrap(MlirTpuDirection direction) {
switch (direction) {
case MlirTpuDirectionSublanes:
return mlir::tpu::Direction::kSublanes;
case MlirTpuImplicitDimMinor:
return mlir::tpu::Direction::kLanes;
case MlirTpuImplicitDimSecondMinor:
return mlir::tpu::Direction::kSubelements;
}
LOG(FATAL) << "Invalid direction (C)";
}
MlirTpuLayoutOffsets wrap(mlir::tpu::LayoutOffsets offsets) {
return {offsets[0].value_or(-1), offsets[1].value_or(-1)};
}
mlir::tpu::LayoutOffsets unwrap(MlirTpuLayoutOffsets offsets) {
auto translateOffset = [](int64_t offset) {
CHECK_GE(offset, -1);
return offset == -1 ? std::nullopt : mlir::tpu::LayoutOffset{offset};
};
return {translateOffset(offsets.sublane), translateOffset(offsets.lane)};
}
std::array<bool, 2> unwrap(MlirTpuBoolTargetTuple arr) {
return {arr.sublane, arr.lane};
}
std::array<int64_t, 2> unwrap(MlirTpuI64TargetTuple arr) {
return {arr.sublane, arr.lane};
}
MlirTpuI64TargetTuple wrap(std::array<int64_t, 2> arr) {
return {arr[0], arr[1]};
}

mlir::OpBuilder mlirTpuInsertionPointToOpBuilder(
MlirTpuInsertionPoint insertion_point) {
mlir::Operation *ref_operation = unwrap(insertion_point.ref_operation);
return ref_operation == nullptr
? mlir::OpBuilder::atBlockEnd(unwrap(insertion_point.block))
: mlir::OpBuilder(ref_operation);
}

// We do not use the names wrap/unwrap for MlirTpuI64ArrayRef because whether
// they should refer to SmallVector or ArrayRef is ambiguous
MlirTpuI64ArrayRef mlirTpuI64ArrayRefFromLlvmSmallVector(
const mlir::SmallVector<int64_t> &vec) {
// TODO(tlongeri): It would be good to steal the buffer from implicit_shape,
// but there are no public member functions for this.
int64_t *ptr =
static_cast<int64_t *>(llvm::safe_malloc(vec.size() * sizeof(int64_t)));
memcpy(ptr, vec.data(), vec.size() * sizeof(int64_t));
return {ptr, vec.size()};
}
llvm::ArrayRef<int64_t> mlirTpuI64ArrayRefToLlvmArrayRef(
MlirTpuI64ArrayRef tpu_array_ref) {
return {tpu_array_ref.ptr, tpu_array_ref.size};
}
} // namespace

extern "C" {

MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(TPU, tpu, mlir::tpu::TPUDialect);
Expand Down Expand Up @@ -51,6 +152,160 @@ void mlirTPUAnalyzePotentialCommunication(MlirOperation op,
*has_custom_barrier = result.second;
}

MlirTpuVectorLayout mlirTpuVectorLayoutCreate(int bitwidth,
MlirTpuLayoutOffsets offsets,
MlirTpuI64TargetTuple tiling,
MlirTpuImplicitDim implicit_dim) {
return wrap(new mlir::tpu::VectorLayout(
bitwidth, unwrap(offsets), unwrap(tiling), unwrap(implicit_dim)));
}

void mlirTpuVectorLayoutDestroy(MlirTpuVectorLayout layout) {
delete unwrap(layout);
}

int mlirTpuVectorLayoutGetBitwidth(MlirTpuVectorLayout layout) {
return unwrap(layout)->bitwidth();
}

MlirTpuLayoutOffsets mlirTpuVectorLayoutGetOffsets(MlirTpuVectorLayout layout) {
return wrap(unwrap(layout)->offsets());
}

MlirTpuI64TargetTuple mlirTpuVectorLayoutGetTiling(MlirTpuVectorLayout layout) {
return wrap(unwrap(layout)->tiling());
}

MlirTpuImplicitDim mlirTpuVectorLayoutGetImplicitDim(
MlirTpuVectorLayout layout) {
return wrap(unwrap(layout)->implicit_dim());
}

int mlirTpuVectorLayoutGetPacking(MlirTpuVectorLayout layout) {
return unwrap(layout)->packing();
}

int mlirTpuVectorLayoutGetLayoutRank(MlirTpuVectorLayout layout) {
return unwrap(layout)->layout_rank();
}

bool mlirTpuVectorLayoutEquals(MlirTpuVectorLayout lhs,
MlirTpuVectorLayout rhs) {
return *unwrap(lhs) == *unwrap(rhs);
}

int64_t mlirTpuVectorLayoutTilesPerVreg(MlirTpuVectorLayout layout,
MlirTpuI64TargetTuple target_shape) {
return unwrap(layout)->tilesPerVreg(unwrap(target_shape));
}

int64_t mlirTpuVectorLayoutSublanesPerTile(MlirTpuVectorLayout layout,
MlirTpuI64TargetTuple target_shape) {
return unwrap(layout)->sublanesPerTile(unwrap(target_shape));
}

MlirTpuI64TargetTuple mlirTpuVectorLayoutVregSlice(
MlirTpuVectorLayout layout, MlirTpuI64TargetTuple target_shape) {
return wrap(unwrap(layout)->vregSlice(unwrap(target_shape)));
}

MlirTpuI64ArrayRef mlirTpuVectorLayoutImplicitShape(MlirTpuVectorLayout layout,
MlirTpuI64ArrayRef shape) {
mlir::SmallVector<int64_t> implicit_shape =
unwrap(layout)->implicitShape(mlirTpuI64ArrayRefToLlvmArrayRef(shape));
return mlirTpuI64ArrayRefFromLlvmSmallVector(implicit_shape);
}

MlirTpuI64ArrayRef mlirTpuVectorLayoutTileArrayShape(
MlirTpuVectorLayout layout, MlirTpuI64ArrayRef shape,
MlirTpuI64TargetTuple target_shape) {
mlir::SmallVector<int64_t> tile_array_shape = unwrap(layout)->tileArrayShape(
mlirTpuI64ArrayRefToLlvmArrayRef(shape), unwrap(target_shape));
return mlirTpuI64ArrayRefFromLlvmSmallVector(tile_array_shape);
}

MlirTpuVregDataBounds mlirTpuVectorLayoutTileDataBounds(
MlirTpuVectorLayout layout, MlirContext ctx, int64_t *full_shape,
int64_t *idxs, size_t size, MlirTpuI64TargetTuple target_shape,
MlirTpuBoolTargetTuple allow_replicated) {
std::unique_ptr<mlir::tpu::VRegDataBounds> ptr =
unwrap(layout)->tileDataBounds(
unwrap(ctx), llvm::ArrayRef<int64_t>{full_shape, size},
llvm::ArrayRef<int64_t>{idxs, size}, unwrap(target_shape),
unwrap(allow_replicated));
return wrap(ptr.release());
}

bool mlirTpuVectorLayoutHasNaturalTopology(MlirTpuVectorLayout layout,
MlirTpuI64TargetTuple target_shape) {
return unwrap(layout)->hasNaturalTopology(unwrap(target_shape));
}

bool mlirTpuVectorLayoutHasNativeTiling(MlirTpuVectorLayout layout,
MlirTpuI64TargetTuple target_shape) {
return unwrap(layout)->hasNativeTiling(unwrap(target_shape));
}

bool mlirTpuVectorLayoutGeneralizes(MlirTpuVectorLayout layout,
MlirTpuVectorLayout other,
MlirTpuI64ArrayRef shape,
MlirTpuI64TargetTuple target_shape) {
return unwrap(layout)->generalizes(*unwrap(other),
mlirTpuI64ArrayRefToLlvmArrayRef(shape),
unwrap(target_shape));
}

bool mlirTpuVectorLayoutEquivalentTo(MlirTpuVectorLayout layout,
MlirTpuVectorLayout other,
MlirTpuI64ArrayRef shape,
MlirTpuI64TargetTuple target_shape) {
return unwrap(layout)->equivalentTo(*unwrap(other),
mlirTpuI64ArrayRefToLlvmArrayRef(shape),
unwrap(target_shape));
}

void mlirTpuVregDataBoundsDestroy(MlirTpuVregDataBounds data_bounds) {
delete unwrap(data_bounds);
}

bool mlirTpuVregDataBoundsMaskVariesAlong(MlirTpuVregDataBounds data_bounds,
MlirTpuDirection direction,
MlirTpuI64TargetTuple target_shape) {
return unwrap(data_bounds)
->maskVariesAlong(unwrap(direction), unwrap(target_shape));
}

bool mlirTpuVregDataBoundsIsComplete(MlirTpuVregDataBounds data_bounds,
MlirTpuI64TargetTuple target_shape) {
return unwrap(data_bounds)->isComplete(unwrap(target_shape));
}

MlirValue mlirTpuVregDataBoundsGetVectorMask(
MlirTpuVregDataBounds data_bounds, MlirTpuInsertionPoint insertion_point,
MlirLocation location, int generation, MlirTpuI64TargetTuple target_shape) {
mlir::OpBuilder builder = mlirTpuInsertionPointToOpBuilder(insertion_point);
auto failure_or_mask = unwrap(data_bounds)
->getVectorMask(builder, unwrap(location),
generation, unwrap(target_shape));
if (failed(failure_or_mask)) {
return wrap(mlir::Value());
} else {
return wrap(failure_or_mask.value());
}
}

MlirAttribute mlirTpuVregDataBoundsGetSublaneMask(
MlirTpuVregDataBounds data_bounds, MlirContext ctx,
MlirTpuI64TargetTuple target_shape) {
return wrap(
unwrap(data_bounds)->getSublaneMask(unwrap(ctx), unwrap(target_shape)));
}
}

#include "mlir/CAPI/Pass.h" // IWYU pragma: keep
#include "mlir/CAPI/Support.h" // IWYU pragma: keep

extern "C" {
using namespace mlir::tpu;

#include "jaxlib/mosaic/dialect/tpu/integrations/c/tpu_passes.capi.cc.inc"
Expand Down

0 comments on commit 1c1dd7c

Please sign in to comment.