diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index b8f3bf17b..1224a66d5 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -20,6 +20,7 @@ Callable, cast, Dict, + FrozenSet, Generic, Iterator, List, @@ -344,6 +345,7 @@ def _populate_zero_collision_tbe_params( meta_header_lens[i] = table.virtual_table_eviction_policy.get_meta_header_len() if not isinstance(table.virtual_table_eviction_policy, NoEvictionPolicy): enabled = True + kvzch_tbe_config = None if enabled: counter_thresholds = [0] * len(config.embedding_tables) ttls_in_mins = [0] * len(config.embedding_tables) @@ -362,20 +364,18 @@ def _populate_zero_collision_tbe_params( assert ( "kvzch_tbe_config" in tbe_params ), "kvzch_tbe_config should be in tbe_params" - eviction_tbe_config = tbe_params["kvzch_tbe_config"] + kvzch_tbe_config = tbe_params["kvzch_tbe_config"] tbe_params.pop("kvzch_tbe_config") - eviction_trigger_mode = eviction_tbe_config.kvzch_eviction_trigger_mode - eviction_free_mem_threshold_gb = ( - eviction_tbe_config.eviction_free_mem_threshold_gb - ) + eviction_trigger_mode = kvzch_tbe_config.kvzch_eviction_trigger_mode + eviction_free_mem_threshold_gb = kvzch_tbe_config.eviction_free_mem_threshold_gb eviction_free_mem_check_interval_batch = ( - eviction_tbe_config.eviction_free_mem_check_interval_batch + kvzch_tbe_config.eviction_free_mem_check_interval_batch ) threshold_calculation_bucket_stride = ( - eviction_tbe_config.threshold_calculation_bucket_stride + kvzch_tbe_config.threshold_calculation_bucket_stride ) threshold_calculation_bucket_num = ( - eviction_tbe_config.threshold_calculation_bucket_num + kvzch_tbe_config.threshold_calculation_bucket_num ) for i, table in enumerate(config.embedding_tables): policy_t = table.virtual_table_eviction_policy @@ -477,6 +477,13 @@ def _populate_zero_collision_tbe_params( else False ) ) + + optimizer_type_for_st: Optional[str] = None + optimizer_state_dtypes_for_st: Optional[FrozenSet[Tuple[str, int]]] = None + if kvzch_tbe_config and kvzch_tbe_config.is_st_publish: + optimizer_type_for_st = kvzch_tbe_config.optimizer_type_for_st + optimizer_state_dtypes_for_st = kvzch_tbe_config.optimizer_state_dtypes_for_st + tbe_params["kv_zch_params"] = KVZCHParams( bucket_offsets=bucket_offsets, bucket_sizes=bucket_sizes, @@ -484,6 +491,9 @@ def _populate_zero_collision_tbe_params( backend_return_whole_row=(backend_type == BackendType.DRAM), eviction_policy=eviction_policy, embedding_cache_mode=embedding_cache_mode_, + load_ckpt_without_opt=kvzch_tbe_config.load_ckpt_without_opt, + optimizer_type_for_st=optimizer_type_for_st, + optimizer_state_dtypes_for_st=optimizer_state_dtypes_for_st, ) diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index 70b10c53b..0144b208d 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -664,6 +664,7 @@ class KeyValueParams: enable_raw_embedding_streaming: Optional[bool]: enable raw embedding streaming for SSD TBE res_store_shards: Optional[int] = None: the number of shards to store the raw embeddings kvzch_tbe_config: Optional[KVZCHTBEConfig]: KVZCH config for TBE + load_ckpt_without_opt: bool: whether it is st publish # Parameter Server (PS) Attributes ps_hosts (Optional[Tuple[Tuple[str, int]]]): List of PS host ip addresses @@ -690,6 +691,7 @@ class KeyValueParams: ) res_store_shards: Optional[int] = None # shards to store the raw embeddings kvzch_tbe_config: Optional[KVZCHTBEConfig] = None + load_ckpt_without_opt: bool = False # is st publish # Parameter Server (PS) Attributes ps_hosts: Optional[Tuple[Tuple[str, int], ...]] = None @@ -719,6 +721,7 @@ def __hash__(self) -> int: self.enable_raw_embedding_streaming, self.res_store_shards, self.kvzch_tbe_config, + self.load_ckpt_without_opt, ) )