Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
250 changes: 247 additions & 3 deletions torchrec/distributed/embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,237 @@ def config(self) -> GroupedEmbeddingConfig:
return self._config


class BaseBatchedEmbedding(BaseEmbedding):
def __init__(
self,
config: GroupedEmbeddingConfig,
pg: Optional[dist.ProcessGroup] = None,
device: Optional[torch.device] = None,
) -> None:
super().__init__()
torch._C._log_api_usage_once(f"torchrec.distributed.{self.__class__.__name__}")
self._config = config
self._pg = pg

self._local_rows: List[int] = []
self._weight_init_mins: List[float] = []
self._weight_init_maxs: List[float] = []
self._num_embeddings: List[int] = []
self._local_cols: List[int] = []
self._feature_table_map: List[int] = []

for idx, config in enumerate(self._config.embedding_tables):
self._local_rows.append(config.local_rows)
self._weight_init_mins.append(config.get_weight_init_min())
self._weight_init_maxs.append(config.get_weight_init_max())
self._num_embeddings.append(config.num_embeddings)
self._local_cols.append(config.local_cols)
self._feature_table_map.extend([idx] * config.num_features())

def init_parameters(self) -> None:
# initialize embedding weights
assert len(self._num_embeddings) == len(self.split_embedding_weights())
for (rows, emb_dim, weight_init_min, weight_init_max, param) in zip(
self._local_rows,
self._local_cols,
self._weight_init_mins,
self._weight_init_maxs,
self.split_embedding_weights(),
):
assert param.shape == (rows, emb_dim)
param.data.uniform_(
weight_init_min,
weight_init_max,
)

def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
return self.emb_module(
indices=features.values().long(),
offsets=features.offsets().long(),
)

def state_dict(
self,
destination: Optional[Dict[str, Any]] = None,
prefix: str = "",
keep_vars: bool = False,
) -> Dict[str, Any]:
if destination is None:
destination = OrderedDict()
# pyre-ignore [16]
destination._metadata = OrderedDict()

for config, param in zip(
self._config.embedding_tables,
self.split_embedding_weights(),
):
key = prefix + f"{config.name}.weight"
assert config.local_rows == param.size(0)
assert config.local_cols == param.size(1)
if config.global_metadata is not None:
# set additional field of sharded tensor based on local tensor properties
config.global_metadata.tensor_properties.dtype = param.dtype
config.global_metadata.tensor_properties.requires_grad = (
param.requires_grad
)
destination[
key
] = ShardedTensor._init_from_local_shards_and_global_metadata(
local_shards=[Shard(param, config.local_metadata)],
sharded_tensor_metadata=config.global_metadata,
process_group=self._pg,
)
else:
destination[key] = param
return destination

def split_embedding_weights(self) -> List[torch.Tensor]:
return self.emb_module.split_embedding_weights()

@property
@abc.abstractmethod
def emb_module(
self,
) -> Union[
DenseTableBatchedEmbeddingBagsCodegen,
SplitTableBatchedEmbeddingBagsCodegen,
IntNBitTableBatchedEmbeddingBagsCodegen,
]:
...

def config(self) -> GroupedEmbeddingConfig:
return self._config


