Skip to content

Commit b4503c1

Browse files
committed
torchrec support on kvzch emb lookup module
Summary: # Change logs 1. add ZeroCollisionKeyValueEmbedding emb lookup 2. address existing unit test missing for ssd offloading 3. add new ut for kv zch embedding module 4. add a temp hack solution for calculate bucket metadata 5. embedding updates, details illustrated below ####################################################################### ########################### embedding.py updates ########################## ####################################################################### 1. keep the original idea to init shardedTensor during training init 2. for kv zch table, the shardeTensor will be init using virtual size for metadata calculation, and skip actual tensor size check for ST init, this is needed as during training init, the table has 0 rows 3. the new tensor, weight_id will not be registered in the EC becuase its shape is changing in realtime, the weight_id tensor will be generated in post_state_dict hooks 4. the new tensor, bucket could be registered and preserved, but in this diff we keep it the same way as weight_id 5. in post state dict hook, we call get_named_split_embedding_weights_snapshot to get Tuple[table_name, weight(ST), weight_id(ST), bucket(ST)], all 3 tensors are return in the format of ST, and we will update destination with the returned ST directly 6. in pre_load_state_dict_hook, which is called upon load_state_dict(), we will skip all 3 tensors update, because the tensor assignment is done [on the nn.module side](https://fburl.com/code/it5nior8), which doesn't support updating KVT through PMT. This is fine for now because, checkpoint loading will be done outside of the load_state_dict call, but we need future plans to make it work cohesively with other type of tensors Differential Revision: D73567631
1 parent 65798a9 commit b4503c1

10 files changed

+888
-65
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 303 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,11 @@
5050
from torchrec.distributed.composable.table_batched_embedding_slice import (
5151
TableBatchedEmbeddingSlice,
5252
)
53-
from torchrec.distributed.embedding_kernel import BaseEmbedding, get_state_dict
53+
from torchrec.distributed.embedding_kernel import (
54+
BaseEmbedding,
55+
create_virtual_sharded_tensors,
56+
get_state_dict,
57+
)
5458
from torchrec.distributed.embedding_types import (
5559
compute_kernel_to_embedding_location,
5660
DTensorMetadata,
@@ -65,7 +69,7 @@
6569
ShardMetadata,
6670
TensorProperties,
6771
)
68-
from torchrec.distributed.utils import append_prefix
72+
from torchrec.distributed.utils import append_prefix, none_throws
6973
from torchrec.modules.embedding_configs import (
7074
data_type_to_sparse_type,
7175
pooling_type_to_pooling_mode,
@@ -169,6 +173,22 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]:
169173
return ssd_tbe_params
170174

171175

