Skip to content

Commit

Permalink
[mlir][sparse] introduce sparse_tensor::StorageSpecifierToLLVM pass
Browse files Browse the repository at this point in the history
Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D140122
  • Loading branch information
PeimingLiu committed Dec 22, 2022
1 parent b49ee01 commit 083ddff
Show file tree
Hide file tree
Showing 11 changed files with 890 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def SparseTensor_StorageSpecifier : SparseTensor_Type<"StorageSpecifier"> {

let parameters = (ins SparseTensorEncodingAttr : $encoding);
let builders = [
TypeBuilder<(ins "SparseTensorEncodingAttr":$encoding)>,
TypeBuilderWithInferredContext<(ins "SparseTensorEncodingAttr":$encoding), [{
assert(encoding && "sparse tensor encoding should not be null");
return $_get(encoding.getContext(), encoding);
return get(encoding.getContext(), encoding);
}]>,
TypeBuilderWithInferredContext<(ins "Type":$type), [{
return get(getSparseTensorEncoding(type));
Expand All @@ -71,6 +71,10 @@ def SparseTensor_StorageSpecifier : SparseTensor_Type<"StorageSpecifier"> {
Type getFieldType(StorageSpecifierKind kind, std::optional<APInt> dim) const;
}];

// We skipped the default builder that simply takes the input sparse tensor encoding
// attribute since we need to normalize the dimension level type and remove unrelated
// fields that are irrelavant to sparse tensor storage scheme.
let skipDefaultBuilders = 1;
let assemblyFormat="`<` qualified($encoding) `>`";
}

Expand Down
13 changes: 13 additions & 0 deletions mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,19 @@ std::unique_ptr<Pass>
createPostSparsificationRewritePass(bool enableRT, bool enableForeach = true,
bool enableConvert = true);

//===----------------------------------------------------------------------===//
// The SparseStorageSpecifierToLLVM pass.
//===----------------------------------------------------------------------===//

class StorageSpecifierToLLVMTypeConverter : public TypeConverter {
public:
StorageSpecifierToLLVMTypeConverter();
};

void populateStorageSpecifierToLLVMPatterns(TypeConverter &converter,
RewritePatternSet &patterns);
std::unique_ptr<Pass> createStorageSpecifierToLLVMPass();

//===----------------------------------------------------------------------===//
// Other rewriting rules and passes.
//===----------------------------------------------------------------------===//
Expand Down
24 changes: 24 additions & 0 deletions mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -301,4 +301,28 @@ def SparseVectorization : Pass<"sparse-vectorization", "ModuleOp"> {
];
}

def StorageSpecifierToLLVM : Pass<"sparse-storage-specifier-to-llvm", "ModuleOp"> {
let summary = "Lower sparse storage specifer to llvm structure";
let description = [{
This pass rewrites sparse tensor storage specifier-related operations into
LLVMDialect, and converts sparse tensor storage specifier into an llvm.struct.

Example of the conversion:
```mlir
Before:
%0 = sparse_tensor.storage_specifier.get %arg0 dim_sz at 0
: !sparse_tensor.storage_specifier<#CSR> to i64

After:
%0 = llvm.extractvalue %arg0[0, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
```
}];
let constructor = "mlir::createStorageSpecifierToLLVMPass()";
let dependentDialects = [
"arith::ArithDialect",
"LLVM::LLVMDialect",
"sparse_tensor::SparseTensorDialect",
];
}

#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES
22 changes: 22 additions & 0 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,28 @@ uint64_t mlir::sparse_tensor::toStoredDim(RankedTensorType type, uint64_t d) {
// SparseTensorDialect Types.
//===----------------------------------------------------------------------===//

/// We normalized sparse tensor encoding attribute by always using
/// ordered/unique DLT such that "compressed-nu-no" and "compressed-nu" (as well
/// as other variants) lead to the same storage specifier type, and stripping
/// irrelevant fields that does not alter the sparse tensor memory layout.
static SparseTensorEncodingAttr
getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
SmallVector<DimLevelType> dlts;
for (auto dlt : enc.getDimLevelType())
dlts.push_back(*getDimLevelType(*getLevelFormat(dlt), true, true));

return SparseTensorEncodingAttr::get(
enc.getContext(), dlts,
AffineMap(), // dimOrdering (irrelavant to storage speicifer)
AffineMap(), // highLvlOrdering (irrelavant to storage specifer)
enc.getPointerBitWidth(), enc.getIndexBitWidth());
}

StorageSpecifierType
StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding));
}

