|
41 | 41 | SplitTableBatchedEmbeddingBagsCodegen,
|
42 | 42 | )
|
43 | 43 | from fbgemm_gpu.tbe.ssd import ASSOC, SSDTableBatchedEmbeddingBags
|
| 44 | +from fbgemm_gpu.tbe.ssd.training import BackendType, KVZCHParams |
44 | 45 | from fbgemm_gpu.tbe.ssd.utils.partially_materialized_tensor import (
|
45 | 46 | PartiallyMaterializedTensor,
|
46 | 47 | )
|
|
50 | 51 | from torchrec.distributed.composable.table_batched_embedding_slice import (
|
51 | 52 | TableBatchedEmbeddingSlice,
|
52 | 53 | )
|
53 |
| -from torchrec.distributed.embedding_kernel import BaseEmbedding, get_state_dict |
| 54 | +from torchrec.distributed.embedding_kernel import ( |
| 55 | + BaseEmbedding, |
| 56 | + create_virtual_sharded_tensors, |
| 57 | + get_state_dict, |
| 58 | +) |
54 | 59 | from torchrec.distributed.embedding_types import (
|
55 | 60 | compute_kernel_to_embedding_location,
|
56 | 61 | DTensorMetadata,
|
|
65 | 70 | ShardMetadata,
|
66 | 71 | TensorProperties,
|
67 | 72 | )
|
68 |
| -from torchrec.distributed.utils import append_prefix |
| 73 | +from torchrec.distributed.utils import append_prefix, none_throws |
69 | 74 | from torchrec.modules.embedding_configs import (
|
70 | 75 | data_type_to_sparse_type,
|
71 | 76 | pooling_type_to_pooling_mode,
|
@@ -169,6 +174,24 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]:
|
169 | 174 | return ssd_tbe_params
|
170 | 175 |
|
171 | 176 |
|
| 177 | +def _populate_zero_collision_tbe_params( |
| 178 | + tbe_params: Dict[str, Any], |
| 179 | + sharded_local_buckets: List[Tuple[int, int, int]], |
| 180 | +) -> None: |
| 181 | + """ |
| 182 | + Construct Zero Collision TBE params from config and fused params dict. |
| 183 | + """ |
| 184 | + bucket_offsets: List[Tuple[int, int]] = [ |
| 185 | + (offset_start, offset_end) |
| 186 | + for offset_start, offset_end, _ in sharded_local_buckets |
| 187 | + ] |
| 188 | + bucket_sizes: List[int] = [size for _, _, size in sharded_local_buckets] |
| 189 | + |
| 190 | + tbe_params["kv_zch_params"] = KVZCHParams( |
| 191 | + bucket_offsets=bucket_offsets, bucket_sizes=bucket_sizes |
| 192 | + ) |
| 193 | + |
| 194 | + |
172 | 195 | class KeyValueEmbeddingFusedOptimizer(FusedOptimizer):
|
173 | 196 | def __init__(
|
174 | 197 | self,
|
@@ -676,24 +699,6 @@ def update_hyper_parameters(self, params_dict: Dict[str, Any]) -> None:
|
676 | 699 | self._emb_module.update_hyper_parameters(params_dict)
|
677 | 700 |
|
678 | 701 |
|
679 |
| -def _gen_named_parameters_by_table_ssd( |
680 |
| - emb_module: SSDTableBatchedEmbeddingBags, |
681 |
| - table_name_to_count: Dict[str, int], |
682 |
| - config: GroupedEmbeddingConfig, |
683 |
| - pg: Optional[dist.ProcessGroup] = None, |
684 |
| -) -> Iterator[Tuple[str, nn.Parameter]]: |
685 |
| - """ |
686 |
| - Return an empty tensor to indicate that the table is on remote device. |
687 |
| - """ |
688 |
| - for table in config.embedding_tables: |
689 |
| - table_name = table.name |
690 |
| - # placeholder |
691 |
| - weight: nn.Parameter = nn.Parameter(torch.empty(0)) |
692 |
| - # pyre-ignore |
693 |
| - weight._in_backward_optimizers = [EmptyFusedOptimizer()] |
694 |
| - yield (table_name, weight) |
695 |
| - |
696 |
| - |
697 | 702 | def _gen_named_parameters_by_table_ssd_pmt(
|
698 | 703 | emb_module: SSDTableBatchedEmbeddingBags,
|
699 | 704 | table_name_to_count: Dict[str, int],
|
@@ -956,6 +961,10 @@ def __init__(
|
956 | 961 | **ssd_tbe_params,
|
957 | 962 | ).to(device)
|
958 | 963 |
|
| 964 | + logger.info( |
| 965 | + f"tbe_unique_id:{self._emb_module.tbe_unique_id} => table name to count dict:{self.table_name_to_count}" |
| 966 | + ) |
| 967 | + |
959 | 968 | self._optim: KeyValueEmbeddingFusedOptimizer = KeyValueEmbeddingFusedOptimizer(
|
960 | 969 | config,
|
961 | 970 | self._emb_module,
|
@@ -1064,6 +1073,8 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
|
1064 | 1073 | Return an iterator over embedding tables, yielding both the table name as well as the embedding
|
1065 | 1074 | table itself. The embedding table is in the form of PartiallyMaterializedTensor with a valid
|
1066 | 1075 | RocksDB snapshot to support windowed access.
|
| 1076 | + optional ShardedTensor for weight_id, this won't be used here as this is non-kvzch |
| 1077 | + optional ShardedTensor for bucket_cnt, this won't be used here as this is non-kvzch |
1067 | 1078 | """
|
1068 | 1079 | for config, tensor in zip(
|
1069 | 1080 | self._config.embedding_tables,
|
@@ -1095,6 +1106,280 @@ def split_embedding_weights(self, no_snapshot: bool = True) -> Tuple[
|
1095 | 1106 | return self.emb_module.split_embedding_weights(no_snapshot)
|
1096 | 1107 |
|
1097 | 1108 |
|
| 1109 | +class ZeroCollisionKeyValueEmbedding( |
| 1110 | + BaseBatchedEmbedding[torch.Tensor], FusedOptimizerModule |
| 1111 | +): |
| 1112 | + def __init__( |
| 1113 | + self, |
| 1114 | + config: GroupedEmbeddingConfig, |
| 1115 | + pg: Optional[dist.ProcessGroup] = None, |
| 1116 | + device: Optional[torch.device] = None, |
| 1117 | + backend_type: BackendType = BackendType.SSD, |
| 1118 | + ) -> None: |
| 1119 | + super().__init__(config, pg, device) |
| 1120 | + |
| 1121 | + assert ( |
| 1122 | + len(config.embedding_tables) > 0 |
| 1123 | + ), "Expected to see at least one table in SSD TBE, but found 0." |
| 1124 | + assert ( |
| 1125 | + len({table.embedding_dim for table in config.embedding_tables}) == 1 |
| 1126 | + ), "Currently we expect all tables in SSD TBE to have the same embedding dimension." |
| 1127 | + |
| 1128 | + ssd_tbe_params = _populate_ssd_tbe_params(config) |
| 1129 | + self._bucket_spec: List[Tuple[int, int, int]] = self.get_sharded_local_buckets() |
| 1130 | + _populate_zero_collision_tbe_params(ssd_tbe_params, self._bucket_spec) |
| 1131 | + compute_kernel = config.embedding_tables[0].compute_kernel |
| 1132 | + embedding_location = compute_kernel_to_embedding_location(compute_kernel) |
| 1133 | + |
| 1134 | + self._emb_module: SSDTableBatchedEmbeddingBags = SSDTableBatchedEmbeddingBags( |
| 1135 | + embedding_specs=list(zip(self._num_embeddings, self._local_cols)), |
| 1136 | + feature_table_map=self._feature_table_map, |
| 1137 | + ssd_cache_location=embedding_location, |
| 1138 | + pooling_mode=PoolingMode.NONE, |
| 1139 | + backend_type=backend_type, |
| 1140 | + **ssd_tbe_params, |
| 1141 | + ).to(device) |
| 1142 | + |
| 1143 | + logger.info( |
| 1144 | + f"tbe_unique_id:{self._emb_module.tbe_unique_id} => table name to count dict:{self.table_name_to_count}" |
| 1145 | + ) |
| 1146 | + |
| 1147 | + self._optim: KeyValueEmbeddingFusedOptimizer = KeyValueEmbeddingFusedOptimizer( |
| 1148 | + config, |
| 1149 | + self._emb_module, |
| 1150 | + pg, |
| 1151 | + ) |
| 1152 | + self._param_per_table: Dict[str, nn.Parameter] = dict( |
| 1153 | + _gen_named_parameters_by_table_ssd_pmt( |
| 1154 | + emb_module=self._emb_module, |
| 1155 | + table_name_to_count=self.table_name_to_count.copy(), |
| 1156 | + config=self._config, |
| 1157 | + pg=pg, |
| 1158 | + ) |
| 1159 | + ) |
| 1160 | + self.init_parameters() |
| 1161 | + |
| 1162 | + # every split_embeding_weights call is expensive, since it iterates over all the elements in the backend kv db |
| 1163 | + # use split weights result cache so that multiple calls in the same train iteration will only trigger once |
| 1164 | + self._split_weights_res: Optional[ |
| 1165 | + Tuple[ |
| 1166 | + List[ShardedTensor], |
| 1167 | + List[ShardedTensor], |
| 1168 | + List[ShardedTensor], |
| 1169 | + ] |
| 1170 | + ] = None |
| 1171 | + |
| 1172 | + def init_parameters(self) -> None: |
| 1173 | + """ |
| 1174 | + An advantage of KV TBE is that we don't need to init weights. Hence skipping. |
| 1175 | + """ |
| 1176 | + pass |
| 1177 | + |
| 1178 | + @property |
| 1179 | + def emb_module( |
| 1180 | + self, |
| 1181 | + ) -> SSDTableBatchedEmbeddingBags: |
| 1182 | + return self._emb_module |
| 1183 | + |
| 1184 | + @property |
| 1185 | + def fused_optimizer(self) -> FusedOptimizer: |
| 1186 | + """ |
| 1187 | + SSD Embedding fuses backward with backward. |
| 1188 | + """ |
| 1189 | + return self._optim |
| 1190 | + |
| 1191 | + def get_sharded_local_buckets(self) -> List[Tuple[int, int, int]]: |
| 1192 | + """ |
| 1193 | + utils to get bucket offset start, bucket offset end, bucket size based on embedding sharding spec |
| 1194 | + """ |
| 1195 | + sharded_local_buckets: List[Tuple[int, int, int]] = [] |
| 1196 | + world_size = dist.get_world_size(self._pg) |
| 1197 | + local_rank = dist.get_rank(self._pg) |
| 1198 | + |
| 1199 | + for table in self._config.embedding_tables: |
| 1200 | + total_num_buckets = none_throws(table.total_num_buckets) |
| 1201 | + assert ( |
| 1202 | + total_num_buckets % world_size == 0 |
| 1203 | + ), f"total_num_buckets={total_num_buckets} must be divisible by world_size={world_size}" |
| 1204 | + assert ( |
| 1205 | + table.total_num_buckets |
| 1206 | + and table.num_embeddings % table.total_num_buckets == 0 |
| 1207 | + ), f"Table size '{table.num_embeddings}' must be divisible by num_buckets '{table.total_num_buckets}'" |
| 1208 | + bucket_offset_start = total_num_buckets // world_size * local_rank |
| 1209 | + bucket_offset_end = min( |
| 1210 | + total_num_buckets, total_num_buckets // world_size * (local_rank + 1) |
| 1211 | + ) |
| 1212 | + bucket_size = ( |
| 1213 | + table.num_embeddings + total_num_buckets - 1 |
| 1214 | + ) // total_num_buckets |
| 1215 | + sharded_local_buckets.append( |
| 1216 | + (bucket_offset_start, bucket_offset_end, bucket_size) |
| 1217 | + ) |
| 1218 | + logger.info( |
| 1219 | + f"bucket_offset: {bucket_offset_start}:{bucket_offset_end}, bucket_size: {bucket_size} for table {table.name}" |
| 1220 | + ) |
| 1221 | + return sharded_local_buckets |
| 1222 | + |
| 1223 | + def state_dict( |
| 1224 | + self, |
| 1225 | + destination: Optional[Dict[str, Any]] = None, |
| 1226 | + prefix: str = "", |
| 1227 | + keep_vars: bool = False, |
| 1228 | + no_snapshot: bool = True, |
| 1229 | + ) -> Dict[str, Any]: |
| 1230 | + """ |
| 1231 | + Args: |
| 1232 | + no_snapshot (bool): the tensors in the returned dict are |
| 1233 | + PartiallyMaterializedTensors. this argument controls wether the |
| 1234 | + PartiallyMaterializedTensor owns a RocksDB snapshot handle. True means the |
| 1235 | + PartiallyMaterializedTensor doesn't have a RocksDB snapshot handle. False means the |
| 1236 | + PartiallyMaterializedTensor has a RocksDB snapshot handle |
| 1237 | + """ |
| 1238 | + # in the case no_snapshot=False, a flush is required. we rely on the flush operation in |
| 1239 | + # ShardedEmbeddingBagCollection._pre_state_dict_hook() |
| 1240 | + |
| 1241 | + emb_tables, _, _ = self.split_embedding_weights(no_snapshot=no_snapshot) |
| 1242 | + emb_table_config_copy = copy.deepcopy(self._config.embedding_tables) |
| 1243 | + for emb_table in emb_table_config_copy: |
| 1244 | + emb_table.local_metadata.placement._device = torch.device("cpu") |
| 1245 | + ret = get_state_dict( |
| 1246 | + emb_table_config_copy, |
| 1247 | + emb_tables, |
| 1248 | + self._pg, |
| 1249 | + destination, |
| 1250 | + prefix, |
| 1251 | + ) |
| 1252 | + return ret |
| 1253 | + |
| 1254 | + def named_parameters( |
| 1255 | + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True |
| 1256 | + ) -> Iterator[Tuple[str, nn.Parameter]]: |
| 1257 | + """ |
| 1258 | + Only allowed ways to get state_dict. |
| 1259 | + """ |
| 1260 | + for name, tensor in self.named_split_embedding_weights( |
| 1261 | + prefix, recurse, remove_duplicate |
| 1262 | + ): |
| 1263 | + # hack before we support optimizer on sharded parameter level |
| 1264 | + # can delete after PEA deprecation |
| 1265 | + # pyre-ignore [6] |
| 1266 | + param = nn.Parameter(tensor) |
| 1267 | + # pyre-ignore |
| 1268 | + param._in_backward_optimizers = [EmptyFusedOptimizer()] |
| 1269 | + yield name, param |
| 1270 | + |
| 1271 | + # pyre-ignore [15] |
| 1272 | + def named_split_embedding_weights( |
| 1273 | + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True |
| 1274 | + ) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]: |
| 1275 | + assert ( |
| 1276 | + remove_duplicate |
| 1277 | + ), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights" |
| 1278 | + for config, tensor in zip( |
| 1279 | + self._config.embedding_tables, |
| 1280 | + self.split_embedding_weights()[0], |
| 1281 | + ): |
| 1282 | + key = append_prefix(prefix, f"{config.name}.weight") |
| 1283 | + yield key, tensor |
| 1284 | + |
| 1285 | + def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterator[ |
| 1286 | + Tuple[ |
| 1287 | + str, |
| 1288 | + Union[ShardedTensor, PartiallyMaterializedTensor], |
| 1289 | + Optional[ShardedTensor], |
| 1290 | + Optional[ShardedTensor], |
| 1291 | + ] |
| 1292 | + ]: |
| 1293 | + """ |
| 1294 | + Return an iterator over embedding tables, for each table yielding |
| 1295 | + table name, |
| 1296 | + PMT for embedding table with a valid RocksDB snapshot to support tensor IO |
| 1297 | + optional ShardedTensor for weight_id |
| 1298 | + optional ShardedTensor for bucket_cnt |
| 1299 | + """ |
| 1300 | + if self._split_weights_res is not None: |
| 1301 | + pmt_sharded_t_list = self._split_weights_res[0] |
| 1302 | + # pyre-ignore |
| 1303 | + weight_id_sharded_t_list = self._split_weights_res[1] |
| 1304 | + bucket_cnt_sharded_t_list = self._split_weights_res[2] |
| 1305 | + for table_idx, pmt_sharded_t in enumerate(pmt_sharded_t_list): |
| 1306 | + table_config = self._config.embedding_tables[table_idx] |
| 1307 | + key = append_prefix(prefix, f"{table_config.name}") |
| 1308 | + |
| 1309 | + yield key, pmt_sharded_t, weight_id_sharded_t_list[ |
| 1310 | + table_idx |
| 1311 | + ], bucket_cnt_sharded_t_list[table_idx] |
| 1312 | + return |
| 1313 | + |
| 1314 | + pmt_list, weight_ids_list, bucket_cnt_list = self.split_embedding_weights( |
| 1315 | + no_snapshot=False |
| 1316 | + ) |
| 1317 | + emb_table_config_copy = copy.deepcopy(self._config.embedding_tables) |
| 1318 | + for emb_table in emb_table_config_copy: |
| 1319 | + emb_table.local_metadata.placement._device = torch.device("cpu") |
| 1320 | + |
| 1321 | + pmt_sharded_t_list = create_virtual_sharded_tensors( |
| 1322 | + emb_table_config_copy, pmt_list, self._pg, prefix |
| 1323 | + ) |
| 1324 | + weight_id_sharded_t_list = create_virtual_sharded_tensors( |
| 1325 | + emb_table_config_copy, weight_ids_list, self._pg, prefix # pyre-ignore |
| 1326 | + ) |
| 1327 | + bucket_cnt_sharded_t_list = create_virtual_sharded_tensors( |
| 1328 | + emb_table_config_copy, bucket_cnt_list, self._pg, prefix # pyre-ignore |
| 1329 | + ) |
| 1330 | + # pyre-ignore |
| 1331 | + assert len(pmt_list) == len(weight_ids_list) == len(bucket_cnt_list) |
| 1332 | + assert ( |
| 1333 | + len(pmt_sharded_t_list) |
| 1334 | + == len(weight_id_sharded_t_list) |
| 1335 | + == len(bucket_cnt_sharded_t_list) |
| 1336 | + ) |
| 1337 | + for table_idx, pmt_sharded_t in enumerate(pmt_sharded_t_list): |
| 1338 | + table_config = self._config.embedding_tables[table_idx] |
| 1339 | + key = append_prefix(prefix, f"{table_config.name}") |
| 1340 | + |
| 1341 | + yield key, pmt_sharded_t, weight_id_sharded_t_list[ |
| 1342 | + table_idx |
| 1343 | + ], bucket_cnt_sharded_t_list[table_idx] |
| 1344 | + |
| 1345 | + self._split_weights_res = ( |
| 1346 | + pmt_sharded_t_list, |
| 1347 | + weight_id_sharded_t_list, |
| 1348 | + bucket_cnt_sharded_t_list, |
| 1349 | + ) |
| 1350 | + |
| 1351 | + def flush(self) -> None: |
| 1352 | + """ |
| 1353 | + Flush the embeddings in cache back to SSD. Should be pretty expensive. |
| 1354 | + """ |
| 1355 | + self.emb_module.flush() |
| 1356 | + |
| 1357 | + def purge(self) -> None: |
| 1358 | + """ |
| 1359 | + Reset the cache space. This is needed when we load state dict. |
| 1360 | + """ |
| 1361 | + # TODO: move the following to SSD TBE. |
| 1362 | + self.emb_module.lxu_cache_weights.zero_() |
| 1363 | + self.emb_module.lxu_cache_state.fill_(-1) |
| 1364 | + |
| 1365 | + # pyre-ignore [15] |
| 1366 | + def split_embedding_weights(self, no_snapshot: bool = True) -> Tuple[ |
| 1367 | + List[PartiallyMaterializedTensor], |
| 1368 | + Optional[List[torch.Tensor]], |
| 1369 | + Optional[List[torch.Tensor]], |
| 1370 | + ]: |
| 1371 | + return self.emb_module.split_embedding_weights(no_snapshot) |
| 1372 | + |
| 1373 | + def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: |
| 1374 | + # reset split weights during training |
| 1375 | + self._split_weights_res = None |
| 1376 | + |
| 1377 | + return self.emb_module( |
| 1378 | + indices=features.values().long(), |
| 1379 | + offsets=features.offsets().long(), |
| 1380 | + ) |
| 1381 | + |
| 1382 | + |
1098 | 1383 | class BatchedFusedEmbedding(BaseBatchedEmbedding[torch.Tensor], FusedOptimizerModule):
|
1099 | 1384 | def __init__(
|
1100 | 1385 | self,
|
@@ -1563,6 +1848,8 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
|
1563 | 1848 | Return an iterator over embedding tables, yielding both the table name as well as the embedding
|
1564 | 1849 | table itself. The embedding table is in the form of PartiallyMaterializedTensor with a valid
|
1565 | 1850 | RocksDB snapshot to support windowed access.
|
| 1851 | + optional ShardedTensor for weight_id, this won't be used here as this is non-kvzch |
| 1852 | + optional ShardedTensor for bucket_cnt, this won't be used here as this is non-kvzch |
1566 | 1853 | """
|
1567 | 1854 | for config, tensor in zip(
|
1568 | 1855 | self._config.embedding_tables,
|
|
0 commit comments