Skip to content

Commit

Permalink
ptensor
Browse files Browse the repository at this point in the history
  • Loading branch information
flybird11111 committed Apr 12, 2024
1 parent 14a4342 commit 175fd26
Show file tree
Hide file tree
Showing 9 changed files with 326 additions and 133 deletions.
132 changes: 76 additions & 56 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from functools import reduce
from pathlib import Path
from shutil import rmtree
from typing import Dict, Iterator, Optional, OrderedDict, Set, Tuple
from typing import Dict, Iterator, Optional, OrderedDict, Tuple

import torch
import torch.distributed as dist
Expand All @@ -14,6 +14,7 @@

from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.tensor.p_tensor import init_as_ptensor, is_padded_tensor, to_padded_tensor, to_unpadded_tensor
from colossalai.utils import get_current_device

from .general_checkpoint_io import GeneralCheckpointIO
Expand Down Expand Up @@ -77,39 +78,39 @@ def __init__(
self.verbose = verbose
self.coordinator = DistCoordinator()

@staticmethod
def _named_modules(
module: nn.Module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
):
r"""Returns an iterator over all leaf modules in the network, yielding
both the name of the module as well as the module itself.
Args:
memo: a memo to store the set of modules already added to the result
prefix: a prefix that will be added to the name of the module
remove_duplicate: whether to remove the duplicated module instances in the result
or not
Yields:
(str, Module): Tuple of name and module
Note:
Duplicate modules are returned only once. In the following
example, ``l`` will be returned only once.
"""
if memo is None:
memo = set()

if module not in memo:
sub_modules = [(name, subm) for (name, subm) in module._modules.items() if subm is not None]
if len(sub_modules) == 0:
if remove_duplicate:
memo.add(module)
yield prefix, module
else:
for name, subm in sub_modules:
submodule_prefix = prefix + ("." if prefix else "") + name
yield from HybridParallelCheckpointIO._named_modules(subm, memo, submodule_prefix, remove_duplicate)
# @staticmethod
# def _named_modules(
# module: nn.Module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
# ):
# r"""Returns an iterator over all leaf modules in the network, yielding
# both the name of the module as well as the module itself.

# Args:
# memo: a memo to store the set of modules already added to the result
# prefix: a prefix that will be added to the name of the module
# remove_duplicate: whether to remove the duplicated module instances in the result
# or not

# Yields:
# (str, Module): Tuple of name and module

# Note:
# Duplicate modules are returned only once. In the following
# example, ``l`` will be returned only once.
# """
# if memo is None:
# memo = set()

# if module not in memo:
# sub_modules = [(name, subm) for (name, subm) in module._modules.items() if subm is not None]
# if len(sub_modules) == 0:
# if remove_duplicate:
# memo.add(module)
# yield prefix, module
# else:
# for name, subm in sub_modules:
# submodule_prefix = prefix + ("." if prefix else "") + name
# yield from HybridParallelCheckpointIO._named_modules(subm, memo, submodule_prefix, remove_duplicate)

@staticmethod
def _model_sharder(
Expand All @@ -120,18 +121,29 @@ def _model_sharder(
state_dict_sharder = StateDictSharder(size_per_shard)

# Save parameters.
for module_name, module in HybridParallelCheckpointIO._named_modules(model):
state_dicts = module.state_dict()
for name, param in state_dicts.items():
if param is None:
continue
# Gather tensor pieces when using tensor parallel.
param_ = gather_distributed_param(param, keep_vars=False)
if module_name != "":
module_name = module_name + "."
block, block_size = state_dict_sharder.append_param(module_name + name, param_)
if block is not None:
yield block, block_size
# for module_name, module in HybridParallelCheckpointIO._named_modules(model):
# state_dicts = module.state_dict()
# for name, param in state_dicts.items():
# if param is None:
# continue
# # Gather tensor pieces when using tensor parallel.
# param_ = gather_distributed_param(param, keep_vars=False)
# if module_name != "":
# module_name = module_name + "."
# block, block_size = state_dict_sharder.append_param(module_name + name, param_)
# if block is not None:
# yield block, block_size
for name, param in model.named_parameters():
if param is None:
continue
# Gather tensor pieces when using tensor parallel.
if is_padded_tensor(param):
print("bbbbbbbbbbbbbbbbbbbbbbbbbb")
param = to_unpadded_tensor(param)
param_ = gather_distributed_param(param, keep_vars=False)
block, block_size = state_dict_sharder.append_param(prefix + name, param_)
if block is not None:
yield block, block_size

# Save buffers.
for name, buf in model.named_buffers():
Expand Down Expand Up @@ -906,7 +918,13 @@ def gather_from_sharded_optimizer_state(
dist.all_gather(gather_tensor, v, group=tp_group)
v = torch.cat(gather_tensor, dim=partition_dim)

state_[k] = v.detach().clone()[: original_shape[0], ...].to(device)
padding_dim = search_padding_dim(v.shape, original_shape)
if padding_dim is not None:
print("cccccccccccec")
v = init_as_ptensor(v, v.shape[padding_dim], original_shape[padding_dim], padding_dim)
v = to_unpadded_tensor(v)

state_[k] = v.detach().clone().to(device)

return state_

Expand Down Expand Up @@ -949,15 +967,17 @@ def shard_from_complete_optimizer_state(

padding_dim = search_padding_dim(global_shape, original_shape)
if padding_dim is not None:
padding_size = global_shape[padding_dim] - original_shape[padding_dim]
padding_data = torch.zeros(
*v.shape[:padding_dim],
padding_size,
*v.shape[padding_dim + 1 :],
device=v.device,
dtype=v.dtype,
)
v = torch.cat((v, padding_data), dim=padding_dim).contiguous()
# padding_size = global_shape[padding_dim] - original_shape[padding_dim]
# padding_data = torch.zeros(
# *v.shape[:padding_dim],
# padding_size,
# *v.shape[padding_dim + 1 :],
# device=v.device,
# dtype=v.dtype,
# )
# v = torch.cat((v, padding_data), dim=padding_dim).contiguous()
print("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")
v = to_padded_tensor(v, global_shape[padding_dim], padding_dim)

if partition_dim is not None:
slice_size = current_shape[partition_dim]
Expand Down
52 changes: 31 additions & 21 deletions colossalai/shardformer/layer/parallel_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
is_distributed_tensor,
sharded_tensor_to_param,
)
from colossalai.tensor.p_tensor import is_padded_tensor, to_padded_tensor, to_unpadded_tensor

__all__ = ["ParallelModule"]

Expand Down Expand Up @@ -230,10 +231,13 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
for name, param in self._parameters.items():
if param is not None:
param = gather_distributed_param(param, keep_vars=keep_vars)
if self.new_num_embeddings > self.old_num_embeddings:
destination[prefix + name] = param[: self.old_num_embeddings, ...].data
else:
destination[prefix + name] = param.data
# if self.new_num_embeddings > self.old_num_embeddings:
# destination[prefix + name] = param[: self.old_num_embeddings, ...].data
# else:
# destination[prefix + name] = param.data
if is_padded_tensor(param):
param = to_unpadded_tensor(param)
destination[prefix + name] = param.data

for name, buf in self._buffers.items():
if buf is not None and name not in self._non_persistent_buffers_set:
Expand Down Expand Up @@ -296,12 +300,15 @@ def _load_from_state_dict(
)
continue

if self.new_num_embeddings > self.old_num_embeddings:
num_padding_tokens = self.new_num_embeddings - self.old_num_embeddings
padding_embeddings = torch.zeros(
num_padding_tokens, *input_param.shape[1:], device=input_param.device, dtype=input_param.dtype
)
input_param.data = torch.cat((input_param.data, padding_embeddings), dim=0).contiguous()
# if self.new_num_embeddings > self.old_num_embeddings:
# num_padding_tokens = self.new_num_embeddings - self.old_num_embeddings
# padding_embeddings = torch.zeros(
# num_padding_tokens, *input_param.shape[1:], device=input_param.device, dtype=input_param.dtype
# )
# input_param.data = torch.cat((input_param.data, padding_embeddings), dim=0).contiguous()
if is_padded_tensor(param):
print("is_padded_tensor(param)", is_padded_tensor(param))
input_param = to_padded_tensor(input_param, param.current_length, param.padding_dim)

if is_distributed_tensor(param):
# shard the input param
Expand Down Expand Up @@ -359,16 +366,19 @@ def _load_from_state_dict(
unexpected_keys.append(key)

def resize_embedding_weight(self):
num_padding_tokens = self.new_num_embeddings - self.old_num_embeddings
valid_weight = self.weight.data
padding_weight = torch.zeros(
num_padding_tokens, *self.weight.shape[1:], device=self.weight.device, dtype=self.weight.dtype
)
# padding to embedding
self.weight.data = torch.cat((valid_weight, padding_weight), dim=0).contiguous()
self.weight = to_padded_tensor(self.weight, self.new_num_embeddings, 0)
# num_padding_tokens = self.new_num_embeddings - self.old_num_embeddings
# valid_weight = self.weight.data
# padding_weight = torch.zeros(
# num_padding_tokens, *self.weight.shape[1:], device=self.weight.device, dtype=self.weight.dtype
# )
# # padding to embedding
# self.weight.data = torch.cat((valid_weight, padding_weight), dim=0).contiguous()

def resize_embedding_bias(self):
num_padding_tokens = self.new_num_embeddings - self.old_num_embeddings
valid_bias = self.bias.data
padding_bias = torch.zeros((num_padding_tokens), device=self.bias.device, dtype=self.bias.dtype)
self.bias.data = torch.cat((valid_bias, padding_bias), dim=0).contiguous()
print("resize bias")
self.bias = to_padded_tensor(self.bias, self.new_num_embeddings, 0)
# num_padding_tokens = self.new_num_embeddings - self.old_num_embeddings
# valid_bias = self.bias.data
# padding_bias = torch.zeros((num_padding_tokens), device=self.bias.device, dtype=self.bias.dtype)
# self.bias.data = torch.cat((valid_bias, padding_bias), dim=0).contiguous()
3 changes: 3 additions & 0 deletions colossalai/tensor/p_tensor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .api import init_as_ptensor, is_padded_tensor, to_padded_tensor, to_unpadded_tensor

__all__ = ["is_padded_tensor", "to_padded_tensor", "to_unpadded_tensor", "init_as_ptensor"]
129 changes: 129 additions & 0 deletions colossalai/tensor/p_tensor/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import torch


def _hijack_detach_and_clone(ptensor: torch.Tensor) -> torch.Tensor:
"""
Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied.
Args:
tensor (torch.Tensor): The tensor to be hijacked.
Returns:
torch.Tensor: The hijacked tensor.
"""
ptensor._unpad_detach = ptensor.detach
ptensor._unpad_clone = ptensor.clone

def new_detach(self):
t_ = self._unpad_detach()
t_.padding_dim = self.padding_dim
t_.origin_length = self.origin_length
t_.current_length = self.current_length
return t_

def new_clone(self, *args, **kwargs):
t_ = self._unpad_clone(*args, **kwargs)
t_.padding_dim = self.padding_dim
t_.origin_length = self.origin_length
t_.current_length = self.current_length
return t_

# bind the new methods to the tensor
ptensor.detach = new_detach.__get__(ptensor)
ptensor.clone = new_clone.__get__(ptensor)
return ptensor


def _hijack_back_detach_and_clone(ptensor: torch.Tensor) -> torch.Tensor:
"""
Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied.
Args:
tensor (torch.Tensor): The tensor to be hijacked.
Returns:
torch.Tensor: The hijacked tensor.
"""
ptensor.detach = ptensor._unpad_detach
ptensor.clone = ptensor._unpad_clone

delattr(ptensor, "_unpad_detach")
delattr(ptensor, "_unpad_clone")

return ptensor


def is_padded_tensor(tensor: torch.Tensor) -> bool:
"""
Check whether the given tensor is a padding tensor.
Args:
tensor (torch.Tensor): The tensor to be checked.
Returns:
bool: Whether the given tensor is a padding tensor.
"""
return hasattr(tensor, "padding_dim")


def to_padded_tensor(
tensor: torch.Tensor,
current_length: int,
padding_dim: int,
) -> torch.Tensor:
assert (
padding_dim < tensor.dim()
), f"Please passing a valid padding_dim. the dimension of the tensor is {tensor.dim()}"

if is_padded_tensor(tensor):
return tensor

origin_length = tensor.shape[padding_dim]
padding_num = current_length - origin_length
padding_data = torch.zeros(
*tensor.shape[:padding_dim],
padding_num,
*tensor.shape[padding_dim + 1 :],
device=tensor.device,
dtype=tensor.dtype,
)
tensor.data = torch.cat((tensor.data, padding_data), dim=padding_dim).contiguous()

setattr(tensor, "padding_dim", padding_dim)
setattr(tensor, "origin_length", origin_length)
setattr(tensor, "current_length", current_length)

_hijack_detach_and_clone(tensor)

return tensor


def to_unpadded_tensor(ptensor: torch.Tensor):
print("ptensor", ptensor.shape)
if not is_padded_tensor(ptensor):
return ptensor

unpad_slices = [slice(None)] * ptensor.dim()
unpad_slices[ptensor.padding_dim] = slice(None, ptensor.origin_length)
tensor = ptensor[tuple(unpad_slices)]

delattr(ptensor, "padding_dim")
delattr(ptensor, "origin_length")
delattr(ptensor, "current_length")

_hijack_back_detach_and_clone(ptensor)

return tensor


def init_as_ptensor(tensor: torch.Tensor, current_length: int, origin_length: int, padding_dim: int):
if is_padded_tensor(tensor):
return tensor

setattr(tensor, "padding_dim", padding_dim)
setattr(tensor, "origin_length", origin_length)
setattr(tensor, "current_length", current_length)

_hijack_detach_and_clone(tensor)

return tensor
Loading

0 comments on commit 175fd26

Please sign in to comment.