Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[gemini] gemini mgr supports "cpu" placement policy #1118

Merged
merged 8 commits into from
Jun 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions colossalai/gemini/gemini_mgr.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import functools
import torch
from .memory_tracer.memstats_collector import MemStatsCollectorV2
from typing import List, Optional, Tuple
from time import time
Expand All @@ -15,8 +15,6 @@ class GeminiManager:
"""

def __init__(self, placement_policy: str, chunk_manager: ChunkManager) -> None:
# TODO: remove assert
assert placement_policy == 'cuda', 'placement_policy can only be "cuda" now'
assert placement_policy in PlacementPolicyFactory.get_polocy_names()
policy_cls = PlacementPolicyFactory.create(placement_policy)
self._chunk_manager = chunk_manager
Expand Down Expand Up @@ -111,3 +109,7 @@ def cuda_margin_mem(self) -> Optional[float]:
@property
def is_cuda_margin_mem_avail(self) -> bool:
return self._placement_policy.need_mem_stats

@staticmethod
def get_default_device(policy_name: str) -> torch.device:
return PlacementPolicyFactory.get_default_device(policy_name)
7 changes: 4 additions & 3 deletions colossalai/gemini/placement_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@ def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[Me

def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> int:
volume = 0
start = time()
for chunk in can_evict_chunks:
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
self.chunk_manager.move_chunk(chunk, torch.device('cpu'), update_ptr=False)
volume += chunk.mem
return volume, 0
return volume, time() - start


class CUDAPlacementPolicy(PlacementPolicy):
Expand Down Expand Up @@ -115,7 +116,7 @@ def evict_tensors(self,
if freed_cuda_model_data >= to_free_cuda_model_data:
break
freed_cuda_model_data += chunk.mem
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
self.chunk_manager.move_chunk(chunk, torch.device('cpu'), update_ptr=False)
if freed_cuda_model_data < to_free_cuda_model_data:
raise RuntimeError(
f"Adjust layout failed! No enough CUDA memory! Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}"
Expand Down
2 changes: 2 additions & 0 deletions colossalai/nn/parallel/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager) -> No
self.fp32_params = []
self.overflow_counter = 0
self.grads_device: Dict[torch.Tensor, torch.device] = {}
self.chunk_manager.create_group('fp16_param', force_data_on_cuda=True)
self.chunk_manager.create_group('fp32_param')
# TODO: get param order and filter unused params
for p in module.parameters():
assert p.dtype == torch.half
Expand Down
122 changes: 91 additions & 31 deletions colossalai/tensor/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,21 @@ class ChunkFullError(Exception):
pass


class Chunk:
def is_storage_empty(tensor: torch.Tensor) -> bool:
return tensor.storage().size() == 0


def free_storage(tensor: torch.Tensor) -> None:
if not is_storage_empty(tensor):
tensor.storage().resize_(0)


def alloc_storage(tensor: torch.Tensor) -> None:
if is_storage_empty(tensor):
tensor.storage().resize_(tensor.numel())


class Chunk:
"""
A chunk is a contiguous memory space which contains multiple tensors.

Expand All @@ -46,26 +59,37 @@ class Chunk:
src_rank (int): the process which owns the chunk
dtype (torch.dtype): the data type of the chunk
init_device (torch.device): optional, the device where the tensor is initialized. The default value is None, which is the current GPU.
force_data_on_cuda (bool): optional, if True, chunk.data is always on cuda. Defaults to False.
"""

def __init__(self,
chunk_size: int,
src_rank: int,
dtype: torch.dtype,
init_device: Optional[torch.device] = None) -> None:
init_device: Optional[torch.device] = None,
force_data_on_cuda: bool = False) -> None:
self.size = chunk_size
self.utilized_size = 0
self.src_rank = src_rank
self.is_src_rank = gpc.get_local_rank(ParallelMode.DATA) == src_rank
self.global_src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[src_rank]
self.dtype = dtype
self.device = init_device or get_current_device()
self.data = torch.empty(chunk_size, dtype=dtype, device=self.device)
device = init_device or get_current_device()
if force_data_on_cuda:
self.data = torch.empty(chunk_size, dtype=dtype, device=get_current_device())
self._cpu_data = torch.empty(chunk_size, dtype=dtype)
if device.type == 'cuda':
free_storage(self._cpu_data)
else:
free_storage(self.data)
else:
self.data = torch.empty(chunk_size, dtype=dtype, device=device)
self._cpu_data = None

# we only keep the chunk in full in the process by which the tensor is owned
if not self.is_src_rank:
self.data.storage().resize_(0)
free_storage(self._payload)

# each tensor is associated with a TensorInfo to track meta info
self.tensors_info: Dict[torch.Tensor, TensorInfo] = {}
self.mem = self.size * self.data.element_size()
Expand All @@ -83,16 +107,16 @@ def append(self, tensor: torch.Tensor) -> None:
# raise exception when the chunk size is exceeded
if new_utilized_size > self.size:
raise ChunkFullError

# set tensor state
tensor_state = TensorState.FREE

# if the process owns the rank, then copy the tensor to its chunk buffer
# otherwise set its storage size to 0 to reduce memory consumption
if self.is_src_rank:
self.data[self.utilized_size:new_utilized_size].copy_(tensor.view(-1))
self._payload[self.utilized_size:new_utilized_size].copy_(tensor.view(-1))
tensor_state = TensorState.HOLD
tensor.data = self.data[self.utilized_size:new_utilized_size].view(tensor.shape)
tensor.data = self._payload[self.utilized_size:new_utilized_size].view(tensor.shape)
else:
tensor.storage().resize_(0)
self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size)
Expand All @@ -103,12 +127,12 @@ def release(self) -> None:
Release the memory space on processes which do not own the chunk.
"""
if not self.is_src_rank:
self.data.storage().resize_(0)
free_storage(self._payload)
self._update_tensors_state(TensorState.FREE)

