From 7947615a53466b2034402a8f834074fe21599b15 Mon Sep 17 00:00:00 2001 From: Jianyu Huang Date: Wed, 9 Feb 2022 21:19:37 -0800 Subject: [PATCH] Remove padding for CPU TBE op to reduce the memory waste (#14) Summary: Pull Request resolved: https://github.com/facebookresearch/torchrec/pull/14 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/922 CPU TBE doesn't need padding. We want to reduce the potential model loading OOM issue. Differential Revision: D34099684 fbshipit-source-id: 8d7def68101add9f62c48df350e6a3f3451d792f --- torchrec/distributed/quant_embedding_kernel.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torchrec/distributed/quant_embedding_kernel.py b/torchrec/distributed/quant_embedding_kernel.py index 808950f4e..4b215ce53 100644 --- a/torchrec/distributed/quant_embedding_kernel.py +++ b/torchrec/distributed/quant_embedding_kernel.py @@ -143,6 +143,7 @@ def _to_data_type(dtype: torch.dtype) -> DataType: itertools.chain(module.named_buffers(), module.named_parameters()) ) device = next(iter(state_dict.values())).device + use_cpu = not (device is not None and device.type == "cuda") # Adjust config to quantized version. # This obviously doesn't work for column-wise sharding. @@ -150,7 +151,9 @@ def _to_data_type(dtype: torch.dtype) -> DataType: config = copy.deepcopy(module.config()) config.data_type = data_type for table in config.embedding_tables: - table.local_cols = rounded_row_size_in_bytes(table.local_cols, sparse_type) + table.local_cols = rounded_row_size_in_bytes( + table.local_cols, sparse_type, use_cpu + ) if table.local_metadata is not None: table.local_metadata.shard_sizes = [ table.local_rows, @@ -163,7 +166,7 @@ def _to_data_type(dtype: torch.dtype) -> DataType: shard_meta.shard_sizes = [ shard_meta.shard_sizes[0], rounded_row_size_in_bytes( - shard_meta.shard_sizes[1], sparse_type + shard_meta.shard_sizes[1], sparse_type, use_cpu ), ] table.global_metadata.size = torch.Size(