Skip to content

Commit

Permalink
Reland pytorch#113487 and pytorch#112527 (sdpa shim & fp8 AOTInductor…
Browse files Browse the repository at this point in the history
… support) (pytorch#114974)

This is a backout of pytorch#113747 which reverted the above two commits. Now that
pytorch#113997 has landed, this diff can be landed safely without breaking ABI compatibility.

Pull Request resolved: pytorch#114974
Approved by: https://github.com/chenyang78
  • Loading branch information
int3 authored and dmenig committed Dec 21, 2023
1 parent df22e9c commit 5ebbaf0
Show file tree
Hide file tree
Showing 7 changed files with 268 additions and 21 deletions.
46 changes: 44 additions & 2 deletions test/inductor/test_aot_inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,50 @@ def forward(self, x, y):
example_inputs = (a, b)
self.check_model(Model(), example_inputs, constraints=constraints)

@unittest.skipIf(
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0),
"FP8 is only supported on H100+",
)
def test_fp8(self):
class Model(torch.nn.Module):
def __init__(self, dtype):
super().__init__()
self.out_dtype = dtype

def forward(self, x, weight, bias, scale_a, scale_b):
weight = weight.to(torch.float8_e4m3fn)
output, updated_amax = torch._scaled_mm(
x,
weight,
bias=input_bias,
out_dtype=self.out_dtype,
scale_a=scale_a,
scale_b=scale_b,
)
return output

dtype = torch.float16

a_scale = torch.Tensor([1.0]).to(device="cuda")
b_scale = torch.Tensor([1.0]).to(device="cuda")
input_bias = torch.rand(32, device="cuda", dtype=dtype)
weight_shape = (32, 16)
weight = torch.rand(*weight_shape, device="cuda", dtype=dtype).T
a_inverse_scale = 1 / a_scale
b_inverse_scale = 1 / b_scale

x_shape = (16, 16)
x = torch.rand(*x_shape, device="cuda", dtype=dtype).to(torch.float8_e4m3fn)
constraints = [
torch._export.dynamic_dim(x, 0) >= 1,
torch._export.dynamic_dim(x, 0) <= 2048,
]
self.check_model(
Model(dtype),
(x, weight, input_bias, a_inverse_scale, b_inverse_scale),
constraints=constraints,
)