def _update_tensors_ptr(self) -> None:
for tensor, tensor_info in self.tensors_info.items():
tensor.data = self.data[tensor_info.offset:tensor_info.end].view(tensor.shape)
tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape)

def _update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None):
for tensor_info in self.tensors_info.values():
Expand All @@ -122,24 +146,41 @@ def access(self) -> None:
# recover the chunk on non-owner processes
# and broadcast the chunk from the source to all processes
if not self.is_src_rank:
self.data.storage().resize_(self.size)
self.data.data = self.data.to(get_current_device())
alloc_storage(self._payload)
self.move_device(get_current_device(), update_ptr=False)
dist.broadcast(self.data, self.global_src_rank, group=gpc.get_group(ParallelMode.DATA))

# update tensor meta info
self._update_tensors_ptr()
if not self.is_src_rank:
self._update_tensors_state(TensorState.HOLD, prev_state=TensorState.FREE)

def move_device(self, device: torch.device) -> None:
def move_device(self, device: torch.device, update_ptr: bool = True) -> None:
"""
Move the chunk to a target device.

Args:
device (torch.device): the target device for data movement.
"""
self.data.data = self.data.to(device)
self._update_tensors_ptr()
if self._payload.device == device:
return
if self._cpu_data is None:
self.data.data = self.data.to(device)
else:
if device.type == 'cuda':
# cpu -> cuda
src = self._cpu_data
dest = self.data
else:
# cuda -> cpu
src = self.data
dest = self._cpu_data
alloc_storage(dest)
dest.copy_(src)
free_storage(src)

if update_ptr:
self._update_tensors_ptr()

def reduce(self, is_all_reduce: bool = False) -> None:
"""
Expand All @@ -148,7 +189,7 @@ def reduce(self, is_all_reduce: bool = False) -> None:
Args:
is_all_reduce (bool): optional, whether to all-reduce the chunk. The default is false.
"""
self.data.data = self.data.to(get_current_device())
self.move_device(get_current_device(), update_ptr=False)
if is_all_reduce:
dist.all_reduce(self.data, group=gpc.get_group(ParallelMode.DATA))
else:
Expand Down Expand Up @@ -187,8 +228,8 @@ def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Ten
data_slice (torch.Tensor): the tensor to be copied to the chunk
"""
tensor_info = self.tensors_info[tensor]
self.data[tensor_info.offset:tensor_info.end].copy_(data_slice.view(-1))
tensor.data = self.data[tensor_info.offset:tensor_info.end].view(tensor.shape)
self._payload[tensor_info.offset:tensor_info.end].copy_(data_slice.view(-1))
tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape)

@property
def can_release(self) -> bool:
Expand Down Expand Up @@ -225,7 +266,7 @@ def is_empty(self) -> bool:
"""
Check whether the chunk is empty.
"""
return self.data.storage().size() == 0
return is_storage_empty(self._payload)

def __repr__(self) -> str:
return f'Chunk: src rank={self.src_rank} ,size={self.size}, utilization={self.utilized_size/self.size*100:.2f}%, freed={self.is_empty}, tensor states={[info.state.name for info in self.tensors_info.values()]}'
Expand All @@ -235,8 +276,8 @@ def has_inf_or_nan(self) -> bool:
"""
Check if the chunk has inf or nan values.
"""
return torch.isinf(self.data[:self.utilized_size]).any().item() or \
torch.isnan(self.data[:self.utilized_size]).any().item()
return torch.isinf(self._payload[:self.utilized_size]).any().item() or \
torch.isnan(self._payload[:self.utilized_size]).any().item()

def copy_(self, dest_chunk: 'Chunk'):
"""
Expand All @@ -246,15 +287,15 @@ def copy_(self, dest_chunk: 'Chunk'):
assert not dest_chunk.is_empty
assert self.size == dest_chunk.size
assert self.utilized_size == dest_chunk.utilized_size
self.data.copy_(dest_chunk.data)
self._payload.copy_(dest_chunk._payload)
self._update_tensors_ptr()

