-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[autoparallel] add embedding handler
- Loading branch information
1 parent
9772581
commit 77231a5
Showing
3 changed files
with
193 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
176 changes: 176 additions & 0 deletions
176
colossalai/auto_parallel/solver/op_handler/embedding_handler.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
import operator | ||
from functools import reduce | ||
import warnings | ||
import torch | ||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector | ||
from .operator_handler import OperatorHandler | ||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager | ||
from colossalai.tensor.sharding_spec import ShardingSpec | ||
from copy import deepcopy | ||
from typing import Dict, List | ||
from colossalai.auto_parallel.solver._utils import exception_handler | ||
|
||
__all__ = ['EmbeddingHandler'] | ||
|
||
|
||
class EmbeddingHandler(OperatorHandler): | ||
""" | ||
An OperatorHandler which deals with the sharding strategies of Embedding operators(such as nn.embedding). | ||
""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.input_data = self.predecessor_node[0]._meta_data | ||
self.weight = self.module_named_parameters['weight'] | ||
self.output_data = self.node._meta_data | ||
|
||
def _generate_compute_cost(self, total_sharding_size): | ||
input_shape = self.input_data.shape | ||
weight_shape = self.weight.shape | ||
input_shape_product = reduce(operator.mul, input_shape, 1) | ||
weight_shape_product = reduce(operator.mul, weight_shape, 1) | ||
compute_cost = input_shape_product * weight_shape_product * 2 / total_sharding_size | ||
return compute_cost | ||
|
||
def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight): | ||
''' | ||
Compute the memory cost per device with this specific strategy. | ||
Argument: | ||
sharding_size_forward(int): The forward activation will be divided | ||
into sharding_size_forward number partions. | ||
sharding_size_backward_activation(int): The backward activation will | ||
be divided into sharding_size_backward_activation number partions. | ||
sharding_size_weight(int): The backward weight will be divided | ||
into sharding_size_weight number partions. | ||
Return: | ||
memory_cost(Tuple[float]): Memory cost per device with this | ||
specific strategy, the first element of this tuple is forward | ||
memory cost, and the second element of this tuple is backward | ||
memory cost. | ||
memory_cost_forward(float): Memory cost of forward activation per | ||
device with this specific strategy. | ||
memory_cost_backward_activation(float): Memory cost of backward activation | ||
per device with this specific strategy. | ||
''' | ||
# compute the memory cost of this strategy | ||
dtype = self.input_data.dtype | ||
numel_output = self.output_data.numel() | ||
numel_input = self.input_data.numel() | ||
numel_weight = self.weight.numel() | ||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() | ||
|
||
# forward memory_cost | ||
memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward | ||
memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight | ||
memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight | ||
|
||
# backward memory_cost | ||
memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation | ||
memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight | ||
memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight | ||
|
||
# memory_cost pair | ||
memory_cost = (memory_cost_forward, memory_cost_backward) | ||
|
||
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight | ||
|
||
@exception_handler | ||
def split_weight_both_dim(self, mesh_dim_0, mesh_dim_1): | ||
name = f'RRS{mesh_dim_1} = RR x S{mesh_dim_0}S{mesh_dim_1}' | ||
|
||
dim_partition_dict_for_input = {} | ||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) | ||
|
||
dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]} | ||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) | ||
|
||
dim_partition_dict_for_output = {2: [mesh_dim_1]} | ||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) | ||
|
||
# generate resharding cost for this strategy | ||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) | ||
|
||
# compute the computation cost of this strategy | ||
total_sharding_size = self.device_mesh.shape[0] * self.device_mesh.shape[1] | ||
compute_cost = self._generate_compute_cost(total_sharding_size) | ||
|
||
# compute the memory cost of this strategy | ||
sharding_size_forward = self.device_mesh.shape[mesh_dim_1] | ||
sharding_size_backward_activation = 1 | ||
sharding_size_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] | ||
memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, _ = self._generate_memory_cost( | ||
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) | ||
|
||
# compute the communication cost of this strategy during forward phase | ||
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0) | ||
# compute the communication cost of this strategy during backward phase | ||
communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_1) | ||
communication_cost = communication_cost_forward + communication_cost_backward | ||
sharding_strategies = ShardingStrategy(name, | ||
output_sharding_spec=sharding_spec_for_output, | ||
compute_cost=compute_cost, | ||
communication_cost=communication_cost, | ||
memory_cost=memory_cost, | ||
resharding_costs=resharding_costs, | ||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) | ||
self.strategies_vector.append(sharding_strategies) | ||
|
||
@exception_handler | ||
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1): | ||
name = f'S{mesh_dim_0}S{mesh_dim_1}R = S{mesh_dim_0}S{mesh_dim_1} x RR' | ||
|
||
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]} | ||
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) | ||
|
||
dim_partition_dict_for_weight = {} | ||
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) | ||
|
||
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]} | ||
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) | ||
|
||
# generate resharding cost for this strategy | ||
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight]) | ||
|
||
# compute the computation cost of this strategy | ||
total_sharding_size = self.device_mesh.shape[0] * self.device_mesh.shape[1] | ||
compute_cost = self._generate_compute_cost(total_sharding_size) | ||
|
||
# compute the memory cost of this strategy | ||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] | ||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] | ||
sharding_size_weight = 1 | ||
memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight = self._generate_memory_cost( | ||
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight) | ||
|
||
# This strategy do not need to do all_reduce during forward phase | ||
communication_cost_forward = 0 | ||
# compute the communication cost of this strategy during backward phase | ||
communication_cost_backward_activation = 0 | ||
communication_cost_backward_weight = self.device_mesh.flatten_device_mesh.all_reduce_cost( | ||
memory_cost_backward_weight, 0) | ||
communication_cost_backward = communication_cost_backward_activation + communication_cost_backward_weight | ||
communication_cost = communication_cost_forward + communication_cost_backward | ||
sharding_strategies = ShardingStrategy(name, | ||
output_sharding_spec=sharding_spec_for_output, | ||
compute_cost=compute_cost, | ||
communication_cost=communication_cost, | ||
memory_cost=memory_cost, | ||
resharding_costs=resharding_costs, | ||
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) | ||
self.strategies_vector.append(sharding_strategies) | ||
|
||
def register_strategy(self) -> StrategiesVector: | ||
''' | ||
Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector. | ||
''' | ||
# RRS = RR x SS | ||
self.split_weight_both_dim(0, 1) | ||
self.split_weight_both_dim(1, 0) | ||
|
||
# SSR = SS x RR | ||
self.split_input_both_dim(0, 1) | ||
self.split_input_both_dim(1, 0) | ||
|
||
return self.strategies_vector |