From 4fde1e38a57011a209dd9e1f88990b701c8363d9 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 28 Jun 2024 15:33:13 -0700 Subject: [PATCH] [2/x]: fix numerics integration test and test delayed vs dynamic Summary: 1. the SAM test wasn't easy to use because it had real weights and hence required real data for useful testing, which is not convenient from an integration test. Switched to LLaMa FFN with random weights, and made all the thresholds tight to actually check numerics are close. 2. extended numerics test to check all combinations of delayed vs dynamic 3. to be able to do (2), extended the module swap utility to configure delayed vs dynamic on a model level, for now without an option to customize further Test Plan: ``` pytest test/test_numerics_integration.py -s -x ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- float8_experimental/float8_linear_utils.py | 19 +- pyproject.toml | 1 - test/test_everything.sh | 2 +- test/test_numerics_integration.py | 191 +++++++++++++++++++++ test/test_sam.py | 78 --------- 5 files changed, 210 insertions(+), 81 deletions(-) create mode 100644 test/test_numerics_integration.py delete mode 100644 test/test_sam.py diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index 6b7235a..4e6dcb2 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -114,6 +114,8 @@ def filter_out_small_unaligned_layers(size_limit: int) -> Callable[[nn.Linear], ) +# TODO(future PR): probably create a per-linear config which contains +# all of the options (emulate, scaling, etc) def swap_linear_with_float8_linear( module: nn.Module, module_cls: Type[nn.Module], @@ -121,6 +123,9 @@ def swap_linear_with_float8_linear( skip_fqn_list: Optional[List[str]] = None, emulate: bool = False, linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None, + scaling_type_x: TensorScalingType = TensorScalingType.DELAYED, + scaling_type_w: TensorScalingType = TensorScalingType.DELAYED, + scaling_type_dL_dY: TensorScalingType = TensorScalingType.DELAYED, ) -> nn.Module: """ Replaces all instances of ``torch.nn.Linear`` in ``module`` with instances @@ -134,6 +139,9 @@ def swap_linear_with_float8_linear( emulate (bool): Whether to emulate the fp8 matmul logic in fp32. linear_layer_filter (Optional[Callable[[nn.Linear], bool]]): If specified, only the linear layers that pass the filter function will be swapped. + scaling_type_x (TensorScalingType): scaling type for `x` + scaling_type_w (TensorScalingType): scaling type for `w` + scaling_type_dL_dY (TensorScalingType): scaling type for `dL_dY` """ module_names_to_skip = set(skip_fqn_list or []) if isinstance(module, nn.Linear) and ( @@ -167,7 +175,16 @@ def post_order_traversal( assert ( parent_module is not None ), f"Linear root module should return early: {module}" - float8linear_module = module_cls.from_float(module, emulate=emulate) + if module_cls is Float8DynamicLinear: + float8linear_module = module_cls.from_float(module, emulate=emulate) + else: + float8linear_module = module_cls.from_float( + module, + emulate=emulate, + scaling_type_x=scaling_type_x, + scaling_type_w=scaling_type_w, + scaling_type_dL_dY=scaling_type_dL_dY, + ) setattr(parent_module, module_name, float8linear_module) post_order_traversal(root_module, "", None) diff --git a/pyproject.toml b/pyproject.toml index 858e53b..6ecc64e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,6 @@ dependencies = [ [project.optional-dependencies] test = [ - "transformers==4.38.2", "pandas >= 2.0", "tqdm==4.66.2", "fire==0.5.0", diff --git a/test/test_everything.sh b/test/test_everything.sh index b989393..a4cf7ae 100755 --- a/test/test_everything.sh +++ b/test/test_everything.sh @@ -5,8 +5,8 @@ set -e IS_ROCM=$(rocm-smi --version || true) pytest test/test_base.py -pytest test/test_sam.py pytest test/test_compile.py +pytest test/test_numerics_integration.py # These tests do not work on ROCm yet if [ -z "$IS_ROCM" ] diff --git a/test/test_numerics_integration.py b/test/test_numerics_integration.py new file mode 100644 index 0000000..1d571de --- /dev/null +++ b/test/test_numerics_integration.py @@ -0,0 +1,191 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +# Tests LLaMa FeedForward numerics with float8 + +import copy +from typing import Optional + +import pytest + +import torch +import torch.nn as nn +import torch.nn.functional as F +from float8_experimental.float8_dynamic_linear import Float8DynamicLinear +from float8_experimental.float8_linear import Float8Linear, TensorScalingType +from float8_experimental.float8_linear_utils import ( + linear_requires_sync, + LinearType, + swap_linear_with_float8_linear, + sync_float8_amax_and_scale_history, +) +from float8_experimental.float8_utils import compute_error, IS_ROCM + +is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) + + +torch.manual_seed(0) + + +# copied from https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py +class FeedForward(nn.Module): + """ + FeedForward module + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None. + + Attributes: + w1 (Linear): Linear transformation for the first layer. + w2 (Linear): Linear transformation for the second layer. + w3 (Linear): Linear transformation for the third layer. + + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + for linear in (self.w2, self.w3): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + + +class TestFloat8NumericsIntegrationTest: + @pytest.mark.parametrize( + "scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + ) + @pytest.mark.parametrize( + "scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + ) + @pytest.mark.parametrize( + "scaling_type_dL_dY", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC] + ) + @pytest.mark.parametrize("linear_cls", [Float8Linear, Float8DynamicLinear]) + @pytest.mark.skipif(not is_H100, reason="requires H100 GPU") + @pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") + def test_encoder_fw_bw( + self, + linear_cls, + scaling_type_x: TensorScalingType, + scaling_type_w: TensorScalingType, + scaling_type_dL_dY: TensorScalingType, + ): + linear_type = ( + LinearType.DELAYED if linear_cls == Float8Linear else LinearType.DYNAMIC + ) + if linear_type is LinearType.DYNAMIC: + # Only test one combination of scaling types, as they are a no-op + # for Float8DynamicLinear. It would be cleaner to split into two + # tests, but IMO not worth it since Float8DynamicLinear will be + # deleted soon + is_all_dynamic = ( + scaling_type_x is TensorScalingType.DYNAMIC + and scaling_type_w is TensorScalingType.DYNAMIC + and scaling_type_dL_dY is TensorScalingType.DYNAMIC + ) + if not is_all_dynamic: + pytest.skip() + + # TODO(later): maybe add float16 back if it becomes important + data_dtype = torch.bfloat16 + + # LLaMa 3 70B shapes + model_ref = ( + FeedForward( + dim=4096, + hidden_dim=16384, + multiple_of=1024, + ffn_dim_multiplier=1.3, + ) + .cuda() + .to(data_dtype) + ) + + # for now just test the encoder to simplify things + model_fp8 = copy.deepcopy(model_ref) + swap_linear_with_float8_linear( + model_fp8, + linear_cls, + emulate=False, + scaling_type_x=scaling_type_x, + scaling_type_w=scaling_type_w, + scaling_type_dL_dY=scaling_type_dL_dY, + ) + + lr = 0.01 + optim_ref = torch.optim.SGD(model_ref.parameters(), lr=lr) + optim_fp8 = torch.optim.SGD(model_fp8.parameters(), lr=lr) + + # Note: you need two different inputs to properly test numerics + # of delayed scaling, because the first time around the initialization + # logic of delayed scaling behaves as dynamic scaling + # TODO(future): also make unit tests do this properly + shape = (1, 8192, 4096) + data1 = torch.randn(*shape, device="cuda", dtype=data_dtype) + data2 = torch.randn(*shape, device="cuda", dtype=data_dtype) + + model_ref(data1).sum().backward() + # zero out grads without stepping, since we just want to compare grads + # of the second datum + optim_ref.zero_grad() + model_ref_out = model_ref(data2) + model_ref_out.sum().backward() + + if linear_requires_sync( + linear_type, scaling_type_x, scaling_type_w, scaling_type_dL_dY + ): + sync_float8_amax_and_scale_history(model_fp8) + model_fp8(data1).sum().backward() + # zero out grads without stepping, since we just want to compare grads + # of the second datum + optim_fp8.zero_grad() + if linear_requires_sync( + linear_type, scaling_type_x, scaling_type_w, scaling_type_dL_dY + ): + sync_float8_amax_and_scale_history(model_fp8) + model_fp8_out = model_fp8(data2) + model_fp8_out.sum().backward() + + out_sqnr = compute_error(model_ref_out, model_fp8_out) + assert out_sqnr > 20.0 + + ref_name_to_grad = { + name: param.grad for name, param in model_ref.named_parameters() + } + + grad_sqnr_threshold = 20.0 + + for name, param in model_fp8.named_parameters(): + ref_grad = ref_name_to_grad[name] + cur_grad = param.grad + sqnr = compute_error(ref_grad, cur_grad) + assert sqnr > grad_sqnr_threshold + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/test/test_sam.py b/test/test_sam.py deleted file mode 100644 index 9341241..0000000 --- a/test/test_sam.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. -# Tests SAM with real weights with float8 -# if we want finetuning later, we can use -# https://github.com/NielsRogge/Transformers-Tutorials/blob/master/SAM/Fine_tune_SAM_(segment_anything)_on_a_custom_dataset.ipynb - -import copy - -import pytest - -import torch -from float8_experimental.float8_dynamic_linear import Float8DynamicLinear -from float8_experimental.float8_linear import Float8Linear -from float8_experimental.float8_linear_utils import ( - swap_linear_with_float8_linear, - sync_float8_amax_and_scale_history, -) -from float8_experimental.float8_utils import compute_error, IS_ROCM -from transformers import SamModel - -is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) - - -torch.manual_seed(0) - - -class TestFloat8SAMIntegrationTest: - @pytest.mark.parametrize("data_dtype", [torch.float16, torch.bfloat16]) - @pytest.mark.parametrize("linear_type", [Float8Linear, Float8DynamicLinear]) - @pytest.mark.skipif(not is_H100, reason="requires H100 GPU") - @pytest.mark.skipif(IS_ROCM, reason="test doesn't currently work on the ROCm stack") - def test_encoder_fw_bw(self, data_dtype, linear_type): - model = SamModel.from_pretrained("facebook/sam-vit-base").to(data_dtype).cuda() - # print(model) - - # for now just test the encoder to simplify things - encoder_ref = model.vision_encoder - encoder_fp8 = copy.deepcopy(encoder_ref) - swap_linear_with_float8_linear(encoder_fp8, linear_type, emulate=False) - - # an image - # Note: bsz==4 or a larger power of 2 for this model is needed to - # ensure all matmuls have arguments with dimensions divisible by 16 - data = torch.randn(4, 3, 1024, 1024).to(data_dtype).cuda() - - encoder_ref_out = encoder_ref(data) - last_hidden_ref = encoder_ref_out.last_hidden_state - last_hidden_ref.max().backward() - - sync_float8_amax_and_scale_history(encoder_fp8) - encoder_fp8_out = encoder_fp8(data) - last_hidden_fp8 = encoder_fp8_out.last_hidden_state - last_hidden_fp8.max().backward() - - hidden_sqnr = compute_error(last_hidden_ref, last_hidden_fp8) - assert hidden_sqnr > 20.0 - - ref_name_to_grad = { - name: param.grad for name, param in encoder_ref.named_parameters() - } - - # Delayed scaling has less performant numerics - fudge_factor = 7.0 if linear_type == Float8Linear else 1.0 - sqnr_threshold = -1.0 if data_dtype == torch.float16 else -4 - sqnr_threshold *= fudge_factor - - for name, param in encoder_fp8.named_parameters(): - ref_grad = ref_name_to_grad[name] - cur_grad = param.grad - sqnr = compute_error(ref_grad, cur_grad) - assert sqnr > sqnr_threshold - - -if __name__ == "__main__": - pytest.main([__file__])