Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 51 additions & 5 deletions mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -1228,7 +1228,7 @@ def AMDGPU_ScaledMFMAOp :
}

def AMDGPU_MakeDmaBaseOp :
AMDGPU_Op<"make_dma_base", [Pure, AttrSizedOperandSegments]>,
AMDGPU_Op<"make_dma_base", [Pure, AttrSizedOperandSegments, AllElementTypesMatch<["global", "lds"]>]>,
Arguments<(ins Arg<AnyMemRef>:$global,
Variadic<Index>:$global_indices,
Arg<AnyMemRef>:$lds,
Expand Down Expand Up @@ -1294,8 +1294,8 @@ def AMDGPU_MakeDmaDescriptorOp :
DenseI64ArrayAttr: $global_static_strides,
Variadic<Index>: $shared_dynamic_sizes,
DenseI64ArrayAttr: $shared_static_sizes,
Optional<Index>: $pad,
Optional<Index>: $pad_every,
Optional<Index>: $pad_amount,
Optional<Index>: $pad_interval,
Optional<AnyMemRef>: $atomic_barrier_address,
Variadic<Index>: $atomic_barrier_indices,
Optional<Index>: $global_increment,
Expand Down Expand Up @@ -1331,7 +1331,7 @@ def AMDGPU_MakeDmaDescriptorOp :

// Example of moving a two dimension tensor to LDS where padding is applied after every integer.
%base = amdgpu.make_dma_base %global[0, 0], %lds[0, 0] : memref<32x32xi32>, memref<64x64xi32, #gpu.address_space<workgroup>> -> !amdgpu.tdm_base<i32>
%descriptor = amdgpu.make_dma_descriptor %base globalSize [32, 32] globalStride [32, 1] sharedSize [64, 64] padding(%pad pad_every %pad_every) : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor
%descriptor = amdgpu.make_dma_descriptor %base globalSize [32, 32] globalStride [32, 1] sharedSize [64, 64] padding(%pad_amount pad_every %pad_interval) : !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor
amdgpu.tensor_load_to_lds %descriptor : !amdgpu.tdm_descriptor
```
}];
Expand All @@ -1341,14 +1341,60 @@ def AMDGPU_MakeDmaDescriptorOp :
`globalSize` custom<DynamicIndexList>($global_dynamic_sizes, $global_static_sizes)
`globalStride` custom<DynamicIndexList>($global_dynamic_strides, $global_static_strides)
`sharedSize` custom<DynamicIndexList>($shared_dynamic_sizes, $shared_static_sizes)
( `padShared` `(` $pad^ `every` $pad_every `)` )?
( `padShared` `(` $pad_amount^ `every` $pad_interval `)` )?
( `atomicBarrier` `(` $atomic_barrier_address^ `[` $atomic_barrier_indices `]`
`:` type($atomic_barrier_address) `)`)?
( `iterate` $global_increment^ `,` $lds_increment `,` $iteration_count )?
attr-dict `:` qualified(type($base)) `->` type(results)
}];

let extraClassDeclaration = [{
int getRank() {
return getGlobalStaticSizes().size();
}

int getElementTypeWidth() {
Type elementType = getBase().getType().getElementType();
int width;
if (auto floatType = dyn_cast<FloatType>(elementType)) {
width = floatType.getWidth();
} else if (auto intType = dyn_cast<IntegerType>(elementType)) {
width = intType.getWidth();
} else {
llvm_unreachable("element type must have getWidth interface");
}
return width;
}

SmallVector<OpFoldResult> getMixedList(SmallVector<Value> dynamics, ArrayRef<int64_t> statics) {
SmallVector<OpFoldResult> result;
unsigned ctr = 0;
OpBuilder b(getContext());
for (int64_t static_elem : statics) {
if (ShapedType::isDynamic(static_elem)) {
result.push_back(dynamics[ctr++]);
} else {
result.push_back(b.getIndexAttr(static_elem));
}
}
return result;
}

SmallVector<OpFoldResult> getMixedGlobalSizes() {
return getMixedList(getGlobalDynamicSizes(), getGlobalStaticSizes());
}

SmallVector<OpFoldResult> getMixedGlobalStrides() {
return getMixedList(getGlobalDynamicStrides(), getGlobalStaticStrides());
}

SmallVector<OpFoldResult> getMixedSharedSizes() {
return getMixedList(getSharedDynamicSizes(), getSharedStaticSizes());
}
}];

let hasVerifier = 1;
let hasFolder = 1;
}

#endif // AMDGPU
Loading