Skip to content

Commit

Permalink
Merge branch 'openxla:main' into shark
Browse files Browse the repository at this point in the history
  • Loading branch information
saienduri committed Feb 21, 2024
2 parents eb2b348 + 4a80ee3 commit 373685f
Show file tree
Hide file tree
Showing 96 changed files with 3,735 additions and 3,084 deletions.
3 changes: 2 additions & 1 deletion build_tools/scripts/generate_release_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def get_all(self):
url = f"https://api.github.com/repos/{self._repo}/releases"
page = 1

while True:
# GitHub limits API responses to the first 1000 results.
while page * self._per_page < 1000:
response = self._session.get(
url,
params={
Expand Down
13 changes: 0 additions & 13 deletions compiler/plugins/target/CUDA/CUDATarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ struct CUDAOptions {
bool clUsePtxas = false;
std::string clUsePtxasFrom;
std::string clUsePtxasParams;
bool enableLegacySync = false;

void bindOptions(OptionsBinder &binder) {
static llvm::cl::OptionCategory category("CUDA HAL Target");
Expand Down Expand Up @@ -101,12 +100,6 @@ struct CUDAOptions {
"iree-hal-cuda-use-ptxas-params", clUsePtxasParams,
llvm::cl::cat(category),
llvm::cl::desc("Passes the given additional parameters to ptxas."));

binder.opt<bool>(
"iree-hal-cuda-enable-legacy-sync", enableLegacySync,
llvm::cl::cat(category),
llvm::cl::desc(
"Enable legacy sync mode that handles semaphores synchronously."));
}
};
} // namespace
Expand Down Expand Up @@ -391,12 +384,6 @@ class CUDATargetBackend final : public TargetBackend {
Builder b(context);
SmallVector<NamedAttribute> configItems;

// Indicates that the runtime HAL driver operates only in the legacy
// synchronous mode.
if (options.enableLegacySync) {
configItems.emplace_back(b.getStringAttr("legacy_sync"), b.getUnitAttr());
}

configItems.emplace_back(b.getStringAttr("executable_targets"),
getExecutableTargets(context));

Expand Down
3 changes: 2 additions & 1 deletion compiler/plugins/target/ROCM/ROCMTargetFeatures.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ static ArrayAttr getMfmaArrayAttr(MLIRContext *context,
}

ArrayAttr getROCMSupportedMmaAttrs(MLIRContext *context, StringRef targetArch) {
if (targetArch == "gfx940") {
// MI300a/x
if (targetArch == "gfx940" || targetArch == "gfx942") {
return getMfmaArrayAttr(context,
{IREE::GPU::MFMAIntrinsic::F16_16x16x16_F32,
IREE::GPU::MFMAIntrinsic::F16_32x32x8_F32});
Expand Down
5 changes: 4 additions & 1 deletion compiler/plugins/target/ROCM/test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ package(
iree_lit_test_suite(
name = "lit",
srcs = enforce_glob(
["smoketest.mlir"],
[
"smoketest.mlir",
"target_device_features.mlir",
],
include = ["*.mlir"],
),
cfg = "//compiler:lit.cfg.py",
Expand Down
1 change: 1 addition & 0 deletions compiler/plugins/target/ROCM/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ iree_lit_test_suite(
lit
SRCS
"smoketest.mlir"
"target_device_features.mlir"
TOOLS
FileCheck
iree-opt
Expand Down
27 changes: 27 additions & 0 deletions compiler/plugins/target/ROCM/test/target_device_features.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targets=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=gfx940 %s | FileCheck %s --check-prefix=MI300
// RUN: iree-opt --pass-pipeline='builtin.module(iree-hal-assign-target-devices{targets=rocm},iree-hal-transformation-pipeline{serialize-executables=false})' --iree-rocm-target-chip=gfx942 %s | FileCheck %s --check-prefix=MI300

// MI300: mma_intrinsics = [#iree_gpu.mfma_layout<F16_16x16x16_F32>, #iree_gpu.mfma_layout<F16_32x32x8_F32>]

stream.executable public @reduce_dispatch {
stream.executable.export @reduce_dispatch workgroups(%arg0: index) -> (index, index, index) {
%x, %y, %z = flow.dispatch.workgroup_count_from_dag_root %arg0
stream.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @reduce_dispatch(%arg0_binding: !stream.binding, %arg1_binding: !stream.binding) {
%c0 = arith.constant 0 : index
%arg0 = stream.binding.subspan %arg0_binding[%c0] : !stream.binding -> !flow.dispatch.tensor<readonly:tensor<16xf32>>
%arg1 = stream.binding.subspan %arg1_binding[%c0] : !stream.binding -> !flow.dispatch.tensor<writeonly:tensor<f32>>
%0 = tensor.empty() : tensor<f32>
%1 = flow.dispatch.tensor.load %arg0, offsets=[0], sizes=[16], strides=[1] : !flow.dispatch.tensor<readonly:tensor<16xf32>> -> tensor<16xf32>
%3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>], iterator_types = ["reduction"]} ins(%1 : tensor<16xf32>) outs(%0 : tensor<f32>) {
^bb0(%arg2: f32, %arg3: f32):
%4 = arith.addf %arg2, %arg3 : f32
linalg.yield %4 : f32
} -> tensor<f32>
flow.dispatch.tensor.store %3, %arg1, offsets=[], sizes=[], strides=[] : tensor<f32> -> !flow.dispatch.tensor<writeonly:tensor<f32>>
return
}
}
}
4 changes: 0 additions & 4 deletions compiler/plugins/target/WebGPU/WebGPUTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,6 @@ class WebGPUTargetBackend : public TargetBackend {
Builder b(context);
SmallVector<NamedAttribute> configItems;

// Indicates that the runtime HAL driver operates only in the legacy
// synchronous mode.
configItems.emplace_back(b.getStringAttr("legacy_sync"), b.getUnitAttr());

configItems.emplace_back(b.getStringAttr("executable_targets"),
getExecutableTargets(context));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,89 +6,20 @@

#include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h"
#include "iree/compiler/Codegen/Utils/VectorOpUtils.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"

#define DEBUG_TYPE "iree-amdgpu-distribute-contract"
#define DEBUG_TYPE "iree-codegen-amdgpu-distribute-contract"

namespace mlir::iree_compiler {
namespace {

using namespace mlir::iree_compiler::IREE::VectorExt;
using VectorValue = TypedValue<VectorType>;

/// A class for querying information about a contract op.
class ContractOpDetail {
public:
enum class OpKind { MK_KN_MN, MK_NK_MN, UNKNOWN };

explicit ContractOpDetail(vector::ContractionOp op) {
opKind = inferOpKind(op.getContext(), op.getIndexingMapsArray());
}

OpKind getOpKind() const { return opKind; }

// Returns the (LHS M, RHS N) dimension index pair.
std::optional<std::pair<int, int>> getOperandMNIndex() const {
switch (opKind) {
case OpKind::MK_KN_MN:
return std::make_pair(0, 1);
case OpKind::MK_NK_MN:
return std::make_pair(0, 0);
case OpKind::UNKNOWN:
break;
}
return std::nullopt;
}

// Returns the (LHS K, RHS K) dimension index pair.
std::optional<std::pair<int, int>> getOperandKIndex() const {
switch (opKind) {
case OpKind::MK_KN_MN:
return std::make_pair(1, 0);
case OpKind::MK_NK_MN:
return std::make_pair(1, 1);
case OpKind::UNKNOWN:
break;
}
return std::nullopt;
}

// Returns the result (M, N) dimension index pair.
std::optional<std::pair<int, int>> getResultMNIndex() const {
switch (opKind) {
case OpKind::MK_KN_MN:
case OpKind::MK_NK_MN:
return std::make_pair(0, 1);
default:
break;
}
return std::nullopt;
}

private:
// Gets the kind of a contract op with the given indexing |maps|.
OpKind inferOpKind(MLIRContext *ctx, SmallVector<AffineMap> maps) {
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
auto infer = [&](MapList m) {
return AffineMap::inferFromExprList(m, ctx);
};
AffineExpr m, n, k;
bindDims(ctx, m, n, k);
if (maps == infer({{m, k}, {k, n}, {m, n}}))
return OpKind::MK_KN_MN;
if (maps == infer({{m, k}, {n, k}, {m, n}}))
return OpKind::MK_NK_MN;
return OpKind::UNKNOWN;
}

private:
OpKind opKind = OpKind::UNKNOWN;
};

/// Distributes `vector.contract` ops with nested layouts.
struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
using OpDistributionPattern::OpDistributionPattern;
Expand Down Expand Up @@ -140,8 +71,8 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
mfmaParams.blocks = mfmaAttr.getBlockSize();

// Infer the contract kind so that we know know to correlate M/N/K dims.
ContractOpDetail opDetail(contractOp);
if (opDetail.getOpKind() == ContractOpDetail::OpKind::UNKNOWN) {
VectorContractOpInfo opDetail(contractOp);
if (opDetail.getOpKind() == VectorContractOpInfo::OpKind::UNKNOWN) {
return rewriter.notifyMatchFailure(contractOp, "unknown contract kind");
}

Expand Down Expand Up @@ -243,7 +174,7 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {
}

// Gets the batch size for matmul K dimensions.
std::optional<int64_t> getKBatchSize(const ContractOpDetail &opDetail,
std::optional<int64_t> getKBatchSize(const VectorContractOpInfo &opDetail,
NestedLayoutAttr lhsLayout,
NestedLayoutAttr rhsLayout) const {
auto [lhsK, rhsK] = *opDetail.getOperandKIndex();
Expand All @@ -257,7 +188,7 @@ struct DistributeContract final : OpDistributionPattern<vector::ContractionOp> {

// Given a contract op's batch |resultOffsets|, fills its batch offsets for
// both LHS and RHS.
void fillOperandBatchOffsets(const ContractOpDetail &opDetail,
void fillOperandBatchOffsets(const VectorContractOpInfo &opDetail,
int64_t kOffset, ArrayRef<int64_t> resultOffsets,
NestedLayoutAttr resultLayout,
SmallVector<int64_t, 2> &lhsOffsets,
Expand Down
16 changes: 16 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ iree_compiler_cc_library(
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
"//compiler/src/iree/compiler/Codegen/Transforms",
"//compiler/src/iree/compiler/Codegen/Utils",
"//compiler/src/iree/compiler/Codegen/Utils:VectorOpUtils",
"//compiler/src/iree/compiler/Dialect/HAL/IR",
"//llvm-external-projects/iree-dialects:IREELinalgExtDialect",
"//llvm-external-projects/iree-dialects:IREELinalgExtTransforms",
Expand Down Expand Up @@ -125,3 +126,18 @@ iree_compiler_cc_library(
"@llvm-project//mlir:VectorTransforms",
],
)

iree_compiler_cc_library(
name = "GPUHeuristics",
srcs = [
"GPUHeuristics.cpp",
],
hdrs = [
"GPUHeuristics.h",
],
deps = [
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
)
15 changes: 15 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/GPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,23 @@ iree_cc_library(
iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect
iree::compiler::Codegen::Transforms
iree::compiler::Codegen::Utils
iree::compiler::Codegen::Utils::VectorOpUtils
iree::compiler::Dialect::HAL::IR
PUBLIC
)

iree_cc_library(
NAME
GPUHeuristics
HDRS
"GPUHeuristics.h"
SRCS
"GPUHeuristics.cpp"
DEPS
LLVMSupport
MLIRIR
MLIRSupport
PUBLIC
)

### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
113 changes: 113 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/GPU/GPUHeuristics.h"
#include "llvm/ADT/APInt.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"

#define DEBUG_TYPE "iree-codegen-gpu-heuristics"

using llvm::APIntOps::GreatestCommonDivisor;

namespace mlir::iree_compiler {

std::optional<GPUMMASchedule>
deduceMMASchedule(const GPUMatmulShapeType &problem,
ArrayRef<GPUMatmulShapeType> intrinsics,
const GPUMMAHeuristicSeeds &seeds) {
for (auto [index, intrinsic] : llvm::enumerate(intrinsics)) {
if (problem.aType != intrinsic.aType || problem.bType != intrinsic.bType ||
problem.cType != intrinsic.cType) {
continue; // Cannot use this intrinsic for mismatched types
}

if (problem.mSize % intrinsic.mSize != 0 ||
problem.nSize % intrinsic.nSize != 0 ||
problem.kSize % intrinsic.kSize != 0) {
continue; // Cannot use this intrinsic for misaligned cases
}

int64_t mTotalTileCount = problem.mSize / intrinsic.mSize;
int64_t nTotalTileCount = problem.nSize / intrinsic.nSize;

int64_t remainingWarps = seeds.bestSubgroupCountPerWorkgroup;
int64_t remainingTiles = seeds.bestMNTileCountPerSubgroup;
// Assign more warps to the M dimension (used later) to balance thread
// counts along X and Y dimensions.
int64_t warpSqrt = 1ull
<< (llvm::divideCeil(llvm::Log2_64(remainingWarps), 2));
int64_t tileSqrt = 1ull << (llvm::Log2_64(remainingTiles) / 2);

int64_t mWarpCount = 0, nWarpCount = 0;
int64_t mTileCount = 0, nTileCount = 0;

// See if the square root can divide mTotalTileCount. If so it means we can
// distribute to both dimensions evenly. Otherwise, try to distribute to N
// and then M.
if (mTotalTileCount > (warpSqrt * tileSqrt) &&
mTotalTileCount % (warpSqrt * tileSqrt) == 0) {
mWarpCount = warpSqrt;
mTileCount = tileSqrt;

remainingWarps /= warpSqrt;
remainingTiles /= tileSqrt;

APInt nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount),
APInt(64, remainingWarps));
nWarpCount = nGCD.getSExtValue();
nTotalTileCount /= nWarpCount;
remainingWarps /= nWarpCount;

nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount),
APInt(64, remainingTiles));
nTileCount = nGCD.getSExtValue();
} else {
APInt nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount),
APInt(64, remainingWarps));
nWarpCount = nGCD.getSExtValue();
nTotalTileCount /= nWarpCount;
remainingWarps /= nWarpCount;

nGCD = GreatestCommonDivisor(APInt(64, nTotalTileCount),
APInt(64, remainingTiles));
nTileCount = nGCD.getSExtValue();
remainingTiles /= nTileCount;

APInt mGCD = GreatestCommonDivisor(APInt(64, mTotalTileCount),
APInt(64, remainingWarps));
mWarpCount = mGCD.getSExtValue();
mTotalTileCount /= mWarpCount;
remainingWarps /= mWarpCount;

mGCD = GreatestCommonDivisor(APInt(64, mTotalTileCount),
APInt(64, remainingTiles));
mTileCount = mGCD.getSExtValue();
}

const uint64_t kTotalTileCount = problem.kSize / intrinsic.kSize;
APInt kGCD = GreatestCommonDivisor(
APInt(64, kTotalTileCount), APInt(64, seeds.bestKTileCountPerSubgroup));
int64_t kTileCount = kGCD.getSExtValue();

LLVM_DEBUG({
llvm::dbgs() << "chosen MMA schedule:\n";
llvm::dbgs() << " intrinsic (M, N, K) = (" << intrinsic.mSize << ", "
<< intrinsic.nSize << ", " << intrinsic.kSize << ")\n";
llvm::dbgs() << " subgroup count (M, N) = (" << mWarpCount << ", "
<< nWarpCount << ")\n";
llvm::dbgs() << " subgroup tile count (M, N, K) = (" << mTileCount
<< ", " << nTileCount << ", " << kTileCount << ")\n";
});
return GPUMMASchedule{index, intrinsic.mSize, intrinsic.nSize,
intrinsic.kSize, mWarpCount, nWarpCount,
mTileCount, nTileCount, kTileCount};
}
return std::nullopt;
}

} // namespace mlir::iree_compiler
Loading

0 comments on commit 373685f

Please sign in to comment.