Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 25 additions & 12 deletions mlir/include/mlir/Dialect/Tosa/IR/TargetEnv.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ TargetEnvAttr lookupTargetEnvOrDefault(Operation *op);
/// and provide utilities around the TOSA specification version.
class TosaSpecificationVersion {
public:
TosaSpecificationVersion() = default;

TosaSpecificationVersion(uint32_t major, uint32_t minor)
: majorVersion(major), minorVersion(minor) {}
TosaSpecificationVersion(SpecificationVersion version)
Expand Down Expand Up @@ -83,6 +85,10 @@ class TosaSpecificationVersion {
}
};

TosaSpecificationVersion getMinVersion(const Profile &profile);
TosaSpecificationVersion getMinVersion(const Extension &extension);
TosaSpecificationVersion getMinVersion(const Level &level);

llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version);

/// This class represents the capability enabled in the target implementation
Expand All @@ -91,22 +97,19 @@ llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version);
class TargetEnv {
public:
TargetEnv() {}
explicit TargetEnv(SpecificationVersion specificationVersion, Level level,
const ArrayRef<Profile> &profiles,
const ArrayRef<Extension> &extensions)
: specificationVersion(specificationVersion), level(level) {
enabledProfiles.insert_range(profiles);
enabledExtensions.insert_range(extensions);
}

explicit TargetEnv(TargetEnvAttr targetAttr)
: TargetEnv(targetAttr.getSpecificationVersion(), targetAttr.getLevel(),
targetAttr.getProfiles(), targetAttr.getExtensions()) {}
static FailureOr<TargetEnv>
createTargetEnvFromAttr(TargetEnvAttr targetAttr, Location targetEnvAttrLoc);

static LogicalResult verifyTargetInformation(TargetEnvAttr targetAttr,
Location targetAttrLoc);

void addProfile(Profile p) { enabledProfiles.insert(p); }
void addExtension(Extension e) { enabledExtensions.insert(e); }

SpecificationVersion getSpecVersion() const { return specificationVersion; }
TosaSpecificationVersion getSpecVersion() const {
return specificationVersion;
}

TosaLevel getLevel() const {
if (level == Level::eightK)
Expand Down Expand Up @@ -140,7 +143,17 @@ class TargetEnv {
}

private:
SpecificationVersion specificationVersion;
// Require target information is verified before constructing, via the use of
// `createTargetEnvFromAttr`.
explicit TargetEnv(SpecificationVersion specificationVersion, Level level,
const ArrayRef<Profile> &profiles,
const ArrayRef<Extension> &extensions)
: specificationVersion(specificationVersion), level(level) {
enabledProfiles.insert_range(profiles);
enabledExtensions.insert_range(extensions);
}

TosaSpecificationVersion specificationVersion;
Level level;
llvm::SmallSet<Profile, 3> enabledProfiles;
llvm::SmallSet<Extension, 13> enabledExtensions;
Expand Down
12 changes: 12 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,18 @@ extensionComplianceMap = {
allOf},
{{Extension::bf16},
{{{bf16T, bf16T, bf16T, bf16T, fp32T}, SpecificationVersion::V_1_0}}}}},
{"tosa.matmul_t_block_scaled",
{{{Extension::mxfp},
{{{fp4e2m1T, fp8ue8m0T, fp4e2m1T, fp8ue8m0T, fp32T},
SpecificationVersion::V_1_1_DRAFT},
{{fp6e2m3T, fp8ue8m0T, fp6e2m3T, fp8ue8m0T, fp32T},
SpecificationVersion::V_1_1_DRAFT},
{{fp6e3m2T, fp8ue8m0T, fp6e3m2T, fp8ue8m0T, fp32T},
SpecificationVersion::V_1_1_DRAFT},
{{fp8e4m3T, fp8ue8m0T, fp8e4m3T, fp8ue8m0T, fp32T},
SpecificationVersion::V_1_1_DRAFT},
{{fp8e5m2T, fp8ue8m0T, fp8e5m2T, fp8ue8m0T, fp32T},
SpecificationVersion::V_1_1_DRAFT}}}}},
{"tosa.max_pool2d",
{{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}},
{{Extension::fp8e4m3},
Expand Down
19 changes: 16 additions & 3 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -270,21 +270,22 @@ def Tosa_EXT_CONTROLFLOW : I32EnumAttrCase<"controlflow", 8>;
def Tosa_EXT_DOUBLEROUND : I32EnumAttrCase<"doubleround", 9>;
def Tosa_EXT_INEXACTROUND : I32EnumAttrCase<"inexactround", 10>;
def Tosa_EXT_DYNAMIC : I32EnumAttrCase<"dynamic", 11>;
def Tosa_EXT_MXFP : I32EnumAttrCase<"mxfp", 12>;

def Tosa_ExtensionAttr
: Tosa_I32EnumAttr<"Extension", "supported TOSA extensions", "ext", [
Tosa_EXT_NONE, Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16,
Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE,
Tosa_EXT_CONTROLFLOW, Tosa_EXT_DOUBLEROUND, Tosa_EXT_INEXACTROUND,
Tosa_EXT_DYNAMIC
Tosa_EXT_DYNAMIC, Tosa_EXT_MXFP
]> {
let extraClassDeclaration = [{
static llvm::SmallVector<Extension, 11> getAllValues() {
return {
Extension::int16, Extension::int4, Extension::bf16,
Extension::fp8e4m3, Extension::fp8e5m2, Extension::fft,
Extension::variable, Extension::controlflow, Extension::doubleround,
Extension::inexactround, Extension::dynamic
Extension::inexactround, Extension::dynamic, Extension::mxfp
};
}
}];
Expand Down Expand Up @@ -437,7 +438,7 @@ def Tosa_TargetEnv : Tosa_Attr<"TargetEnv", "target_env"> {
}

//===----------------------------------------------------------------------===//
// Iterable attributes.
// Enum attributes.
//===----------------------------------------------------------------------===//
// Defined in `section 3. Enumerations` of the TOSA specification.

Expand All @@ -463,6 +464,18 @@ def Tosa_RoundingModeAttr
: Tosa_I32EnumAttr<"RoundingMode", "Supported rounding modes", "rounding_mode",
[Tosa_ROUNDING_SINGLE_ROUND, Tosa_ROUNDING_INEXACT_ROUND, Tosa_ROUNDING_DOUBLE_ROUND]>;

def Tosa_BLOCK_SIZE_32 : I32EnumAttrCase<"BLOCK_SIZE_32", 32>;

def Tosa_BlockSizeAttr
: Tosa_I32EnumAttr<"BlockSize", "Block size for the block_scaled formats", "block_size",
[Tosa_BLOCK_SIZE_32]> {
let extraClassDeclaration = [{
static uint32_t getBlockSizeValue(BlockSize blockSize) {
return static_cast<uint32_t>(blockSize);
}
}];
}


//===----------------------------------------------------------------------===//
// TOSA Interfaces.
Expand Down
34 changes: 34 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,40 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
"operands attr-dict `:` functional-type(operands, results)";
}

//===----------------------------------------------------------------------===//
// Operator: matmul_t_block_scaled
//===----------------------------------------------------------------------===//
def Tosa_MatmulTBlockScaledOp : Tosa_InferShapedTypeOp<"matmul_t_block_scaled"> {
let summary = "Performs two dimensional matrix multiplications using block scaled tensors.";

let description = [{
Performs two dimensional matrix multiplications using block scaled tensors. The block
dimension is always the the last dimension of the tensor, so the result is effectively
a matrix multiply of A by the transposed B matrix. If the N dimension of input B is of
size 1, the B matrix will be broadcast.
}];

let arguments = (ins
Tosa_MXFPDataTensor3D:$a_data,
Tosa_MXFPScaleTensor3D:$a_scale,
Tosa_MXFPDataTensor3D:$b_data,
Tosa_MXFPScaleTensor3D:$b_scale,
Tosa_BlockSizeAttr:$block_size
);

let results = (outs
Tosa_Tensor3D:$output_data
);

let hasVerifier = 1;
let hasCustomAssemblyFormat = 1;

list<Availability> availability = [
Profile<[Tosa_PRO_FP]>,
Extension<[Tosa_EXT_MXFP]>
];
}

//===----------------------------------------------------------------------===//
// Operator: max_pool2d
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ class TosaProfileCompliance {
case Extension::fp8e4m3:
case Extension::fp8e5m2:
case Extension::fft:
case Extension::mxfp:
return {Profile::pro_fp};
case Extension::variable:
case Extension::controlflow:
Expand Down
13 changes: 13 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>,
def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
"number">;

def Tosa_MXFPNumber : AnyTypeOf<[F8E4M3FN, F8E5M2, F4E2M1FN, F6E2M3FN, F6E3M2FN],
"micro-scaling format number">;
def Tosa_MXFPScaleNumber : AnyTypeOf<[F8E8M0FNU], "micro-scaling format scale number">;

//===----------------------------------------------------------------------===//
// TOSA Tensor Conformance
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -187,6 +191,15 @@ def Tosa_Int32Tensor2D : AnyTypeOf<[
def Tosa_TensorAtLeast1D : AnyTypeOf<[
Tosa_UnrankedTensor, TosaRankedTensorOf<[Tosa_AnyNumber], [AtLeastRankOne]>], "tosa-conformant tensor of at least rank 1", "::mlir::TensorType">;

def Tosa_MXFPDataTensor3D : AnyTypeOf<[
TosaUnrankedTensorOf<[Tosa_MXFPNumber]>,
TosaTensorRankOf<[Tosa_MXFPNumber], [3]>
]>;
def Tosa_MXFPScaleTensor3D : AnyTypeOf<[
TosaUnrankedTensorOf<[Tosa_MXFPScaleNumber]>,
TosaTensorRankOf<[Tosa_MXFPScaleNumber], [3]>
]>;

//===----------------------------------------------------------------------===//
// Generic scalar, vector, or tensor of a particular type.
//===----------------------------------------------------------------------===//
Expand Down
94 changes: 90 additions & 4 deletions mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,96 @@
namespace mlir {
namespace tosa {

llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) {
return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor());
}

TosaSpecificationVersion getMinVersion(const Profile &profile) {
switch (profile) {
case Profile::pro_int:
case Profile::pro_fp:
return TosaSpecificationVersion(1, 0);
case Profile::none:
return TosaSpecificationVersion(0, 0);
}
llvm_unreachable("Unknown TOSA profile");
}

TosaSpecificationVersion getMinVersion(const Extension &extension) {
switch (extension) {
case Extension::int16:
case Extension::int4:
case Extension::bf16:
case Extension::fp8e4m3:
case Extension::fp8e5m2:
case Extension::fft:
case Extension::variable:
case Extension::controlflow:
case Extension::doubleround:
case Extension::inexactround:
case Extension::dynamic:
return TosaSpecificationVersion(1, 0);
case Extension::mxfp:
return TosaSpecificationVersion(1, 1);
case Extension::none:
return TosaSpecificationVersion(0, 0);
}
llvm_unreachable("Unknown TOSA extension");
}

TosaSpecificationVersion getMinVersion(const Level &level) {
switch (level) {
case Level::eightK:
case Level::none:
return TosaSpecificationVersion(1, 0);
}
llvm_unreachable("Unknown TOSA level");
}

FailureOr<TargetEnv>
TargetEnv::createTargetEnvFromAttr(TargetEnvAttr targetAttr,
Location targetEnvAttrLoc) {
if (failed(verifyTargetInformation(targetAttr, targetEnvAttrLoc)))
return failure();

return TargetEnv(targetAttr.getSpecificationVersion(), targetAttr.getLevel(),
targetAttr.getProfiles(), targetAttr.getExtensions());
}

LogicalResult TargetEnv::verifyTargetInformation(TargetEnvAttr targetAttr,
Location targetAttrLoc) {
TosaSpecificationVersion targetVersion(targetAttr.getSpecificationVersion());

const auto isCompatibleWithTargetVersion =
[&](const auto &targetEnum, Location targetAttrLoc,
StringRef enumName) -> LogicalResult {
const TosaSpecificationVersion minRequiredVersion =
getMinVersion(targetEnum);
if (!targetVersion.isBackwardsCompatibleWith(minRequiredVersion))
return emitError(targetAttrLoc, enumName)
<< " '" << stringifyEnum(targetEnum)
<< "' is not compatible with the target version "
<< stringifyVersion(targetVersion)
<< ", minimum required version is "
<< stringifyVersion(minRequiredVersion);
return success();
};

for (const auto &profile : targetAttr.getProfiles())
if (failed(
isCompatibleWithTargetVersion(profile, targetAttrLoc, "profile")))
return failure();
for (const auto &extension : targetAttr.getExtensions())
if (failed(isCompatibleWithTargetVersion(extension, targetAttrLoc,
"extension")))
return failure();
if (failed(isCompatibleWithTargetVersion(targetAttr.getLevel(), targetAttrLoc,
"level")))
return failure();

return success();
}

TargetEnvAttr lookupTargetEnv(Operation *op) {
while (op) {
op = SymbolTable::getNearestSymbolTable(op);
Expand Down Expand Up @@ -39,9 +129,5 @@ TargetEnvAttr lookupTargetEnvOrDefault(Operation *op) {
return getDefaultTargetEnv(op->getContext());
}

llvm::SmallString<4> stringifyVersion(TosaSpecificationVersion version) {
return llvm::formatv("{0}.{1}", version.getMajor(), version.getMinor());
}

} // namespace tosa
} // namespace mlir
Loading