Skip to content

Commit 723672c

Browse files
committed
torchrec support on kvzch emb lookup module (#2922)
Summary: X-link: pytorch/FBGEMM#4035 Pull Request resolved: #2922 X-link: facebookresearch/FBGEMM#1120 # Change logs 1. add ZeroCollisionKeyValueEmbedding emb lookup 2. address existing unit test missing for ssd offloading 3. add new ut for kv zch embedding module 4. add a temp hack solution for calculate bucket metadata 5. embedding updates, details illustrated below ####################################################################### ########################### embedding.py updates ########################## ####################################################################### 1. keep the original idea to init shardedTensor during training init 2. for kv zch table, the shardeTensor will be init using virtual size for metadata calculation, and skip actual tensor size check for ST init, this is needed as during training init, the table has 0 rows 3. the new tensor, weight_id will not be registered in the EC becuase its shape is changing in realtime, the weight_id tensor will be generated in post_state_dict hooks 4. the new tensor, bucket could be registered and preserved, but in this diff we keep it the same way as weight_id 5. in post state dict hook, we call get_named_split_embedding_weights_snapshot to get Tuple[table_name, weight(ST), weight_id(ST), bucket(ST)], all 3 tensors are return in the format of ST, and we will update destination with the returned ST directly 6. in pre_load_state_dict_hook, which is called upon load_state_dict(), we will skip all 3 tensors update, because the tensor assignment is done [on the nn.module side](https://fburl.com/code/it5nior8), which doesn't support updating KVT through PMT. This is fine for now because, checkpoint loading will be done outside of the load_state_dict call, but we need future plans to make it work cohesively with other type of tensors Reviewed By: kausv, emlin Differential Revision: D73567631
1 parent d152bf7 commit 723672c

15 files changed

+1036
-72
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 307 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
SplitTableBatchedEmbeddingBagsCodegen,
4242
)
4343
from fbgemm_gpu.tbe.ssd import ASSOC, SSDTableBatchedEmbeddingBags
44+
from fbgemm_gpu.tbe.ssd.training import BackendType, KVZCHParams
4445
from fbgemm_gpu.tbe.ssd.utils.partially_materialized_tensor import (
4546
PartiallyMaterializedTensor,
4647
)
@@ -50,7 +51,11 @@
5051
from torchrec.distributed.composable.table_batched_embedding_slice import (
5152
TableBatchedEmbeddingSlice,
5253
)
53-
from torchrec.distributed.embedding_kernel import BaseEmbedding, get_state_dict
54+
from torchrec.distributed.embedding_kernel import (
55+
BaseEmbedding,
56+
create_virtual_sharded_tensors,
57+
get_state_dict,
58+
)
5459
from torchrec.distributed.embedding_types import (
5560
compute_kernel_to_embedding_location,
5661
DTensorMetadata,
@@ -65,7 +70,7 @@
6570
ShardMetadata,
6671
TensorProperties,
6772
)
68-
from torchrec.distributed.utils import append_prefix
73+
from torchrec.distributed.utils import append_prefix, none_throws
6974
from torchrec.modules.embedding_configs import (
7075
data_type_to_sparse_type,
7176
pooling_type_to_pooling_mode,
@@ -169,6 +174,24 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]:
169174
return ssd_tbe_params
170175

171176

