diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 4ff0166..5e198af 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -19,6 +19,7 @@ from float8_experimental.float8_dynamic_linear import ( cast_to_float8_e4m3_dynamic, cast_to_float8_e5m2_dynamic_bw, + WeightWithDynamicFloat8CastTensor, ) from float8_experimental.float8_tensor import ( @@ -163,6 +164,7 @@ def __init__(self, *args, **kwargs): ) # Amax scales should always be kept as float32. self.always_float32_buffers = set() + emulate = kwargs.pop("emulate", False) scaling_type_x = kwargs.pop("scaling_type_x", TensorScalingType.DELAYED) scaling_type_w = kwargs.pop("scaling_type_w", TensorScalingType.DELAYED) scaling_type_dL_dY = kwargs.pop("scaling_type_dL_dY", TensorScalingType.DELAYED) @@ -187,8 +189,12 @@ def __init__(self, *args, **kwargs): self.create_buffers() # Defines the behavior of the matmul in the forward and backward pass - self.forward_config = ScaledMMConfig() - self.backward_config = ScaledMMConfig() + self.forward_config = ScaledMMConfig( + emulate, True if not emulate else False, False, config.pad_inner_dim + ) + self.backward_config = ScaledMMConfig( + emulate, False, False, config.pad_inner_dim + ) # Note: is_amax_initialized is not a buffer to avoid data dependent # control flow visible to dynamo @@ -428,19 +434,20 @@ def from_float( scaling_type_x=scaling_type_x, scaling_type_w=scaling_type_w, scaling_type_dL_dY=scaling_type_dL_dY, + emulate=emulate, + ) + if ( + scaling_type_w == TensorScalingType.DYNAMIC + and config.enable_fsdp_fp8_all_gather + ): + new_mod.weight = torch.nn.Parameter( + WeightWithDynamicFloat8CastTensor(mod.weight, new_mod.forward_config) ) - new_mod.weight = mod.weight + else: + assert not config.enable_fsdp_fp8_all_gather, "unsupported" + new_mod.weight = mod.weight new_mod.bias = mod.bias # need to create buffers again when moving from meta device to # real device new_mod.create_buffers() - # Defines the behavior of the matmul in the forward and backward - # Forward we use fast_accum, backwards we do not - # TODO(future PR): move below to the constructor - new_mod.forward_config = ScaledMMConfig( - emulate, True if not emulate else False, False, config.pad_inner_dim - ) - new_mod.backward_config = ScaledMMConfig( - emulate, False, False, config.pad_inner_dim - ) return new_mod diff --git a/test/test_fsdp2/test_fsdp2_eager.py b/test/test_fsdp2/test_fsdp2_eager.py index 98ef92b..5e4dc8f 100644 --- a/test/test_fsdp2/test_fsdp2_eager.py +++ b/test/test_fsdp2/test_fsdp2_eager.py @@ -1,4 +1,5 @@ import copy +import itertools import threading import unittest from typing import Any, List @@ -11,6 +12,7 @@ Float8DynamicLinear, WeightWithDynamicFloat8CastTensor, ) +from float8_experimental.float8_linear import Float8Linear, TensorScalingType from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear from test_fsdp2_common import ( check_parity_bf16_mp, @@ -74,8 +76,16 @@ def get_local_inp(self, dtype: torch.dtype = torch.float32): dist.broadcast(global_inp, src=0) return global_inp.view(self.world_size, -1)[self.rank].view(16, 16) - def swap_linear_with_dynamic(self, module: nn.Module, **kwargs: Any) -> nn.Module: - return swap_linear_with_float8_linear(module, Float8DynamicLinear, **kwargs) + def swap_linear_with_dynamic( + self, module: nn.Module, use_float8_linear=False, **kwargs: Any + ) -> nn.Module: + if use_float8_linear: + kwargs["scaling_type_x"] = TensorScalingType.DYNAMIC + kwargs["scaling_type_w"] = TensorScalingType.DYNAMIC + kwargs["scaling_type_dL_dY"] = TensorScalingType.DYNAMIC + return swap_linear_with_float8_linear(module, Float8Linear, **kwargs) + else: + return swap_linear_with_float8_linear(module, Float8DynamicLinear, **kwargs) class TestFloat8MultiProcess(FSDPTest, TestFloat8Common): @@ -85,10 +95,16 @@ def world_size(self) -> int: @skip_if_lt_x_gpu(2) def test_transformer_parity_dynamic(self): - for enable_fsdp_fp8_all_gather in [False, True]: - self._test_transformer_parity_dynamic(enable_fsdp_fp8_all_gather) + for enable_fsdp_fp8_all_gather, use_float8_linear in itertools.product( + [False, True], [False, True] + ): + self._test_transformer_parity_dynamic( + enable_fsdp_fp8_all_gather, use_float8_linear + ) - def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool): + def _test_transformer_parity_dynamic( + self, enable_fsdp_fp8_all_gather: bool, use_float8_linear: bool + ): # NOTE: Weight-tying does not compose with fp8 all-gather because the # embedding weight and output linear weight are tied but only the # latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to @@ -96,9 +112,9 @@ def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool): weight_tying = not enable_fsdp_fp8_all_gather module = self.init_transformer(weight_tying=weight_tying) ref_module = copy.deepcopy(module) - ref_module = self.swap_linear_with_dynamic(ref_module).cuda() + ref_module = self.swap_linear_with_dynamic(ref_module, use_float8_linear).cuda() with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): - module = self.swap_linear_with_dynamic(module) + module = self.swap_linear_with_dynamic(module, use_float8_linear) for submodule in module.modules(): if isinstance(submodule, TransformerBlock): fully_shard(submodule) @@ -108,6 +124,8 @@ def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool): local_inp = torch.randint( 0, ref_module.tok_embeddings.weight.size(0), (16, 16), device="cuda" ) + # TODO(future): change Float8DynamicLinear to module_cls below, and + # ensure there is no amax syncing for all-dynamic check_parity_no_mp( self, ref_module, ref_optim, module, optim, local_inp, Float8DynamicLinear ) @@ -115,10 +133,15 @@ def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool): @skip_if_lt_x_gpu(2) def test_transformer_memory(self): """Tests peak active memory in the forward and backward passes.""" - for enable_fsdp_fp8_all_gather in [False, True]: - self._test_transformer_memory(enable_fsdp_fp8_all_gather) - - def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool): + # for enable_fsdp_fp8_all_gather in [False, True]: + for enable_fsdp_fp8_all_gather, use_float8_linear in itertools.product( + [False, True], [False, True] + ): + self._test_transformer_memory(enable_fsdp_fp8_all_gather, use_float8_linear) + + def _test_transformer_memory( + self, enable_fsdp_fp8_all_gather: bool, use_float8_linear: bool + ): torch.manual_seed(42) # Pre-run a linear forward (gemm and bias) and backward (gemm) to # allocate the cuBLAS workspaces before measuring the memory usage @@ -141,7 +164,9 @@ def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool): # Emulate the fp8 matmul to bypass the scaled matmul op's divisibility # requirement to use a smaller activation size with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): - model = self.swap_linear_with_dynamic(model, emulate=True) + model = self.swap_linear_with_dynamic( + model, emulate=True, use_float8_linear=use_float8_linear + ) model_unsharded_numel = sum(p.numel() for p in model.parameters()) model_sharded_numel = (model_unsharded_numel + 1) // 2 block_lin_weight_numel = 0 @@ -242,16 +267,23 @@ class TestFloat8MultiThread(FSDPTestMultiThread, TestFloat8Common): def world_size(self) -> int: return 2 - @unittest.skipIf(not TEST_CUDA, "no cuda") - def test_weight_subclass_dynamic(self): + def _test_weight_subclass_dynamic(self, use_float8_linear): + float8_cls = Float8Linear if use_float8_linear else Float8DynamicLinear + extra_kwargs = {} + if use_float8_linear: + extra_kwargs["scaling_type_x"] = TensorScalingType.DYNAMIC + extra_kwargs["scaling_type_w"] = TensorScalingType.DYNAMIC + extra_kwargs["scaling_type_dL_dY"] = TensorScalingType.DYNAMIC + pass tensor_cls = WeightWithDynamicFloat8CastTensor # Check for a single FSDP paramter group module_fp32 = self.init_single_module() with set_enable_fsdp_fp8_all_gather(True): module = swap_linear_with_float8_linear( module_fp32, - Float8DynamicLinear, + float8_cls, emulate=True, + **extra_kwargs, ) self.assertIsInstance(module.weight, tensor_cls) fully_shard(module) @@ -265,8 +297,9 @@ def test_weight_subclass_dynamic(self): with set_enable_fsdp_fp8_all_gather(True): module = swap_linear_with_float8_linear( module, - Float8DynamicLinear, + float8_cls, emulate=True, + **extra_kwargs, ) for param_name, param in module.named_parameters(): if "weight" in param_name: @@ -280,7 +313,14 @@ def test_weight_subclass_dynamic(self): self.assertIsInstance(param.to_local(), tensor_cls) @unittest.skipIf(not TEST_CUDA, "no cuda") - def test_fp8_fp32_all_gather_dynamic_comm_size(self): + def test_weight_subclass_float8_dynamic_linear(self): + self._test_weight_subclass_dynamic(use_float8_linear=False) + + @unittest.skipIf(not TEST_CUDA, "no cuda") + def test_weight_subclass_float8_linear(self): + self._test_weight_subclass_dynamic(use_float8_linear=True) + + def _test_fp8_fp32_all_gather_dynamic_comm_size(self, use_float8_linear): """ Tests that fp8 all-gather with dynamic scaling communicates the expected number of bytes. @@ -314,7 +354,7 @@ def get_expected_all_gather_size(module: nn.Module): module_fp32 = self.init_single_module() ref_module = copy.deepcopy(module_fp32) with set_enable_fsdp_fp8_all_gather(True): - module = self.swap_linear_with_dynamic(module_fp32) + module = self.swap_linear_with_dynamic(module_fp32, use_float8_linear) fully_shard(module) local_inp = self.get_local_inp() expected_all_gather_size = get_expected_all_gather_size(ref_module) @@ -358,18 +398,30 @@ def get_expected_all_gather_size(module: nn.Module): [s for s in expected_all_gather_sizes for _ in range(self.world_size)], ) + @unittest.skipIf(not TEST_CUDA, "no cuda") + def test_fp8_fp32_all_gather_float8_dynamic_linear_comm_size(self): + self._test_fp8_fp32_all_gather_dynamic_comm_size(use_float8_linear=False) + + @unittest.skipIf(not TEST_CUDA, "no cuda") + def test_fp8_fp32_all_gather_float8_linear_comm_size(self): + self._test_fp8_fp32_all_gather_dynamic_comm_size(use_float8_linear=True) + @unittest.skipIf(not TEST_CUDA, "no cuda") def test_fp32_fp8_single_module_parity(self): """ Tests numeric parity for fp32 parameters with fp8 computation with a single module/FSDP communication group. """ - for enable_fsdp_fp8_all_gather in [False, True]: + for enable_fsdp_fp8_all_gather, use_float8_linear in itertools.product( + [False, True], [False, True] + ): module_fp32 = self.init_single_module() - ref_module = self.swap_linear_with_dynamic(copy.deepcopy(module_fp32)) + ref_module = self.swap_linear_with_dynamic( + copy.deepcopy(module_fp32), use_float8_linear + ) ref_module = ref_module.cuda() with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): - module = self.swap_linear_with_dynamic(module_fp32) + module = self.swap_linear_with_dynamic(module_fp32, use_float8_linear) fully_shard(module) ref_optim = torch.optim.Adam(ref_module.parameters(), lr=1e-2) optim = torch.optim.Adam(module.parameters(), lr=1e-2, foreach=True) @@ -390,12 +442,16 @@ def test_fp32_fp8_multi_module_parity(self): Tests numeric parity for fp32 parameters with fp8 computation with multiple modules/FSDP communication groups. """ - for enable_fsdp_fp8_all_gather in [False, True]: + for enable_fsdp_fp8_all_gather, use_float8_linear in itertools.product( + [False, True], [False, True] + ): module = self.init_multi_module() ref_module = copy.deepcopy(module) - ref_module = self.swap_linear_with_dynamic(ref_module).cuda() + ref_module = self.swap_linear_with_dynamic( + ref_module, use_float8_linear + ).cuda() with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): - module = self.swap_linear_with_dynamic(module) + module = self.swap_linear_with_dynamic(module, use_float8_linear) for submodule in module: fully_shard(submodule) fully_shard(module)