Skip to content

Commit

Permalink
fix uneven heads issue (microsoft#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yejing-Lai committed Oct 31, 2023
1 parent 2016f30 commit 57ff508
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions deepspeed/module_inject/tp_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,14 @@ def get_num_kv_heads():

def get_shard_size(total_size, mp_size, rank=None):
global num_kv_heads
# When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division
if num_kv_heads != None:
if (rank == None):
rank = dist.get_rank()
# When we have num_kv_heads defined, uneven division is possible, otherwise enforce near even division
if rank == None:
rank = dist.get_rank()
if num_kv_heads != None and total_size % num_kv_heads == 0:
my_slices = (num_kv_heads // mp_size) + (1 if rank < (num_kv_heads % mp_size) else 0)
return total_size * my_slices // num_kv_heads
else:
if total_size % mp_size == 0:
return total_size // mp_size
else:
assert False, f"Number of attention heads ({total_size}) must be divisible by mp_size ({mp_size})"
return total_size // mp_size + (1 if rank < (total_size % mp_size) else 0)


def get_shard_size_list(total_size, mp_size):
Expand Down

0 comments on commit 57ff508

Please sign in to comment.