From ab760d3c8cc6ecd2354152ad16fd8e478c24bc4b Mon Sep 17 00:00:00 2001 From: Xing Liu Date: Wed, 8 Dec 2021 10:18:15 -0800 Subject: [PATCH] flush table batched embedding modules when calling state_dict(...) (#27) Summary: Pull Request resolved: https://github.com/facebookresearch/torchrec/pull/27 As title Reviewed By: bigning Differential Revision: D32860295 fbshipit-source-id: ea97e473b1583b045d334b7cfb3d4b0f84552adb --- torchrec/distributed/embedding_lookup.py | 304 ++++++++++++----------- 1 file changed, 159 insertions(+), 145 deletions(-) diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index 390413b2d..a552831ca 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -85,6 +85,152 @@ def _load_state_dict( return missing_keys, unexpected_keys +class EmbeddingFusedOptimizer(FusedOptimizer): + def __init__( + self, + config: GroupedEmbeddingConfig, + emb_module: SplitTableBatchedEmbeddingBagsCodegen, + pg: Optional[dist.ProcessGroup] = None, + ) -> None: + self._emb_module: SplitTableBatchedEmbeddingBagsCodegen = emb_module + self._pg = pg + + def to_rowwise_sharded_metadata( + local_metadata: ShardMetadata, + global_metadata: ShardedTensorMetadata, + sharding_dim: int, + ) -> Tuple[ShardMetadata, ShardedTensorMetadata]: + rw_shards: List[ShardMetadata] = [] + rw_local_shard: ShardMetadata = local_metadata + shards_metadata = global_metadata.shards_metadata + # column-wise sharding + # sort the metadata based on column offset and + # we construct the momentum tensor in row-wise sharded way + if sharding_dim == 1: + shards_metadata = sorted( + shards_metadata, key=lambda shard: shard.shard_offsets[1] + ) + + for idx, shard in enumerate(shards_metadata): + offset = shard.shard_offsets[0] + # for column-wise sharding, we still create row-wise sharded metadata for optimizer + # manually create a row-wise offset + + if sharding_dim == 1: + offset = idx * shard.shard_sizes[0] + rw_shard = ShardMetadata( + shard_sizes=[shard.shard_sizes[0]], + shard_offsets=[offset], + placement=shard.placement, + ) + + if local_metadata == shard: + rw_local_shard = rw_shard + + rw_shards.append(rw_shard) + + tensor_properties = TensorProperties( + dtype=global_metadata.tensor_properties.dtype, + layout=global_metadata.tensor_properties.layout, + requires_grad=global_metadata.tensor_properties.requires_grad, + memory_format=global_metadata.tensor_properties.memory_format, + pin_memory=global_metadata.tensor_properties.pin_memory, + ) + len_rw_shards = len(shards_metadata) if sharding_dim == 1 else 1 + rw_metadata = ShardedTensorMetadata( + shards_metadata=rw_shards, + size=torch.Size([global_metadata.size[0] * len_rw_shards]), + tensor_properties=tensor_properties, + ) + return rw_local_shard, rw_metadata + + # pyre-ignore [33] + state: Dict[Any, Any] = {} + param_group: Dict[str, Any] = { + "params": [], + "lr": emb_module.optimizer_args.learning_rate, + } + params: Dict[str, torch.Tensor] = {} + + # Fused optimizers use buffers (they don't use autograd) and we want to make sure + # that state_dict look identical to non-fused version. + split_embedding_weights = emb_module.split_embedding_weights() + for table_config, weight in zip( + config.embedding_tables, + split_embedding_weights, + ): + param_group["params"].append(weight) + param_key = table_config.name + ".weight" + params[param_key] = weight + + for table_config, optimizer_states, weight in zip( + config.embedding_tables, + emb_module.split_optimizer_states(), + split_embedding_weights, + ): + state[weight] = {} + # momentum1 + assert table_config.local_rows == optimizer_states[0].size(0) + sharding_dim = ( + 1 if table_config.local_cols != table_config.embedding_dim else 0 + ) + momentum1_key = f"{table_config.name}.momentum1" + if optimizer_states[0].dim() == 1: + (local_metadata, sharded_tensor_metadata) = to_rowwise_sharded_metadata( + table_config.local_metadata, + table_config.global_metadata, + sharding_dim, + ) + else: + (local_metadata, sharded_tensor_metadata) = ( + table_config.local_metadata, + table_config.global_metadata, + ) + + momentum1 = ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards=[Shard(optimizer_states[0], local_metadata)], + sharded_tensor_metadata=sharded_tensor_metadata, + process_group=self._pg, + ) + state[weight][momentum1_key] = momentum1 + # momentum2 + if len(optimizer_states) == 2: + assert table_config.local_rows == optimizer_states[1].size(0) + momentum2_key = f"{table_config.name}.momentum2" + + if optimizer_states[1].dim() == 1: + ( + local_metadata, + sharded_tensor_metadata, + ) = to_rowwise_sharded_metadata( + table_config.local_metadata, + table_config.global_metadata, + sharding_dim, + ) + else: + (local_metadata, sharded_tensor_metadata) = ( + table_config.local_metadata, + table_config.global_metadata, + ) + momentum2 = ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards=[Shard(optimizer_states[1], local_metadata)], + sharded_tensor_metadata=sharded_tensor_metadata, + process_group=self._pg, + ) + state[weight][momentum2_key] = momentum2 + + super().__init__(params, state, [param_group]) + + def zero_grad(self, set_to_none: bool = False) -> None: + # pyre-ignore [16] + self._emb_module.set_learning_rate(self.param_groups[0]["lr"]) + + # pyre-ignore [2] + def step(self, closure: Any = None) -> None: + # pyre-ignore [16] + self._emb_module.set_learning_rate(self.param_groups[0]["lr"]) + + class BaseEmbedding(abc.ABC, nn.Module): """ abstract base class for grouped nn.Embedding @@ -270,6 +416,7 @@ def state_dict( prefix: str = "", keep_vars: bool = False, ) -> Dict[str, Any]: + self.flush() if destination is None: destination = OrderedDict() # pyre-ignore [16] @@ -316,6 +463,9 @@ def emb_module( def config(self) -> GroupedEmbeddingConfig: return self._config + def flush(self) -> None: + pass + class BatchedFusedEmbedding(BaseBatchedEmbedding, FusedOptimizerModule): def __init__( @@ -408,6 +558,9 @@ def named_buffers( key = append_prefix(prefix, f"{config.name}.weight") yield key, param + def flush(self) -> None: + self._emb_module.flush() + class BatchedDenseEmbedding(BaseBatchedEmbedding): def __init__( @@ -832,6 +985,7 @@ def state_dict( prefix: str = "", keep_vars: bool = False, ) -> Dict[str, Any]: + self.flush() if destination is None: destination = OrderedDict() # pyre-ignore [16] @@ -878,151 +1032,8 @@ def emb_module( def config(self) -> GroupedEmbeddingConfig: return self._config - -class EmbeddingFusedOptimizer(FusedOptimizer): - def __init__( - self, - config: GroupedEmbeddingConfig, - emb_module: SplitTableBatchedEmbeddingBagsCodegen, - pg: Optional[dist.ProcessGroup] = None, - ) -> None: - self._emb_module: SplitTableBatchedEmbeddingBagsCodegen = emb_module - self._pg = pg - - def to_rowwise_sharded_metadata( - local_metadata: ShardMetadata, - global_metadata: ShardedTensorMetadata, - sharding_dim: int, - ) -> Tuple[ShardMetadata, ShardedTensorMetadata]: - rw_shards: List[ShardMetadata] = [] - rw_local_shard: ShardMetadata = local_metadata - shards_metadata = global_metadata.shards_metadata - # column-wise sharding - # sort the metadata based on column offset and - # we construct the momentum tensor in row-wise sharded way - if sharding_dim == 1: - shards_metadata = sorted( - shards_metadata, key=lambda shard: shard.shard_offsets[1] - ) - - for idx, shard in enumerate(shards_metadata): - offset = shard.shard_offsets[0] - # for column-wise sharding, we still create row-wise sharded metadata for optimizer - # manually create a row-wise offset - - if sharding_dim == 1: - offset = idx * shard.shard_sizes[0] - rw_shard = ShardMetadata( - shard_sizes=[shard.shard_sizes[0]], - shard_offsets=[offset], - placement=shard.placement, - ) - - if local_metadata == shard: - rw_local_shard = rw_shard - - rw_shards.append(rw_shard) - - tensor_properties = TensorProperties( - dtype=global_metadata.tensor_properties.dtype, - layout=global_metadata.tensor_properties.layout, - requires_grad=global_metadata.tensor_properties.requires_grad, - memory_format=global_metadata.tensor_properties.memory_format, - pin_memory=global_metadata.tensor_properties.pin_memory, - ) - len_rw_shards = len(shards_metadata) if sharding_dim == 1 else 1 - rw_metadata = ShardedTensorMetadata( - shards_metadata=rw_shards, - size=torch.Size([global_metadata.size[0] * len_rw_shards]), - tensor_properties=tensor_properties, - ) - return rw_local_shard, rw_metadata - - # pyre-ignore [33] - state: Dict[Any, Any] = {} - param_group: Dict[str, Any] = { - "params": [], - "lr": emb_module.optimizer_args.learning_rate, - } - params: Dict[str, torch.Tensor] = {} - - # Fused optimizers use buffers (they don't use autograd) and we want to make sure - # that state_dict look identical to non-fused version. - split_embedding_weights = emb_module.split_embedding_weights() - for table_config, weight in zip( - config.embedding_tables, - split_embedding_weights, - ): - param_group["params"].append(weight) - param_key = table_config.name + ".weight" - params[param_key] = weight - - for table_config, optimizer_states, weight in zip( - config.embedding_tables, - emb_module.split_optimizer_states(), - split_embedding_weights, - ): - state[weight] = {} - # momentum1 - assert table_config.local_rows == optimizer_states[0].size(0) - sharding_dim = ( - 1 if table_config.local_cols != table_config.embedding_dim else 0 - ) - momentum1_key = f"{table_config.name}.momentum1" - if optimizer_states[0].dim() == 1: - (local_metadata, sharded_tensor_metadata) = to_rowwise_sharded_metadata( - table_config.local_metadata, - table_config.global_metadata, - sharding_dim, - ) - else: - (local_metadata, sharded_tensor_metadata) = ( - table_config.local_metadata, - table_config.global_metadata, - ) - - momentum1 = ShardedTensor._init_from_local_shards_and_global_metadata( - local_shards=[Shard(optimizer_states[0], local_metadata)], - sharded_tensor_metadata=sharded_tensor_metadata, - process_group=self._pg, - ) - state[weight][momentum1_key] = momentum1 - # momentum2 - if len(optimizer_states) == 2: - assert table_config.local_rows == optimizer_states[1].size(0) - momentum2_key = f"{table_config.name}.momentum2" - - if optimizer_states[1].dim() == 1: - ( - local_metadata, - sharded_tensor_metadata, - ) = to_rowwise_sharded_metadata( - table_config.local_metadata, - table_config.global_metadata, - sharding_dim, - ) - else: - (local_metadata, sharded_tensor_metadata) = ( - table_config.local_metadata, - table_config.global_metadata, - ) - momentum2 = ShardedTensor._init_from_local_shards_and_global_metadata( - local_shards=[Shard(optimizer_states[1], local_metadata)], - sharded_tensor_metadata=sharded_tensor_metadata, - process_group=self._pg, - ) - state[weight][momentum2_key] = momentum2 - - super().__init__(params, state, [param_group]) - - def zero_grad(self, set_to_none: bool = False) -> None: - # pyre-ignore [16] - self._emb_module.set_learning_rate(self.param_groups[0]["lr"]) - - # pyre-ignore [2] - def step(self, closure: Any = None) -> None: - # pyre-ignore [16] - self._emb_module.set_learning_rate(self.param_groups[0]["lr"]) + def flush(self) -> None: + pass class BatchedFusedEmbeddingBag(BaseBatchedEmbeddingBag, FusedOptimizerModule): @@ -1116,6 +1127,9 @@ def named_buffers( key = append_prefix(prefix, f"{config.name}.weight") yield key, param + def flush(self) -> None: + self._emb_module.flush() + class BatchedDenseEmbeddingBag(BaseBatchedEmbeddingBag): def __init__(