Skip to content

Commit

Permalink
Use backward- and forward-compatible code for efficient bfloat16
Browse files Browse the repository at this point in the history
serialization
  • Loading branch information
borzunov committed Apr 26, 2023
1 parent 06db5e3 commit 07e98ae
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions hivemind/compression/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,11 @@ def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: b
dtype_name = str(tensor.dtype).lstrip("torch.")
raw_data = tensor
if tensor.dtype == torch.bfloat16:
if USE_LEGACY_BFLOAT16:
if USE_LEGACY_BFLOAT16: # legacy mode: convert to fp32
raw_data = tensor.to(torch.float32)
else:
typed_storage = tensor.storage()
storage = typed_storage.untyped() if hasattr(typed_storage, "untyped") else typed_storage._untyped()
raw_data = torch.tensor(storage, dtype=torch.int8)
else: # efficient mode: send bfloat16 data directly
# reinterpret_cast to an arbitrary 2-byte type supported by numpy
raw_data = tensor.view(torch.int16)

return runtime_pb2.Tensor(
compression=self.compression_type,
Expand All @@ -106,13 +105,13 @@ def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
shape = torch.Size(serialized_tensor.size)
if serialized_tensor.dtype == "bfloat16":
numel = shape.numel()
if numel > 0 and len(serialized_tensor.buffer) // numel == 4: # legacy mode: convert to fp32
if numel > 0 and len(serialized_tensor.buffer) // numel == 4:
array = np.frombuffer(serialized_tensor.buffer, dtype=np.float32)
tensor = torch.as_tensor(array, dtype=torch.bfloat16)
else: # efficient mode: send bfloat16 data directly
storage_type = torch.TypedStorage if hasattr(torch, "TypedStorage") else torch._TypedStorage
storage = storage_type.from_buffer(serialized_tensor.buffer, byte_order="little", dtype=torch.bfloat16)
tensor = torch.as_tensor(storage, dtype=torch.bfloat16)
else:
array = np.frombuffer(serialized_tensor.buffer, dtype=np.int16)
# reinterpret_cast from an arbitrary 2-byte type supported by numpy
tensor = torch.as_tensor(array).view(torch.bfloat16)
else:
array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype))
tensor = torch.as_tensor(array)
Expand Down

0 comments on commit 07e98ae

Please sign in to comment.