diff --git a/torchrec/distributed/quant_embedding_kernel.py b/torchrec/distributed/quant_embedding_kernel.py index 808950f4e..4b215ce53 100644 --- a/torchrec/distributed/quant_embedding_kernel.py +++ b/torchrec/distributed/quant_embedding_kernel.py @@ -143,6 +143,7 @@ def _to_data_type(dtype: torch.dtype) -> DataType: itertools.chain(module.named_buffers(), module.named_parameters()) ) device = next(iter(state_dict.values())).device + use_cpu = not (device is not None and device.type == "cuda") # Adjust config to quantized version. # This obviously doesn't work for column-wise sharding. @@ -150,7 +151,9 @@ def _to_data_type(dtype: torch.dtype) -> DataType: config = copy.deepcopy(module.config()) config.data_type = data_type for table in config.embedding_tables: - table.local_cols = rounded_row_size_in_bytes(table.local_cols, sparse_type) + table.local_cols = rounded_row_size_in_bytes( + table.local_cols, sparse_type, use_cpu + ) if table.local_metadata is not None: table.local_metadata.shard_sizes = [ table.local_rows, @@ -163,7 +166,7 @@ def _to_data_type(dtype: torch.dtype) -> DataType: shard_meta.shard_sizes = [ shard_meta.shard_sizes[0], rounded_row_size_in_bytes( - shard_meta.shard_sizes[1], sparse_type + shard_meta.shard_sizes[1], sparse_type, use_cpu ), ] table.global_metadata.size = torch.Size(