Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
fix
  • Loading branch information
flybird11111 committed Apr 16, 2024
1 parent 70e8113 commit 3b4f14a
Show file tree
Hide file tree
Showing 11 changed files with 115 additions and 47 deletions.
9 changes: 7 additions & 2 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@

from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.tensor.p_tensor import init_as_ptensor, is_padded_tensor, to_padded_tensor, to_unpadded_tensor
from colossalai.tensor.padded_tensor import (
init_as_padded_tensor,
is_padded_tensor,
to_padded_tensor,
to_unpadded_tensor,
)
from colossalai.utils import get_current_device

from .general_checkpoint_io import GeneralCheckpointIO
Expand Down Expand Up @@ -873,7 +878,7 @@ def gather_from_sharded_optimizer_state(

padding_dim = search_padding_dim(v.shape, original_shape)
if padding_dim is not None:
v = init_as_ptensor(v, v.shape[padding_dim], original_shape[padding_dim], padding_dim)
v = init_as_padded_tensor(v, v.shape[padding_dim], original_shape[padding_dim], padding_dim)
v = to_unpadded_tensor(v)

state_[k] = v.detach().clone().to(device)
Expand Down
11 changes: 4 additions & 7 deletions colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
to_global,
to_global_for_customized_distributed_tensor,
)
from colossalai.tensor.p_tensor.api import init_as_ptensor, is_padded_tensor

SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin"
Expand Down Expand Up @@ -208,13 +207,11 @@ def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False) -> to
"""
param_ = param if keep_vars else param.detach()
if is_distributed_tensor(param_):
param_ = to_global(param_)
return to_global(param_)
elif is_customized_distributed_tensor(param_):
param_ = to_global_for_customized_distributed_tensor(param_)

if is_padded_tensor(param):
param_ = init_as_ptensor(param_, param.current_length, param.origin_length, param.padding_dim)
return param_
return to_global_for_customized_distributed_tensor(param_)
else:
return param_


def save_state_dict_shards(
Expand Down
5 changes: 2 additions & 3 deletions colossalai/shardformer/layer/parallel_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
is_distributed_tensor,
sharded_tensor_to_param,
)
from colossalai.tensor.p_tensor import is_padded_tensor, to_padded_tensor, to_unpadded_tensor
from colossalai.tensor.padded_tensor import is_padded_tensor, to_padded_tensor, to_unpadded_tensor

__all__ = ["ParallelModule"]

Expand Down Expand Up @@ -297,8 +297,7 @@ def _load_from_state_dict(
continue

if is_padded_tensor(param):
print("is_padded_tensor(param)", is_padded_tensor(param))
input_param = to_padded_tensor(input_param, param.current_length, param.padding_dim)
input_param = to_padded_tensor(input_param, param._current_length, param._padding_dim)

if is_distributed_tensor(param):
# shard the input param
Expand Down
15 changes: 12 additions & 3 deletions colossalai/tensor/d_tensor/layout_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from colossalai.tensor.d_tensor.comm_spec import *
from colossalai.tensor.d_tensor.layout import Layout
from colossalai.tensor.d_tensor.misc import LayoutException
from colossalai.tensor.padded_tensor.api import init_as_padded_tensor, is_padded_tensor
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator

from .sharding_spec import ShardingSpec
Expand Down Expand Up @@ -607,8 +608,16 @@ def apply(self, tensor: torch.Tensor, source_layout: Layout, target_layout: Layo
[3.],
[3.]])
"""

_, comm_action_sequence = self.layout_converting(source_layout, target_layout)
for comm_spec in comm_action_sequence:
tensor = comm_spec.covert_spec_to_action(tensor)
tensor.dist_layout = target_layout
return tensor
target_tensor = comm_spec.covert_spec_to_action(tensor)
target_tensor.dist_layout = target_layout

# restore the padding information
if is_padded_tensor(tensor) and not is_padded_tensor(target_tensor):
target_tensor = init_as_padded_tensor(
target_tensor, tensor._current_length, tensor._origin_length, tensor._padding_dim
)

return target_tensor
3 changes: 0 additions & 3 deletions colossalai/tensor/p_tensor/__init__.py

This file was deleted.

