Skip to content

Commit

Permalink
[zero] add chunk init function for users (#1729)
Browse files Browse the repository at this point in the history
* add chunk manager init function

* fix unit tests

* add comment

* add flush=True
  • Loading branch information
1SAA committed Oct 18, 2022
1 parent 2e1dbfb commit f69f9bf
Show file tree
Hide file tree
Showing 10 changed files with 689 additions and 627 deletions.
7 changes: 4 additions & 3 deletions colossalai/gemini/chunk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .chunk import TensorState, TensorInfo, ChunkFullError, Chunk
from .manager import ChunkManager
from .search_utils import clasify_params, search_chunk_configuration
from .chunk import Chunk, ChunkFullError, TensorInfo, TensorState
from .manager import ChunkManager
from .search_utils import clasify_params, search_chunk_configuration
from .utils import init_chunk_manager
208 changes: 108 additions & 100 deletions colossalai/gemini/chunk/search_utils.py
Original file line number Diff line number Diff line change
@@ -1,100 +1,108 @@
import math
from typing import Dict, List
import numpy as np
import torch.nn as nn
from colossalai.tensor import ColoParameter


def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None:
"""Filter those parameters whose size is too large from others.
"""
params_size = [p.numel() for p in model.parameters() if not getattr(p, '_ddp_to_ignore', False)]
params_size_arr = np.array(params_size)

std = np.std(params_size_arr)
mean = np.mean(params_size_arr)
upper_limit = mean + 3 * std

for key in size_dict:
org_list = size_dict[key]
size_dict[key] = list(filter(lambda x: x <= upper_limit, org_list))


def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
"""Get unused byte for a certain chunk size.
"""
acc = 0
left = 0
for s in size_list:
if s > left:
acc += left
left = chunk_size
left -= s
return left + acc


def clasify_params(model: nn.Module) -> Dict[int, List[ColoParameter]]:
params_dict: Dict[int, List[ColoParameter]] = dict()
for param in model.parameters():
assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
if getattr(param, '_ddp_to_ignore', False):
continue

param_key = param.process_group.dp_world_size()

if param_key not in params_dict:
params_dict[param_key] = []
params_dict[param_key].append(param)

return params_dict


def search_chunk_configuration(
model: nn.Module,
search_range_mb: float,
search_interval_byte: int, # hidden size is the best value for the interval
min_chunk_size_mb: float = 32,
filter_exlarge_params: bool = True) -> Dict:
search_range_byte = round(search_range_mb * 1024**2)
min_chunk_size_byte = round(min_chunk_size_mb * 1024**2)
assert search_range_byte >= 0

params_dict = clasify_params(model)
config_dict: Dict[int, Dict] = dict()

size_dict: Dict[int, List[int]] = dict()
for key in params_dict:
params_list = params_dict[key]
size_list = [p.numel() for p in params_list]
# let small parameters keep gathered in CUDA all the time
total_size = sum(size_list)
if total_size < min_chunk_size_byte:
config_dict[key] = dict(chunk_size=total_size, keep_gathered=True)
else:
size_dict[key] = size_list

if filter_exlarge_params:
_filter_exlarge_params(model, size_dict)

max_size = min_chunk_size_byte
for key in size_dict:
max_size = max(max_size, max(size_dict[key]))
start_size = int(math.ceil(max_size / search_interval_byte) * search_interval_byte)

min_chunk_waste = float('+inf')
best_chunk_size = start_size

for chunk_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte):
temp_waste = 0
for key in size_dict:
temp_waste += _get_unused_byte(size_dict[key], chunk_size)
if temp_waste < min_chunk_waste:
min_chunk_waste = temp_waste
best_chunk_size = chunk_size

for key in params_dict:
if key in config_dict:
continue
config_dict[key] = dict(chunk_size=best_chunk_size, keep_gathered=False)

return config_dict
import math
from typing import Dict, List, Tuple

import numpy as np
import torch.nn as nn

from colossalai.tensor import ColoParameter


def in_ddp(param: nn.Parameter) -> bool:
return not getattr(param, '_ddp_to_ignore', False)


def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None:
"""Filter those parameters whose size is too large from others.
"""
params_size = [p.numel() for p in model.parameters() if in_ddp(p)]
params_size_arr = np.array(params_size)

std = np.std(params_size_arr)
mean = np.mean(params_size_arr)
upper_limit = mean + 3 * std

for key in size_dict:
org_list = size_dict[key]
size_dict[key] = list(filter(lambda x: x <= upper_limit, org_list))


def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
"""Get unused byte for a certain chunk size.
"""
acc = 0
left = 0
for s in size_list:
if s > left:
acc += left
left = chunk_size
left -= s
return left + acc


def clasify_params(model: nn.Module) -> Dict[int, List[ColoParameter]]:
"""Clasify each parameter by its size of DP group.
"""
params_dict: Dict[int, List[ColoParameter]] = dict()
for param in model.parameters():
assert isinstance(param, ColoParameter), "please init model in the ColoInitContext"
if not in_ddp(param):
continue

