Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[gemini] improve compatibility and add static placement policy #4479

Merged
merged 10 commits into from
Aug 24, 2023
104 changes: 39 additions & 65 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import gc
import logging
import os
import warnings
from pathlib import Path
from typing import Callable, Iterator, List, Optional, Tuple, Union
from typing import Callable, Iterator, List, Optional, Tuple

import torch
import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
Expand All @@ -16,16 +14,14 @@
from colossalai.checkpoint_io.utils import (
get_model_base_filenames,
get_optimizer_base_filenames,
get_shard_filename,
load_shard_state_dict,
save_state_dict,
save_state_dict_shards,
)
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.zero.gemini import ZeroOptimizer
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats

from .dp_plugin_base import DPPluginBase
Expand Down Expand Up @@ -132,11 +128,7 @@ def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_
As there is communication when getting state dict, this must be called on all processes.
"""

# If optimizer is wrapped, unwrap it.
if isinstance(optimizer, OptimizerWrapper):
optimizer = optimizer.unwrap()

assert isinstance(optimizer, ZeroOptimizer)
assert isinstance(optimizer, GeminiOptimizer)

if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
Expand Down Expand Up @@ -183,11 +175,7 @@ def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Pa
if not os.path.isfile(checkpoint_index_file):
logging.error(f"Provided path ({checkpoint_index_file}) should be a file")

# If optimizer is wrapped, unwrap it.
if isinstance(optimizer, OptimizerWrapper):
optimizer = optimizer.unwrap()

assert isinstance(optimizer, ZeroOptimizer)
assert isinstance(optimizer, GeminiOptimizer)

# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
Expand Down Expand Up @@ -220,47 +208,6 @@ def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
super().save_lr_scheduler(lr_scheduler, checkpoint)


class GeminiModel(ModelWrapper):

def __init__(self, module: nn.Module, gemini_config: dict, verbose: bool = False) -> None:
super().__init__(module)
self.module = zero_model_wrapper(module, zero_stage=3, gemini_config=gemini_config, verbose=verbose)

def unwrap(self):
# as save/load state dict is coupled with the GeminiDDP, we only return GeminiDDP model
return self.module


class GeminiOptimizer(OptimizerWrapper):

def __init__(self,
module: GeminiDDP,
optimizer: Optimizer,
zero_optim_config: dict,
optim_kwargs: dict,
verbose: bool = False) -> None:
optimizer = zero_optim_wrapper(module,
optimizer,
optim_config=zero_optim_config,
**optim_kwargs,
verbose=verbose)
super().__init__(optimizer)

def backward(self, loss: Tensor, *args, **kwargs):
self.optim.backward(loss)

def clip_grad_by_norm(self,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2,
error_if_nonfinite: bool = False,
*args,
**kwargs) -> Tensor:
warnings.warn(f'Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm')

def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
raise NotImplementedError('Gemini does not support clip_grad_by_value')


class GeminiPlugin(DPPluginBase):
"""
Plugin for Gemini.
Expand All @@ -277,8 +224,20 @@ class GeminiPlugin(DPPluginBase):
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)

Args:
device (torch.device): device to place the model.
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
chunk_config_dict (dict, optional): chunk configuration dictionary.
chunk_init_device (torch.device, optional): device to initialize the chunk.
placement_policy (str, optional): "static" and "auto". Defaults to "static".
shard_param_frac (float, optional): fraction of parameters to be sharded. Only for "static" placement.
If `shard_param_frac` is 1.0, it's equal to zero-3. If `shard_param_frac` is 0.0, it's equal to zero-2. Defaults to 1.0.
offload_optim_frac (float, optional): fraction of optimizer states to be offloaded. Only for "static" placement.
If `shard_param_frac` is 1.0 and `offload_optim_frac` is 0.0, it's equal to old "cuda" placement. Defaults to 0.0.
offload_param_frac (float, optional): fraction of parameters to be offloaded. Only for "static" placement.
For efficiency, this argument is useful only when `shard_param_frac` is 1.0 and `offload_optim_frac` is 1.0.
If `shard_param_frac` is 1.0, `offload_optim_frac` is 1.0 and `offload_param_frac` is 1.0, it's equal to old "cpu" placement.
When using static placement, we recommend users to tune `shard_param_frac` first and then `offload_optim_frac`.
Defaults to 0.0.
warmup_non_model_data_ratio (float, optional): ratio of expected non-model data memory during warmup. Only for "auto" placement. Defaults to 0.8.
steady_cuda_cap_ratio (float, optional): ratio of allowed cuda capacity for model data during steady state. Only for "auto" placement. Defaults to 0.9.
precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'.
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
Expand Down Expand Up @@ -310,8 +269,14 @@ class GeminiPlugin(DPPluginBase):

