diff --git a/src/huggingface_hub/serialization/_base.py b/src/huggingface_hub/serialization/_base.py index 23d82145cf..e16e4a8137 100644 --- a/src/huggingface_hub/serialization/_base.py +++ b/src/huggingface_hub/serialization/_base.py @@ -89,7 +89,8 @@ def split_state_dict_into_shards_factory( current_shard_size = 0 total_size = 0 - max_shard_size = convert_file_size_to_int(max_shard_size) + if isinstance(max_shard_size, str): + max_shard_size = parse_size_to_int(max_shard_size) for key, tensor in state_dict.items(): # when bnb serialization is used the weights in the state dict can be strings @@ -171,45 +172,42 @@ def split_state_dict_into_shards_factory( ) -def convert_file_size_to_int(size: Union[int, str]): +SIZE_UNITS = { + "TB": 10**12, + "GB": 10**9, + "MB": 10**6, + "KB": 10**3, +} + + +def parse_size_to_int(size_as_str: str) -> int: """ - Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes). + Parse a size expressed as a string with digits and unit (like `"5MB"`) to an integer (in bytes). + + Supported units are "TB", "GB", "MB", "KB". Args: - size (`int` or `str`): The size to convert. Will be directly returned if an `int`. + size_as_str (`str`): The size to convert. Will be directly returned if an `int`. Example: ```py - >>> convert_file_size_to_int("1MiB") - 1048576 + >>> parse_size_to_int("5MB") + 5000000 ``` """ - mem_size = -1 - err_msg = ( - f"`size` {size} is not in a valid format. Use an integer for bytes, or a string with an unit (like '5.0GB')." - ) + size_as_str = size_as_str.strip() + + # Parse unit + unit = size_as_str[-2:].upper() + if unit not in SIZE_UNITS: + raise ValueError(f"Unit '{unit}' not supported. Supported units are TB, GB, MB, KB. Got '{size_as_str}'.") + multiplier = SIZE_UNITS[unit] + + # Parse value try: - if isinstance(size, int): - mem_size = size - elif size.upper().endswith("GIB"): - mem_size = int(float(size[:-3]) * (2**30)) - elif size.upper().endswith("MIB"): - mem_size = int(float(size[:-3]) * (2**20)) - elif size.upper().endswith("KIB"): - mem_size = int(float(size[:-3]) * (2**10)) - elif size.upper().endswith("GB"): - int_size = int(float(size[:-2]) * (10**9)) - mem_size = int_size // 8 if size.endswith("b") else int_size - elif size.upper().endswith("MB"): - int_size = int(float(size[:-2]) * (10**6)) - mem_size = int_size // 8 if size.endswith("b") else int_size - elif size.upper().endswith("KB"): - int_size = int(float(size[:-2]) * (10**3)) - mem_size = int_size // 8 if size.endswith("b") else int_size - except ValueError: - raise ValueError(err_msg) - - if mem_size < 0: - raise ValueError(err_msg) - return mem_size + value = float(size_as_str[:-2].strip()) + except ValueError as e: + raise ValueError(f"Could not parse the size value from '{size_as_str}': {e}") from e + + return int(value * multiplier) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 4eecbe997d..47a78d5e2e 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -1,5 +1,7 @@ +import pytest + from huggingface_hub.serialization import split_state_dict_into_shards_factory -from huggingface_hub.serialization._base import convert_file_size_to_int +from huggingface_hub.serialization._base import parse_size_to_int from huggingface_hub.serialization._numpy import get_tensor_size as get_tensor_size_numpy from huggingface_hub.serialization._tensorflow import get_tensor_size as get_tensor_size_tensorflow from huggingface_hub.serialization._torch import get_tensor_size as get_tensor_size_torch @@ -126,11 +128,15 @@ def test_get_tensor_size_torch(): assert get_tensor_size_torch(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float16)) == 5 * 2 -def test_convert_file_size_to_int(): - assert convert_file_size_to_int("1KiB") == 2**10 - assert convert_file_size_to_int("1KB") == 10**3 - assert convert_file_size_to_int("1MiB") == 2**20 - assert convert_file_size_to_int("1MB") == 10**6 - assert convert_file_size_to_int("1GiB") == 2**30 - assert convert_file_size_to_int("1GB") == 10**9 - assert convert_file_size_to_int("5GB") == 5 * 10**9 +def test_parse_size_to_int(): + assert parse_size_to_int("1KB") == 1 * 10**3 + assert parse_size_to_int("2MB") == 2 * 10**6 + assert parse_size_to_int("3GB") == 3 * 10**9 + assert parse_size_to_int(" 10 KB ") == 10 * 10**3 # ok with whitespace + assert parse_size_to_int("20mb") == 20 * 10**6 # ok with lowercase + + with pytest.raises(ValueError, match="Unit 'IB' not supported"): + parse_size_to_int("1KiB") # not a valid unit + + with pytest.raises(ValueError, match="Could not parse the size value"): + parse_size_to_int("1ooKB") # not a float