diff --git a/torchstore/utils.py b/torchstore/utils.py index 5536bde..bbc8109 100644 --- a/torchstore/utils.py +++ b/torchstore/utils.py @@ -4,12 +4,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import math import uuid from logging import getLogger from typing import List, Tuple, TYPE_CHECKING -import numpy as np - import torch from torchstore.constants import MONARCH_HOSTMESH_V1 @@ -119,8 +118,8 @@ def get_target_tensor_shape_and_offset( # Verify that local tensors can fill the target tensor, this verification is only necessary but not # sufficient to guarantee that the target tensor can be filled by local tensors. - local_tensor_total_size = sum([np.prod(shape) for shape in local_tensor_shapes]) - target_tensor_size = np.prod(target_shape) + local_tensor_total_size = sum([math.prod(shape) for shape in local_tensor_shapes]) + target_tensor_size = math.prod(target_shape) assert ( local_tensor_total_size >= target_tensor_size ), "Local tensor sizes doesn't match target tensor. "