177+
def _populate_zero_collision_tbe_params(
178+
tbe_params: Dict[str, Any],
179+
sharded_local_buckets: List[Tuple[int, int, int]],
180+
) -> None:
181+
"""
182+
Construct Zero Collision TBE params from config and fused params dict.
183+
"""
184+
bucket_offsets: List[Tuple[int, int]] = [
185+
(offset_start, offset_end)
186+
for offset_start, offset_end, _ in sharded_local_buckets
187+
]
188+
bucket_sizes: List[int] = [size for _, _, size in sharded_local_buckets]
189+
190+
tbe_params["kv_zch_params"] = KVZCHParams(
191+
bucket_offsets=bucket_offsets, bucket_sizes=bucket_sizes
192+
)
193+
194+
172195
class KeyValueEmbeddingFusedOptimizer(FusedOptimizer):
173196
def __init__(
174197
self,
@@ -676,24 +699,6 @@ def update_hyper_parameters(self, params_dict: Dict[str, Any]) -> None:
676699
self._emb_module.update_hyper_parameters(params_dict)
677700

678701

679-
def _gen_named_parameters_by_table_ssd(
680-
emb_module: SSDTableBatchedEmbeddingBags,
681-
table_name_to_count: Dict[str, int],
682-
config: GroupedEmbeddingConfig,
683-
pg: Optional[dist.ProcessGroup] = None,
684-
) -> Iterator[Tuple[str, nn.Parameter]]:
685-
"""
686-
Return an empty tensor to indicate that the table is on remote device.
687-
"""
688-
for table in config.embedding_tables:
689-
table_name = table.name
690-
# placeholder
691-
weight: nn.Parameter = nn.Parameter(torch.empty(0))
692-
# pyre-ignore
693-
weight._in_backward_optimizers = [EmptyFusedOptimizer()]
694-
yield (table_name, weight)
695-
696-
697702
def _gen_named_parameters_by_table_ssd_pmt(
698703
emb_module: SSDTableBatchedEmbeddingBags,
699704
table_name_to_count: Dict[str, int],
@@ -956,6 +961,10 @@ def __init__(
956961
**ssd_tbe_params,
957962
).to(device)
958963

964+
logger.info(
965+
f"tbe_unique_id:{self._emb_module.tbe_unique_id} => table name to count dict:{self.table_name_to_count}"
966+
)
967+
959968
self._optim: KeyValueEmbeddingFusedOptimizer = KeyValueEmbeddingFusedOptimizer(
960969
config,
961970
self._emb_module,
@@ -1064,6 +1073,8 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
10641073
Return an iterator over embedding tables, yielding both the table name as well as the embedding
10651074
table itself. The embedding table is in the form of PartiallyMaterializedTensor with a valid
10661075
RocksDB snapshot to support windowed access.
1076+
optional ShardedTensor for weight_id, this won't be used here as this is non-kvzch
1077+
optional ShardedTensor for bucket_cnt, this won't be used here as this is non-kvzch
10671078
"""
10681079
for config, tensor in zip(
10691080
self._config.embedding_tables,
@@ -1095,6 +1106,280 @@ def split_embedding_weights(self, no_snapshot: bool = True) -> Tuple[
10951106
return self.emb_module.split_embedding_weights(no_snapshot)
10961107

10971108

1109+
class ZeroCollisionKeyValueEmbedding(
1110+
BaseBatchedEmbedding[torch.Tensor], FusedOptimizerModule
1111+
):
1112+
def __init__(
1113+
self,
1114+
config: GroupedEmbeddingConfig,
1115+
pg: Optional[dist.ProcessGroup] = None,
1116+
device: Optional[torch.device] = None,
1117+
backend_type: BackendType = BackendType.SSD,
1118+
) -> None:
1119+
super().__init__(config, pg, device)
1120+
1121+
assert (
1122+
len(config.embedding_tables) > 0
1123+
), "Expected to see at least one table in SSD TBE, but found 0."
1124+
assert (
1125+
len({table.embedding_dim for table in config.embedding_tables}) == 1
1126+
), "Currently we expect all tables in SSD TBE to have the same embedding dimension."
1127+
1128+
ssd_tbe_params = _populate_ssd_tbe_params(config)
1129+
self._bucket_spec: List[Tuple[int, int, int]] = self.get_sharded_local_buckets()
1130+
_populate_zero_collision_tbe_params(ssd_tbe_params, self._bucket_spec)
1131+
compute_kernel = config.embedding_tables[0].compute_kernel
1132+
embedding_location = compute_kernel_to_embedding_location(compute_kernel)
1133+
1134+
self._emb_module: SSDTableBatchedEmbeddingBags = SSDTableBatchedEmbeddingBags(
1135+
embedding_specs=list(zip(self._num_embeddings, self._local_cols)),
1136+
feature_table_map=self._feature_table_map,
1137+
ssd_cache_location=embedding_location,
1138+
pooling_mode=PoolingMode.NONE,
1139+
backend_type=backend_type,
1140+
**ssd_tbe_params,
1141+
).to(device)
1142+
1143+
logger.info(
1144+
f"tbe_unique_id:{self._emb_module.tbe_unique_id} => table name to count dict:{self.table_name_to_count}"
1145+
)
1146+
1147+
self._optim: KeyValueEmbeddingFusedOptimizer = KeyValueEmbeddingFusedOptimizer(
1148+
config,
1149+
self._emb_module,
1150+
pg,
1151+
)
1152+
self._param_per_table: Dict[str, nn.Parameter] = dict(
1153+
_gen_named_parameters_by_table_ssd_pmt(
1154+
emb_module=self._emb_module,
1155+
table_name_to_count=self.table_name_to_count.copy(),
1156+
config=self._config,
1157+
pg=pg,
1158+
)
1159+
)
1160+
self.init_parameters()
1161+
1162+
# every split_embeding_weights call is expensive, since it iterates over all the elements in the backend kv db
1163+
# use split weights result cache so that multiple calls in the same train iteration will only trigger once
1164+
self._split_weights_res: Optional[
1165+
Tuple[
1166+
List[ShardedTensor],
1167+
List[ShardedTensor],
1168+
List[ShardedTensor],
1169+
]
1170+
] = None
1171+
1172+
def init_parameters(self) -> None:
1173+
"""
1174+
An advantage of KV TBE is that we don't need to init weights. Hence skipping.
1175+
"""
1176+
pass
1177+
1178+
@property
1179+
def emb_module(
1180+
self,
1181+
) -> SSDTableBatchedEmbeddingBags:
1182+
return self._emb_module
1183+
1184+
@property
1185+
def fused_optimizer(self) -> FusedOptimizer:
1186+
"""
1187+
SSD Embedding fuses backward with backward.
1188+
"""
1189+
return self._optim
1190+
1191+
def get_sharded_local_buckets(self) -> List[Tuple[int, int, int]]:
1192+
"""
1193+
utils to get bucket offset start, bucket offset end, bucket size based on embedding sharding spec
1194+
"""
1195+
sharded_local_buckets: List[Tuple[int, int, int]] = []
1196+
world_size = dist.get_world_size(self._pg)
1197+
local_rank = dist.get_rank(self._pg)
1198+
1199+
for table in self._config.embedding_tables:
1200+
total_num_buckets = none_throws(table.total_num_buckets)
1201+
assert (
1202+
total_num_buckets % world_size == 0
1203+
), f"total_num_buckets={total_num_buckets} must be divisible by world_size={world_size}"
1204+
assert (
1205+
table.total_num_buckets
1206+
and table.num_embeddings % table.total_num_buckets == 0
1207+
), f"Table size '{table.num_embeddings}' must be divisible by num_buckets '{table.total_num_buckets}'"
1208+
bucket_offset_start = total_num_buckets // world_size * local_rank
1209+
bucket_offset_end = min(
1210+
total_num_buckets, total_num_buckets // world_size * (local_rank + 1)
1211+
)
1212+
bucket_size = (
1213+
table.num_embeddings + total_num_buckets - 1
1214+
) // total_num_buckets
1215+
sharded_local_buckets.append(
1216+
(bucket_offset_start, bucket_offset_end, bucket_size)
1217+
)
1218+
logger.info(
1219+
f"bucket_offset: {bucket_offset_start}:{bucket_offset_end}, bucket_size: {bucket_size} for table {table.name}"
1220+
)
1221+
return sharded_local_buckets
1222+
1223+
def state_dict(
1224+
self,
1225+
destination: Optional[Dict[str, Any]] = None,
1226+
prefix: str = "",
1227+
keep_vars: bool = False,
1228+
no_snapshot: bool = True,
1229+
) -> Dict[str, Any]:
1230+
"""
1231+
Args:
1232+
no_snapshot (bool): the tensors in the returned dict are
1233+
PartiallyMaterializedTensors. this argument controls wether the
1234+
PartiallyMaterializedTensor owns a RocksDB snapshot handle. True means the
1235+
PartiallyMaterializedTensor doesn't have a RocksDB snapshot handle. False means the
1236+
PartiallyMaterializedTensor has a RocksDB snapshot handle
1237+
"""
1238+
# in the case no_snapshot=False, a flush is required. we rely on the flush operation in
1239+
# ShardedEmbeddingBagCollection._pre_state_dict_hook()
1240+
1241+
emb_tables, _, _ = self.split_embedding_weights(no_snapshot=no_snapshot)
1242+
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
1243+
for emb_table in emb_table_config_copy:
1244+
emb_table.local_metadata.placement._device = torch.device("cpu")
1245+
ret = get_state_dict(
1246+
emb_table_config_copy,
1247+
emb_tables,
1248+
self._pg,
1249+
destination,
1250+
prefix,
1251+
)
1252+
return ret
1253+
1254+
def named_parameters(
1255+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
1256+
) -> Iterator[Tuple[str, nn.Parameter]]:
1257+
"""
1258+
Only allowed ways to get state_dict.
1259+
"""
1260+
for name, tensor in self.named_split_embedding_weights(
1261+
prefix, recurse, remove_duplicate
1262+
):
1263+
# hack before we support optimizer on sharded parameter level
1264+
# can delete after PEA deprecation
1265+
# pyre-ignore [6]
1266+
param = nn.Parameter(tensor)
1267+
# pyre-ignore
1268+
param._in_backward_optimizers = [EmptyFusedOptimizer()]
1269+
yield name, param
1270+
1271+
# pyre-ignore [15]
1272+
def named_split_embedding_weights(
1273+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
1274+
) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]:
1275+
assert (
1276+
remove_duplicate
1277+
), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
1278+
for config, tensor in zip(
1279+
self._config.embedding_tables,
1280+
self.split_embedding_weights()[0],
1281+
):
1282+
key = append_prefix(prefix, f"{config.name}.weight")
1283+
yield key, tensor
1284+
1285+
def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterator[
1286+
Tuple[
1287+
str,
1288+
Union[ShardedTensor, PartiallyMaterializedTensor],
1289+
Optional[ShardedTensor],
1290+
Optional[ShardedTensor],
1291+
]
1292+
]:
1293+
"""
1294+
Return an iterator over embedding tables, for each table yielding
1295+
table name,
1296+
PMT for embedding table with a valid RocksDB snapshot to support tensor IO
1297+
optional ShardedTensor for weight_id
1298+
optional ShardedTensor for bucket_cnt
1299+
"""
1300+
if self._split_weights_res is not None:
1301+
pmt_sharded_t_list = self._split_weights_res[0]
1302+
# pyre-ignore
1303+
weight_id_sharded_t_list = self._split_weights_res[1]
1304+
bucket_cnt_sharded_t_list = self._split_weights_res[2]
1305+
for table_idx, pmt_sharded_t in enumerate(pmt_sharded_t_list):
1306+
table_config = self._config.embedding_tables[table_idx]
1307+
key = append_prefix(prefix, f"{table_config.name}")
1308+
1309+
yield key, pmt_sharded_t, weight_id_sharded_t_list[
1310+
table_idx
1311+
], bucket_cnt_sharded_t_list[table_idx]
1312+
return
1313+
1314+
pmt_list, weight_ids_list, bucket_cnt_list = self.split_embedding_weights(
1315+
no_snapshot=False
1316+
)
1317+
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
1318+
for emb_table in emb_table_config_copy:
1319+
emb_table.local_metadata.placement._device = torch.device("cpu")
1320+
1321+
pmt_sharded_t_list = create_virtual_sharded_tensors(
1322+
emb_table_config_copy, pmt_list, self._pg, prefix
1323+
)
1324+
weight_id_sharded_t_list = create_virtual_sharded_tensors(
1325+
emb_table_config_copy, weight_ids_list, self._pg, prefix # pyre-ignore
1326+
)
1327+
bucket_cnt_sharded_t_list = create_virtual_sharded_tensors(
1328+
emb_table_config_copy, bucket_cnt_list, self._pg, prefix # pyre-ignore
1329+
)
1330+
# pyre-ignore
1331+
assert len(pmt_list) == len(weight_ids_list) == len(bucket_cnt_list)
1332+
assert (
1333+
len(pmt_sharded_t_list)
1334+
== len(weight_id_sharded_t_list)
1335+
== len(bucket_cnt_sharded_t_list)
1336+
)
1337+
for table_idx, pmt_sharded_t in enumerate(pmt_sharded_t_list):
1338+
table_config = self._config.embedding_tables[table_idx]
1339+
key = append_prefix(prefix, f"{table_config.name}")
1340+
1341+
yield key, pmt_sharded_t, weight_id_sharded_t_list[
1342+
table_idx
1343+
], bucket_cnt_sharded_t_list[table_idx]
1344+
1345+
self._split_weights_res = (
1346+
pmt_sharded_t_list,
1347+
weight_id_sharded_t_list,
1348+
bucket_cnt_sharded_t_list,
1349+
)
1350+
1351+
def flush(self) -> None:
1352+
"""
1353+
Flush the embeddings in cache back to SSD. Should be pretty expensive.
1354+
"""
1355+
self.emb_module.flush()
1356+
1357+
def purge(self) -> None:
1358+
"""
1359+
Reset the cache space. This is needed when we load state dict.
1360+
"""
1361+
# TODO: move the following to SSD TBE.
1362+
self.emb_module.lxu_cache_weights.zero_()
1363+
self.emb_module.lxu_cache_state.fill_(-1)
1364+
1365+
# pyre-ignore [15]
1366+
def split_embedding_weights(self, no_snapshot: bool = True) -> Tuple[
1367+
List[PartiallyMaterializedTensor],
1368+
Optional[List[torch.Tensor]],
1369+
Optional[List[torch.Tensor]],
1370+
]:
1371+
return self.emb_module.split_embedding_weights(no_snapshot)
1372+
1373+
def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
1374+
# reset split weights during training
1375+
self._split_weights_res = None
1376+
1377+
return self.emb_module(
1378+
indices=features.values().long(),
1379+
offsets=features.offsets().long(),
1380+
)
1381+
1382+
10981383
class BatchedFusedEmbedding(BaseBatchedEmbedding[torch.Tensor], FusedOptimizerModule):
10991384
def __init__(
11001385
self,
@@ -1563,6 +1848,8 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
15631848
Return an iterator over embedding tables, yielding both the table name as well as the embedding
15641849
table itself. The embedding table is in the form of PartiallyMaterializedTensor with a valid
15651850
RocksDB snapshot to support windowed access.
1851+
optional ShardedTensor for weight_id, this won't be used here as this is non-kvzch
1852+
optional ShardedTensor for bucket_cnt, this won't be used here as this is non-kvzch
15661853
"""
15671854
for config, tensor in zip(
15681855
self._config.embedding_tables,

0 commit comments

Comments
 (0)