Skip to content

Commit

Permalink
SDP: fuse ops in sdp math (#3535)
Browse files Browse the repository at this point in the history
* sdp: fuse ops in sdp_math path

* Ensure attn_mask contiguous

* Add ut

---------

Co-authored-by: Ye Ting <ting.ye@intel.com>
Co-authored-by: zhuyuhua-v <yuhua.zhu@intel.com>
Co-authored-by: xiaolil1 <xiaoli.liu@intel.com>
  • Loading branch information
4 people committed Dec 25, 2023
1 parent 4ba6dec commit 7ea2a3c
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 29 deletions.
1 change: 1 addition & 0 deletions csrc/gpu/aten/operators/Blas.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "Blas.h"
#include <ATen/WrapDimUtilsMulti.h>
#include <ATen/native/Resize.h>
#include "BlasImpl.h"
Expand Down
26 changes: 26 additions & 0 deletions csrc/gpu/aten/operators/Blas.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#pragma once

#include <ATen/ATen.h>

namespace at {
namespace AtenIpexTypeXPU {
// res = (m1 * m2.transpose()) / oscale
at::Tensor trans_matmul_div_scalar(
const at::Tensor& tensor2,
int64_t dim1,
int64_t dim2,
const at::Tensor& tensor1,
Scalar oscale);

// res = (m1 * m2.transpose()) / oscale + accumul
at::Tensor trans_matmul_div_add(
const at::Tensor& tensor2,
int64_t dim1,
int64_t dim2,
const at::Tensor& tensor1,
Scalar oscale,
Tensor& accumul,
Scalar alpha);

} // namespace AtenIpexTypeXPU
} // namespace at
82 changes: 55 additions & 27 deletions csrc/gpu/aten/operators/transformers/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <ATen/record_function.h>
#include <runtime/Utils.h>
#include <utils/DPCPP.h>
#include "../Blas.h"
#include "../comm/ATDispatch.h"
#include "sdp_utils.h"
#include "utils/CustomOperatorRegistration.h"
Expand Down Expand Up @@ -92,38 +93,62 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math_impl(
auto attn_mask = attn_mask_;
// Naive, composite implementation defined here.

// Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for
// math
// [Original] Scale q, k before matmul for stability see
// https://tinyurl.com/sudb9s96 for math
// Here we apply scaling after matmul for op fusion purpose
bool is_negative_scaling = scale.has_value() && scale.value() < 0.0;
const auto scaling_factor =
sdp::calculate_scale(
query_, is_negative_scaling ? std::abs(scale.value()) : scale)
.sqrt();

const auto query = query_ *
(is_negative_scaling ? c10::SymFloat(0.0) - scaling_factor
: scaling_factor);
const auto orig_scaling_factor = sdp::calculate_scale(
query_, is_negative_scaling ? std::abs(scale.value()) : scale);

if (is_causal) {
TORCH_CHECK(
!attn_mask.has_value(),
"_scaled_dot_product_attention: Explicit attn_mask should not be set when is_causal=True");
TORCH_CHECK(
!query.is_nested() && !key.is_nested(),
!query_.is_nested() && !key.is_nested(),
"_scaled_dot_product_attention: Nested tensors for query / key are not supported when is_causal=True");

// Replace attn_mask with causal mask; lower triangular elements take part
// in attention.
const auto L = query.sym_size(-2), S = key.sym_size(-2);
const auto L = query_.sym_size(-2), S = key.sym_size(-2);
attn_mask =
at::ones_symint({L, S}, query.options().dtype(at::kBool)).tril();
attn_mask = sdp::convert_boolean_attn_mask(attn_mask, query.dtype());
at::ones_symint({L, S}, query_.options().dtype(at::kBool)).tril();
attn_mask = sdp::convert_boolean_attn_mask(attn_mask, query_.dtype());
}
auto attn = at::matmul(query, key.transpose(-2, -1) * scaling_factor);

Tensor attn;
if (attn_mask.has_value()) {
if (at::areAnyTensorSubclassLike({attn, *attn_mask})) {
attn = attn.add(*attn_mask);
attn_mask = attn_mask->contiguous();
if (is_negative_scaling) {
attn = trans_matmul_div_add(
key,
/*dim1=*/-1,
/*dim2=*/-1,
query_,
c10::SymFloat(0.0) - orig_scaling_factor,
*attn_mask,
1.0);
} else {
attn = trans_matmul_div_add(
key,
/*dim1=*/-1,
/*dim2=*/-1,
query_,
orig_scaling_factor,
*attn_mask,
1.0);
}
} else {
if (is_negative_scaling) {
attn = trans_matmul_div_scalar(
key,
/*dim1=*/-1,
/*dim2=*/-1,
query_,
c10::SymFloat(0.0) - orig_scaling_factor);
} else {
attn.add_(*attn_mask);
attn = trans_matmul_div_scalar(
key, /*dim1=*/-1, /*dim2=*/-1, query_, orig_scaling_factor);
}
}
attn = at::softmax(attn, -1);
Expand Down Expand Up @@ -166,15 +191,18 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
Tensor query_fp32 = query_.to(at::kFloat);
Tensor key_fp32 = key.to(at::kFloat);
Tensor value_fp32 = value.to(at::kFloat);
return _scaled_dot_product_attention_math_impl(
query_fp32,
key_fp32,
value_fp32,
attn_mask_,
dropout_p,
is_causal,
dropout_mask,
scale);
auto [attn_output, attn_weight] =
_scaled_dot_product_attention_math_impl(
query_fp32,
key_fp32,
value_fp32,
attn_mask_,
dropout_p,
is_causal,
dropout_mask,
scale);
return std::make_tuple(
attn_output.to(at::kHalf), attn_weight.to(at::kHalf));
}
return _scaled_dot_product_attention_math_impl(
query_,
Expand Down
2 changes: 1 addition & 1 deletion csrc/gpu/aten/operators/transformers/sdp_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ inline c10::SymFloat calculate_scale(
c10::optional<double> scale) {
const auto softmax_scale = scale.has_value()
? scale.value()
: (c10::SymFloat(1.0) / (c10::SymFloat(query.sym_size(-1)).sqrt()));
: c10::SymFloat(query.sym_size(-1)).sqrt();
return c10::SymFloat(softmax_scale);
}

Expand Down
22 changes: 21 additions & 1 deletion tests/gpu/examples/test_sdp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
import intel_extension_for_pytorch # noqa
import intel_extension_for_pytorch as ipex # noqa
import torch.nn.functional as F

from torch.testing._internal.common_utils import TestCase
Expand All @@ -23,3 +23,23 @@ def test_sdp_mem_effi_half(self, dtype=torch.float16):
out_xpu = F.scaled_dot_product_attention(query.xpu(), key.xpu(), value.xpu())

self.assertEqual(out_cpu, out_xpu.cpu().float(), atol=1e-3, rtol=1e-3)

@pytest.mark.skipif(
ipex._C._has_2d_block_array(0),
reason="Only for naive sdp with half datatype on ATS-M",
)
def test_sdp_math_half(self, dtype=torch.float16):
head_dim = 256
seq_lenth = 1
k_seq_lenth = 33
v_seq_lenth = 33
query = torch.rand(1, 16, seq_lenth, head_dim, dtype=dtype)
key = torch.rand(1, 16, k_seq_lenth, head_dim, dtype=dtype)
value = torch.rand(1, 16, v_seq_lenth, head_dim, dtype=dtype)

out_cpu = F.scaled_dot_product_attention(
query.float(), key.float(), value.float()
)
out_xpu = F.scaled_dot_product_attention(query.xpu(), key.xpu(), value.xpu())

self.assertEqual(out_cpu, out_xpu.cpu().float(), atol=1e-3, rtol=1e-3)

0 comments on commit 7ea2a3c

Please sign in to comment.