Skip to content

Commit

Permalink
Fix using .lstrip() in hivemind.compression (#578)
Browse files Browse the repository at this point in the history
  • Loading branch information
borzunov committed Jul 21, 2023
1 parent ec1d7fe commit da130cd
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion hivemind/compression/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class NoCompression(CompressionBase):
def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
tensor = tensor.detach()
shape = tensor.shape
dtype_name = str(tensor.dtype).lstrip("torch.")
dtype_name = str(tensor.dtype).replace("torch.", "")
raw_data = tensor
if tensor.dtype == torch.bfloat16:
if USE_LEGACY_BFLOAT16: # legacy mode: convert to fp32
Expand Down
2 changes: 1 addition & 1 deletion hivemind/compression/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def quantize(

def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
tensor = tensor.detach()
dtype_name = str(tensor.dtype).lstrip("torch.")
dtype_name = str(tensor.dtype).replace("torch.", "")
if tensor.dtype == torch.bfloat16:
tensor = tensor.to(torch.float32)

Expand Down

0 comments on commit da130cd

Please sign in to comment.