diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 5ade95252bd88..0ea05eabcbb10 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -569,6 +569,50 @@ def forward(self, x, y): example_inputs = (a, b) self.check_model(Model(), example_inputs, constraints=constraints) + @unittest.skipIf( + not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0), + "FP8 is only supported on H100+", + ) + def test_fp8(self): + class Model(torch.nn.Module): + def __init__(self, dtype): + super().__init__() + self.out_dtype = dtype + + def forward(self, x, weight, bias, scale_a, scale_b): + weight = weight.to(torch.float8_e4m3fn) + output, updated_amax = torch._scaled_mm( + x, + weight, + bias=input_bias, + out_dtype=self.out_dtype, + scale_a=scale_a, + scale_b=scale_b, + ) + return output + + dtype = torch.float16 + + a_scale = torch.Tensor([1.0]).to(device="cuda") + b_scale = torch.Tensor([1.0]).to(device="cuda") + input_bias = torch.rand(32, device="cuda", dtype=dtype) + weight_shape = (32, 16) + weight = torch.rand(*weight_shape, device="cuda", dtype=dtype).T + a_inverse_scale = 1 / a_scale + b_inverse_scale = 1 / b_scale + + x_shape = (16, 16) + x = torch.rand(*x_shape, device="cuda", dtype=dtype).to(torch.float8_e4m3fn) + constraints = [ + torch._export.dynamic_dim(x, 0) >= 1, + torch._export.dynamic_dim(x, 0) <= 2048, + ] + self.check_model( + Model(dtype), + (x, weight, input_bias, a_inverse_scale, b_inverse_scale), + constraints=constraints, + ) + def test_poi_multiple_dynamic(self): class Model(torch.nn.Module): def __init__(self): @@ -1432,8 +1476,6 @@ class AOTInductorTestABICompatibleCpu(TestCase): "test_poi_multiple_dynamic": TestFailure(("abi_compatible_cpu",)), # There is a double-free issue which will be fixed in another PR "test_repeat_output": TestFailure(("abi_compatible_cpu",), is_skip=True), - "test_sdpa": TestFailure(("abi_compatible_cpu",)), - "test_sdpa_2": TestFailure(("abi_compatible_cpu",)), "test_simple_dynamic": TestFailure(("abi_compatible_cpu",)), # error: could not find s0 "test_shifted_constraint_ranges": TestFailure( diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 88b3185e35710..c9965e5233c70 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -77,6 +77,8 @@ torch.bool: "at::kBool", torch.bfloat16: "at::kBFloat16", torch.complex64: "at::kComplexFloat", + torch.float8_e4m3fn: "at::kFloat8_e4m3fn", + torch.float8_e5m2: "at::kFloat8_e5m2", } DEVICE_TO_ATEN = { diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 19b83dd5f0069..2a3c442cf5120 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -1042,6 +1042,9 @@ def writeline(self, line): def enter_context(self, ctx): self.lines.append(LineContext(ctx)) + def val_to_cpp_arg_str(self, type_, val, is_legacy_abi) -> str: + raise NotImplementedError() + def val_to_arg_str(self, s): if isinstance(s, SymTypes): return pexpr(sympy.expand(repr(s))) @@ -1756,8 +1759,11 @@ def generate_c_shim_extern_kernel_alloc_call(self, extern_kernel, args): else: raise NotImplementedError("unsupported type of {output=}") args = args + output_args + assert ( + extern_kernel.abi_compatible_kernel is not None + ), f"abi_compatible_kernel is None for {extern_kernel.kernel=}" self.generate_c_shim_extern_kernel_call( - extern_kernel.codegen_kernel_name(), args + extern_kernel.abi_compatible_kernel, args ) for raii_handle in output_raii_handles: self.writeline(raii_handle) @@ -2367,13 +2373,29 @@ def extract_output_name(out): self.extern_call_ops.add(cpp_kernel_key) - def val_to_arg_str(self, val): + def val_to_cpp_arg_str(self, type_, val, is_legacy_abi) -> str: + if ( + config.aot_inductor.abi_compatible + and not is_legacy_abi + and isinstance(type_, torch.OptionalType) + ): + if val is None: + return "0" # nullptr is not available in C + if isinstance(val, (bool, int, str, float)): + var_name = f"var_{next(self.arg_var_id)}" + self.writeline(f"auto {var_name} = {self.val_to_arg_str(val)};") + return f"&{var_name}" + if not isinstance(type_.getElementType(), torch.TensorType): + return f"&{self.val_to_arg_str(val)}" + + return self.val_to_arg_str(val) + + def val_to_arg_str(self, val) -> str: if val is None: # When None is passed as an argument, it represents an optional that does not contain a value. if config.aot_inductor.abi_compatible: - return "nullptr" - else: - return "c10::nullopt" + return "0" # nullptr is not available in C + return "c10::nullopt" elif isinstance(val, bool): if config.aot_inductor.abi_compatible: return "1" if val else "0" @@ -2395,7 +2417,8 @@ def val_to_arg_str(self, val): else: return "-std::numeric_limits::infinity()" elif isinstance(val, (list, tuple)): - result = f"{{{', '.join(list(map(self.val_to_arg_str, val)))}}}" + # FIXME handle embedded optional types? + result = f"{{{', '.join(self.val_to_arg_str(x) for x in val)}}}" if config.aot_inductor.abi_compatible: # Need to pass the array length because we can't use std::vector return f"{self.codegen_int_array_var(result)}, {len(val)}" diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 98c46d3290e3b..e4840b7777f81 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -79,6 +79,8 @@ def supported_dtype_of_cpp_wrapper(dtype, cuda): } if cuda: supported_dtype.add(torch.float16) + supported_dtype.add(torch.float8_e4m3fn) + supported_dtype.add(torch.float8_e5m2) return dtype in supported_dtype diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 15c4d82f0959c..3c5712d251eca 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3628,9 +3628,10 @@ def get_kwargs_value(self, arg_name): f"arg {arg_name} not found in self.kwargs or self.kwargs_default_value" ) + def is_legacy_abi_kernel(self): + return False + def codegen_kwargs(self): - if not self.kwargs: - return [] if V.graph.cpp_wrapper: # FIXME: we should unconditionally fill self.kwargs with missing default values # instead of carrying an extra self.ordered_kwargs_for_cpp_kernel @@ -3642,7 +3643,16 @@ def codegen_kwargs(self): if isinstance(v, sympy.Expr): kwargs.append(v) else: - kwargs.append(V.graph.wrapper_code.val_to_arg_str(v)) + # FIXME We should let ExternKernel have access to the cpp schema where possible. + if hasattr(self, "kwargs_default_value"): + type_ = self.kwargs_default_value.get(arg_name).get("type") + else: + type_ = None + kwargs.append( + V.graph.wrapper_code.val_to_cpp_arg_str( + type_, v, self.is_legacy_abi_kernel() + ) + ) else: kwargs = [ f"{k}={V.graph.wrapper_code.val_to_arg_str(v)}" @@ -3777,12 +3787,38 @@ def __init__(self, count: int, device: torch.device): class ExternKernelAlloc(ExternKernel): + # Generate abi-compatible kernel names for shim kernels. + # Each individual shim kernel may have its own versioning rule. + # However, we don't expect we would end up with too many of such rules. + def _get_abi_compatible_kernel(self): + if not V.graph.cpp_wrapper: + return self.kernel + + def sdpa_ver_fn(): + # For sdpa, we need the v2 version only if any optional + # kwarg is missing. + if any( + self.get_kwargs_value(arg_name) is None + for arg_name in self.ordered_kwargs_for_cpp_kernel + ): + return f"{self.cpp_kernel}_v2" + else: + return self.cpp_kernel + + kernel_to_ver = {"at::_scaled_dot_product_flash_attention": sdpa_ver_fn} + if (ver_fn := kernel_to_ver.get(self.cpp_kernel, None)) is not None: + return ver_fn() + return self.cpp_kernel + def codegen_kernel_name(self): return self.cpp_kernel if V.graph.cpp_wrapper else self.kernel def codegen(self, wrapper): self.codegen_comment(wrapper) args = [*self.codegen_args(), *self.codegen_kwargs()] + # Now we setup abi_compatible_kernel after self.kernel + # and kwargs are adjusted appropriately. + self.abi_compatible_kernel = self._get_abi_compatible_kernel() V.graph.wrapper_code.generate_extern_kernel_alloc(self, args) if isinstance(self.layout, Layout): self.codegen_size_asserts(wrapper) @@ -3803,6 +3839,7 @@ def __init__( self.name = V.graph.register_buffer(self) self.kernel = kernel self.cpp_kernel = cpp_kernel + self.abi_compatible_kernel = None self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel def should_allocate(self): @@ -4302,6 +4339,9 @@ def is_not_write(arg): x.name for x in kernel._schema.arguments if x.kwarg_only ] + def is_legacy_abi_kernel(self): + return "_scaled_dot_product_flash_attention" in str(self.kernel) + def get_arg_default_value(self, pos): assert hasattr( self, "args_default_value" @@ -4321,7 +4361,17 @@ def __repr__(self): tensor_args = [Shim(x.codegen_reference()) for x in self.inputs] args, kwargs = self.unflatten_args(tensor_args, self.constant_args) - args = [V.graph.wrapper_code.val_to_arg_str(x) for x in args] + + if V.graph.cpp_wrapper and isinstance(self.op_overload, torch._ops.OpOverload): + args = [ + V.graph.wrapper_code.val_to_cpp_arg_str( + param.real_type, x, self.is_legacy_abi_kernel() + ) + for param, x in zip(self.op_overload._schema.arguments, args) + ] + else: + args = [V.graph.wrapper_code.val_to_arg_str(x) for x in args] + # Previously, we want to maintain forward-compatibility by skipping # default args in the serialized artifacts in fbcode. However, # some of our shim interfaces require default values being set. diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index ae0032e2d0c0c..3bc1aaf99fc2b 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -83,6 +83,8 @@ using AOTITorchError = int32_t; AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_cpu(); AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_cuda(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e5m2(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e4m3fn(); AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_bfloat16(); AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float16(); AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float32(); @@ -179,6 +181,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob( AtenTensorHandle* ret // returns new reference ); +// This version is deprecated. We will remove it later AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_dot_product_flash_attention( AtenTensorHandle query, AtenTensorHandle key, @@ -198,6 +201,38 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_dot_product_flash_attention( AtenTensorHandle* ret8 // returns new reference ); +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch__scaled_dot_product_flash_attention_v2( + AtenTensorHandle query, + AtenTensorHandle key, + AtenTensorHandle value, + double dropout_p, + int is_causal, + int return_debug_mask, + double* scale, + AtenTensorHandle* ret0, // returns new reference + AtenTensorHandle* ret1, // returns new reference + AtenTensorHandle* ret2, // returns new reference + AtenTensorHandle* ret3, // returns new reference + int64_t* ret4, + int64_t* ret5, + AtenTensorHandle* ret6, // returns new reference + AtenTensorHandle* ret7, // returns new reference + AtenTensorHandle* ret8 // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_mm( + AtenTensorHandle self, + AtenTensorHandle mat2, + AtenTensorHandle bias, + int32_t* out_dtype, + AtenTensorHandle scale_a, + AtenTensorHandle scale_b, + AtenTensorHandle scale_result, + int8_t use_fast_accum, + AtenTensorHandle* ret0, + AtenTensorHandle* ret1); + // This function will create a new uninitialized tensor object // and its pointer is returned through *ret. AOTI_TORCH_EXPORT AOTITorchError @@ -242,7 +277,7 @@ aoti_torch_nonzero(AtenTensorHandle self, AtenTensorHandle* out); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_repeat_interleave_Tensor( AtenTensorHandle repeats, - int64_t output_size, + int64_t* output_size, AtenTensorHandle* out); AOTI_TORCH_EXPORT AOTITorchError diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index c28b391c2c5a4..9dd62412e4bf5 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -18,6 +18,7 @@ #include #include +#include #include #include #include @@ -42,6 +43,17 @@ static c10::Device c10_device(int32_t device_type, int32_t device_index) { static_cast(device_index)); } } + +template +c10::optional pointer_to_optional(T* ptr) { + return ptr ? c10::make_optional(*ptr) : c10::nullopt; +} + +template >> +c10::optional pointer_to_optional(U* ptr) { + return ptr ? c10::make_optional(T(*ptr)) : c10::nullopt; +} + } // namespace int32_t aoti_torch_device_type_cpu() { @@ -52,6 +64,14 @@ int32_t aoti_torch_device_type_cuda() { return (int32_t)c10::DeviceType::CUDA; } +int32_t aoti_torch_dtype_float8_e5m2() { + return (int32_t)c10::ScalarType::Float8_e5m2; +} + +int32_t aoti_torch_dtype_float8_e4m3fn() { + return (int32_t)c10::ScalarType::Float8_e4m3fn; +} + int32_t aoti_torch_dtype_bfloat16() { return (int32_t)c10::ScalarType::BFloat16; } @@ -245,14 +265,14 @@ AOTITorchError aoti_torch_create_tensor_from_blob( }); } -AOTITorchError aoti_torch__scaled_dot_product_flash_attention( +AOTITorchError aoti_torch__scaled_dot_product_flash_attention_v2( AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, double dropout_p, - bool is_causal, - bool return_debug_mask, - double scale, + int is_causal, + int return_debug_mask, + double* scale, AtenTensorHandle* ret0, // returns new reference AtenTensorHandle* ret1, // returns new reference AtenTensorHandle* ret2, // returns new reference @@ -267,6 +287,7 @@ AOTITorchError aoti_torch__scaled_dot_product_flash_attention( at::Tensor* query_tensor = tensor_handle_to_tensor_pointer(query); at::Tensor* key_tensor = tensor_handle_to_tensor_pointer(key); at::Tensor* value_tensor = tensor_handle_to_tensor_pointer(value); + auto optional_scale = pointer_to_optional(scale); auto [r0, r1, r2, r3, r4, r5, r6, r7, r8] = at::_scaled_dot_product_flash_attention( *query_tensor, @@ -275,7 +296,7 @@ AOTITorchError aoti_torch__scaled_dot_product_flash_attention( dropout_p, is_causal, return_debug_mask, - scale); + optional_scale); at::Tensor* ret0_tensor = new at::Tensor(std::move(r0)); *ret0 = tensor_pointer_to_tensor_handle(ret0_tensor); @@ -301,6 +322,43 @@ AOTITorchError aoti_torch__scaled_dot_product_flash_attention( }); } +AOTITorchError aoti_torch__scaled_dot_product_flash_attention( + AtenTensorHandle query, + AtenTensorHandle key, + AtenTensorHandle value, + double dropout_p, + bool is_causal, + bool return_debug_mask, + double scale, + AtenTensorHandle* ret0, // returns new reference + AtenTensorHandle* ret1, // returns new reference + AtenTensorHandle* ret2, // returns new reference + AtenTensorHandle* ret3, // returns new reference + int64_t* ret4, + int64_t* ret5, + AtenTensorHandle* ret6, // returns new reference + AtenTensorHandle* ret7, // returns new reference + AtenTensorHandle* ret8 // returns new reference +) { + return aoti_torch__scaled_dot_product_flash_attention_v2( + query, + key, + value, + dropout_p, + is_causal, + return_debug_mask, + &scale, + ret0, + ret1, + ret2, + ret3, + ret4, + ret5, + ret6, + ret7, + ret8); +} + AOTITorchError aoti_torch_new_uninitialized_tensor(AtenTensorHandle* ret) { AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ at::Tensor* out_tensor = new at::Tensor(); @@ -308,6 +366,41 @@ AOTITorchError aoti_torch_new_uninitialized_tensor(AtenTensorHandle* ret) { }); } +AOTITorchError aoti_torch__scaled_mm( + AtenTensorHandle self, + AtenTensorHandle mat2, + AtenTensorHandle bias, + int32_t* out_dtype, + AtenTensorHandle scale_a, + AtenTensorHandle scale_b, + AtenTensorHandle scale_result, + int8_t use_fast_accum, + AtenTensorHandle* ret0, + AtenTensorHandle* ret1) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + at::Tensor* self_tensor = tensor_handle_to_tensor_pointer(self); + at::Tensor* mat2_tensor = tensor_handle_to_tensor_pointer(mat2); + at::Tensor* bias_tensor = tensor_handle_to_tensor_pointer(bias); + at::Tensor* scale_a_tensor = tensor_handle_to_tensor_pointer(scale_a); + at::Tensor* scale_b_tensor = tensor_handle_to_tensor_pointer(scale_b); + at::Tensor* scale_result_tensor = + tensor_handle_to_tensor_pointer(scale_result); + auto [r0, r1] = at::_scaled_mm( + *self_tensor, + *mat2_tensor, + pointer_to_optional(bias_tensor), + pointer_to_optional(out_dtype), + pointer_to_optional(scale_a_tensor), + pointer_to_optional(scale_b_tensor), + pointer_to_optional(scale_result_tensor), + use_fast_accum); + at::Tensor* ret0_tensor = new at::Tensor(std::move(r0)); + *ret0 = tensor_pointer_to_tensor_handle(ret0_tensor); + at::Tensor* ret1_tensor = new at::Tensor(std::move(r1)); + *ret1 = tensor_pointer_to_tensor_handle(ret1_tensor); + }); +} + // TODO: implement a more efficient version instead of calling into aten AOTITorchError aoti_torch_tensor_copy_( AtenTensorHandle src, @@ -395,12 +488,12 @@ AOTITorchError aoti_torch_nonzero( AOTITorchError aoti_torch_repeat_interleave_Tensor( AtenTensorHandle repeats, - int64_t output_size, + int64_t* output_size, AtenTensorHandle* out) { AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ at::Tensor* repeats_tensor = tensor_handle_to_tensor_pointer(repeats); - at::Tensor out_tensor = - at::_ops::repeat_interleave_Tensor::call(*repeats_tensor, output_size); + at::Tensor out_tensor = at::_ops::repeat_interleave_Tensor::call( + *repeats_tensor, pointer_to_optional(output_size)); at::Tensor* out_tensor_ptr = new at::Tensor(std::move(out_tensor)); *out = tensor_pointer_to_tensor_handle(out_tensor_ptr); });