Skip to content

Commit

Permalink
[autoparallel] refactored shape consistency to remove redundancy (#1591)
Browse files Browse the repository at this point in the history
* [autoparallel] refactored shape consistency to remove redundancy

* polish code

* polish code

* polish code
  • Loading branch information
FrankLeeeee committed Sep 13, 2022
1 parent d164449 commit 27fe8af
Show file tree
Hide file tree
Showing 13 changed files with 220 additions and 234 deletions.
45 changes: 44 additions & 1 deletion colossalai/auto_parallel/solver/_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
import torch
from torch.fx.node import Node
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
from typing import Union, Dict, List
from typing import Union, Dict, List, Optional


def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh,
Expand Down Expand Up @@ -31,3 +32,45 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic

sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict)
return sharding_spec


def generate_resharding_costs(nodes: List[Node],
sharding_specs: List[ShardingSpec],
count_backward: Optional[bool] = True,
dtype: Optional[torch.dtype] = None):
'''
Compute the resharding costs with this specific strategy.
Argument:
nodes (List[Node]): a list of nodes
sharding_spec_for_input(ShardingSpec): a list of ShardingSpec for the nodes.
count_backward (Optional[bool]): whether to include the cost of resharding in the backward pass, default is True. False can be used for inference.
dtype (Optional[torch.dtype]): the data type for cost calculation, default is None.
'''
# The resharding_cost of weight is counted due to sharing weight cases.
resharding_costs = {}
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()

# shape consistency manager is a singleton class
shape_consistency_manager = ShapeConsistencyManager()

for input_node, input_spec in zip(nodes, sharding_specs):
resharding_costs[input_node] = []
for strategy in input_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
# compute the resharding cost during forward phase
_, _, resharding_cost_forward = shape_consistency_manager.shape_consistency(input_sharding_spec, input_spec)

if count_backward:
# In backward phase, we should convert grad with target_spec into input_sharding_spec
_, _, resharding_cost_backward = shape_consistency_manager.shape_consistency(
input_spec, input_sharding_spec)
total_resharding_cost = resharding_cost_forward + resharding_cost_backward
else:
total_resharding_cost = resharding_cost_forward

# we need multiply the size of elem dtype to get correct communication cost
resharding_cost = total_resharding_cost * size_per_elem_bytes
resharding_costs[input_node].append(resharding_cost)
return resharding_costs
57 changes: 21 additions & 36 deletions colossalai/auto_parallel/solver/op_handler/batch_norm_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler
from .._utils import generate_sharding_spec

__all__ = ['BatchNormHandler']

Expand Down Expand Up @@ -115,15 +114,13 @@ def split_input_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'

dim_partition_dict_for_input = {1: [mesh_dim_0]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)

dim_partition_dict_for_weight = {0: [mesh_dim_0]}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)

dim_partition_dict_for_output = {1: [mesh_dim_0]}
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)

# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
Expand Down Expand Up @@ -156,8 +153,7 @@ def split_input_channel(self, mesh_dim_0, mesh_dim_1):
new_name = f'S{mesh_dim_1}S{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'

dim_partition_dict_for_output = {0: [mesh_dim_1], 1: [mesh_dim_0]}
new_sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# the computation cost is all the same
new_compute_cost = compute_cost

Expand Down Expand Up @@ -192,15 +188,13 @@ def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'

dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)

dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)

dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)

# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
Expand Down Expand Up @@ -234,15 +228,13 @@ def non_split(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RR x R'

dim_partition_dict_for_input = {}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)

dim_partition_dict_for_weight = {}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)

dim_partition_dict_for_output = {}
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)

# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
Expand Down Expand Up @@ -273,8 +265,7 @@ def non_split(self, mesh_dim_0, mesh_dim_1):

def _construct_batch_sharding_strategies(mesh_dim_list, new_name):
dim_partition_dict_for_output = {0: mesh_dim_list}
new_sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)

# the computation cost is all the same
new_compute_cost = compute_cost
Expand Down Expand Up @@ -332,15 +323,13 @@ def split_input_batch(self, mesh_dim_0):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'

dim_partition_dict_for_input = {0: [mesh_dim_0]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)

dim_partition_dict_for_weight = {}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)

dim_partition_dict_for_output = {0: [mesh_dim_0]}
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)

# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
Expand Down Expand Up @@ -374,15 +363,13 @@ def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'

dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)

dim_partition_dict_for_weight = {}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)

dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)

# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
Expand Down Expand Up @@ -416,15 +403,13 @@ def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'

dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)

dim_partition_dict_for_weight = {0: [mesh_dim_1]}
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)

dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)

# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
Expand Down Expand Up @@ -459,7 +444,7 @@ def register_strategy(self) -> StrategiesVector:
Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
Example:
norm_handler = BatchNormHandler(node, self.device_mesh, strategies_vector,
norm_handler = BatchNormHandler(node, strategies_vector,
self.shape_consistency_manager)
norm_handler.register_strategy()
for strategy in norm_handler.strategies_vector:
Expand Down

0 comments on commit 27fe8af

Please sign in to comment.