diff --git a/benchmarks/bench_linear_float8.py b/benchmarks/bench_linear_float8.py index ff6a52a..2780600 100644 --- a/benchmarks/bench_linear_float8.py +++ b/benchmarks/bench_linear_float8.py @@ -147,7 +147,6 @@ def main( linear_float8 = Float8Linear.from_float( copy.deepcopy(linear_ref), - emulate=False, config=config, ) scaling_repr = linear_float8.scaling_repr() diff --git a/benchmarks/bench_multi_gpu.py b/benchmarks/bench_multi_gpu.py index b099d34..9e82cdb 100644 --- a/benchmarks/bench_multi_gpu.py +++ b/benchmarks/bench_multi_gpu.py @@ -79,7 +79,6 @@ def get_model(K, N, is_fp8, base_dtype=torch.float32): if is_fp8: swap_linear_with_float8_linear( m, - emulate=False, config=config, ) return m diff --git a/float8_experimental/config.py b/float8_experimental/config.py index c063c78..e190b2f 100644 --- a/float8_experimental/config.py +++ b/float8_experimental/config.py @@ -69,6 +69,9 @@ class Float8LinearConfig: # This can cause a memory spike however so we keep this off by default. pad_inner_dim: bool = False + # If True, emulation is used instead of hardware accelerated gemm + emulate: bool = False + # If True, use 'fnuz' float8 types for calculations. # Currently, ROCm only supports fnuz variants. diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 782fa6b..6c1e3e0 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -166,8 +166,8 @@ def __init__(self, *args, **kwargs): ) # Amax scales should always be kept as float32. self.always_float32_buffers = set() - emulate = kwargs.pop("emulate", False) config = kwargs.pop("config") + emulate = config.emulate super().__init__(*args, **kwargs) # Defines the scaling behavior of input, weight, grad_output @@ -434,7 +434,6 @@ def extra_repr(self): def from_float( cls, mod, - emulate: bool = False, config: Optional[Float8LinearConfig] = None, ): """ @@ -442,7 +441,6 @@ def from_float( Args: mod (torch.nn.Linear): nn.Linear to convert - emulate (bool): whether to emulate fp8 matmul logic in float32 config (Optional[Float8LinearConfig]): configuration for conversion to float8 """ if config is None: @@ -452,7 +450,6 @@ def from_float( mod.in_features, mod.out_features, bias=False, - emulate=emulate, config=config, ) new_mod.weight = mod.weight diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index 4e4ad3f..41365df 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -127,7 +127,6 @@ def post_order_traversal( def swap_linear_with_float8_linear( module: nn.Module, *, - emulate: bool = False, module_filter_fn: Optional[Callable[[str, nn.Module], bool]] = None, config: Float8LinearConfig = None, ) -> Optional[nn.Module]: @@ -136,7 +135,6 @@ def swap_linear_with_float8_linear( Args: module: Module to modify. - emulate: If True, emulation is used instead of hardware accelerated gemm module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that that pass the filter function will be swapped. The inputs to the filter function are the FQN and module instance. @@ -149,7 +147,6 @@ def swap_linear_with_float8_linear( config = Float8LinearConfig() from_float = lambda m: Float8Linear.from_float( m, - emulate=emulate, config=config, ) return swap_linear_layers( diff --git a/test/test_base.py b/test/test_base.py index 9c102df..10d53b6 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -152,12 +152,10 @@ def _test_linear_impl( self, x, m_ref, - emulate: bool, config: Float8LinearConfig, ): m_fp8 = Float8Linear.from_float( copy.deepcopy(m_ref), - emulate, config, ) for _ in range(2): @@ -264,11 +262,11 @@ def test_linear( cast_config_grad_output=Float8TensorCastConfig( scaling_type=scaling_type_grad_output ), + emulate=emulate, ) self._test_linear_impl( x, m_ref, - emulate, config, ) @@ -303,8 +301,9 @@ def test_autocast_outputs( cast_config_grad_output=Float8TensorCastConfig( scaling_type=TensorScalingType.DELAYED ), + emulate=emulate, ) - m = Float8Linear.from_float(copy.deepcopy(m_ref), emulate, config) + m = Float8Linear.from_float(copy.deepcopy(m_ref), config) # autocast off x = torch.randn(16, 32, device="cuda", dtype=linear_dtype) @@ -339,8 +338,8 @@ def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool): ) m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) - config = Float8LinearConfig() - m = Float8Linear.from_float(copy.deepcopy(m), emulate, config) + config = Float8LinearConfig(emulate=emulate) + m = Float8Linear.from_float(copy.deepcopy(m), config) # Cast the module to dtype m = m.to(dtype=linear_dtype) @@ -389,10 +388,10 @@ def test_repr(self): cast_config_weight=Float8TensorCastConfig( scaling_type=TensorScalingType.DELAYED ), + emulate=True, ) m = Float8Linear.from_float( copy.deepcopy(m), - emulate=True, config=config, ) s = m.__repr__() @@ -604,7 +603,8 @@ class TestFloat8LinearUtils(unittest.TestCase): def test_swap_root_linear(self): for emulate in [True, False]: module = nn.Linear(3, 3) - module = swap_linear_with_float8_linear(module, emulate=emulate) + config = Float8LinearConfig(emulate=emulate) + module = swap_linear_with_float8_linear(module, config=config) self.assertIsInstance(module, Float8Linear) self.assertEqual(module.linear_mm_config.y.emulate, emulate) self.assertEqual(module.linear_mm_config.y.emulate, emulate) @@ -613,11 +613,12 @@ def test_swap_root_linear_with_children_raises(self): for emulate in [True, False]: module = nn.Linear(3, 3) module.child = nn.Sequential(nn.Linear(3, 3)) + config = Float8LinearConfig(emulate=emulate) with self.assertRaisesRegex( AssertionError, "Does not support a root nn.Linear with children", ): - swap_linear_with_float8_linear(module, emulate=emulate) + swap_linear_with_float8_linear(module, config=config) def test_swap_submodule_linears(self): class MLP(nn.Module): @@ -628,7 +629,8 @@ def __init__(self, dim: int): for emulate in [True, False]: model = nn.Sequential(MLP(3), nn.Linear(3, 3), MLP(3)) - model = swap_linear_with_float8_linear(model, emulate=emulate) + config = Float8LinearConfig(emulate=emulate) + model = swap_linear_with_float8_linear(model, config=config) self.assertIsInstance(model[0].lin1, Float8Linear) self.assertIsInstance(model[0].lin2, Float8Linear) self.assertIsInstance(model[1], Float8Linear) @@ -655,9 +657,10 @@ def module_filter_fn(fqn, mod): and mod.out_features % 16 == 0 ) + config = Float8LinearConfig(emulate=True) model = swap_linear_with_float8_linear( model, - emulate=True, + config=config, module_filter_fn=module_filter_fn, ) # in_features=8, out_features=32, 8 is less than 32. @@ -683,9 +686,10 @@ def __init__(self, dim: int): "0.lin2", "2.lin1", ] + config = Float8LinearConfig(emulate=True) model = swap_linear_with_float8_linear( model, - emulate=True, + config=config, module_filter_fn=module_filter_fn, ) self.assertTrue(type(model[0].lin1) is Float8Linear) diff --git a/test/test_compile.py b/test/test_compile.py index 3525585..cd2372d 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -36,7 +36,6 @@ def _test_compile_base( backend: str, fullgraph: bool, - emulate: bool, config: Float8LinearConfig, dtype: torch.dtype, ): @@ -50,7 +49,6 @@ def _test_compile_base( m_fp8 = Float8Linear.from_float( copy.deepcopy(m_ref), - emulate, config, ) @@ -95,11 +93,11 @@ def test_eager_only( cast_config_grad_output=Float8TensorCastConfig( scaling_type=scaling_type_grad_output ), + emulate=emulate, ) _test_compile_base( "eager", fullgraph, - emulate, config, dtype, ) @@ -133,11 +131,11 @@ def test_aot_eager( cast_config_grad_output=Float8TensorCastConfig( scaling_type=scaling_type_grad_output ), + emulate=emulate, ) _test_compile_base( "aot_eager", fullgraph, - emulate, config, dtype, ) @@ -171,11 +169,11 @@ def test_inductor( cast_config_grad_output=Float8TensorCastConfig( scaling_type=scaling_type_grad_output ), + emulate=emulate, ) _test_compile_base( "inductor", fullgraph, - emulate, config, dtype, ) @@ -315,11 +313,20 @@ def test_sync_amax_func_cuda_graph_success(): my_module = nn.Sequential( nn.Linear(16, 32, bias=True), nn.ReLU(), nn.Linear(32, 16, bias=True) ).to("cuda") + config = Float8LinearConfig( + cast_config_input=Float8TensorCastConfig( + scaling_type=TensorScalingType.DELAYED + ), + cast_config_weight=Float8TensorCastConfig( + scaling_type=TensorScalingType.DELAYED + ), + cast_config_grad_output=Float8TensorCastConfig( + scaling_type=TensorScalingType.DELAYED + ), + ) swap_linear_with_float8_linear( my_module, - scaling_type_input=TensorScalingType.DELAYED, - scaling_type_weight=TensorScalingType.DELAYED, - scaling_type_grad_output=TensorScalingType.DELAYED, + config=config, ) inpt = torch.randn( 16, 16, device="cuda", dtype=torch.float32, requires_grad=True diff --git a/test/test_dtensor.py b/test/test_dtensor.py index 606b3a5..842e214 100644 --- a/test/test_dtensor.py +++ b/test/test_dtensor.py @@ -13,7 +13,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from float8_experimental.config import TensorScalingType +from float8_experimental import Float8LinearConfig from float8_experimental.float8_dynamic_utils import NoopFwToFloat8E5M2Bw from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear @@ -184,14 +184,15 @@ def _test_fp8_mlp_tensor_parallelism_base( # For now, only supports dynamic scaling of `x` and `dL_dY`. # TODO(future): add support for float8 all-gather with delayed scaling # for activations and gradients. + config = Float8LinearConfig(emulate=True) toy_model = ToyModel().to(device) - toy_model_fp8 = swap_linear_with_float8_linear(toy_model, emulate=True) + toy_model_fp8 = swap_linear_with_float8_linear(toy_model, config=config) tp_model = copy.deepcopy(toy_model) - tp_model = swap_linear_with_float8_linear(tp_model, emulate=True) + tp_model = swap_linear_with_float8_linear(tp_model, config=config) sp_model = copy.deepcopy(toy_model) - sp_model = swap_linear_with_float8_linear(sp_model, emulate=True) + sp_model = swap_linear_with_float8_linear(sp_model, config=config) # vanilla TP tp_model = parallelize_module( @@ -222,7 +223,7 @@ def _test_fp8_mlp_tensor_parallelism_base( # PrepareFloat8ModuleInput with specific submodule fqn sp_model2 = copy.deepcopy(toy_model) - sp_model2 = swap_linear_with_float8_linear(sp_model2, emulate=True) + sp_model2 = swap_linear_with_float8_linear(sp_model2, config=config) sp_model2 = parallelize_module( sp_model2, diff --git a/test/test_fsdp.py b/test/test_fsdp.py index 637be89..202e7e3 100644 --- a/test/test_fsdp.py +++ b/test/test_fsdp.py @@ -84,13 +84,14 @@ def fsdp_main(rank, world_size, args): ) config = Float8LinearConfig( cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight), + # TODO(future): delete this arg as it's always False + emulate=False, ) # Note: we only iterate over `scaling_type_weight` because FSDP only interacts # with weights. swap_linear_with_float8_linear( model_fp8, - emulate=False, config=config, ) diff --git a/test/test_fsdp2/test_fsdp2.py b/test/test_fsdp2/test_fsdp2.py index dacf911..fca34a4 100644 --- a/test/test_fsdp2/test_fsdp2.py +++ b/test/test_fsdp2/test_fsdp2.py @@ -185,8 +185,9 @@ def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool): # requirement to use a smaller activation size float8_linear_config = Float8LinearConfig( enable_fsdp_fp8_all_gather=enable_fsdp_fp8_all_gather, + emulate=True, ) - swap_linear_with_float8_linear(model, emulate=True, config=float8_linear_config) + swap_linear_with_float8_linear(model, config=float8_linear_config) model_unsharded_numel = sum(p.numel() for p in model.parameters()) model_sharded_numel = (model_unsharded_numel + 1) // 2 block_lin_weight_numel = 0 @@ -294,10 +295,10 @@ def test_weight_subclass_dynamic(self): module_fp32 = self.init_single_module() float8_linear_config = Float8LinearConfig( enable_fsdp_fp8_all_gather=True, + emulate=True, ) module = swap_linear_with_float8_linear( module_fp32, - emulate=True, config=float8_linear_config, ) self.assertIsInstance(module.weight, tensor_cls) @@ -311,7 +312,6 @@ def test_weight_subclass_dynamic(self): module = self.init_multi_module() module = swap_linear_with_float8_linear( module, - emulate=True, config=float8_linear_config, ) for param_name, param in module.named_parameters(): @@ -517,12 +517,13 @@ def test_bf16_mp_fp8_dynamic_multi_parity(self): # parameters, changing the numerics. module = self.init_multi_module() ref_module_bf16 = copy.deepcopy(module).to(torch.bfloat16) + float8_config = Float8LinearConfig(emulate=True) ref_module_bf16 = swap_linear_with_float8_linear( ref_module_bf16, - emulate=True, + config=float8_config, ) ref_module_fp32 = copy.deepcopy(module).cuda() - module = swap_linear_with_float8_linear(module, emulate=True) + module = swap_linear_with_float8_linear(module, config=float8_config) mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16) for mlp in module: fully_shard(mlp, mp_policy=mp_policy) diff --git a/test/test_fsdp_compile.py b/test/test_fsdp_compile.py index c23cfbe..2080782 100644 --- a/test/test_fsdp_compile.py +++ b/test/test_fsdp_compile.py @@ -66,6 +66,7 @@ def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32): cast_config_grad_output=Float8TensorCastConfig( scaling_type=TensorScalingType.DELAYED ), + emulate=emulate, ) m = nn.Sequential( @@ -74,7 +75,6 @@ def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32): ) swap_linear_with_float8_linear( m, - emulate=emulate, config=config, ) return m diff --git a/test/test_numerics_integration.py b/test/test_numerics_integration.py index 2260780..9b8e2e4 100644 --- a/test/test_numerics_integration.py +++ b/test/test_numerics_integration.py @@ -122,7 +122,6 @@ def test_encoder_fw_bw( ) swap_linear_with_float8_linear( model_fp8, - emulate=False, config=config, )