Skip to content

Commit 083ddff

Browse files
author
Peiming Liu
committed
[mlir][sparse] introduce sparse_tensor::StorageSpecifierToLLVM pass
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D140122
1 parent b49ee01 commit 083ddff

File tree

11 files changed

+890
-5
lines changed

11 files changed

+890
-5
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ def SparseTensor_StorageSpecifier : SparseTensor_Type<"StorageSpecifier"> {
5252

5353
let parameters = (ins SparseTensorEncodingAttr : $encoding);
5454
let builders = [
55+
TypeBuilder<(ins "SparseTensorEncodingAttr":$encoding)>,
5556
TypeBuilderWithInferredContext<(ins "SparseTensorEncodingAttr":$encoding), [{
56-
assert(encoding && "sparse tensor encoding should not be null");
57-
return $_get(encoding.getContext(), encoding);
57+
return get(encoding.getContext(), encoding);
5858
}]>,
5959
TypeBuilderWithInferredContext<(ins "Type":$type), [{
6060
return get(getSparseTensorEncoding(type));
@@ -71,6 +71,10 @@ def SparseTensor_StorageSpecifier : SparseTensor_Type<"StorageSpecifier"> {
7171
Type getFieldType(StorageSpecifierKind kind, std::optional<APInt> dim) const;
7272
}];
7373

74+
// We skipped the default builder that simply takes the input sparse tensor encoding
75+
// attribute since we need to normalize the dimension level type and remove unrelated
76+
// fields that are irrelavant to sparse tensor storage scheme.
77+
let skipDefaultBuilders = 1;
7478
let assemblyFormat="`<` qualified($encoding) `>`";
7579
}
7680

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,19 @@ std::unique_ptr<Pass>
158158
createPostSparsificationRewritePass(bool enableRT, bool enableForeach = true,
159159
bool enableConvert = true);
160160

161+
//===----------------------------------------------------------------------===//
162+
// The SparseStorageSpecifierToLLVM pass.
163+
//===----------------------------------------------------------------------===//
164+
165+
class StorageSpecifierToLLVMTypeConverter : public TypeConverter {
166+
public:
167+
StorageSpecifierToLLVMTypeConverter();
168+
};
169+
170+
void populateStorageSpecifierToLLVMPatterns(TypeConverter &converter,
171+
RewritePatternSet &patterns);
172+
std::unique_ptr<Pass> createStorageSpecifierToLLVMPass();
173+
161174
//===----------------------------------------------------------------------===//
162175
// Other rewriting rules and passes.
163176
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,4 +301,28 @@ def SparseVectorization : Pass<"sparse-vectorization", "ModuleOp"> {
301301
];
302302
}
303303

304+
def StorageSpecifierToLLVM : Pass<"sparse-storage-specifier-to-llvm", "ModuleOp"> {
305+
let summary = "Lower sparse storage specifer to llvm structure";
306+
let description = [{
307+
This pass rewrites sparse tensor storage specifier-related operations into
308+
LLVMDialect, and converts sparse tensor storage specifier into an llvm.struct.
309+
310+
Example of the conversion:
311+
```mlir
312+
Before:
313+
%0 = sparse_tensor.storage_specifier.get %arg0 dim_sz at 0
314+
: !sparse_tensor.storage_specifier<#CSR> to i64
315+
316+
After:
317+
%0 = llvm.extractvalue %arg0[0, 0] : !llvm.struct<(array<2 x i64>, array<3 x i64>)>
318+
```
319+
}];
320+
let constructor = "mlir::createStorageSpecifierToLLVMPass()";
321+
let dependentDialects = [
322+
"arith::ArithDialect",
323+
"LLVM::LLVMDialect",
324+
"sparse_tensor::SparseTensorDialect",
325+
];
326+
}
327+
304328
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,28 @@ uint64_t mlir::sparse_tensor::toStoredDim(RankedTensorType type, uint64_t d) {
323323
// SparseTensorDialect Types.
324324
//===----------------------------------------------------------------------===//
325325

326+
/// We normalized sparse tensor encoding attribute by always using
327+
/// ordered/unique DLT such that "compressed-nu-no" and "compressed-nu" (as well
328+
/// as other variants) lead to the same storage specifier type, and stripping
329+
/// irrelevant fields that does not alter the sparse tensor memory layout.
330+
static SparseTensorEncodingAttr
331+
getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) {
332+
SmallVector<DimLevelType> dlts;
333+
for (auto dlt : enc.getDimLevelType())
334+
dlts.push_back(*getDimLevelType(*getLevelFormat(dlt), true, true));
335+
336+
return SparseTensorEncodingAttr::get(
337+
enc.getContext(), dlts,
338+
AffineMap(), // dimOrdering (irrelavant to storage speicifer)
339+
AffineMap(), // highLvlOrdering (irrelavant to storage specifer)
340+
enc.getPointerBitWidth(), enc.getIndexBitWidth());
341+
}
342+
343+
StorageSpecifierType
344+
StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) {
345+
return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding));
346+
}
347+
326348
IntegerType StorageSpecifierType::getSizesType() const {
327349
unsigned idxBitWidth =
328350
getEncoding().getIndexBitWidth() ? getEncoding().getIndexBitWidth() : 64u;

mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
33
CodegenEnv.cpp
44
CodegenUtils.cpp
55
SparseBufferRewriting.cpp
6+
SparseStorageSpecifierToLLVM.cpp
67
SparseTensorCodegen.cpp
78
SparseTensorConversion.cpp
89
SparseTensorPasses.cpp
910
SparseTensorRewriting.cpp
11+
SparseTensorStorageLayout.cpp
1012
SparseVectorization.cpp
1113
Sparsification.cpp
1214
SparsificationAndBufferizationPass.cpp
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
//===- SparseStorageSpecifierToLLVM.cpp - convert specifier to llvm -------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "CodegenUtils.h"
10+
#include "SparseTensorStorageLayout.h"
11+
12+
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
13+
14+
using namespace mlir;
15+
using namespace sparse_tensor;
16+
17+
static SmallVector<Type, 2> getSpecifierFields(StorageSpecifierType tp) {
18+
MLIRContext *ctx = tp.getContext();
19+
auto enc = tp.getEncoding();
20+
unsigned rank = enc.getDimLevelType().size();
21+
22+
SmallVector<Type, 2> result;
23+
auto indexType = tp.getSizesType();
24+
auto dimSizes = LLVM::LLVMArrayType::get(ctx, indexType, rank);
25+
auto memSizes = LLVM::LLVMArrayType::get(ctx, indexType,
26+
getNumDataFieldsFromEncoding(enc));
27+
result.push_back(dimSizes);
28+
result.push_back(memSizes);
29+
return result;
30+
}
31+
32+
static Type convertSpecifier(StorageSpecifierType tp) {
33+
return LLVM::LLVMStructType::getLiteral(tp.getContext(),
34+
getSpecifierFields(tp));
35+
}
36+
37+
StorageSpecifierToLLVMTypeConverter::StorageSpecifierToLLVMTypeConverter() {
38+
addConversion([](Type type) { return type; });
39+
addConversion([](StorageSpecifierType tp) { return convertSpecifier(tp); });
40+
}
41+
42+
constexpr uint64_t kDimSizePosInSpecifier = 0;
43+
constexpr uint64_t kMemSizePosInSpecifier = 1;
44+
45+
class SpecifierStructBuilder : public StructBuilder {
46+
public:
47+
explicit SpecifierStructBuilder(Value specifier) : StructBuilder(specifier) {
48+
assert(value);
49+
}
50+
51+
// Undef value for dimension sizes, all zero value for memory sizes.
52+
static Value getInitValue(OpBuilder &builder, Location loc, Type structType);
53+
54+
Value dimSize(OpBuilder &builder, Location loc, unsigned dim);
55+
void setDimSize(OpBuilder &builder, Location loc, unsigned dim, Value size);
56+
57+
Value memSize(OpBuilder &builder, Location loc, unsigned pos);
58+
void setMemSize(OpBuilder &builder, Location loc, unsigned pos, Value size);
59+
};
60+
61+
Value SpecifierStructBuilder::getInitValue(OpBuilder &builder, Location loc,
62+
Type structType) {
63+
Value metaData = builder.create<LLVM::UndefOp>(loc, structType);
64+
SpecifierStructBuilder md(metaData);
65+
auto memSizeArrayType = structType.cast<LLVM::LLVMStructType>()
66+
.getBody()[kMemSizePosInSpecifier]
67+
.cast<LLVM::LLVMArrayType>();
68+
69+
Value zero = constantZero(builder, loc, memSizeArrayType.getElementType());
70+
// Fill memSizes array with zero.
71+
for (int i = 0, e = memSizeArrayType.getNumElements(); i < e; i++)
72+
md.setMemSize(builder, loc, i, zero);
73+
74+
return md;
75+
}
76+
77+
/// Builds IR inserting the pos-th size into the descriptor.
78+
Value SpecifierStructBuilder::dimSize(OpBuilder &builder, Location loc,
79+
unsigned dim) {
80+
return builder.create<LLVM::ExtractValueOp>(
81+
loc, value, ArrayRef<int64_t>({kDimSizePosInSpecifier, dim}));
82+
}
83+
84+
/// Builds IR inserting the pos-th size into the descriptor.
85+
void SpecifierStructBuilder::setDimSize(OpBuilder &builder, Location loc,
86+
unsigned dim, Value size) {
87+
value = builder.create<LLVM::InsertValueOp>(
88+
loc, value, size, ArrayRef<int64_t>({kDimSizePosInSpecifier, dim}));
89+
}
90+
91+
/// Builds IR extracting the pos-th memory size into the descriptor.
92+
Value SpecifierStructBuilder::memSize(OpBuilder &builder, Location loc,
93+
unsigned pos) {
94+
return builder.create<LLVM::ExtractValueOp>(
95+
loc, value, ArrayRef<int64_t>({kMemSizePosInSpecifier, pos}));
96+
}
97+
98+
/// Builds IR inserting the pos-th memory size into the descriptor.
99+
void SpecifierStructBuilder::setMemSize(OpBuilder &builder, Location loc,
100+
unsigned pos, Value size) {
101+
value = builder.create<LLVM::InsertValueOp>(
102+
loc, value, size, ArrayRef<int64_t>({kMemSizePosInSpecifier, pos}));
103+
}
104+
105+
template <typename Base, typename SourceOp>
106+
class SpecifierGetterSetterOpConverter : public OpConversionPattern<SourceOp> {
107+
public:
108+
using OpAdaptor = typename SourceOp::Adaptor;
109+
using OpConversionPattern<SourceOp>::OpConversionPattern;
110+
111+
LogicalResult
112+
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
113+
ConversionPatternRewriter &rewriter) const override {
114+
SpecifierStructBuilder spec(adaptor.getSpecifier());
115+
Value v;
116+
if (op.getSpecifierKind() == StorageSpecifierKind::DimSize) {
117+
v = Base::onDimSize(rewriter, op, spec,
118+
op.getDim().value().getZExtValue());
119+
} else {
120+
auto enc = op.getSpecifier().getType().getEncoding();
121+
builder::StorageLayout layout(enc);
122+
Optional<unsigned> dim = std::nullopt;
123+
if (op.getDim())
124+
dim = op.getDim().value().getZExtValue();
125+
unsigned idx = layout.getMemRefFieldIndex(op.getSpecifierKind(), dim);
126+
v = Base::onMemSize(rewriter, op, spec, idx);
127+
}
128+
129+
rewriter.replaceOp(op, v);
130+
return success();
131+
}
132+
};
133+
134+
struct StorageSpecifierSetOpConverter
135+
: public SpecifierGetterSetterOpConverter<StorageSpecifierSetOpConverter,
136+
SetStorageSpecifierOp> {
137+
using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
138+
static Value onDimSize(OpBuilder &builder, SetStorageSpecifierOp op,
139+
SpecifierStructBuilder &spec, unsigned d) {
140+
spec.setDimSize(builder, op.getLoc(), d, op.getValue());
141+
return spec;
142+
}
143+
144+
static Value onMemSize(OpBuilder &builder, SetStorageSpecifierOp op,
145+
SpecifierStructBuilder &spec, unsigned i) {
146+
spec.setMemSize(builder, op.getLoc(), i, op.getValue());
147+
return spec;
148+
}
149+
};
150+
151+
struct StorageSpecifierGetOpConverter
152+
: public SpecifierGetterSetterOpConverter<StorageSpecifierGetOpConverter,
153+
GetStorageSpecifierOp> {
154+
using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
155+
static Value onDimSize(OpBuilder &builder, GetStorageSpecifierOp op,
156+
SpecifierStructBuilder &spec, unsigned d) {
157+
return spec.dimSize(builder, op.getLoc(), d);
158+
}
159+
static Value onMemSize(OpBuilder &builder, GetStorageSpecifierOp op,
160+
SpecifierStructBuilder &spec, unsigned i) {
161+
return spec.memSize(builder, op.getLoc(), i);
162+
}
163+
};
164+
165+
struct StorageSpecifierInitOpConverter
166+
: public OpConversionPattern<StorageSpecifierInitOp> {
167+
public:
168+
using OpConversionPattern::OpConversionPattern;
169+
LogicalResult
170+
matchAndRewrite(StorageSpecifierInitOp op, OpAdaptor adaptor,
171+
ConversionPatternRewriter &rewriter) const override {
172+
Type llvmType = getTypeConverter()->convertType(op.getResult().getType());
173+
rewriter.replaceOp(op, SpecifierStructBuilder::getInitValue(
174+
rewriter, op.getLoc(), llvmType));
175+
return success();
176+
}
177+
};
178+
179+
void mlir::populateStorageSpecifierToLLVMPatterns(TypeConverter &converter,
180+
RewritePatternSet &patterns) {
181+
patterns.add<StorageSpecifierGetOpConverter, StorageSpecifierSetOpConverter,
182+
StorageSpecifierInitOpConverter>(converter,
183+
patterns.getContext());
184+
}

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ namespace mlir {
2828
#define GEN_PASS_DEF_SPARSETENSORCODEGEN
2929
#define GEN_PASS_DEF_SPARSEBUFFERREWRITE
3030
#define GEN_PASS_DEF_SPARSEVECTORIZATION
31+
#define GEN_PASS_DEF_STORAGESPECIFIERTOLLVM
3132
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
3233
} // namespace mlir
3334