@property
def device_type(self) -> str:
"""
Get the device type of the chunk.
"""
return self.data.device.type
return self._payload.device.type

def __hash__(self) -> int:
return hash(id(self))
Expand All @@ -265,6 +306,12 @@ def __eq__(self, __o: object) -> bool:
def get_tensors(self) -> List[torch.Tensor]:
return list(self.tensors_info.keys())

@property
def _payload(self) -> torch.Tensor:
if self._cpu_data is None or is_storage_empty(self._cpu_data):
return self.data
return self._cpu_data


class ChunkManager:
"""
Expand All @@ -285,13 +332,25 @@ def __init__(self,
self.enable_distributed_storage = enable_distributed_storage
self.device = init_device or get_current_device()
self.chunk_groups: Dict[str, Deque[Chunk]] = {}
self.groups_force_data_on_cuda: Dict[str, bool] = {}
self.tensor_chunk_map: Dict[torch.Tensor, Chunk] = {}
self.accessed_chunks: Set[Chunk] = set()
self.lazy_release_tensors: List[torch.Tensor] = []
if enable_distributed_storage and chunk_size is None:
self.rank_load: Dict[str, torch.Tensor] = {}
self.total_mem: Dict[str, int] = {'cpu': 0, 'cuda': 0}

def create_group(self, group_name: str, force_data_on_cuda: bool = False) -> None:
"""Create a chunk group.

Args:
group_name (str): group name
force_data_on_cuda (bool, optional): If True, the data of chunks in this group is always on cuda.. Defaults to False.
"""
assert group_name not in self.chunk_groups
self.chunk_groups[group_name] = deque()
self.groups_force_data_on_cuda[group_name] = force_data_on_cuda

def append_tensor(self, tensor: torch.Tensor, group_name: str) -> None:
"""
Append a tensor to a chunk.
Expand All @@ -304,19 +363,20 @@ def append_tensor(self, tensor: torch.Tensor, group_name: str) -> None:
if self.chunk_size is not None and tensor.numel() > self.chunk_size:
raise ValueError(
f'Cannot create chunk, got tensor numel ({tensor.numel()}) > chunk size ({self.chunk_size})')
if group_name not in self.chunk_groups:
self.chunk_groups[group_name] = deque()

try:
# append the tensor to the last chunk
self.chunk_groups[group_name][-1].append(tensor)
except (IndexError, ChunkFullError):
# the except statement will be triggered when there is no chunk or
# the except statement will be triggered when there is no chunk or
# the last chunk in the chunk group is full
# this will create a new chunk and allocate this chunk to its corresponding process
chunk_size = self.chunk_size or tensor.numel()
src_rank = self._get_next_src_rank(group_name)
chunk = Chunk(chunk_size, src_rank, tensor.dtype, self.device)
chunk = Chunk(chunk_size,
src_rank,
tensor.dtype,
self.device,
force_data_on_cuda=self.groups_force_data_on_cuda[group_name])

if self.enable_distributed_storage and self.chunk_size is None:
self.rank_load[group_name][src_rank] += chunk_size
Expand Down Expand Up @@ -387,7 +447,7 @@ def release_chunk(self, chunk: Chunk) -> None:
# update the memory consumption after releasing
self.total_mem[chunk.device_type] -= chunk.mem

def move_chunk(self, chunk: Chunk, device: torch.device) -> None:
def move_chunk(self, chunk: Chunk, device: torch.device, update_ptr: bool = True) -> None:
"""
Move the chunk to the target device.

Expand All @@ -399,7 +459,7 @@ def move_chunk(self, chunk: Chunk, device: torch.device) -> None:
return
if chunk.can_move_device and not chunk.is_empty:
self.total_mem[chunk.device_type] -= chunk.mem
chunk.move_device(device)
chunk.move_device(device, update_ptr=update_ptr)
self.total_mem[chunk.device_type] += chunk.mem

def trans_tensor_state(self, tensor: torch.Tensor, state: TensorState) -> None:
Expand Down
1 change: 1 addition & 0 deletions tests/test_tensor/test_chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def run_chunk_zero(use_chunk, use_zero):
params = [torch.rand(8, 8) for _ in range(3)]
chunk_size = 128 if use_chunk else None
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
chunk_manager.create_group('param')
assert chunk_manager.total_mem['cpu'] == 0
assert chunk_manager.total_mem['cuda'] == 0
for p in params:
Expand Down
Loading