diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index bf3e6c018..602cc0091 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -54,6 +54,8 @@ from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer from fairscale.utils.state_dict import replace_by_prefix_ +logger = logging.getLogger(__name__) + from . import fsdp_optim_utils as ou if TYPE_CHECKING: @@ -64,6 +66,11 @@ else: enable_nccl_base_collectives = True +if os.getenv("PERFORM_ALL_GATHER_IN_8_BITS", "0") == "1": + perform_all_gather_in_8_bits = True +else: + perform_all_gather_in_8_bits = False + try: import fairscale.experimental.nn.ssd_offload as ssd_offload from fairscale.experimental.nn.ssd_offload import SsdFlatParameter @@ -75,6 +82,18 @@ pass +try: + # Use fbgemm_gpu for bf16 <-> fp8 conversion for now. + # This is just a test, if it works well, maybe we can copy those + # kernels and add as part of fairscale. + import fbgemm_gpu + dynamic_file_location = os.path.join(os.path.dirname(fbgemm_gpu.__file__), 'fbgemm_gpu_py.so') + torch.ops.load_library(dynamic_file_location) + FBGEMM_FOUND = True +except: + FBGEMM_FOUND = False + + class TrainingState(Enum): """ Simple enum to indicate what state FSDP is in. Used for asserting @@ -518,6 +537,11 @@ def __init__( if isinstance(m, FullyShardedDataParallel): m._free_ssd_offload() + self.perform_all_gather_in_8_bits = perform_all_gather_in_8_bits and FBGEMM_FOUND + if self.perform_all_gather_in_8_bits: + logger.info("Performing FSDP all_gather in 8 bit precision.") + + def _get_gradient_predivide_factor(self, world_size: int) -> float: factor: int = 1 while world_size % factor == 0 and world_size / factor > factor: @@ -1986,14 +2010,7 @@ def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None: alloc_storage_(p._full_param_padded, size=p_size) output_tensor = p._full_param_padded - # Fill output_tensor with (p.data for each shard in self.world_size) - if hasattr(dist, "_all_gather_base") and enable_nccl_base_collectives: - # New version of PyTorch has all_gather_base, which is faster than chunk and then all_gather. - dist._all_gather_base(output_tensor, p_data, group=self.process_group) - else: - chunks = list(output_tensor.chunk(self.world_size)) - dist.all_gather(chunks, p_data, group=self.process_group) - + self._perform_all_gather(output_tensor, p_data) # Set p.data = output_tensor (with padding trimmed) update_p_data(output_tensor) @@ -2006,6 +2023,63 @@ def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None: torch.cuda.current_stream().wait_stream(self._streams["all_gather"]) return output_tensors + + def _convert_to_fp8(self, input_tensor, ebits=4, mbits=3, exponent_bias=7): + max_pos = (1 << ((1 << ebits) - 2 - exponent_bias)) * (2 - 2 ** (-mbits)) + # Scale cant be communicated or it can but will require + # some complex calculation involving sending out scale from each gpu + # and scaling down chunk of tensor based on scale up. + # + # So for now, assuming tensor max is 1, which is not a bad assumption for + # most NN inits. I am sure we can do better. + # tensor_max = input_tensor.detach().abs().max() + scale = max_pos + tensor_fp8 = torch.ops.fbgemm.FloatToHFP8Quantized( + input_tensor.data * scale, + ebits, + exponent_bias, + max_pos, + ) + return tensor_fp8, scale + + def _convert_from_fp8(self, fp8_tensor, scale, ebits=4, exponent_bias=7): + converted_back_tensor = torch.ops.fbgemm.HFP8QuantizedToFloat( + fp8_tensor.contiguous(), + ebits, + # qparams[1], + exponent_bias, + ) / (scale if scale else 1.0) + return converted_back_tensor + + def _perform_all_gather(self, output_tensor, p_data): + # ebits, mbits, bias = 4, 3, 7 + + if self.perform_all_gather_in_8_bits: + p_data_original = p_data + output_tensor_original = output_tensor + p_data, scale = self._convert_to_fp8( + p_data_original, + ) + output_tensor = torch.empty( + output_tensor_original.shape, + dtype=p_data.dtype, + device=output_tensor_original.device + ) + + # Fill output_tensor with (p.data for each shard in self.world_size) + if hasattr(dist, "_all_gather_base") and enable_nccl_base_collectives: + # New version of PyTorch has all_gather_base, which is faster than chunk and then all_gather. + dist._all_gather_base(output_tensor, p_data, group=self.process_group) + else: + chunks = list(output_tensor.chunk(self.world_size)) + dist.all_gather(chunks, p_data, group=self.process_group) + + if self.perform_all_gather_in_8_bits: + output_tensor = self._convert_from_fp8(output_tensor, scale) + # Cause we only have fp32 <-> fp8 kernels in FBGEMM + output_tensor_original.copy_(output_tensor) + + @torch.no_grad() def _use_full_params(self) -> None: """ @@ -2072,6 +2146,7 @@ def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None: # Storage object and unshard it in-place. For now, just resize # the Storage to 0 to save memory. free_storage_(p._full_param_padded) + torch.cuda.current_stream().synchronize() def local_metadata_dict(self) -> Dict[str, Any]: