diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index 1224a66d5..37d4446ef 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -480,7 +480,7 @@ def _populate_zero_collision_tbe_params( optimizer_type_for_st: Optional[str] = None optimizer_state_dtypes_for_st: Optional[FrozenSet[Tuple[str, int]]] = None - if kvzch_tbe_config and kvzch_tbe_config.is_st_publish: + if kvzch_tbe_config and kvzch_tbe_config.load_ckpt_without_opt: optimizer_type_for_st = kvzch_tbe_config.optimizer_type_for_st optimizer_state_dtypes_for_st = kvzch_tbe_config.optimizer_state_dtypes_for_st