-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Revert "[feature] new zero implementation (#1623)"
This reverts commit 5be118f.
- Loading branch information
1 parent
5be118f
commit fb8f86b
Showing
29 changed files
with
1,540 additions
and
918 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,10 @@ | ||
from .chunk import TensorInfo, TensorState | ||
from .chunk import TensorInfo, Chunk, TensorState | ||
from .chunk_mgr import ChunkManager | ||
from .stateful_tensor_mgr import StatefulTensorMgr | ||
from .tensor_placement_policy import TensorPlacementPolicyFactory | ||
from .gemini_mgr import GeminiManager | ||
|
||
__all__ = ['StatefulTensorMgr', 'TensorPlacementPolicyFactory', 'GeminiManager', 'TensorInfo', 'TensorState'] | ||
__all__ = [ | ||
'StatefulTensorMgr', 'TensorPlacementPolicyFactory', 'GeminiManager', 'ChunkManager', 'TensorInfo', 'Chunk', | ||
'TensorState' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,316 @@ | ||
import torch | ||
import torch.distributed as dist | ||
from dataclasses import dataclass | ||
from enum import Enum | ||
from typing import Optional, Dict, List | ||
|
||
from colossalai.utils import get_current_device | ||
from colossalai.tensor import ProcessGroup as ColoProcessGroup | ||
|
||
|
||
class TensorState(Enum): | ||
FREE = 0 | ||
COMPUTE = 1 | ||
HOLD = 2 | ||
HOLD_AFTER_BWD = 3 | ||
READY_FOR_REDUCE = 4 | ||
|
||
|
||
STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE), | ||
(TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE), | ||
(TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD), | ||
(TensorState.COMPUTE, TensorState.READY_FOR_REDUCE), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE), | ||
(TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE, | ||
TensorState.HOLD)) | ||
|
||
|
||
@dataclass | ||
class TensorInfo: | ||
state: TensorState | ||
offset: int | ||
end: int | ||
|
||
|
||
class ChunkFullError(Exception): | ||
pass | ||
|
||
|
||
def is_storage_empty(tensor: torch.Tensor) -> bool: | ||
return tensor.storage().size() == 0 | ||
|
||
|
||
def free_storage(tensor: torch.Tensor) -> None: | ||
if not is_storage_empty(tensor): | ||
tensor.storage().resize_(0) | ||
|
||
|
||
def alloc_storage(tensor: torch.Tensor) -> None: | ||
if is_storage_empty(tensor): | ||
tensor.storage().resize_(tensor.numel()) | ||
|
||
|
||
class Chunk: | ||
""" | ||
A chunk is a contiguous memory space which contains multiple tensors. | ||
Args: | ||
chunk_size (int): the number of elements in a chunk | ||
src_rank (int): the process which owns the chunk | ||
dtype (torch.dtype): the data type of the chunk | ||
init_device (torch.device): optional, the device where the tensor is initialized. The default value is None, which is the current GPU. | ||
force_data_on_cuda (bool): optional, if True, chunk.data is always on cuda. Defaults to False. | ||
""" | ||
|
||
def __init__(self, | ||
chunk_size: int, | ||
src_rank: int, | ||
process_group: ColoProcessGroup, | ||
dtype: torch.dtype, | ||
init_device: Optional[torch.device] = None, | ||
force_data_on_cuda: bool = False) -> None: | ||
self.size = chunk_size | ||
self.utilized_size = 0 | ||
self.src_rank = src_rank | ||
self.process_group = process_group | ||
self.is_src_rank = process_group.dp_local_rank() == src_rank | ||
self.global_src_rank = process_group.get_ranks_in_dp()[src_rank] | ||
self.dtype = dtype | ||
device = init_device or get_current_device() | ||
if force_data_on_cuda: | ||
self.data = torch.empty(chunk_size, dtype=dtype, device=get_current_device()) | ||
self._cpu_data = torch.empty(chunk_size, dtype=dtype) | ||
if device.type == 'cuda': | ||
free_storage(self._cpu_data) | ||
else: | ||
free_storage(self.data) | ||
else: | ||
self.data = torch.empty(chunk_size, dtype=dtype, device=device) | ||
self._cpu_data = None | ||
|
||
# we only keep the chunk in full in the process by which the tensor is owned | ||
if not self.is_src_rank: | ||
free_storage(self._payload) | ||
|
||
# each tensor is associated with a TensorInfo to track meta info | ||
self.tensors_info: Dict[torch.Tensor, TensorInfo] = {} | ||
self.mem = self.size * self.data.element_size() | ||
|
||
def append(self, tensor: torch.Tensor) -> None: | ||
""" | ||
Add a tensor to the chunk. | ||
Args: | ||
tensor (torch.Tensor): a tensor to be added to the chunk | ||
""" | ||
assert tensor.dtype == self.dtype | ||
new_utilized_size = self.utilized_size + tensor.numel() | ||
|
||
# raise exception when the chunk size is exceeded | ||
if new_utilized_size > self.size: | ||
raise ChunkFullError | ||
|
||
# set tensor state | ||
tensor_state = TensorState.FREE | ||
|
||
# if the process owns the rank, then copy the tensor to its chunk buffer | ||
# otherwise set its storage size to 0 to reduce memory consumption | ||
if self.is_src_rank: | ||
self._payload[self.utilized_size:new_utilized_size].copy_(tensor.flatten()) | ||
tensor_state = TensorState.HOLD | ||
assert type(self._payload) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor" | ||
tensor.data = self._payload[self.utilized_size:new_utilized_size].view(tensor.shape) | ||
else: | ||
tensor.storage().resize_(0) | ||
self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size) | ||
self.utilized_size = new_utilized_size | ||
|
||
def release(self) -> None: | ||
""" | ||
Release the memory space on processes which do not own the chunk. | ||
""" | ||
if not self.is_src_rank: | ||
free_storage(self._payload) | ||
self._update_tensors_state(TensorState.FREE) | ||
|
||
def _update_tensors_ptr(self) -> None: | ||
assert type(self._payload) == torch.Tensor | ||
for tensor, tensor_info in self.tensors_info.items(): | ||
tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape) | ||
|
||
def _update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None): | ||
for tensor_info in self.tensors_info.values(): | ||
if prev_state is None or tensor_info.state == prev_state: | ||
tensor_info.state = next_state | ||
|
||
def access(self) -> None: | ||
""" | ||
Broadcast the chunk to synchronize the tensors across data parallel processes. | ||
""" | ||
# recover the chunk on non-owner processes | ||
# and broadcast the chunk from the source to all processes | ||
if not self.is_src_rank: | ||
alloc_storage(self._payload) | ||
self.move_device(get_current_device(), update_ptr=False) | ||
dist.broadcast(self.data, self.global_src_rank, group=self.process_group.dp_process_group()) | ||
|
||
# update tensor meta info | ||
self._update_tensors_ptr() | ||
if not self.is_src_rank: | ||
self._update_tensors_state(TensorState.HOLD, prev_state=TensorState.FREE) | ||
|
||
def move_device(self, device: torch.device, update_ptr: bool = True) -> None: | ||
""" | ||
Move the chunk to a target device. | ||
Args: | ||
device (torch.device): the target device for data movement. | ||
""" | ||
if self._payload.device == device: | ||
return | ||
if self._cpu_data is None: | ||
self.data.data = self.data.to(device) | ||
else: | ||
if device.type == 'cuda': | ||
# cpu -> cuda | ||
src = self._cpu_data | ||
dest = self.data | ||
else: | ||
# cuda -> cpu | ||
src = self.data | ||
dest = self._cpu_data | ||
alloc_storage(dest) | ||
dest.copy_(src) | ||
free_storage(src) | ||
|
||
if update_ptr: | ||
self._update_tensors_ptr() | ||
|
||
def reduce(self, is_all_reduce: bool = False) -> None: | ||
""" | ||
Reduce or all-reduce the chunk. | ||
Args: | ||
is_all_reduce (bool): optional, whether to all-reduce the chunk. The default is false. | ||
""" | ||
self.move_device(get_current_device(), update_ptr=False) | ||
if is_all_reduce: | ||
dist.all_reduce(self.data, group=self.process_group.dp_process_group()) | ||
else: | ||
dist.reduce(self.data, self.global_src_rank, group=self.process_group.dp_process_group()) | ||
self._update_tensors_ptr() | ||
self._update_tensors_state(TensorState.HOLD) | ||
|
||
def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None: | ||
""" | ||
Make a transition of the tensor into the next state. | ||
Args: | ||
tensor (torch.Tensor): a torch Tensor object. | ||
tensor_state (TensorState): the target state for transition. | ||
""" | ||
|
||
# As the gradient hook can be triggered either before or after post-backward | ||
# tensor's state can be compute -> hold_after_bwd -> ready_for_reduce | ||
# or compute -> ready_for_reduce -> hold_after_bwd | ||
# the second one is invalid, we just ignore ready_for_reduce -> hold_after_bwd | ||
# this function only apply valid state transformation | ||
# invalid calls will be ignored and nothing changes | ||
if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS: | ||
# print( | ||
# f'WARNING: Rank{self.process_group.rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}' | ||
# ) | ||
return | ||
self.tensors_info[tensor].state = tensor_state | ||
|
||
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None: | ||
""" | ||
Copy data slice to the memory space indexed by the input tensor in the chunk. | ||
Args: | ||
tensor (torch.Tensor): the tensor used to retrive meta information | ||
data_slice (torch.Tensor): the tensor to be copied to the chunk | ||
""" | ||
tensor_info = self.tensors_info[tensor] | ||
self._payload[tensor_info.offset:tensor_info.end].copy_(data_slice.flatten()) | ||
tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape) | ||
|
||
@property | ||
def can_release(self) -> bool: | ||
""" | ||
Check whether the chunk can be released. | ||
""" | ||
for tensor_info in self.tensors_info.values(): | ||
if tensor_info.state != TensorState.HOLD: | ||
return False | ||
return True | ||
|
||
@property | ||
def can_move_device(self) -> bool: | ||
""" | ||
Check whether the chunk can be moved across devices. | ||
""" | ||
for tensor_info in self.tensors_info.values(): | ||
if tensor_info.state in (TensorState.COMPUTE, TensorState.READY_FOR_REDUCE): | ||
return False | ||
return True | ||
|
||
@property | ||
def can_reduce(self) -> bool: | ||
""" | ||
Check whether the chunk can be reduced. | ||
""" | ||
for tensor_info in self.tensors_info.values(): | ||
if tensor_info.state != TensorState.READY_FOR_REDUCE: | ||
return False | ||
return True | ||
|
||
@property | ||
def is_empty(self) -> bool: | ||
""" | ||
Check whether the chunk is empty. | ||
""" | ||
return is_storage_empty(self._payload) | ||
|
||
def __repr__(self) -> str: | ||
return f'Chunk: src rank={self.src_rank} ,size={self.size}, utilization={self.utilized_size/self.size*100:.2f}%, freed={self.is_empty}, tensor states={[info.state.name for info in self.tensors_info.values()]}' | ||
|
||
@property | ||
def has_inf_or_nan(self) -> bool: | ||
""" | ||
Check if the chunk has inf or nan values. | ||
""" | ||
return torch.isinf(self._payload[:self.utilized_size]).any().item() or \ | ||
torch.isnan(self._payload[:self.utilized_size]).any().item() | ||
|
||
def copy_(self, dest_chunk: 'Chunk'): | ||
""" | ||
Copy the data of this chunk to a destination chunk. | ||
""" | ||
assert not self.is_empty | ||
assert not dest_chunk.is_empty | ||
assert self.size == dest_chunk.size | ||
assert self.utilized_size == dest_chunk.utilized_size | ||
self._payload.copy_(dest_chunk._payload) | ||
self._update_tensors_ptr() | ||
|
||
@property | ||
def device_type(self) -> str: | ||
""" | ||
Get the device type of the chunk. | ||
""" | ||
return self._payload.device.type | ||
|
||
def __hash__(self) -> int: | ||
return hash(id(self)) | ||
|
||
def __eq__(self, __o: object) -> bool: | ||
return self is __o | ||
|
||
def get_tensors(self) -> List[torch.Tensor]: | ||
return list(self.tensors_info.keys()) | ||
|
||
@property | ||
def _payload(self) -> torch.Tensor: | ||
if self._cpu_data is None or is_storage_empty(self._cpu_data): | ||
return self.data | ||
return self._cpu_data |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.