diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index df6a23b98..958697699 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -232,6 +232,7 @@ def _populate_zero_collision_tbe_params( l2_weight_thresholds = [0.0] * len(config.embedding_tables) eviction_strategy = -1 table_names = [table.name for table in config.embedding_tables] + l2_cache_size = tbe_params["l2_cache_size"] for i, table in enumerate(config.embedding_tables): policy_t = table.virtual_table_eviction_policy if policy_t is not None: @@ -276,6 +277,7 @@ def _populate_zero_collision_tbe_params( ) eviction_policy = EvictionPolicy( eviction_trigger_mode=2, # 2 means mem_util based eviction + eviction_mem_threshold_gb=l2_cache_size, eviction_strategy=eviction_strategy, counter_thresholds=counter_thresholds, ttls_in_mins=ttls_in_mins, @@ -288,7 +290,7 @@ def _populate_zero_collision_tbe_params( tbe_params["kv_zch_params"] = KVZCHParams( bucket_offsets=bucket_offsets, bucket_sizes=bucket_sizes, - enable_optimizer_offloading=False, + enable_optimizer_offloading=True, eviction_policy=eviction_policy, )