diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index 651a4bb75..4c70f0d25 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -339,12 +339,23 @@ def emb_module( ]: ... + @property def config(self) -> GroupedEmbeddingConfig: return self._config def flush(self) -> None: pass + def named_buffers( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, torch.Tensor]]: + for config, param in zip( + self._config.embedding_tables, + self.emb_module.split_embedding_weights(), + ): + key = append_prefix(prefix, f"{config.name}.weight") + yield key, param + class BatchedFusedEmbedding(BaseBatchedEmbedding, FusedOptimizerModule): def __init__( @@ -555,12 +566,23 @@ def emb_module( ]: ... + @property def config(self) -> GroupedEmbeddingConfig: return self._config def flush(self) -> None: pass + def named_buffers( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, torch.Tensor]]: + for config, param in zip( + self._config.embedding_tables, + self.emb_module.split_embedding_weights(), + ): + key = append_prefix(prefix, f"{config.name}.weight") + yield key, param + class BatchedFusedEmbeddingBag(BaseBatchedEmbeddingBag, FusedOptimizerModule): def __init__( @@ -630,16 +652,6 @@ def named_parameters( ) -> Iterator[Tuple[str, nn.Parameter]]: yield from () - def named_buffers( - self, prefix: str = "", recurse: bool = True - ) -> Iterator[Tuple[str, torch.Tensor]]: - for config, param in zip( - self._config.embedding_tables, - self.emb_module.split_embedding_weights(), - ): - key = append_prefix(prefix, f"{config.name}.weight") - yield key, param - def flush(self) -> None: self._emb_module.flush() diff --git a/torchrec/distributed/embedding_kernel.py b/torchrec/distributed/embedding_kernel.py index bc997f1ac..1f90df4b3 100644 --- a/torchrec/distributed/embedding_kernel.py +++ b/torchrec/distributed/embedding_kernel.py @@ -204,9 +204,19 @@ def sparse_grad_parameter_names( destination.append(append_prefix(prefix, f"{config.name}.weight")) return destination + @property def config(self) -> GroupedEmbeddingConfig: return self._config + def named_buffers( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, torch.Tensor]]: + for config, emb_module in zip( + self._config.embedding_tables, + self._emb_modules, + ): + yield append_prefix(prefix, f"{config.name}.weight"), emb_module.weight + class GroupedEmbeddingBag(BaseEmbedding): def __init__( @@ -226,7 +236,6 @@ def __init__( self._emb_names: List[str] = [] self._lengths_per_emb: List[int] = [] - shared_feature: Dict[str, bool] = {} for embedding_config in self._config.embedding_tables: self._emb_modules.append( nn.EmbeddingBag( @@ -294,6 +303,18 @@ def named_parameters( assert config.local_cols == param.size(1) yield append_prefix(prefix, f"{config.name}.weight"), param + def named_buffers( + self, prefix: str = "", recurse: bool = True + ) -> Iterator[Tuple[str, torch.Tensor]]: + for config, emb_module in zip( + self._config.embedding_tables, + self._emb_modules, + ): + param = emb_module.weight + assert config.local_rows == param.size(0) + assert config.local_cols == param.size(1) + yield append_prefix(prefix, f"{config.name}.weight"), param + def sparse_grad_parameter_names( self, destination: Optional[List[str]] = None, prefix: str = "" ) -> List[str]: @@ -303,5 +324,6 @@ def sparse_grad_parameter_names( destination.append(append_prefix(prefix, f"{config.name}.weight")) return destination + @property def config(self) -> GroupedEmbeddingConfig: return self._config diff --git a/torchrec/distributed/quant_embedding_kernel.py b/torchrec/distributed/quant_embedding_kernel.py index 60764a801..86f4f2b73 100644 --- a/torchrec/distributed/quant_embedding_kernel.py +++ b/torchrec/distributed/quant_embedding_kernel.py @@ -5,8 +5,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import copy import logging -from typing import List, Optional, Tuple, Iterator +from typing import List, Optional, Tuple, Iterator, Dict import torch import torch.distributed as dist @@ -15,19 +16,80 @@ PoolingMode, EmbeddingLocation, IntNBitTableBatchedEmbeddingBagsCodegen, + rounded_row_size_in_bytes, ) from torchrec.distributed.batched_embedding_kernel import ( BaseBatchedEmbeddingBag, BaseBatchedEmbedding, ) +from torchrec.distributed.embedding_kernel import BaseEmbedding from torchrec.distributed.embedding_types import GroupedEmbeddingConfig from torchrec.distributed.utils import append_prefix -from torchrec.modules.embedding_configs import data_type_to_sparse_type +from torchrec.modules.embedding_configs import ( + DATA_TYPE_NUM_BITS, + dtype_to_data_type, + data_type_to_sparse_type, + DataType, +) from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + logger: logging.Logger = logging.getLogger(__name__) +def _copy_config( + original: GroupedEmbeddingConfig, data_type: DataType, sparse_type: SparseType +) -> GroupedEmbeddingConfig: + # Adjust config to quantized version. + # This obviously doesn't work for column-wise sharding. + config = copy.deepcopy(original) + config.data_type = data_type + for table in config.embedding_tables: + table.local_cols = rounded_row_size_in_bytes(table.local_cols, sparse_type) + if table.local_metadata is not None: + table.local_metadata.shard_sizes = [ + table.local_rows, + table.local_cols, + ] + + global_metadata = table.global_metadata + if global_metadata is not None: + for shard_meta in global_metadata.shards_metadata: + if shard_meta != table.local_metadata: + shard_meta.shard_sizes = [ + shard_meta.shard_sizes[0], + rounded_row_size_in_bytes( + shard_meta.shard_sizes[1], sparse_type + ), + ] + global_metadata.size = torch.Size( + [ + global_metadata.size[0], + sum( + shard_meta.shard_sizes[1] + for shard_meta in global_metadata.shards_metadata + ), + ] + ) + return config + + +def _quantize_weight( + state_dict: Dict[str, torch.Tensor], data_type: DataType +) -> List[Tuple[torch.Tensor, Optional[torch.Tensor]]]: + quant_weight_list = [] + for weight in state_dict.values(): + quantized_weights = torch.ops.fbgemm.FloatToFusedNBitRowwiseQuantizedSBHalf( + weight, DATA_TYPE_NUM_BITS[data_type] + ) + # weight and 4 byte scale shift (2xfp16) + quant_weight = quantized_weights[:, :-4] + scale_shift = quantized_weights[:, -4:] + + quant_weight_list.append((quant_weight, scale_shift)) + return quant_weight_list + + class QuantBatchedEmbeddingBag(BaseBatchedEmbeddingBag): def __init__( self, @@ -77,7 +139,7 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor: indices=features.values().int(), offsets=features.offsets().int(), per_sample_weights=features.weights_or_none(), - ).float() + ) def named_buffers( self, prefix: str = "", recurse: bool = True @@ -96,6 +158,27 @@ def split_embedding_weights(self) -> List[torch.Tensor]: ) ] + @classmethod + def from_float(cls, module: BaseEmbedding) -> "QuantBatchedEmbeddingBag": + assert hasattr( + module, "qconfig" + ), "BaseEmbedding input float module must have qconfig defined" + + # pyre-ignore [16] + data_type = dtype_to_data_type(module.qconfig.weight().dtype) + sparse_type = data_type_to_sparse_type(data_type) + + state_dict = dict(module.named_buffers()) + device = next(iter(state_dict.values())).device + + config = _copy_config(module.config, data_type, sparse_type) + ret = QuantBatchedEmbeddingBag(config=config, device=device) + + quant_weight_list = _quantize_weight(state_dict, data_type) + ret.emb_module.assign_embedding_weights(quant_weight_list) + + return ret + class QuantBatchedEmbedding(BaseBatchedEmbedding): def __init__( @@ -159,3 +242,24 @@ def named_buffers( self.emb_module.split_embedding_weights(), ): yield append_prefix(prefix, f"{config.name}.weight"), weight[0] + + @classmethod + def from_float(cls, module: BaseEmbedding) -> "QuantBatchedEmbedding": + assert hasattr( + module, "qconfig" + ), "BaseEmbedding input float module must have qconfig defined" + + # pyre-ignore [16] + data_type = dtype_to_data_type(module.qconfig.weight().dtype) + sparse_type = data_type_to_sparse_type(data_type) + + state_dict = dict(module.named_buffers()) + device = next(iter(state_dict.values())).device + + config = _copy_config(module.config, data_type, sparse_type) + ret = QuantBatchedEmbedding(config=config, device=device) + + quant_weight_list = _quantize_weight(state_dict, data_type) + ret.emb_module.assign_embedding_weights(quant_weight_list) + + return ret diff --git a/torchrec/distributed/tests/test_quantize.py b/torchrec/distributed/tests/test_quantize.py new file mode 100644 index 000000000..4deea8b2a --- /dev/null +++ b/torchrec/distributed/tests/test_quantize.py @@ -0,0 +1,207 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import unittest + +import hypothesis.strategies as st +import torch +import torch.distributed as dist +import torch.quantization as quant +from hypothesis import Verbosity, given, settings +from torchrec.distributed.embedding_kernel import GroupedEmbeddingBag +from torchrec.distributed.embedding_lookup import ( + BatchedFusedEmbeddingBag, + BatchedDenseEmbeddingBag, + GroupedPooledEmbeddingsLookup, + GroupedEmbeddingsLookup, + BatchedDenseEmbedding, + BatchedFusedEmbedding, + GroupedEmbedding, +) +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.embedding_types import ( + GroupedEmbeddingConfig, + ShardedEmbeddingTable, + ShardMetadata, + ShardedTensorMetadata, +) +from torchrec.distributed.quant_embedding_kernel import ( + QuantBatchedEmbeddingBag, + QuantBatchedEmbedding, +) +from torchrec.modules.embedding_configs import ( + PoolingType, + DataType, +) +from torchrec.test_utils import get_free_port + + +def quantize_sharded_embeddings( + module: torch.nn.Module, dtype: torch.dtype +) -> torch.nn.Module: + qconfig = quant.QConfigDynamic( + activation=quant.PlaceholderObserver, + weight=quant.PlaceholderObserver.with_args(dtype=dtype), + ) + return quant.quantize_dynamic( + module, + qconfig_spec={ + GroupedEmbeddingBag: qconfig, + BatchedFusedEmbeddingBag: qconfig, + BatchedDenseEmbeddingBag: qconfig, + BatchedDenseEmbedding: qconfig, + BatchedFusedEmbedding: qconfig, + GroupedEmbedding: qconfig, + }, + mapping={ + GroupedEmbeddingBag: QuantBatchedEmbeddingBag, + BatchedFusedEmbeddingBag: QuantBatchedEmbeddingBag, + BatchedDenseEmbeddingBag: QuantBatchedEmbeddingBag, + BatchedDenseEmbedding: QuantBatchedEmbedding, + BatchedFusedEmbedding: QuantBatchedEmbedding, + GroupedEmbedding: QuantBatchedEmbedding, + }, + inplace=False, + ) + + +class QuantizeKernelTest(unittest.TestCase): + def setUp(self) -> None: + os.environ["RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" + os.environ["LOCAL_WORLD_SIZE"] = "1" + os.environ["MASTER_ADDR"] = str("localhost") + os.environ["MASTER_PORT"] = str(get_free_port()) + os.environ["NCCL_SOCKET_IFNAME"] = "lo" + self.device = torch.device("cuda:0") + backend = "nccl" + torch.cuda.set_device(self.device) + dist.init_process_group(backend=backend) + + def tearDown(self) -> None: + dist.destroy_process_group() + del os.environ["NCCL_SOCKET_IFNAME"] + super().tearDown() + + def _create_config( + self, compute_kernel: EmbeddingComputeKernel + ) -> GroupedEmbeddingConfig: + num_embedding_tables = 2 + embedding_tables = [] + for i in range(num_embedding_tables): + rows = (i + 1) * 10 + cols = 16 + local_metadata = ShardMetadata( + shard_offsets=[0, 0], + shard_sizes=[rows, cols], + placement=torch.distributed._remote_device("rank:0/cuda:0"), + ) + embedding_tables.append( + ShardedEmbeddingTable( + num_embeddings=rows, + embedding_dim=cols, + name="table_" + str(i), + feature_names=["feature_" + str(i)], + pooling=PoolingType.MEAN, + is_weighted=False, + has_feature_processor=False, + local_rows=rows, + local_cols=cols, + compute_kernel=compute_kernel, + local_metadata=local_metadata, + global_metadata=ShardedTensorMetadata( + shards_metadata=[local_metadata], + size=torch.Size([rows, cols]), + ), + weight_init_max=1.0, + weight_init_min=0.0, + ) + ) + return GroupedEmbeddingConfig( + data_type=DataType.FP32, + pooling=PoolingType.MEAN, + is_weighted=False, + has_feature_processor=False, + compute_kernel=compute_kernel, + embedding_tables=embedding_tables, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 0, + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore [56] + @given( + compute_kernel=st.sampled_from( + [ + EmbeddingComputeKernel.BATCHED_DENSE, + EmbeddingComputeKernel.BATCHED_FUSED, + EmbeddingComputeKernel.DENSE, + EmbeddingComputeKernel.SPARSE, + ] + ), + dtype=st.sampled_from( + [ + torch.qint8, + torch.quint4x2, + torch.quint2x4, + ] + ), + ) + @settings(verbosity=Verbosity.verbose, max_examples=12, deadline=None) + def test_quantize_embedding_bag_kernels( + self, compute_kernel: EmbeddingComputeKernel, dtype: torch.dtype + ) -> None: + config = self._create_config(compute_kernel) + sharded = GroupedPooledEmbeddingsLookup( + grouped_configs=[config], + grouped_score_configs=[], + device=torch.device("cuda:0"), + ) + + quantized = quantize_sharded_embeddings(sharded, dtype=dtype) + + for _, buffer in quantized.named_buffers(): + self.assertEqual(buffer.dtype, torch.uint8) + + @unittest.skipIf( + torch.cuda.device_count() <= 0, + "Not enough GPUs, this test requires at least one GPU", + ) + # pyre-ignore [56] + @given( + compute_kernel=st.sampled_from( + [ + EmbeddingComputeKernel.BATCHED_DENSE, + EmbeddingComputeKernel.BATCHED_FUSED, + EmbeddingComputeKernel.DENSE, + EmbeddingComputeKernel.SPARSE, + ] + ), + dtype=st.sampled_from( + [ + torch.qint8, + torch.quint4x2, + torch.quint2x4, + ] + ), + ) + @settings(verbosity=Verbosity.verbose, max_examples=12, deadline=None) + def test_quantize_embedding_kernels( + self, compute_kernel: EmbeddingComputeKernel, dtype: torch.dtype + ) -> None: + config = self._create_config(compute_kernel) + sharded = GroupedEmbeddingsLookup( + grouped_configs=[config], + device=torch.device("cuda:0"), + ) + + quantized = quantize_sharded_embeddings(sharded, dtype=dtype) + + for _, buffer in quantized.named_buffers(): + self.assertEqual(buffer.dtype, torch.uint8) diff --git a/torchrec/modules/embedding_configs.py b/torchrec/modules/embedding_configs.py index 7cc327004..9bbbfa996 100644 --- a/torchrec/modules/embedding_configs.py +++ b/torchrec/modules/embedding_configs.py @@ -48,9 +48,9 @@ class DataType(Enum): def dtype_to_data_type(dtype: torch.dtype) -> DataType: if dtype == torch.quint8 or dtype == torch.qint8: return DataType.INT8 - elif dtype == torch.quint4 or dtype == torch.qint4: + elif dtype == torch.quint4x2: return DataType.INT4 - elif dtype == torch.quint2 or dtype == torch.qint2: + elif dtype == torch.quint2x4: return DataType.INT2 else: raise Exception(f"Invalid data type {dtype}")