diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index 04f07a574..390413b2d 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -215,6 +215,237 @@ def config(self) -> GroupedEmbeddingConfig: return self._config +class BaseBatchedEmbedding(BaseEmbedding): + def __init__( + self, + config: GroupedEmbeddingConfig, + pg: Optional[dist.ProcessGroup] = None, + device: Optional[torch.device] = None, + ) -> None: + super().__init__() + torch._C._log_api_usage_once(f"torchrec.distributed.{self.__class__.__name__}") + self._config = config + self._pg = pg + + self._local_rows: List[int] = [] + self._weight_init_mins: List[float] = [] + self._weight_init_maxs: List[float] = [] + self._num_embeddings: List[int] = [] + self._local_cols: List[int] = [] + self._feature_table_map: List[int] = [] + + for idx, config in enumerate(self._config.embedding_tables): + self._local_rows.append(config.local_rows) + self._weight_init_mins.append(config.get_weight_init_min()) + self._weight_init_maxs.append(config.get_weight_init_max()) + self._num_embeddings.append(config.num_embeddings) + self._local_cols.append(config.local_cols) + self._feature_table_map.extend([idx] * config.num_features()) + + def init_parameters(self) -> None: + # initialize embedding weights + assert len(self._num_embeddings) == len(self.split_embedding_weights()) + for (rows, emb_dim, weight_init_min, weight_init_max, param) in zip( + self._local_rows, + self._local_cols, + self._weight_init_mins, + self._weight_init_maxs, + self.split_embedding_weights(), + ): + assert param.shape == (rows, emb_dim) + param.data.uniform_( + weight_init_min, + weight_init_max, + ) + + def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: + return self.emb_module( + indices=features.values().long(), + offsets=features.offsets().long(), + ) + + def state_dict( + self, + destination: Optional[Dict[str, Any]] = None, + prefix: str = "", + keep_vars: bool = False, + ) -> Dict[str, Any]: + if destination is None: + destination = OrderedDict() + # pyre-ignore [16] + destination._metadata = OrderedDict() + + for config, param in zip( + self._config.embedding_tables, + self.split_embedding_weights(), + ): + key = prefix + f"{config.name}.weight" + assert config.local_rows == param.size(0) + assert config.local_cols == param.size(1) + if config.global_metadata is not None: + # set additional field of sharded tensor based on local tensor properties + config.global_metadata.tensor_properties.dtype = param.dtype + config.global_metadata.tensor_properties.requires_grad = ( + param.requires_grad + ) + destination[ + key + ] = ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards=[Shard(param, config.local_metadata)], + sharded_tensor_metadata=config.global_metadata, + process_group=self._pg, + ) + else: + destination[key] = param + return destination + + def split_embedding_weights(self) -> List[torch.Tensor]: + return self.emb_module.split_embedding_weights() + + @property + @abc.abstractmethod + def emb_module( + self, + ) -> Union[ + DenseTableBatchedEmbeddingBagsCodegen, + SplitTableBatchedEmbeddingBagsCodegen, + IntNBitTableBatchedEmbeddingBagsCodegen, + ]: + ... + + def config(self) -> GroupedEmbeddingConfig: + return self._config + + +class BatchedFusedEmbedding(BaseBatchedEmbedding, FusedOptimizerModule): + def __init__( + self, + config: GroupedEmbeddingConfig, + pg: Optional[dist.ProcessGroup] = None, + device: Optional[torch.device] = None, + fused_params: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__(config, pg, device) + + def to_embedding_location( + compute_kernel: EmbeddingComputeKernel, + ) -> EmbeddingLocation: + if compute_kernel == EmbeddingComputeKernel.BATCHED_FUSED: + return EmbeddingLocation.DEVICE + elif compute_kernel == EmbeddingComputeKernel.BATCHED_FUSED_UVM: + return EmbeddingLocation.MANAGED + elif compute_kernel == EmbeddingComputeKernel.BATCHED_FUSED_UVM_CACHING: + return EmbeddingLocation.MANAGED_CACHING + else: + raise ValueError(f"Invalid EmbeddingComputeKernel {compute_kernel}") + + managed: List[EmbeddingLocation] = [] + compute_devices: List[ComputeDevice] = [] + for table in config.embedding_tables: + if device is not None and device.type == "cuda": + compute_devices.append(ComputeDevice.CUDA) + managed.append(to_embedding_location(table.compute_kernel)) + else: + compute_devices.append(ComputeDevice.CPU) + managed.append(EmbeddingLocation.HOST) + if fused_params is None: + fused_params = {} + self._emb_module: SplitTableBatchedEmbeddingBagsCodegen = ( + SplitTableBatchedEmbeddingBagsCodegen( + embedding_specs=list( + zip(self._local_rows, self._local_cols, managed, compute_devices) + ), + feature_table_map=self._feature_table_map, + pooling_mode=PoolingMode.NONE, + weights_precision=BatchedFusedEmbeddingBag.to_sparse_type( + config.data_type + ), + device=device, + **fused_params, + ) + ) + self._optim: EmbeddingFusedOptimizer = EmbeddingFusedOptimizer( + config, + self._emb_module, + pg, + ) + + self.init_parameters() + + @staticmethod + def to_sparse_type(data_type: DataType) -> SparseType: + if data_type == DataType.FP32: + return SparseType.FP32 + elif data_type == DataType.FP16: + return SparseType.FP16 + elif data_type == DataType.INT8: + return SparseType.INT8 + else: + raise ValueError(f"Invalid DataType {data_type}") + + @property + def emb_module( + self, + ) -> SplitTableBatchedEmbeddingBagsCodegen: + return self._emb_module + + @property + def fused_optimizer(self) -> FusedOptimizer: + return self._optim + + def named_parameters( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, nn.Parameter]]: + yield from () + + def named_buffers( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, torch.Tensor]]: + for config, param in zip( + self._config.embedding_tables, + self.emb_module.split_embedding_weights(), + ): + key = append_prefix(prefix, f"{config.name}.weight") + yield key, param + + +class BatchedDenseEmbedding(BaseBatchedEmbedding): + def __init__( + self, + config: GroupedEmbeddingConfig, + pg: Optional[dist.ProcessGroup] = None, + device: Optional[torch.device] = None, + ) -> None: + super().__init__(config, pg, device) + + self._emb_module: DenseTableBatchedEmbeddingBagsCodegen = ( + DenseTableBatchedEmbeddingBagsCodegen( + list(zip(self._local_rows, self._local_cols)), + feature_table_map=self._feature_table_map, + pooling_mode=PoolingMode.NONE, + use_cpu=device is None or device.type == "cpu", + ) + ) + + self.init_parameters() + + @property + def emb_module( + self, + ) -> DenseTableBatchedEmbeddingBagsCodegen: + return self._emb_module + + def named_parameters( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, nn.Parameter]]: + combined_key = "/".join( + [config.name for config in self._config.embedding_tables] + ) + yield append_prefix(prefix, f"{combined_key}.weight"), cast( + nn.Parameter, self._emb_module.weights + ) + + class GroupedEmbeddingsLookup(BaseEmbeddingLookup): def __init__( self, @@ -226,7 +457,20 @@ def __init__( def _create_lookup( config: GroupedEmbeddingConfig, ) -> BaseEmbedding: - if config.compute_kernel == EmbeddingComputeKernel.DENSE: + if config.compute_kernel == EmbeddingComputeKernel.BATCHED_DENSE: + return BatchedDenseEmbedding( + config=config, + pg=pg, + device=device, + ) + elif config.compute_kernel == EmbeddingComputeKernel.BATCHED_FUSED: + return BatchedFusedEmbedding( + config=config, + pg=pg, + device=device, + fused_params=fused_params, + ) + elif config.compute_kernel == EmbeddingComputeKernel.DENSE: return GroupedEmbedding( config=config, sparse=False, @@ -635,7 +879,7 @@ def config(self) -> GroupedEmbeddingConfig: return self._config -class EmbeddingBagFusedOptimizer(FusedOptimizer): +class EmbeddingFusedOptimizer(FusedOptimizer): def __init__( self, config: GroupedEmbeddingConfig, @@ -828,7 +1072,7 @@ def to_embedding_location( **fused_params, ) ) - self._optim: EmbeddingBagFusedOptimizer = EmbeddingBagFusedOptimizer( + self._optim: EmbeddingFusedOptimizer = EmbeddingFusedOptimizer( config, self._emb_module, pg,