Skip to content

Commit

Permalink
Lower miopen.lds_barrier to Std.
Browse files Browse the repository at this point in the history
  • Loading branch information
whchung committed Jun 6, 2020
1 parent 10f4843 commit d87f8d1
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 93 deletions.
Expand Up @@ -29,7 +29,7 @@ void populateMIOpenOpsToLLVMConversionPatterns(LLVMTypeConverter &converter,
namespace miopen {

/// Create a pass to convert MIOpen operations to LLVM operations.
std::unique_ptr<OpPassBase<FuncOp>> createLowerMIOpenOpsToLLVMPass();
std::unique_ptr<OpPassBase<ModuleOp>> createLowerMIOpenOpsToLLVMPass();

} // namespace miopen
} // namespace mlir
Expand Down
206 changes: 114 additions & 92 deletions mlir/lib/Dialect/MIOpenOps/LLVMOutput/ConvertMIOpenOpsToLLVM.cpp
Expand Up @@ -47,106 +47,128 @@
using namespace mlir;

namespace {
struct LowerMIOpenOpsToLLVMPass : public FunctionPass<LowerMIOpenOpsToLLVMPass> {
void runOnFunction() override;
struct LowerMIOpenOpsToLLVMPass : public ModulePass<LowerMIOpenOpsToLLVMPass> {
void runOnModule() override;
};
} // end anonymous namespace

void LowerMIOpenOpsToLLVMPass::runOnFunction() {
FuncOp func = getFunction();
LLVMTypeConverter converter(&getContext());

func.walk([&](miopen::TransformOp op) {
op.replaceAllUsesWith(op.input());
op.erase();
});

func.walk([&](miopen::FillOp op) {
op.erase();
});

func.walk([&](miopen::LdsBarrierOp op) {
op.erase();
});

func.walk([&](miopen::ThreadwiseGemmOp op) {
op.erase();
});

func.walk([&](miopen::ThreadwiseCopyOp op) {
op.erase();
});

func.walk([&](miopen::GpuAllocOp op) {
auto loc = op.getLoc();
auto sizeBytes = op.sizeBytes().getDefiningOp()->getAttr("value").dyn_cast<IntegerAttr>().getInt();
auto type = op.output().getType().cast<MemRefType>();

OpBuilder b(op.getContext());

if (type.getMemorySpace() == 5) {
// Create llvm.mlir.alloca for VGPRs.
b.setInsertionPointToStart(op.getOperation()->getBlock());
auto ptrType = converter.convertType(type.getElementType())
.cast<LLVM::LLVMType>().getPointerTo();

void LowerMIOpenOpsToLLVMPass::runOnModule() {
auto m = getModule();

for (auto func : m.getOps<FuncOp>()) {
LLVMTypeConverter converter(&getContext());

func.walk([&](miopen::TransformOp op) {
op.replaceAllUsesWith(op.input());
op.erase();
});

func.walk([&](miopen::FillOp op) {
//OpBuilder b(op.getContext());
//b.setInsertionPointToStart(op.getOperation()->getBlock());
//auto zeroConstantIndexOp = b.create<ConstantIndexOp>(op.getLoc(), 0);
//auto storeOp = b.create<StoreOp>(op.getLoc(), op.value(), op.input(), ValueRange{zeroConstantIndexOp});
//storeOp.dump();
op.erase();
});

func.walk([&](miopen::ThreadwiseGemmOp op) {
op.erase();
});

func.walk([&](miopen::ThreadwiseCopyOp op) {
op.erase();
});

func.walk([&](miopen::GpuAllocOp op) {
auto loc = op.getLoc();
auto sizeBytes = op.sizeBytes().getDefiningOp()->getAttr("value").dyn_cast<IntegerAttr>().getInt();
auto type = op.output().getType().cast<MemRefType>();

OpBuilder b(op.getContext());

if (type.getMemorySpace() == 5) {
// Create llvm.mlir.alloca for VGPRs.
b.setInsertionPointToStart(op.getOperation()->getBlock());
auto ptrType = converter.convertType(type.getElementType())
.cast<LLVM::LLVMType>().getPointerTo();

auto *llvmDialect = b.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();

auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
auto numElements = b.create<LLVM::ConstantOp>(loc, int64Ty, b.getIntegerAttr(b.getIndexType(), sizeBytes));
auto allocated = b.create<LLVM::AllocaOp>(loc, ptrType, numElements, 0);
op.replaceAllUsesWith(allocated.res());
} else if (type.getMemorySpace() == 3) {
// Create llvm.mlir.global for LDS.
b.setInsertionPointToStart(op.getOperation()->getParentOp()->getBlock());
auto elementType = converter.convertType(type.getElementType()).cast<LLVM::LLVMType>();
auto arrayType = LLVM::LLVMType::getArrayTy(elementType, sizeBytes);
StringRef name = "lds_buffer";
auto globalOp = b.create<LLVM::GlobalOp>(loc, arrayType.cast<LLVM::LLVMType>(),
/*isConstant=*/false, LLVM::Linkage::Internal, name,
/*value=*/Attribute(), 3);
b.setInsertionPoint(op);
auto addrOfOp = b.create<LLVM::AddressOfOp>(loc, globalOp);
op.replaceAllUsesWith(addrOfOp.res());
}
op.erase();
});

func.walk([&](miopen::SubviewOp op) {
OpBuilder b(op.getContext());
b.setInsertionPoint(op);
auto *llvmDialect = b.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();

auto int64Ty = LLVM::LLVMType::getInt64Ty(llvmDialect);
auto numElements = b.create<LLVM::ConstantOp>(loc, int64Ty, b.getIntegerAttr(b.getIndexType(), sizeBytes));
auto allocated = b.create<LLVM::AllocaOp>(loc, ptrType, numElements, 0);
op.replaceAllUsesWith(allocated.res());
} else if (type.getMemorySpace() == 3) {
// Create llvm.mlir.global for LDS.
b.setInsertionPointToStart(op.getOperation()->getParentOp()->getBlock());
auto elementType = converter.convertType(type.getElementType()).cast<LLVM::LLVMType>();
auto arrayType = LLVM::LLVMType::getArrayTy(elementType, sizeBytes);
StringRef name = "lds_buffer";
auto globalOp = b.create<LLVM::GlobalOp>(loc, arrayType.cast<LLVM::LLVMType>(),
/*isConstant=*/false, LLVM::Linkage::Internal, name,
/*value=*/Attribute(), 3);
auto loc = op.getLoc();
auto type = op.input().getType();
auto llvmType = type.cast<LLVM::LLVMType>();

auto int32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
auto offset = op.offset().getDefiningOp()->getAttr("value").dyn_cast<IntegerAttr>().getInt();
Value offsetValue = b.create<LLVM::ConstantOp>(loc, int32Ty, b.getI32IntegerAttr(offset));

// inputType would have been converted to LLVMType already.
auto inputType = op.input().getType().cast<LLVM::LLVMType>();
auto inputAddrSpace = dyn_cast<llvm::PointerType>(inputType.getUnderlyingType())->getAddressSpace();

auto gepOp = b.create<LLVM::GEPOp>(loc, inputType, ArrayRef<Value>({offsetValue}));

auto outputType = op.output().getType().cast<MemRefType>();
auto outputElementType = converter.convertType(outputType.getElementType()).cast<LLVM::LLVMType>();

auto outputShape = outputType.getShape();
LLVM::LLVMType outputLLVMType = outputElementType;
for (int i = outputShape.size() - 1; i >= 0; --i) {
outputLLVMType = LLVM::LLVMType::getArrayTy(outputLLVMType, outputShape[i]);
}
auto bitcastOp = b.create<LLVM::BitcastOp>(loc, outputLLVMType.getPointerTo(inputAddrSpace), gepOp);

op.replaceAllUsesWith(bitcastOp.res());
op.erase();
});

func.walk([&](miopen::LdsBarrierOp op) {
OpBuilder b(op.getContext());
auto loc = op.getLoc();
if (!getModule().lookupSymbol<FuncOp>("lds_barrier")) {
auto funcType = b.getFunctionType({}, {});

StringRef funcName = "lds_barrier";
b.setInsertionPoint(getModule().getBody(), getModule().getBody()->begin());
auto func = b.create<FuncOp>(loc, funcName, funcType, ArrayRef<NamedAttribute>{});
}
auto barrierFunc = getModule().lookupSymbol<FuncOp>("lds_barrier");
b.setInsertionPoint(op);
auto addrOfOp = b.create<LLVM::AddressOfOp>(loc, globalOp);
op.replaceAllUsesWith(addrOfOp.res());
}
op.erase();
});

func.walk([&](miopen::SubviewOp op) {
OpBuilder b(op.getContext());
b.setInsertionPoint(op);
auto *llvmDialect = b.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();

auto loc = op.getLoc();
auto type = op.input().getType();
auto llvmType = type.cast<LLVM::LLVMType>();

auto int32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
auto offset = op.offset().getDefiningOp()->getAttr("value").dyn_cast<IntegerAttr>().getInt();
Value offsetValue = b.create<LLVM::ConstantOp>(loc, int32Ty, b.getI32IntegerAttr(offset));

// inputType would have been converted to LLVMType already.
auto inputType = op.input().getType().cast<LLVM::LLVMType>();
auto inputAddrSpace = dyn_cast<llvm::PointerType>(inputType.getUnderlyingType())->getAddressSpace();

auto gepOp = b.create<LLVM::GEPOp>(loc, inputType, ArrayRef<Value>({offsetValue}));

auto outputType = op.output().getType().cast<MemRefType>();
auto outputElementType = converter.convertType(outputType.getElementType()).cast<LLVM::LLVMType>();

auto outputShape = outputType.getShape();
LLVM::LLVMType outputLLVMType = outputElementType;
for (int i = outputShape.size() - 1; i >= 0; --i) {
outputLLVMType = LLVM::LLVMType::getArrayTy(outputLLVMType, outputShape[i]);
}
auto bitcastOp = b.create<LLVM::BitcastOp>(loc, outputLLVMType.getPointerTo(inputAddrSpace), gepOp);

op.replaceAllUsesWith(bitcastOp.res());
op.erase();
});
b.create<CallOp>(loc, ArrayRef<Type>{},
b.getSymbolRefAttr(barrierFunc),
ArrayRef<Value>{});
op.erase();
});
}
}

std::unique_ptr<OpPassBase<FuncOp>> mlir::miopen::createLowerMIOpenOpsToLLVMPass() {
std::unique_ptr<OpPassBase<ModuleOp>> mlir::miopen::createLowerMIOpenOpsToLLVMPass() {
return std::make_unique<LowerMIOpenOpsToLLVMPass>();
}

Expand Down

0 comments on commit d87f8d1

Please sign in to comment.