Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[autoparallel] fixed broken node handler tests #1708

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@ class BatchNormStrategyGenerator(StrategyGenerator):
In this generator, both methods will be considered.
"""

@property
def has_bias(self):
return 'bias' in self.op_data

def validate(self) -> bool:
'''
In sanity check, we need make sure the input data having correct dimension size.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@ class ConvStrategyGenerator(StrategyGenerator):
The operation data is defined as `output = input x other + bias`.
"""

@property
def has_bias(self):
return 'bias' in self.op_data

def validate(self) -> bool:
'''
In sanity check, we need make sure the input data having correct dimension size.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,6 @@ class GetItemStrategyGenerator(FollowingStrategyGenerator):
3. args_0._meta_data: Tuple[torch.Tensor], args_1._meta_data: int
"""

@property
def has_bias(self):
return 'bias' in self.op_data

def validate(self) -> bool:
return super().validate()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@ class LayerNormGenerator(StrategyGenerator):
The operation data is defined as `output = input x other + bias`.
"""

@property
def has_bias(self):
return 'bias' in self.op_data

def validate(self) -> bool:
return super().validate()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ class MatMulStrategyGenerator(StrategyGenerator):
The operation data is defined as `output = input x other + bias`.
"""

@property
def has_bias(self):
return 'bias' in self.op_data

def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
Expand Down Expand Up @@ -512,11 +508,13 @@ def split_one_batch_dim(self, mesh_dim):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)

# get communication actions
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim)
communication_action_mapping = {"bias": bias_comm_spec}
communication_action_mapping = {}
if self.has_bias:
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim)
communication_action_mapping['bias'] = bias_comm_spec
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
Expand All @@ -538,11 +536,14 @@ def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)

# get communication actions
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping = {"bias": bias_comm_spec}
communication_action_mapping = {}
if self.has_bias:
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mappingp['bias'] = bias_comm_spec

return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
Expand All @@ -566,15 +567,20 @@ def split_batch_dim_lhs_space(self, mesh_dim_0, mesh_dim_1):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)

# get communication actions
communication_action_mapping = {}
other_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['other'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1)
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping = {'other': other_comm_spec, 'bias': bias_comm_spec}
communication_action_mapping['other'] = other_comm_spec

if self.has_bias:
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mapping['bias'] = bias_comm_spec

return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
Expand All @@ -600,15 +606,20 @@ def split_batch_dim_rhs_space(self, mesh_dim_0, mesh_dim_1):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)

# get communication actions
communication_action_mapping = {}
input_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['input'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_1)
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping = {'input': input_comm_spec, 'bias': bias_comm_spec}
communication_action_mapping['input'] = input_comm_spec

if self.has_bias:
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping['bias'] = bias_comm_spec

return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
Expand All @@ -633,15 +644,20 @@ def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1):
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict)

# get communication actions
communication_action_mapping = {}
output_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['output'],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_1)
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping = {'output': output_comm_spec, 'bias': bias_comm_spec}
communication_action_mapping['output'] = output_comm_spec

if self.has_bias:
bias_comm_spec = self.get_communication_spec(
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0)
communication_action_mapping['bias'] = bias_comm_spec

return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh
self.op_data = operation_data_mapping
self.device_mesh = device_mesh

@property
def has_bias(self):
"""
A utility method to check for the existence of bias operand for convenience.
"""
return 'bias' in self.op_data

def is_param(self, op_data_name):
other_data = self.op_data[op_data_name]
return other_data.type == OperationDataType.PARAM
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def forward(self, x1, x2):
return torch.bmm(x1, x2)


@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
def test_2d_device_mesh(module):

Expand Down Expand Up @@ -93,7 +92,6 @@ def test_2d_device_mesh(module):
assert 'Sb1R = Sb1Sk0 x Sb1Sk0' in strategy_name_list


@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule])
def test_1d_device_mesh(module):
model = module()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from colossalai.testing.pytest_wrapper import run_on_environment_flag


@run_on_environment_flag(name='AUTO_PARALLEL')
def test_norm_pool_handler():
model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta'))
tracer = ColoTracer()
Expand Down Expand Up @@ -50,7 +49,7 @@ def test_norm_pool_handler():
assert mapping['output'].data.shape == torch.Size([4, 4, 16, 16])
assert mapping['output'].type == OperationDataType.OUTPUT

strategies_vector = handler.register_strategy()
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
strategy_name_list = [val.name for val in strategies_vector]
assert len(strategy_name_list) == 9

Expand Down