Skip to content

Commit

Permalink
Squashed commit of the following:
Browse files Browse the repository at this point in the history
commit 30e07df5e704ef85668093b7350bfdff1a24a7c8
Author: Captain Jack Sparrow <ritdas@microsoft.com>
Date:   Mon Apr 3 20:38:05 2023 +0000

    Merged PR 3199: Rename _slice to slice and add docs

    Rename _slice to slice and add docs

commit 52491f28481ec9ca555c563eaca249ce7d621ad1
Author: Captain Jack Sparrow <ritdas@microsoft.com>
Date:   Mon Apr 3 06:05:52 2023 +0000

    Merged PR 3197: Preserve dest memref shape during SliceOp to SubViewOp lowering

    Preserve dest memref shape during SliceOp to SubViewOp lowering:

    Without this change, subview op would discard the dest memref type required by the slice op. For example,

    ```
    %7 = "accv.slice"(%arg0, %6) {sliceDimensions = [0]} : (memref<1x30x256xui8>, index) -> memref<30x256xui8, affine_map<...>>
    ```

    would get lowered to:

    ```
    %4 = memref.subview %arg0[%3, 0, 0] [1, 30, 256] [1, 1, 1] : memref<1x30x256xui8> to memref<1x30x256xui8, affine_map<...>>
    %5 = memref.cast %4 : memref<1x30x256xui8, affine_map<...>> to memref<?x?x?xui8, affine_map<...>>
    ```
    which does not drop the first dimension as expected. With this fix, the slice op correctly lowers to:
    ```
    %4 = memref.subview %arg0[%3, 0, 0] [1, 30, 256] [1, 1, 1] : memref<1x30x256xui8> to memref<30x256xui8, affine_map<...>>
    %5 = memref.cast %4 : memref<30x256xui8, affine_map<...>> to memref<30x256xui8, affine_map<...>>
    ```

commit 79b6fba2b083e4f38b4b9b5f86d134ebbaf604de
Author: Denny Sun <dennys@microsoft.com>
Date:   Mon Apr 3 01:02:02 2023 +0000

    Merged PR 3194: Reorder the ops in GetTimeOpLowering to improve the timing accuracy

    In order to get the most accurate timing, we need to order the operations more appropriately,

    ```
    from
            Independent logic
            GetTime()
            Independent logic
            Main logic to profile
            Independent logic
            GetTime()
            Independent logic

    to

            Independent logic
            Independent logic
            GetTime()
            Main logic  to profile
            GetTime()
            Independent logic
            Independent logic
    ```

commit a24f82d514d8ebd5b06de4f5c36d2a13601f4ebe
Author: Denny Sun <dennys@microsoft.com>
Date:   Thu Mar 30 03:47:56 2023 +0000

    Merged PR 3187: Fully dynamic split_dimension op

    This change enable Accera to be able to split a dynamic dimension by a dynamic size

    ```
    `       M, N, MN = create_dimensions()

            Input = Array(role=Role.INPUT, element_type=ScalarType.float32, shape=(MN, ))
            Output = Array(role=Role.INPUT_OUTPUT, element_type=ScalarType.float32, shape=(M, N))

            nest = Nest(shape=(M, N))
            i, j = nest.get_indices()

            @nest.iteration_logic
            def _():
                split_input = Input._split_dimension(0, N)
                Output[i, j] = split_input[i, j]

           package.add(nest, args=(MN, M, N, Input, Output), base_name=f"{test_name}_fn")`
    ```

commit fe1955c975c3597afd6167203a4c9b7ef7cf4d9b
Author: Kern Handa <kerha@microsoft.com>
Date:   Wed Mar 29 21:18:06 2023 +0000

    Merged PR 3185: [nfc] Adds tests for vectorization, fast_exp_sum

