Skip to content

Commit

Permalink
ptensor
Browse files Browse the repository at this point in the history
ptensor
  • Loading branch information
flybird11111 committed Apr 12, 2024
1 parent 14a4342 commit 935c2b3
Show file tree
Hide file tree
Showing 8 changed files with 193 additions and 137 deletions.
13 changes: 4 additions & 9 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,12 @@
ZERO_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2


def get_param_info(model: nn.Module, optim: Optimizer):
def get_param_info(optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes:
# 1. A mapping from integer param_id to param32 shape.

param_info = {"id2shape": {}, "name2shape": {}}
for p_name, param in model.named_parameters(remove_duplicate=False):
param_info["name2shape"][p_name] = param.shape

if optim is None:
return param_info
return {}
param_info = {"id2shape": {}}

start_index = 0
for group in optim.param_groups:
Expand Down Expand Up @@ -531,7 +527,7 @@ def configure(
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
params_info = get_param_info(model, optimizer)
params_info = get_param_info(optimizer)
if not isinstance(model, ModelWrapper):
# convert model to sync bn
# FIXME(ver217): gemini does not support sync bn
Expand All @@ -553,7 +549,6 @@ def configure(
zero_group=self.zero_group,
extra_dp_group=self.extra_dp_group,
verbose=self.verbose,
params_info=params_info,
)

if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
Expand Down
76 changes: 19 additions & 57 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from functools import reduce
from pathlib import Path
from shutil import rmtree
from typing import Dict, Iterator, Optional, OrderedDict, Set, Tuple
from typing import Dict, Iterator, Optional, OrderedDict, Tuple

import torch
import torch.distributed as dist
Expand All @@ -14,6 +14,7 @@

from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.tensor.p_tensor import init_as_ptensor, is_padded_tensor, to_padded_tensor, to_unpadded_tensor
from colossalai.utils import get_current_device

from .general_checkpoint_io import GeneralCheckpointIO
Expand Down Expand Up @@ -77,40 +78,6 @@ def __init__(
self.verbose = verbose
self.coordinator = DistCoordinator()

@staticmethod
def _named_modules(
module: nn.Module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
):
r"""Returns an iterator over all leaf modules in the network, yielding
both the name of the module as well as the module itself.
Args:
memo: a memo to store the set of modules already added to the result
prefix: a prefix that will be added to the name of the module
remove_duplicate: whether to remove the duplicated module instances in the result
or not
Yields:
(str, Module): Tuple of name and module
Note:
Duplicate modules are returned only once. In the following
example, ``l`` will be returned only once.
"""
if memo is None:
memo = set()

if module not in memo:
sub_modules = [(name, subm) for (name, subm) in module._modules.items() if subm is not None]
if len(sub_modules) == 0:
if remove_duplicate:
memo.add(module)
yield prefix, module
else:
for name, subm in sub_modules:
submodule_prefix = prefix + ("." if prefix else "") + name
yield from HybridParallelCheckpointIO._named_modules(subm, memo, submodule_prefix, remove_duplicate)

@staticmethod
def _model_sharder(
model: nn.Module, prefix: str = "", keep_vars: bool = False, size_per_shard: int = 1024
Expand All @@ -120,18 +87,16 @@ def _model_sharder(
state_dict_sharder = StateDictSharder(size_per_shard)

# Save parameters.
for module_name, module in HybridParallelCheckpointIO._named_modules(model):
state_dicts = module.state_dict()
for name, param in state_dicts.items():
if param is None:
continue
# Gather tensor pieces when using tensor parallel.
param_ = gather_distributed_param(param, keep_vars=False)
if module_name != "":
module_name = module_name + "."
block, block_size = state_dict_sharder.append_param(module_name + name, param_)
if block is not None:
yield block, block_size
for name, param in model.named_parameters():
if param is None:
continue
# Gather tensor pieces when using tensor parallel.
if is_padded_tensor(param):
param = to_unpadded_tensor(param)
param_ = gather_distributed_param(param, keep_vars=False)
block, block_size = state_dict_sharder.append_param(prefix + name, param_)
if block is not None:
yield block, block_size

# Save buffers.
for name, buf in model.named_buffers():
Expand Down Expand Up @@ -906,7 +871,12 @@ def gather_from_sharded_optimizer_state(
dist.all_gather(gather_tensor, v, group=tp_group)
v = torch.cat(gather_tensor, dim=partition_dim)

state_[k] = v.detach().clone()[: original_shape[0], ...].to(device)
padding_dim = search_padding_dim(v.shape, original_shape)
if padding_dim is not None:
v = init_as_ptensor(v, v.shape[padding_dim], original_shape[padding_dim], padding_dim)
v = to_unpadded_tensor(v)

state_[k] = v.detach().clone().to(device)

return state_

Expand Down Expand Up @@ -949,15 +919,7 @@ def shard_from_complete_optimizer_state(

padding_dim = search_padding_dim(global_shape, original_shape)
if padding_dim is not None:
padding_size = global_shape[padding_dim] - original_shape[padding_dim]
padding_data = torch.zeros(
*v.shape[:padding_dim],
padding_size,
*v.shape[padding_dim + 1 :],
device=v.device,
dtype=v.dtype,
)
v = torch.cat((v, padding_data), dim=padding_dim).contiguous()
v = to_padded_tensor(v, global_shape[padding_dim], padding_dim)

if partition_dim is not None:
slice_size = current_shape[partition_dim]
Expand Down
30 changes: 9 additions & 21 deletions colossalai/shardformer/layer/parallel_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
is_distributed_tensor,
sharded_tensor_to_param,
)
from colossalai.tensor.p_tensor import is_padded_tensor, to_padded_tensor, to_unpadded_tensor

__all__ = ["ParallelModule"]

Expand Down Expand Up @@ -230,10 +231,9 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
for name, param in self._parameters.items():
if param is not None:
param = gather_distributed_param(param, keep_vars=keep_vars)
if self.new_num_embeddings > self.old_num_embeddings:
destination[prefix + name] = param[: self.old_num_embeddings, ...].data
else:
destination[prefix + name] = param.data
if is_padded_tensor(param):
param = to_unpadded_tensor(param)
destination[prefix + name] = param.data

for name, buf in self._buffers.items():
if buf is not None and name not in self._non_persistent_buffers_set:
Expand Down Expand Up @@ -296,12 +296,9 @@ def _load_from_state_dict(
)
continue

if self.new_num_embeddings > self.old_num_embeddings:
num_padding_tokens = self.new_num_embeddings - self.old_num_embeddings
padding_embeddings = torch.zeros(
num_padding_tokens, *input_param.shape[1:], device=input_param.device, dtype=input_param.dtype
)
input_param.data = torch.cat((input_param.data, padding_embeddings), dim=0).contiguous()
if is_padded_tensor(param):
print("is_padded_tensor(param)", is_padded_tensor(param))
input_param = to_padded_tensor(input_param, param.current_length, param.padding_dim)

if is_distributed_tensor(param):
# shard the input param
Expand Down Expand Up @@ -359,16 +356,7 @@ def _load_from_state_dict(
unexpected_keys.append(key)

def resize_embedding_weight(self):
num_padding_tokens = self.new_num_embeddings - self.old_num_embeddings
valid_weight = self.weight.data
padding_weight = torch.zeros(
num_padding_tokens, *self.weight.shape[1:], device=self.weight.device, dtype=self.weight.dtype
)
# padding to embedding
self.weight.data = torch.cat((valid_weight, padding_weight), dim=0).contiguous()
self.weight = to_padded_tensor(self.weight, self.new_num_embeddings, 0)

def resize_embedding_bias(self):
num_padding_tokens = self.new_num_embeddings - self.old_num_embeddings
valid_bias = self.bias.data
padding_bias = torch.zeros((num_padding_tokens), device=self.bias.device, dtype=self.bias.dtype)
self.bias.data = torch.cat((valid_bias, padding_bias), dim=0).contiguous()
self.bias = to_padded_tensor(self.bias, self.new_num_embeddings, 0)
3 changes: 3 additions & 0 deletions colossalai/tensor/p_tensor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .api import init_as_ptensor, is_padded_tensor, to_padded_tensor, to_unpadded_tensor

__all__ = ["is_padded_tensor", "to_padded_tensor", "to_unpadded_tensor", "init_as_ptensor"]
128 changes: 128 additions & 0 deletions colossalai/tensor/p_tensor/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import torch


def _hijack_detach_and_clone(ptensor: torch.Tensor) -> torch.Tensor:
"""
Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied.
Args:
tensor (torch.Tensor): The tensor to be hijacked.
Returns:
torch.Tensor: The hijacked tensor.
"""
ptensor._unpad_detach = ptensor.detach
ptensor._unpad_clone = ptensor.clone

def new_detach(self):
t_ = self._unpad_detach()
t_.padding_dim = self.padding_dim
t_.origin_length = self.origin_length
t_.current_length = self.current_length
return t_

def new_clone(self, *args, **kwargs):
t_ = self._unpad_clone(*args, **kwargs)
t_.padding_dim = self.padding_dim
t_.origin_length = self.origin_length
t_.current_length = self.current_length
return t_

# bind the new methods to the tensor
ptensor.detach = new_detach.__get__(ptensor)
ptensor.clone = new_clone.__get__(ptensor)
return ptensor


def _hijack_back_detach_and_clone(ptensor: torch.Tensor) -> torch.Tensor:
"""
Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied.
Args:
tensor (torch.Tensor): The tensor to be hijacked.
Returns:
torch.Tensor: The hijacked tensor.
"""
ptensor.detach = ptensor._unpad_detach
ptensor.clone = ptensor._unpad_clone

delattr(ptensor, "_unpad_detach")
delattr(ptensor, "_unpad_clone")

return ptensor


def is_padded_tensor(tensor: torch.Tensor) -> bool:
"""
Check whether the given tensor is a padding tensor.
Args:
tensor (torch.Tensor): The tensor to be checked.
Returns:
bool: Whether the given tensor is a padding tensor.
"""
return hasattr(tensor, "padding_dim")


def to_padded_tensor(
tensor: torch.Tensor,
current_length: int,
padding_dim: int,
) -> torch.Tensor:
assert (
padding_dim < tensor.dim()
), f"Please passing a valid padding_dim. the dimension of the tensor is {tensor.dim()}"

if is_padded_tensor(tensor):
return tensor

origin_length = tensor.shape[padding_dim]
padding_num = current_length - origin_length
padding_data = torch.zeros(
*tensor.shape[:padding_dim],
padding_num,
*tensor.shape[padding_dim + 1 :],
device=tensor.device,
dtype=tensor.dtype,
)
tensor.data = torch.cat((tensor.data, padding_data), dim=padding_dim).contiguous()

setattr(tensor, "padding_dim", padding_dim)
setattr(tensor, "origin_length", origin_length)
setattr(tensor, "current_length", current_length)

_hijack_detach_and_clone(tensor)

return tensor


def to_unpadded_tensor(ptensor: torch.Tensor):
if not is_padded_tensor(ptensor):
return ptensor

unpad_slices = [slice(None)] * ptensor.dim()
unpad_slices[ptensor.padding_dim] = slice(None, ptensor.origin_length)
tensor = ptensor[tuple(unpad_slices)]

delattr(ptensor, "padding_dim")
delattr(ptensor, "origin_length")
delattr(ptensor, "current_length")

_hijack_back_detach_and_clone(ptensor)

return tensor


def init_as_ptensor(tensor: torch.Tensor, current_length: int, origin_length: int, padding_dim: int):
if is_padded_tensor(tensor):
return tensor

setattr(tensor, "padding_dim", padding_dim)
setattr(tensor, "origin_length", origin_length)
setattr(tensor, "current_length", current_length)

_hijack_detach_and_clone(tensor)

return tensor
2 changes: 1 addition & 1 deletion colossalai/testing/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1
rtol=rtol,
atol=atol,
msg=f"Tensor not close, shape: {a.shape} vs {b.shape}, \
dtype: {a.dtype} vs {b.dtype}",
dtype: {a.dtype} vs {b.dtype}",
)


Expand Down
Loading

0 comments on commit 935c2b3

Please sign in to comment.