Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions torchrec/distributed/quant_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,17 @@ def _to_data_type(dtype: torch.dtype) -> DataType:
itertools.chain(module.named_buffers(), module.named_parameters())
)
device = next(iter(state_dict.values())).device
use_cpu = not (device is not None and device.type == "cuda")

# Adjust config to quantized version.
# This obviously doesn't work for column-wise sharding.
# pyre-ignore [29]
config = copy.deepcopy(module.config())
config.data_type = data_type
for table in config.embedding_tables:
table.local_cols = rounded_row_size_in_bytes(table.local_cols, sparse_type)
table.local_cols = rounded_row_size_in_bytes(
table.local_cols, sparse_type, use_cpu
)
if table.local_metadata is not None:
table.local_metadata.shard_sizes = [
table.local_rows,
Expand All @@ -163,7 +166,7 @@ def _to_data_type(dtype: torch.dtype) -> DataType:
shard_meta.shard_sizes = [
shard_meta.shard_sizes[0],
rounded_row_size_in_bytes(
shard_meta.shard_sizes[1], sparse_type
shard_meta.shard_sizes[1], sparse_type, use_cpu
),
]
table.global_metadata.size = torch.Size(
Expand Down