def __init__(
self,
device: Optional[torch.device] = None,
placement_policy: str = "cpu",
chunk_config_dict: Optional[dict] = None,
chunk_init_device: Optional[torch.device] = None,
placement_policy: str = "static",
shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement
warmup_non_model_data_ratio: float = 0.8, # only for auto placement
steady_cuda_cap_ratio: float = 0.9, # only for auto placement
precision: str = "fp16",
pin_memory: bool = False,
force_outputs_fp32: bool = False,
Expand All @@ -335,8 +300,14 @@ def __init__(
super().__init__()
assert precision in SUPPORTED_PRECISION, f'precision {precision} is not supported'
self.gemini_config = dict(
device=(device or get_current_device()),
chunk_config_dict=chunk_config_dict,
chunk_init_device=(chunk_init_device or get_current_device()),
placement_policy=placement_policy,
shard_param_frac=shard_param_frac,
offload_optim_frac=offload_optim_frac,
offload_param_frac=offload_param_frac,
warmup_non_model_data_ratio=warmup_non_model_data_ratio,
steady_cuda_cap_ratio=steady_cuda_cap_ratio,
pin_memory=pin_memory,
force_outputs_fp32=force_outputs_fp32,
strict_ddp_mode=strict_ddp_mode,
Expand Down Expand Up @@ -393,12 +364,15 @@ def configure(
# model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None)

# wrap the model with Gemini
model = GeminiModel(model, self.gemini_config, self.verbose)
model = GeminiDDP(model, **self.gemini_config, verbose=self.verbose)

if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
self.verbose)
optimizer = GeminiOptimizer(optimizer,
model.unwrap(),
**self.zero_optim_config,
**self.optim_kwargs,
verbose=self.verbose)

return model, optimizer, criterion, dataloader, lr_scheduler

Expand Down
68 changes: 22 additions & 46 deletions colossalai/tensor/colo_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@
import torch

from colossalai.tensor.colo_tensor import ColoTensor
from colossalai.tensor.const import TensorType
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.tensor.tensor_spec import ColoTensorSpec

from .colo_tensor import _convert_output

WHITE_LIST_FUNCS = {torch.Tensor.__getitem__}


def is_no_hook_op(func) -> bool:
return func.__name__.startswith('__') and func not in WHITE_LIST_FUNCS


def filter_colo_parameters(*args, **kwargs):
Expand Down Expand Up @@ -41,53 +47,25 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):

"""

def __new__(cls,
data: Optional[torch.Tensor] = None,
requires_grad: bool = True,
spec: ColoTensorSpec = None) -> 'ColoParameter':
def __new__(cls, data: Optional[torch.Tensor] = None, requires_grad: bool = True) -> 'ColoParameter':
if data is None:
data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, requires_grad)

def __init__(self,
data: Optional[torch.Tensor] = None,
requires_grad: bool = True,
spec: ColoTensorSpec = None) -> None:
ColoTensor.__init__(self, data, spec)
self._type = TensorType.MODEL
# a list contains modules sharing this ColoParameter with others.
self._shared_param_modules = []

@property
def shared_param_modules(self):
return self._shared_param_modules

@staticmethod
def from_torch_tensor(tensor: torch.Tensor,
requires_grad: bool = True,
spec: ColoTensorSpec = None) -> 'ColoParameter':
tensor = tensor.as_subclass(ColoParameter)
tensor.__init__(tensor, requires_grad=requires_grad, spec=spec)
return tensor

def __repr__(self):
return super(ColoParameter, self).__repr__()

@classmethod
def __torch_function__(cls, func, types, args=..., kwargs=None):
if ColoParamOpHookManager.has_hook():
if not func.__name__.startswith('__'):
if kwargs is None:
kwargs = {}
params = filter_colo_parameters(*args, **kwargs)
if len(params) > 0:
with torch._C.DisableTorchFunction():
new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values())
args, kwargs = replace_args(args, kwargs, new_args)
ret = super().__torch_function__(func, types, args, kwargs)
with torch._C.DisableTorchFunction():
ret = ColoParamOpHookManager.post_op(params, ret)
return ret
if kwargs is None:
kwargs = {}
if ColoParamOpHookManager.has_hook() and not is_no_hook_op(func):
params = filter_colo_parameters(*args, **kwargs)
if len(params) > 0:
with torch._C.DisableTorchFunction():
new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values())
args, kwargs = replace_args(args, kwargs, new_args)
ret = super().__torch_function__(func, types, args, kwargs)
with torch._C.DisableTorchFunction():
ret = ColoParamOpHookManager.post_op(params, ret)
return _convert_output(ret, func)
return super().__torch_function__(func, types, args, kwargs)

def __deepcopy__(self, memo):
Expand All @@ -96,9 +74,7 @@ def __deepcopy__(self, memo):
else:
with torch._C.DisableTorchFunction():
data = self.data.clone()
tensor = ColoParameter(data,
self.requires_grad,
spec=ColoTensorSpec(self.get_process_group(), self.dist_spec, self.compute_spec))
tensor = ColoParameter(data, self.requires_grad)
memo[id(self)] = tensor
return tensor

Expand Down
Loading
Loading