Skip to content

Commit

Permalink
[autoparallel] added generate_sharding_spec to utils (#1590)
Browse files Browse the repository at this point in the history
  • Loading branch information
FrankLeeeee committed Sep 13, 2022
1 parent 49ccf8b commit 7c18a58
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 103 deletions.
33 changes: 33 additions & 0 deletions colossalai/auto_parallel/solver/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch
from torch.fx.node import Node
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
from typing import Union, Dict, List


def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh,
dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
"""
Generate the sharding spec of the tensor based on the given dim_partition_dict.
Args:
input_ (Union[Node, torch.Tensor]): the input can be a Node object or a PyTorch tensor. If a node is used, it will look for its meta data associated with this node.
device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster.
dim_partition_dict (Dict[int, List[int]]): a dictionary to specify the sharding specs, the key is the tensor dimension and the value is the mesh dimension for sharding.
"""

if isinstance(input_, Node):
assert hasattr(input_, '_meta_data'), f'The given node has not attribte _meta_data'
meta_tensor = input_._meta_data
assert meta_tensor is not None, "The given node's _meta_data attribute is None"
shape = meta_tensor.shape
elif isinstance(input_, torch.Tensor):
shape = input_.shape
else:
raise TypeError(
f'We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected.'
)

sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict)
return sharding_spec
55 changes: 35 additions & 20 deletions colossalai/auto_parallel/solver/op_handler/batch_norm_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler
from .._utils import generate_sharding_spec

__all__ = ['BatchNormHandler']

Expand Down Expand Up @@ -114,13 +115,15 @@ def split_input_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'

dim_partition_dict_for_input = {1: [mesh_dim_0]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)

dim_partition_dict_for_weight = {0: [mesh_dim_0]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)

dim_partition_dict_for_output = {1: [mesh_dim_0]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)

# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
Expand Down Expand Up @@ -153,7 +156,8 @@ def split_input_channel(self, mesh_dim_0, mesh_dim_1):
new_name = f'S{mesh_dim_1}S{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'

dim_partition_dict_for_output = {0: [mesh_dim_1], 1: [mesh_dim_0]}
new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
new_sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)
# the computation cost is all the same
new_compute_cost = compute_cost

Expand Down Expand Up @@ -188,13 +192,15 @@ def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'

dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)

dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)

dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)

# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
Expand Down Expand Up @@ -228,13 +234,15 @@ def non_split(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RR x R'

dim_partition_dict_for_input = {}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)

dim_partition_dict_for_weight = {}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)

dim_partition_dict_for_output = {}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)

# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
Expand Down Expand Up @@ -265,7 +273,8 @@ def non_split(self, mesh_dim_0, mesh_dim_1):

def _construct_batch_sharding_strategies(mesh_dim_list, new_name):
dim_partition_dict_for_output = {0: mesh_dim_list}
new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
new_sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)

# the computation cost is all the same
new_compute_cost = compute_cost
Expand Down Expand Up @@ -323,13 +332,15 @@ def split_input_batch(self, mesh_dim_0):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'

dim_partition_dict_for_input = {0: [mesh_dim_0]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)

dim_partition_dict_for_weight = {}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)

dim_partition_dict_for_output = {0: [mesh_dim_0]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)

# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
Expand Down Expand Up @@ -363,13 +374,15 @@ def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'

dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)

dim_partition_dict_for_weight = {}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)

dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)

# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
Expand Down Expand Up @@ -403,13 +416,15 @@ def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'

dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
sharding_spec_for_input = generate_sharding_spec(self.input_data, self.device_mesh,
dim_partition_dict_for_input)

dim_partition_dict_for_weight = {0: [mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
sharding_spec_for_weight = generate_sharding_spec(self.weight, self.device_mesh, dim_partition_dict_for_weight)

dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
sharding_spec_for_output = generate_sharding_spec(self.output_data, self.device_mesh,
dim_partition_dict_for_output)

# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
Expand Down

0 comments on commit 7c18a58

Please sign in to comment.