@@ -1298,14 +1298,14 @@ def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure,
12981298}
12991299
13001300def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
1301-                               AllElementTypesMatch<["mem_desc", "res"]>]>  {
1301+                               AllElementTypesMatch<["mem_desc", "res"]>,
1302+                               AllRanksMatch<["mem_desc", "res"]>]>  {
13021303  let arguments = (ins XeGPU_MemDesc:$mem_desc,
13031304    Variadic<Index>: $offsets,
13041305    DenseI64ArrayAttr: $const_offsets,
1305-     OptionalAttr<UnitAttr>:$subgroup_block_io,
13061306    OptionalAttr<DistributeLayoutAttr>:$layout
13071307  );
1308-   let results = (outs AnyTypeOf<[ XeGPU_ValueType, XeGPU_ScalarType]> :$res);   
1308+   let results = (outs XeGPU_ValueType:$res);
13091309  let assemblyFormat = [{
13101310    $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
13111311    prop-dict attr-dict `` `:` type(operands) `->` type(results)
@@ -1319,9 +1319,6 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
13191319    Arguments:
13201320     - `mem_desc`: the memory descriptor identifying the SLM region.
13211321     - `offsets`: the coordinates within the matrix to read from.
1322-      - `subgroup_block_io`: [optional] An attribute indicating that the operation can be 
1323-                  lowered to a subgroup block load. When this attribute is present, 
1324-                  the offsets are subgroup-uniform across all lanes.
13251322     - `layout`: [optional] An attribute for guiding distributions among
13261323                 subgroups and/or work-items. It currently can accept either
13271324                 LayoutAttr or SliceAttr.
@@ -1339,24 +1336,21 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
13391336    }
13401337
13411338    ArrayRef<int64_t> getDataShape() {
1342-       auto resTy = getRes().getType();
1343-       if (auto vecTy = llvm::dyn_cast<VectorType>(resTy))
1344-         return vecTy.getShape();
1345-       return {};
1339+       return getRes().getType().getShape();
13461340    }
13471341  }];
13481342
13491343  let hasVerifier = 1;
13501344}
13511345
13521346def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
1353-                               AllElementTypesMatch<["mem_desc", "data"]>]> {
1347+                               AllElementTypesMatch<["mem_desc", "data"]>,
1348+                               AllRanksMatch<["mem_desc", "data"]>]> {
13541349  let arguments = (ins
1355-     AnyTypeOf<[ XeGPU_ValueType, XeGPU_ScalarType]> :$data,
1350+     XeGPU_ValueType:$data,
13561351    XeGPU_MemDesc:$mem_desc,
13571352    Variadic<Index>: $offsets,
13581353    DenseI64ArrayAttr: $const_offsets,
1359-     OptionalAttr<UnitAttr>:$subgroup_block_io,
13601354    OptionalAttr<DistributeLayoutAttr>:$layout
13611355  );
13621356  let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
@@ -1370,9 +1364,6 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
13701364     - `mem_desc`: the memory descriptor specifying the SLM region.
13711365     - `offsets`: the coordinates within the matrix where the data will be written.
13721366     - `data`: the values to be stored in the matrix.
1373-      - `subgroup_block_io`: [optional] An attribute indicating that the operation can be 
1374-                  lowered to a subgroup block store. When this attribute is present, 
1375-                  the offsets are subgroup-uniform across all lanes.     
13761367     - `layout`: [optional] An attribute for guiding distributions among
13771368                 subgroups and/or work-items. It currently can accept either
13781369                 LayoutAttr or SliceAttr.
@@ -1387,15 +1378,49 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
13871378    }
13881379
13891380    ArrayRef<int64_t> getDataShape() {
1390-       auto DataTy = getData().getType();
1391-       if (auto vecTy = llvm::dyn_cast<VectorType>(DataTy))
1392-         return vecTy.getShape();
1393-       return {};
1381+       return getData().getType().getShape();
13941382    }
13951383
13961384  }];
13971385
13981386  let hasVerifier = 1;
13991387}
14001388
1389+ def XeGPU_MemDescSubviewOp: XeGPU_Op<"mem_desc_subview",
1390+           [Pure, ViewLikeOpInterface, AllElementTypesMatch<["src", "res"]>]> {
1391+   let description = [{
1392+     Creates a subview of a memory descriptor. The resulting memory descriptor can have
1393+     a lower rank than the source; in this case, the result dimensions correspond to the
1394+     higher-order dimensions of the source memory descriptor.
1395+ 
1396+     Arguments:
1397+      - `src` : a memory descriptor.
1398+      - `offsets` : the coordinates within the matrix the subview will be created from.
1399+ 
1400+     Results:
1401+     - `res` : a memory descriptor with smaller size.
1402+ 
1403+   }];
1404+   let arguments = (ins XeGPU_MemDesc:$src,
1405+                        Variadic<Index>:$offsets,
1406+                        DenseI64ArrayAttr:$const_offsets);
1407+   let results = (outs XeGPU_MemDesc:$res);
1408+   let assemblyFormat = [{$src `` custom<DynamicIndexList>($offsets, $const_offsets) prop-dict
1409+                          attr-dict `` `:` qualified(type($src)) `->` qualified(type($res))}];
1410+   let builders = [
1411+     OpBuilder<(ins "Type": $res, "Value":$src, "llvm::ArrayRef<OpFoldResult>": $offsets)>
1412+   ];
1413+ 
1414+   let extraClassDeclaration = [{
1415+     mlir::Value getViewSource() { return getSrc(); }
1416+ 
1417+     SmallVector<OpFoldResult> getMixedOffsets() {
1418+       return getMixedValues(getConstOffsets(), getOffsets(), getContext());
1419+     }
1420+   }];
1421+ 
1422+   let hasVerifier = 1;
1423+ }
1424+ 
1425+ 
14011426#endif // MLIR_DIALECT_XEGPU_IR_XEGPUOPS_TD
0 commit comments