Skip to content

Commit

Permalink
NFC: A few cleanups for SPIRVLowering
Browse files Browse the repository at this point in the history
Updated comments and used static instead of anonymous namspace
to hide functions to be consistent with the existing codebase.

PiperOrigin-RevId: 282847784
  • Loading branch information
antiagainst authored and tensorflower-gardener committed Nov 27, 2019
1 parent a4d7650 commit 5810efe
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 60 deletions.
19 changes: 10 additions & 9 deletions mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h
Expand Up @@ -30,21 +30,23 @@

namespace mlir {

/// Converts a function type according to the requirements of a SPIR-V entry
/// function. The arguments need to be converted to spv.GlobalVariables of
/// spv.ptr types so that they could be bound by the runtime.
/// Type conversion from stdandard types to SPIR-V types for shader interface.
///
/// For composite types, this converter additionally performs type wrapping to
/// satisfy shader interface requirements: shader interface types must be
/// pointers to structs.
class SPIRVTypeConverter final : public TypeConverter {
public:
using TypeConverter::TypeConverter;

/// Converts types to SPIR-V types using the basic type converter.
Type convertType(Type t) override;
/// Converts the given standard `type` to SPIR-V correspondance.
Type convertType(Type type) override;

/// Gets the index type equivalent in SPIR-V.
Type getIndexType(MLIRContext *context);
/// Gets the SPIR-V correspondance for the standard index type.
static Type getIndexType(MLIRContext *context);
};

/// Base class to define a conversion pattern to translate Ops into SPIR-V.
/// Base class to define a conversion pattern to lower `SourceOp` into SPIR-V.
template <typename SourceOp>
class SPIRVOpLowering : public OpConversionPattern<SourceOp> {
public:
Expand All @@ -54,7 +56,6 @@ class SPIRVOpLowering : public OpConversionPattern<SourceOp> {
typeConverter(typeConverter) {}

protected:
/// Type lowering class.
SPIRVTypeConverter &typeConverter;
};

Expand Down
104 changes: 53 additions & 51 deletions mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
Expand Up @@ -68,8 +68,7 @@ mlir::spirv::getEntryPointABIAttr(ArrayRef<int32_t> localSize,
// Type Conversion
//===----------------------------------------------------------------------===//

namespace {
Type convertIndexType(MLIRContext *context) {
Type SPIRVTypeConverter::getIndexType(MLIRContext *context) {
// Convert to 32-bit integers for now. Might need a way to control this in
// future.
// TODO(ravishankarm): It is porbably better to make it 64-bit integers. To
Expand All @@ -82,7 +81,7 @@ Type convertIndexType(MLIRContext *context) {

// TODO(ravishankarm): This is a utility function that should probably be
// exposed by the SPIR-V dialect. Keeping it local till the use case arises.
Optional<int64_t> getTypeNumBytes(Type t) {
static Optional<int64_t> getTypeNumBytes(Type t) {
if (auto integerType = t.dyn_cast<IntegerType>()) {
return integerType.getWidth() / 8;
} else if (auto floatType = t.dyn_cast<FloatType>()) {
Expand All @@ -92,17 +91,17 @@ Optional<int64_t> getTypeNumBytes(Type t) {
return llvm::None;
}

Type typeConversionImpl(Type t) {
// Check if the type is SPIR-V supported. If so return the type.
if (spirv::SPIRVDialect::isValidType(t)) {
return t;
static Type convertStdType(Type type) {
// If the type is already valid in SPIR-V, directly return.
if (spirv::SPIRVDialect::isValidType(type)) {
return type;
}

if (auto indexType = t.dyn_cast<IndexType>()) {
return convertIndexType(t.getContext());
if (auto indexType = type.dyn_cast<IndexType>()) {
return SPIRVTypeConverter::getIndexType(type.getContext());
}

if (auto memRefType = t.dyn_cast<MemRefType>()) {
if (auto memRefType = type.dyn_cast<MemRefType>()) {
// TODO(ravishankarm): For now only support default memory space. The memory
// space description is not set is stone within MLIR, i.e. it depends on the
// context it is being used. To map this to SPIR-V storage classes, we
Expand All @@ -111,60 +110,65 @@ Type typeConversionImpl(Type t) {
if (memRefType.getMemorySpace()) {
return Type();
}
auto elementType = typeConversionImpl(memRefType.getElementType());

auto elementType = convertStdType(memRefType.getElementType());
if (!elementType) {
return Type();
}

auto elementSize = getTypeNumBytes(elementType);
if (!elementSize) {
return Type();
}
// TODO(ravishankarm) : Handle dynamic shapes.
if (memRefType.hasStaticShape()) {
// Get the strides and offset
int64_t offset;
SmallVector<int64_t, 4> strides;
if (failed(getStridesAndOffset(memRefType, strides, offset)) ||
offset == MemRefType::getDynamicStrideOrOffset() ||
llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) {
// TODO(ravishankarm) : Handle dynamic strides and offsets.
return Type();
}
// Convert to a multi-dimensional spv.array if size is known.
auto shape = memRefType.getShape();
assert(shape.size() == strides.size());
for (int i = shape.size(); i > 0; --i) {
elementType = spirv::ArrayType::get(
elementType, shape[i - 1], strides[i - 1] * elementSize.getValue());
}
// For the offset, need to wrap the array in a struct.
auto structType =
spirv::StructType::get(elementType, offset * elementSize.getValue());
// For now initialize the storage class to StorageBuffer. This will be
// updated later based on whats passed in w.r.t to the ABI attributes.
return spirv::PointerType::get(structType,
spirv::StorageClass::StorageBuffer);

if (!memRefType.hasStaticShape()) {
// TODO(ravishankarm) : Handle dynamic shapes.
return Type();
}

// Get the strides and offset.
int64_t offset;
SmallVector<int64_t, 4> strides;
if (failed(getStridesAndOffset(memRefType, strides, offset)) ||
offset == MemRefType::getDynamicStrideOrOffset() ||
llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) {
// TODO(ravishankarm) : Handle dynamic strides and offsets.
return Type();
}

// Convert to a multi-dimensional spv.array if size is known.
auto shape = memRefType.getShape();
assert(shape.size() == strides.size());
Type arrayType = elementType;
// TODO(antiagainst): Introduce layout as part of the shader ABI to have
// better separate of concerns.
for (int i = shape.size(); i > 0; --i) {
arrayType = spirv::ArrayType::get(
arrayType, shape[i - 1], strides[i - 1] * elementSize.getValue());
}

// For the offset, need to wrap the array in a struct.
auto structType =
spirv::StructType::get(arrayType, offset * elementSize.getValue());
// For now initialize the storage class to StorageBuffer. This will be
// updated later based on whats passed in w.r.t to the ABI attributes.
return spirv::PointerType::get(structType,
spirv::StorageClass::StorageBuffer);
}

return Type();
}
} // namespace

Type SPIRVTypeConverter::convertType(Type t) { return typeConversionImpl(t); }

Type SPIRVTypeConverter::getIndexType(MLIRContext *context) {
return convertType(IndexType::get(context));
}
Type SPIRVTypeConverter::convertType(Type type) { return convertStdType(type); }

//===----------------------------------------------------------------------===//
// Builtin Variables
//===----------------------------------------------------------------------===//

namespace {
/// Look through all global variables in `moduleOp` and check if there is a
/// spv.globalVariable that has the same `builtin` attribute.
spirv::GlobalVariableOp getBuiltinVariable(spirv::ModuleOp &moduleOp,
spirv::BuiltIn builtin) {
static spirv::GlobalVariableOp getBuiltinVariable(spirv::ModuleOp &moduleOp,
spirv::BuiltIn builtin) {
for (auto varOp : moduleOp.getBlock().getOps<spirv::GlobalVariableOp>()) {
if (auto builtinAttr = varOp.getAttrOfType<StringAttr>(convertToSnakeCase(
stringifyDecoration(spirv::Decoration::BuiltIn)))) {
Expand All @@ -178,15 +182,14 @@ spirv::GlobalVariableOp getBuiltinVariable(spirv::ModuleOp &moduleOp,
}

/// Gets name of global variable for a buitlin.
std::string getBuiltinVarName(spirv::BuiltIn builtin) {
static std::string getBuiltinVarName(spirv::BuiltIn builtin) {
return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + "__";
}

/// Gets or inserts a global variable for a builtin within a module.
spirv::GlobalVariableOp getOrInsertBuiltinVariable(spirv::ModuleOp &moduleOp,
Location loc,
spirv::BuiltIn builtin,
OpBuilder &builder) {
static spirv::GlobalVariableOp
getOrInsertBuiltinVariable(spirv::ModuleOp &moduleOp, Location loc,
spirv::BuiltIn builtin, OpBuilder &builder) {
if (auto varOp = getBuiltinVariable(moduleOp, builtin)) {
return varOp;
}
Expand Down Expand Up @@ -217,7 +220,6 @@ spirv::GlobalVariableOp getOrInsertBuiltinVariable(spirv::ModuleOp &moduleOp,
builder.restoreInsertionPoint(ip);
return newVarOp;
}
} // namespace

/// Gets the global variable associated with a builtin and add
/// it if it doesnt exist.
Expand Down

0 comments on commit 5810efe

Please sign in to comment.