def test_poi_multiple_dynamic(self):
class Model(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -1432,8 +1476,6 @@ class AOTInductorTestABICompatibleCpu(TestCase):
"test_poi_multiple_dynamic": TestFailure(("abi_compatible_cpu",)),
# There is a double-free issue which will be fixed in another PR
"test_repeat_output": TestFailure(("abi_compatible_cpu",), is_skip=True),
"test_sdpa": TestFailure(("abi_compatible_cpu",)),
"test_sdpa_2": TestFailure(("abi_compatible_cpu",)),
"test_simple_dynamic": TestFailure(("abi_compatible_cpu",)),
# error: could not find s0
"test_shifted_constraint_ranges": TestFailure(
Expand Down
2 changes: 2 additions & 0 deletions torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@
torch.bool: "at::kBool",
torch.bfloat16: "at::kBFloat16",
torch.complex64: "at::kComplexFloat",
torch.float8_e4m3fn: "at::kFloat8_e4m3fn",
torch.float8_e5m2: "at::kFloat8_e5m2",
}

DEVICE_TO_ATEN = {
Expand Down
35 changes: 29 additions & 6 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,6 +1042,9 @@ def writeline(self, line):
def enter_context(self, ctx):
self.lines.append(LineContext(ctx))

def val_to_cpp_arg_str(self, type_, val, is_legacy_abi) -> str:
raise NotImplementedError()

def val_to_arg_str(self, s):
if isinstance(s, SymTypes):
return pexpr(sympy.expand(repr(s)))
Expand Down Expand Up @@ -1756,8 +1759,11 @@ def generate_c_shim_extern_kernel_alloc_call(self, extern_kernel, args):
else:
raise NotImplementedError("unsupported type of {output=}")
args = args + output_args
assert (
extern_kernel.abi_compatible_kernel is not None
), f"abi_compatible_kernel is None for {extern_kernel.kernel=}"
self.generate_c_shim_extern_kernel_call(
extern_kernel.codegen_kernel_name(), args
extern_kernel.abi_compatible_kernel, args
)
for raii_handle in output_raii_handles:
self.writeline(raii_handle)
Expand Down Expand Up @@ -2367,13 +2373,29 @@ def extract_output_name(out):

self.extern_call_ops.add(cpp_kernel_key)

def val_to_arg_str(self, val):
def val_to_cpp_arg_str(self, type_, val, is_legacy_abi) -> str:
if (
config.aot_inductor.abi_compatible
and not is_legacy_abi
and isinstance(type_, torch.OptionalType)
):
if val is None:
return "0" # nullptr is not available in C
if isinstance(val, (bool, int, str, float)):
var_name = f"var_{next(self.arg_var_id)}"
self.writeline(f"auto {var_name} = {self.val_to_arg_str(val)};")
return f"&{var_name}"
if not isinstance(type_.getElementType(), torch.TensorType):
return f"&{self.val_to_arg_str(val)}"

return self.val_to_arg_str(val)

def val_to_arg_str(self, val) -> str:
if val is None:
# When None is passed as an argument, it represents an optional that does not contain a value.
if config.aot_inductor.abi_compatible:
return "nullptr"
else:
return "c10::nullopt"
return "0" # nullptr is not available in C
return "c10::nullopt"
elif isinstance(val, bool):
if config.aot_inductor.abi_compatible:
return "1" if val else "0"
Expand All @@ -2395,7 +2417,8 @@ def val_to_arg_str(self, val):
else:
return "-std::numeric_limits<float>::infinity()"
elif isinstance(val, (list, tuple)):
result = f"{{{', '.join(list(map(self.val_to_arg_str, val)))}}}"
# FIXME handle embedded optional types?
result = f"{{{', '.join(self.val_to_arg_str(x) for x in val)}}}"
if config.aot_inductor.abi_compatible:
# Need to pass the array length because we can't use std::vector
return f"{self.codegen_int_array_var(result)}, {len(val)}"
Expand Down
2 changes: 2 additions & 0 deletions torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def supported_dtype_of_cpp_wrapper(dtype, cuda):
}
if cuda:
supported_dtype.add(torch.float16)
supported_dtype.add(torch.float8_e4m3fn)
supported_dtype.add(torch.float8_e5m2)

return dtype in supported_dtype

Expand Down
58 changes: 54 additions & 4 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3628,9 +3628,10 @@ def get_kwargs_value(self, arg_name):
f"arg {arg_name} not found in self.kwargs or self.kwargs_default_value"
)

def is_legacy_abi_kernel(self):
return False

