Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 87 additions & 83 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,88 +147,6 @@ def get_ec_index_dedup() -> bool:
return EC_INDEX_DEDUP


def create_sharding_infos_by_sharding(
module: EmbeddingCollectionInterface,
table_name_to_parameter_sharding: Dict[str, ParameterSharding],
fused_params: Optional[Dict[str, Any]],
) -> Dict[str, List[EmbeddingShardingInfo]]:

if fused_params is None:
fused_params = {}

sharding_type_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]] = {}
# state_dict returns parameter.Tensor, which loses parameter level attributes
parameter_by_name = dict(module.named_parameters())
# QuantEBC registers weights as buffers (since they are INT8), and so we need to grab it there
state_dict = module.state_dict()

for (
config,
embedding_names,
) in zip(module.embedding_configs(), module.embedding_names_by_table()):
table_name = config.name
assert table_name in table_name_to_parameter_sharding

parameter_sharding = table_name_to_parameter_sharding[table_name]
if parameter_sharding.compute_kernel not in [
kernel.value for kernel in EmbeddingComputeKernel
]:
raise ValueError(
f"Compute kernel not supported {parameter_sharding.compute_kernel}"
)

param_name = "embeddings." + config.name + ".weight"
assert param_name in parameter_by_name or param_name in state_dict
param = parameter_by_name.get(param_name, state_dict[param_name])

if parameter_sharding.sharding_type not in sharding_type_to_sharding_infos:
sharding_type_to_sharding_infos[parameter_sharding.sharding_type] = []

optimizer_params = getattr(param, "_optimizer_kwargs", [{}])
optimizer_classes = getattr(param, "_optimizer_classes", [None])

assert (
len(optimizer_classes) == 1 and len(optimizer_params) == 1
), f"Only support 1 optimizer, given {len(optimizer_classes)}"

optimizer_class = optimizer_classes[0]
optimizer_params = optimizer_params[0]
if optimizer_class:
optimizer_params["optimizer"] = optimizer_type_to_emb_opt_type(
optimizer_class
)

per_table_fused_params = merge_fused_params(fused_params, optimizer_params)
per_table_fused_params = add_params_from_parameter_sharding(
per_table_fused_params, parameter_sharding
)
per_table_fused_params = convert_to_fbgemm_types(per_table_fused_params)

sharding_type_to_sharding_infos[parameter_sharding.sharding_type].append(
(
EmbeddingShardingInfo(
embedding_config=EmbeddingTableConfig(
num_embeddings=config.num_embeddings,
embedding_dim=config.embedding_dim,
name=config.name,
data_type=config.data_type,
feature_names=copy.deepcopy(config.feature_names),
pooling=PoolingType.NONE,
is_weighted=False,
has_feature_processor=False,
embedding_names=embedding_names,
weight_init_max=config.weight_init_max,
weight_init_min=config.weight_init_min,
),
param_sharding=parameter_sharding,
param=param,
fused_params=per_table_fused_params,
)
)
)
return sharding_type_to_sharding_infos


def create_sharding_infos_by_sharding_device_group(
module: EmbeddingCollectionInterface,
table_name_to_parameter_sharding: Dict[str, ParameterSharding],
Expand Down Expand Up @@ -503,7 +421,7 @@ def __init__(
self._output_dtensor: bool = env.output_dtensor
# TODO get rid of get_ec_index_dedup global flag
self._use_index_dedup: bool = use_index_dedup or get_ec_index_dedup()
sharding_type_to_sharding_infos = create_sharding_infos_by_sharding(
sharding_type_to_sharding_infos = self.create_grouped_sharding_infos(
module,
table_name_to_parameter_sharding,
fused_params,
Expand Down Expand Up @@ -597,6 +515,92 @@ def __init__(
if module.device != torch.device("meta"):
self.load_state_dict(module.state_dict())

@classmethod
def create_grouped_sharding_infos(
cls,
module: EmbeddingCollectionInterface,
table_name_to_parameter_sharding: Dict[str, ParameterSharding],
fused_params: Optional[Dict[str, Any]],
) -> Dict[str, List[EmbeddingShardingInfo]]:
"""
convert ParameterSharding (table_name_to_parameter_sharding: Dict[str, ParameterSharding]) to
EmbeddingShardingInfo that are grouped by sharding_type, and propagate the configs/parameters
"""
if fused_params is None:
fused_params = {}

sharding_type_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]] = {}
# state_dict returns parameter.Tensor, which loses parameter level attributes
parameter_by_name = dict(module.named_parameters())
# QuantEBC registers weights as buffers (since they are INT8), and so we need to grab it there
state_dict = module.state_dict()

for (
config,
embedding_names,
) in zip(module.embedding_configs(), module.embedding_names_by_table()):
table_name = config.name
assert table_name in table_name_to_parameter_sharding

parameter_sharding = table_name_to_parameter_sharding[table_name]
if parameter_sharding.compute_kernel not in [
kernel.value for kernel in EmbeddingComputeKernel
]:
raise ValueError(
f"Compute kernel not supported {parameter_sharding.compute_kernel}"
)

param_name = "embeddings." + config.name + ".weight"
assert param_name in parameter_by_name or param_name in state_dict
param = parameter_by_name.get(param_name, state_dict[param_name])

if parameter_sharding.sharding_type not in sharding_type_to_sharding_infos:
sharding_type_to_sharding_infos[parameter_sharding.sharding_type] = []

optimizer_params = getattr(param, "_optimizer_kwargs", [{}])
optimizer_classes = getattr(param, "_optimizer_classes", [None])

assert (
len(optimizer_classes) == 1 and len(optimizer_params) == 1
), f"Only support 1 optimizer, given {len(optimizer_classes)}"

optimizer_class = optimizer_classes[0]
optimizer_params = optimizer_params[0]
if optimizer_class:
optimizer_params["optimizer"] = optimizer_type_to_emb_opt_type(
optimizer_class
)

per_table_fused_params = merge_fused_params(fused_params, optimizer_params)
per_table_fused_params = add_params_from_parameter_sharding(
per_table_fused_params, parameter_sharding
)
per_table_fused_params = convert_to_fbgemm_types(per_table_fused_params)

sharding_type_to_sharding_infos[parameter_sharding.sharding_type].append(
(
EmbeddingShardingInfo(
embedding_config=EmbeddingTableConfig(
num_embeddings=config.num_embeddings,
embedding_dim=config.embedding_dim,
name=config.name,
data_type=config.data_type,
feature_names=copy.deepcopy(config.feature_names),
pooling=PoolingType.NONE,
is_weighted=False,
has_feature_processor=False,
embedding_names=embedding_names,
weight_init_max=config.weight_init_max,
weight_init_min=config.weight_init_min,
),
param_sharding=parameter_sharding,
param=param,
fused_params=per_table_fused_params,
)
)
)
return sharding_type_to_sharding_infos

@classmethod
def create_embedding_sharding(
cls,
Expand Down
Loading
Loading