class BatchedFusedEmbedding(BaseBatchedEmbedding, FusedOptimizerModule):
def __init__(
self,
config: GroupedEmbeddingConfig,
pg: Optional[dist.ProcessGroup] = None,
device: Optional[torch.device] = None,
fused_params: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__(config, pg, device)

def to_embedding_location(
compute_kernel: EmbeddingComputeKernel,
) -> EmbeddingLocation:
if compute_kernel == EmbeddingComputeKernel.BATCHED_FUSED:
return EmbeddingLocation.DEVICE
elif compute_kernel == EmbeddingComputeKernel.BATCHED_FUSED_UVM:
return EmbeddingLocation.MANAGED
elif compute_kernel == EmbeddingComputeKernel.BATCHED_FUSED_UVM_CACHING:
return EmbeddingLocation.MANAGED_CACHING
else:
raise ValueError(f"Invalid EmbeddingComputeKernel {compute_kernel}")

managed: List[EmbeddingLocation] = []
compute_devices: List[ComputeDevice] = []
for table in config.embedding_tables:
if device is not None and device.type == "cuda":
compute_devices.append(ComputeDevice.CUDA)
managed.append(to_embedding_location(table.compute_kernel))
else:
compute_devices.append(ComputeDevice.CPU)
managed.append(EmbeddingLocation.HOST)
if fused_params is None:
fused_params = {}
self._emb_module: SplitTableBatchedEmbeddingBagsCodegen = (
SplitTableBatchedEmbeddingBagsCodegen(
embedding_specs=list(
zip(self._local_rows, self._local_cols, managed, compute_devices)
),
feature_table_map=self._feature_table_map,
pooling_mode=PoolingMode.NONE,
weights_precision=BatchedFusedEmbeddingBag.to_sparse_type(
config.data_type
),
device=device,
**fused_params,
)
)
self._optim: EmbeddingFusedOptimizer = EmbeddingFusedOptimizer(
config,
self._emb_module,
pg,
)

self.init_parameters()

@staticmethod
def to_sparse_type(data_type: DataType) -> SparseType:
if data_type == DataType.FP32:
return SparseType.FP32
elif data_type == DataType.FP16:
return SparseType.FP16
elif data_type == DataType.INT8:
return SparseType.INT8
else:
raise ValueError(f"Invalid DataType {data_type}")

@property
def emb_module(
self,
) -> SplitTableBatchedEmbeddingBagsCodegen:
return self._emb_module

@property
def fused_optimizer(self) -> FusedOptimizer:
return self._optim

def named_parameters(
self, prefix: str = "", recurse: bool = True
) -> Iterator[Tuple[str, nn.Parameter]]:
yield from ()

def named_buffers(
self, prefix: str = "", recurse: bool = True
) -> Iterator[Tuple[str, torch.Tensor]]:
for config, param in zip(
self._config.embedding_tables,
self.emb_module.split_embedding_weights(),
):
key = append_prefix(prefix, f"{config.name}.weight")
yield key, param


class BatchedDenseEmbedding(BaseBatchedEmbedding):
def __init__(
self,
config: GroupedEmbeddingConfig,
pg: Optional[dist.ProcessGroup] = None,
device: Optional[torch.device] = None,
) -> None:
super().__init__(config, pg, device)

self._emb_module: DenseTableBatchedEmbeddingBagsCodegen = (
DenseTableBatchedEmbeddingBagsCodegen(
list(zip(self._local_rows, self._local_cols)),
feature_table_map=self._feature_table_map,
pooling_mode=PoolingMode.NONE,
use_cpu=device is None or device.type == "cpu",
)
)

self.init_parameters()

@property
def emb_module(
self,
) -> DenseTableBatchedEmbeddingBagsCodegen:
return self._emb_module

def named_parameters(
self, prefix: str = "", recurse: bool = True
) -> Iterator[Tuple[str, nn.Parameter]]:
combined_key = "/".join(
[config.name for config in self._config.embedding_tables]
)
yield append_prefix(prefix, f"{combined_key}.weight"), cast(
nn.Parameter, self._emb_module.weights
)


class GroupedEmbeddingsLookup(BaseEmbeddingLookup):
def __init__(
self,
Expand All @@ -226,7 +457,20 @@ def __init__(
def _create_lookup(
config: GroupedEmbeddingConfig,
) -> BaseEmbedding:
if config.compute_kernel == EmbeddingComputeKernel.DENSE:
if config.compute_kernel == EmbeddingComputeKernel.BATCHED_DENSE:
return BatchedDenseEmbedding(
config=config,
pg=pg,
device=device,
)
elif config.compute_kernel == EmbeddingComputeKernel.BATCHED_FUSED:
return BatchedFusedEmbedding(
config=config,
pg=pg,
device=device,
fused_params=fused_params,
)
elif config.compute_kernel == EmbeddingComputeKernel.DENSE:
return GroupedEmbedding(
config=config,
sparse=False,
Expand Down Expand Up @@ -635,7 +879,7 @@ def config(self) -> GroupedEmbeddingConfig:
return self._config


class EmbeddingBagFusedOptimizer(FusedOptimizer):
class EmbeddingFusedOptimizer(FusedOptimizer):
def __init__(
self,
config: GroupedEmbeddingConfig,
Expand Down Expand Up @@ -828,7 +1072,7 @@ def to_embedding_location(
**fused_params,
)
)
self._optim: EmbeddingBagFusedOptimizer = EmbeddingBagFusedOptimizer(
self._optim: EmbeddingFusedOptimizer = EmbeddingFusedOptimizer(
config,
self._emb_module,
pg,
Expand Down