Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[autoparallel] resnet block runtime apply #1709

Merged
merged 5 commits into from
Oct 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,30 @@ def get_operation_data_mapping(self) -> Dict[str, OperationData]:
logical_shape=self.named_parameters['weight'].shape)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)

mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
physical_running_mean_operand = OperationData(name="running_mean",
type=OperationDataType.BUFFER,
data=self.named_buffers['running_mean'],
logical_shape=self.named_buffers['running_mean'].shape)

physical_running_var_operand = OperationData(name="running_var",
type=OperationDataType.BUFFER,
data=self.named_buffers['running_var'],
logical_shape=self.named_buffers['running_var'].shape)

physical_num_batches_tracked_operand = OperationData(
name="num_batches_tracked",
type=OperationDataType.BUFFER,
data=self.named_buffers['num_batches_tracked'],
logical_shape=self.named_buffers['num_batches_tracked'].shape)

mapping = {
"input": physical_input_operand,
"other": physical_other_operand,
"output": physical_output,
"running_mean": physical_running_mean_operand,
"running_var": physical_running_var_operand,
"num_batches_tracked": physical_num_batches_tracked_operand
}

if self.named_parameters['bias'] is not None:
physical_bias_operand = OperationData(name="bias",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,10 @@ def __init__(self, *args, **kwargs) -> None:
f'The graph is not associated with a module, please make sure it can be used to instantiate a GraphModule object.'
module = self.node.graph.owning_module.get_submodule(self.node.target)
named_parameters = list(module.named_parameters(recurse=False))
named_buffers = list(module.named_buffers(recurse=False))
# convert named parameters from list to dict
named_parameters = {k: v for k, v in named_parameters}
named_buffers = {k: v for k, v in named_buffers}
self.module = module
self.named_parameters = named_parameters
self.named_buffers = named_buffers
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
@operator_registry.register(torch.reshape)
@operator_registry.register(torch.flatten)
@operator_registry.register(torch.Tensor.permute)
@operator_registry.register(torch.nn.AdaptiveAvgPool2d)
class ReshapeHandler(NodeHandler):
"""
A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def update_memory_cost(self, strategy: ShardingStrategy):
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'other': self._compute_size_in_bytes(strategy, "other"),
'output': self._compute_size_in_bytes(strategy, "output")
'output': self._compute_size_in_bytes(strategy, "output"),
'running_mean': self._compute_size_in_bytes(strategy, "running_mean"),
'running_var': self._compute_size_in_bytes(strategy, "running_var"),
}

if self.has_bias:
Expand All @@ -75,24 +77,27 @@ def update_memory_cost(self, strategy: ShardingStrategy):
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + other + bias + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
fwd_activation_cost = sum(
[v for k, v in forward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)])
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
fwd_buffer_cost = sum([v for k, v in forward_size_mapping.items() if self.is_buffer(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost, buffer=fwd_buffer_cost)

# compute bwd cost incurred
# bwd_cost = input_grad + other_grad + bias_grad
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
bwd_activation_cost = sum(
[v for k, v in backward_size_mapping.items() if not self.is_param(k) and not self.is_buffer(k)])
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)

# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
parameter=fwd_parameter_cost + bwd_parameter_cost,
buffer=fwd_buffer_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost

def split_input_channel(self, mesh_dim_0):
strategy_list = []
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
dim_partition_dict_mapping = {
"input": {
Expand All @@ -104,6 +109,13 @@ def split_input_channel(self, mesh_dim_0):
"output": {
1: [mesh_dim_0]
},
"running_mean": {
0: [mesh_dim_0]
},
"running_var": {
0: [mesh_dim_0]
},
"num_batches_tracked": {},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0]}
Expand All @@ -128,6 +140,13 @@ def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1):
"output": {
1: [mesh_dim_0, mesh_dim_1]
},
"running_mean": {
0: [mesh_dim_0, mesh_dim_1]
},
"running_var": {
0: [mesh_dim_0, mesh_dim_1]
},
"num_batches_tracked": {},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {0: [mesh_dim_0, mesh_dim_1]}
Expand All @@ -146,6 +165,9 @@ def non_split(self):
"input": {},
"other": {},
"output": {},
"running_mean": {},
"running_var": {},
"num_batches_tracked": {},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {}
Expand All @@ -168,6 +190,9 @@ def split_input_batch(self, mesh_dim_0):
"output": {
0: [mesh_dim_0]
},
"running_mean": {},
"running_var": {},
"num_batches_tracked": {},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {}
Expand Down Expand Up @@ -199,6 +224,9 @@ def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1):
"output": {
0: [mesh_dim_0, mesh_dim_1]
},
"running_mean": {},
"running_var": {},
"num_batches_tracked": {},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {}
Expand Down Expand Up @@ -234,6 +262,13 @@ def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
0: [mesh_dim_0],
1: [mesh_dim_1],
},
"running_mean": {
0: [mesh_dim_1],
},
"running_var": {
0: [mesh_dim_1],
},
"num_batches_tracked": {},
}
if self.has_bias:
dim_partition_dict_mapping["bias"] = {
Expand Down Expand Up @@ -273,16 +308,22 @@ def generate(self):
# RS01 = RS01 x S01
strategy_list.append(self.split_input_channel_1d(0, 1))

# The strategies with SYNC_BN are temporarily commented,
# because it requires some additional passes to keep runtime
# computation correctness.

# TODO: The strategies below should be uncommented after runtime
# passes ready.
# SR = SR x R WITH SYNC_BN
strategy_list.append(self.split_input_batch(0))
strategy_list.append(self.split_input_batch(1))
# strategy_list.append(self.split_input_batch(0))
# strategy_list.append(self.split_input_batch(1))

# SS = SS x S WITH SYNC_BN
strategy_list.append(self.split_input_both_dim(0, 1))
strategy_list.append(self.split_input_both_dim(1, 0))
# strategy_list.append(self.split_input_both_dim(0, 1))
# strategy_list.append(self.split_input_both_dim(1, 0))

# S01R = S01R x R WITH SYNC_BN
strategy_list.append(self.split_input_batch_1d(0, 1))
# strategy_list.append(self.split_input_batch_1d(0, 1))
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved

for strategy in strategy_list:
self.update_communication_cost(strategy)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def is_param(self, op_data_name):
other_data = self.op_data[op_data_name]
return other_data.type == OperationDataType.PARAM

def is_buffer(self, op_data_name):
other_data = self.op_data[op_data_name]
return other_data.type == OperationDataType.BUFFER

def get_sharding_strategy(self, name: str, sharding_spec_mapping: Dict[str, ShardingSpec],
communication_action_mapping: Dict[str, CommSpec]):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ class OperationDataType(Enum):
INPUT = 0
ARG = 1
PARAM = 2
OUTPUT = 3
BUFFER = 3
OUTPUT = 4


@dataclass
Expand Down Expand Up @@ -80,6 +81,7 @@ class MemoryCost:
"""
activation: int = 0
parameter: int = 0
buffer: int = 0


@dataclass
Expand Down
12 changes: 7 additions & 5 deletions colossalai/auto_parallel/tensor_shard/solver/cost_graph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST
import torch


class CostGraph:
Expand Down Expand Up @@ -51,7 +52,6 @@ def _build_cost_graph(self):
if src_node not in self.nodes:
continue
node_pair = (src_node, dst_node)
# src_index = strategies_vector.predecessor_nodes.index(src_node)
edge_cost = {}
for i in range(len(strategies_vector)):
for j in range(len(src_node.strategies_vector)):
Expand All @@ -62,10 +62,12 @@ def _build_cost_graph(self):
edge_cost[(j, i)] = resharding_cost_item.total
self.edge_costs[node_pair] = edge_cost
# add parents and children attribute to node
setattr(dst_node, 'parents', strategies_vector.predecessor_nodes)
setattr(dst_node, 'children', strategies_vector.successor_nodes)
self._remove_invalid_node(dst_node, 'parents')
self._remove_invalid_node(dst_node, 'children')
parent_nodes = [node for node in strategies_vector.predecessor_nodes]
children_nodes = [node for node in strategies_vector.successor_nodes]
setattr(dst_node, 'parents', parent_nodes)
setattr(dst_node, 'children', children_nodes)
# self._remove_invalid_node(dst_node, 'parents')
# self._remove_invalid_node(dst_node, 'children')

if self.simplify and strategies_vector.check_merge():
for followed_node in strategies_vector.predecessor_nodes:
Expand Down
5 changes: 1 addition & 4 deletions colossalai/auto_parallel/tensor_shard/solver/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,7 @@ def _prepare_data_for_solver(self):
else:
communication_costs.append(origin_communication_cost)
memory_costs.append(memory_cost)
# if isinstance(memory_cost, tuple):
# memory_costs.append(memory_cost[0])
# else:
# memory_costs.append(memory_cost)

compute_costs = np.array(compute_costs)
communication_costs = np.array(communication_costs)
memory_costs = np.array(memory_costs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,19 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], de
for name, param in target_module.named_parameters():
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
setattr(param, 'sharding_spec', origin_sharding_spec)
target_weight_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
apply(param, target_weight_sharding_spec)
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
apply(param, target_sharding_spec)

for name, buffer in target_module.named_buffers():
origin_sharding_spec = ShardingSpec(device_mesh, buffer.shape, {})
setattr(buffer, 'sharding_spec', origin_sharding_spec)
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
apply(buffer, target_sharding_spec)

# the dict to get input sharding specs of user node
sharding_spec_convert_dict = {}
for index, node in enumerate(nodes):
target_sharding_specs = []
if node.name == 'bn1':
print(node.strategies_vector.successor_nodes)
assert False
for user_node in node.strategies_vector.successor_nodes:
# node_index = user_node.strategies_vector.predecessor_nodes.index(node)
# target_sharding_spec = user_node.best_strategy.input_shardings[node_index]
Expand Down