From 935c2b33f203c0eda8b8ceb2c9d8309e3a17d15b Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 12 Apr 2024 20:52:36 +0800 Subject: [PATCH] ptensor ptensor --- colossalai/booster/plugin/gemini_plugin.py | 13 +- .../hybrid_parallel_checkpoint_io.py | 76 +++-------- .../shardformer/layer/parallel_module.py | 30 ++-- colossalai/tensor/p_tensor/__init__.py | 3 + colossalai/tensor/p_tensor/api.py | 128 ++++++++++++++++++ colossalai/testing/comparison.py | 2 +- colossalai/zero/gemini/gemini_ddp.py | 47 ++----- colossalai/zero/gemini/gemini_optimizer.py | 31 ++--- 8 files changed, 193 insertions(+), 137 deletions(-) create mode 100644 colossalai/tensor/p_tensor/__init__.py create mode 100644 colossalai/tensor/p_tensor/api.py diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 3709b3055c93..442ac4a8da06 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -41,16 +41,12 @@ ZERO_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2 -def get_param_info(model: nn.Module, optim: Optimizer): +def get_param_info(optim: Optimizer): # Get a backup of necessary information of parameters for future use, which includes: # 1. A mapping from integer param_id to param32 shape. - - param_info = {"id2shape": {}, "name2shape": {}} - for p_name, param in model.named_parameters(remove_duplicate=False): - param_info["name2shape"][p_name] = param.shape - if optim is None: - return param_info + return {} + param_info = {"id2shape": {}} start_index = 0 for group in optim.param_groups: @@ -531,7 +527,7 @@ def configure( dataloader: Optional[DataLoader] = None, lr_scheduler: Optional[LRScheduler] = None, ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: - params_info = get_param_info(model, optimizer) + params_info = get_param_info(optimizer) if not isinstance(model, ModelWrapper): # convert model to sync bn # FIXME(ver217): gemini does not support sync bn @@ -553,7 +549,6 @@ def configure( zero_group=self.zero_group, extra_dp_group=self.extra_dp_group, verbose=self.verbose, - params_info=params_info, ) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 771a5f78bb24..0718b2a60889 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -4,7 +4,7 @@ from functools import reduce from pathlib import Path from shutil import rmtree -from typing import Dict, Iterator, Optional, OrderedDict, Set, Tuple +from typing import Dict, Iterator, Optional, OrderedDict, Tuple import torch import torch.distributed as dist @@ -14,6 +14,7 @@ from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.tensor.p_tensor import init_as_ptensor, is_padded_tensor, to_padded_tensor, to_unpadded_tensor from colossalai.utils import get_current_device from .general_checkpoint_io import GeneralCheckpointIO @@ -77,40 +78,6 @@ def __init__( self.verbose = verbose self.coordinator = DistCoordinator() - @staticmethod - def _named_modules( - module: nn.Module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True - ): - r"""Returns an iterator over all leaf modules in the network, yielding - both the name of the module as well as the module itself. - - Args: - memo: a memo to store the set of modules already added to the result - prefix: a prefix that will be added to the name of the module - remove_duplicate: whether to remove the duplicated module instances in the result - or not - - Yields: - (str, Module): Tuple of name and module - - Note: - Duplicate modules are returned only once. In the following - example, ``l`` will be returned only once. - """ - if memo is None: - memo = set() - - if module not in memo: - sub_modules = [(name, subm) for (name, subm) in module._modules.items() if subm is not None] - if len(sub_modules) == 0: - if remove_duplicate: - memo.add(module) - yield prefix, module - else: - for name, subm in sub_modules: - submodule_prefix = prefix + ("." if prefix else "") + name - yield from HybridParallelCheckpointIO._named_modules(subm, memo, submodule_prefix, remove_duplicate) - @staticmethod def _model_sharder( model: nn.Module, prefix: str = "", keep_vars: bool = False, size_per_shard: int = 1024 @@ -120,18 +87,16 @@ def _model_sharder( state_dict_sharder = StateDictSharder(size_per_shard) # Save parameters. - for module_name, module in HybridParallelCheckpointIO._named_modules(model): - state_dicts = module.state_dict() - for name, param in state_dicts.items(): - if param is None: - continue - # Gather tensor pieces when using tensor parallel. - param_ = gather_distributed_param(param, keep_vars=False) - if module_name != "": - module_name = module_name + "." - block, block_size = state_dict_sharder.append_param(module_name + name, param_) - if block is not None: - yield block, block_size + for name, param in model.named_parameters(): + if param is None: + continue + # Gather tensor pieces when using tensor parallel. + if is_padded_tensor(param): + param = to_unpadded_tensor(param) + param_ = gather_distributed_param(param, keep_vars=False) + block, block_size = state_dict_sharder.append_param(prefix + name, param_) + if block is not None: + yield block, block_size # Save buffers. for name, buf in model.named_buffers(): @@ -906,7 +871,12 @@ def gather_from_sharded_optimizer_state( dist.all_gather(gather_tensor, v, group=tp_group) v = torch.cat(gather_tensor, dim=partition_dim) - state_[k] = v.detach().clone()[: original_shape[0], ...].to(device) + padding_dim = search_padding_dim(v.shape, original_shape) + if padding_dim is not None: + v = init_as_ptensor(v, v.shape[padding_dim], original_shape[padding_dim], padding_dim) + v = to_unpadded_tensor(v) + + state_[k] = v.detach().clone().to(device) return state_ @@ -949,15 +919,7 @@ def shard_from_complete_optimizer_state( padding_dim = search_padding_dim(global_shape, original_shape) if padding_dim is not None: - padding_size = global_shape[padding_dim] - original_shape[padding_dim] - padding_data = torch.zeros( - *v.shape[:padding_dim], - padding_size, - *v.shape[padding_dim + 1 :], - device=v.device, - dtype=v.dtype, - ) - v = torch.cat((v, padding_data), dim=padding_dim).contiguous() + v = to_padded_tensor(v, global_shape[padding_dim], padding_dim) if partition_dim is not None: slice_size = current_shape[partition_dim] diff --git a/colossalai/shardformer/layer/parallel_module.py b/colossalai/shardformer/layer/parallel_module.py index eae31215c58d..55114281d1ac 100644 --- a/colossalai/shardformer/layer/parallel_module.py +++ b/colossalai/shardformer/layer/parallel_module.py @@ -20,6 +20,7 @@ is_distributed_tensor, sharded_tensor_to_param, ) +from colossalai.tensor.p_tensor import is_padded_tensor, to_padded_tensor, to_unpadded_tensor __all__ = ["ParallelModule"] @@ -230,10 +231,9 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): for name, param in self._parameters.items(): if param is not None: param = gather_distributed_param(param, keep_vars=keep_vars) - if self.new_num_embeddings > self.old_num_embeddings: - destination[prefix + name] = param[: self.old_num_embeddings, ...].data - else: - destination[prefix + name] = param.data + if is_padded_tensor(param): + param = to_unpadded_tensor(param) + destination[prefix + name] = param.data for name, buf in self._buffers.items(): if buf is not None and name not in self._non_persistent_buffers_set: @@ -296,12 +296,9 @@ def _load_from_state_dict( ) continue - if self.new_num_embeddings > self.old_num_embeddings: - num_padding_tokens = self.new_num_embeddings - self.old_num_embeddings - padding_embeddings = torch.zeros( - num_padding_tokens, *input_param.shape[1:], device=input_param.device, dtype=input_param.dtype - ) - input_param.data = torch.cat((input_param.data, padding_embeddings), dim=0).contiguous() + if is_padded_tensor(param): + print("is_padded_tensor(param)", is_padded_tensor(param)) + input_param = to_padded_tensor(input_param, param.current_length, param.padding_dim) if is_distributed_tensor(param): # shard the input param @@ -359,16 +356,7 @@ def _load_from_state_dict( unexpected_keys.append(key) def resize_embedding_weight(self): - num_padding_tokens = self.new_num_embeddings - self.old_num_embeddings - valid_weight = self.weight.data - padding_weight = torch.zeros( - num_padding_tokens, *self.weight.shape[1:], device=self.weight.device, dtype=self.weight.dtype - ) - # padding to embedding - self.weight.data = torch.cat((valid_weight, padding_weight), dim=0).contiguous() + self.weight = to_padded_tensor(self.weight, self.new_num_embeddings, 0) def resize_embedding_bias(self): - num_padding_tokens = self.new_num_embeddings - self.old_num_embeddings - valid_bias = self.bias.data - padding_bias = torch.zeros((num_padding_tokens), device=self.bias.device, dtype=self.bias.dtype) - self.bias.data = torch.cat((valid_bias, padding_bias), dim=0).contiguous() + self.bias = to_padded_tensor(self.bias, self.new_num_embeddings, 0) diff --git a/colossalai/tensor/p_tensor/__init__.py b/colossalai/tensor/p_tensor/__init__.py new file mode 100644 index 000000000000..84490fc2a538 --- /dev/null +++ b/colossalai/tensor/p_tensor/__init__.py @@ -0,0 +1,3 @@ +from .api import init_as_ptensor, is_padded_tensor, to_padded_tensor, to_unpadded_tensor + +__all__ = ["is_padded_tensor", "to_padded_tensor", "to_unpadded_tensor", "init_as_ptensor"] diff --git a/colossalai/tensor/p_tensor/api.py b/colossalai/tensor/p_tensor/api.py new file mode 100644 index 000000000000..805d650d776b --- /dev/null +++ b/colossalai/tensor/p_tensor/api.py @@ -0,0 +1,128 @@ +import torch + + +def _hijack_detach_and_clone(ptensor: torch.Tensor) -> torch.Tensor: + """ + Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied. + + Args: + tensor (torch.Tensor): The tensor to be hijacked. + + Returns: + torch.Tensor: The hijacked tensor. + """ + ptensor._unpad_detach = ptensor.detach + ptensor._unpad_clone = ptensor.clone + + def new_detach(self): + t_ = self._unpad_detach() + t_.padding_dim = self.padding_dim + t_.origin_length = self.origin_length + t_.current_length = self.current_length + return t_ + + def new_clone(self, *args, **kwargs): + t_ = self._unpad_clone(*args, **kwargs) + t_.padding_dim = self.padding_dim + t_.origin_length = self.origin_length + t_.current_length = self.current_length + return t_ + + # bind the new methods to the tensor + ptensor.detach = new_detach.__get__(ptensor) + ptensor.clone = new_clone.__get__(ptensor) + return ptensor + + +def _hijack_back_detach_and_clone(ptensor: torch.Tensor) -> torch.Tensor: + """ + Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied. + + Args: + tensor (torch.Tensor): The tensor to be hijacked. + + Returns: + torch.Tensor: The hijacked tensor. + """ + ptensor.detach = ptensor._unpad_detach + ptensor.clone = ptensor._unpad_clone + + delattr(ptensor, "_unpad_detach") + delattr(ptensor, "_unpad_clone") + + return ptensor + + +def is_padded_tensor(tensor: torch.Tensor) -> bool: + """ + Check whether the given tensor is a padding tensor. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + bool: Whether the given tensor is a padding tensor. + """ + return hasattr(tensor, "padding_dim") + + +def to_padded_tensor( + tensor: torch.Tensor, + current_length: int, + padding_dim: int, +) -> torch.Tensor: + assert ( + padding_dim < tensor.dim() + ), f"Please passing a valid padding_dim. the dimension of the tensor is {tensor.dim()}" + + if is_padded_tensor(tensor): + return tensor + + origin_length = tensor.shape[padding_dim] + padding_num = current_length - origin_length + padding_data = torch.zeros( + *tensor.shape[:padding_dim], + padding_num, + *tensor.shape[padding_dim + 1 :], + device=tensor.device, + dtype=tensor.dtype, + ) + tensor.data = torch.cat((tensor.data, padding_data), dim=padding_dim).contiguous() + + setattr(tensor, "padding_dim", padding_dim) + setattr(tensor, "origin_length", origin_length) + setattr(tensor, "current_length", current_length) + + _hijack_detach_and_clone(tensor) + + return tensor + + +def to_unpadded_tensor(ptensor: torch.Tensor): + if not is_padded_tensor(ptensor): + return ptensor + + unpad_slices = [slice(None)] * ptensor.dim() + unpad_slices[ptensor.padding_dim] = slice(None, ptensor.origin_length) + tensor = ptensor[tuple(unpad_slices)] + + delattr(ptensor, "padding_dim") + delattr(ptensor, "origin_length") + delattr(ptensor, "current_length") + + _hijack_back_detach_and_clone(ptensor) + + return tensor + + +def init_as_ptensor(tensor: torch.Tensor, current_length: int, origin_length: int, padding_dim: int): + if is_padded_tensor(tensor): + return tensor + + setattr(tensor, "padding_dim", padding_dim) + setattr(tensor, "origin_length", origin_length) + setattr(tensor, "current_length", current_length) + + _hijack_detach_and_clone(tensor) + + return tensor diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py index e415b5fc3aa3..bdf7b19f39d0 100644 --- a/colossalai/testing/comparison.py +++ b/colossalai/testing/comparison.py @@ -23,7 +23,7 @@ def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1 rtol=rtol, atol=atol, msg=f"Tensor not close, shape: {a.shape} vs {b.shape}, \ - dtype: {a.dtype} vs {b.dtype}", + dtype: {a.dtype} vs {b.dtype}", ) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index e6a08aa31d9a..22351d26e9a6 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -11,7 +11,7 @@ from torch.distributed.distributed_c10d import _get_default_group from colossalai.accelerator import get_accelerator -from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param, search_padding_dim +from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param from colossalai.interface import ModelWrapper from colossalai.lazy import LazyTensor from colossalai.logging import get_dist_logger @@ -27,6 +27,7 @@ is_customized_distributed_tensor, is_distributed_tensor, ) +from colossalai.tensor.p_tensor import init_as_ptensor, is_padded_tensor, to_padded_tensor, to_unpadded_tensor from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.utils import _cast_float, free_storage, is_ddp_ignored @@ -89,7 +90,6 @@ def __init__( memstats: Optional[MemStats] = None, # genimi memory stats master_weights: bool = True, extra_dp_group: Optional[ProcessGroup] = None, - params_info: OrderedDict = None, verbose: bool = False, ) -> None: assert mixed_precision in (torch.float16, torch.bfloat16) @@ -131,7 +131,6 @@ def __init__( self.mixed_precision = mixed_precision self.zero_group = zero_group or _get_default_group() self.extra_dp_group = extra_dp_group - self.params_info = params_info self.reuse_fp16_chunk = master_weights self.master_weights = master_weights @@ -462,6 +461,11 @@ def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool) -> Dict: record_tensor, shard_fn=tensor.shard_fn, gather_fn=tensor.gather_fn ) record_tensor = gather_distributed_param(record_tensor, keep_vars=False).cpu() + if is_padded_tensor(tensor): + record_tensor = init_as_ptensor( + record_tensor, tensor.current_length, tensor.origin_length, tensor.padding_dim + ) + record_tensor = to_unpadded_tensor(record_tensor) assert tensor not in chunk_to_save_data chunk_to_save_data[tensor] = record_tensor @@ -522,17 +526,9 @@ def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True): # deal with ddp ignored parameters destination[prefix + name] = param if keep_vars else param.detach() else: - if self.params_info is not None: - origin_shape = self.params_info["name2shape"][name] - padding_dim = search_padding_dim(p_mapping[param].shape, origin_shape) - if padding_dim is not None: - unpadding_slices = [slice(None)] * p_mapping[param].dim() - unpadding_slices[padding_dim] = slice(None, origin_shape[0]) - destination[prefix + name] = p_mapping[param][tuple(unpadding_slices)] - else: - destination[prefix + name] = p_mapping[param] - else: - destination[prefix + name] = p_mapping[param] + if is_padded_tensor(p_mapping[param]): + p_mapping[param] = to_unpadded_tensor(p_mapping[param]) + destination[prefix + name] = p_mapping[param] del p_mapping del param_to_save_data @@ -639,6 +635,7 @@ def _load_from_state_dict( list, and will be reported together in :meth:`~torch.nn.Module.load_state_dict` """ + for hook in self._load_state_dict_pre_hooks.values(): hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) @@ -663,17 +660,9 @@ def load( if source_device_mesh is not None and source_sharding_spec is not None: global_shape = get_global_shape(dest_tensor) - padding_dim = search_padding_dim(global_shape, input_param.shape) - if padding_dim is not None: - padding_num = global_shape[padding_dim] - input_param.shape[padding_dim] - padding_data = torch.zeros( - *input_param.shape[:padding_dim], - padding_num, - *input_param.shape[padding_dim + 1 :], - device=input_param.device, - dtype=input_param.dtype, - ) - input_param = torch.cat((input_param, padding_data), dim=padding_dim) + if is_padded_tensor(dest_tensor): + padding_dim = dest_tensor.padding_dim + input_param = to_padded_tensor(input_param, global_shape[padding_dim], padding_dim) if source_device_mesh is not None and source_sharding_spec is not None: input_param = distribute_tensor(input_param, source_device_mesh, source_sharding_spec) @@ -911,14 +900,6 @@ def state_dict_shard( gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0)) gathered_param = gathered_param_buffer.pop(param_to_save) - if self.params_info is not None: - origin_shape = self.params_info["name2shape"][name] - padding_dim = search_padding_dim(gathered_param.shape, origin_shape) - if padding_dim is not None: - unpadding_slices = [slice(None)] * gathered_param.dim() - unpadding_slices[padding_dim] = slice(None, origin_shape[0]) - gathered_param = gathered_param[tuple(unpadding_slices)] - block, block_size = sharder.append_param(prefix + name, gathered_param) if block is not None: yield block, block_size diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 6bef63baa438..135927e4f295 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -13,7 +13,7 @@ from colossalai.accelerator import get_accelerator from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin -from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param, search_padding_dim +from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_param from colossalai.interface import OptimizerWrapper from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam @@ -28,6 +28,7 @@ is_customized_distributed_tensor, is_distributed_tensor, ) +from colossalai.tensor.p_tensor import init_as_ptensor, is_padded_tensor, to_padded_tensor, to_unpadded_tensor from colossalai.utils import disposable, is_ddp_ignored from .chunk import Chunk, ChunkManager @@ -461,7 +462,6 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: shard_spec = get_sharding_spec(param) if is_dtensor else None device_mesh = get_device_mesh(param) if is_dtensor else None global_shape = self.params_info["id2shape"][param_id] - origin_shape = global_shape # If the chunk is kept gathered, # the parameters are treated the same as that of those in strict DDP during training. @@ -494,8 +494,11 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: ) state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu() state_tensor = state_tensor.reshape(global_shape) - state_tensor = state_tensor[: origin_shape[0], ...] - + if is_padded_tensor(param): + state_tensor = init_as_ptensor( + state_tensor, param.current_length, param.origin_length, param.padding_dim + ) + state_tensor = to_unpadded_tensor(state_tensor) collected_states[state_name] = state_tensor return collected_states @@ -551,7 +554,11 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn ) state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu() - state_tensor = state_tensor[: origin_shape[0], ...] + if is_padded_tensor(param): + state_tensor = init_as_ptensor( + state_tensor, param.current_length, param.origin_length, param.padding_dim + ) + state_tensor = to_unpadded_tensor(state_tensor) return collected_states @@ -723,18 +730,10 @@ def cast(param, state_range, value, global_shape, origin_shape, key=None): if is_dtensor: global_shape = get_global_shape(real_param) - padding_dim = search_padding_dim(global_shape, origin_shape) - if padding_dim is not None: - padding_num = global_shape[padding_dim] - origin_shape[padding_dim] + if is_padded_tensor(real_param): value = torch.reshape(value, origin_shape) - padding_data = torch.zeros( - *value.shape[:padding_dim], - padding_num, - *value.shape[padding_dim + 1 :], - device=value.device, - dtype=value.dtype, - ) - value = torch.cat((value, padding_data), dim=padding_dim).contiguous() + padding_dim = real_param.padding_dim + value = to_padded_tensor(value, global_shape[padding_dim], padding_dim) if is_dtensor: value = distribute_tensor(value, sharding_spec=shard_spec, device_mesh=device_mesh)