diff --git a/llvm/include/llvm/IR/IntrinsicsX86.td b/llvm/include/llvm/IR/IntrinsicsX86.td index ae0a416175f9e1..eba83493e686d9 100644 --- a/llvm/include/llvm/IR/IntrinsicsX86.td +++ b/llvm/include/llvm/IR/IntrinsicsX86.td @@ -5093,6 +5093,10 @@ let TargetPrefix = "x86" in { [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty, llvm_x86amx_ty, llvm_x86amx_ty, llvm_x86amx_ty], []>; + def int_x86_cast_vector_to_tile: + Intrinsic<[llvm_x86amx_ty], [llvm_anyvector_ty], [IntrNoMem]>; + def int_x86_cast_tile_to_vector: + Intrinsic<[llvm_anyvector_ty], [llvm_x86amx_ty], [IntrNoMem]>; } //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp index 4ba44ccb6c1607..a2bcc98f3d5b64 100644 --- a/llvm/lib/Target/X86/X86LowerAMXType.cpp +++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp @@ -40,8 +40,10 @@ // #include "X86.h" #include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallSet.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/CodeGen/Passes.h" #include "llvm/CodeGen/TargetPassConfig.h" @@ -56,66 +58,44 @@ #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Target/TargetMachine.h" +#include "llvm/Transforms/Utils/AssumeBundleBuilder.h" +#include "llvm/Transforms/Utils/Local.h" using namespace llvm; using namespace PatternMatch; #define DEBUG_TYPE "lower-amx-type" -static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, - BasicBlock *BB) { +static bool isAMXCast(Instruction *II) { + return match(II, + m_Intrinsic(m_Value())) || + match(II, m_Intrinsic(m_Value())); +} + +static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB, + Type *Ty) { Function &F = *BB->getParent(); Module *M = BB->getModule(); const DataLayout &DL = M->getDataLayout(); - Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false); LLVMContext &Ctx = Builder.getContext(); auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx)); unsigned AllocaAS = DL.getAllocaAddrSpace(); AllocaInst *AllocaRes = - new AllocaInst(V256I32Ty, AllocaAS, "", &F.getEntryBlock().front()); + new AllocaInst(Ty, AllocaAS, "", &F.getEntryBlock().front()); AllocaRes->setAlignment(AllocaAlignment); return AllocaRes; } -namespace { -class X86LowerAMXType { - Function &Func; - TargetMachine *TM = nullptr; - - // In AMX intrinsics we let Shape = {Row, Col}, but the - // RealCol = Col / ElementSize. We may use the RealCol - // as a new Row for other new created AMX intrinsics. - std::map Col2Row; - -public: - X86LowerAMXType(Function &F, TargetMachine *TargetM) : Func(F), TM(TargetM) {} - bool visit(); - void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast); - void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST); - bool transformBitcast(BitCastInst *Bitcast); - std::pair getShape(IntrinsicInst *II, unsigned OpNo); - Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity); -}; - -Value *X86LowerAMXType::getRowFromCol(Instruction *II, Value *V, - unsigned Granularity) { - if (Col2Row.count(V)) - return Col2Row[V]; - IRBuilder<> Builder(&*II->getParent()->getFirstInsertionPt()); - if (auto *I = dyn_cast(V)) { - BasicBlock::iterator Iter = I->getIterator(); - ++Iter; - Builder.SetInsertPoint(&*Iter); - } - ConstantInt *Gran = Builder.getInt16(Granularity); - Value *RealRow = Builder.CreateUDiv(V, Gran); - Col2Row[V] = RealRow; - return RealRow; +static Instruction *getFirstNonAllocaInTheEntryBlock(Function &F) { + for (Instruction &I : F.getEntryBlock()) + if (!isa(&I)) + return &I; + llvm_unreachable("No terminator in the entry block!"); } -std::pair X86LowerAMXType::getShape(IntrinsicInst *II, - unsigned OpNo) { +static std::pair getShape(IntrinsicInst *II, unsigned OpNo) { + IRBuilder<> Builder(II); Value *Row = nullptr, *Col = nullptr; switch (II->getIntrinsicID()) { default: @@ -144,14 +124,32 @@ std::pair X86LowerAMXType::getShape(IntrinsicInst *II, Col = II->getArgOperand(2); break; case 5: - Row = II->getArgOperand(2); - // FIXME: There is a design bug for AMX shape, which the Col should be - // Col/4 if it will be used as Row, but current Greedy RA can't handle - // this case well, it may failed if we generate a new Shape definition. - // So Let's just do it in O0 first. - // Row = Row / 4 - if (TM->getOptLevel() == CodeGenOpt::None) - Row = getRowFromCol(II, Row, 4); + if (isa(II->getArgOperand(2))) + Row = Builder.getInt16( + (dyn_cast(II->getOperand(2))->getSExtValue()) / 4); + else if (isa(II->getArgOperand(2))) { + // When it is not a const value and it is not a function argument, we + // create Row after the definition of II->getOperand(2) instead of + // before II. For example, II is %118, we try to getshape for %117: + // %117 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x + // i32> %115). + // %118 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16 + // %104, i16 %105, i16 %106, x86_amx %110, x86_amx %114, x86_amx + // %117). + // If we create %row = udiv i16 %106, 4 before %118(aka. II), then its + // definition is after its user(new tileload for %117). + // So, the best choice is to create %row right after the definition of + // %106. + Builder.SetInsertPoint(cast(II->getOperand(2))); + Row = Builder.CreateUDiv(II->getOperand(2), Builder.getInt16(4)); + cast(Row)->moveAfter(cast(II->getOperand(2))); + } else { + // When it is not a const value and it is a function argument, we create + // Row at the entry bb. + IRBuilder<> NewBuilder( + getFirstNonAllocaInTheEntryBlock(*II->getFunction())); + Row = NewBuilder.CreateUDiv(II->getOperand(2), NewBuilder.getInt16(4)); + } Col = II->getArgOperand(1); break; } @@ -162,6 +160,40 @@ std::pair X86LowerAMXType::getShape(IntrinsicInst *II, return std::make_pair(Row, Col); } +namespace { +class X86LowerAMXType { + Function &Func; + + // In AMX intrinsics we let Shape = {Row, Col}, but the + // RealCol = Col / ElementSize. We may use the RealCol + // as a new Row for other new created AMX intrinsics. + std::map Col2Row; + +public: + X86LowerAMXType(Function &F) : Func(F) {} + bool visit(); + void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast); + void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST); + bool transformBitcast(BitCastInst *Bitcast); + Value *getRowFromCol(Instruction *II, Value *V, unsigned Granularity); +}; + +Value *X86LowerAMXType::getRowFromCol(Instruction *II, Value *V, + unsigned Granularity) { + if (Col2Row.count(V)) + return Col2Row[V]; + IRBuilder<> Builder(&*II->getParent()->getFirstInsertionPt()); + if (auto *I = dyn_cast(V)) { + BasicBlock::iterator Iter = I->getIterator(); + ++Iter; + Builder.SetInsertPoint(&*Iter); + } + ConstantInt *Gran = Builder.getInt16(Granularity); + Value *RealRow = Builder.CreateUDiv(V, Gran); + Col2Row[V] = RealRow; + return RealRow; +} + // %src = load <256 x i32>, <256 x i32>* %addr, align 64 // %2 = bitcast <256 x i32> %src to x86_amx // --> @@ -230,8 +262,8 @@ bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) { Value *I8Ptr, *Stride; auto *Src = Bitcast->getOperand(0); - auto Prepare = [&]() { - AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent()); + auto Prepare = [&](Type *MemTy) { + AllocaAddr = createAllocaInstAtEntry(Builder, Bitcast->getParent(), MemTy); I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy()); Stride = Builder.getInt64(64); }; @@ -250,7 +282,7 @@ bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) { auto *II = dyn_cast(U.getUser()); if (!II) return false; // May be bitcast from x86amx to <256 x i32>. - Prepare(); + Prepare(Bitcast->getOperand(0)->getType()); Builder.CreateStore(Src, AllocaAddr); // TODO we can pick an constant operand for the shape. Value *Row = nullptr, *Col = nullptr; @@ -270,7 +302,7 @@ bool X86LowerAMXType::transformBitcast(BitCastInst *Bitcast) { auto *II = dyn_cast(Src); if (!II) return false; // May be bitcast from <256 x i32> to x86amx. - Prepare(); + Prepare(Bitcast->getType()); Value *Row = II->getOperand(0); Value *Col = II->getOperand(1); std::array Args = {Row, Col, I8Ptr, Stride, Src}; @@ -637,6 +669,364 @@ bool X86VolatileTileData::volatileTileData() { namespace { +class X86LowerAMXCast { + Function &Func; + +public: + X86LowerAMXCast(Function &F) : Func(F) {} + bool combineAMXcast(TargetLibraryInfo *TLI); + bool transformAMXCast(IntrinsicInst *AMXCast); + bool transformAllAMXCast(); + bool optimizeAMXCastFromPhi(IntrinsicInst *CI, PHINode *PN, + SmallSetVector &DeadInst); +}; + +static bool DCEInstruction(Instruction *I, + SmallSetVector &WorkList, + const TargetLibraryInfo *TLI) { + if (isInstructionTriviallyDead(I, TLI)) { + salvageDebugInfo(*I); + salvageKnowledge(I); + + // Null out all of the instruction's operands to see if any operand becomes + // dead as we go. + for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { + Value *OpV = I->getOperand(i); + I->setOperand(i, nullptr); + + if (!OpV->use_empty() || I == OpV) + continue; + + // If the operand is an instruction that became dead as we nulled out the + // operand, and if it is 'trivially' dead, delete it in a future loop + // iteration. + if (Instruction *OpI = dyn_cast(OpV)) { + if (isInstructionTriviallyDead(OpI, TLI)) { + WorkList.insert(OpI); + } + } + } + I->eraseFromParent(); + return true; + } + return false; +} + +/// This function handles following case +/// +/// A -> B amxcast +/// PHI +/// B -> A amxcast +/// +/// All the related PHI nodes can be replaced by new PHI nodes with type A. +/// The uses of \p CI can be changed to the new PHI node corresponding to \p PN. +bool X86LowerAMXCast::optimizeAMXCastFromPhi( + IntrinsicInst *CI, PHINode *PN, + SmallSetVector &DeadInst) { + IRBuilder<> Builder(CI); + Value *Src = CI->getOperand(0); + Type *SrcTy = Src->getType(); // Type B + Type *DestTy = CI->getType(); // Type A + + SmallVector PhiWorklist; + SmallSetVector OldPhiNodes; + + // Find all of the A->B casts and PHI nodes. + // We need to inspect all related PHI nodes, but PHIs can be cyclic, so + // OldPhiNodes is used to track all known PHI nodes, before adding a new + // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first. + PhiWorklist.push_back(PN); + OldPhiNodes.insert(PN); + while (!PhiWorklist.empty()) { + auto *OldPN = PhiWorklist.pop_back_val(); + for (Value *IncValue : OldPN->incoming_values()) { + // TODO: currently, We ignore cases where it is a const. In the future, we + // might support const. + if (isa(IncValue)) + return false; + + if (auto *PNode = dyn_cast(IncValue)) { + if (OldPhiNodes.insert(PNode)) + PhiWorklist.push_back(PNode); + continue; + } + Instruction *ACI = dyn_cast(IncValue); + if (ACI && isAMXCast(ACI)) { + // Verify it's a A->B cast. + Type *TyA = ACI->getOperand(0)->getType(); + Type *TyB = ACI->getType(); + if (TyA != DestTy || TyB != SrcTy) + return false; + continue; + } + return false; + } + } + + // Check that each user of each old PHI node is something that we can + // rewrite, so that all of the old PHI nodes can be cleaned up afterwards. + for (auto *OldPN : OldPhiNodes) { + for (User *V : OldPN->users()) { + Instruction *ACI = dyn_cast(V); + if (ACI && isAMXCast(ACI)) { + // Verify it's a B->A cast. + Type *TyB = ACI->getOperand(0)->getType(); + Type *TyA = ACI->getType(); + if (TyA != DestTy || TyB != SrcTy) + return false; + } else if (auto *PHI = dyn_cast(V)) { + // As long as the user is another old PHI node, then even if we don't + // rewrite it, the PHI web we're considering won't have any users + // outside itself, so it'll be dead. + // example: + // bb.0: + // %0 = amxcast ... + // bb.1: + // %1 = amxcast ... + // bb.2: + // %goodphi = phi %0, %1 + // %3 = amxcast %goodphi + // bb.3: + // %goodphi2 = phi %0, %goodphi + // %4 = amxcast %goodphi2 + // When optimizeAMXCastFromPhi process %3 and %goodphi, %goodphi2 is + // outside the phi-web, so the combination stop When + // optimizeAMXCastFromPhi process %4 and %goodphi2, the optimization + // will be done. + if (OldPhiNodes.count(PHI) == 0) + return false; + } else + return false; + } + } + + // For each old PHI node, create a corresponding new PHI node with a type A. + SmallDenseMap NewPNodes; + for (auto *OldPN : OldPhiNodes) { + Builder.SetInsertPoint(OldPN); + PHINode *NewPN = Builder.CreatePHI(DestTy, OldPN->getNumOperands()); + NewPNodes[OldPN] = NewPN; + } + + // Fill in the operands of new PHI nodes. + for (auto *OldPN : OldPhiNodes) { + PHINode *NewPN = NewPNodes[OldPN]; + for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) { + Value *V = OldPN->getOperand(j); + Value *NewV = nullptr; + Instruction *ACI = dyn_cast(V); + // There should not be a AMXcast from a const. + if (ACI && isAMXCast(ACI)) + NewV = ACI->getOperand(0); + else if (auto *PrevPN = dyn_cast(V)) + NewV = NewPNodes[PrevPN]; + assert(NewV); + NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j)); + } + } + + // Traverse all accumulated PHI nodes and process its users, + // which are Stores and BitcCasts. Without this processing + // NewPHI nodes could be replicated and could lead to extra + // moves generated after DeSSA. + // If there is a store with type B, change it to type A. + + // Replace users of BitCast B->A with NewPHI. These will help + // later to get rid of a closure formed by OldPHI nodes. + for (auto *OldPN : OldPhiNodes) { + PHINode *NewPN = NewPNodes[OldPN]; + for (User *V : make_early_inc_range(OldPN->users())) { + Instruction *ACI = dyn_cast(V); + if (ACI && isAMXCast(ACI)) { + Type *TyB = ACI->getOperand(0)->getType(); + Type *TyA = ACI->getType(); + assert(TyA == DestTy && TyB == SrcTy); + (void)TyA; + (void)TyB; + ACI->replaceAllUsesWith(NewPN); + DeadInst.insert(ACI); + } else if (auto *PHI = dyn_cast(V)) { + // We don't need to push PHINode into DeadInst since they are operands + // of rootPN DCE can safely delete rootPN's operands if rootPN is dead. + assert(OldPhiNodes.contains(PHI)); + (void)PHI; + } else + llvm_unreachable("all uses should be handled"); + } + } + return true; +} + +bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) { + bool Change = false; + // Collect tile cast instruction. + SmallVector Vec2TileInsts; + SmallVector Tile2VecInsts; + SmallVector PhiCastWorkList; + SmallSetVector DeadInst; + for (BasicBlock &BB : Func) { + for (Instruction &I : BB) { + Value *Vec; + if (match(&I, + m_Intrinsic(m_Value(Vec)))) + Vec2TileInsts.push_back(&I); + else if (match(&I, m_Intrinsic( + m_Value(Vec)))) + Tile2VecInsts.push_back(&I); + } + } + + auto Convert = [&](SmallVectorImpl &Insts, Intrinsic::ID IID) { + for (auto *Inst : Insts) { + for (User *U : Inst->users()) { + IntrinsicInst *II = dyn_cast(U); + if (!II || II->getIntrinsicID() != IID) + continue; + // T1 = vec2tile V0 + // V2 = tile2vec T1 + // V3 = OP V2 + // --> + // T1 = vec2tile V0 + // V2 = tile2vec T1 + // V3 = OP V0 + II->replaceAllUsesWith(Inst->getOperand(0)); + Change = true; + } + } + }; + + Convert(Vec2TileInsts, Intrinsic::x86_cast_tile_to_vector); + Convert(Tile2VecInsts, Intrinsic::x86_cast_vector_to_tile); + + auto EraseInst = [](SmallVectorImpl &Insts) { + for (auto *Inst : Insts) { + if (Inst->use_empty()) + Inst->eraseFromParent(); + } + }; + + EraseInst(Vec2TileInsts); + EraseInst(Tile2VecInsts); + + // Handle the A->B->A cast, and there is an intervening PHI node. + for (BasicBlock &BB : Func) { + for (Instruction &I : BB) { + if (isAMXCast(&I)) { + if (PHINode *PN = dyn_cast(I.getOperand(0))) + PhiCastWorkList.push_back(&I); + } + } + } + for (auto *I : PhiCastWorkList) { + // We skip the dead Amxcast. + if (DeadInst.contains(I)) + continue; + PHINode *PN = cast(I->getOperand(0)); + if (optimizeAMXCastFromPhi(cast(I), PN, DeadInst)) { + DeadInst.insert(PN); + Change = true; + } + } + + // Since we create new phi and merge AMXCast, some old phis and AMXCast might + // have no uses. We do some DeadCodeElimination for them. + while (!DeadInst.empty()) { + Instruction *I = DeadInst.pop_back_val(); + Change |= DCEInstruction(I, DeadInst, TLI); + } + return Change; +} + +// There might be remaining AMXcast after combineAMXcast and they should be +// handled elegantly. +bool X86LowerAMXCast::transformAMXCast(IntrinsicInst *AMXCast) { + IRBuilder<> Builder(AMXCast); + AllocaInst *AllocaAddr; + Value *I8Ptr, *Stride; + auto *Src = AMXCast->getOperand(0); + + auto Prepare = [&](Type *MemTy) { + AllocaAddr = createAllocaInstAtEntry(Builder, AMXCast->getParent(), MemTy); + I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy()); + Stride = Builder.getInt64(64); + }; + + if (AMXCast->getType()->isX86_AMXTy()) { + // %2 = amxcast <225 x i32> %src to x86_amx + // call void @llvm.x86.tilestored64.internal(i16 15, i16 60, + // i8* %addr3, i64 60, x86_amx %2) + // --> + // %addr = alloca <225 x i32>, align 64 + // store <225 x i32> %src, <225 x i32>* %addr, align 64 + // %addr2 = bitcast <225 x i32>* %addr to i8* + // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 15, i16 60, + // i8* %addr2, + // i64 60) + // call void @llvm.x86.tilestored64.internal(i16 15, i16 60, + // i8* %addr3, i64 60, x86_amx %2) + Use &U = *(AMXCast->use_begin()); + unsigned OpNo = U.getOperandNo(); + auto *II = dyn_cast(U.getUser()); + if (!II) + return false; // May be bitcast from x86amx to <256 x i32>. + Prepare(AMXCast->getOperand(0)->getType()); + Builder.CreateStore(Src, AllocaAddr); + // TODO we can pick an constant operand for the shape. + Value *Row = nullptr, *Col = nullptr; + std::tie(Row, Col) = getShape(II, OpNo); + std::array Args = { + Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty())}; + Value *NewInst = Builder.CreateIntrinsic( + Intrinsic::x86_tileloadd64_internal, None, Args); + AMXCast->replaceAllUsesWith(NewInst); + AMXCast->eraseFromParent(); + } else { + // %2 = amxcast x86_amx %src to <225 x i32> + // --> + // %addr = alloca <225 x i32>, align 64 + // %addr2 = bitcast <225 x i32>* to i8* + // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, + // i8* %addr2, i64 %stride) + // %2 = load <225 x i32>, <225 x i32>* %addr, align 64 + auto *II = dyn_cast(Src); + if (!II) + return false; // May be bitcast from <256 x i32> to x86amx. + Prepare(AMXCast->getType()); + Value *Row = II->getOperand(0); + Value *Col = II->getOperand(1); + std::array Args = { + Row, Col, I8Ptr, Builder.CreateSExt(Col, Builder.getInt64Ty()), Src}; + Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); + Value *NewInst = Builder.CreateLoad(AMXCast->getType(), AllocaAddr); + AMXCast->replaceAllUsesWith(NewInst); + AMXCast->eraseFromParent(); + } + + return true; +} + +bool X86LowerAMXCast::transformAllAMXCast() { + bool Change = false; + // Collect tile cast instruction. + SmallVector WorkLists; + for (BasicBlock &BB : Func) { + for (Instruction &I : BB) { + if (isAMXCast(&I)) + WorkLists.push_back(&I); + } + } + + for (auto *Inst : WorkLists) { + Change |= transformAMXCast(cast(Inst)); + } + + return Change; +} + +} // anonymous namespace + +namespace { + class X86LowerAMXTypeLegacyPass : public FunctionPass { public: static char ID; @@ -647,8 +1037,15 @@ class X86LowerAMXTypeLegacyPass : public FunctionPass { bool runOnFunction(Function &F) override { TargetMachine *TM = &getAnalysis().getTM(); + TargetLibraryInfo *TLI = + &getAnalysis().getTLI(F); + X86LowerAMXCast LAC(F); + LAC.combineAMXcast(TLI); + // There might be remaining AMXcast after combineAMXcast and they should be + // handled elegantly. + LAC.transformAllAMXCast(); - X86LowerAMXType LAT(F, TM); + X86LowerAMXType LAT(F); bool C = LAT.visit(); // Prepare for fast register allocation at O0. @@ -671,6 +1068,7 @@ class X86LowerAMXTypeLegacyPass : public FunctionPass { void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); AU.addRequired(); + AU.addRequired(); } }; @@ -681,6 +1079,7 @@ char X86LowerAMXTypeLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, false) INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, false) diff --git a/llvm/test/CodeGen/X86/AMX/amx-type.ll b/llvm/test/CodeGen/X86/AMX/amx-type.ll index 989a1076ce7a6d..ddf650525baaa8 100644 --- a/llvm/test/CodeGen/X86/AMX/amx-type.ll +++ b/llvm/test/CodeGen/X86/AMX/amx-type.ll @@ -163,18 +163,19 @@ define dso_local void @__tile_dpbssd(%struct.__tile_str* nocapture %0, %struct._ ; CHECK-NEXT: [[TMP7:%.*]] = load i16, i16* [[TMP6]], align 2 ; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP1]], i64 0, i32 1 ; CHECK-NEXT: [[TMP9:%.*]] = load i16, i16* [[TMP8]], align 2 -; CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0:%.*]], i64 0, i32 2 -; CHECK-NEXT: [[TMP11:%.*]] = bitcast <256 x i32>* [[TMP10]] to i8* -; CHECK-NEXT: [[TMP12:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP11]], i64 64) -; CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP1]], i64 0, i32 2 -; CHECK-NEXT: [[TMP14:%.*]] = bitcast <256 x i32>* [[TMP13]] to i8* -; CHECK-NEXT: [[TMP15:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP9]], i8* [[TMP14]], i64 64) -; CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 2 -; CHECK-NEXT: [[TMP17:%.*]] = bitcast <256 x i32>* [[TMP16]] to i8* -; CHECK-NEXT: [[TMP18:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP9]], i16 [[TMP7]], i8* [[TMP17]], i64 64) -; CHECK-NEXT: [[TMP19:%.*]] = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 [[TMP5]], i16 [[TMP7]], i16 [[TMP9]], x86_amx [[TMP12]], x86_amx [[TMP15]], x86_amx [[TMP18]]) -; CHECK-NEXT: [[TMP20:%.*]] = bitcast <256 x i32>* [[TMP10]] to i8* -; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP20]], i64 64, x86_amx [[TMP19]]) +; CHECK-NEXT: [[TMP10:%.*]] = udiv i16 [[TMP9]], 4 +; CHECK-NEXT: [[TMP11:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0:%.*]], i64 0, i32 2 +; CHECK-NEXT: [[TMP12:%.*]] = bitcast <256 x i32>* [[TMP11]] to i8* +; CHECK-NEXT: [[TMP13:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP12]], i64 64) +; CHECK-NEXT: [[TMP14:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP1]], i64 0, i32 2 +; CHECK-NEXT: [[TMP15:%.*]] = bitcast <256 x i32>* [[TMP14]] to i8* +; CHECK-NEXT: [[TMP16:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP9]], i8* [[TMP15]], i64 64) +; CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 2 +; CHECK-NEXT: [[TMP18:%.*]] = bitcast <256 x i32>* [[TMP17]] to i8* +; CHECK-NEXT: [[TMP19:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP10]], i16 [[TMP7]], i8* [[TMP18]], i64 64) +; CHECK-NEXT: [[TMP20:%.*]] = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 [[TMP5]], i16 [[TMP7]], i16 [[TMP9]], x86_amx [[TMP13]], x86_amx [[TMP16]], x86_amx [[TMP19]]) +; CHECK-NEXT: [[TMP21:%.*]] = bitcast <256 x i32>* [[TMP11]] to i8* +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP21]], i64 64, x86_amx [[TMP20]]) ; CHECK-NEXT: ret void ; %4 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %1, i64 0, i32 0 @@ -200,15 +201,16 @@ define dso_local void @__tile_dpbssd(%struct.__tile_str* nocapture %0, %struct._ define dso_local void @__tile_dpbsud(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <256 x i32>* %pa, <256 x i32>* %pb) { ; CHECK-LABEL: @__tile_dpbsud( -; CHECK-NEXT: [[TMP1:%.*]] = bitcast <256 x i32>* [[PA:%.*]] to i8* -; CHECK-NEXT: [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K:%.*]], i8* [[TMP1]], i64 64) -; CHECK-NEXT: [[TMP3:%.*]] = bitcast <256 x i32>* [[PB:%.*]] to i8* -; CHECK-NEXT: [[TMP4:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[K]], i16 [[N:%.*]], i8* [[TMP3]], i64 64) -; CHECK-NEXT: [[TMP5:%.*]] = bitcast <256 x i32>* [[PC:%.*]] to i8* -; CHECK-NEXT: [[TMP6:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP5]], i64 64) -; CHECK-NEXT: [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbsud.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP6]], x86_amx [[TMP2]], x86_amx [[TMP4]]) -; CHECK-NEXT: [[TMP7:%.*]] = bitcast <256 x i32>* [[PC]] to i8* -; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP7]], i64 64, x86_amx [[T6]]) +; CHECK-NEXT: [[TMP1:%.*]] = udiv i16 [[K:%.*]], 4 +; CHECK-NEXT: [[TMP2:%.*]] = bitcast <256 x i32>* [[PA:%.*]] to i8* +; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], i8* [[TMP2]], i64 64) +; CHECK-NEXT: [[TMP4:%.*]] = bitcast <256 x i32>* [[PB:%.*]] to i8* +; CHECK-NEXT: [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP1]], i16 [[N:%.*]], i8* [[TMP4]], i64 64) +; CHECK-NEXT: [[TMP6:%.*]] = bitcast <256 x i32>* [[PC:%.*]] to i8* +; CHECK-NEXT: [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP6]], i64 64) +; CHECK-NEXT: [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbsud.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP7]], x86_amx [[TMP3]], x86_amx [[TMP5]]) +; CHECK-NEXT: [[TMP8:%.*]] = bitcast <256 x i32>* [[PC]] to i8* +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP8]], i64 64, x86_amx [[T6]]) ; CHECK-NEXT: ret void ; %t0 = load <256 x i32>, <256 x i32>* %pa, align 64 @@ -225,15 +227,16 @@ define dso_local void @__tile_dpbsud(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, < define dso_local void @__tile_dpbusd(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <256 x i32>* %pa, <256 x i32>* %pb) { ; CHECK-LABEL: @__tile_dpbusd( -; CHECK-NEXT: [[TMP1:%.*]] = bitcast <256 x i32>* [[PA:%.*]] to i8* -; CHECK-NEXT: [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K:%.*]], i8* [[TMP1]], i64 64) -; CHECK-NEXT: [[TMP3:%.*]] = bitcast <256 x i32>* [[PB:%.*]] to i8* -; CHECK-NEXT: [[TMP4:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[K]], i16 [[N:%.*]], i8* [[TMP3]], i64 64) -; CHECK-NEXT: [[TMP5:%.*]] = bitcast <256 x i32>* [[PC:%.*]] to i8* -; CHECK-NEXT: [[TMP6:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP5]], i64 64) -; CHECK-NEXT: [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbusd.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP6]], x86_amx [[TMP2]], x86_amx [[TMP4]]) -; CHECK-NEXT: [[TMP7:%.*]] = bitcast <256 x i32>* [[PC]] to i8* -; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP7]], i64 64, x86_amx [[T6]]) +; CHECK-NEXT: [[TMP1:%.*]] = udiv i16 [[K:%.*]], 4 +; CHECK-NEXT: [[TMP2:%.*]] = bitcast <256 x i32>* [[PA:%.*]] to i8* +; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], i8* [[TMP2]], i64 64) +; CHECK-NEXT: [[TMP4:%.*]] = bitcast <256 x i32>* [[PB:%.*]] to i8* +; CHECK-NEXT: [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP1]], i16 [[N:%.*]], i8* [[TMP4]], i64 64) +; CHECK-NEXT: [[TMP6:%.*]] = bitcast <256 x i32>* [[PC:%.*]] to i8* +; CHECK-NEXT: [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP6]], i64 64) +; CHECK-NEXT: [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbusd.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP7]], x86_amx [[TMP3]], x86_amx [[TMP5]]) +; CHECK-NEXT: [[TMP8:%.*]] = bitcast <256 x i32>* [[PC]] to i8* +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP8]], i64 64, x86_amx [[T6]]) ; CHECK-NEXT: ret void ; %t0 = load <256 x i32>, <256 x i32>* %pa, align 64 @@ -250,15 +253,16 @@ define dso_local void @__tile_dpbusd(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, < define dso_local void @__tile_dpbuud(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <256 x i32>* %pa, <256 x i32>* %pb) { ; CHECK-LABEL: @__tile_dpbuud( -; CHECK-NEXT: [[TMP1:%.*]] = bitcast <256 x i32>* [[PA:%.*]] to i8* -; CHECK-NEXT: [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K:%.*]], i8* [[TMP1]], i64 64) -; CHECK-NEXT: [[TMP3:%.*]] = bitcast <256 x i32>* [[PB:%.*]] to i8* -; CHECK-NEXT: [[TMP4:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[K]], i16 [[N:%.*]], i8* [[TMP3]], i64 64) -; CHECK-NEXT: [[TMP5:%.*]] = bitcast <256 x i32>* [[PC:%.*]] to i8* -; CHECK-NEXT: [[TMP6:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP5]], i64 64) -; CHECK-NEXT: [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbuud.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP6]], x86_amx [[TMP2]], x86_amx [[TMP4]]) -; CHECK-NEXT: [[TMP7:%.*]] = bitcast <256 x i32>* [[PC]] to i8* -; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP7]], i64 64, x86_amx [[T6]]) +; CHECK-NEXT: [[TMP1:%.*]] = udiv i16 [[K:%.*]], 4 +; CHECK-NEXT: [[TMP2:%.*]] = bitcast <256 x i32>* [[PA:%.*]] to i8* +; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], i8* [[TMP2]], i64 64) +; CHECK-NEXT: [[TMP4:%.*]] = bitcast <256 x i32>* [[PB:%.*]] to i8* +; CHECK-NEXT: [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP1]], i16 [[N:%.*]], i8* [[TMP4]], i64 64) +; CHECK-NEXT: [[TMP6:%.*]] = bitcast <256 x i32>* [[PC:%.*]] to i8* +; CHECK-NEXT: [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP6]], i64 64) +; CHECK-NEXT: [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbuud.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP7]], x86_amx [[TMP3]], x86_amx [[TMP5]]) +; CHECK-NEXT: [[TMP8:%.*]] = bitcast <256 x i32>* [[PC]] to i8* +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP8]], i64 64, x86_amx [[T6]]) ; CHECK-NEXT: ret void ; %t0 = load <256 x i32>, <256 x i32>* %pa, align 64 @@ -275,15 +279,16 @@ define dso_local void @__tile_dpbuud(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, < define dso_local void @__tile_dpbf16ps(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <256 x i32>* %pa, <256 x i32>* %pb) { ; CHECK-LABEL: @__tile_dpbf16ps( -; CHECK-NEXT: [[TMP1:%.*]] = bitcast <256 x i32>* [[PA:%.*]] to i8* -; CHECK-NEXT: [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K:%.*]], i8* [[TMP1]], i64 64) -; CHECK-NEXT: [[TMP3:%.*]] = bitcast <256 x i32>* [[PB:%.*]] to i8* -; CHECK-NEXT: [[TMP4:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[K]], i16 [[N:%.*]], i8* [[TMP3]], i64 64) -; CHECK-NEXT: [[TMP5:%.*]] = bitcast <256 x i32>* [[PC:%.*]] to i8* -; CHECK-NEXT: [[TMP6:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP5]], i64 64) -; CHECK-NEXT: [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbf16ps.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP6]], x86_amx [[TMP2]], x86_amx [[TMP4]]) -; CHECK-NEXT: [[TMP7:%.*]] = bitcast <256 x i32>* [[PC]] to i8* -; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP7]], i64 64, x86_amx [[T6]]) +; CHECK-NEXT: [[TMP1:%.*]] = udiv i16 [[K:%.*]], 4 +; CHECK-NEXT: [[TMP2:%.*]] = bitcast <256 x i32>* [[PA:%.*]] to i8* +; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], i8* [[TMP2]], i64 64) +; CHECK-NEXT: [[TMP4:%.*]] = bitcast <256 x i32>* [[PB:%.*]] to i8* +; CHECK-NEXT: [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP1]], i16 [[N:%.*]], i8* [[TMP4]], i64 64) +; CHECK-NEXT: [[TMP6:%.*]] = bitcast <256 x i32>* [[PC:%.*]] to i8* +; CHECK-NEXT: [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP6]], i64 64) +; CHECK-NEXT: [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbf16ps.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP7]], x86_amx [[TMP3]], x86_amx [[TMP5]]) +; CHECK-NEXT: [[TMP8:%.*]] = bitcast <256 x i32>* [[PC]] to i8* +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP8]], i64 64, x86_amx [[T6]]) ; CHECK-NEXT: ret void ; %t0 = load <256 x i32>, <256 x i32>* %pa, align 64 diff --git a/llvm/test/CodeGen/X86/AMX/lat-combine-amx-bitcast.ll b/llvm/test/CodeGen/X86/AMX/lat-combine-amx-bitcast.ll new file mode 100644 index 00000000000000..4aa5c7e3e1b9a6 --- /dev/null +++ b/llvm/test/CodeGen/X86/AMX/lat-combine-amx-bitcast.ll @@ -0,0 +1,412 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt --codegen-opt-level=2 -mtriple=x86_64 -lower-amx-type %s -S | FileCheck %s + +define void @combine_amx_cast_inside_bb() { +; CHECK-LABEL: @combine_amx_cast_inside_bb( +; CHECK-NEXT: wrapper_entry: +; CHECK-NEXT: [[TMP0:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* undef, i64 undef) +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP0]]) +; CHECK-NEXT: ret void +; +wrapper_entry: + %0 = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* undef, i64 undef) + %tmp = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %0) + %1 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> %tmp) + call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx %1) + ret void +} + +; Cases where amxcast can be combined across bb +; %5 and %6 is combined together since %goodphi's incoming is phi or amxcast +define void @combine_amx_cast_and_phi() { +; CHECK-LABEL: @combine_amx_cast_and_phi( +; CHECK-NEXT: wrapper_entry: +; CHECK-NEXT: [[TMP0:%.*]] = alloca <560 x i8>, align 64 +; CHECK-NEXT: [[TMP1:%.*]] = alloca <616 x i8>, align 64 +; CHECK-NEXT: [[TMP2:%.*]] = alloca <110 x i32>, align 64 +; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* undef, i64 undef) +; CHECK-NEXT: br i1 undef, label [[FOR_COND_CLEANUP_I_I:%.*]], label [[FOR_BODY_I_LR_PH_I:%.*]] +; CHECK: for.body.i.lr.ph.i: +; CHECK-NEXT: [[TMP4:%.*]] = bitcast <110 x i32>* [[TMP2]] to i8* +; CHECK-NEXT: store <110 x i32> undef, <110 x i32>* [[TMP2]], align 512 +; CHECK-NEXT: [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* [[TMP4]], i64 40) +; CHECK-NEXT: [[TMP6:%.*]] = bitcast <616 x i8>* [[TMP1]] to i8* +; CHECK-NEXT: store <616 x i8> undef, <616 x i8>* [[TMP1]], align 1024 +; CHECK-NEXT: [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 56, i8* [[TMP6]], i64 56) +; CHECK-NEXT: [[TMP8:%.*]] = bitcast <560 x i8>* [[TMP0]] to i8* +; CHECK-NEXT: store <560 x i8> undef, <560 x i8>* [[TMP0]], align 1024 +; CHECK-NEXT: [[TMP9:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 14, i16 40, i8* [[TMP8]], i64 40) +; CHECK-NEXT: [[TMP10:%.*]] = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx [[TMP5]], x86_amx [[TMP7]], x86_amx [[TMP9]]) +; CHECK-NEXT: br label [[FOR_COND_CLEANUP_I_I]] +; CHECK: for.cond.cleanup.i.i: +; CHECK-NEXT: [[TMP11:%.*]] = phi x86_amx [ [[TMP3]], [[WRAPPER_ENTRY:%.*]] ], [ [[TMP10]], [[FOR_BODY_I_LR_PH_I]] ] +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP11]]) +; CHECK-NEXT: ret void +; +wrapper_entry: + %0 = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* undef, i64 undef) + %tmp = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %0) + br i1 undef, label %for.cond.cleanup.i.i, label %for.body.i.lr.ph.i + +for.body.i.lr.ph.i: ; preds = %wrapper_entry + %1 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> undef) + %2 = call x86_amx @llvm.x86.cast.vector.to.tile.v616i8(<616 x i8> undef) + %3 = call x86_amx @llvm.x86.cast.vector.to.tile.v560i8(<560 x i8> undef) + %4 = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx %1, x86_amx %2, x86_amx %3) + %5 = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %4) + br label %for.cond.cleanup.i.i + +for.cond.cleanup.i.i: ; preds = %for.body.i.lr.ph.i, %wrapper_entry + %goodphi = phi <110 x i32> [ %tmp, %wrapper_entry ], [ %5, %for.body.i.lr.ph.i ] + %6 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> %goodphi) + call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx %6) + ret void +} + +; Cases where amxcast can't be combined across bb +; %5 and %6 is not combined together since %evilphi's incoming is not phi or amxcast +define void @fail_to_combine_amx_cast_and_phi(<110 x i32> %tmp) { +; CHECK-LABEL: @fail_to_combine_amx_cast_and_phi( +; CHECK-NEXT: wrapper_entry: +; CHECK-NEXT: [[TMP0:%.*]] = alloca <110 x i32>, align 64 +; CHECK-NEXT: [[TMP1:%.*]] = alloca <110 x i32>, align 64 +; CHECK-NEXT: [[TMP2:%.*]] = alloca <560 x i8>, align 64 +; CHECK-NEXT: [[TMP3:%.*]] = alloca <616 x i8>, align 64 +; CHECK-NEXT: [[TMP4:%.*]] = alloca <110 x i32>, align 64 +; CHECK-NEXT: [[TMP5:%.*]] = add <110 x i32> [[TMP:%.*]], [[TMP]] +; CHECK-NEXT: br i1 undef, label [[FOR_COND_CLEANUP_I_I:%.*]], label [[FOR_BODY_I_LR_PH_I:%.*]] +; CHECK: for.body.i.lr.ph.i: +; CHECK-NEXT: [[TMP6:%.*]] = bitcast <110 x i32>* [[TMP4]] to i8* +; CHECK-NEXT: store <110 x i32> undef, <110 x i32>* [[TMP4]], align 512 +; CHECK-NEXT: [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* [[TMP6]], i64 40) +; CHECK-NEXT: [[TMP8:%.*]] = bitcast <616 x i8>* [[TMP3]] to i8* +; CHECK-NEXT: store <616 x i8> undef, <616 x i8>* [[TMP3]], align 1024 +; CHECK-NEXT: [[TMP9:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 56, i8* [[TMP8]], i64 56) +; CHECK-NEXT: [[TMP10:%.*]] = bitcast <560 x i8>* [[TMP2]] to i8* +; CHECK-NEXT: store <560 x i8> undef, <560 x i8>* [[TMP2]], align 1024 +; CHECK-NEXT: [[TMP11:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 14, i16 40, i8* [[TMP10]], i64 40) +; CHECK-NEXT: [[TMP12:%.*]] = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx [[TMP7]], x86_amx [[TMP9]], x86_amx [[TMP11]]) +; CHECK-NEXT: [[TMP13:%.*]] = bitcast <110 x i32>* [[TMP1]] to i8* +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* [[TMP13]], i64 40, x86_amx [[TMP12]]) +; CHECK-NEXT: [[TMP14:%.*]] = load <110 x i32>, <110 x i32>* [[TMP1]], align 512 +; CHECK-NEXT: br label [[FOR_COND_CLEANUP_I_I]] +; CHECK: for.cond.cleanup.i.i: +; CHECK-NEXT: [[EVILPHI:%.*]] = phi <110 x i32> [ [[TMP5]], [[WRAPPER_ENTRY:%.*]] ], [ [[TMP14]], [[FOR_BODY_I_LR_PH_I]] ] +; CHECK-NEXT: [[TMP15:%.*]] = bitcast <110 x i32>* [[TMP0]] to i8* +; CHECK-NEXT: store <110 x i32> [[EVILPHI]], <110 x i32>* [[TMP0]], align 512 +; CHECK-NEXT: [[TMP16:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* [[TMP15]], i64 40) +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP16]]) +; CHECK-NEXT: ret void +; +wrapper_entry: + %0 = add <110 x i32> %tmp, %tmp + br i1 undef, label %for.cond.cleanup.i.i, label %for.body.i.lr.ph.i + +for.body.i.lr.ph.i: ; preds = %wrapper_entry + %1 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> undef) + %2 = call x86_amx @llvm.x86.cast.vector.to.tile.v616i8(<616 x i8> undef) + %3 = call x86_amx @llvm.x86.cast.vector.to.tile.v560i8(<560 x i8> undef) + %4 = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx %1, x86_amx %2, x86_amx %3) + %5 = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %4) + br label %for.cond.cleanup.i.i + +for.cond.cleanup.i.i: ; preds = %for.body.i.lr.ph.i, %wrapper_entry + %evilphi = phi <110 x i32> [ %0, %wrapper_entry ], [ %5, %for.body.i.lr.ph.i ] + %6 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> %evilphi) + call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx %6) + ret void +} + +; Cases where amxcast can't be combined across bb +; %5 and %6 is not combined together since %evilphi's user aka %evilphi2 is not inside phi web. +define void @fail_to_combine_amx_cast_and_phi2() { +; CHECK-LABEL: @fail_to_combine_amx_cast_and_phi2( +; CHECK-NEXT: wrapper_entry: +; CHECK-NEXT: [[TMP0:%.*]] = alloca <110 x i32>, align 64 +; CHECK-NEXT: [[TMP1:%.*]] = alloca <110 x i32>, align 64 +; CHECK-NEXT: [[TMP2:%.*]] = alloca <560 x i8>, align 64 +; CHECK-NEXT: [[TMP3:%.*]] = alloca <616 x i8>, align 64 +; CHECK-NEXT: [[TMP4:%.*]] = alloca <110 x i32>, align 64 +; CHECK-NEXT: [[TMP5:%.*]] = alloca <110 x i32>, align 64 +; CHECK-NEXT: [[TMP6:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* undef, i64 undef) +; CHECK-NEXT: [[TMP7:%.*]] = bitcast <110 x i32>* [[TMP5]] to i8* +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* [[TMP7]], i64 40, x86_amx [[TMP6]]) +; CHECK-NEXT: [[TMP8:%.*]] = load <110 x i32>, <110 x i32>* [[TMP5]], align 512 +; CHECK-NEXT: br i1 undef, label [[FOR_COND_CLEANUP_I_I:%.*]], label [[FOR_BODY_I_LR_PH_I:%.*]] +; CHECK: for.body.i.lr.ph.i: +; CHECK-NEXT: [[TMP9:%.*]] = bitcast <110 x i32>* [[TMP4]] to i8* +; CHECK-NEXT: store <110 x i32> undef, <110 x i32>* [[TMP4]], align 512 +; CHECK-NEXT: [[TMP10:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* [[TMP9]], i64 40) +; CHECK-NEXT: [[TMP11:%.*]] = bitcast <616 x i8>* [[TMP3]] to i8* +; CHECK-NEXT: store <616 x i8> undef, <616 x i8>* [[TMP3]], align 1024 +; CHECK-NEXT: [[TMP12:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 56, i8* [[TMP11]], i64 56) +; CHECK-NEXT: [[TMP13:%.*]] = bitcast <560 x i8>* [[TMP2]] to i8* +; CHECK-NEXT: store <560 x i8> undef, <560 x i8>* [[TMP2]], align 1024 +; CHECK-NEXT: [[TMP14:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 14, i16 40, i8* [[TMP13]], i64 40) +; CHECK-NEXT: [[TMP15:%.*]] = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx [[TMP10]], x86_amx [[TMP12]], x86_amx [[TMP14]]) +; CHECK-NEXT: [[TMP16:%.*]] = bitcast <110 x i32>* [[TMP1]] to i8* +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* [[TMP16]], i64 40, x86_amx [[TMP15]]) +; CHECK-NEXT: [[TMP17:%.*]] = load <110 x i32>, <110 x i32>* [[TMP1]], align 512 +; CHECK-NEXT: br i1 undef, label [[FOR_COND_CLEANUP_I_I]], label [[EXIT:%.*]] +; CHECK: for.cond.cleanup.i.i: +; CHECK-NEXT: [[GOODPHI:%.*]] = phi <110 x i32> [ [[TMP8]], [[WRAPPER_ENTRY:%.*]] ], [ [[TMP17]], [[FOR_BODY_I_LR_PH_I]] ] +; CHECK-NEXT: [[TMP18:%.*]] = bitcast <110 x i32>* [[TMP0]] to i8* +; CHECK-NEXT: store <110 x i32> [[GOODPHI]], <110 x i32>* [[TMP0]], align 512 +; CHECK-NEXT: [[TMP19:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* [[TMP18]], i64 40) +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP19]]) +; CHECK-NEXT: br i1 undef, label [[EXIT]], label [[FOR_BODY_I_LR_PH_I]] +; CHECK: exit: +; CHECK-NEXT: [[EVILPHI2:%.*]] = phi <110 x i32> [ [[GOODPHI]], [[FOR_COND_CLEANUP_I_I]] ], [ [[TMP17]], [[FOR_BODY_I_LR_PH_I]] ] +; CHECK-NEXT: store <110 x i32> [[EVILPHI2]], <110 x i32>* undef, align 512 +; CHECK-NEXT: ret void +; +wrapper_entry: + %0 = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* undef, i64 undef) + %tmp = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %0) + br i1 undef, label %for.cond.cleanup.i.i, label %for.body.i.lr.ph.i + +for.body.i.lr.ph.i: ; preds = %wrapper_entry + %1 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> undef) + %2 = call x86_amx @llvm.x86.cast.vector.to.tile.v616i8(<616 x i8> undef) + %3 = call x86_amx @llvm.x86.cast.vector.to.tile.v560i8(<560 x i8> undef) + %4 = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx %1, x86_amx %2, x86_amx %3) + %5 = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %4) + br i1 undef, label %for.cond.cleanup.i.i, label %exit + +for.cond.cleanup.i.i: ; preds = %for.body.i.lr.ph.i, %wrapper_entry + %goodphi = phi <110 x i32> [ %tmp, %wrapper_entry ], [ %5, %for.body.i.lr.ph.i ] + %6 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> %goodphi) + call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx %6) + br i1 undef, label %exit, label %for.body.i.lr.ph.i +exit: + %evilphi2 = phi <110 x i32> [ %goodphi, %for.cond.cleanup.i.i ], [ %5, %for.body.i.lr.ph.i ] + store <110 x i32> %evilphi2, <110 x i32>* undef, align 512 + ret void +} + +define void @fail_to_combine_amx_cast_and_phi_due_to_const_value() { +; CHECK-LABEL: @fail_to_combine_amx_cast_and_phi_due_to_const_value( +; CHECK-NEXT: wrapper_entry: +; CHECK-NEXT: [[TMP0:%.*]] = alloca <110 x i32>, align 64 +; CHECK-NEXT: [[TMP1:%.*]] = alloca <110 x i32>, align 64 +; CHECK-NEXT: [[TMP2:%.*]] = alloca <560 x i8>, align 64 +; CHECK-NEXT: [[TMP3:%.*]] = alloca <616 x i8>, align 64 +; CHECK-NEXT: [[TMP4:%.*]] = alloca <110 x i32>, align 64 +; CHECK-NEXT: br i1 undef, label [[FOR_COND_CLEANUP_I_I:%.*]], label [[FOR_BODY_I_LR_PH_I:%.*]] +; CHECK: for.body.i.lr.ph.i: +; CHECK-NEXT: [[TMP5:%.*]] = bitcast <110 x i32>* [[TMP4]] to i8* +; CHECK-NEXT: store <110 x i32> undef, <110 x i32>* [[TMP4]], align 512 +; CHECK-NEXT: [[TMP6:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* [[TMP5]], i64 40) +; CHECK-NEXT: [[TMP7:%.*]] = bitcast <616 x i8>* [[TMP3]] to i8* +; CHECK-NEXT: store <616 x i8> undef, <616 x i8>* [[TMP3]], align 1024 +; CHECK-NEXT: [[TMP8:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 56, i8* [[TMP7]], i64 56) +; CHECK-NEXT: [[TMP9:%.*]] = bitcast <560 x i8>* [[TMP2]] to i8* +; CHECK-NEXT: store <560 x i8> undef, <560 x i8>* [[TMP2]], align 1024 +; CHECK-NEXT: [[TMP10:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 14, i16 40, i8* [[TMP9]], i64 40) +; CHECK-NEXT: [[TMP11:%.*]] = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx [[TMP6]], x86_amx [[TMP8]], x86_amx [[TMP10]]) +; CHECK-NEXT: [[TMP12:%.*]] = bitcast <110 x i32>* [[TMP1]] to i8* +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* [[TMP12]], i64 40, x86_amx [[TMP11]]) +; CHECK-NEXT: [[TMP13:%.*]] = load <110 x i32>, <110 x i32>* [[TMP1]], align 512 +; CHECK-NEXT: br label [[FOR_COND_CLEANUP_I_I]] +; CHECK: for.cond.cleanup.i.i: +; CHECK-NEXT: [[EVILPHI:%.*]] = phi <110 x i32> [ undef, [[WRAPPER_ENTRY:%.*]] ], [ [[TMP13]], [[FOR_BODY_I_LR_PH_I]] ] +; CHECK-NEXT: [[TMP14:%.*]] = bitcast <110 x i32>* [[TMP0]] to i8* +; CHECK-NEXT: store <110 x i32> [[EVILPHI]], <110 x i32>* [[TMP0]], align 512 +; CHECK-NEXT: [[TMP15:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* [[TMP14]], i64 40) +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP15]]) +; CHECK-NEXT: ret void +; +wrapper_entry: + br i1 undef, label %for.cond.cleanup.i.i, label %for.body.i.lr.ph.i + +for.body.i.lr.ph.i: ; preds = %wrapper_entry + %0 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> undef) + %1 = call x86_amx @llvm.x86.cast.vector.to.tile.v616i8(<616 x i8> undef) + %2 = call x86_amx @llvm.x86.cast.vector.to.tile.v560i8(<560 x i8> undef) + %3 = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx %0, x86_amx %1, x86_amx %2) + %4 = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %3) + br label %for.cond.cleanup.i.i + +for.cond.cleanup.i.i: ; preds = %for.body.i.lr.ph.i, %wrapper_entry + %evilphi = phi <110 x i32> [ undef, %wrapper_entry ], [ %4, %for.body.i.lr.ph.i ] + %5 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> %evilphi) + call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx %5) + ret void +} + +; Cases where amxcast can be combined across bb +; When optimizeAMXCastFromPhi process %6 and %goodphi, %goodphi2 is outside the phi-web, so the optimization stop +; When optimizeAMXCastFromPhi process %7 and %goodphi2, the optimization continue. +define void @combine_amx_cast_and_multiple_phi() { +; CHECK-LABEL: @combine_amx_cast_and_multiple_phi( +; CHECK-NEXT: wrapper_entry: +; CHECK-NEXT: [[TMP0:%.*]] = alloca <560 x i8>, align 64 +; CHECK-NEXT: [[TMP1:%.*]] = alloca <616 x i8>, align 64 +; CHECK-NEXT: [[TMP2:%.*]] = alloca <110 x i32>, align 64 +; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* undef, i64 undef) +; CHECK-NEXT: br i1 undef, label [[FOR_COND_CLEANUP_I_I:%.*]], label [[FOR_BODY_I_LR_PH_I:%.*]] +; CHECK: for.body.i.lr.ph.i: +; CHECK-NEXT: [[TMP4:%.*]] = bitcast <110 x i32>* [[TMP2]] to i8* +; CHECK-NEXT: store <110 x i32> undef, <110 x i32>* [[TMP2]], align 512 +; CHECK-NEXT: [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* [[TMP4]], i64 40) +; CHECK-NEXT: [[TMP6:%.*]] = bitcast <616 x i8>* [[TMP1]] to i8* +; CHECK-NEXT: store <616 x i8> undef, <616 x i8>* [[TMP1]], align 1024 +; CHECK-NEXT: [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 56, i8* [[TMP6]], i64 56) +; CHECK-NEXT: [[TMP8:%.*]] = bitcast <560 x i8>* [[TMP0]] to i8* +; CHECK-NEXT: store <560 x i8> undef, <560 x i8>* [[TMP0]], align 1024 +; CHECK-NEXT: [[TMP9:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 14, i16 40, i8* [[TMP8]], i64 40) +; CHECK-NEXT: [[TMP10:%.*]] = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx [[TMP5]], x86_amx [[TMP7]], x86_amx [[TMP9]]) +; CHECK-NEXT: br i1 undef, label [[FOR_COND_CLEANUP_I_I]], label [[EXIT:%.*]] +; CHECK: for.cond.cleanup.i.i: +; CHECK-NEXT: [[TMP11:%.*]] = phi x86_amx [ [[TMP3]], [[WRAPPER_ENTRY:%.*]] ], [ [[TMP10]], [[FOR_BODY_I_LR_PH_I]] ] +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP11]]) +; CHECK-NEXT: br i1 undef, label [[EXIT]], label [[FOR_BODY_I_LR_PH_I]] +; CHECK: exit: +; CHECK-NEXT: [[TMP12:%.*]] = phi x86_amx [ [[TMP11]], [[FOR_COND_CLEANUP_I_I]] ], [ [[TMP10]], [[FOR_BODY_I_LR_PH_I]] ] +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP12]]) +; CHECK-NEXT: ret void +; +wrapper_entry: + %0 = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* undef, i64 undef) + %tmp = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %0) + br i1 undef, label %for.cond.cleanup.i.i, label %for.body.i.lr.ph.i + +for.body.i.lr.ph.i: ; preds = %wrapper_entry + %1 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> undef) + %2 = call x86_amx @llvm.x86.cast.vector.to.tile.v616i8(<616 x i8> undef) + %3 = call x86_amx @llvm.x86.cast.vector.to.tile.v560i8(<560 x i8> undef) + %4 = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx %1, x86_amx %2, x86_amx %3) + %5 = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %4) + br i1 undef, label %for.cond.cleanup.i.i, label %exit + +for.cond.cleanup.i.i: ; preds = %for.body.i.lr.ph.i, %wrapper_entry + %goodphi = phi <110 x i32> [ %tmp, %wrapper_entry ], [ %5, %for.body.i.lr.ph.i ] + %6 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> %goodphi) + call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx %6) + br i1 undef, label %exit, label %for.body.i.lr.ph.i +exit: + %evilphi2 = phi <110 x i32> [ %goodphi, %for.cond.cleanup.i.i ], [ %5, %for.body.i.lr.ph.i ] + %7 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> %evilphi2) + call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx %7) + ret void +} + +; Currently we are not able to delete DeadPHICycle, later we will handle with them +define void @combine_amx_cast_and_phi_in_a_circle() { +; CHECK-LABEL: @combine_amx_cast_and_phi_in_a_circle( +; CHECK-NEXT: wrapper_entry: +; CHECK-NEXT: [[TMP0:%.*]] = alloca <110 x i32>, align 64 +; CHECK-NEXT: [[TMP1:%.*]] = alloca <560 x i8>, align 64 +; CHECK-NEXT: [[TMP2:%.*]] = alloca <616 x i8>, align 64 +; CHECK-NEXT: [[TMP3:%.*]] = alloca <110 x i32>, align 64 +; CHECK-NEXT: [[TMP4:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* undef, i64 undef) +; CHECK-NEXT: br label [[BB1:%.*]] +; CHECK: bb1: +; CHECK-NEXT: [[TMP5:%.*]] = bitcast <110 x i32>* [[TMP3]] to i8* +; CHECK-NEXT: store <110 x i32> undef, <110 x i32>* [[TMP3]], align 512 +; CHECK-NEXT: [[TMP6:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* [[TMP5]], i64 40) +; CHECK-NEXT: [[TMP7:%.*]] = bitcast <616 x i8>* [[TMP2]] to i8* +; CHECK-NEXT: store <616 x i8> undef, <616 x i8>* [[TMP2]], align 1024 +; CHECK-NEXT: [[TMP8:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 56, i8* [[TMP7]], i64 56) +; CHECK-NEXT: [[TMP9:%.*]] = bitcast <560 x i8>* [[TMP1]] to i8* +; CHECK-NEXT: store <560 x i8> undef, <560 x i8>* [[TMP1]], align 1024 +; CHECK-NEXT: [[TMP10:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 14, i16 40, i8* [[TMP9]], i64 40) +; CHECK-NEXT: [[TMP11:%.*]] = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx [[TMP6]], x86_amx [[TMP8]], x86_amx [[TMP10]]) +; CHECK-NEXT: [[TMP12:%.*]] = bitcast <110 x i32>* [[TMP0]] to i8* +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* [[TMP12]], i64 40, x86_amx [[TMP11]]) +; CHECK-NEXT: [[TMP13:%.*]] = load <110 x i32>, <110 x i32>* [[TMP0]], align 512 +; CHECK-NEXT: br i1 undef, label [[BB2:%.*]], label [[BB3:%.*]] +; CHECK: bb2: +; CHECK-NEXT: [[TMP14:%.*]] = phi x86_amx [ [[TMP15:%.*]], [[BB3]] ], [ [[TMP11]], [[BB1]] ] +; CHECK-NEXT: [[GOODPHI:%.*]] = phi <110 x i32> [ [[EVILPHI2:%.*]], [[BB3]] ], [ [[TMP13]], [[BB1]] ] +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP14]]) +; CHECK-NEXT: br label [[BB3]] +; CHECK: bb3: +; CHECK-NEXT: [[TMP15]] = phi x86_amx [ [[TMP14]], [[BB2]] ], [ [[TMP11]], [[BB1]] ] +; CHECK-NEXT: [[EVILPHI2]] = phi <110 x i32> [ [[GOODPHI]], [[BB2]] ], [ [[TMP13]], [[BB1]] ] +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP15]]) +; CHECK-NEXT: br i1 undef, label [[BB2]], label [[EXIT:%.*]] +; CHECK: exit: +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP15]]) +; CHECK-NEXT: ret void +; +wrapper_entry: + %0 = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* undef, i64 undef) + %tmp = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %0) + br label %bb1 + +bb1: ; preds = %wrapper_entry + %1 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> undef) + %2 = call x86_amx @llvm.x86.cast.vector.to.tile.v616i8(<616 x i8> undef) + %3 = call x86_amx @llvm.x86.cast.vector.to.tile.v560i8(<560 x i8> undef) + %4 = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx %1, x86_amx %2, x86_amx %3) + %5 = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %4) + br i1 undef, label %bb2, label %bb3 + +bb2: ; preds = %bb1, %wrapper_entry + %goodphi = phi <110 x i32> [ %evilphi2, %bb3], [ %5, %bb1 ] + %6 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> %goodphi) + call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx %6) + br label %bb3 +bb3: + %evilphi2 = phi <110 x i32> [ %goodphi, %bb2 ], [ %5, %bb1 ] + %7 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> %evilphi2) + call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx %7) + br i1 undef, label %bb2, label %exit +exit: + %8 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> %evilphi2) + call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx %8) + ret void +} + +define void @eliminate_unused_phi_and_cast() { +; CHECK-LABEL: @eliminate_unused_phi_and_cast( +; CHECK-NEXT: wrapper_entry: +; CHECK-NEXT: [[TMP0:%.*]] = alloca <560 x i8>, align 64 +; CHECK-NEXT: [[TMP1:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* undef, i64 undef) +; CHECK-NEXT: br i1 undef, label [[FOR_COND_CLEANUP_I_I:%.*]], label [[FOR_BODY_I_LR_PH_I:%.*]] +; CHECK: for.body.i.lr.ph.i: +; CHECK-NEXT: [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 56, i8* undef, i64 undef) +; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 14, i16 40, i8* undef, i64 undef) +; CHECK-NEXT: [[TMP4:%.*]] = bitcast <560 x i8>* [[TMP0]] to i8* +; CHECK-NEXT: store <560 x i8> undef, <560 x i8>* [[TMP0]], align 1024 +; CHECK-NEXT: [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 14, i16 40, i8* [[TMP4]], i64 40) +; CHECK-NEXT: [[TMP6:%.*]] = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx [[TMP2]], x86_amx [[TMP3]], x86_amx [[TMP5]]) +; CHECK-NEXT: br label [[FOR_COND_CLEANUP_I_I]] +; CHECK: for.cond.cleanup.i.i: +; CHECK-NEXT: [[TMP7:%.*]] = phi x86_amx [ [[TMP1]], [[WRAPPER_ENTRY:%.*]] ], [ [[TMP6]], [[FOR_BODY_I_LR_PH_I]] ] +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP7]]) +; CHECK-NEXT: ret void +; +wrapper_entry: + %0 = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* undef, i64 undef) + %tmp = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %0) + br i1 undef, label %for.cond.cleanup.i.i, label %for.body.i.lr.ph.i + +for.body.i.lr.ph.i: ; preds = %wrapper_entry + %1 = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 56, i8* undef, i64 undef) + %v1 = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %1) + %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 14, i16 40, i8* undef, i64 undef) + %v2 = call <616 x i8> @llvm.x86.cast.tile.to.vector.v616i8(x86_amx %2) + %3 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> %v1) + %4 = call x86_amx @llvm.x86.cast.vector.to.tile.v616i8(<616 x i8> %v2) + %5 = call x86_amx @llvm.x86.cast.vector.to.tile.v560i8(<560 x i8> undef) + %6 = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx %3, x86_amx %4, x86_amx %5) + %7 = call <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx %6) + br label %for.cond.cleanup.i.i + +for.cond.cleanup.i.i: ; preds = %for.body.i.lr.ph.i, %wrapper_entry + %goodphi = phi <110 x i32> [ %tmp, %wrapper_entry ], [ %7, %for.body.i.lr.ph.i ] + %8 = call x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32> %goodphi) + call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx %8) + ret void +} + +declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) +declare <110 x i32> @llvm.x86.cast.tile.to.vector.v110i32(x86_amx) +declare <616 x i8> @llvm.x86.cast.tile.to.vector.v616i8(x86_amx) +declare x86_amx @llvm.x86.cast.vector.to.tile.v110i32(<110 x i32>) +declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx) +declare x86_amx @llvm.x86.cast.vector.to.tile.v616i8(<616 x i8>) +declare x86_amx @llvm.x86.cast.vector.to.tile.v560i8(<560 x i8>) +declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) diff --git a/llvm/test/CodeGen/X86/AMX/lat-transform-amx-bitcast.ll b/llvm/test/CodeGen/X86/AMX/lat-transform-amx-bitcast.ll new file mode 100644 index 00000000000000..98a820197bbd6a --- /dev/null +++ b/llvm/test/CodeGen/X86/AMX/lat-transform-amx-bitcast.ll @@ -0,0 +1,429 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt --codegen-opt-level=2 -mtriple=x86_64 -lower-amx-type %s -S | FileCheck %s + +%struct.__tile_str = type { i16, i16, <256 x i32> } + +@buf = dso_local global [1024 x i8] zeroinitializer, align 64 +@buf2 = dso_local global [1024 x i8] zeroinitializer, align 64 + +; test bitcast x86_amx to <256 x i32> +define dso_local void @test_user_empty(i16 %m, i16 %n, i8 *%buf, i64 %s) { +; CHECK-LABEL: @test_user_empty( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[T1:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[N:%.*]], i8* [[BUF:%.*]], i64 [[S:%.*]]) +; CHECK-NEXT: ret void +; +entry: + %t1 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %m, i16 %n, i8* %buf, i64 %s) + %t2 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %t1) + ret void +} + +; test bitcast <256 x i32> to x86_amx +define dso_local void @test_user_empty2(<256 x i32> %in) { +; CHECK-LABEL: @test_user_empty2( +; CHECK-NEXT: entry: +; CHECK-NEXT: ret void +; +entry: + %t = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %in) + ret void +} + +define dso_local <256 x i32> @test_amx_load_bitcast_v256i32(<256 x i32>* %in, i16 %m, i16 %n, i8 *%buf, i64 %s) { +; CHECK-LABEL: @test_amx_load_bitcast_v256i32( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = alloca <256 x i32>, align 64 +; CHECK-NEXT: [[T1:%.*]] = load <256 x i32>, <256 x i32>* [[IN:%.*]], align 64 +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <256 x i32>* [[TMP0]] to i8* +; CHECK-NEXT: store <256 x i32> [[T1]], <256 x i32>* [[TMP0]], align 1024 +; CHECK-NEXT: [[TMP2:%.*]] = sext i16 [[N:%.*]] to i64 +; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[N]], i8* [[TMP1]], i64 [[TMP2]]) +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[BUF:%.*]], i64 [[S:%.*]], x86_amx [[TMP3]]) +; CHECK-NEXT: ret <256 x i32> [[T1]] +; +entry: + %t1 = load <256 x i32>, <256 x i32>* %in, align 64 + %t2 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t1) + call void @llvm.x86.tilestored64.internal(i16 %m, i16 %n, i8* %buf, i64 %s, x86_amx %t2) + ret <256 x i32> %t1 +} + +define dso_local <225 x i32> @test_amx_load_bitcast_v225i32(<225 x i32>* %in, i16 %m, i16 %n, i8 *%buf, i64 %s) { +; CHECK-LABEL: @test_amx_load_bitcast_v225i32( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = alloca <225 x i32>, align 64 +; CHECK-NEXT: [[T1:%.*]] = load <225 x i32>, <225 x i32>* [[IN:%.*]], align 64 +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <225 x i32>* [[TMP0]] to i8* +; CHECK-NEXT: store <225 x i32> [[T1]], <225 x i32>* [[TMP0]], align 1024 +; CHECK-NEXT: [[TMP2:%.*]] = sext i16 [[N:%.*]] to i64 +; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[N]], i8* [[TMP1]], i64 [[TMP2]]) +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[BUF:%.*]], i64 [[S:%.*]], x86_amx [[TMP3]]) +; CHECK-NEXT: ret <225 x i32> [[T1]] +; +entry: + %t1 = load <225 x i32>, <225 x i32>* %in, align 64 + %t2 = call x86_amx @llvm.x86.cast.vector.to.tile.v225i32(<225 x i32> %t1) + call void @llvm.x86.tilestored64.internal(i16 %m, i16 %n, i8* %buf, i64 %s, x86_amx %t2) + ret <225 x i32> %t1 +} + +define dso_local <256 x i32> @test_amx_bitcast_store(<256 x i32>* %out, i16 %m, i16 %n, i8 *%buf, i64 %s) { +; CHECK-LABEL: @test_amx_bitcast_store( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = alloca <256 x i32>, align 64 +; CHECK-NEXT: [[T1:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[M]], i8* [[BUF:%.*]], i64 [[S:%.*]]) +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <256 x i32>* [[TMP0]] to i8* +; CHECK-NEXT: [[TMP2:%.*]] = sext i16 [[M]] to i64 +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[M]], i8* [[TMP1]], i64 [[TMP2]], x86_amx [[T1]]) +; CHECK-NEXT: [[TMP3:%.*]] = load <256 x i32>, <256 x i32>* [[TMP0]], align 1024 +; CHECK-NEXT: store <256 x i32> [[TMP3]], <256 x i32>* [[OUT:%.*]], align 1024 +; CHECK-NEXT: ret <256 x i32> [[TMP3]] +; +entry: + %t1 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %m, i16 %m, i8* %buf, i64 %s) + %t2 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %t1) + store <256 x i32> %t2, <256 x i32>* %out + ret <256 x i32> %t2 +} + +define dso_local void @test_src_add(<256 x i32> %x, <256 x i32> %y, i16 %r, i16 %c, i8* %buf, i64 %s) { +; CHECK-LABEL: @test_src_add( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = alloca <256 x i32>, align 64 +; CHECK-NEXT: [[ADD:%.*]] = add <256 x i32> [[Y:%.*]], [[X:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <256 x i32>* [[TMP0]] to i8* +; CHECK-NEXT: store <256 x i32> [[ADD]], <256 x i32>* [[TMP0]], align 1024 +; CHECK-NEXT: [[TMP2:%.*]] = sext i16 [[C:%.*]] to i64 +; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[R:%.*]], i16 [[C]], i8* [[TMP1]], i64 [[TMP2]]) +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[R]], i16 [[C]], i8* [[BUF:%.*]], i64 [[S:%.*]], x86_amx [[TMP3]]) +; CHECK-NEXT: ret void +; +entry: + %add = add <256 x i32> %y, %x + %t = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %add) + call void @llvm.x86.tilestored64.internal(i16 %r, i16 %c, i8* %buf, i64 %s, x86_amx %t) + ret void +} + +define dso_local void @test_src_add2(<256 x i32> %x, i16 %r, i16 %c, i8* %buf, i64 %s) { +; CHECK-LABEL: @test_src_add2( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = alloca <256 x i32>, align 64 +; CHECK-NEXT: [[T1:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[R:%.*]], i16 [[C:%.*]], i8* [[BUF:%.*]], i64 [[S:%.*]]) +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <256 x i32>* [[TMP0]] to i8* +; CHECK-NEXT: [[TMP2:%.*]] = sext i16 [[C]] to i64 +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[R]], i16 [[C]], i8* [[TMP1]], i64 [[TMP2]], x86_amx [[T1]]) +; CHECK-NEXT: [[TMP3:%.*]] = load <256 x i32>, <256 x i32>* [[TMP0]], align 1024 +; CHECK-NEXT: [[ADD:%.*]] = add <256 x i32> [[TMP3]], [[X:%.*]] +; CHECK-NEXT: ret void +; +entry: + %t1 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %r, i16 %c, i8* %buf, i64 %s) + %t2 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %t1) + %add = add <256 x i32> %t2, %x + ret void +} + +define dso_local void @__tile_loadd(%struct.__tile_str* nocapture %0, i8* %1, i64 %2) local_unnamed_addr { +; CHECK-LABEL: @__tile_loadd( +; CHECK-NEXT: [[TMP4:%.*]] = alloca <256 x i32>, align 64 +; CHECK-NEXT: [[TMP5:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[TMP0:%.*]], i64 0, i32 0 +; CHECK-NEXT: [[TMP6:%.*]] = load i16, i16* [[TMP5]], align 64 +; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0]], i64 0, i32 1 +; CHECK-NEXT: [[TMP8:%.*]] = load i16, i16* [[TMP7]], align 2 +; CHECK-NEXT: [[TMP9:%.*]] = shl i64 [[TMP2:%.*]], 32 +; CHECK-NEXT: [[TMP10:%.*]] = ashr exact i64 [[TMP9]], 32 +; CHECK-NEXT: [[TMP11:%.*]] = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP6]], i16 [[TMP8]], i8* [[TMP1:%.*]], i64 [[TMP10]]) +; CHECK-NEXT: [[TMP12:%.*]] = bitcast <256 x i32>* [[TMP4]] to i8* +; CHECK-NEXT: [[TMP13:%.*]] = sext i16 [[TMP8]] to i64 +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[TMP6]], i16 [[TMP8]], i8* [[TMP12]], i64 [[TMP13]], x86_amx [[TMP11]]) +; CHECK-NEXT: [[TMP14:%.*]] = load <256 x i32>, <256 x i32>* [[TMP4]], align 1024 +; CHECK-NEXT: [[TMP15:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0]], i64 0, i32 2 +; CHECK-NEXT: store <256 x i32> [[TMP14]], <256 x i32>* [[TMP15]], align 64 +; CHECK-NEXT: ret void +; + %4 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %0, i64 0, i32 0 + %5 = load i16, i16* %4, align 64 + %6 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %0, i64 0, i32 1 + %7 = load i16, i16* %6, align 2 + %8 = shl i64 %2, 32 + %9 = ashr exact i64 %8, 32 + %10 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %5, i16 %7, i8* %1, i64 %9) + %11 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %10) + %12 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %0, i64 0, i32 2 + store <256 x i32> %11, <256 x i32>* %12, align 64 + ret void +} + +define dso_local void @__tile_dpbssd(%struct.__tile_str* nocapture %0, %struct.__tile_str* nocapture readonly byval(%struct.__tile_str) align 64 %1, %struct.__tile_str* nocapture readonly byval(%struct.__tile_str) align 64 %2) local_unnamed_addr { +; CHECK-LABEL: @__tile_dpbssd( +; CHECK-NEXT: [[TMP4:%.*]] = alloca <256 x i32>, align 64 +; CHECK-NEXT: [[TMP5:%.*]] = alloca <256 x i32>, align 64 +; CHECK-NEXT: [[TMP6:%.*]] = alloca <256 x i32>, align 64 +; CHECK-NEXT: [[TMP7:%.*]] = alloca <256 x i32>, align 64 +; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[TMP1:%.*]], i64 0, i32 0 +; CHECK-NEXT: [[TMP9:%.*]] = load i16, i16* [[TMP8]], align 64 +; CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2:%.*]], i64 0, i32 1 +; CHECK-NEXT: [[TMP11:%.*]] = load i16, i16* [[TMP10]], align 2 +; CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP1]], i64 0, i32 1 +; CHECK-NEXT: [[TMP13:%.*]] = load i16, i16* [[TMP12]], align 2 +; CHECK-NEXT: [[TMP14:%.*]] = udiv i16 [[TMP13]], 4 +; CHECK-NEXT: [[TMP15:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0:%.*]], i64 0, i32 2 +; CHECK-NEXT: [[TMP16:%.*]] = load <256 x i32>, <256 x i32>* [[TMP15]], align 64 +; CHECK-NEXT: [[TMP17:%.*]] = bitcast <256 x i32>* [[TMP7]] to i8* +; CHECK-NEXT: store <256 x i32> [[TMP16]], <256 x i32>* [[TMP7]], align 1024 +; CHECK-NEXT: [[TMP18:%.*]] = sext i16 [[TMP11]] to i64 +; CHECK-NEXT: [[TMP19:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP9]], i16 [[TMP11]], i8* [[TMP17]], i64 [[TMP18]]) +; CHECK-NEXT: [[TMP20:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP1]], i64 0, i32 2 +; CHECK-NEXT: [[TMP21:%.*]] = load <256 x i32>, <256 x i32>* [[TMP20]], align 64 +; CHECK-NEXT: [[TMP22:%.*]] = bitcast <256 x i32>* [[TMP6]] to i8* +; CHECK-NEXT: store <256 x i32> [[TMP21]], <256 x i32>* [[TMP6]], align 1024 +; CHECK-NEXT: [[TMP23:%.*]] = sext i16 [[TMP13]] to i64 +; CHECK-NEXT: [[TMP24:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP9]], i16 [[TMP13]], i8* [[TMP22]], i64 [[TMP23]]) +; CHECK-NEXT: [[TMP25:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 2 +; CHECK-NEXT: [[TMP26:%.*]] = load <256 x i32>, <256 x i32>* [[TMP25]], align 64 +; CHECK-NEXT: [[TMP27:%.*]] = bitcast <256 x i32>* [[TMP5]] to i8* +; CHECK-NEXT: store <256 x i32> [[TMP26]], <256 x i32>* [[TMP5]], align 1024 +; CHECK-NEXT: [[TMP28:%.*]] = sext i16 [[TMP11]] to i64 +; CHECK-NEXT: [[TMP29:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP14]], i16 [[TMP11]], i8* [[TMP27]], i64 [[TMP28]]) +; CHECK-NEXT: [[TMP30:%.*]] = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 [[TMP9]], i16 [[TMP11]], i16 [[TMP13]], x86_amx [[TMP19]], x86_amx [[TMP24]], x86_amx [[TMP29]]) +; CHECK-NEXT: [[TMP31:%.*]] = bitcast <256 x i32>* [[TMP4]] to i8* +; CHECK-NEXT: [[TMP32:%.*]] = sext i16 [[TMP11]] to i64 +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[TMP9]], i16 [[TMP11]], i8* [[TMP31]], i64 [[TMP32]], x86_amx [[TMP30]]) +; CHECK-NEXT: [[TMP33:%.*]] = load <256 x i32>, <256 x i32>* [[TMP4]], align 1024 +; CHECK-NEXT: store <256 x i32> [[TMP33]], <256 x i32>* [[TMP15]], align 64 +; CHECK-NEXT: ret void +; + %4 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %1, i64 0, i32 0 + %5 = load i16, i16* %4, align 64 + %6 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 1 + %7 = load i16, i16* %6, align 2 + %8 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %1, i64 0, i32 1 + %9 = load i16, i16* %8, align 2 + %10 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %0, i64 0, i32 2 + %11 = load <256 x i32>, <256 x i32>* %10, align 64 + %12 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %11) + %13 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %1, i64 0, i32 2 + %14 = load <256 x i32>, <256 x i32>* %13, align 64 + %15 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %14) + %16 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 2 + %17 = load <256 x i32>, <256 x i32>* %16, align 64 + %18 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %17) + %19 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %5, i16 %7, i16 %9, x86_amx %12, x86_amx %15, x86_amx %18) + %20 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %19) + store <256 x i32> %20, <256 x i32>* %10, align 64 + ret void +} + +define dso_local void @__tile_dpbsud(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <256 x i32>* %pa, <256 x i32>* %pb) { +; CHECK-LABEL: @__tile_dpbsud( +; CHECK-NEXT: [[TMP1:%.*]] = alloca <256 x i32>, align 64 +; CHECK-NEXT: [[TMP2:%.*]] = alloca <256 x i32>, align 64 +; CHECK-NEXT: [[TMP3:%.*]] = alloca <256 x i32>, align 64 +; CHECK-NEXT: [[TMP4:%.*]] = alloca <256 x i32>, align 64 +; CHECK-NEXT: [[TMP5:%.*]] = udiv i16 [[K:%.*]], 4 +; CHECK-NEXT: [[T0:%.*]] = load <256 x i32>, <256 x i32>* [[PA:%.*]], align 64 +; CHECK-NEXT: [[TMP6:%.*]] = bitcast <256 x i32>* [[TMP4]] to i8* +; CHECK-NEXT: store <256 x i32> [[T0]], <256 x i32>* [[TMP4]], align 1024 +; CHECK-NEXT: [[TMP7:%.*]] = sext i16 [[K]] to i64 +; CHECK-NEXT: [[TMP8:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], i8* [[TMP6]], i64 [[TMP7]]) +; CHECK-NEXT: [[T2:%.*]] = load <256 x i32>, <256 x i32>* [[PB:%.*]], align 64 +; CHECK-NEXT: [[TMP9:%.*]] = bitcast <256 x i32>* [[TMP3]] to i8* +; CHECK-NEXT: store <256 x i32> [[T2]], <256 x i32>* [[TMP3]], align 1024 +; CHECK-NEXT: [[TMP10:%.*]] = sext i16 [[N:%.*]] to i64 +; CHECK-NEXT: [[TMP11:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[N]], i8* [[TMP9]], i64 [[TMP10]]) +; CHECK-NEXT: [[T4:%.*]] = load <256 x i32>, <256 x i32>* [[PC:%.*]], align 64 +; CHECK-NEXT: [[TMP12:%.*]] = bitcast <256 x i32>* [[TMP2]] to i8* +; CHECK-NEXT: store <256 x i32> [[T4]], <256 x i32>* [[TMP2]], align 1024 +; CHECK-NEXT: [[TMP13:%.*]] = sext i16 [[N]] to i64 +; CHECK-NEXT: [[TMP14:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP12]], i64 [[TMP13]]) +; CHECK-NEXT: [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbsud.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP14]], x86_amx [[TMP8]], x86_amx [[TMP11]]) +; CHECK-NEXT: [[TMP15:%.*]] = bitcast <256 x i32>* [[TMP1]] to i8* +; CHECK-NEXT: [[TMP16:%.*]] = sext i16 [[N]] to i64 +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP15]], i64 [[TMP16]], x86_amx [[T6]]) +; CHECK-NEXT: [[TMP17:%.*]] = load <256 x i32>, <256 x i32>* [[TMP1]], align 1024 +; CHECK-NEXT: store <256 x i32> [[TMP17]], <256 x i32>* [[PC]], align 64 +; CHECK-NEXT: ret void +; + %t0 = load <256 x i32>, <256 x i32>* %pa, align 64 + %t1 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t0) + %t2 = load <256 x i32>, <256 x i32>* %pb, align 64 + %t3 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t2) + %t4 = load <256 x i32>, <256 x i32>* %pc, align 64 + %t5 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t4) + %t6 = tail call x86_amx @llvm.x86.tdpbsud.internal(i16 %m, i16 %n, i16 %k, x86_amx %t5, x86_amx %t1, x86_amx %t3) + %t7 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %t6) + store <256 x i32> %t7, <256 x i32>* %pc, align 64 + ret void +} + +define dso_local void @__tile_dpbusd(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <256 x i32>* %pa, <256 x i32>* %pb) { +; CHECK-LABEL: @__tile_dpbusd( +; CHECK-NEXT: [[TMP1:%.*]] = alloca <256 x i32>, align 64 +; CHECK-NEXT: [[TMP2:%.*]] = alloca <256 x i32>, align 64 +; CHECK-NEXT: [[TMP3:%.*]] = alloca <256 x i32>, align 64 +; CHECK-NEXT: [[TMP4:%.*]] = alloca <256 x i32>, align 64 +; CHECK-NEXT: [[TMP5:%.*]] = udiv i16 [[K:%.*]], 4 +; CHECK-NEXT: [[T0:%.*]] = load <256 x i32>, <256 x i32>* [[PA:%.*]], align 64 +; CHECK-NEXT: [[TMP6:%.*]] = bitcast <256 x i32>* [[TMP4]] to i8* +; CHECK-NEXT: store <256 x i32> [[T0]], <256 x i32>* [[TMP4]], align 1024 +; CHECK-NEXT: [[TMP7:%.*]] = sext i16 [[K]] to i64 +; CHECK-NEXT: [[TMP8:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], i8* [[TMP6]], i64 [[TMP7]]) +; CHECK-NEXT: [[T2:%.*]] = load <256 x i32>, <256 x i32>* [[PB:%.*]], align 64 +; CHECK-NEXT: [[TMP9:%.*]] = bitcast <256 x i32>* [[TMP3]] to i8* +; CHECK-NEXT: store <256 x i32> [[T2]], <256 x i32>* [[TMP3]], align 1024 +; CHECK-NEXT: [[TMP10:%.*]] = sext i16 [[N:%.*]] to i64 +; CHECK-NEXT: [[TMP11:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[N]], i8* [[TMP9]], i64 [[TMP10]]) +; CHECK-NEXT: [[T4:%.*]] = load <256 x i32>, <256 x i32>* [[PC:%.*]], align 64 +; CHECK-NEXT: [[TMP12:%.*]] = bitcast <256 x i32>* [[TMP2]] to i8* +; CHECK-NEXT: store <256 x i32> [[T4]], <256 x i32>* [[TMP2]], align 1024 +; CHECK-NEXT: [[TMP13:%.*]] = sext i16 [[N]] to i64 +; CHECK-NEXT: [[TMP14:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP12]], i64 [[TMP13]]) +; CHECK-NEXT: [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbusd.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP14]], x86_amx [[TMP8]], x86_amx [[TMP11]]) +; CHECK-NEXT: [[TMP15:%.*]] = bitcast <256 x i32>* [[TMP1]] to i8* +; CHECK-NEXT: [[TMP16:%.*]] = sext i16 [[N]] to i64 +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP15]], i64 [[TMP16]], x86_amx [[T6]]) +; CHECK-NEXT: [[TMP17:%.*]] = load <256 x i32>, <256 x i32>* [[TMP1]], align 1024 +; CHECK-NEXT: store <256 x i32> [[TMP17]], <256 x i32>* [[PC]], align 64 +; CHECK-NEXT: ret void +; + %t0 = load <256 x i32>, <256 x i32>* %pa, align 64 + %t1 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t0) + %t2 = load <256 x i32>, <256 x i32>* %pb, align 64 + %t3 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t2) + %t4 = load <256 x i32>, <256 x i32>* %pc, align 64 + %t5 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t4) + %t6 = tail call x86_amx @llvm.x86.tdpbusd.internal(i16 %m, i16 %n, i16 %k, x86_amx %t5, x86_amx %t1, x86_amx %t3) + %t7 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %t6) + store <256 x i32> %t7, <256 x i32>* %pc, align 64 + ret void +} + +define dso_local void @__tile_dpbuud(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <256 x i32>* %pa, <256 x i32>* %pb) { +; CHECK-LABEL: @__tile_dpbuud( +; CHECK-NEXT: [[TMP1:%.*]] = alloca <256 x i32>, align 64 +; CHECK-NEXT: [[TMP2:%.*]] = alloca <256 x i32>, align 64 +; CHECK-NEXT: [[TMP3:%.*]] = alloca <256 x i32>, align 64 +; CHECK-NEXT: [[TMP4:%.*]] = alloca <256 x i32>, align 64 +; CHECK-NEXT: [[TMP5:%.*]] = udiv i16 [[K:%.*]], 4 +; CHECK-NEXT: [[T0:%.*]] = load <256 x i32>, <256 x i32>* [[PA:%.*]], align 64 +; CHECK-NEXT: [[TMP6:%.*]] = bitcast <256 x i32>* [[TMP4]] to i8* +; CHECK-NEXT: store <256 x i32> [[T0]], <256 x i32>* [[TMP4]], align 1024 +; CHECK-NEXT: [[TMP7:%.*]] = sext i16 [[K]] to i64 +; CHECK-NEXT: [[TMP8:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], i8* [[TMP6]], i64 [[TMP7]]) +; CHECK-NEXT: [[T2:%.*]] = load <256 x i32>, <256 x i32>* [[PB:%.*]], align 64 +; CHECK-NEXT: [[TMP9:%.*]] = bitcast <256 x i32>* [[TMP3]] to i8* +; CHECK-NEXT: store <256 x i32> [[T2]], <256 x i32>* [[TMP3]], align 1024 +; CHECK-NEXT: [[TMP10:%.*]] = sext i16 [[N:%.*]] to i64 +; CHECK-NEXT: [[TMP11:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[N]], i8* [[TMP9]], i64 [[TMP10]]) +; CHECK-NEXT: [[T4:%.*]] = load <256 x i32>, <256 x i32>* [[PC:%.*]], align 64 +; CHECK-NEXT: [[TMP12:%.*]] = bitcast <256 x i32>* [[TMP2]] to i8* +; CHECK-NEXT: store <256 x i32> [[T4]], <256 x i32>* [[TMP2]], align 1024 +; CHECK-NEXT: [[TMP13:%.*]] = sext i16 [[N]] to i64 +; CHECK-NEXT: [[TMP14:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP12]], i64 [[TMP13]]) +; CHECK-NEXT: [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbuud.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP14]], x86_amx [[TMP8]], x86_amx [[TMP11]]) +; CHECK-NEXT: [[TMP15:%.*]] = bitcast <256 x i32>* [[TMP1]] to i8* +; CHECK-NEXT: [[TMP16:%.*]] = sext i16 [[N]] to i64 +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP15]], i64 [[TMP16]], x86_amx [[T6]]) +; CHECK-NEXT: [[TMP17:%.*]] = load <256 x i32>, <256 x i32>* [[TMP1]], align 1024 +; CHECK-NEXT: store <256 x i32> [[TMP17]], <256 x i32>* [[PC]], align 64 +; CHECK-NEXT: ret void +; + %t0 = load <256 x i32>, <256 x i32>* %pa, align 64 + %t1 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t0) + %t2 = load <256 x i32>, <256 x i32>* %pb, align 64 + %t3 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t2) + %t4 = load <256 x i32>, <256 x i32>* %pc, align 64 + %t5 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t4) + %t6 = tail call x86_amx @llvm.x86.tdpbuud.internal(i16 %m, i16 %n, i16 %k, x86_amx %t5, x86_amx %t1, x86_amx %t3) + %t7 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %t6) + store <256 x i32> %t7, <256 x i32>* %pc, align 64 + ret void +} + +define dso_local void @__tile_dpbf16ps(i16 %m, i16 %n, i16 %k, <256 x i32>* %pc, <256 x i32>* %pa, <256 x i32>* %pb) { +; CHECK-LABEL: @__tile_dpbf16ps( +; CHECK-NEXT: [[TMP1:%.*]] = alloca <256 x i32>, align 64 +; CHECK-NEXT: [[TMP2:%.*]] = alloca <256 x i32>, align 64 +; CHECK-NEXT: [[TMP3:%.*]] = alloca <256 x i32>, align 64 +; CHECK-NEXT: [[TMP4:%.*]] = alloca <256 x i32>, align 64 +; CHECK-NEXT: [[TMP5:%.*]] = udiv i16 [[K:%.*]], 4 +; CHECK-NEXT: [[T0:%.*]] = load <256 x i32>, <256 x i32>* [[PA:%.*]], align 64 +; CHECK-NEXT: [[TMP6:%.*]] = bitcast <256 x i32>* [[TMP4]] to i8* +; CHECK-NEXT: store <256 x i32> [[T0]], <256 x i32>* [[TMP4]], align 1024 +; CHECK-NEXT: [[TMP7:%.*]] = sext i16 [[K]] to i64 +; CHECK-NEXT: [[TMP8:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[K]], i8* [[TMP6]], i64 [[TMP7]]) +; CHECK-NEXT: [[T2:%.*]] = load <256 x i32>, <256 x i32>* [[PB:%.*]], align 64 +; CHECK-NEXT: [[TMP9:%.*]] = bitcast <256 x i32>* [[TMP3]] to i8* +; CHECK-NEXT: store <256 x i32> [[T2]], <256 x i32>* [[TMP3]], align 1024 +; CHECK-NEXT: [[TMP10:%.*]] = sext i16 [[N:%.*]] to i64 +; CHECK-NEXT: [[TMP11:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[N]], i8* [[TMP9]], i64 [[TMP10]]) +; CHECK-NEXT: [[T4:%.*]] = load <256 x i32>, <256 x i32>* [[PC:%.*]], align 64 +; CHECK-NEXT: [[TMP12:%.*]] = bitcast <256 x i32>* [[TMP2]] to i8* +; CHECK-NEXT: store <256 x i32> [[T4]], <256 x i32>* [[TMP2]], align 1024 +; CHECK-NEXT: [[TMP13:%.*]] = sext i16 [[N]] to i64 +; CHECK-NEXT: [[TMP14:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M]], i16 [[N]], i8* [[TMP12]], i64 [[TMP13]]) +; CHECK-NEXT: [[T6:%.*]] = tail call x86_amx @llvm.x86.tdpbf16ps.internal(i16 [[M]], i16 [[N]], i16 [[K]], x86_amx [[TMP14]], x86_amx [[TMP8]], x86_amx [[TMP11]]) +; CHECK-NEXT: [[TMP15:%.*]] = bitcast <256 x i32>* [[TMP1]] to i8* +; CHECK-NEXT: [[TMP16:%.*]] = sext i16 [[N]] to i64 +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[TMP15]], i64 [[TMP16]], x86_amx [[T6]]) +; CHECK-NEXT: [[TMP17:%.*]] = load <256 x i32>, <256 x i32>* [[TMP1]], align 1024 +; CHECK-NEXT: store <256 x i32> [[TMP17]], <256 x i32>* [[PC]], align 64 +; CHECK-NEXT: ret void +; + %t0 = load <256 x i32>, <256 x i32>* %pa, align 64 + %t1 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t0) + %t2 = load <256 x i32>, <256 x i32>* %pb, align 64 + %t3 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t2) + %t4 = load <256 x i32>, <256 x i32>* %pc, align 64 + %t5 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t4) + %t6 = tail call x86_amx @llvm.x86.tdpbf16ps.internal(i16 %m, i16 %n, i16 %k, x86_amx %t5, x86_amx %t1, x86_amx %t3) + %t7 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %t6) + store <256 x i32> %t7, <256 x i32>* %pc, align 64 + ret void +} + +define dso_local void @__tile_stored(i8* %0, i64 %1, %struct.__tile_str* nocapture readonly byval(%struct.__tile_str) align 64 %2) local_unnamed_addr { +; CHECK-LABEL: @__tile_stored( +; CHECK-NEXT: [[TMP4:%.*]] = alloca <256 x i32>, align 64 +; CHECK-NEXT: [[TMP5:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[TMP2:%.*]], i64 0, i32 0 +; CHECK-NEXT: [[TMP6:%.*]] = load i16, i16* [[TMP5]], align 64 +; CHECK-NEXT: [[TMP7:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 1 +; CHECK-NEXT: [[TMP8:%.*]] = load i16, i16* [[TMP7]], align 2 +; CHECK-NEXT: [[TMP9:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 2 +; CHECK-NEXT: [[TMP10:%.*]] = load <256 x i32>, <256 x i32>* [[TMP9]], align 64 +; CHECK-NEXT: [[TMP11:%.*]] = bitcast <256 x i32>* [[TMP4]] to i8* +; CHECK-NEXT: store <256 x i32> [[TMP10]], <256 x i32>* [[TMP4]], align 1024 +; CHECK-NEXT: [[TMP12:%.*]] = sext i16 [[TMP8]] to i64 +; CHECK-NEXT: [[TMP13:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP6]], i16 [[TMP8]], i8* [[TMP11]], i64 [[TMP12]]) +; CHECK-NEXT: [[TMP14:%.*]] = shl i64 [[TMP1:%.*]], 32 +; CHECK-NEXT: [[TMP15:%.*]] = ashr exact i64 [[TMP14]], 32 +; CHECK-NEXT: tail call void @llvm.x86.tilestored64.internal(i16 [[TMP6]], i16 [[TMP8]], i8* [[TMP0:%.*]], i64 [[TMP15]], x86_amx [[TMP13]]) +; CHECK-NEXT: ret void +; + %4 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 0 + %5 = load i16, i16* %4, align 64 + %6 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 1 + %7 = load i16, i16* %6, align 2 + %8 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 2 + %9 = load <256 x i32>, <256 x i32>* %8, align 64 + %10 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %9) + %11 = shl i64 %1, 32 + %12 = ashr exact i64 %11, 32 + tail call void @llvm.x86.tilestored64.internal(i16 %5, i16 %7, i8* %0, i64 %12, x86_amx %10) + ret void +} + +declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) +declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) +declare x86_amx @llvm.x86.tdpbsud.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) +declare x86_amx @llvm.x86.tdpbusd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) +declare x86_amx @llvm.x86.tdpbuud.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) +declare x86_amx @llvm.x86.tdpbf16ps.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) +declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx) + +declare x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32>) +declare x86_amx @llvm.x86.cast.vector.to.tile.v225i32(<225 x i32>) +declare <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx) +declare <225 x i32> @llvm.x86.cast.tile.to.vector.v225i32(x86_amx)