|
50 | 50 | from torchrec.distributed.composable.table_batched_embedding_slice import (
|
51 | 51 | TableBatchedEmbeddingSlice,
|
52 | 52 | )
|
53 |
| -from torchrec.distributed.embedding_kernel import BaseEmbedding, get_state_dict |
| 53 | +from torchrec.distributed.embedding_kernel import ( |
| 54 | + BaseEmbedding, |
| 55 | + create_virtual_sharded_tensors, |
| 56 | + get_state_dict, |
| 57 | +) |
54 | 58 | from torchrec.distributed.embedding_types import (
|
55 | 59 | compute_kernel_to_embedding_location,
|
56 | 60 | DTensorMetadata,
|
|
65 | 69 | ShardMetadata,
|
66 | 70 | TensorProperties,
|
67 | 71 | )
|
68 |
| -from torchrec.distributed.utils import append_prefix |
| 72 | +from torchrec.distributed.utils import append_prefix, none_throws |
69 | 73 | from torchrec.modules.embedding_configs import (
|
70 | 74 | data_type_to_sparse_type,
|
71 | 75 | pooling_type_to_pooling_mode,
|
@@ -169,6 +173,22 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]:
|
169 | 173 | return ssd_tbe_params
|
170 | 174 |
|
171 | 175 |
|
| 176 | +def _populate_zero_collision_tbe_params( |
| 177 | + tbe_params: Dict[str, Any], |
| 178 | + sharded_local_buckets: List[Tuple[int, int, int]], |
| 179 | +) -> None: |
| 180 | + """ |
| 181 | + Construct Zero Collision TBE params from config and fused params dict. |
| 182 | + """ |
| 183 | + tbe_params["zero_collision_tbe"] = True |
| 184 | + if tbe_params.get("zero_collision_tbe", False): |
| 185 | + tbe_params["bucket_offsets"] = [ |
| 186 | + (offset_start, offset_end) |
| 187 | + for offset_start, offset_end, _ in sharded_local_buckets |
| 188 | + ] |
| 189 | + tbe_params["bucket_sizes"] = [size for _, _, size in sharded_local_buckets] |
| 190 | + |
| 191 | + |
172 | 192 | class KeyValueEmbeddingFusedOptimizer(FusedOptimizer):
|
173 | 193 | def __init__(
|
174 | 194 | self,
|
@@ -631,24 +651,6 @@ def update_hyper_parameters(self, params_dict: Dict[str, Any]) -> None:
|
631 | 651 | self._emb_module.update_hyper_parameters(params_dict)
|
632 | 652 |
|
633 | 653 |
|
634 |
| -def _gen_named_parameters_by_table_ssd( |
635 |
| - emb_module: SSDTableBatchedEmbeddingBags, |
636 |
| - table_name_to_count: Dict[str, int], |
637 |
| - config: GroupedEmbeddingConfig, |
638 |
| - pg: Optional[dist.ProcessGroup] = None, |
639 |
| -) -> Iterator[Tuple[str, nn.Parameter]]: |
640 |
| - """ |
641 |
| - Return an empty tensor to indicate that the table is on remote device. |
642 |
| - """ |
643 |
| - for table in config.embedding_tables: |
644 |
| - table_name = table.name |
645 |
| - # placeholder |
646 |
| - weight: nn.Parameter = nn.Parameter(torch.empty(0)) |
647 |
| - # pyre-ignore |
648 |
| - weight._in_backward_optimizers = [EmptyFusedOptimizer()] |
649 |
| - yield (table_name, weight) |
650 |
| - |
651 |
| - |
652 | 654 | def _gen_named_parameters_by_table_ssd_pmt(
|
653 | 655 | emb_module: SSDTableBatchedEmbeddingBags,
|
654 | 656 | table_name_to_count: Dict[str, int],
|
@@ -911,6 +913,10 @@ def __init__(
|
911 | 913 | **ssd_tbe_params,
|
912 | 914 | ).to(device)
|
913 | 915 |
|
| 916 | + logger.info( |
| 917 | + f"tbe_unique_id:{self._emb_module.tbe_unique_id} => table name to count dict:{self.table_name_to_count}" |
| 918 | + ) |
| 919 | + |
914 | 920 | self._optim: KeyValueEmbeddingFusedOptimizer = KeyValueEmbeddingFusedOptimizer(
|
915 | 921 | config,
|
916 | 922 | self._emb_module,
|
@@ -1019,6 +1025,8 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
|
1019 | 1025 | Return an iterator over embedding tables, yielding both the table name as well as the embedding
|
1020 | 1026 | table itself. The embedding table is in the form of PartiallyMaterializedTensor with a valid
|
1021 | 1027 | RocksDB snapshot to support windowed access.
|
| 1028 | + optional ShardedTensor for weight_id, this won't be used here as this is non-kvzch |
| 1029 | + optional ShardedTensor for bucket_cnt, this won't be used here as this is non-kvzch |
1022 | 1030 | """
|
1023 | 1031 | for config, tensor in zip(
|
1024 | 1032 | self._config.embedding_tables,
|
@@ -1050,6 +1058,279 @@ def split_embedding_weights(self, no_snapshot: bool = True) -> Tuple[
|
1050 | 1058 | return self.emb_module.split_embedding_weights(no_snapshot)
|
1051 | 1059 |
|
1052 | 1060 |
|
| 1061 | +class ZeroCollisionKeyValueEmbedding( |
| 1062 | + BaseBatchedEmbedding[torch.Tensor], FusedOptimizerModule |
| 1063 | +): |
| 1064 | + def __init__( |
| 1065 | + self, |
| 1066 | + config: GroupedEmbeddingConfig, |
| 1067 | + pg: Optional[dist.ProcessGroup] = None, |
| 1068 | + device: Optional[torch.device] = None, |
| 1069 | + ) -> None: |
| 1070 | + super().__init__(config, pg, device) |
| 1071 | + |
| 1072 | + assert ( |
| 1073 | + len(config.embedding_tables) > 0 |
| 1074 | + ), "Expected to see at least one table in SSD TBE, but found 0." |
| 1075 | + assert ( |
| 1076 | + len({table.embedding_dim for table in config.embedding_tables}) == 1 |
| 1077 | + ), "Currently we expect all tables in SSD TBE to have the same embedding dimension." |
| 1078 | + |
| 1079 | + ssd_tbe_params = _populate_ssd_tbe_params(config) |
| 1080 | + self._bucket_spec: List[Tuple[int, int, int]] = self.get_sharded_local_buckets() |
| 1081 | + _populate_zero_collision_tbe_params(ssd_tbe_params, self._bucket_spec) |
| 1082 | + compute_kernel = config.embedding_tables[0].compute_kernel |
| 1083 | + embedding_location = compute_kernel_to_embedding_location(compute_kernel) |
| 1084 | + |
| 1085 | + self._emb_module: SSDTableBatchedEmbeddingBags = SSDTableBatchedEmbeddingBags( |
| 1086 | + embedding_specs=list( |
| 1087 | + zip(self._num_embeddings, self._local_cols, self._local_rows) |
| 1088 | + ), |
| 1089 | + feature_table_map=self._feature_table_map, |
| 1090 | + ssd_cache_location=embedding_location, |
| 1091 | + pooling_mode=PoolingMode.NONE, |
| 1092 | + **ssd_tbe_params, |
| 1093 | + ).to(device) |
| 1094 | + |
| 1095 | + logger.info( |
| 1096 | + f"tbe_unique_id:{self._emb_module.tbe_unique_id} => table name to count dict:{self.table_name_to_count}" |
| 1097 | + ) |
| 1098 | + |
| 1099 | + self._optim: KeyValueEmbeddingFusedOptimizer = KeyValueEmbeddingFusedOptimizer( |
| 1100 | + config, |
| 1101 | + self._emb_module, |
| 1102 | + pg, |
| 1103 | + ) |
| 1104 | + self._param_per_table: Dict[str, nn.Parameter] = dict( |
| 1105 | + _gen_named_parameters_by_table_ssd_pmt( |
| 1106 | + emb_module=self._emb_module, |
| 1107 | + table_name_to_count=self.table_name_to_count.copy(), |
| 1108 | + config=self._config, |
| 1109 | + pg=pg, |
| 1110 | + ) |
| 1111 | + ) |
| 1112 | + self.init_parameters() |
| 1113 | + |
| 1114 | + # every split_embeding_weights call is expensive, since it iterates over all the elements in the backend kv db |
| 1115 | + # use split weights result cache so that multiple calls in the same train iteration will only trigger once |
| 1116 | + self._split_weights_res: Optional[ |
| 1117 | + Tuple[ |
| 1118 | + List[ShardedTensor], |
| 1119 | + List[ShardedTensor], |
| 1120 | + List[ShardedTensor], |
| 1121 | + ] |
| 1122 | + ] = None |
| 1123 | + |
| 1124 | + def init_parameters(self) -> None: |
| 1125 | + """ |
| 1126 | + An advantage of KV TBE is that we don't need to init weights. Hence skipping. |
| 1127 | + """ |
| 1128 | + pass |
| 1129 | + |
| 1130 | + @property |
| 1131 | + def emb_module( |
| 1132 | + self, |
| 1133 | + ) -> SSDTableBatchedEmbeddingBags: |
| 1134 | + return self._emb_module |
| 1135 | + |
| 1136 | + @property |
| 1137 | + def fused_optimizer(self) -> FusedOptimizer: |
| 1138 | + """ |
| 1139 | + SSD Embedding fuses backward with backward. |
| 1140 | + """ |
| 1141 | + return self._optim |
| 1142 | + |
| 1143 | + # TODO: this is a temporary hack, we should read shard info from torchrec sharding plan |
| 1144 | + def get_sharded_local_buckets(self) -> List[Tuple[int, int, int]]: |
| 1145 | + """ |
| 1146 | + utils to get bucket offset start, bucket offset end, bucket size based on embedding sharding spec |
| 1147 | + """ |
| 1148 | + sharded_local_buckets: List[Tuple[int, int, int]] = [] |
| 1149 | + world_size = dist.get_world_size(self._pg) |
| 1150 | + local_rank = dist.get_rank(self._pg) |
| 1151 | + |
| 1152 | + for table in self._config.embedding_tables: |
| 1153 | + # temporary before uneven sharding utils is ready |
| 1154 | + # Question, what happen if we have uneven sharding on the training side? |
| 1155 | + assert ( |
| 1156 | + table.num_embeddings % world_size == 0 |
| 1157 | + ), "total_num_embeddings must be divisible by world_size" |
| 1158 | + total_num_buckets = none_throws(table.total_num_buckets) |
| 1159 | + bucket_offset_start = total_num_buckets // world_size * local_rank |
| 1160 | + bucket_offset_end = min( |
| 1161 | + total_num_buckets, total_num_buckets // world_size * (local_rank + 1) |
| 1162 | + ) |
| 1163 | + bucket_size = ( |
| 1164 | + table.num_embeddings + total_num_buckets - 1 |
| 1165 | + ) // total_num_buckets |
| 1166 | + sharded_local_buckets.append( |
| 1167 | + (bucket_offset_start, bucket_offset_end, bucket_size) |
| 1168 | + ) |
| 1169 | + logger.info( |
| 1170 | + f"bucket_offset: {bucket_offset_start}:{bucket_offset_end}, bucket_size: {bucket_size} for table {table.name}" |
| 1171 | + ) |
| 1172 | + return sharded_local_buckets |
| 1173 | + |
| 1174 | + def state_dict( |
| 1175 | + self, |
| 1176 | + destination: Optional[Dict[str, Any]] = None, |
| 1177 | + prefix: str = "", |
| 1178 | + keep_vars: bool = False, |
| 1179 | + no_snapshot: bool = True, |
| 1180 | + ) -> Dict[str, Any]: |
| 1181 | + """ |
| 1182 | + Args: |
| 1183 | + no_snapshot (bool): the tensors in the returned dict are |
| 1184 | + PartiallyMaterializedTensors. this argument controls wether the |
| 1185 | + PartiallyMaterializedTensor owns a RocksDB snapshot handle. True means the |
| 1186 | + PartiallyMaterializedTensor doesn't have a RocksDB snapshot handle. False means the |
| 1187 | + PartiallyMaterializedTensor has a RocksDB snapshot handle |
| 1188 | + """ |
| 1189 | + # in the case no_snapshot=False, a flush is required. we rely on the flush operation in |
| 1190 | + # ShardedEmbeddingBagCollection._pre_state_dict_hook() |
| 1191 | + |
| 1192 | + emb_tables, _, _ = self.split_embedding_weights(no_snapshot=no_snapshot) |
| 1193 | + emb_table_config_copy = copy.deepcopy(self._config.embedding_tables) |
| 1194 | + for emb_table in emb_table_config_copy: |
| 1195 | + emb_table.local_metadata.placement._device = torch.device("cpu") |
| 1196 | + ret = get_state_dict( |
| 1197 | + emb_table_config_copy, |
| 1198 | + emb_tables, |
| 1199 | + self._pg, |
| 1200 | + destination, |
| 1201 | + prefix, |
| 1202 | + ) |
| 1203 | + return ret |
| 1204 | + |
| 1205 | + def named_parameters( |
| 1206 | + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True |
| 1207 | + ) -> Iterator[Tuple[str, nn.Parameter]]: |
| 1208 | + """ |
| 1209 | + Only allowed ways to get state_dict. |
| 1210 | + """ |
| 1211 | + for name, tensor in self.named_split_embedding_weights( |
| 1212 | + prefix, recurse, remove_duplicate |
| 1213 | + ): |
| 1214 | + # hack before we support optimizer on sharded parameter level |
| 1215 | + # can delete after PEA deprecation |
| 1216 | + # pyre-ignore [6] |
| 1217 | + param = nn.Parameter(tensor) |
| 1218 | + # pyre-ignore |
| 1219 | + param._in_backward_optimizers = [EmptyFusedOptimizer()] |
| 1220 | + yield name, param |
| 1221 | + |
| 1222 | + # pyre-ignore [15] |
| 1223 | + def named_split_embedding_weights( |
| 1224 | + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True |
| 1225 | + ) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]: |
| 1226 | + assert ( |
| 1227 | + remove_duplicate |
| 1228 | + ), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights" |
| 1229 | + for config, tensor in zip( |
| 1230 | + self._config.embedding_tables, |
| 1231 | + self.split_embedding_weights()[0], |
| 1232 | + ): |
| 1233 | + key = append_prefix(prefix, f"{config.name}.weight") |
| 1234 | + yield key, tensor |
| 1235 | + |
| 1236 | + def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterator[ |
| 1237 | + Tuple[ |
| 1238 | + str, |
| 1239 | + Union[ShardedTensor, PartiallyMaterializedTensor], |
| 1240 | + Optional[ShardedTensor], |
| 1241 | + Optional[ShardedTensor], |
| 1242 | + ] |
| 1243 | + ]: |
| 1244 | + """ |
| 1245 | + Return an iterator over embedding tables, for each table yielding |
| 1246 | + table name, |
| 1247 | + PMT for embedding table with a valid RocksDB snapshot to support tensor IO |
| 1248 | + optional ShardedTensor for weight_id |
| 1249 | + optional ShardedTensor for bucket_cnt |
| 1250 | + """ |
| 1251 | + if self._split_weights_res is not None: |
| 1252 | + pmt_sharded_t_list = self._split_weights_res[0] |
| 1253 | + # pyre-ignore |
| 1254 | + weight_id_sharded_t_list = self._split_weights_res[1] |
| 1255 | + bucket_cnt_sharded_t_list = self._split_weights_res[2] |
| 1256 | + for table_idx, pmt_sharded_t in enumerate(pmt_sharded_t_list): |
| 1257 | + table_config = self._config.embedding_tables[table_idx] |
| 1258 | + key = append_prefix(prefix, f"{table_config.name}") |
| 1259 | + |
| 1260 | + yield key, pmt_sharded_t, weight_id_sharded_t_list[ |
| 1261 | + table_idx |
| 1262 | + ], bucket_cnt_sharded_t_list[table_idx] |
| 1263 | + return |
| 1264 | + |
| 1265 | + pmt_list, weight_ids_list, bucket_cnt_list = self.split_embedding_weights( |
| 1266 | + no_snapshot=False |
| 1267 | + ) |
| 1268 | + emb_table_config_copy = copy.deepcopy(self._config.embedding_tables) |
| 1269 | + for emb_table in emb_table_config_copy: |
| 1270 | + emb_table.local_metadata.placement._device = torch.device("cpu") |
| 1271 | + |
| 1272 | + pmt_sharded_t_list = create_virtual_sharded_tensors( |
| 1273 | + emb_table_config_copy, pmt_list, self._pg, prefix |
| 1274 | + ) |
| 1275 | + weight_id_sharded_t_list = create_virtual_sharded_tensors( |
| 1276 | + emb_table_config_copy, weight_ids_list, self._pg, prefix # pyre-ignore |
| 1277 | + ) |
| 1278 | + bucket_cnt_sharded_t_list = create_virtual_sharded_tensors( |
| 1279 | + emb_table_config_copy, bucket_cnt_list, self._pg, prefix # pyre-ignore |
| 1280 | + ) |
| 1281 | + # pyre-ignore |
| 1282 | + assert len(pmt_list) == len(weight_ids_list) == len(bucket_cnt_list) |
| 1283 | + assert ( |
| 1284 | + len(pmt_sharded_t_list) |
| 1285 | + == len(weight_id_sharded_t_list) |
| 1286 | + == len(bucket_cnt_sharded_t_list) |
| 1287 | + ) |
| 1288 | + for table_idx, pmt_sharded_t in enumerate(pmt_sharded_t_list): |
| 1289 | + table_config = self._config.embedding_tables[table_idx] |
| 1290 | + key = append_prefix(prefix, f"{table_config.name}") |
| 1291 | + |
| 1292 | + yield key, pmt_sharded_t, weight_id_sharded_t_list[ |
| 1293 | + table_idx |
| 1294 | + ], bucket_cnt_sharded_t_list[table_idx] |
| 1295 | + |
| 1296 | + self._split_weights_res = ( |
| 1297 | + pmt_sharded_t_list, |
| 1298 | + weight_id_sharded_t_list, |
| 1299 | + bucket_cnt_sharded_t_list, |
| 1300 | + ) |
| 1301 | + |
| 1302 | + def flush(self) -> None: |
| 1303 | + """ |
| 1304 | + Flush the embeddings in cache back to SSD. Should be pretty expensive. |
| 1305 | + """ |
| 1306 | + self.emb_module.flush() |
| 1307 | + |
| 1308 | + def purge(self) -> None: |
| 1309 | + """ |
| 1310 | + Reset the cache space. This is needed when we load state dict. |
| 1311 | + """ |
| 1312 | + # TODO: move the following to SSD TBE. |
| 1313 | + self.emb_module.lxu_cache_weights.zero_() |
| 1314 | + self.emb_module.lxu_cache_state.fill_(-1) |
| 1315 | + |
| 1316 | + # pyre-ignore [15] |
| 1317 | + def split_embedding_weights(self, no_snapshot: bool = True) -> Tuple[ |
| 1318 | + List[PartiallyMaterializedTensor], |
| 1319 | + Optional[List[torch.Tensor]], |
| 1320 | + Optional[List[torch.Tensor]], |
| 1321 | + ]: |
| 1322 | + return self.emb_module.split_embedding_weights(no_snapshot) |
| 1323 | + |
| 1324 | + def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: |
| 1325 | + # reset split weights during training |
| 1326 | + self._split_weights_res = None |
| 1327 | + |
| 1328 | + return self.emb_module( |
| 1329 | + indices=features.values().long(), |
| 1330 | + offsets=features.offsets().long(), |
| 1331 | + ) |
| 1332 | + |
| 1333 | + |
1053 | 1334 | class BatchedFusedEmbedding(BaseBatchedEmbedding[torch.Tensor], FusedOptimizerModule):
|
1054 | 1335 | def __init__(
|
1055 | 1336 | self,
|
@@ -1518,6 +1799,8 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
|
1518 | 1799 | Return an iterator over embedding tables, yielding both the table name as well as the embedding
|
1519 | 1800 | table itself. The embedding table is in the form of PartiallyMaterializedTensor with a valid
|
1520 | 1801 | RocksDB snapshot to support windowed access.
|
| 1802 | + optional ShardedTensor for weight_id, this won't be used here as this is non-kvzch |
| 1803 | + optional ShardedTensor for bucket_cnt, this won't be used here as this is non-kvzch |
1521 | 1804 | """
|
1522 | 1805 | for config, tensor in zip(
|
1523 | 1806 | self._config.embedding_tables,
|
|
0 commit comments