Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 22 additions & 10 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,12 +339,23 @@ def emb_module(
]:
...

@property
def config(self) -> GroupedEmbeddingConfig:
return self._config

def flush(self) -> None:
pass

def named_buffers(
self, prefix: str = "", recurse: bool = True
) -> Iterator[Tuple[str, torch.Tensor]]:
for config, param in zip(
self._config.embedding_tables,
self.emb_module.split_embedding_weights(),
):
key = append_prefix(prefix, f"{config.name}.weight")
yield key, param


class BatchedFusedEmbedding(BaseBatchedEmbedding, FusedOptimizerModule):
def __init__(
Expand Down Expand Up @@ -555,12 +566,23 @@ def emb_module(
]:
...

@property
def config(self) -> GroupedEmbeddingConfig:
return self._config

def flush(self) -> None:
pass

def named_buffers(
self, prefix: str = "", recurse: bool = True
) -> Iterator[Tuple[str, torch.Tensor]]:
for config, param in zip(
self._config.embedding_tables,
self.emb_module.split_embedding_weights(),
):
key = append_prefix(prefix, f"{config.name}.weight")
yield key, param


class BatchedFusedEmbeddingBag(BaseBatchedEmbeddingBag, FusedOptimizerModule):
def __init__(
Expand Down Expand Up @@ -630,16 +652,6 @@ def named_parameters(
) -> Iterator[Tuple[str, nn.Parameter]]:
yield from ()

def named_buffers(
self, prefix: str = "", recurse: bool = True
) -> Iterator[Tuple[str, torch.Tensor]]:
for config, param in zip(
self._config.embedding_tables,
self.emb_module.split_embedding_weights(),
):
key = append_prefix(prefix, f"{config.name}.weight")
yield key, param

def flush(self) -> None:
self._emb_module.flush()

Expand Down
24 changes: 23 additions & 1 deletion torchrec/distributed/embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,19 @@ def sparse_grad_parameter_names(
destination.append(append_prefix(prefix, f"{config.name}.weight"))
return destination

@property
def config(self) -> GroupedEmbeddingConfig:
return self._config

def named_buffers(
self, prefix: str = "", recurse: bool = True
) -> Iterator[Tuple[str, torch.Tensor]]:
for config, emb_module in zip(
self._config.embedding_tables,
self._emb_modules,
):
yield append_prefix(prefix, f"{config.name}.weight"), emb_module.weight


class GroupedEmbeddingBag(BaseEmbedding):
def __init__(
Expand All @@ -226,7 +236,6 @@ def __init__(
self._emb_names: List[str] = []
self._lengths_per_emb: List[int] = []

shared_feature: Dict[str, bool] = {}
for embedding_config in self._config.embedding_tables:
self._emb_modules.append(
nn.EmbeddingBag(
Expand Down Expand Up @@ -294,6 +303,18 @@ def named_parameters(
assert config.local_cols == param.size(1)
yield append_prefix(prefix, f"{config.name}.weight"), param

def named_buffers(
self, prefix: str = "", recurse: bool = True
) -> Iterator[Tuple[str, torch.Tensor]]:
for config, emb_module in zip(
self._config.embedding_tables,
self._emb_modules,
):
param = emb_module.weight
assert config.local_rows == param.size(0)
assert config.local_cols == param.size(1)
yield append_prefix(prefix, f"{config.name}.weight"), param

def sparse_grad_parameter_names(
self, destination: Optional[List[str]] = None, prefix: str = ""
) -> List[str]:
Expand All @@ -303,5 +324,6 @@ def sparse_grad_parameter_names(
destination.append(append_prefix(prefix, f"{config.name}.weight"))
return destination

@property
def config(self) -> GroupedEmbeddingConfig:
return self._config
110 changes: 107 additions & 3 deletions torchrec/distributed/quant_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import copy
import logging
from typing import List, Optional, Tuple, Iterator
from typing import List, Optional, Tuple, Iterator, Dict

import torch
import torch.distributed as dist
Expand All @@ -15,19 +16,80 @@
PoolingMode,
EmbeddingLocation,
IntNBitTableBatchedEmbeddingBagsCodegen,
rounded_row_size_in_bytes,
)
from torchrec.distributed.batched_embedding_kernel import (
BaseBatchedEmbeddingBag,
BaseBatchedEmbedding,
)
from torchrec.distributed.embedding_kernel import BaseEmbedding
from torchrec.distributed.embedding_types import GroupedEmbeddingConfig
from torchrec.distributed.utils import append_prefix
from torchrec.modules.embedding_configs import data_type_to_sparse_type
from torchrec.modules.embedding_configs import (
DATA_TYPE_NUM_BITS,
dtype_to_data_type,
data_type_to_sparse_type,
DataType,
)
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor


logger: logging.Logger = logging.getLogger(__name__)


def _copy_config(
original: GroupedEmbeddingConfig, data_type: DataType, sparse_type: SparseType
) -> GroupedEmbeddingConfig:
# Adjust config to quantized version.
# This obviously doesn't work for column-wise sharding.
config = copy.deepcopy(original)
config.data_type = data_type
for table in config.embedding_tables:
table.local_cols = rounded_row_size_in_bytes(table.local_cols, sparse_type)
if table.local_metadata is not None:
table.local_metadata.shard_sizes = [
table.local_rows,
table.local_cols,
]

global_metadata = table.global_metadata
if global_metadata is not None:
for shard_meta in global_metadata.shards_metadata:
if shard_meta != table.local_metadata:
shard_meta.shard_sizes = [
shard_meta.shard_sizes[0],
rounded_row_size_in_bytes(
shard_meta.shard_sizes[1], sparse_type
),
]
global_metadata.size = torch.Size(
[
global_metadata.size[0],
sum(
shard_meta.shard_sizes[1]
for shard_meta in global_metadata.shards_metadata
),
]
)
return config


def _quantize_weight(
state_dict: Dict[str, torch.Tensor], data_type: DataType
) -> List[Tuple[torch.Tensor, Optional[torch.Tensor]]]:
quant_weight_list = []
for weight in state_dict.values():
quantized_weights = torch.ops.fbgemm.FloatToFusedNBitRowwiseQuantizedSBHalf(
weight, DATA_TYPE_NUM_BITS[data_type]
)
# weight and 4 byte scale shift (2xfp16)
quant_weight = quantized_weights[:, :-4]
scale_shift = quantized_weights[:, -4:]

quant_weight_list.append((quant_weight, scale_shift))
return quant_weight_list


class QuantBatchedEmbeddingBag(BaseBatchedEmbeddingBag):
def __init__(
self,
Expand Down Expand Up @@ -77,7 +139,7 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
indices=features.values().int(),
offsets=features.offsets().int(),
per_sample_weights=features.weights_or_none(),
).float()
)

def named_buffers(
self, prefix: str = "", recurse: bool = True
Expand All @@ -96,6 +158,27 @@ def split_embedding_weights(self) -> List[torch.Tensor]:
)
]

@classmethod
def from_float(cls, module: BaseEmbedding) -> "QuantBatchedEmbeddingBag":
assert hasattr(
module, "qconfig"
), "BaseEmbedding input float module must have qconfig defined"

# pyre-ignore [16]
data_type = dtype_to_data_type(module.qconfig.weight().dtype)
sparse_type = data_type_to_sparse_type(data_type)

state_dict = dict(module.named_buffers())
device = next(iter(state_dict.values())).device

config = _copy_config(module.config, data_type, sparse_type)
ret = QuantBatchedEmbeddingBag(config=config, device=device)

quant_weight_list = _quantize_weight(state_dict, data_type)
ret.emb_module.assign_embedding_weights(quant_weight_list)

return ret


class QuantBatchedEmbedding(BaseBatchedEmbedding):
def __init__(
Expand Down Expand Up @@ -159,3 +242,24 @@ def named_buffers(
self.emb_module.split_embedding_weights(),
):
yield append_prefix(prefix, f"{config.name}.weight"), weight[0]

@classmethod
def from_float(cls, module: BaseEmbedding) -> "QuantBatchedEmbedding":
assert hasattr(
module, "qconfig"
), "BaseEmbedding input float module must have qconfig defined"

# pyre-ignore [16]
data_type = dtype_to_data_type(module.qconfig.weight().dtype)
sparse_type = data_type_to_sparse_type(data_type)

state_dict = dict(module.named_buffers())
device = next(iter(state_dict.values())).device

config = _copy_config(module.config, data_type, sparse_type)
ret = QuantBatchedEmbedding(config=config, device=device)

quant_weight_list = _quantize_weight(state_dict, data_type)
ret.emb_module.assign_embedding_weights(quant_weight_list)

return ret
Loading