Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "[feature] new zero implementation" #1643

Merged
merged 1 commit into from
Sep 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 6 additions & 2 deletions colossalai/gemini/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from .chunk import TensorInfo, TensorState
from .chunk import TensorInfo, Chunk, TensorState
from .chunk_mgr import ChunkManager
from .stateful_tensor_mgr import StatefulTensorMgr
from .tensor_placement_policy import TensorPlacementPolicyFactory
from .gemini_mgr import GeminiManager

__all__ = ['StatefulTensorMgr', 'TensorPlacementPolicyFactory', 'GeminiManager', 'TensorInfo', 'TensorState']
__all__ = [
'StatefulTensorMgr', 'TensorPlacementPolicyFactory', 'GeminiManager', 'ChunkManager', 'TensorInfo', 'Chunk',
'TensorState'
]
316 changes: 316 additions & 0 deletions colossalai/gemini/chunk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,316 @@
import torch
import torch.distributed as dist
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Dict, List

from colossalai.utils import get_current_device
from colossalai.tensor import ProcessGroup as ColoProcessGroup


class TensorState(Enum):
FREE = 0
COMPUTE = 1
HOLD = 2
HOLD_AFTER_BWD = 3
READY_FOR_REDUCE = 4


STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE),
(TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE),
(TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD),
(TensorState.COMPUTE, TensorState.READY_FOR_REDUCE), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
(TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE,
TensorState.HOLD))


@dataclass
class TensorInfo:
state: TensorState
offset: int
end: int


class ChunkFullError(Exception):
pass


def is_storage_empty(tensor: torch.Tensor) -> bool:
return tensor.storage().size() == 0


def free_storage(tensor: torch.Tensor) -> None:
if not is_storage_empty(tensor):
tensor.storage().resize_(0)


def alloc_storage(tensor: torch.Tensor) -> None:
if is_storage_empty(tensor):
tensor.storage().resize_(tensor.numel())


class Chunk:
"""
A chunk is a contiguous memory space which contains multiple tensors.

Args:
chunk_size (int): the number of elements in a chunk
src_rank (int): the process which owns the chunk
dtype (torch.dtype): the data type of the chunk
init_device (torch.device): optional, the device where the tensor is initialized. The default value is None, which is the current GPU.
force_data_on_cuda (bool): optional, if True, chunk.data is always on cuda. Defaults to False.
"""

def __init__(self,
chunk_size: int,
src_rank: int,
process_group: ColoProcessGroup,
dtype: torch.dtype,
init_device: Optional[torch.device] = None,
force_data_on_cuda: bool = False) -> None:
self.size = chunk_size
self.utilized_size = 0
self.src_rank = src_rank
self.process_group = process_group
self.is_src_rank = process_group.dp_local_rank() == src_rank
self.global_src_rank = process_group.get_ranks_in_dp()[src_rank]
self.dtype = dtype
device = init_device or get_current_device()
if force_data_on_cuda:
self.data = torch.empty(chunk_size, dtype=dtype, device=get_current_device())
self._cpu_data = torch.empty(chunk_size, dtype=dtype)
if device.type == 'cuda':
free_storage(self._cpu_data)
else:
free_storage(self.data)
else:
self.data = torch.empty(chunk_size, dtype=dtype, device=device)
self._cpu_data = None

# we only keep the chunk in full in the process by which the tensor is owned
if not self.is_src_rank:
free_storage(self._payload)

# each tensor is associated with a TensorInfo to track meta info
self.tensors_info: Dict[torch.Tensor, TensorInfo] = {}
self.mem = self.size * self.data.element_size()

def append(self, tensor: torch.Tensor) -> None:
"""
Add a tensor to the chunk.

Args:
tensor (torch.Tensor): a tensor to be added to the chunk
"""
assert tensor.dtype == self.dtype
new_utilized_size = self.utilized_size + tensor.numel()

# raise exception when the chunk size is exceeded
if new_utilized_size > self.size:
raise ChunkFullError

# set tensor state
tensor_state = TensorState.FREE

# if the process owns the rank, then copy the tensor to its chunk buffer
# otherwise set its storage size to 0 to reduce memory consumption
if self.is_src_rank:
self._payload[self.utilized_size:new_utilized_size].copy_(tensor.flatten())
tensor_state = TensorState.HOLD
assert type(self._payload) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor"
tensor.data = self._payload[self.utilized_size:new_utilized_size].view(tensor.shape)
else:
tensor.storage().resize_(0)
self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size)
self.utilized_size = new_utilized_size

def release(self) -> None:
"""
Release the memory space on processes which do not own the chunk.
"""
if not self.is_src_rank:
free_storage(self._payload)
self._update_tensors_state(TensorState.FREE)

def _update_tensors_ptr(self) -> None:
assert type(self._payload) == torch.Tensor
for tensor, tensor_info in self.tensors_info.items():
tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape)

def _update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None):
for tensor_info in self.tensors_info.values():
if prev_state is None or tensor_info.state == prev_state:
tensor_info.state = next_state