IntegerType StorageSpecifierType::getSizesType() const {
unsigned idxBitWidth =
getEncoding().getIndexBitWidth() ? getEncoding().getIndexBitWidth() : 64u;
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
CodegenEnv.cpp
CodegenUtils.cpp
SparseBufferRewriting.cpp
SparseStorageSpecifierToLLVM.cpp
SparseTensorCodegen.cpp
SparseTensorConversion.cpp
SparseTensorPasses.cpp
SparseTensorRewriting.cpp
SparseTensorStorageLayout.cpp
SparseVectorization.cpp
Sparsification.cpp
SparsificationAndBufferizationPass.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
//===- SparseStorageSpecifierToLLVM.cpp - convert specifier to llvm -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "CodegenUtils.h"
#include "SparseTensorStorageLayout.h"

#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"

using namespace mlir;
using namespace sparse_tensor;

static SmallVector<Type, 2> getSpecifierFields(StorageSpecifierType tp) {
MLIRContext *ctx = tp.getContext();
auto enc = tp.getEncoding();
unsigned rank = enc.getDimLevelType().size();

SmallVector<Type, 2> result;
auto indexType = tp.getSizesType();
auto dimSizes = LLVM::LLVMArrayType::get(ctx, indexType, rank);
auto memSizes = LLVM::LLVMArrayType::get(ctx, indexType,
getNumDataFieldsFromEncoding(enc));
result.push_back(dimSizes);
result.push_back(memSizes);
return result;
}

static Type convertSpecifier(StorageSpecifierType tp) {
return LLVM::LLVMStructType::getLiteral(tp.getContext(),
getSpecifierFields(tp));
}

StorageSpecifierToLLVMTypeConverter::StorageSpecifierToLLVMTypeConverter() {
addConversion([](Type type) { return type; });
addConversion([](StorageSpecifierType tp) { return convertSpecifier(tp); });
}

constexpr uint64_t kDimSizePosInSpecifier = 0;
constexpr uint64_t kMemSizePosInSpecifier = 1;

class SpecifierStructBuilder : public StructBuilder {
public:
explicit SpecifierStructBuilder(Value specifier) : StructBuilder(specifier) {
assert(value);
}

// Undef value for dimension sizes, all zero value for memory sizes.
static Value getInitValue(OpBuilder &builder, Location loc, Type structType);

Value dimSize(OpBuilder &builder, Location loc, unsigned dim);
void setDimSize(OpBuilder &builder, Location loc, unsigned dim, Value size);

Value memSize(OpBuilder &builder, Location loc, unsigned pos);
void setMemSize(OpBuilder &builder, Location loc, unsigned pos, Value size);
};

Value SpecifierStructBuilder::getInitValue(OpBuilder &builder, Location loc,
Type structType) {
Value metaData = builder.create<LLVM::UndefOp>(loc, structType);
SpecifierStructBuilder md(metaData);
auto memSizeArrayType = structType.cast<LLVM::LLVMStructType>()
.getBody()[kMemSizePosInSpecifier]
.cast<LLVM::LLVMArrayType>();

Value zero = constantZero(builder, loc, memSizeArrayType.getElementType());
// Fill memSizes array with zero.
for (int i = 0, e = memSizeArrayType.getNumElements(); i < e; i++)
md.setMemSize(builder, loc, i, zero);

return md;
}

/// Builds IR inserting the pos-th size into the descriptor.
Value SpecifierStructBuilder::dimSize(OpBuilder &builder, Location loc,
unsigned dim) {
return builder.create<LLVM::ExtractValueOp>(
loc, value, ArrayRef<int64_t>({kDimSizePosInSpecifier, dim}));
}

/// Builds IR inserting the pos-th size into the descriptor.
void SpecifierStructBuilder::setDimSize(OpBuilder &builder, Location loc,
unsigned dim, Value size) {
value = builder.create<LLVM::InsertValueOp>(
loc, value, size, ArrayRef<int64_t>({kDimSizePosInSpecifier, dim}));
}

/// Builds IR extracting the pos-th memory size into the descriptor.
Value SpecifierStructBuilder::memSize(OpBuilder &builder, Location loc,
unsigned pos) {
return builder.create<LLVM::ExtractValueOp>(
loc, value, ArrayRef<int64_t>({kMemSizePosInSpecifier, pos}));
}

/// Builds IR inserting the pos-th memory size into the descriptor.
void SpecifierStructBuilder::setMemSize(OpBuilder &builder, Location loc,
unsigned pos, Value size) {
value = builder.create<LLVM::InsertValueOp>(
loc, value, size, ArrayRef<int64_t>({kMemSizePosInSpecifier, pos}));
}

template <typename Base, typename SourceOp>
class SpecifierGetterSetterOpConverter : public OpConversionPattern<SourceOp> {
public:
using OpAdaptor = typename SourceOp::Adaptor;
using OpConversionPattern<SourceOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SpecifierStructBuilder spec(adaptor.getSpecifier());
Value v;
if (op.getSpecifierKind() == StorageSpecifierKind::DimSize) {
v = Base::onDimSize(rewriter, op, spec,
op.getDim().value().getZExtValue());
} else {
auto enc = op.getSpecifier().getType().getEncoding();
builder::StorageLayout layout(enc);
Optional<unsigned> dim = std::nullopt;
if (op.getDim())
dim = op.getDim().value().getZExtValue();
unsigned idx = layout.getMemRefFieldIndex(op.getSpecifierKind(), dim);
v = Base::onMemSize(rewriter, op, spec, idx);
}

rewriter.replaceOp(op, v);
return success();
}
};

