Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

8 bit all_gather #1105

Open
wants to merge 2 commits into
base: ngoyal_bf16_changes
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
91 changes: 83 additions & 8 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer
from fairscale.utils.state_dict import replace_by_prefix_

logger = logging.getLogger(__name__)

from . import fsdp_optim_utils as ou

if TYPE_CHECKING:
Expand All @@ -64,6 +66,11 @@
else:
enable_nccl_base_collectives = True

if os.getenv("PERFORM_ALL_GATHER_IN_8_BITS", "0") == "1":
perform_all_gather_in_8_bits = True
else:
perform_all_gather_in_8_bits = False

try:
import fairscale.experimental.nn.ssd_offload as ssd_offload
from fairscale.experimental.nn.ssd_offload import SsdFlatParameter
Expand All @@ -75,6 +82,18 @@
pass


try:
# Use fbgemm_gpu for bf16 <-> fp8 conversion for now.
# This is just a test, if it works well, maybe we can copy those
# kernels and add as part of fairscale.
import fbgemm_gpu
dynamic_file_location = os.path.join(os.path.dirname(fbgemm_gpu.__file__), 'fbgemm_gpu_py.so')
torch.ops.load_library(dynamic_file_location)
FBGEMM_FOUND = True
except:
FBGEMM_FOUND = False


class TrainingState(Enum):
"""
Simple enum to indicate what state FSDP is in. Used for asserting
Expand Down Expand Up @@ -518,6 +537,11 @@ def __init__(
if isinstance(m, FullyShardedDataParallel):
m._free_ssd_offload()

self.perform_all_gather_in_8_bits = perform_all_gather_in_8_bits and FBGEMM_FOUND
if self.perform_all_gather_in_8_bits:
logger.info("Performing FSDP all_gather in 8 bit precision.")


def _get_gradient_predivide_factor(self, world_size: int) -> float:
factor: int = 1
while world_size % factor == 0 and world_size / factor > factor:
Expand Down Expand Up @@ -1986,14 +2010,7 @@ def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None:
alloc_storage_(p._full_param_padded, size=p_size)
output_tensor = p._full_param_padded

# Fill output_tensor with (p.data for each shard in self.world_size)
if hasattr(dist, "_all_gather_base") and enable_nccl_base_collectives:
# New version of PyTorch has all_gather_base, which is faster than chunk and then all_gather.
dist._all_gather_base(output_tensor, p_data, group=self.process_group)
else:
chunks = list(output_tensor.chunk(self.world_size))
dist.all_gather(chunks, p_data, group=self.process_group)

self._perform_all_gather(output_tensor, p_data)
# Set p.data = output_tensor (with padding trimmed)
update_p_data(output_tensor)

Expand All @@ -2006,6 +2023,63 @@ def update_p_data(custom_output_tensor: Optional[torch.Tensor] = None) -> None:
torch.cuda.current_stream().wait_stream(self._streams["all_gather"])
return output_tensors


def _convert_to_fp8(self, input_tensor, ebits=4, mbits=3, exponent_bias=7):
max_pos = (1 << ((1 << ebits) - 2 - exponent_bias)) * (2 - 2 ** (-mbits))
# Scale cant be communicated or it can but will require
# some complex calculation involving sending out scale from each gpu
# and scaling down chunk of tensor based on scale up.
#
# So for now, assuming tensor max is 1, which is not a bad assumption for
# most NN inits. I am sure we can do better.
# tensor_max = input_tensor.detach().abs().max()
scale = max_pos
tensor_fp8 = torch.ops.fbgemm.FloatToHFP8Quantized(
input_tensor.data * scale,
ebits,
exponent_bias,
max_pos,
)
return tensor_fp8, scale

def _convert_from_fp8(self, fp8_tensor, scale, ebits=4, exponent_bias=7):
converted_back_tensor = torch.ops.fbgemm.HFP8QuantizedToFloat(
fp8_tensor.contiguous(),
ebits,
# qparams[1],
exponent_bias,
) / (scale if scale else 1.0)
return converted_back_tensor

def _perform_all_gather(self, output_tensor, p_data):
# ebits, mbits, bias = 4, 3, 7

if self.perform_all_gather_in_8_bits:
p_data_original = p_data
output_tensor_original = output_tensor
p_data, scale = self._convert_to_fp8(
p_data_original,
)
output_tensor = torch.empty(
output_tensor_original.shape,
dtype=p_data.dtype,
device=output_tensor_original.device
)

# Fill output_tensor with (p.data for each shard in self.world_size)
if hasattr(dist, "_all_gather_base") and enable_nccl_base_collectives:
# New version of PyTorch has all_gather_base, which is faster than chunk and then all_gather.
dist._all_gather_base(output_tensor, p_data, group=self.process_group)
else:
chunks = list(output_tensor.chunk(self.world_size))
dist.all_gather(chunks, p_data, group=self.process_group)

if self.perform_all_gather_in_8_bits:
output_tensor = self._convert_from_fp8(output_tensor, scale)
# Cause we only have fp32 <-> fp8 kernels in FBGEMM
output_tensor_original.copy_(output_tensor)


@torch.no_grad()
def _use_full_params(self) -> None:
"""
Expand Down Expand Up @@ -2072,6 +2146,7 @@ def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None:
# Storage object and unshard it in-place. For now, just resize
# the Storage to 0 to save memory.
free_storage_(p._full_param_padded)

torch.cuda.current_stream().synchronize()

def local_metadata_dict(self) -> Dict[str, Any]:
Expand Down