commit 0f7daceebbfb7382c64678a624955c3c06e81765
Author: Captain Jack Sparrow <ritdas@microsoft.com>
Date:   Wed Mar 29 05:38:53 2023 +0000

    Merged PR 3168: [docs] Tensorization tutorials and type name updates
  • Loading branch information
Ritwik Das committed Apr 4, 2023
1 parent c42ca38 commit 8affe97
Show file tree
Hide file tree
Showing 64 changed files with 3,294 additions and 945 deletions.
440 changes: 425 additions & 15 deletions accera/acc-opt/test/vectorization.mlir

Large diffs are not rendered by default.

Expand Up @@ -109,7 +109,7 @@ namespace cpp_printer
}

auto memRefType = allocMatrixOp.result().getType().cast<MemRefType>();
const vir::MMAOp mfmaOpType{ static_cast<vir::MMAShape>(allocMatrixOp.mmaShapeType()) };
const vir::MMAOp mfmaOpType{ static_cast<vir::MMAShapeType>(allocMatrixOp.mmaShapeType()) };
const auto shape = std::make_tuple(mfmaOpType.getM(), mfmaOpType.getN(), mfmaOpType.getK());
const vir::MMAOperandType opType{ allocMatrixOp.operandType() };
const auto rowMajor = allocMatrixOp.rowMajor();
Expand All @@ -135,7 +135,7 @@ namespace cpp_printer

const auto operandType = static_cast<vir::MMAOperandType>(loadMatrixOp.operandType());

return printLoadMatrixOp(state, printer, loadMatrixOp.memref(), loadMatrixOp.dest(), operandType, loadMatrixOp.indices(), loadMatrixOp.rowMajor(), loadMatrixOp.blockThreadId(), loadMatrixOp.staticOffsets(), static_cast<vir::MMAFragmentOp>(loadMatrixOp.mmaPrologueOp()), loadMatrixOp.mmaPrologueArg());
return printLoadMatrixOp(state, printer, loadMatrixOp.memref(), loadMatrixOp.dest(), operandType, loadMatrixOp.indices(), loadMatrixOp.rowMajor(), loadMatrixOp.blockThreadId(), loadMatrixOp.staticOffsets(), static_cast<vir::MMAFragmentOpType>(loadMatrixOp.mmaPrologueOp()), loadMatrixOp.mmaPrologueArg());
}

LogicalResult AcceraDialectCppPrinter::printOp(vir::MMAComputeSyncOp computeMatrixOp)
Expand All @@ -155,7 +155,7 @@ namespace cpp_printer
return storeMatrixOp.emitError("non-cuda version is not supported.");
}

return printStoreMatrixOp(state, printer, storeMatrixOp.src(), storeMatrixOp.memref(), storeMatrixOp.indices(), storeMatrixOp.blockThreadId(), storeMatrixOp.staticOffsets(), static_cast<vir::MMAFragmentOp>(storeMatrixOp.mmaEpilogueOp()), storeMatrixOp.mmaEpilogueArg());
return printStoreMatrixOp(state, printer, storeMatrixOp.src(), storeMatrixOp.memref(), storeMatrixOp.indices(), storeMatrixOp.blockThreadId(), storeMatrixOp.staticOffsets(), static_cast<vir::MMAFragmentOpType>(storeMatrixOp.mmaEpilogueOp()), storeMatrixOp.mmaEpilogueArg());
}

LogicalResult AcceraDialectCppPrinter::printVectorType(mlir::Type elementType, const uint32_t stride) const
Expand Down Expand Up @@ -214,7 +214,7 @@ namespace cpp_printer
const auto wpt = blockLoadOp.workPerThread();
const auto vecWidth = blockLoadOp.vecWidth();
const auto stride = std::min(wpt, vecWidth);
const auto strategy = stringifyCacheStrategy(blockLoadOp.strategy());
const auto strategy = stringifyCacheStrategyType(blockLoadOp.strategy());

