Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions torchrec/distributed/planner/planners.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
ParameterConstraints,
Partitioner,
PerfModel,
PlanDebugStats,
PlannerError,
PlannerErrorType,
Proposer,
Expand Down Expand Up @@ -528,6 +529,10 @@ def plan(
enumerator=self._enumerator,
sharders=sharders,
debug=self._debug,
debug_stats=PlanDebugStats(
planner_type=self.__class__.__name__,
timeout_seconds=self._timeout_seconds,
),
)
return sharding_plan
else:
Expand Down
3 changes: 3 additions & 0 deletions torchrec/distributed/planner/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
Enumerator,
ParameterConstraints,
Perf,
PlanDebugStats,
ShardingOption,
Stats,
Storage,
Expand Down Expand Up @@ -160,6 +161,7 @@ def log(
sharders: Optional[List[ModuleSharder[nn.Module]]] = None,
enumerator: Optional[Enumerator] = None,
debug: bool = True,
debug_stats: Optional[PlanDebugStats] = None,
) -> None:
"""
Logs stats for a given sharding plan.
Expand Down Expand Up @@ -1138,5 +1140,6 @@ def log(
sharders: Optional[List[ModuleSharder[nn.Module]]] = None,
enumerator: Optional[Enumerator] = None,
debug: bool = True,
debug_stats: Optional[PlanDebugStats] = None,
) -> None:
pass
79 changes: 79 additions & 0 deletions torchrec/distributed/planner/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,22 @@ def __hash__(self) -> int:
)
)

def storage_hash(self) -> int:
"""
Hash needed to preserve sharding option uniquely based on input before
planning. This is needed to restore sharding option from the loaded plan.
Hash is computed based on the following attributes:
- fqn
- sharding_type
- compute_kernel
- column_wise_shard_dim
"""
# Use BLAKE2b for deterministic hashing, constrained to 64-bit signed int range
hash_str = f"{self.fqn}|{self.sharding_type}|{self.compute_kernel}|{self.cache_load_factor}|{self.num_shards}"
hash_bytes = hashlib.blake2b(hash_str.encode("utf-8"), digest_size=7).digest()
hash_int = int.from_bytes(hash_bytes, byteorder="big")
return hash_int

def __deepcopy__(
self, memo: Optional[Dict[int, "ShardingOption"]]
) -> "ShardingOption":
Expand Down Expand Up @@ -944,6 +960,16 @@ def partition(
...


@dataclass
class PlanDebugStats:
"""
Representation of debug stats associated with a sharding plan, used for logging.
"""

planner_type: str
timeout_seconds: Optional[int]


class Stats(abc.ABC):
"""
Logs statistics related to the sharding plan.
Expand All @@ -964,6 +990,7 @@ def log(
sharders: Optional[List[ModuleSharder[nn.Module]]] = None,
enumerator: Optional[Enumerator] = None,
debug: bool = False,
debug_stats: Optional[PlanDebugStats] = None,
) -> None:
"""
See class description
Expand Down Expand Up @@ -991,6 +1018,16 @@ def hash_sha256_to_int(hashable_list: List[Any]) -> int: # pyre-ignore
return int(hash_digest, 16)


def hash_sha256_str(hashable_list: List[Any]) -> str: # pyre-ignore
"""
Hashes the given data using SHA256 and returns the hash as an string
"""
serialized_list = str(hashable_list).encode("utf-8")
hash_object = hashlib.sha256(serialized_list)
hash_digest = hash_object.hexdigest()
return hash_digest


def hash_planner_context_inputs(
topology: Topology,
batch_size: int,
Expand Down Expand Up @@ -1031,3 +1068,45 @@ def hash_planner_context_inputs(
constraints.items() if constraints else None,
]
return hash_function(hashable_list)


def hash_planner_context_inputs_str(
topology: Topology,
batch_size: int,
enumerator: Enumerator,
storage_reservation: StorageReservation,
constraints: Optional[Dict[str, ParameterConstraints]],
# pyre-ignore
hash_function: Callable[[List[Any]], str] = hash_sha256_str,
) -> str:
assert hasattr(
enumerator, "last_stored_search_space"
), "This enumerator is not compatible with hashing"
assert (
enumerator.last_stored_search_space is not None # pyre-ignore
), "Unable to hash planner context without an enumerator that has a precomputed search space"
search_space = enumerator.last_stored_search_space
storage_reservation_policy = type(storage_reservation).__name__

assert (
storage_reservation._last_reserved_topology is not None # pyre-ignore
), "Unable to hash planner context without a storage reservation that has a precomputed topology"

hashable_list = [
topology,
batch_size,
[
[
shard_option.fqn,
shard_option.sharding_type,
shard_option.compute_kernel,
tuple(shard_option.shards),
shard_option.cache_params,
]
for shard_option in search_space
],
storage_reservation_policy,
storage_reservation._last_reserved_topology,
constraints.items() if constraints else None,
]
return hash_function(hashable_list)
Loading