@@ -193,9 +194,14 @@ struct SparseTensorCodegenPass
193194
target.addLegalOp<SortOp>();
194195
target.addLegalOp<SortCooOp>();
195196
target.addLegalOp<PushBackOp>();
196-
// All dynamic rules below accept new function, call, return, and various
197-
// tensor and bufferization operations as legal output of the rewriting
198-
// provided that all sparse tensor types have been fully rewritten.
197+
// Storage specifier outlives sparse tensor pipeline.
198+
target.addLegalOp<GetStorageSpecifierOp>();
199+
target.addLegalOp<SetStorageSpecifierOp>();
200+
target.addLegalOp<StorageSpecifierInitOp>();
201+
// All dynamic rules below accept new function, call, return, and
202+
// various tensor and bufferization operations as legal output of the
203+
// rewriting provided that all sparse tensor types have been fully
204+
// rewritten.
199205
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
200206
return converter.isSignatureLegal(op.getFunctionType());
201207
});
@@ -271,6 +277,44 @@ struct SparseVectorizationPass
271277
}
272278
};
273279

280+
struct StorageSpecifierToLLVMPass
281+
: public impl::StorageSpecifierToLLVMBase<StorageSpecifierToLLVMPass> {
282+
283+
StorageSpecifierToLLVMPass() = default;
284+
285+
void runOnOperation() override {
286+
auto *ctx = &getContext();
287+
ConversionTarget target(*ctx);
288+
RewritePatternSet patterns(ctx);
289+
StorageSpecifierToLLVMTypeConverter converter;
290+
291+
// All ops in the sparse dialect must go!
292+
target.addIllegalDialect<SparseTensorDialect>();
293+
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
294+
return converter.isSignatureLegal(op.getFunctionType());
295+
});
296+
target.addDynamicallyLegalOp<func::CallOp>([&](func::CallOp op) {
297+
return converter.isSignatureLegal(op.getCalleeType());
298+
});
299+
target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
300+
return converter.isLegal(op.getOperandTypes());
301+
});
302+
target.addLegalDialect<arith::ArithDialect, LLVM::LLVMDialect>();
303+
304+
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
305+
converter);
306+
populateCallOpTypeConversionPattern(patterns, converter);
307+
populateBranchOpInterfaceTypeConversionPattern(patterns, converter);
308+
populateReturnOpTypeConversionPattern(patterns, converter);
309+
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
310+
target);
311+
populateStorageSpecifierToLLVMPatterns(converter, patterns);
312+
if (failed(applyPartialConversion(getOperation(), target,
313+
std::move(patterns))))
314+
signalPassFailure();
315+
}
316+
};
317+
274318
} // namespace
275319

276320
//===----------------------------------------------------------------------===//
@@ -355,3 +399,7 @@ mlir::createSparseVectorizationPass(unsigned vectorLength,
355399
return std::make_unique<SparseVectorizationPass>(
356400
vectorLength, enableVLAVectorization, enableSIMDIndex32);
357401
}
402+
403+
std::unique_ptr<Pass> mlir::createStorageSpecifierToLLVMPass() {
404+
return std::make_unique<StorageSpecifierToLLVMPass>();
405+
}

0 commit comments

Comments
 (0)