Skip to content

Commit

Permalink
[zero] add chunk size search for chunk manager (#1052)
Browse files Browse the repository at this point in the history
  • Loading branch information
ver217 committed Jun 2, 2022
1 parent 2c42b23 commit e1922ea
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions colossalai/tensor/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,41 @@ def __repr__(self) -> str:
for i, chunk in enumerate(group):
msg += f'[{i}] {chunk}\n'
return msg

@staticmethod
def get_chunk_util(chunk_size: int, params_numel: List[int]) -> float:
assert len(params_numel) > 0
total_size = 0
total_utilized_size = 0
cur_chunk_utilized_size = 0
for size in params_numel:
assert chunk_size >= size
total_utilized_size += size
if total_size == 0 or cur_chunk_utilized_size + size > chunk_size:
total_size += chunk_size
cur_chunk_utilized_size = 0
cur_chunk_utilized_size += size
return total_utilized_size / total_size

@staticmethod
def search_chunk_size(module: torch.nn.Module,
search_range: int,
n_grids: int,
min_chunk_size: Optional[int] = None) -> int:
assert search_range % n_grids == 0
# TODO(ver217): sort params and filter unused ones
params_numel = [p.numel() for p in module.parameters()]
max_param_numel = max(params_numel)
if min_chunk_size is not None:
assert min_chunk_size >= max_param_numel
else:
min_chunk_size = max_param_numel
step_size = search_range // n_grids
max_chunk_util = -1
best_chunk_size = -1
for chunk_size in range(min_chunk_size, min_chunk_size + search_range + 1, step_size):
chunk_util = ChunkManager.get_chunk_util(chunk_size, params_numel)
if chunk_util > max_chunk_util:
max_chunk_util = chunk_util
best_chunk_size = chunk_size
return best_chunk_size

0 comments on commit e1922ea

Please sign in to comment.