176+
def _populate_zero_collision_tbe_params(
177+
tbe_params: Dict[str, Any],
178+
sharded_local_buckets: List[Tuple[int, int, int]],
179+
) -> None:
180+
"""
181+
Construct Zero Collision TBE params from config and fused params dict.
182+
"""
183+
tbe_params["zero_collision_tbe"] = True
184+
if tbe_params.get("zero_collision_tbe", False):
185+
tbe_params["bucket_offsets"] = [
186+
(offset_start, offset_end)
187+
for offset_start, offset_end, _ in sharded_local_buckets
188+
]
189+
tbe_params["bucket_sizes"] = [size for _, _, size in sharded_local_buckets]
190+
191+
172192
class KeyValueEmbeddingFusedOptimizer(FusedOptimizer):
173193
def __init__(
174194
self,
@@ -631,24 +651,6 @@ def update_hyper_parameters(self, params_dict: Dict[str, Any]) -> None:
631651
self._emb_module.update_hyper_parameters(params_dict)
632652

633653

634-
def _gen_named_parameters_by_table_ssd(
635-
emb_module: SSDTableBatchedEmbeddingBags,
636-
table_name_to_count: Dict[str, int],
637-
config: GroupedEmbeddingConfig,
638-
pg: Optional[dist.ProcessGroup] = None,
639-
) -> Iterator[Tuple[str, nn.Parameter]]:
640-
"""
641-
Return an empty tensor to indicate that the table is on remote device.
642-
"""
643-
for table in config.embedding_tables:
644-
table_name = table.name
645-
# placeholder
646-
weight: nn.Parameter = nn.Parameter(torch.empty(0))
647-
# pyre-ignore
648-
weight._in_backward_optimizers = [EmptyFusedOptimizer()]
649-
yield (table_name, weight)
650-
651-
652654
def _gen_named_parameters_by_table_ssd_pmt(
653655
emb_module: SSDTableBatchedEmbeddingBags,
654656
table_name_to_count: Dict[str, int],
@@ -911,6 +913,10 @@ def __init__(
911913
**ssd_tbe_params,
912914
).to(device)
913915

916+
logger.info(
917+
f"tbe_unique_id:{self._emb_module.tbe_unique_id} => table name to count dict:{self.table_name_to_count}"
918+
)
919+
914920
self._optim: KeyValueEmbeddingFusedOptimizer = KeyValueEmbeddingFusedOptimizer(
915921
config,
916922
self._emb_module,
@@ -1019,6 +1025,8 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
10191025
Return an iterator over embedding tables, yielding both the table name as well as the embedding
10201026
table itself. The embedding table is in the form of PartiallyMaterializedTensor with a valid
10211027
RocksDB snapshot to support windowed access.
1028+
optional ShardedTensor for weight_id, this won't be used here as this is non-kvzch
1029+
optional ShardedTensor for bucket_cnt, this won't be used here as this is non-kvzch
10221030
"""
10231031
for config, tensor in zip(
10241032
self._config.embedding_tables,
@@ -1050,6 +1058,279 @@ def split_embedding_weights(self, no_snapshot: bool = True) -> Tuple[
10501058
return self.emb_module.split_embedding_weights(no_snapshot)
10511059

10521060

1061+
class ZeroCollisionKeyValueEmbedding(
1062+
BaseBatchedEmbedding[torch.Tensor], FusedOptimizerModule
1063+
):
1064+
def __init__(
1065+
self,
1066+
config: GroupedEmbeddingConfig,
1067+
pg: Optional[dist.ProcessGroup] = None,
1068+
device: Optional[torch.device] = None,
1069+
) -> None:
1070+
super().__init__(config, pg, device)
1071+
1072+
assert (
1073+
len(config.embedding_tables) > 0
1074+
), "Expected to see at least one table in SSD TBE, but found 0."
1075+
assert (
1076+
len({table.embedding_dim for table in config.embedding_tables}) == 1
1077+
), "Currently we expect all tables in SSD TBE to have the same embedding dimension."
1078+
1079+
ssd_tbe_params = _populate_ssd_tbe_params(config)
1080+
self._bucket_spec: List[Tuple[int, int, int]] = self.get_sharded_local_buckets()
1081+
_populate_zero_collision_tbe_params(ssd_tbe_params, self._bucket_spec)
1082+
compute_kernel = config.embedding_tables[0].compute_kernel
1083+
embedding_location = compute_kernel_to_embedding_location(compute_kernel)
1084+
1085+
self._emb_module: SSDTableBatchedEmbeddingBags = SSDTableBatchedEmbeddingBags(
1086+
embedding_specs=list(
1087+
zip(self._num_embeddings, self._local_cols, self._local_rows)
1088+
),
1089+
feature_table_map=self._feature_table_map,
1090+
ssd_cache_location=embedding_location,
1091+
pooling_mode=PoolingMode.NONE,
1092+
**ssd_tbe_params,
1093+
).to(device)
1094+
1095+
logger.info(
1096+
f"tbe_unique_id:{self._emb_module.tbe_unique_id} => table name to count dict:{self.table_name_to_count}"
1097+
)
1098+
1099+
self._optim: KeyValueEmbeddingFusedOptimizer = KeyValueEmbeddingFusedOptimizer(
1100+
config,
1101+
self._emb_module,
1102+
pg,
1103+
)
1104+
self._param_per_table: Dict[str, nn.Parameter] = dict(
1105+
_gen_named_parameters_by_table_ssd_pmt(
1106+
emb_module=self._emb_module,
1107+
table_name_to_count=self.table_name_to_count.copy(),
1108+
config=self._config,
1109+
pg=pg,
1110+
)
1111+
)
1112+
self.init_parameters()
1113+
1114+
# every split_embeding_weights call is expensive, since it iterates over all the elements in the backend kv db
1115+
# use split weights result cache so that multiple calls in the same train iteration will only trigger once
1116+
self._split_weights_res: Optional[
1117+
Tuple[
1118+
List[ShardedTensor],
1119+
List[ShardedTensor],
1120+
List[ShardedTensor],
1121+
]
1122+
] = None
1123+
1124+
def init_parameters(self) -> None:
1125+
"""
1126+
An advantage of KV TBE is that we don't need to init weights. Hence skipping.
1127+
"""
1128+
pass
1129+
1130+
@property
1131+
def emb_module(
1132+
self,
1133+
) -> SSDTableBatchedEmbeddingBags:
1134+
return self._emb_module
1135+
1136+
@property
1137+
def fused_optimizer(self) -> FusedOptimizer:
1138+
"""
1139+
SSD Embedding fuses backward with backward.
1140+
"""
1141+
return self._optim
1142+
1143+
# TODO: this is a temporary hack, we should read shard info from torchrec sharding plan
1144+
def get_sharded_local_buckets(self) -> List[Tuple[int, int, int]]:
1145+
"""
1146+
utils to get bucket offset start, bucket offset end, bucket size based on embedding sharding spec
1147+
"""
1148+
sharded_local_buckets: List[Tuple[int, int, int]] = []
1149+
world_size = dist.get_world_size(self._pg)
1150+
local_rank = dist.get_rank(self._pg)
1151+
1152+
for table in self._config.embedding_tables:
1153+
# temporary before uneven sharding utils is ready
1154+
# Question, what happen if we have uneven sharding on the training side?
1155+
assert (
1156+
table.num_embeddings % world_size == 0
1157+
), "total_num_embeddings must be divisible by world_size"
1158+
total_num_buckets = none_throws(table.total_num_buckets)
1159+
bucket_offset_start = total_num_buckets // world_size * local_rank
1160+
bucket_offset_end = min(
1161+
total_num_buckets, total_num_buckets // world_size * (local_rank + 1)
1162+
)
1163+
bucket_size = (
1164+
table.num_embeddings + total_num_buckets - 1
1165+
) // total_num_buckets
1166+
sharded_local_buckets.append(
1167+
(bucket_offset_start, bucket_offset_end, bucket_size)
1168+
)
1169+
logger.info(
1170+
f"bucket_offset: {bucket_offset_start}:{bucket_offset_end}, bucket_size: {bucket_size} for table {table.name}"
1171+
)
1172+
return sharded_local_buckets
1173+
1174+
def state_dict(
1175+
self,
1176+
destination: Optional[Dict[str, Any]] = None,
1177+
prefix: str = "",
1178+
keep_vars: bool = False,
1179+
no_snapshot: bool = True,
1180+
) -> Dict[str, Any]:
1181+
"""
1182+
Args:
1183+
no_snapshot (bool): the tensors in the returned dict are
1184+
PartiallyMaterializedTensors. this argument controls wether the
1185+
PartiallyMaterializedTensor owns a RocksDB snapshot handle. True means the
1186+
PartiallyMaterializedTensor doesn't have a RocksDB snapshot handle. False means the
1187+
PartiallyMaterializedTensor has a RocksDB snapshot handle
1188+
"""
1189+
# in the case no_snapshot=False, a flush is required. we rely on the flush operation in
1190+
# ShardedEmbeddingBagCollection._pre_state_dict_hook()
1191+
1192+
emb_tables, _, _ = self.split_embedding_weights(no_snapshot=no_snapshot)
1193+
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
1194+
for emb_table in emb_table_config_copy:
1195+
emb_table.local_metadata.placement._device = torch.device("cpu")
1196+
ret = get_state_dict(
1197+
emb_table_config_copy,
1198+
emb_tables,
1199+
self._pg,
1200+
destination,
1201+
prefix,
1202+
)
1203+
return ret
1204+
1205+
def named_parameters(
1206+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
1207+
) -> Iterator[Tuple[str, nn.Parameter]]:
1208+
"""
1209+
Only allowed ways to get state_dict.
1210+
"""
1211+
for name, tensor in self.named_split_embedding_weights(
1212+
prefix, recurse, remove_duplicate
1213+
):
1214+
# hack before we support optimizer on sharded parameter level
1215+
# can delete after PEA deprecation
1216+
# pyre-ignore [6]
1217+
param = nn.Parameter(tensor)
1218+
# pyre-ignore
1219+
param._in_backward_optimizers = [EmptyFusedOptimizer()]
1220+
yield name, param
1221+
1222+
# pyre-ignore [15]
1223+
def named_split_embedding_weights(
1224+
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
1225+
) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]:
1226+
assert (
1227+
remove_duplicate
1228+
), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
1229+
for config, tensor in zip(
1230+
self._config.embedding_tables,
1231+
self.split_embedding_weights()[0],
1232+
):
1233+
key = append_prefix(prefix, f"{config.name}.weight")
1234+
yield key, tensor
1235+
1236+
def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterator[
1237+
Tuple[
1238+
str,
1239+
Union[ShardedTensor, PartiallyMaterializedTensor],
1240+
Optional[ShardedTensor],
1241+
Optional[ShardedTensor],
1242+
]
1243+
]:
1244+
"""
1245+
Return an iterator over embedding tables, for each table yielding
1246+
table name,
1247+
PMT for embedding table with a valid RocksDB snapshot to support tensor IO
1248+
optional ShardedTensor for weight_id
1249+
optional ShardedTensor for bucket_cnt
1250+
"""
1251+
if self._split_weights_res is not None:
1252+
pmt_sharded_t_list = self._split_weights_res[0]
1253+
# pyre-ignore
1254+
weight_id_sharded_t_list = self._split_weights_res[1]
1255+
bucket_cnt_sharded_t_list = self._split_weights_res[2]
1256+
for table_idx, pmt_sharded_t in enumerate(pmt_sharded_t_list):
1257+
table_config = self._config.embedding_tables[table_idx]
1258+
key = append_prefix(prefix, f"{table_config.name}")
1259+
1260+
yield key, pmt_sharded_t, weight_id_sharded_t_list[
1261+
table_idx
1262+
], bucket_cnt_sharded_t_list[table_idx]
1263+
return
1264+
1265+
pmt_list, weight_ids_list, bucket_cnt_list = self.split_embedding_weights(
1266+
no_snapshot=False
1267+
)
1268+
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
1269+
for emb_table in emb_table_config_copy:
1270+
emb_table.local_metadata.placement._device = torch.device("cpu")
1271+
1272+
pmt_sharded_t_list = create_virtual_sharded_tensors(
1273+
emb_table_config_copy, pmt_list, self._pg, prefix
1274+
)
1275+
weight_id_sharded_t_list = create_virtual_sharded_tensors(
1276+
emb_table_config_copy, weight_ids_list, self._pg, prefix # pyre-ignore
1277+
)
1278+
bucket_cnt_sharded_t_list = create_virtual_sharded_tensors(
1279+
emb_table_config_copy, bucket_cnt_list, self._pg, prefix # pyre-ignore
1280+
)
1281+
# pyre-ignore
1282+
assert len(pmt_list) == len(weight_ids_list) == len(bucket_cnt_list)
1283+
assert (
1284+
len(pmt_sharded_t_list)
1285+
== len(weight_id_sharded_t_list)
1286+
== len(bucket_cnt_sharded_t_list)
1287+
)
1288+
for table_idx, pmt_sharded_t in enumerate(pmt_sharded_t_list):
1289+
table_config = self._config.embedding_tables[table_idx]
1290+
key = append_prefix(prefix, f"{table_config.name}")
1291+
1292+
yield key, pmt_sharded_t, weight_id_sharded_t_list[
1293+
table_idx
1294+
], bucket_cnt_sharded_t_list[table_idx]
1295+
1296+
self._split_weights_res = (
1297+
pmt_sharded_t_list,
1298+
weight_id_sharded_t_list,
1299+
bucket_cnt_sharded_t_list,
1300+
)
1301+
1302+
def flush(self) -> None:
1303+
"""
1304+
Flush the embeddings in cache back to SSD. Should be pretty expensive.
1305+
"""
1306+
self.emb_module.flush()
1307+
1308+
def purge(self) -> None:
1309+
"""
1310+
Reset the cache space. This is needed when we load state dict.
1311+
"""
1312+
# TODO: move the following to SSD TBE.
1313+
self.emb_module.lxu_cache_weights.zero_()
1314+
self.emb_module.lxu_cache_state.fill_(-1)
1315+
1316+
# pyre-ignore [15]
1317+
def split_embedding_weights(self, no_snapshot: bool = True) -> Tuple[
1318+
List[PartiallyMaterializedTensor],
1319+
Optional[List[torch.Tensor]],
1320+
Optional[List[torch.Tensor]],
1321+
]:
1322+
return self.emb_module.split_embedding_weights(no_snapshot)
1323+
1324+
def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
1325+
# reset split weights during training
1326+
self._split_weights_res = None
1327+
1328+
return self.emb_module(
1329+
indices=features.values().long(),
1330+
offsets=features.offsets().long(),
1331+
)
1332+
1333+
10531334
class BatchedFusedEmbedding(BaseBatchedEmbedding[torch.Tensor], FusedOptimizerModule):
10541335
def __init__(
10551336
self,
@@ -1518,6 +1799,8 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
15181799
Return an iterator over embedding tables, yielding both the table name as well as the embedding
15191800
table itself. The embedding table is in the form of PartiallyMaterializedTensor with a valid
15201801
RocksDB snapshot to support windowed access.
1802+
optional ShardedTensor for weight_id, this won't be used here as this is non-kvzch
1803+
optional ShardedTensor for bucket_cnt, this won't be used here as this is non-kvzch
15211804
"""
15221805
for config, tensor in zip(
15231806
self._config.embedding_tables,

0 commit comments

Comments
 (0)