@@ -764,7 +764,9 @@ loadNdDescTiles(PatternRewriter &rewriter, Location loc, ValueRange loadTiles,
764764 if (vnniConf) {
765765 vecLoadType = getVnniVector (tileType.getShape (), tileType.getElementType (),
766766 *vnniConf);
767- packedAttr = mlir::UnitAttr::get (rewriter.getContext ());
767+ if (!transpose_bit) {
768+ packedAttr = mlir::UnitAttr::get (rewriter.getContext ());
769+ }
768770 }
769771 SmallVector<Value> loadVec;
770772 for (auto tile : loadTiles) {
@@ -1165,7 +1167,6 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
11651167 if (vnniFactor == -1 )
11661168 return failure ();
11671169
1168- VnniConfig vnniConfA{.vnniFactor = vnniFactor, .vnniAxis = 1 };
11691170 VnniConfig vnniConfB{.vnniFactor = vnniFactor, .vnniAxis = 0 };
11701171
11711172 // Load A sub-tiles.
@@ -1212,9 +1213,9 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
12121213 }
12131214
12141215 // Extract DPAS tiles from loaded sub-tiles.
1215- TilesArray dpasVecA = extractVecSubTiles (rewriter, loc, loadVecA,
1216- {dimM, kTile }, tileTypeA. getShape () ,
1217- {dpasTileM, dpasTileK}, vnniConfA );
1216+ TilesArray dpasVecA =
1217+ extractVecSubTiles (rewriter, loc, loadVecA, {dimM, kTile },
1218+ tileTypeA. getShape (), {dpasTileM, dpasTileK});
12181219 TilesArray dpasVecB = extractVecSubTiles (rewriter, loc, loadVecB,
12191220 {kTile , dimN}, tileTypeB.getShape (),
12201221 {dpasTileK, dpasTileN}, vnniConfB);
@@ -1629,7 +1630,8 @@ struct LinalgToXeGPU : public gc::impl::LinalgToXeGPUBase<LinalgToXeGPU> {
16291630 using LinalgToXeGPUBase::LinalgToXeGPUBase;
16301631
16311632 void runOnOperation () override {
1632- LinalgToXeGPUOptions options{kTile , stages, dpasTile};
1633+ LinalgToXeGPUOptions options{
1634+ kTile , stages, SmallVector<int64_t >(dpasTile.begin (), dpasTile.end ())};
16331635
16341636 // Run GEMM pattern first to allow fusion with its consumers.
16351637 RewritePatternSet gemmPatterns (&getContext ());
0 commit comments