diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index c2193b9..cbf992e 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -191,10 +191,41 @@ def swap_linear_with_float8_linear( skip_fqn_list: Optional[List[str]] = None, emulate: bool = False, linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None, + scaling_type_x: TensorScalingType = TensorScalingType.DELAYED, + scaling_type_w: TensorScalingType = TensorScalingType.DELAYED, + scaling_type_dL_dY: TensorScalingType = TensorScalingType.DELAYED, ) -> Optional[nn.Module]: + """ + Swaps `torch.nn.Linear` in `module` with `Float8Linear` or `Float8DynamicLinear`. + + Args: + module: Module to modify. + module_cls: `Float8Linear` or `Float8DynamicLinear`. + from_float_func: Function that accepts a linear layer and returns a new type of linear layer. + skip_fqn_list: If specified, a list of module FQNs to skip. + emulate: If True, emulation is used instead of hardware accelerated gemm + linear_layer_filter: If specified, only the linear layers + that pass the filter function will be swapped. + scaling_type_x (TensorScalingType): scaling type for `x` + scaling_type_w (TensorScalingType): scaling type for `w` + scaling_type_dL_dY (TensorScalingType): scaling type for `dL_dY` + + Returns: + nn.Module: The modified module with swapped linear layers. + """ + if module_cls is Float8DynamicLinear: + from_float = lambda m: module_cls.from_float(m, emulate=emulate) + else: + from_float = lambda m: module_cls.from_float( + m, + emulate=emulate, + scaling_type_x=scaling_type_x, + scaling_type_w=scaling_type_w, + scaling_type_dL_dY=scaling_type_dL_dY, + ) return swap_linear_layers( module, - lambda m: module_cls.from_float(m, emulate=emulate), + from_float, skip_fqn_list=skip_fqn_list, linear_layer_filter=linear_layer_filter, ) diff --git a/pyproject.toml b/pyproject.toml index addd522..b0ee7bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,6 @@ dependencies = [ [project.optional-dependencies] test = [ - "transformers==4.38.2", "pandas >= 2.0", "tqdm==4.66.2", "fire==0.5.0", diff --git a/test/test_everything.sh b/test/test_everything.sh index ada305d..5eeb17c 100755 --- a/test/test_everything.sh +++ b/test/test_everything.sh @@ -5,9 +5,9 @@ set -e IS_ROCM=$(rocm-smi --version || true) pytest test/test_base.py -pytest test/test_sam.py pytest test/test_compile.py pytest test/test_inference_flows.py +pytest test/test_numerics_integration.py # These tests do not work on ROCm yet if [ -z "$IS_ROCM" ] diff --git a/test/test_numerics_integration.py b/test/test_numerics_integration.py new file mode 100644 index 0000000..1d571de --- /dev/null +++ b/test/test_numerics_integration.py @@ -0,0 +1,191 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +# Tests LLaMa FeedForward numerics with float8 + +import copy +from typing import Optional + +import pytest + +import torch +import torch.nn as nn +import torch.nn.functional as F +from float8_experimental.float8_dynamic_linear import Float8DynamicLinear +from float8_experimental.float8_linear import Float8Linear, TensorScalingType +from float8_experimental.float8_linear_utils import ( + linear_requires_sync, + LinearType, + swap_linear_with_float8_linear, + sync_float8_amax_and_scale_history, +) +from float8_experimental.float8_utils import compute_error, IS_ROCM + +is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) + + +torch.manual_seed(0) + + +# copied from https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py +class FeedForward(nn.Module): + """ + FeedForward module + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None. + + Attributes: + w1 (Linear): Linear transformation for the first layer. + w2 (Linear): Linear transformation for the second layer. + w3 (Linear): Linear transformation for the third layer. + + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + for linear in (self.w2, self.w3): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + + +class TestFloat8NumericsIntegrationTest: + @pytest.mark.parametrize( + "scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + ) + @pytest.mark.parametrize( + "scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + ) + @pytest.mark.parametrize( + "scaling_type_dL_dY", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + ) + @pytest.mark.parametrize("linear_cls", [Float8Linear, Float8DynamicLinear]) + @pytest.mark.skipif(not is_H100, reason="requires H100 GPU") + @pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") + def test_encoder_fw_bw( + self, + linear_cls, + scaling_type_x: TensorScalingType, + scaling_type_w: TensorScalingType, + scaling_type_dL_dY: TensorScalingType, + ): + linear_type = ( + LinearType.DELAYED if linear_cls == Float8Linear else LinearType.DYNAMIC + ) + if linear_type is LinearType.DYNAMIC: + # Only test one combination of scaling types, as they are a no-op + # for Float8DynamicLinear. It would be cleaner to split into two + # tests, but IMO not worth it since Float8DynamicLinear will be + # deleted soon + is_all_dynamic = ( + scaling_type_x is TensorScalingType.DYNAMIC + and scaling_type_w is TensorScalingType.DYNAMIC + and scaling_type_dL_dY is TensorScalingType.DYNAMIC + ) + if not is_all_dynamic: + pytest.skip() + + # TODO(later): maybe add float16 back if it becomes important + data_dtype = torch.bfloat16 + + # LLaMa 3 70B shapes + model_ref = ( + FeedForward( + dim=4096, + hidden_dim=16384, + multiple_of=1024, + ffn_dim_multiplier=1.3, + ) + .cuda() + .to(data_dtype) + ) + + # for now just test the encoder to simplify things + model_fp8 = copy.deepcopy(model_ref) + swap_linear_with_float8_linear( + model_fp8, + linear_cls, + emulate=False, + scaling_type_x=scaling_type_x, + scaling_type_w=scaling_type_w, + scaling_type_dL_dY=scaling_type_dL_dY, + ) + + lr = 0.01 + optim_ref = torch.optim.SGD(model_ref.parameters(), lr=lr) + optim_fp8 = torch.optim.SGD(model_fp8.parameters(), lr=lr) + + # Note: you need two different inputs to properly test numerics + # of delayed scaling, because the first time around the initialization + # logic of delayed scaling behaves as dynamic scaling + # TODO(future): also make unit tests do this properly + shape = (1, 8192, 4096) + data1 = torch.randn(*shape, device="cuda", dtype=data_dtype) + data2 = torch.randn(*shape, device="cuda", dtype=data_dtype) + + model_ref(data1).sum().backward() + # zero out grads without stepping, since we just want to compare grads + # of the second datum + optim_ref.zero_grad() + model_ref_out = model_ref(data2) + model_ref_out.sum().backward() + + if linear_requires_sync( + linear_type, scaling_type_x, scaling_type_w, scaling_type_dL_dY + ): + sync_float8_amax_and_scale_history(model_fp8) + model_fp8(data1).sum().backward() + # zero out grads without stepping, since we just want to compare grads + # of the second datum + optim_fp8.zero_grad() + if linear_requires_sync( + linear_type, scaling_type_x, scaling_type_w, scaling_type_dL_dY + ): + sync_float8_amax_and_scale_history(model_fp8) + model_fp8_out = model_fp8(data2) + model_fp8_out.sum().backward() + + out_sqnr = compute_error(model_ref_out, model_fp8_out) + assert out_sqnr > 20.0 + + ref_name_to_grad = { + name: param.grad for name, param in model_ref.named_parameters() + } + + grad_sqnr_threshold = 20.0 + + for name, param in model_fp8.named_parameters(): + ref_grad = ref_name_to_grad[name] + cur_grad = param.grad + sqnr = compute_error(ref_grad, cur_grad) + assert sqnr > grad_sqnr_threshold + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/test/test_sam.py b/test/test_sam.py deleted file mode 100644 index 9341241..0000000 --- a/test/test_sam.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. -# Tests SAM with real weights with float8 -# if we want finetuning later, we can use -# https://github.com/NielsRogge/Transformers-Tutorials/blob/master/SAM/Fine_tune_SAM_(segment_anything)_on_a_custom_dataset.ipynb - -import copy - -import pytest - -import torch -from float8_experimental.float8_dynamic_linear import Float8DynamicLinear -from float8_experimental.float8_linear import Float8Linear -from float8_experimental.float8_linear_utils import ( - swap_linear_with_float8_linear, - sync_float8_amax_and_scale_history, -) -from float8_experimental.float8_utils import compute_error, IS_ROCM -from transformers import SamModel - -is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) - - -torch.manual_seed(0) - - -class TestFloat8SAMIntegrationTest: - @pytest.mark.parametrize("data_dtype", [torch.float16, torch.bfloat16]) - @pytest.mark.parametrize("linear_type", [Float8Linear, Float8DynamicLinear]) - @pytest.mark.skipif(not is_H100, reason="requires H100 GPU") - @pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") - def test_encoder_fw_bw(self, data_dtype, linear_type): - model = SamModel.from_pretrained("facebook/sam-vit-base").to(data_dtype).cuda() - # print(model) - - # for now just test the encoder to simplify things - encoder_ref = model.vision_encoder - encoder_fp8 = copy.deepcopy(encoder_ref) - swap_linear_with_float8_linear(encoder_fp8, linear_type, emulate=False) - - # an image - # Note: bsz==4 or a larger power of 2 for this model is needed to - # ensure all matmuls have arguments with dimensions divisible by 16 - data = torch.randn(4, 3, 1024, 1024).to(data_dtype).cuda() - - encoder_ref_out = encoder_ref(data) - last_hidden_ref = encoder_ref_out.last_hidden_state - last_hidden_ref.max().backward() - - sync_float8_amax_and_scale_history(encoder_fp8) - encoder_fp8_out = encoder_fp8(data) - last_hidden_fp8 = encoder_fp8_out.last_hidden_state - last_hidden_fp8.max().backward() - - hidden_sqnr = compute_error(last_hidden_ref, last_hidden_fp8) - assert hidden_sqnr > 20.0 - - ref_name_to_grad = { - name: param.grad for name, param in encoder_ref.named_parameters() - } - - # Delayed scaling has less performant numerics - fudge_factor = 7.0 if linear_type == Float8Linear else 1.0 - sqnr_threshold = -1.0 if data_dtype == torch.float16 else -4 - sqnr_threshold *= fudge_factor - - for name, param in encoder_fp8.named_parameters(): - ref_grad = ref_name_to_grad[name] - cur_grad = param.grad - sqnr = compute_error(ref_grad, cur_grad) - assert sqnr > sqnr_threshold - - -if __name__ == "__main__": - pytest.main([__file__])