def access(self) -> None:
"""
Broadcast the chunk to synchronize the tensors across data parallel processes.
"""
# recover the chunk on non-owner processes
# and broadcast the chunk from the source to all processes
if not self.is_src_rank:
alloc_storage(self._payload)
self.move_device(get_current_device(), update_ptr=False)
dist.broadcast(self.data, self.global_src_rank, group=self.process_group.dp_process_group())

# update tensor meta info
self._update_tensors_ptr()
if not self.is_src_rank:
self._update_tensors_state(TensorState.HOLD, prev_state=TensorState.FREE)

def move_device(self, device: torch.device, update_ptr: bool = True) -> None:
"""
Move the chunk to a target device.

Args:
device (torch.device): the target device for data movement.
"""
if self._payload.device == device:
return
if self._cpu_data is None:
self.data.data = self.data.to(device)
else:
if device.type == 'cuda':
# cpu -> cuda
src = self._cpu_data
dest = self.data
else:
# cuda -> cpu
src = self.data
dest = self._cpu_data
alloc_storage(dest)
dest.copy_(src)
free_storage(src)

if update_ptr:
self._update_tensors_ptr()

def reduce(self, is_all_reduce: bool = False) -> None:
"""
Reduce or all-reduce the chunk.

Args:
is_all_reduce (bool): optional, whether to all-reduce the chunk. The default is false.
"""
self.move_device(get_current_device(), update_ptr=False)
if is_all_reduce:
dist.all_reduce(self.data, group=self.process_group.dp_process_group())
else:
dist.reduce(self.data, self.global_src_rank, group=self.process_group.dp_process_group())
self._update_tensors_ptr()
self._update_tensors_state(TensorState.HOLD)

def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None:
"""
Make a transition of the tensor into the next state.

Args:
tensor (torch.Tensor): a torch Tensor object.
tensor_state (TensorState): the target state for transition.
"""

# As the gradient hook can be triggered either before or after post-backward
# tensor's state can be compute -> hold_after_bwd -> ready_for_reduce
# or compute -> ready_for_reduce -> hold_after_bwd
# the second one is invalid, we just ignore ready_for_reduce -> hold_after_bwd
# this function only apply valid state transformation
# invalid calls will be ignored and nothing changes
if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS:
# print(
# f'WARNING: Rank{self.process_group.rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}'
# )
return
self.tensors_info[tensor].state = tensor_state

def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None:
"""
Copy data slice to the memory space indexed by the input tensor in the chunk.

Args:
tensor (torch.Tensor): the tensor used to retrive meta information
data_slice (torch.Tensor): the tensor to be copied to the chunk
"""
tensor_info = self.tensors_info[tensor]
self._payload[tensor_info.offset:tensor_info.end].copy_(data_slice.flatten())
tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape)

@property
def can_release(self) -> bool:
"""
Check whether the chunk can be released.
"""
for tensor_info in self.tensors_info.values():
if tensor_info.state != TensorState.HOLD:
return False
return True

@property
def can_move_device(self) -> bool:
"""
Check whether the chunk can be moved across devices.
"""
for tensor_info in self.tensors_info.values():
if tensor_info.state in (TensorState.COMPUTE, TensorState.READY_FOR_REDUCE):
return False
return True

@property
def can_reduce(self) -> bool:
"""
Check whether the chunk can be reduced.
"""
for tensor_info in self.tensors_info.values():
if tensor_info.state != TensorState.READY_FOR_REDUCE:
return False
return True

@property
def is_empty(self) -> bool:
"""
Check whether the chunk is empty.
"""
return is_storage_empty(self._payload)

def __repr__(self) -> str:
return f'Chunk: src rank={self.src_rank} ,size={self.size}, utilization={self.utilized_size/self.size*100:.2f}%, freed={self.is_empty}, tensor states={[info.state.name for info in self.tensors_info.values()]}'

@property
def has_inf_or_nan(self) -> bool:
"""
Check if the chunk has inf or nan values.
"""
return torch.isinf(self._payload[:self.utilized_size]).any().item() or \
torch.isnan(self._payload[:self.utilized_size]).any().item()

def copy_(self, dest_chunk: 'Chunk'):
"""
Copy the data of this chunk to a destination chunk.
"""
assert not self.is_empty
assert not dest_chunk.is_empty
assert self.size == dest_chunk.size
assert self.utilized_size == dest_chunk.utilized_size
self._payload.copy_(dest_chunk._payload)
self._update_tensors_ptr()

@property
def device_type(self) -> str:
"""
Get the device type of the chunk.
"""
return self._payload.device.type

def __hash__(self) -> int:
return hash(id(self))

def __eq__(self, __o: object) -> bool:
return self is __o

def get_tensors(self) -> List[torch.Tensor]:
return list(self.tensors_info.keys())

@property
def _payload(self) -> torch.Tensor:
if self._cpu_data is None or is_storage_empty(self._cpu_data):
return self.data
return self._cpu_data
3 changes: 0 additions & 3 deletions colossalai/gemini/chunk/__init__.py

This file was deleted.