3 changes: 3 additions & 0 deletions colossalai/tensor/padded_tensor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .api import init_as_padded_tensor, is_padded_tensor, to_padded_tensor, to_unpadded_tensor

__all__ = ["is_padded_tensor", "to_padded_tensor", "to_unpadded_tensor", "init_as_padded_tensor"]
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@ def _hijack_detach_and_clone(ptensor: torch.Tensor) -> torch.Tensor:

def new_detach(self):
t_ = self._unpad_detach()
t_.padding_dim = self.padding_dim
t_.origin_length = self.origin_length
t_.current_length = self.current_length
t_._padding_dim = self._padding_dim
t_._origin_length = self._origin_length
t_._current_length = self._current_length
return t_

def new_clone(self, *args, **kwargs):
t_ = self._unpad_clone(*args, **kwargs)
t_.padding_dim = self.padding_dim
t_.origin_length = self.origin_length
t_.current_length = self.current_length
t_._padding_dim = self._padding_dim
t_._origin_length = self._origin_length
t_._current_length = self._current_length
return t_

# bind the new methods to the tensor
Expand Down Expand Up @@ -63,7 +63,7 @@ def is_padded_tensor(tensor: torch.Tensor) -> bool:
Returns:
bool: Whether the given tensor is a padding tensor.
"""
return hasattr(tensor, "padding_dim")
return hasattr(tensor, "_padding_dim")


def to_padded_tensor(
Expand All @@ -89,9 +89,9 @@ def to_padded_tensor(
)
tensor.data = torch.cat((tensor.data, padding_data), dim=padding_dim).contiguous()

setattr(tensor, "padding_dim", padding_dim)
setattr(tensor, "origin_length", origin_length)
setattr(tensor, "current_length", current_length)
tensor._padding_dim = padding_dim
tensor._origin_length = origin_length
tensor._current_length = current_length

_hijack_detach_and_clone(tensor)

Expand All @@ -103,25 +103,25 @@ def to_unpadded_tensor(ptensor: torch.Tensor):
return ptensor

unpad_slices = [slice(None)] * ptensor.dim()
unpad_slices[ptensor.padding_dim] = slice(None, ptensor.origin_length)
unpad_slices[ptensor._padding_dim] = slice(None, ptensor._origin_length)
ptensor.data = ptensor.data[tuple(unpad_slices)]

delattr(ptensor, "padding_dim")
delattr(ptensor, "origin_length")
delattr(ptensor, "current_length")
delattr(ptensor, "_padding_dim")
delattr(ptensor, "_origin_length")
delattr(ptensor, "_current_length")

_hijack_back_detach_and_clone(ptensor)

return ptensor


def init_as_ptensor(tensor: torch.Tensor, current_length: int, origin_length: int, padding_dim: int):
def init_as_padded_tensor(tensor: torch.Tensor, current_length: int, origin_length: int, padding_dim: int):
if is_padded_tensor(tensor):
return tensor

setattr(tensor, "padding_dim", padding_dim)
setattr(tensor, "origin_length", origin_length)
setattr(tensor, "current_length", current_length)
tensor._padding_dim = padding_dim
tensor._origin_length = origin_length
tensor._current_length = current_length

_hijack_detach_and_clone(tensor)

Expand Down
13 changes: 9 additions & 4 deletions colossalai/zero/gemini/gemini_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@
is_customized_distributed_tensor,
is_distributed_tensor,
)
from colossalai.tensor.p_tensor import init_as_ptensor, is_padded_tensor, to_padded_tensor, to_unpadded_tensor
from colossalai.tensor.padded_tensor import (
init_as_padded_tensor,
is_padded_tensor,
to_padded_tensor,
to_unpadded_tensor,
)
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import _cast_float, free_storage, is_ddp_ignored

Expand Down Expand Up @@ -462,8 +467,8 @@ def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool) -> Dict:
)
record_tensor = gather_distributed_param(record_tensor, keep_vars=False).cpu()
if is_padded_tensor(tensor):
record_tensor = init_as_ptensor(
record_tensor, tensor.current_length, tensor.origin_length, tensor.padding_dim
record_tensor = init_as_padded_tensor(
record_tensor, tensor._current_length, tensor._origin_length, tensor._padding_dim
)
record_tensor = to_unpadded_tensor(record_tensor)

Expand Down Expand Up @@ -661,7 +666,7 @@ def load(
global_shape = get_global_shape(dest_tensor)

if is_padded_tensor(dest_tensor):
padding_dim = dest_tensor.padding_dim
padding_dim = dest_tensor._padding_dim
input_param = to_padded_tensor(input_param, global_shape[padding_dim], padding_dim)

if source_device_mesh is not None and source_sharding_spec is not None:
Expand Down
17 changes: 11 additions & 6 deletions colossalai/zero/gemini/gemini_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@
is_customized_distributed_tensor,
is_distributed_tensor,
)
from colossalai.tensor.p_tensor import init_as_ptensor, is_padded_tensor, to_padded_tensor, to_unpadded_tensor
from colossalai.tensor.padded_tensor import (
init_as_padded_tensor,
is_padded_tensor,
to_padded_tensor,
to_unpadded_tensor,
)
from colossalai.utils import disposable, is_ddp_ignored

from .chunk import Chunk, ChunkManager
Expand Down Expand Up @@ -495,8 +500,8 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict:
state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
state_tensor = state_tensor.reshape(global_shape)
if is_padded_tensor(param):
state_tensor = init_as_ptensor(
state_tensor, param.current_length, param.origin_length, param.padding_dim
state_tensor = init_as_padded_tensor(
state_tensor, param._current_length, param._origin_length, param._padding_dim
)
state_tensor = to_unpadded_tensor(state_tensor)
collected_states[state_name] = state_tensor
Expand Down Expand Up @@ -555,8 +560,8 @@ def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict:
)
state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
if is_padded_tensor(param):
state_tensor = init_as_ptensor(
state_tensor, param.current_length, param.origin_length, param.padding_dim
state_tensor = init_as_padded_tensor(
state_tensor, param._current_length, param._origin_length, param._padding_dim
)
state_tensor = to_unpadded_tensor(state_tensor)

Expand Down Expand Up @@ -732,7 +737,7 @@ def cast(param, state_range, value, global_shape, origin_shape, key=None):

if is_padded_tensor(real_param):
value = torch.reshape(value, origin_shape)
padding_dim = real_param.padding_dim
padding_dim = real_param._padding_dim
value = to_padded_tensor(value, global_shape[padding_dim], padding_dim)

if is_dtensor:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_shardformer/test_model/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from colossalai.shardformer._utils import getattr_
from colossalai.shardformer.policies.auto_policy import Policy
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.tensor.p_tensor.api import is_padded_tensor, to_unpadded_tensor
from colossalai.tensor.padded_tensor.api import is_padded_tensor, to_unpadded_tensor


def build_model(
Expand Down
48 changes: 48 additions & 0 deletions tests/test_tensor/test_padded_tensor/test_padded_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch

from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.d_tensor import ShardingSpec, distribute_tensor, is_distributed_tensor, to_global
from colossalai.tensor.padded_tensor import is_padded_tensor, to_padded_tensor, to_unpadded_tensor
from colossalai.testing import rerun_if_address_is_in_use, spawn


def check_padded_tensor(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
original_tensor = torch.rand(32, 64).to("cuda")

device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]})
d_tensor = distribute_tensor(original_tensor, device_mesh, target_sharding_spec)

padded_tensor = to_padded_tensor(d_tensor, current_length=64, padding_dim=0)
assert padded_tensor.shape == (64, 64)

tensor_copy = padded_tensor.clone()
assert is_padded_tensor(tensor_copy)
assert is_distributed_tensor(tensor_copy)

tensor_detached = padded_tensor.detach()
assert is_padded_tensor(tensor_detached)
assert is_distributed_tensor(tensor_detached)
assert tensor_detached.requires_grad == False
assert tensor_detached.grad == None

unpadded_tensor = to_unpadded_tensor(padded_tensor)
assert unpadded_tensor.shape == d_tensor.shape
assert is_distributed_tensor(unpadded_tensor)

global_tensor = to_global(unpadded_tensor)
assert global_tensor.shape == original_tensor.shape


@rerun_if_address_is_in_use()
def test_padded_tensor():
world_size = 4
spawn(check_padded_tensor, world_size)


if __name__ == "__main__":
test_padded_tensor()

0 comments on commit 3b4f14a

Please sign in to comment.