Skip to content

Commit

Permalink
PR #59515: Cublaslt fp8 matmul restriction work-around
Browse files Browse the repository at this point in the history
Imported from GitHub PR tensorflow/tensorflow#59515

As of 2/1/2023, cublasLt f8 matmul only support col major input(default to cublas). But calling from TF/XLA, input can be in all kinds of storage type. This PR aims to "canonicalize" fp8 matmuls by having lrs/rhs_contracting_dim={1,0} and adding necessary transposes to inputs. A reproducer of this bug restriction is located at [here](https://github.com/wenscarl/fp8_gemm_test/blob/main/fp8_gemm_backward_fail.py).

A remaining restriction is the batch dimension still needs to be a leading dimension.
Copybara import of the project:

--
77de65a4d0c2c0c7db8e6b305d03e039c454c2e2 by wenscarl <shuw@nvidia.com>:

Workaround cublasLt fp8 matmul restrictions.
--
706ad9adb3e5109acc7909d39fa89c9286c192f3 by shuw <shuw@nvidia.com>:

Work around cublasLt fp8 matmul restrictions

--
8f4856968d826b621a3a083f7843ffb30b8d921f by wenscarl <shuw@nvidia.com>:

Add MatrixIsColumnMajor
--
e0e67a4223dda278463e2004f975b32c0d5ac3d8 by shuw <shuw@nvidia.com>:

Remove dead code

--
cfea2cb8dc60edb07bdf1fd3c2113cd401354701 by shuw <shuw@nvidia.com>:

Add gemm_rewrite test

--
460f616103888ad6d80dc22cc3876690ab7e16c3 by shuw <shuw@nvidia.com>:

Cover 32 cases

--
6f724ff30570d7f35dba03a13d6a9bf6aa903da5 by shuw <shuw@nvidia.com>:

Abbreviate branches logics

--
64f153a32cb8ddd925f6f8d4a0ae88178990f24c by shuw <shuw@nvidia.com>:

Parameterized test

--
6146558b21416ec3c5b8aa32f7f59c9f3b49ae87 by shuw <shuw@nvidia.com>:

auto -> type names

--
96e758886d67fe868e4d263b44cc729af972fb71 by shuw <shuw@nvidia.com>:

Update comments

--
07db882429e7e93afba335714d672c460483ec4f by shuw <shuw@nvidia.com>:

vector -> array

--
3089065fb4d78033b85f46b68cd25a0643dd17e3 by shuw <shuw@nvidia.com>:

fix typo

--
50eaec634ccf7d77a0e8315ade5e0a5bab06c07e by shuw <shuw@nvidia.com>:

Convert to NN config and add batched matmul tests

Merging this change closes #59515

FUTURE_COPYBARA_INTEGRATE_REVIEW=tensorflow/tensorflow#59515 from wenscarl:cublaslt_fp8_matmul_war bff4bc682d5e0bc200e6261317ba1499960c3e45
PiperOrigin-RevId: 512191684
  • Loading branch information
reedwm authored and Copybara-Service committed Feb 25, 2023
1 parent fd24c67 commit 3653fc3
Show file tree
Hide file tree
Showing 3 changed files with 222 additions and 44 deletions.
108 changes: 72 additions & 36 deletions xla/service/gpu/gemm_rewriter.cc
Expand Up @@ -625,14 +625,10 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
return false;
}
if ((a_bitcast ? a_bitcast : a)->shape().dimensions_size() -
gemm_backend_config.dot_dimension_numbers()
.lhs_batch_dimensions()
.size() !=
a_batch_dims.size() !=
2 ||
(b_bitcast ? b_bitcast : b)->shape().dimensions_size() -
gemm_backend_config.dot_dimension_numbers()
.rhs_batch_dimensions()
.size() !=
b_batch_dims.size() !=
2) {
return false;
}
Expand All @@ -650,24 +646,19 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
// Shift any bitcasts to the unconverted and unscaled operands.
if (a_bitcast) {
a = instr->AddInstruction(a_bitcast->CloneWithNewOperands(
ShapeUtil::MakeShape(a->shape().element_type(),
a_bitcast->shape().dimensions()),
ShapeUtil::MakeShapeWithDenseLayout(
a->shape().element_type(), a_bitcast->shape().dimensions(),
a_bitcast->shape().layout().minor_to_major()),
{a}));
}
if (b_bitcast) {
b = instr->AddInstruction(b_bitcast->CloneWithNewOperands(
ShapeUtil::MakeShape(b->shape().element_type(),
b_bitcast->shape().dimensions()),
ShapeUtil::MakeShapeWithDenseLayout(
b->shape().element_type(), b_bitcast->shape().dimensions(),
b_bitcast->shape().layout().minor_to_major()),
{b}));
}

