Skip to content

Commit

Permalink
Fix styling + do not support KiB
Browse files Browse the repository at this point in the history
  • Loading branch information
Wauplin committed May 23, 2024
1 parent 5d18005 commit 5708144
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 42 deletions.
64 changes: 31 additions & 33 deletions src/huggingface_hub/serialization/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def split_state_dict_into_shards_factory(
current_shard_size = 0
total_size = 0

max_shard_size = convert_file_size_to_int(max_shard_size)
if isinstance(max_shard_size, str):
max_shard_size = parse_size_to_int(max_shard_size)

for key, tensor in state_dict.items():
# when bnb serialization is used the weights in the state dict can be strings
Expand Down Expand Up @@ -171,45 +172,42 @@ def split_state_dict_into_shards_factory(
)


def convert_file_size_to_int(size: Union[int, str]):
SIZE_UNITS = {
"TB": 10**12,
"GB": 10**9,
"MB": 10**6,
"KB": 10**3,
}


def parse_size_to_int(size_as_str: str) -> int:
"""
Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes).
Parse a size expressed as a string with digits and unit (like `"5MB"`) to an integer (in bytes).
Supported units are "TB", "GB", "MB", "KB".
Args:
size (`int` or `str`): The size to convert. Will be directly returned if an `int`.
size_as_str (`str`): The size to convert. Will be directly returned if an `int`.
Example:
```py
>>> convert_file_size_to_int("1MiB")
1048576
>>> parse_size_to_int("5MB")
5000000
```
"""
mem_size = -1
err_msg = (
f"`size` {size} is not in a valid format. Use an integer for bytes, or a string with an unit (like '5.0GB')."
)
size_as_str = size_as_str.strip()

# Parse unit
unit = size_as_str[-2:].upper()
if unit not in SIZE_UNITS:
raise ValueError(f"Unit '{unit}' not supported. Supported units are TB, GB, MB, KB. Got '{size_as_str}'.")
multiplier = SIZE_UNITS[unit]

# Parse value
try:
if isinstance(size, int):
mem_size = size
elif size.upper().endswith("GIB"):
mem_size = int(float(size[:-3]) * (2**30))
elif size.upper().endswith("MIB"):
mem_size = int(float(size[:-3]) * (2**20))
elif size.upper().endswith("KIB"):
mem_size = int(float(size[:-3]) * (2**10))
elif size.upper().endswith("GB"):
int_size = int(float(size[:-2]) * (10**9))
mem_size = int_size // 8 if size.endswith("b") else int_size
elif size.upper().endswith("MB"):
int_size = int(float(size[:-2]) * (10**6))
mem_size = int_size // 8 if size.endswith("b") else int_size
elif size.upper().endswith("KB"):
int_size = int(float(size[:-2]) * (10**3))
mem_size = int_size // 8 if size.endswith("b") else int_size
except ValueError:
raise ValueError(err_msg)

if mem_size < 0:
raise ValueError(err_msg)
return mem_size
value = float(size_as_str[:-2].strip())
except ValueError as e:
raise ValueError(f"Could not parse the size value from '{size_as_str}': {e}") from e

return int(value * multiplier)
24 changes: 15 additions & 9 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest

from huggingface_hub.serialization import split_state_dict_into_shards_factory
from huggingface_hub.serialization._base import convert_file_size_to_int
from huggingface_hub.serialization._base import parse_size_to_int
from huggingface_hub.serialization._numpy import get_tensor_size as get_tensor_size_numpy
from huggingface_hub.serialization._tensorflow import get_tensor_size as get_tensor_size_tensorflow
from huggingface_hub.serialization._torch import get_tensor_size as get_tensor_size_torch
Expand Down Expand Up @@ -126,11 +128,15 @@ def test_get_tensor_size_torch():
assert get_tensor_size_torch(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float16)) == 5 * 2


def test_convert_file_size_to_int():
assert convert_file_size_to_int("1KiB") == 2**10
assert convert_file_size_to_int("1KB") == 10**3
assert convert_file_size_to_int("1MiB") == 2**20
assert convert_file_size_to_int("1MB") == 10**6
assert convert_file_size_to_int("1GiB") == 2**30
assert convert_file_size_to_int("1GB") == 10**9
assert convert_file_size_to_int("5GB") == 5 * 10**9
def test_parse_size_to_int():
assert parse_size_to_int("1KB") == 1 * 10**3
assert parse_size_to_int("2MB") == 2 * 10**6
assert parse_size_to_int("3GB") == 3 * 10**9
assert parse_size_to_int(" 10 KB ") == 10 * 10**3 # ok with whitespace
assert parse_size_to_int("20mb") == 20 * 10**6 # ok with lowercase

with pytest.raises(ValueError, match="Unit 'IB' not supported"):
parse_size_to_int("1KiB") # not a valid unit

with pytest.raises(ValueError, match="Could not parse the size value"):
parse_size_to_int("1ooKB") # not a float

0 comments on commit 5708144

Please sign in to comment.