Skip to content

Commit

Permalink
Enabling ssd_offload training basic tests. (#887)
Browse files Browse the repository at this point in the history
* Enabling ssd_offload training and test via tests/nn/data_parallel/test_fsdp_offload.py.
* Removed unused classes: SsdBuffer, SsdTensorHandleView, SsdParameter, SsdTensor
* Enhance test coverage of test_ssd_offloading_train_flatten_params_wrapper
* Modifications from PR #887 review comments.
* Update Changelog
  • Loading branch information
another-pjohnson committed Jan 5, 2022
1 parent 541bb8c commit c5e471b
Show file tree
Hide file tree
Showing 7 changed files with 301 additions and 245 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed
- Fixed a corner case of FSDP init order and losing one of the flags [#880]
- FSDP: Adding basic training support for SSD Offload, it now only supports flattened parameters. Renamed OffloadConfig.ssd_filepath_dir to more generic OffloadConfig.dir. SSD Offload remains an experimental feature. [#887]

## [0.4.3] - 2021-11-18

Expand Down
201 changes: 107 additions & 94 deletions fairscale/experimental/nn/ssd_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from functools import reduce
import io
import os
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Iterator, List, Optional, Sequence, Tuple, Type

import numpy as np
import torch
Expand Down Expand Up @@ -131,7 +131,7 @@ def from_file(
return handle

@classmethod
def from_tensor(cls, tensor: torch.Tensor) -> SsdTensorHandle:
def from_tensor(cls: Type[SsdTensorHandle], tensor: torch.Tensor) -> SsdTensorHandle:
"""Returns a new SsdTensorHandle from a tensor."""
handle = cls(shape=tensor.shape, dtype=tensor.dtype, requires_grad=tensor.requires_grad)
handle.tensor = tensor
Expand Down Expand Up @@ -159,6 +159,13 @@ def point_to_tensor(self, tensor: torch.Tensor) -> None:
assert self._dtype == tensor.dtype
self.tensor = tensor

# if resizing a handle that is part of an ssd buffer, care must be taken that the new size
# doesn't conflict with adjacent handles!
def point_to_resized_tensor(self, tensor: torch.Tensor) -> None:
assert self._dtype == tensor.dtype
self._shape = tensor.shape
self.tensor = tensor

def to_tensor(self) -> torch.Tensor:
"""Returns the tensor represented by the SsdTensorHandle object.
Expand All @@ -173,13 +180,15 @@ def to_tensor(self) -> torch.Tensor:
self.storage_state = StorageState.ON_CPU
return self.tensor

def to_file(self, release_tensor_after_write: bool = True) -> None:
def to_file(self, permit_when_tensor_none: bool = False, release_tensor_after_write: bool = True) -> None:
"""Saves the tensor to disk and releases memory if specified."""
assert self.tensor is not None
write(self.tensor, self.filename, self.offset * self.tensor.element_size())
if release_tensor_after_write:
self.tensor = None
self.storage_state = StorageState.ON_DISK
assert self.tensor is not None or permit_when_tensor_none

if self.tensor is not None:
write(self.tensor, self.filename, self.offset * self.tensor.element_size())
if release_tensor_after_write:
self.tensor = None
self.storage_state = StorageState.ON_DISK

def copy_into_tensor(self, tensor: torch.Tensor) -> None:
"""Copies SsdTensorHandle's data into the given tensor.
Expand Down Expand Up @@ -229,92 +238,96 @@ def unwrap(e: Any) -> torch.Tensor:
return r


class SsdBuffer:
class SsdFlatParameter(torch.nn.Parameter, SsdTensorHandle):
"""A parameter that is initialized from a list of parameters and can be
turned into a list of views as needed.
"""
The SsdBuffer represents a single buffer containing a list of tensors. Each of the
tensors are represented by a `SsdTensorHandle`.
Args:
num_elems (int): Dictates the size of the 1-D tensor.
dtype (torch.dtype): Dtype of the buffer.
"""

def __init__(self, num_elems: int, filename: str, dtype: torch.dtype = torch.float32) -> None:
self.buffer: torch.Tensor = torch.empty((num_elems,), dtype=dtype)
self.filename = filename
self.offset = 0
self.tensors: Dict[int, SsdTensorHandle] = {}
self.storage_state = StorageState.ON_CPU

def allocate(self, num_elems: int) -> SsdTensorHandle:
"""Allocates a new tensor handle of size num_elems."""
assert num_elems > 0
assert self.storage_state == StorageState.ON_CPU, self.storage_state
assert self.can_alloc(num_elems)

tensor = self.buffer.narrow(0, self.offset, num_elems)

tensor_offset = self.offset
handle = SsdTensorHandle.from_tensor(tensor)
self.tensors[tensor_offset] = handle
handle.set_file_params(self.filename, tensor_offset)
self.offset += num_elems

return handle

def insert(self, tensor: torch.Tensor) -> SsdTensorHandle:
"""Insert a new tensor by allocating memory and creating a corresponding handle."""
assert self.storage_state == StorageState.ON_CPU, self.storage_state
# For the non sharded case, the tensor will not be flattened
tensor = tensor.reshape(-1)
assert self.buffer.dtype == tensor.dtype
handle = self.allocate(tensor.numel())
handle.get_tensor().copy_(tensor)
return handle
def __new__(
cls, params: Sequence[torch.nn.Parameter], filename: str, requires_grad: bool = True
) -> "SsdFlatParameter":
"""Make an object using the parent's __new__ function."""

# A empty of non-list input doesn't make sense.
if not isinstance(params, (list, tuple)) or len(params) == 0:
raise ValueError("An non-empty list or tuple argument is needed")

# Normally, all items are Parameters. But during pickling, we will have a single
# Tensor as the input and later in __init__, the correct _param_numels and _param_shapes
# are set.
if not all(isinstance(p, (torch.nn.Parameter, torch.Tensor)) for p in params):
raise ValueError("List items need to be Parameter types")

# Flattening involves (1) making a tensor flat (i.e. single dimensional) and (2) making a module
# heirarchy flat (using a single tensor to replace a tree of tensors). Therefore,
# adding back nesting and heirarchy is counter-productive. If nesting is encountered
# in the future, the reasonable thing to do is likely for the top level SsdFlatParameter to
# absorb the nested one and keep the result flat, free from hierarchy.
if any(isinstance(p, SsdFlatParameter) for p in params):
raise ValueError("Nesting SsdFlatParameter is not supported")

dtype = params[0].dtype
size = sum(p.numel() for p in params)
r = SsdTensorHandle._make_wrapper_subclass(cls, (size,), dtype=dtype, requires_grad=requires_grad) # type: ignore
return r

def can_alloc(self, num_elems: int) -> bool:
"""Verify that you can allocate a tensor within the bounds
of the larger SsdBuffer memory buffer."""
assert self.storage_state == StorageState.ON_CPU, self.storage_state
return (self.offset + num_elems) <= self.buffer.numel()

def get_tensors(self) -> List[SsdTensorHandle]:
"""Returns the list of tensor handles in SsdBuffer."""
return [t for t in self.tensors.values()]

def to_disk(self) -> None:
"""Writes all tensors backed by handles to disk."""
if self.storage_state == StorageState.ON_DISK:
return
assert self.storage_state == StorageState.ON_CPU, self.storage_state
# We use `narrow` so that we write valid tensors that have been allocated
# as opposed to the entire SSD buffer.
valid_data = self.buffer.narrow(0, 0, self.offset)
write(valid_data, self.filename)

# Remove all Tensor references
for offset, t in self.tensors.items():
t.point_to_file(self.filename, offset)

# TODO(anj-s): Setting this to None does not result in GC picking
# this reference up.
self.buffer = torch.empty((1))
self.storage_state = StorageState.ON_DISK

def from_disk(self, num_elems: int, dtype: torch.dtype = torch.float32) -> None:
"""Reads all tensors backed by handles into memory."""
if self.storage_state == StorageState.ON_CPU:
return
assert self.storage_state == StorageState.ON_DISK, self.storage_state
if num_elems < self.offset:
raise RuntimeError(
f"Attempted to load from file ssdbuffer of size: {self.offset} into a buffer that is of size: {num_elems}"
)
self.buffer = torch.empty((num_elems,), dtype=dtype)
valid_data = self.buffer.narrow(0, 0, self.offset)
read(valid_data, self.filename)

for offset, t in self.tensors.items():
t.point_to_tensor(self.buffer.narrow(0, t.offset, t._numel))

self.storage_state = StorageState.ON_CPU
def __init__(self, params: Sequence[torch.nn.Parameter], filename: str, requires_grad: bool = True):
"""Initialize the _param_numels and _param_shapes lists."""
self._param_numels = [p.numel() for p in params]
total_numels = sum(self._param_numels)
assert (
self.numel() <= total_numels
), f"Something wrong with __new__ method, {self.numel()} vs. {sum(self._param_numels)}"
self._param_shapes = [p.size() for p in params]

# These are set by FPW class below, not by this class itself.
self._param_infos: List[Tuple[str, torch.nn.Module, str]] = []
self._shared_param_infos: List[Tuple[str, str, torch.nn.Module, str, torch.nn.Module, str]] = []

super(SsdFlatParameter, self).__init__(shape=(total_numels,), dtype=params[0].dtype, requires_grad=requires_grad) # type: ignore

tensor = torch.cat(
[p.detach().reshape(-1) if isinstance(p, torch.nn.Parameter) else p.reshape(-1) for p in params], 0
)
tensor.requires_grad = requires_grad
self.set_file_params(filename, 0)
self.point_to_tensor(tensor)

def get_param_views(self, external_data: Optional[torch.Tensor] = None) -> Iterator[torch.Tensor]:
"""Return a generator of views that map to the original parameters."""
# Note, self.data could be sharded, so its numel is <= to the sum.
"""
assert self.data.numel() <= sum(
self._param_numels
), f"Incorrect internal state {self.data.numel()} vs. {sum(self._param_numels)}"
"""
if external_data:
if external_data.numel() != sum(self._param_numels):
raise ValueError(
f"Incorrect numel of supplied data: got {external_data.numel()} but expected {sum(self._param_numels)}"
)
return (t.view(s) for (t, s) in zip(external_data.split(self._param_numels), self._param_shapes))
else:
return (t.view(s) for (t, s) in zip(self.split(self._param_numels), self._param_shapes))

def metadata(self) -> Tuple[List[str], List[torch.Size], List[int]]:
"""Return tuple of (names, shapes, numels) metadata for this flat parameter."""
names = [".".join([m, n]) if m else n for (m, _, n) in self._param_infos]
return names, self._param_shapes, self._param_numels

def __setstate__(self, state: Tuple[Any, Any, Any, Any]) -> None:
"""Use by pickle to set the internal states."""
(self._param_numels, self._param_shapes, self._param_infos, self._shared_param_infos) = state
assert self.numel() <= sum(
self._param_numels
), f"Incorrect pickling {self.numel()} vs. {sum(self._param_numels)}"

def __reduce_ex__(self, proto: int) -> Tuple[Any, Any, Any]:
"""Support pickling between ranks."""
return (
SsdFlatParameter, # Callable
# Args to the callable above
([self.data], self.filename, self.requires_grad),
# Args to __setstate__
(self._param_numels, self._param_shapes, self._param_infos, self._shared_param_infos),
)

0 comments on commit c5e471b

Please sign in to comment.