// cuBLASLt FP8 GEMM kernels currently require the first operand, i.e. A, to
// be transposed. If the result of the GEMM is not in column major order, A
// and B are later exchanged, and B is transposed here instead.
// TODO(philipphack): Remove once cuBLASLt supports the NN configuration.
TF_ASSIGN_OR_RETURN(bool is_col_major,
OutputIsColumnMajor(instr, gemm_backend_config));

// Identify the dimensional order which describes a transpose of the
// contracting and non-contracting dimensions of the GEMM.
auto transp_dim_order =
Expand Down Expand Up @@ -700,21 +691,57 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
}
return transp_dims;
};
// Transpose a matrix by swapping the contracting and non-contracting
// dimension. Keep the layout the same.
auto plain_transpose =
[&](HloInstruction **x,
const absl::Span<const int64_t> &contracting_dims,
const absl::Span<const int64_t> &batch_dims) {
std::vector<int64_t> new_dim_order =
transp_dim_order(*x, contracting_dims[0], batch_dims);
*x = instr->AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShapeWithDenseLayout(
(*x)->shape().element_type(), transp_dims(*x, new_dim_order),
(*x)->shape().layout().minor_to_major()),
*x, new_dim_order));
};

TF_ASSIGN_OR_RETURN(bool a_is_col_major,
MatrixIsColumnMajor(instr, gemm_backend_config, "a"));
TF_ASSIGN_OR_RETURN(bool b_is_col_major,
MatrixIsColumnMajor(instr, gemm_backend_config, "b"));

DotDimensionNumbers *dim_nums =
gemm_backend_config.mutable_dot_dimension_numbers();
int a_batch_dim_offset = a_batch_dims.size();
int b_batch_dim_offset = b_batch_dims.size();

if (is_col_major) {
std::vector<int64_t> new_dim_order =
transp_dim_order(a, a_contracting_dims[0], a_batch_dims);
a = instr->AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(a->shape().element_type(),
transp_dims(a, new_dim_order)),
a, new_dim_order));
} else {
std::vector<int64_t> new_dim_order =
transp_dim_order(b, b_contracting_dims[0], b_batch_dims);
b = instr->AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(b->shape().element_type(),
transp_dims(b, new_dim_order)),
b, new_dim_order));
// cuBLASLt FP8 GEMM kernels currently require the first operand, i.e. A, to
// be row-major. If A is column-major, swap the contracting and
// non-contracting dimension and transpose the matrix to effectively make it
// column-major.
// TODO(philipphack): Remove once cuBLASLt supports A being column-major
if (a_is_col_major) {
DCHECK(a_contracting_dims[0] == a_batch_dim_offset ||
a_contracting_dims[0] == a_batch_dim_offset + 1);
if (a_contracting_dims[0] == a_batch_dim_offset) {
dim_nums->set_lhs_contracting_dimensions(0, a_batch_dim_offset + 1);
} else {
dim_nums->set_lhs_contracting_dimensions(0, a_batch_dim_offset);
}
plain_transpose(&a, a_contracting_dims, a_batch_dims);
}
// Similarly, cuBLASLt requires the second operand to be column-major, so
// make it column-major if it is currently row-major.
if (!b_is_col_major) {
DCHECK(b_contracting_dims[0] == b_batch_dim_offset ||
b_contracting_dims[0] == b_batch_dim_offset + 1);
if (b_contracting_dims[0] == b_batch_dim_offset) {
dim_nums->set_rhs_contracting_dimensions(0, b_batch_dim_offset + 1);
} else {
dim_nums->set_rhs_contracting_dimensions(0, b_batch_dim_offset);
}
plain_transpose(&b, b_contracting_dims, b_batch_dims);
}