def codegen_kwargs(self):
if not self.kwargs:
return []
if V.graph.cpp_wrapper:
# FIXME: we should unconditionally fill self.kwargs with missing default values
# instead of carrying an extra self.ordered_kwargs_for_cpp_kernel
Expand All @@ -3642,7 +3643,16 @@ def codegen_kwargs(self):
if isinstance(v, sympy.Expr):
kwargs.append(v)
else:
kwargs.append(V.graph.wrapper_code.val_to_arg_str(v))
# FIXME We should let ExternKernel have access to the cpp schema where possible.
if hasattr(self, "kwargs_default_value"):
type_ = self.kwargs_default_value.get(arg_name).get("type")
else:
type_ = None
kwargs.append(
V.graph.wrapper_code.val_to_cpp_arg_str(
type_, v, self.is_legacy_abi_kernel()
)
)
else:
kwargs = [
f"{k}={V.graph.wrapper_code.val_to_arg_str(v)}"
Expand Down Expand Up @@ -3777,12 +3787,38 @@ def __init__(self, count: int, device: torch.device):


class ExternKernelAlloc(ExternKernel):
# Generate abi-compatible kernel names for shim kernels.
# Each individual shim kernel may have its own versioning rule.
# However, we don't expect we would end up with too many of such rules.
def _get_abi_compatible_kernel(self):
if not V.graph.cpp_wrapper:
return self.kernel

def sdpa_ver_fn():
# For sdpa, we need the v2 version only if any optional
# kwarg is missing.
if any(
self.get_kwargs_value(arg_name) is None
for arg_name in self.ordered_kwargs_for_cpp_kernel
):
return f"{self.cpp_kernel}_v2"
else:
return self.cpp_kernel

kernel_to_ver = {"at::_scaled_dot_product_flash_attention": sdpa_ver_fn}
if (ver_fn := kernel_to_ver.get(self.cpp_kernel, None)) is not None:
return ver_fn()
return self.cpp_kernel

def codegen_kernel_name(self):
return self.cpp_kernel if V.graph.cpp_wrapper else self.kernel

def codegen(self, wrapper):
self.codegen_comment(wrapper)
args = [*self.codegen_args(), *self.codegen_kwargs()]
# Now we setup abi_compatible_kernel after self.kernel
# and kwargs are adjusted appropriately.
self.abi_compatible_kernel = self._get_abi_compatible_kernel()
V.graph.wrapper_code.generate_extern_kernel_alloc(self, args)
if isinstance(self.layout, Layout):
self.codegen_size_asserts(wrapper)
Expand All @@ -3803,6 +3839,7 @@ def __init__(
self.name = V.graph.register_buffer(self)
self.kernel = kernel
self.cpp_kernel = cpp_kernel
self.abi_compatible_kernel = None
self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel

def should_allocate(self):
Expand Down Expand Up @@ -4302,6 +4339,9 @@ def is_not_write(arg):
x.name for x in kernel._schema.arguments if x.kwarg_only
]

def is_legacy_abi_kernel(self):
return "_scaled_dot_product_flash_attention" in str(self.kernel)

def get_arg_default_value(self, pos):
assert hasattr(
self, "args_default_value"
Expand All @@ -4321,7 +4361,17 @@ def __repr__(self):

tensor_args = [Shim(x.codegen_reference()) for x in self.inputs]
args, kwargs = self.unflatten_args(tensor_args, self.constant_args)
args = [V.graph.wrapper_code.val_to_arg_str(x) for x in args]

if V.graph.cpp_wrapper and isinstance(self.op_overload, torch._ops.OpOverload):
args = [
V.graph.wrapper_code.val_to_cpp_arg_str(
param.real_type, x, self.is_legacy_abi_kernel()
)
for param, x in zip(self.op_overload._schema.arguments, args)
]
else:
args = [V.graph.wrapper_code.val_to_arg_str(x) for x in args]

# Previously, we want to maintain forward-compatibility by skipping
# default args in the serialized artifacts in fbcode. However,
# some of our shim interfaces require default values being set.
Expand Down
37 changes: 36 additions & 1 deletion torch/csrc/inductor/aoti_torch/c/shim.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ using AOTITorchError = int32_t;
AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_cpu();
AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_cuda();

AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e5m2();
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e4m3fn();
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_bfloat16();
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float16();
AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float32();
Expand Down Expand Up @@ -179,6 +181,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob(
AtenTensorHandle* ret // returns new reference
);

// This version is deprecated. We will remove it later
AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_dot_product_flash_attention(
AtenTensorHandle query,
AtenTensorHandle key,
Expand All @@ -198,6 +201,38 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_dot_product_flash_attention(
AtenTensorHandle* ret8 // returns new reference
);

AOTI_TORCH_EXPORT AOTITorchError
aoti_torch__scaled_dot_product_flash_attention_v2(
AtenTensorHandle query,
AtenTensorHandle key,
AtenTensorHandle value,
double dropout_p,
int is_causal,
int return_debug_mask,
double* scale,
AtenTensorHandle* ret0, // returns new reference
AtenTensorHandle* ret1, // returns new reference
AtenTensorHandle* ret2, // returns new reference
AtenTensorHandle* ret3, // returns new reference
int64_t* ret4,
int64_t* ret5,
AtenTensorHandle* ret6, // returns new reference
AtenTensorHandle* ret7, // returns new reference
AtenTensorHandle* ret8 // returns new reference
);

AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_mm(
AtenTensorHandle self,
AtenTensorHandle mat2,
AtenTensorHandle bias,
int32_t* out_dtype,
AtenTensorHandle scale_a,
AtenTensorHandle scale_b,
AtenTensorHandle scale_result,
int8_t use_fast_accum,
AtenTensorHandle* ret0,
AtenTensorHandle* ret1);

// This function will create a new uninitialized tensor object
// and its pointer is returned through *ret.
AOTI_TORCH_EXPORT AOTITorchError
Expand Down Expand Up @@ -242,7 +277,7 @@ aoti_torch_nonzero(AtenTensorHandle self, AtenTensorHandle* out);

AOTI_TORCH_EXPORT AOTITorchError aoti_torch_repeat_interleave_Tensor(
AtenTensorHandle repeats,
int64_t output_size,
int64_t* output_size,
AtenTensorHandle* out);

AOTI_TORCH_EXPORT AOTITorchError
Expand Down

0 comments on commit 5ebbaf0

Please sign in to comment.