Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tensor] add Parameter inheritance for ColoParameter #1041

Merged
merged 5 commits into from
May 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions colossalai/tensor/colo_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,23 @@
import torch
from colossalai.tensor import TensorSpec, distspec
from copy import copy
from typing import Optional


class ColoParameter(ColoTensor):
class ColoParameter(ColoTensor, torch.nn.Parameter):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can ColoParameter use torch_function?
Can you add a test for this function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works.
image

r"""A kind of ColoTensor to be considered as a module parameter.

"""

def __new__(cls,
data: torch.Tensor,
data: Optional[torch.Tensor] = None,
requires_grad: bool = True,
spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoParameter':
if data is None:
data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, requires_grad)

def __init__(self,
data: torch.Tensor,
data: Optional[torch.Tensor] = None,
requires_grad: bool = True,
spec: TensorSpec = TensorSpec(distspec.replicate())) -> None:
self._spec = copy(spec)
Expand All @@ -43,4 +43,30 @@ def from_torch_tensor(tensor: torch.Tensor,

def __repr__(self):
return f'ColoParameter: {torch.Tensor.__repr__(self)}'

def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
else:
with torch._C.DisableTorchFunction():
data = self.data.clone()
tensor = ColoParameter(data, self.requires_grad, spec=copy(self.spec))
memo[id(self)] = tensor
return tensor

def __reduce_ex__(self, proto):
# Adapted from torch._utils._rebuild_parameter
# def _rebuild_colo_parameter(data, requires_grad, backward_hooks):
# colo_param = ColoParameter(data, requires_grad)
# colo_param._backward_hooks = backward_hooks
# return colo_param

# return (
# _rebuild_colo_parameter,
# (self.data, self.requires_grad, OrderedDict())
# )

# TODO(jzy) we don't support object reflection now.
# distspec cannot be pickled or rebuilt because it's tightly connected to runtime attribute `process_group`.
raise NotImplementedError

94 changes: 0 additions & 94 deletions colossalai/utils/model/colo_init_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,96 +24,6 @@ def _named_params_with_replica(
name = mod_prefix + ('.' if mod_prefix else '') + name
yield name, val


# Adapted from torch.nn.module.Module.register_param


def _register_parameter_with_colotensor(self, name: str, param):
if '_parameters' not in self.__dict__:
raise AttributeError("cannot assign parameter before Module.__init__() call")

if not isinstance(name, torch._six.string_classes):
raise TypeError("parameter name should be a string. "
"Got {}".format(torch.typename(name)))
if '.' in name:
raise KeyError("parameter name can't contain \".\"")
if name == '':
raise KeyError("parameter name can't be empty string \"\"")
if hasattr(self, name) and name not in self._parameters:
raise KeyError("attribute '{}' already exists".format(name))

if param is None:
self._parameters[name] = None
elif not isinstance(param, (torch.nn.Parameter, ColoParameter)):
raise TypeError("cannot assign '{}' object to parameter '{}' "
"(torch.nn.Parameter or ColoParameter or None required)".format(torch.typename(param), name))
elif param.grad_fn:
raise ValueError("Cannot assign non-leaf Tensor to parameter '{0}'. Model "
"parameters must be created explicitly. To express '{0}' "
"as a function of another Tensor, compute the value in "
"the forward() method.".format(name))
else:
self._parameters[name] = param


# Adapted from torch.nn.module.Module.__setattr__


def _setattr_with_colotensor(self, name: str, value: Union[torch.Tensor, torch.nn.Module, ColoTensor]):

def remove_from(*dicts_or_sets):
for d in dicts_or_sets:
if name in d:
if isinstance(d, dict):
del d[name]
else:
d.discard(name)

params = self.__dict__.get('_parameters')
if isinstance(value, (ColoParameter, torch.nn.Parameter)):
if params is None:
raise AttributeError("cannot assign parameters before Module.__init__() call")
remove_from(self.__dict__, self._buffers, self._modules, self._non_persistent_buffers_set)
self.register_parameter(name, value)
elif params is not None and name in params:
if value is not None:
raise TypeError("cannot assign '{}' as parameter '{}' "
"(torch.nn.Parameter or None expected)".format(torch.typename(value), name))
self.register_parameter(name, value)
else:
modules = self.__dict__.get('_modules')
if isinstance(value, torch.nn.Module):
if modules is None:
raise AttributeError("cannot assign module before Module.__init__() call")
remove_from(self.__dict__, self._parameters, self._buffers, self._non_persistent_buffers_set)
modules[name] = value
elif modules is not None and name in modules:
if value is not None:
raise TypeError("cannot assign '{}' as child module '{}' "
"(torch.nn.Module or None expected)".format(torch.typename(value), name))
modules[name] = value
else:
buffers = self.__dict__.get('_buffers')
if buffers is not None and name in buffers:
if value is not None and not isinstance(value, torch.Tensor):
raise TypeError("cannot assign '{}' as buffer '{}' "
"(torch.Tensor or None expected)".format(torch.typename(value), name))
buffers[name] = value
else:
object.__setattr__(self, name, value)

def _get_parameter_with_colotensor(self, target: str) -> Union[torch.nn.Parameter, ColoTensor]:
module_path, _, param_name = target.rpartition(".")

mod: torch.nn.Module = self.get_submodule(module_path)

if not hasattr(mod, param_name):
raise AttributeError(mod._get_name() + " has no attribute `"
+ param_name + "`")

param = getattr(mod, param_name)
return param

def ColoModulize(module):
"""
Replacing the parameters() and named_parameters() with our customized ones
Expand All @@ -134,10 +44,6 @@ def __init__(self, lazy_memory_allocate: bool = False, device: torch.device = to
self._lazy_memory_allocate = lazy_memory_allocate
self._device = device

torch.nn.Module.__setattr__ = _setattr_with_colotensor
torch.nn.Module.register_parameter = _register_parameter_with_colotensor
torch.nn.Module.get_parameter = _get_parameter_with_colotensor

self._register_colo_modules()

def _register_colo_modules(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_tensor/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,5 +353,5 @@ def _test_pretrain_load(world_size):
if __name__ == '__main__':
# test_model_parameters()
# test_colo_optimizer()
test_model(4)
# _test_pretrain_load(4)
# test_model(4)
_test_pretrain_load(4)
26 changes: 26 additions & 0 deletions tests/test_tensor/test_parameter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from colossalai.tensor import ColoParameter, ColoTensor
import torch
from numpy import allclose
from _utils import tensor_equal

def test_multiinheritance():
colo_param = ColoParameter()
assert isinstance(colo_param, ColoTensor)
assert isinstance(colo_param, torch.nn.Parameter)

# __deepcopy__ overload
import copy
colo_param2 = copy.deepcopy(colo_param)
assert isinstance(colo_param2, ColoParameter)
assert tensor_equal(colo_param.data, colo_param2.data)
assert colo_param.requires_grad == colo_param2.requires_grad

# __repr__ overload
assert 'ColoParameter' in str(colo_param)

# __torch_function__
clone_param = torch.clone(colo_param)
assert isinstance(clone_param, ColoTensor)

if __name__ == '__main__':
test_multiinheritance()
1 change: 1 addition & 0 deletions tests/test_tensor/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,4 @@ def test_operand():
t_ref_res = t_ref + t_ref
t_res = t + t
assert torch.allclose(t_ref_res, t_res)