std::unique_ptr<HloInstruction> new_custom_call =
Expand Down Expand Up @@ -1217,9 +1244,9 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
output_dtype));
}

StatusOr<bool> OutputIsColumnMajor(
const HloInstruction *instr,
const GemmBackendConfig &gemm_backend_config) const {
StatusOr<bool> MatrixIsColumnMajor(
const HloInstruction *instr, const GemmBackendConfig &gemm_backend_config,
const std::string matrix_name = "output") const {
const HloInstruction *lhs = instr->operand(0);
const HloInstruction *rhs = instr->operand(1);

Expand All @@ -1236,7 +1263,16 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
gemm_backend_config.alpha_imag(), gemm_backend_config.beta(),
/*algorithm*/ std::nullopt, se::blas::kDefaultComputePrecision));

return gemm_config.output_layout.order == MatrixLayout::Order::kColumnMajor;
if (matrix_name == "lhs" || matrix_name == "a") {
return gemm_config.lhs_layout.order == MatrixLayout::Order::kColumnMajor;
} else if (matrix_name == "rhs" || matrix_name == "b") {
return gemm_config.rhs_layout.order == MatrixLayout::Order::kColumnMajor;
} else if (matrix_name == "output" || matrix_name == "d") {
return gemm_config.output_layout.order ==
MatrixLayout::Order::kColumnMajor;
} else {
return InternalError("Invalid matrix name.");
}
}

StatusOr<bool> GemmIsSupportedByCublasLt(
Expand Down Expand Up @@ -1292,7 +1328,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
gemm_backend_config.dot_dimension_numbers();

TF_ASSIGN_OR_RETURN(bool output_is_column_major,
OutputIsColumnMajor(instr, gemm_backend_config));
MatrixIsColumnMajor(instr, gemm_backend_config));
if (!output_is_column_major) {
// cublasLt's matmul output is column major by default. This gemm requires
// the output to be in row major. Later we will swap lhs & rhs (and
Expand Down
18 changes: 10 additions & 8 deletions xla/service/gpu/matmul_utils.cc
Expand Up @@ -806,13 +806,15 @@ StatusOr<se::cuda::BlasLt::Epilogue> AsBlasLtEpilogue(
lhs_layout.batch_size = batch_size;
rhs_layout.batch_size = batch_size;

// cuBLASLt FP8 GEMM kernels require A (i.e. lhs) to be transposed.
se::blas::Transpose trans_a;
if (lhs_layout.dtype == F8E4M3FN || lhs_layout.dtype == F8E5M2) {
trans_a = se::blas::Transpose::kTranspose;
} else {
trans_a = se::blas::Transpose::kNoTranspose;
}
// The cuBLASLt documentation somewhat incorrectly claims "A must be
// transposed and B non-transposed" when A and B are FP8
// (https://docs.nvidia.com/cuda/cublas/#cublasltmatmul). In reality, this is
// only true if A and B are column-major. If A is row-major, A must *not* be
// transposed, and if B is row-major, B must be transposed. We never transpose
// A or B, and expect the caller to ensure A is row-major and B is column when
// A and B are FP8.
const se::blas::Transpose trans_a = se::blas::Transpose::kNoTranspose;
const se::blas::Transpose trans_b = se::blas::Transpose::kNoTranspose;

bool must_swap_operands =
MakeOutputColumnMajor(lhs_layout, rhs_layout, c_layout, output_layout);
Expand All @@ -826,7 +828,7 @@ StatusOr<se::cuda::BlasLt::Epilogue> AsBlasLtEpilogue(
se::cuda::BlasLt::MatmulDesc op_desc,
se::cuda::BlasLt::MatmulDesc::Create(
computation_type, GetScaleType(output_dtype, computation_type),
trans_a, /*trans_b=*/se::blas::Transpose::kNoTranspose, epilogue));
trans_a, trans_b, epilogue));