struct StorageSpecifierSetOpConverter
: public SpecifierGetterSetterOpConverter<StorageSpecifierSetOpConverter,
SetStorageSpecifierOp> {
using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
static Value onDimSize(OpBuilder &builder, SetStorageSpecifierOp op,
SpecifierStructBuilder &spec, unsigned d) {
spec.setDimSize(builder, op.getLoc(), d, op.getValue());
return spec;
}

static Value onMemSize(OpBuilder &builder, SetStorageSpecifierOp op,
SpecifierStructBuilder &spec, unsigned i) {
spec.setMemSize(builder, op.getLoc(), i, op.getValue());
return spec;
}
};

struct StorageSpecifierGetOpConverter
: public SpecifierGetterSetterOpConverter<StorageSpecifierGetOpConverter,
GetStorageSpecifierOp> {
using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
static Value onDimSize(OpBuilder &builder, GetStorageSpecifierOp op,
SpecifierStructBuilder &spec, unsigned d) {
return spec.dimSize(builder, op.getLoc(), d);
}
static Value onMemSize(OpBuilder &builder, GetStorageSpecifierOp op,
SpecifierStructBuilder &spec, unsigned i) {
return spec.memSize(builder, op.getLoc(), i);
}
};

struct StorageSpecifierInitOpConverter
: public OpConversionPattern<StorageSpecifierInitOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(StorageSpecifierInitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type llvmType = getTypeConverter()->convertType(op.getResult().getType());
rewriter.replaceOp(op, SpecifierStructBuilder::getInitValue(
rewriter, op.getLoc(), llvmType));
return success();
}
};

void mlir::populateStorageSpecifierToLLVMPatterns(TypeConverter &converter,
RewritePatternSet &patterns) {
patterns.add<StorageSpecifierGetOpConverter, StorageSpecifierSetOpConverter,
StorageSpecifierInitOpConverter>(converter,
patterns.getContext());
}
54 changes: 51 additions & 3 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ namespace mlir {
#define GEN_PASS_DEF_SPARSETENSORCODEGEN
#define GEN_PASS_DEF_SPARSEBUFFERREWRITE
#define GEN_PASS_DEF_SPARSEVECTORIZATION
#define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
} // namespace mlir

Expand Down Expand Up @@ -193,9 +194,14 @@ struct SparseTensorCodegenPass
target.addLegalOp<SortOp>();
target.addLegalOp<SortCooOp>();
target.addLegalOp<PushBackOp>();
// All dynamic rules below accept new function, call, return, and various
// tensor and bufferization operations as legal output of the rewriting
// provided that all sparse tensor types have been fully rewritten.
// Storage specifier outlives sparse tensor pipeline.
target.addLegalOp<GetStorageSpecifierOp>();
target.addLegalOp<SetStorageSpecifierOp>();
target.addLegalOp<StorageSpecifierInitOp>();
// All dynamic rules below accept new function, call, return, and
// various tensor and bufferization operations as legal output of the
// rewriting provided that all sparse tensor types have been fully
// rewritten.
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
return converter.isSignatureLegal(op.getFunctionType());
});
Expand Down Expand Up @@ -271,6 +277,44 @@ struct SparseVectorizationPass
}
};

struct StorageSpecifierToLLVMPass
: public impl::StorageSpecifierToLLVMBase<StorageSpecifierToLLVMPass> {

StorageSpecifierToLLVMPass() = default;

void runOnOperation() override {
auto *ctx = &getContext();
ConversionTarget target(*ctx);
RewritePatternSet patterns(ctx);
StorageSpecifierToLLVMTypeConverter converter;

// All ops in the sparse dialect must go!
target.addIllegalDialect<SparseTensorDialect>();
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
return converter.isSignatureLegal(op.getFunctionType());
});
target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
return converter.isSignatureLegal(op.getCalleeType());
});
target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
return converter.isLegal(op.getOperandTypes());
});
target.addLegalDialect<arith::ArithDialect, LLVM::LLVMDialect>();

populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
converter);
populateCallOpTypeConversionPattern(patterns, converter);
populateBranchOpInterfaceTypeConversionPattern(patterns, converter);
populateReturnOpTypeConversionPattern(patterns, converter);
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
target);
populateStorageSpecifierToLLVMPatterns(converter, patterns);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};

} // namespace

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -355,3 +399,7 @@ mlir::createSparseVectorizationPass(unsigned vectorLength,
return std::make_unique<SparseVectorizationPass>(
vectorLength, enableVLAVectorization, enableSIMDIndex32);
}

std::unique_ptr<Pass> mlir::createStorageSpecifierToLLVMPass() {
return std::make_unique<StorageSpecifierToLLVMPass>();
}
Loading

0 comments on commit 083ddff

Please sign in to comment.