-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][XeGPU] XeVM lowering support for load_matrix/store_matrix #162780
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4c58d3d
554b95e
446b951
9f9744c
c89c5db
bbd43d0
0344761
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -716,8 +716,30 @@ def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> { | |
return getAttrs().getAs<ArrayAttr>("stride"); | ||
} | ||
|
||
ArrayAttr getBlockAttr() { | ||
return getAttrs().getAs<ArrayAttr>("block"); | ||
} | ||
|
||
}]; | ||
|
||
} | ||
|
||
def RowOriented : I32EnumAttrCase<"ROW", 0, "row">; | ||
def ColOriented : I32EnumAttrCase<"COL", 1, "col">; | ||
def MatrixAccessDirection : | ||
I32EnumAttr<"MatrixAccessDirection", | ||
"Matrix elements/vectors can have row or column direction", [ | ||
RowOriented, ColOriented | ||
]> { | ||
let genSpecializedAttr = 0; | ||
let cppNamespace = "::mlir::xegpu"; | ||
} | ||
def MatrixAccessDirectionAttr : | ||
EnumAttr<XeGPU_Dialect, | ||
MatrixAccessDirection, | ||
"matrix_access_direction">{ | ||
let summary = [{Describe the direction of memory access for load_matrix and store_matrix.}]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't this already covered by It's unclear to me how to use this new attr. I had a look at the |
||
let assemblyFormat = "`<` $value `>`"; | ||
} | ||
|
||
#endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1298,14 +1298,16 @@ def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure, | |
} | ||
|
||
def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, | ||
AllElementTypesMatch<["mem_desc", "res"]>, | ||
AllRanksMatch<["mem_desc", "res"]>]> { | ||
AllElementTypesMatch<["mem_desc", "res"]>]> { | ||
let arguments = (ins XeGPU_MemDesc:$mem_desc, | ||
Variadic<Index>: $offsets, | ||
DenseI64ArrayAttr: $const_offsets, | ||
OptionalAttr<I32Attr>:$vec_length, | ||
OptionalAttr<MatrixAccessDirectionAttr>:$vec_direction, | ||
OptionalAttr<UnitAttr>:$subgroup_block_io, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Whose responsibility will it be to assign this option? |
||
OptionalAttr<DistributeLayoutAttr>:$layout | ||
); | ||
let results = (outs XeGPU_ValueType:$res); | ||
let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$res); | ||
let assemblyFormat = [{ | ||
$mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets) | ||
prop-dict attr-dict `` `:` type(operands) `->` type(results) | ||
|
@@ -1336,21 +1338,26 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>, | |
} | ||
|
||
ArrayRef<int64_t> getDataShape() { | ||
return getRes().getType().getShape(); | ||
auto resTy = getRes().getType(); | ||
if (auto vecTy = llvm::dyn_cast<VectorType>(resTy)) | ||
return vecTy.getShape(); | ||
return {}; | ||
} | ||
}]; | ||
|
||
let hasVerifier = 1; | ||
} | ||
|
||
def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, | ||
AllElementTypesMatch<["mem_desc", "data"]>, | ||
AllRanksMatch<["mem_desc", "data"]>]> { | ||
AllElementTypesMatch<["mem_desc", "data"]>]> { | ||
let arguments = (ins | ||
XeGPU_ValueType:$data, | ||
AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$data, | ||
XeGPU_MemDesc:$mem_desc, | ||
Variadic<Index>: $offsets, | ||
DenseI64ArrayAttr: $const_offsets, | ||
OptionalAttr<I32Attr>:$vec_length, | ||
OptionalAttr<MatrixAccessDirectionAttr>:$vec_direction, | ||
OptionalAttr<UnitAttr>:$subgroup_block_io, | ||
OptionalAttr<DistributeLayoutAttr>:$layout | ||
); | ||
let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets) | ||
|
@@ -1378,7 +1385,10 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>, | |
} | ||
|
||
ArrayRef<int64_t> getDataShape() { | ||
return getData().getType().getShape(); | ||
auto DataTy = getData().getType(); | ||
if (auto vecTy = llvm::dyn_cast<VectorType>(DataTy)) | ||
return vecTy.getShape(); | ||
return {}; | ||
} | ||
|
||
}]; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -237,7 +237,7 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m | |
return MemDescType::get(getContext(), shape.value_or(getShape()), elementType, getMemLayout()); | ||
} | ||
|
||
ArrayAttr getStrides() { | ||
ArrayAttr getStridesAttr() { | ||
auto layout = getMemLayout(); | ||
if (layout && layout.hasAttr("stride")) { | ||
return layout.getStrides(); | ||
|
@@ -250,6 +250,54 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m | |
Builder builder(getContext()); | ||
return builder.getI64ArrayAttr(defaultStrides); | ||
} | ||
|
||
/// Heuristic to determine if the MemDesc uses column-major layout, | ||
/// based on the rank and the value of the first stride dimension. | ||
bool isColMajor() { | ||
auto dim0 = dyn_cast<IntegerAttr>(getStridesAttr()[0]); | ||
return getRank() == 2 && dim0 && dim0.getInt() == 1; | ||
} | ||
|
||
// get the Blocking shape for a MemDescType, Which is represented | ||
// as an attribute in MemDescType. By default it is the shape | ||
// of the mdescTy | ||
SmallVector<int64_t> getBlockSize() { | ||
SmallVector<int64_t> size(getShape()); | ||
MemLayoutAttr layout = getMemLayout(); | ||
if (layout && layout.hasAttr("block")) { | ||
ArrayAttr attr = layout.getBlockAttr(); | ||
size.clear(); | ||
llvm::for_each(attr, [&](Attribute elem) { | ||
if (auto intElem = dyn_cast<IntegerAttr>(elem)) | ||
size.push_back(intElem.getInt()); | ||
Comment on lines
+271
to
+272
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also note for later. I think it's not the case now but this check shouldn't be needed. |
||
}); | ||
} | ||
return size; | ||
} | ||
|
||
// Get strides as vector of integer. | ||
// If it contains block attribute, the strides are blocked strides. | ||
// | ||
// The blocking is applied against the original matrix shape | ||
// so that the linear offset is not impacted by the subview. | ||
// | ||
// It first computes the original matrix shape using the stride info, | ||
// then computes the number of blocks in each dimension of original shape, | ||
// then compute the outer block shape and stride, | ||
// then combines the inner and outer block shape and stride | ||
// e.g. for mem_desc<32x256xf16, @block=[16, 8], @strides=[1, 32]> | ||
// its memory layout tuple is ([2,32,16,8],[128,256,1,16]) | ||
// for mem_desc<256x32xf16, @block=[8, 16]> with default @stride[32, 1] | ||
// its memory layout tuple is ([32,2,8,16],[256,128,16,1]) | ||
SmallVector<int64_t> getStrides(); | ||
|
||
/// Generates instructions to compute the linearize offset | ||
// if the memory descriptor is blocked, it returns linearize offset based on the blocked layout | ||
// the strides of memory descriptor is always considered regardless of blocked or not | ||
Value getLinearOffsets(OpBuilder &builder, | ||
Location loc, ArrayRef<OpFoldResult> offsets); | ||
|
||
|
||
}]; | ||
|
||
let hasCustomAssemblyFormat = true; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note for future, not burning issue here.
It'd be nice to align the two getters.
getXAttr
version might be better in this case asgetStrides()
andgetBlocks()
is already used for many other things.