TF_ASSIGN_OR_RETURN(se::cuda::BlasLt::MatrixLayout a_desc,
AsBlasLtMatrixLayout(lhs_layout));
Expand Down
140 changes: 140 additions & 0 deletions xla/service/gpu/tests/gemm_rewrite_test.cc
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <array>
#include <string>
#include <tuple>
#include <utility>
Expand Down Expand Up @@ -4658,6 +4659,145 @@ TEST_F(CublasLtF8GemmRewriteTest, ScaledABScaledDWithDAmaxF8) {
)");
}

TEST_F(CublasLtF8GemmRewriteTest, ScaledABUnscaledDF8Parameterized) {
if (!GetCudaComputeCapability().IsAtLeast(
se::CudaComputeCapability::HOPPER)) {
GTEST_SKIP()
<< "cuBLASLt FP8 kernels require Hopper or newer architecture.";
}

std::array<std::array<absl::string_view, 7>, 32> combinations;
int i = 0;
for (bool d_is_col : {false, true}) {
for (bool a_is_col : {false, true}) {
for (bool b_is_col : {false, true}) {
for (int lhs_contracting_dim : {0, 1}) {
for (int rhs_contracting_dim : {0, 1}) {
const absl::string_view lcd =
lhs_contracting_dim == 1 ? "{1}" : "{0}";
const absl::string_view rcd =
rhs_contracting_dim == 1 ? "{1}" : "{0}";
const absl::string_view a_shape =
lhs_contracting_dim == 1 ? "[64,32]" : "[32,64]";
const absl::string_view b_shape =
rhs_contracting_dim == 0 ? "[32,16]" : "[16,32]";
const absl::string_view a_layout = a_is_col ? "{0,1}" : "{1,0}";
const absl::string_view b_layout = b_is_col ? "{0,1}" : "{1,0}";
const absl::string_view output_layout =
d_is_col ? "{0,1}" : "{1,0}";
combinations[i++] = std::array{
lcd, rcd, a_shape, b_shape, a_layout, b_layout, output_layout};
}
}
}
}
}
const char* hlo_template = R"(
HloModule test
ENTRY test {
x = f8e4m3fn<<Ashape>><<Alayout>> parameter(0)
x_f32 = f32<<Ashape>><<Alayout>> convert(x)
x_scale = f32[] parameter(2)
x_scale_bcast = f32<<Ashape>> broadcast(x_scale), dimensions={}
x_unscaled = f32<<Ashape>> multiply(x_f32, x_scale_bcast)
y = f8e4m3fn<<Bshape>><<Blayout>> parameter(1)
y_f32 = f32<<Bshape>><<Blayout>> convert(y)
y_scale = f32[] parameter(3)
y_scale_bcast = f32<<Bshape>> broadcast(y_scale), dimensions={}
y_unscaled = f32<<Bshape>> multiply(y_f32, y_scale_bcast)
ROOT out = f32[64,16]<<Olayout>> dot(x_unscaled, y_unscaled), lhs_contracting_dims=<<Lcd>>, rhs_contracting_dims=<<Rcd>>
}
)";
for (const auto& combination : combinations) {
absl::flat_hash_map<absl::string_view, absl::string_view> replacements;
replacements["<<Lcd>>"] = std::get<0>(combination);
replacements["<<Rcd>>"] = std::get<1>(combination);
replacements["<<Ashape>>"] = std::get<2>(combination);
replacements["<<Bshape>>"] = std::get<3>(combination);
replacements["<<Alayout>>"] = std::get<4>(combination);
replacements["<<Blayout>>"] = std::get<5>(combination);
replacements["<<Olayout>>"] = std::get<6>(combination);
const auto hlo_text = absl::StrReplaceAll(hlo_template, replacements);
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-4, 0.}));

