Skip to content

Commit

Permalink
flexibility for mp axis name in partitioning.py
Browse files Browse the repository at this point in the history
  • Loading branch information
fattorib committed May 27, 2023
1 parent d482e38 commit f89c2ad
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 21 deletions.
34 changes: 17 additions & 17 deletions src/partitioning/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _get_partition_rules_dp(mesh: Mesh):
),
]

def _get_partition_rules_tp(mesh: Mesh):
def _get_partition_rules_tp(mesh: Mesh, axis_name:str):
"""
Follows Megatron-LM partition rules from
Expand All @@ -87,16 +87,16 @@ def _get_partition_rules_tp(mesh: Mesh):
"""
return [
(("wte", "embedding"), NamedSharding(mesh, PartitionSpec("mp", None))),
(("wpe", "embedding"), NamedSharding(mesh,PartitionSpec("mp", None))),
(("wte", "embedding"), NamedSharding(mesh, PartitionSpec(axis_name, None))),
(("wpe", "embedding"), NamedSharding(mesh,PartitionSpec(axis_name, None))),
# attention
(("(query_proj|key_proj|value_proj)", "kernel"), NamedSharding(mesh,PartitionSpec(None, "mp"))),
(("residual_out", "kernel"), NamedSharding(mesh,PartitionSpec("mp", None))),
(("(query_proj|key_proj|value_proj)", "kernel"), NamedSharding(mesh,PartitionSpec(None, axis_name))),
(("residual_out", "kernel"), NamedSharding(mesh,PartitionSpec(axis_name, None))),
(("(query_proj|key_proj|value_proj)", "bias"), NamedSharding(mesh,PartitionSpec(None))),
(("residual_out", "bias"), NamedSharding(mesh,PartitionSpec(None))),
# MLP
(("fc_in", "kernel"), NamedSharding(mesh,PartitionSpec(None, "mp"))),
(("fc_residual", "kernel"), NamedSharding(mesh,PartitionSpec("mp", None))),
(("fc_in", "kernel"), NamedSharding(mesh,PartitionSpec(None, axis_name))),
(("fc_residual", "kernel"), NamedSharding(mesh,PartitionSpec(axis_name, None))),
(("fc_in", "bias"), NamedSharding(mesh,PartitionSpec(None))),
(("fc_residual", "bias"), NamedSharding(mesh,PartitionSpec(None))),
# layer norms
Expand All @@ -117,7 +117,7 @@ def _get_partition_rules_tp(mesh: Mesh):
),
]

def _get_partition_rules_tp_dp(mesh: Mesh):
def _get_partition_rules_tp_dp(mesh: Mesh, axis_name:str):
"""
Follows Megatron-LM partition rules from
Expand All @@ -127,16 +127,16 @@ def _get_partition_rules_tp_dp(mesh: Mesh):
"""
return [
(("wte", "embedding"), NamedSharding(mesh, PartitionSpec("dp","mp", None))),
(("wpe", "embedding"), NamedSharding(mesh,PartitionSpec("dp","mp", None))),
(("wte", "embedding"), NamedSharding(mesh, PartitionSpec("dp",axis_name, None))),
(("wpe", "embedding"), NamedSharding(mesh,PartitionSpec("dp",axis_name, None))),
# attention
(("(query_proj|key_proj|value_proj)", "kernel"), NamedSharding(mesh,PartitionSpec("dp",None, "mp"))),
(("residual_out", "kernel"), NamedSharding(mesh,PartitionSpec("dp","mp", None))),
(("(query_proj|key_proj|value_proj)", "kernel"), NamedSharding(mesh,PartitionSpec("dp",None, axis_name))),
(("residual_out", "kernel"), NamedSharding(mesh,PartitionSpec("dp",axis_name, None))),
(("(query_proj|key_proj|value_proj)", "bias"), NamedSharding(mesh,PartitionSpec("dp",None))),
(("residual_out", "bias"), NamedSharding(mesh,PartitionSpec("dp","mp"))),
(("residual_out", "bias"), NamedSharding(mesh,PartitionSpec("dp",axis_name))),
# MLP
(("fc_in", "kernel"), NamedSharding(mesh,PartitionSpec("dp",None, "mp"))),
(("fc_residual", "kernel"), NamedSharding(mesh,PartitionSpec("dp","mp", None))),
(("fc_in", "kernel"), NamedSharding(mesh,PartitionSpec("dp",None, axis_name))),
(("fc_residual", "kernel"), NamedSharding(mesh,PartitionSpec("dp",axis_name, None))),
(("fc_in", "bias"), NamedSharding(mesh,PartitionSpec("dp",None))),
(("fc_residual", "bias"), NamedSharding(mesh,PartitionSpec("dp",None))),
# layer norms
Expand All @@ -157,13 +157,13 @@ def _get_partition_rules_tp_dp(mesh: Mesh):
),
]

def set_partitions_rules(in_dict, mesh: Mesh, rules_func: Callable):
def set_partitions_rules(in_dict, mesh: Mesh, rules_func: Callable, axis_name: str = 'mp'):
"""
Takes a FrozenDict and returns the associated PartitionSpec rule
for all groups of parameters
"""

rules = rules_func(mesh)
rules = rules_func(mesh, axis_name)
replace = _replacement_rules(rules)

_unmatched = object()
Expand Down
8 changes: 4 additions & 4 deletions tensor_parallel_emulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def cumul_minibatch_step(carry, x_y):

if args.emulation:
print("Emulating 8 TPU cores")
GRAD_ACCUM_STEPS = 32
GRAD_ACCUM_STEPS = 8
BATCH_SIZE = 128
CTX_LEN = 32
NUM_PASSES = args.iter
Expand Down Expand Up @@ -143,12 +143,12 @@ def to_bf16(t):
param_shape = jax.eval_shape(model.init, rng, batch_tok)

if args.mp > 1:
param_spec = set_partitions_rules(param_shape, mesh, _get_partition_rules_tp)
batch_grad_spec = set_partitions_rules(param_shape, mesh, _get_partition_rules_tp_dp)
param_spec = set_partitions_rules(param_shape, mesh, _get_partition_rules_tp, axis_name = 'mp')
batch_grad_spec = set_partitions_rules(param_shape, mesh, _get_partition_rules_tp_dp, axis_name = 'mp')
batch_loss_spec = NamedSharding(mesh, P(None, 'dp', None))

else:
param_spec = no_shard
param_spec = set_partitions_rules(param_shape, mesh, _get_partition_rules_tp, axis_name = 'dp')
batch_grad_spec = no_shard
batch_loss_spec = NamedSharding(mesh, P(None, 'dp', None))

Expand Down

0 comments on commit f89c2ad

Please sign in to comment.