-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[zero] add chunk init function for users (#1729)
* add chunk manager init function * fix unit tests * add comment * add flush=True
- Loading branch information
Showing
10 changed files
with
689 additions
and
627 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .chunk import TensorState, TensorInfo, ChunkFullError, Chunk | ||
from .manager import ChunkManager | ||
from .search_utils import clasify_params, search_chunk_configuration | ||
from .chunk import Chunk, ChunkFullError, TensorInfo, TensorState | ||
from .manager import ChunkManager | ||
from .search_utils import clasify_params, search_chunk_configuration | ||
from .utils import init_chunk_manager |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,100 +1,108 @@ | ||
import math | ||
from typing import Dict, List | ||
import numpy as np | ||
import torch.nn as nn | ||
from colossalai.tensor import ColoParameter | ||
|
||
|
||
def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None: | ||
"""Filter those parameters whose size is too large from others. | ||
""" | ||
params_size = [p.numel() for p in model.parameters() if not getattr(p, '_ddp_to_ignore', False)] | ||
params_size_arr = np.array(params_size) | ||
|
||
std = np.std(params_size_arr) | ||
mean = np.mean(params_size_arr) | ||
upper_limit = mean + 3 * std | ||
|
||
for key in size_dict: | ||
org_list = size_dict[key] | ||
size_dict[key] = list(filter(lambda x: x <= upper_limit, org_list)) | ||
|
||
|
||
def _get_unused_byte(size_list: List[int], chunk_size: int) -> int: | ||
"""Get unused byte for a certain chunk size. | ||
""" | ||
acc = 0 | ||
left = 0 | ||
for s in size_list: | ||
if s > left: | ||
acc += left | ||
left = chunk_size | ||
left -= s | ||
return left + acc | ||
|
||
|
||
def clasify_params(model: nn.Module) -> Dict[int, List[ColoParameter]]: | ||
params_dict: Dict[int, List[ColoParameter]] = dict() | ||
for param in model.parameters(): | ||
assert isinstance(param, ColoParameter), "please init model in the ColoInitContext" | ||
if getattr(param, '_ddp_to_ignore', False): | ||
continue | ||
|
||
param_key = param.process_group.dp_world_size() | ||
|
||
if param_key not in params_dict: | ||
params_dict[param_key] = [] | ||
params_dict[param_key].append(param) | ||
|
||
return params_dict | ||
|
||
|
||
def search_chunk_configuration( | ||
model: nn.Module, | ||
search_range_mb: float, | ||
search_interval_byte: int, # hidden size is the best value for the interval | ||
min_chunk_size_mb: float = 32, | ||
filter_exlarge_params: bool = True) -> Dict: | ||
search_range_byte = round(search_range_mb * 1024**2) | ||
min_chunk_size_byte = round(min_chunk_size_mb * 1024**2) | ||
assert search_range_byte >= 0 | ||
|
||
params_dict = clasify_params(model) | ||
config_dict: Dict[int, Dict] = dict() | ||
|
||
size_dict: Dict[int, List[int]] = dict() | ||
for key in params_dict: | ||
params_list = params_dict[key] | ||
size_list = [p.numel() for p in params_list] | ||
# let small parameters keep gathered in CUDA all the time | ||
total_size = sum(size_list) | ||
if total_size < min_chunk_size_byte: | ||
config_dict[key] = dict(chunk_size=total_size, keep_gathered=True) | ||
else: | ||
size_dict[key] = size_list | ||
|
||
if filter_exlarge_params: | ||
_filter_exlarge_params(model, size_dict) | ||
|
||
max_size = min_chunk_size_byte | ||
for key in size_dict: | ||
max_size = max(max_size, max(size_dict[key])) | ||
start_size = int(math.ceil(max_size / search_interval_byte) * search_interval_byte) | ||
|
||
min_chunk_waste = float('+inf') | ||
best_chunk_size = start_size | ||
|
||
for chunk_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte): | ||
temp_waste = 0 | ||
for key in size_dict: | ||
temp_waste += _get_unused_byte(size_dict[key], chunk_size) | ||
if temp_waste < min_chunk_waste: | ||
min_chunk_waste = temp_waste | ||
best_chunk_size = chunk_size | ||
|
||
for key in params_dict: | ||
if key in config_dict: | ||
continue | ||
config_dict[key] = dict(chunk_size=best_chunk_size, keep_gathered=False) | ||
|
||
return config_dict | ||
import math | ||
from typing import Dict, List, Tuple | ||
|
||
import numpy as np | ||
import torch.nn as nn | ||
|
||
from colossalai.tensor import ColoParameter | ||
|
||
|
||
def in_ddp(param: nn.Parameter) -> bool: | ||
return not getattr(param, '_ddp_to_ignore', False) | ||
|
||
|
||
def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None: | ||
"""Filter those parameters whose size is too large from others. | ||
""" | ||
params_size = [p.numel() for p in model.parameters() if in_ddp(p)] | ||
params_size_arr = np.array(params_size) | ||
|
||
std = np.std(params_size_arr) | ||
mean = np.mean(params_size_arr) | ||
upper_limit = mean + 3 * std | ||
|
||
for key in size_dict: | ||
org_list = size_dict[key] | ||
size_dict[key] = list(filter(lambda x: x <= upper_limit, org_list)) | ||
|
||
|
||
def _get_unused_byte(size_list: List[int], chunk_size: int) -> int: | ||
"""Get unused byte for a certain chunk size. | ||
""" | ||
acc = 0 | ||
left = 0 | ||
for s in size_list: | ||
if s > left: | ||
acc += left | ||
left = chunk_size | ||
left -= s | ||
return left + acc | ||
|
||
|
||
def clasify_params(model: nn.Module) -> Dict[int, List[ColoParameter]]: | ||
"""Clasify each parameter by its size of DP group. | ||
""" | ||
params_dict: Dict[int, List[ColoParameter]] = dict() | ||
for param in model.parameters(): | ||
assert isinstance(param, ColoParameter), "please init model in the ColoInitContext" | ||
if not in_ddp(param): | ||
continue | ||
|
||
param_key = param.process_group.dp_world_size() | ||
|
||
if param_key not in params_dict: | ||
params_dict[param_key] = [] | ||
params_dict[param_key].append(param) | ||
|
||
return params_dict | ||
|
||
|
||
def search_chunk_configuration( | ||
model: nn.Module, | ||
search_range_mb: float, | ||
search_interval_byte: int, # hidden size is the best value for the interval | ||
min_chunk_size_mb: float = 32, | ||
filter_exlarge_params: bool = True) -> Tuple[Dict, int]: | ||
search_range_byte = round(search_range_mb * 1024**2) | ||
min_chunk_size_byte = round(min_chunk_size_mb * 1024**2) | ||
assert search_range_byte >= 0 | ||
|
||
params_dict = clasify_params(model) | ||
config_dict: Dict[int, Dict] = dict() | ||
|
||
size_dict: Dict[int, List[int]] = dict() | ||
for key in params_dict: | ||
params_list = params_dict[key] | ||
size_list = [p.numel() for p in params_list] | ||
# let small parameters keep gathered in CUDA all the time | ||
total_size = sum(size_list) | ||
if total_size < min_chunk_size_byte: | ||
config_dict[key] = dict(chunk_size=total_size, keep_gathered=True) | ||
else: | ||
size_dict[key] = size_list | ||
|
||
if filter_exlarge_params: | ||
_filter_exlarge_params(model, size_dict) | ||
|
||
max_size = min_chunk_size_byte | ||
for key in size_dict: | ||
max_size = max(max_size, max(size_dict[key])) | ||
start_size = int(math.ceil(max_size / search_interval_byte) * search_interval_byte) | ||
|
||
min_chunk_waste = float('+inf') | ||
best_chunk_size = start_size | ||
|
||
for chunk_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte): | ||
temp_waste = 0 | ||
for key in size_dict: | ||
temp_waste += _get_unused_byte(size_dict[key], chunk_size) | ||
if temp_waste < min_chunk_waste: | ||
min_chunk_waste = temp_waste | ||
best_chunk_size = chunk_size | ||
|
||
for key in params_dict: | ||
if key in config_dict: | ||
continue | ||
config_dict[key] = dict(chunk_size=best_chunk_size, keep_gathered=False) | ||
|
||
return config_dict, min_chunk_waste |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from time import time | ||
from typing import Optional | ||
|
||
import torch | ||
import torch.distributed as dist | ||
import torch.nn as nn | ||
|
||
from colossalai.gemini.chunk import ChunkManager | ||
from colossalai.gemini.chunk.search_utils import in_ddp, search_chunk_configuration | ||
|
||
|
||
def init_chunk_manager(model: nn.Module, | ||
init_device: Optional[torch.device] = None, | ||
hidden_dim: Optional[int] = None, | ||
search_range_mb: Optional[float] = None, | ||
min_chunk_size_mb: Optional[float] = None, | ||
filter_exlarge_params: Optional[bool] = None) -> ChunkManager: | ||
|
||
kwargs_dict = dict() | ||
|
||
if hidden_dim: | ||
search_interval_byte = hidden_dim | ||
else: | ||
search_interval_byte = 1024 # 1kb | ||
kwargs_dict["search_interval_byte"] = search_interval_byte | ||
|
||
if search_range_mb: | ||
kwargs_dict["search_range_mb"] = search_range_mb | ||
|
||
if min_chunk_size_mb: | ||
kwargs_dict["min_chunk_size_mb"] = min_chunk_size_mb | ||
|
||
if filter_exlarge_params: | ||
kwargs_dict["filter_exlarge_params"] = filter_exlarge_params | ||
|
||
params_sizes = [p.numel() for p in model.parameters() if in_ddp(p)] | ||
total_size = sum(params_sizes) / 1024**2 | ||
|
||
dist.barrier() | ||
begine = time() | ||
|
||
config_dict, wasted_size = search_chunk_configuration(model, **kwargs_dict) | ||
|
||
dist.barrier() | ||
end = time() | ||
span_s = end - begine | ||
wasted_size /= 1024**2 | ||
|
||
if dist.get_rank() == 0: | ||
print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s), | ||
"used number: {:.2f} MB, wasted number: {:.2f} MB\n".format(total_size, wasted_size), | ||
"total wasted percentage is {:.2f}%".format(100 * wasted_size / (total_size + wasted_size)), | ||
sep='', | ||
flush=True) | ||
dist.barrier() | ||
|
||
chunk_manager = ChunkManager(config_dict, init_device) | ||
return chunk_manager |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.