Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
304 changes: 159 additions & 145 deletions torchrec/distributed/embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,152 @@ def _load_state_dict(
return missing_keys, unexpected_keys


class EmbeddingFusedOptimizer(FusedOptimizer):
def __init__(
self,
config: GroupedEmbeddingConfig,
emb_module: SplitTableBatchedEmbeddingBagsCodegen,
pg: Optional[dist.ProcessGroup] = None,
) -> None:
self._emb_module: SplitTableBatchedEmbeddingBagsCodegen = emb_module
self._pg = pg

def to_rowwise_sharded_metadata(
local_metadata: ShardMetadata,
global_metadata: ShardedTensorMetadata,
sharding_dim: int,
) -> Tuple[ShardMetadata, ShardedTensorMetadata]:
rw_shards: List[ShardMetadata] = []
rw_local_shard: ShardMetadata = local_metadata
shards_metadata = global_metadata.shards_metadata
# column-wise sharding
# sort the metadata based on column offset and
# we construct the momentum tensor in row-wise sharded way
if sharding_dim == 1:
shards_metadata = sorted(
shards_metadata, key=lambda shard: shard.shard_offsets[1]
)

for idx, shard in enumerate(shards_metadata):
offset = shard.shard_offsets[0]
# for column-wise sharding, we still create row-wise sharded metadata for optimizer
# manually create a row-wise offset

if sharding_dim == 1:
offset = idx * shard.shard_sizes[0]
rw_shard = ShardMetadata(
shard_sizes=[shard.shard_sizes[0]],
shard_offsets=[offset],
placement=shard.placement,
)

if local_metadata == shard:
rw_local_shard = rw_shard

rw_shards.append(rw_shard)

tensor_properties = TensorProperties(
dtype=global_metadata.tensor_properties.dtype,
layout=global_metadata.tensor_properties.layout,
requires_grad=global_metadata.tensor_properties.requires_grad,
memory_format=global_metadata.tensor_properties.memory_format,
pin_memory=global_metadata.tensor_properties.pin_memory,
)
len_rw_shards = len(shards_metadata) if sharding_dim == 1 else 1
rw_metadata = ShardedTensorMetadata(
shards_metadata=rw_shards,
size=torch.Size([global_metadata.size[0] * len_rw_shards]),
tensor_properties=tensor_properties,
)
return rw_local_shard, rw_metadata

# pyre-ignore [33]
state: Dict[Any, Any] = {}
param_group: Dict[str, Any] = {
"params": [],
"lr": emb_module.optimizer_args.learning_rate,
}
params: Dict[str, torch.Tensor] = {}

# Fused optimizers use buffers (they don't use autograd) and we want to make sure
# that state_dict look identical to non-fused version.
split_embedding_weights = emb_module.split_embedding_weights()
for table_config, weight in zip(
config.embedding_tables,
split_embedding_weights,
):
param_group["params"].append(weight)
param_key = table_config.name + ".weight"
params[param_key] = weight

for table_config, optimizer_states, weight in zip(
config.embedding_tables,
emb_module.split_optimizer_states(),
split_embedding_weights,
):
state[weight] = {}
# momentum1
assert table_config.local_rows == optimizer_states[0].size(0)
sharding_dim = (
1 if table_config.local_cols != table_config.embedding_dim else 0
)
momentum1_key = f"{table_config.name}.momentum1"
if optimizer_states[0].dim() == 1:
(local_metadata, sharded_tensor_metadata) = to_rowwise_sharded_metadata(
table_config.local_metadata,
table_config.global_metadata,
sharding_dim,
)
else:
(local_metadata, sharded_tensor_metadata) = (
table_config.local_metadata,
table_config.global_metadata,
)

momentum1 = ShardedTensor._init_from_local_shards_and_global_metadata(
local_shards=[Shard(optimizer_states[0], local_metadata)],
sharded_tensor_metadata=sharded_tensor_metadata,
process_group=self._pg,
)
state[weight][momentum1_key] = momentum1
# momentum2
if len(optimizer_states) == 2:
assert table_config.local_rows == optimizer_states[1].size(0)
momentum2_key = f"{table_config.name}.momentum2"

if optimizer_states[1].dim() == 1:
(
local_metadata,
sharded_tensor_metadata,
) = to_rowwise_sharded_metadata(
table_config.local_metadata,
table_config.global_metadata,
sharding_dim,
)
else:
(local_metadata, sharded_tensor_metadata) = (
table_config.local_metadata,
table_config.global_metadata,
)
momentum2 = ShardedTensor._init_from_local_shards_and_global_metadata(
local_shards=[Shard(optimizer_states[1], local_metadata)],
sharded_tensor_metadata=sharded_tensor_metadata,
process_group=self._pg,
)
state[weight][momentum2_key] = momentum2

super().__init__(params, state, [param_group])

def zero_grad(self, set_to_none: bool = False) -> None:
# pyre-ignore [16]
self._emb_module.set_learning_rate(self.param_groups[0]["lr"])

# pyre-ignore [2]
def step(self, closure: Any = None) -> None:
# pyre-ignore [16]
self._emb_module.set_learning_rate(self.param_groups[0]["lr"])


class BaseEmbedding(abc.ABC, nn.Module):
"""
abstract base class for grouped nn.Embedding
Expand Down Expand Up @@ -270,6 +416,7 @@ def state_dict(
prefix: str = "",
keep_vars: bool = False,
) -> Dict[str, Any]:
self.flush()
if destination is None:
destination = OrderedDict()
# pyre-ignore [16]
Expand Down Expand Up @@ -316,6 +463,9 @@ def emb_module(
def config(self) -> GroupedEmbeddingConfig:
return self._config

def flush(self) -> None:
pass


class BatchedFusedEmbedding(BaseBatchedEmbedding, FusedOptimizerModule):
def __init__(
Expand Down Expand Up @@ -408,6 +558,9 @@ def named_buffers(
key = append_prefix(prefix, f"{config.name}.weight")
yield key, param

def flush(self) -> None:
self._emb_module.flush()


class BatchedDenseEmbedding(BaseBatchedEmbedding):
def __init__(
Expand Down Expand Up @@ -832,6 +985,7 @@ def state_dict(
prefix: str = "",
keep_vars: bool = False,
) -> Dict[str, Any]:
self.flush()
if destination is None:
destination = OrderedDict()
# pyre-ignore [16]
Expand Down Expand Up @@ -878,151 +1032,8 @@ def emb_module(
def config(self) -> GroupedEmbeddingConfig:
return self._config


class EmbeddingFusedOptimizer(FusedOptimizer):
def __init__(
self,
config: GroupedEmbeddingConfig,
emb_module: SplitTableBatchedEmbeddingBagsCodegen,
pg: Optional[dist.ProcessGroup] = None,
) -> None:
self._emb_module: SplitTableBatchedEmbeddingBagsCodegen = emb_module
self._pg = pg

def to_rowwise_sharded_metadata(
local_metadata: ShardMetadata,
global_metadata: ShardedTensorMetadata,
sharding_dim: int,
) -> Tuple[ShardMetadata, ShardedTensorMetadata]:
rw_shards: List[ShardMetadata] = []
rw_local_shard: ShardMetadata = local_metadata
shards_metadata = global_metadata.shards_metadata
# column-wise sharding
# sort the metadata based on column offset and
# we construct the momentum tensor in row-wise sharded way
if sharding_dim == 1:
shards_metadata = sorted(
shards_metadata, key=lambda shard: shard.shard_offsets[1]
)

for idx, shard in enumerate(shards_metadata):
offset = shard.shard_offsets[0]
# for column-wise sharding, we still create row-wise sharded metadata for optimizer
# manually create a row-wise offset

if sharding_dim == 1:
offset = idx * shard.shard_sizes[0]
rw_shard = ShardMetadata(
shard_sizes=[shard.shard_sizes[0]],
shard_offsets=[offset],
placement=shard.placement,
)

if local_metadata == shard:
rw_local_shard = rw_shard

rw_shards.append(rw_shard)

tensor_properties = TensorProperties(
dtype=global_metadata.tensor_properties.dtype,
layout=global_metadata.tensor_properties.layout,
requires_grad=global_metadata.tensor_properties.requires_grad,
memory_format=global_metadata.tensor_properties.memory_format,
pin_memory=global_metadata.tensor_properties.pin_memory,
)
len_rw_shards = len(shards_metadata) if sharding_dim == 1 else 1
rw_metadata = ShardedTensorMetadata(
shards_metadata=rw_shards,
size=torch.Size([global_metadata.size[0] * len_rw_shards]),
tensor_properties=tensor_properties,
)
return rw_local_shard, rw_metadata

# pyre-ignore [33]
state: Dict[Any, Any] = {}
param_group: Dict[str, Any] = {
"params": [],
"lr": emb_module.optimizer_args.learning_rate,
}
params: Dict[str, torch.Tensor] = {}

# Fused optimizers use buffers (they don't use autograd) and we want to make sure
# that state_dict look identical to non-fused version.
split_embedding_weights = emb_module.split_embedding_weights()
for table_config, weight in zip(
config.embedding_tables,
split_embedding_weights,
):
param_group["params"].append(weight)
param_key = table_config.name + ".weight"
params[param_key] = weight

for table_config, optimizer_states, weight in zip(
config.embedding_tables,
emb_module.split_optimizer_states(),
split_embedding_weights,
):
state[weight] = {}
# momentum1
assert table_config.local_rows == optimizer_states[0].size(0)
sharding_dim = (
1 if table_config.local_cols != table_config.embedding_dim else 0
)
momentum1_key = f"{table_config.name}.momentum1"
if optimizer_states[0].dim() == 1:
(local_metadata, sharded_tensor_metadata) = to_rowwise_sharded_metadata(
table_config.local_metadata,
table_config.global_metadata,
sharding_dim,
)
else:
(local_metadata, sharded_tensor_metadata) = (
table_config.local_metadata,
table_config.global_metadata,
)

momentum1 = ShardedTensor._init_from_local_shards_and_global_metadata(
local_shards=[Shard(optimizer_states[0], local_metadata)],
sharded_tensor_metadata=sharded_tensor_metadata,
process_group=self._pg,
)
state[weight][momentum1_key] = momentum1
# momentum2
if len(optimizer_states) == 2:
assert table_config.local_rows == optimizer_states[1].size(0)
momentum2_key = f"{table_config.name}.momentum2"

if optimizer_states[1].dim() == 1:
(
local_metadata,
sharded_tensor_metadata,
) = to_rowwise_sharded_metadata(
table_config.local_metadata,
table_config.global_metadata,
sharding_dim,
)
else:
(local_metadata, sharded_tensor_metadata) = (
table_config.local_metadata,
table_config.global_metadata,
)
momentum2 = ShardedTensor._init_from_local_shards_and_global_metadata(
local_shards=[Shard(optimizer_states[1], local_metadata)],
sharded_tensor_metadata=sharded_tensor_metadata,
process_group=self._pg,
)
state[weight][momentum2_key] = momentum2

super().__init__(params, state, [param_group])

def zero_grad(self, set_to_none: bool = False) -> None:
# pyre-ignore [16]
self._emb_module.set_learning_rate(self.param_groups[0]["lr"])

# pyre-ignore [2]
def step(self, closure: Any = None) -> None:
# pyre-ignore [16]
self._emb_module.set_learning_rate(self.param_groups[0]["lr"])
def flush(self) -> None:
pass


class BatchedFusedEmbeddingBag(BaseBatchedEmbeddingBag, FusedOptimizerModule):
Expand Down Expand Up @@ -1116,6 +1127,9 @@ def named_buffers(
key = append_prefix(prefix, f"{config.name}.weight")
yield key, param

def flush(self) -> None:
self._emb_module.flush()


class BatchedDenseEmbeddingBag(BaseBatchedEmbeddingBag):
def __init__(
Expand Down