if (!blockLoadOp.srcToDst())
{
Expand Down
20 changes: 10 additions & 10 deletions accera/acc-translate/src/Target/Cpp/CppPrinterUtils.cpp
Expand Up @@ -187,7 +187,7 @@ namespace cpp_printer
return success();
}

LogicalResult printFragmentOp(CppPrinter* printer, Type elementType, const StringRef& fragName, const vir::MMAFragmentOp mmaFragmentOp, const StringRef& mmaFragmentArg)
LogicalResult printFragmentOp(CppPrinter* printer, Type elementType, const StringRef& fragName, const vir::MMAFragmentOpType mmaFragmentOp, const StringRef& mmaFragmentArg)
{
auto&& os = printer->getOStream();
auto fragDataName = fragName + "_data";
Expand All @@ -202,16 +202,16 @@ namespace cpp_printer
os << "); ++i) { ";
switch (mmaFragmentOp)
{
case vir::MMAFragmentOp::ReLU:
case vir::MMAFragmentOpType::ReLU:
os << "relu(" << fragDataName << "[i]);";
break;
case vir::MMAFragmentOp::ReLU_NoConditional:
case vir::MMAFragmentOpType::ReLU_NoConditional:
os << "relu_no_conditional(" << fragDataName << "[i]);";
break;
case vir::MMAFragmentOp::Set:
case vir::MMAFragmentOpType::Set:
os << "set(" << fragDataName << "[i], " << mmaFragmentArg << ");";
break;
case vir::MMAFragmentOp::Scale:
case vir::MMAFragmentOpType::Scale:
os << "scale(" << fragDataName << "[i], " << mmaFragmentArg << ");";
break;
default:
Expand All @@ -222,9 +222,9 @@ namespace cpp_printer
return success();
}

LogicalResult printLoadMatrixOp(PrinterState& state, CppPrinter* printer, Value src, Value dest, const vir::MMAOperandType operandType, mlir::Operation::operand_range indices, bool rowMajor, Value blockTid, const bool useStaticOffsets, const vir::MMAFragmentOp mmaPrologueOp, Value mmaPrologueArg)
LogicalResult printLoadMatrixOp(PrinterState& state, CppPrinter* printer, Value src, Value dest, const vir::MMAOperandType operandType, mlir::Operation::operand_range indices, bool rowMajor, Value blockTid, const bool useStaticOffsets, const vir::MMAFragmentOpType mmaPrologueOp, Value mmaPrologueArg)
{
if (mmaPrologueOp == vir::MMAFragmentOp::Set)
if (mmaPrologueOp == vir::MMAFragmentOpType::Set)
{
return printConstantMatrixOp(state, printer, dest, mmaPrologueArg);
}
Expand Down Expand Up @@ -264,7 +264,7 @@ namespace cpp_printer
}
os << ")";

if (mmaPrologueOp != vir::MMAFragmentOp::None)
if (mmaPrologueOp != vir::MMAFragmentOpType::None)
{
os << ";\n";

Expand Down Expand Up @@ -292,7 +292,7 @@ namespace cpp_printer
return success();
}

LogicalResult printStoreMatrixOp(PrinterState& state, CppPrinter* printer, Value src, Value dest, mlir::Operation::operand_range indices, Value blockTid, const bool useStaticOffsets, const vir::MMAFragmentOp mmaEpilogueOp, Value mmaEpilogueArg)
LogicalResult printStoreMatrixOp(PrinterState& state, CppPrinter* printer, Value src, Value dest, mlir::Operation::operand_range indices, Value blockTid, const bool useStaticOffsets, const vir::MMAFragmentOpType mmaEpilogueOp, Value mmaEpilogueArg)
{
auto fragName = state.nameState.getName(src);

Expand All @@ -310,7 +310,7 @@ namespace cpp_printer
const auto dstMemrefStr = getMemrefAccessStr(printer, sharedMem, memRefType, state.nameState.getName(dest).str(), indices);
auto&& os = printer->getOStream();

if (mmaEpilogueOp != vir::MMAFragmentOp::None)
if (mmaEpilogueOp != vir::MMAFragmentOpType::None)
{
auto srcElementType = src.getType().cast<MemRefType>().getElementType();
auto mmaEpilogueArgName = state.nameState.getOrCreateName(mmaEpilogueArg, SSANameState::SSANameKind::Variable, "mmaEpilogueArg_");
Expand Down
4 changes: 2 additions & 2 deletions accera/acc-translate/src/Target/Cpp/CppPrinterUtils.h
Expand Up @@ -24,9 +24,9 @@ namespace cpp_printer

LogicalResult printMMAMatrixOp(PrinterState& state, CppPrinter* printer, Type elementType, const std::tuple<int, int, int>& matrixShape, Value dest, vir::MMAOperandType operandType, int totalBlocks, int blocks, bool rowMajor);
LogicalResult printConstantMatrixOp(PrinterState& state, CppPrinter* printer, Value dest, Value value);
LogicalResult printLoadMatrixOp(PrinterState& state, CppPrinter* printer, Value src, Value dest, vir::MMAOperandType operandType, mlir::Operation::operand_range indices, bool rowMajor, Value blockTid = {}, bool useStaticOffsets = {}, vir::MMAFragmentOp mmaPrologueOp = vir::MMAFragmentOp::None, Value mmaPrologueArg = {});
LogicalResult printLoadMatrixOp(PrinterState& state, CppPrinter* printer, Value src, Value dest, vir::MMAOperandType operandType, mlir::Operation::operand_range indices, bool rowMajor, Value blockTid = {}, bool useStaticOffsets = {}, vir::MMAFragmentOpType mmaPrologueOp = vir::MMAFragmentOpType::None, Value mmaPrologueArg = {});
LogicalResult printComputeMatrixOp(PrinterState& state, CppPrinter* printer, Value A, Value B, Value C, Value D, int cbsz = 0, int abid = 0, int blgp = 0);
LogicalResult printStoreMatrixOp(PrinterState& state, CppPrinter* printer, Value src, Value dest, mlir::Operation::operand_range indices, Value blockTid = {}, bool useStaticOffsets = {}, vir::MMAFragmentOp mmaEpilogueOp = vir::MMAFragmentOp::None, Value mmaEpilogueArg = {});
LogicalResult printStoreMatrixOp(PrinterState& state, CppPrinter* printer, Value src, Value dest, mlir::Operation::operand_range indices, Value blockTid = {}, bool useStaticOffsets = {}, vir::MMAFragmentOpType mmaEpilogueOp = vir::MMAFragmentOpType::None, Value mmaEpilogueArg = {});

constexpr auto HipIncludesAndTypes = R"ROCM(
#if defined(__HIP_PLATFORM_HCC__)
Expand Down
12 changes: 6 additions & 6 deletions accera/ir/include/exec/ExecutionPlanOps.td
Expand Up @@ -189,7 +189,7 @@ def accxp_ActiveBlockCacheCopyOp : accxp_Op<"active_block_cache_copy", [AttrSize
UnitAttr:$toCache,
UnitAttr:$thrifty,
UnitAttr:$readOnlyCache,
CacheStrategyAttr:$strategy,
CacheStrategyTypeAttr:$strategy,
UnitAttr:$skipBarriers, // TODO : remove this once barrier analysis hoists barriers out of loops
OptionalAttr<accxp_VectorizationInfoAttr>:$vectorizationInfo,
OptionalAttr<accxp_TensorizationInfoAttr>:$tensorizationInfo);
Expand Down Expand Up @@ -218,7 +218,7 @@ def accxp_MultiCacheCopyOp : accxp_Op<"multi_cache_copy"> {
AffineMapAttr:$activeBlockToCacheMap,
UnitAttr:$thrifty,
UnitAttr:$readOnlyCache,
CacheStrategyAttr:$strategy,
CacheStrategyTypeAttr:$strategy,
UnitAttr:$toCache,
OptionalAttr<accxp_VectorizationInfoAttr>:$vectorizationInfo,
OptionalAttr<accxp_TensorizationInfoAttr>:$tensorizationInfo);
Expand Down Expand Up @@ -389,7 +389,7 @@ def accxp_BeginCreateCacheOp : accxp_Op<"begin_create_cache",
UnitAttr:$activeBlockCache,
UnitAttr:$dimReorderCache,
UnitAttr:$thrifty,
CacheStrategyAttr:$strategy,
CacheStrategyTypeAttr:$strategy,
UnitAttr:$doubleBufferCache,
OptionalAttr<MemorySpaceAttr>:$doubleBufferMemorySpace,
OptionalAttr<accxp_VectorizationInfoAttr>:$vectorizationInfo);
Expand All @@ -409,7 +409,7 @@ def accxp_BeginCreateCacheOp : accxp_Op<"begin_create_cache",
"bool":$activeBlockCache,
"bool":$dimReorderCache,
"bool":$thrifty,
"value::CacheStrategy":$strategy,
"value::CacheStrategyType":$strategy,
"bool":$doubleBufferCache,
"MemorySpace":$doubleBufferMemorySpace,
"const VectorizationInfo&":$vectorizationInfo)>
Expand Down Expand Up @@ -484,7 +484,7 @@ def accxp_BeginCreateMaxElementCacheOp : accxp_Op<"begin_create_max_element_cach
I64Attr:$maxElements,
UnitAttr:$dimReorderCache,
UnitAttr:$thrifty,
CacheStrategyAttr:$strategy,
CacheStrategyTypeAttr:$strategy,
UnitAttr:$doubleBufferCache,
OptionalAttr<MemorySpaceAttr>:$doubleBufferMemorySpace,
OptionalAttr<accxp_VectorizationInfoAttr>:$vectorizationInfo);
Expand All @@ -504,7 +504,7 @@ def accxp_BeginCreateMaxElementCacheOp : accxp_Op<"begin_create_max_element_cach
"int64_t":$cacheHierarchyLevel,
"bool":$dimReorderCache,
"bool":$thrifty,
"value::CacheStrategy":$strategy,
"value::CacheStrategyType":$strategy,
"bool":$doubleBufferCache,
"MemorySpace":$doubleBufferMemorySpace,
"const VectorizationInfo&":$vectorizationInfo)>
Expand Down
8 changes: 4 additions & 4 deletions accera/ir/include/exec/TensorizationInfo.h
Expand Up @@ -14,14 +14,14 @@ namespace executionPlan
{
struct TensorizationInfo
{
accera::ir::value::MMAShape dim;
accera::ir::value::MMAShapeType dim;
int numTotalPasses{ 1 };
bool useStaticOffsets{};
int numFusedPasses{ -1 };
accera::ir::value::MMASchedulingPolicy schedulingPolicy{};
accera::ir::value::MMAFragmentOp prologueOp{};
accera::ir::value::MMASchedulingPolicyType schedulingPolicy{};
accera::ir::value::MMAFragmentOpType prologueOp{};
double prologueArg{};
accera::ir::value::MMAFragmentOp epilogueOp{};
accera::ir::value::MMAFragmentOpType epilogueOp{};
double epilogueArg{};
bool _useRocWMMA{};

Expand Down
4 changes: 2 additions & 2 deletions accera/ir/include/value/ValueAttrs.td
Expand Up @@ -85,8 +85,8 @@ def MemorySpaceAttr : I64EnumAttr<
def CACHE_STRATEGY_BLOCKED : I32EnumAttrCase<"Blocked", 0>;
def CACHE_STRATEGY_STRIPED : I32EnumAttrCase<"Striped", 1>;

def CacheStrategyAttr : I32EnumAttr<
"CacheStrategy", "An attribute containing a cache strategy type enum",
def CacheStrategyTypeAttr : I32EnumAttr<
"CacheStrategyType", "An attribute containing a cache strategy type enum",
[CACHE_STRATEGY_BLOCKED, CACHE_STRATEGY_STRIPED]> {
let cppNamespace = "::accera::ir::value";
}
Expand Down
12 changes: 6 additions & 6 deletions accera/ir/include/value/ValueMMAOp.h
Expand Up @@ -23,7 +23,7 @@ using mlir::Type;

constexpr auto MFMAThreadBufferMapName = "threadOffsetsMFMA";

enum class MMAShape
enum class MMAShapeType
{
// The shapes below refer to the dimensions of the matmul operation
// they perform. The B{N} refers to the number of blocks the operation
Expand Down Expand Up @@ -55,13 +55,13 @@ enum class MMAOperandType
Acc
};

enum class MMASchedulingPolicy
enum class MMASchedulingPolicyType
{
BlockOrder,
PassOrder
};

enum class MMAFragmentOp
enum class MMAFragmentOpType
{
None,
ReLU,
Expand All @@ -73,9 +73,9 @@ enum class MMAFragmentOp
class MMAOp
{
public:
MMAOp(MMAShape shape);
MMAOp(MMAShapeType shape);

MMAShape getShapeType() const;
MMAShapeType getShapeType() const;
int getM() const { return m; }
int getN() const { return n; }
int getK() const { return k; }
Expand All @@ -87,7 +87,7 @@ class MMAOp
std::vector<int64_t> getOperandShape(MMAOperandType operandType) const;

private:
MMAShape shape;
MMAShapeType shape;
int m{};
int n{};
int k{};
Expand Down
10 changes: 5 additions & 5 deletions accera/ir/include/value/ValueOps.td
Expand Up @@ -1281,7 +1281,7 @@ def accv_GPUBlockCacheOp : accv_Op<"gpu_block_cache",
BoolAttr:$srcToDst,
BoolAttr:$dstRowMajor,
ArrayAttr:$tileShape,
CacheStrategyAttr:$strategy,
CacheStrategyTypeAttr:$strategy,
I32Attr:$vecWidth,
I32Attr:$blockDimX,
I32Attr:$blockDimY,
Expand All @@ -1293,7 +1293,7 @@ def accv_GPUBlockCacheOp : accv_Op<"gpu_block_cache",
let results = (outs MemRefOf<[UI8, I8, I32, F16, BF16, F32]>:$result);

let builders = [
OpBuilder<(ins "Type":$resultType, "Value":$blockThreadId, "CacheStrategy":$strategy, "Value":$memref, "Value":$srcOffsetRows, "Value":$srcOffsetCols, "Value":$dest, "bool":$srcToDst, "bool":$dstRowMajor, "ArrayAttr":$tileShape, "int32_t":$vecWidth, "const std::array<int64_t, 3>&":$blockDim, "int32_t":$workPerThread), [{
OpBuilder<(ins "Type":$resultType, "Value":$blockThreadId, "CacheStrategyType":$strategy, "Value":$memref, "Value":$srcOffsetRows, "Value":$srcOffsetCols, "Value":$dest, "bool":$srcToDst, "bool":$dstRowMajor, "ArrayAttr":$tileShape, "int32_t":$vecWidth, "const std::array<int64_t, 3>&":$blockDim, "int32_t":$workPerThread), [{
auto memrefType = memref.getType().cast<MemRefType>();
int64_t rank = memrefType.getRank();
// Create identity map for memrefs with at least one dimension or () -> () for zero-dimensional memrefs.
Expand All @@ -1303,7 +1303,7 @@ def accv_GPUBlockCacheOp : accv_Op<"gpu_block_cache",
$_state.addOperands(srcOffsetRows);
$_state.addOperands(srcOffsetCols);
$_state.addOperands(dest);
$_state.addAttribute("strategy", CacheStrategyAttr::get($_builder.getContext(), strategy));
$_state.addAttribute("strategy", CacheStrategyTypeAttr::get($_builder.getContext(), strategy));
$_state.addAttribute("tileShape", tileShape);
$_state.addAttribute("srcToDst", $_builder.getBoolAttr(srcToDst));
$_state.addAttribute("dstRowMajor", $_builder.getBoolAttr(dstRowMajor));
Expand Down Expand Up @@ -1442,7 +1442,7 @@ def accv_MMALoadSyncOp : accv_Op<"wmma_load_sync",
);

let builders = [
OpBuilder<(ins "Value":$blockThreadId, "Value":$memref, "Value":$dest, "MMAOperandType":$operandType, "bool":$rowMajor, "ValueRange":$indices, "bool":$staticOffsets, "MMAFragmentOp":$mmaPrologueOp, "Value":$mmaPrologueArg), [{
OpBuilder<(ins "Value":$blockThreadId, "Value":$memref, "Value":$dest, "MMAOperandType":$operandType, "bool":$rowMajor, "ValueRange":$indices, "bool":$staticOffsets, "MMAFragmentOpType":$mmaPrologueOp, "Value":$mmaPrologueArg), [{
auto memrefType = memref.getType().cast<MemRefType>();
int64_t rank = memrefType.getRank();
// Create identity map for memrefs with at least one dimension or () -> () for zero-dimensional memrefs.
Expand Down Expand Up @@ -1507,7 +1507,7 @@ def accv_MMAStoreSyncOp : accv_Op<"wmma_store_sync", [
AnyTypeOf<[I8, I16, I32, F16, BF16, F32]>:$mmaEpilogueArg);

let builders = [
OpBuilder<(ins "Value":$blockThreadId, "Value":$src, "Value":$memref, "ValueRange":$indices, "bool":$staticOffsets, "MMAFragmentOp":$mmaEpilogueOp, "Value":$mmaEpilogueArg), [{
OpBuilder<(ins "Value":$blockThreadId, "Value":$src, "Value":$memref, "ValueRange":$indices, "bool":$staticOffsets, "MMAFragmentOpType":$mmaEpilogueOp, "Value":$mmaEpilogueArg), [{
auto memrefType = memref.getType().cast<MemRefType>();
int64_t rank = memrefType.getRank();
// Create identity map for memrefs with at least one dimension or () -> ()
Expand Down
2 changes: 1 addition & 1 deletion accera/ir/src/exec/ExecutionPlanAttributes.cpp
Expand Up @@ -219,7 +219,7 @@ namespace executionPlan
return {};
if (useStaticOffsets != 0 && useStaticOffsets != 1)
return {};
return TensorizationInfoAttr::get(TensorizationInfo{ accera::ir::value::MMAShape{ dim }, numTotalPasses, useStaticOffsets, numFusedPasses, accera::ir::value::MMASchedulingPolicy{ schedulingPolicy }, accera::ir::value::MMAFragmentOp{ prologueOp }, prologueArg, accera::ir::value::MMAFragmentOp{ epilogueOp }, epilogueArg, _useRocWMMA }, parser.getBuilder().getContext());
return TensorizationInfoAttr::get(TensorizationInfo{ accera::ir::value::MMAShapeType{ dim }, numTotalPasses, useStaticOffsets, numFusedPasses, accera::ir::value::MMASchedulingPolicyType{ schedulingPolicy }, accera::ir::value::MMAFragmentOpType{ prologueOp }, prologueArg, accera::ir::value::MMAFragmentOpType{ epilogueOp }, epilogueArg, _useRocWMMA }, parser.getBuilder().getContext());
}

void print(TensorizationInfoAttr attr, mlir::DialectAsmPrinter& printer)
Expand Down

0 comments on commit 8affe97

Please sign in to comment.