param_key = param.process_group.dp_world_size()

if param_key not in params_dict:
params_dict[param_key] = []
params_dict[param_key].append(param)

return params_dict


def search_chunk_configuration(
model: nn.Module,
search_range_mb: float,
search_interval_byte: int, # hidden size is the best value for the interval
min_chunk_size_mb: float = 32,
filter_exlarge_params: bool = True) -> Tuple[Dict, int]:
search_range_byte = round(search_range_mb * 1024**2)
min_chunk_size_byte = round(min_chunk_size_mb * 1024**2)
assert search_range_byte >= 0

params_dict = clasify_params(model)
config_dict: Dict[int, Dict] = dict()

size_dict: Dict[int, List[int]] = dict()
for key in params_dict:
params_list = params_dict[key]
size_list = [p.numel() for p in params_list]
# let small parameters keep gathered in CUDA all the time
total_size = sum(size_list)
if total_size < min_chunk_size_byte:
config_dict[key] = dict(chunk_size=total_size, keep_gathered=True)
else:
size_dict[key] = size_list

if filter_exlarge_params:
_filter_exlarge_params(model, size_dict)

max_size = min_chunk_size_byte
for key in size_dict:
max_size = max(max_size, max(size_dict[key]))
start_size = int(math.ceil(max_size / search_interval_byte) * search_interval_byte)

min_chunk_waste = float('+inf')
best_chunk_size = start_size

for chunk_size in range(start_size, start_size + search_range_byte + 1, search_interval_byte):
temp_waste = 0
for key in size_dict:
temp_waste += _get_unused_byte(size_dict[key], chunk_size)
if temp_waste < min_chunk_waste:
min_chunk_waste = temp_waste
best_chunk_size = chunk_size

for key in params_dict:
if key in config_dict:
continue
config_dict[key] = dict(chunk_size=best_chunk_size, keep_gathered=False)

return config_dict, min_chunk_waste
58 changes: 58 additions & 0 deletions colossalai/gemini/chunk/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from time import time
from typing import Optional

import torch
import torch.distributed as dist
import torch.nn as nn

from colossalai.gemini.chunk import ChunkManager
from colossalai.gemini.chunk.search_utils import in_ddp, search_chunk_configuration


def init_chunk_manager(model: nn.Module,
init_device: Optional[torch.device] = None,
hidden_dim: Optional[int] = None,
search_range_mb: Optional[float] = None,
min_chunk_size_mb: Optional[float] = None,
filter_exlarge_params: Optional[bool] = None) -> ChunkManager:

kwargs_dict = dict()

if hidden_dim:
search_interval_byte = hidden_dim
else:
search_interval_byte = 1024 # 1kb
kwargs_dict["search_interval_byte"] = search_interval_byte

if search_range_mb:
kwargs_dict["search_range_mb"] = search_range_mb

if min_chunk_size_mb:
kwargs_dict["min_chunk_size_mb"] = min_chunk_size_mb

if filter_exlarge_params:
kwargs_dict["filter_exlarge_params"] = filter_exlarge_params

params_sizes = [p.numel() for p in model.parameters() if in_ddp(p)]
total_size = sum(params_sizes) / 1024**2

dist.barrier()
begine = time()

config_dict, wasted_size = search_chunk_configuration(model, **kwargs_dict)

dist.barrier()
end = time()
span_s = end - begine
wasted_size /= 1024**2

if dist.get_rank() == 0:
print("searching chunk configuration is completed in {:.2f} s.\n".format(span_s),
"used number: {:.2f} MB, wasted number: {:.2f} MB\n".format(total_size, wasted_size),
"total wasted percentage is {:.2f}%".format(100 * wasted_size / (total_size + wasted_size)),
sep='',
flush=True)
dist.barrier()

chunk_manager = ChunkManager(config_dict, init_device)
return chunk_manager
28 changes: 15 additions & 13 deletions tests/test_ddp/test_ddp_ignore_params.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
import os
import random
from functools import partial
from typing import Callable, Type

import numpy as np
import pytest
import colossalai
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext

import colossalai
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration
from functools import partial
from colossalai.nn.parallel import ColoDDP, ZeroDDP
from colossalai.gemini.gemini_mgr import GeminiManager
from typing import Callable, Type
import torch.distributed as dist
import os
import random
import numpy as np
from colossalai.nn.parallel import ColoDDP, ZeroDDP
from colossalai.tensor import ProcessGroup
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext


def set_seed(seed):
Expand All @@ -33,7 +35,7 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP:


def init_ddpv2(module: torch.nn.Module) -> ZeroDDP:
chunk_config = search_chunk_configuration(module, 4, 1024)
chunk_config, _ = search_chunk_configuration(module, 4, 1024)
chunk_manager = ChunkManager(chunk_config)
gemini_manager = GeminiManager('cuda', chunk_manager)
return ZeroDDP(module, gemini_manager)
Expand Down

0 comments on commit f69f9bf

Please sign in to comment.