Skip to content

Commit

Permalink
[Matrix] Propagate and use shape information for loads.
Browse files Browse the repository at this point in the history
This patch extends to shape propagation to also include load
instructions and implements shape aware lowering for vector loads.

Reviewers: anemet, Gerolf, reames, hfinkel, andrew.w.kaylor

Reviewed By: anemet

Differential Revision: https://reviews.llvm.org/D70900

(Cherry-picked from 7adf664)
  • Loading branch information
fhahn committed Mar 5, 2020
1 parent a43aacc commit 40ff955
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 136 deletions.
42 changes: 29 additions & 13 deletions llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
Expand Up @@ -95,20 +95,20 @@ Value *computeColumnAddr(Value *BasePtr, Value *Col, Value *Stride,
unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();

// Compute the start of the column with index Col as Col * Stride.
Value *ColumnStart = Builder.CreateMul(Col, Stride);
Value *ColumnStart = Builder.CreateMul(Col, Stride, "col.start");

// Get pointer to the start of the selected column. Skip GEP creation,
// if we select column 0.
if (isa<ConstantInt>(ColumnStart) && cast<ConstantInt>(ColumnStart)->isZero())
ColumnStart = BasePtr;
else
ColumnStart = Builder.CreateGEP(EltType, BasePtr, ColumnStart);
ColumnStart = Builder.CreateGEP(EltType, BasePtr, ColumnStart, "col.gep");

// Cast elementwise column start pointer to a pointer to a column
// (EltType x NumRows)*.
Type *ColumnType = VectorType::get(EltType, NumRows);
Type *ColumnPtrType = PointerType::get(ColumnType, AS);
return Builder.CreatePointerCast(ColumnStart, ColumnPtrType);
return Builder.CreatePointerCast(ColumnStart, ColumnPtrType, "col.cast");
}

/// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
Expand Down Expand Up @@ -317,7 +317,7 @@ class LowerMatrixIntrinsics {
default:
return false;
}
return isUniformShape(V) || isa<StoreInst>(V);
return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
}

/// Propagate the shape information of instructions to their users.
Expand Down Expand Up @@ -481,6 +481,8 @@ class LowerMatrixIntrinsics {
Value *Op2;
if (auto *BinOp = dyn_cast<BinaryOperator>(&Inst))
Changed |= VisitBinaryOperator(BinOp);
if (match(&Inst, m_Load(m_Value(Op1))))
Changed |= VisitLoad(&Inst, Op1, Builder);
else if (match(&Inst, m_Store(m_Value(Op1), m_Value(Op2))))
Changed |= VisitStore(&Inst, Op1, Op2, Builder);
}
Expand All @@ -495,7 +497,7 @@ class LowerMatrixIntrinsics {
LoadInst *createColumnLoad(Value *ColumnPtr, Type *EltType,
IRBuilder<> Builder) {
unsigned Align = DL.getABITypeAlignment(EltType);
return Builder.CreateAlignedLoad(ColumnPtr, Align);
return Builder.CreateAlignedLoad(ColumnPtr, Align, "col.load");
}

StoreInst *createColumnStore(Value *ColumnValue, Value *ColumnPtr,
Expand Down Expand Up @@ -536,17 +538,11 @@ class LowerMatrixIntrinsics {
return true;
}

/// Lowers llvm.matrix.columnwise.load.
///
/// The intrinsic loads a matrix from memory using a stride between columns.
void LowerColumnwiseLoad(CallInst *Inst) {
void LowerLoad(Instruction *Inst, Value *Ptr, Value *Stride,
ShapeInfo Shape) {
IRBuilder<> Builder(Inst);
Value *Ptr = Inst->getArgOperand(0);
Value *Stride = Inst->getArgOperand(1);
auto VType = cast<VectorType>(Inst->getType());
Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
ShapeInfo Shape(Inst->getArgOperand(2), Inst->getArgOperand(3));

ColumnMatrixTy Result;
// Distance between start of one column and the start of the next
for (unsigned C = 0, E = Shape.NumColumns; C < E; ++C) {
Expand All @@ -560,6 +556,16 @@ class LowerMatrixIntrinsics {
finalizeLowering(Inst, Result, Builder);
}

/// Lowers llvm.matrix.columnwise.load.
///
/// The intrinsic loads a matrix from memory using a stride between columns.
void LowerColumnwiseLoad(CallInst *Inst) {
Value *Ptr = Inst->getArgOperand(0);
Value *Stride = Inst->getArgOperand(1);
LowerLoad(Inst, Ptr, Stride,
{Inst->getArgOperand(2), Inst->getArgOperand(3)});
}

void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, Value *Stride,
ShapeInfo Shape) {
IRBuilder<> Builder(Inst);
Expand Down Expand Up @@ -755,6 +761,16 @@ class LowerMatrixIntrinsics {
finalizeLowering(Inst, Result, Builder);
}

/// Lower load instructions, if shape information is available.
bool VisitLoad(Instruction *Inst, Value *Ptr, IRBuilder<> &Builder) {
auto I = ShapeMap.find(Inst);
if (I == ShapeMap.end())
return false;

LowerLoad(Inst, Ptr, Builder.getInt32(I->second.NumRows), I->second);
return true;
}

bool VisitStore(Instruction *Inst, Value *StoredVal, Value *Ptr,
IRBuilder<> &Builder) {
auto I = ShapeMap.find(StoredVal);
Expand Down

0 comments on commit 40ff955

Please sign in to comment.