MatchOptimizedHlo(hlo_text,
R"(
; CHECK: custom_call_target="__cublas$lt$matmul$f8",
)");
}
}

TEST_F(CublasLtF8GemmRewriteTest, ScaledABUnscaledDF8ParameterizedBatched) {
// TODO(wenscarl): For batched matmaul, not all combinations of A, B and
// output layouts get pattern matched successfully to FP8 custom call. Only
// a handful of cases are tested here.
if (!GetCudaComputeCapability().IsAtLeast(
se::CudaComputeCapability::HOPPER)) {
GTEST_SKIP()
<< "cuBLASLt FP8 kernels require Hopper or newer architecture.";
}

std::array<std::array<std::string, 7>, 32> combinations;
std::string lcd, rcd, a_shape, b_shape, a_layout, b_layout, o_layout;
int i = 0;
for (bool o_is_col : {false, true}) {
for (bool a_is_col : {false, true}) {
for (bool b_is_col : {false, true}) {
for (int lhs_contracting_dim : {2, 1}) {
for (int rhs_contracting_dim : {2, 1}) {
lcd = lhs_contracting_dim == 2 ? "{2}" : "{1}";
rcd = rhs_contracting_dim == 2 ? "{2}" : "{1}";
a_shape = lhs_contracting_dim == 2 ? "[2,64,32]" : "[2,32,64]";
b_shape = rhs_contracting_dim == 1 ? "[2,32,16]" : "[2,16,32]";
a_layout = a_is_col ? "{1,2,0}" : "{2,1,0}";
b_layout = b_is_col ? "{1,2,0}" : "{2,1,0}";
o_layout = o_is_col ? "{1,2,0}" : "{2,1,0}";
combinations[i++] = std::array{lcd, rcd, a_shape, b_shape,
a_layout, b_layout, o_layout};
}
}
}
}
}

const char* hlo_template = R"(
HloModule m
ENTRY f {
x_q = f8e4m3fn<<Ashape>><<Alayout>> parameter(0)
x_scale = f32[] parameter(2)
x_scale_broadcast = f32<<Ashape>><<Alayout>> broadcast(x_scale), dimensions={}
x_q_convert = f32<<Ashape>><<Alayout>> convert(x_q)
x_qdq = f32<<Ashape>><<Alayout>> multiply(x_q_convert, x_scale_broadcast)
y_q = f8e4m3fn<<Bshape>><<Blayout>> parameter(1)
y_scale = f32[] parameter(3)
y_scale_broadcast = f32<<Bshape>><<Blayout>> broadcast(y_scale), dimensions={}
y_q_convert = f32<<Bshape>><<Blayout>> convert(y_q)
y_qdq = f32<<Bshape>><<Blayout>> multiply(y_q_convert, y_scale_broadcast)
ROOT out = f32[2,64,16]<<Olayout>> dot(x_qdq, y_qdq), lhs_batch_dims={0}, lhs_contracting_dims=<<Lcd>>, rhs_batch_dims={0}, rhs_contracting_dims=<<Rcd>>
}
)";
for (const auto& combination : combinations) {
absl::flat_hash_map<std::string, std::string> replacements;
replacements["<<Lcd>>"] = std::get<0>(combination);
replacements["<<Rcd>>"] = std::get<1>(combination);
replacements["<<Ashape>>"] = std::get<2>(combination);
replacements["<<Bshape>>"] = std::get<3>(combination);
replacements["<<Alayout>>"] = std::get<4>(combination);
replacements["<<Blayout>>"] = std::get<5>(combination);
replacements["<<Olayout>>"] = std::get<6>(combination);

const auto hlo_text = absl::StrReplaceAll(hlo_template, replacements);
EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-4, 0.}));

MatchOptimizedHlo(hlo_text,
R"(
; CHECK: custom_call_target="__cublas$lt$matmul$f8",
)");
}
}

TEST_F(GemmRewriteTest, NoFuseBiasBroadcast) {
const char* hlo = R"(
Expand Down

0 comments on commit 3653fc3

Please sign in to comment.