Skip to content

Commit

Permalink
Merge pull request #18 from opendilab/dev/update
Browse files Browse the repository at this point in the history
dev(narugo): add supported for vmap
  • Loading branch information
HansBug committed Sep 19, 2023
2 parents c8a07a3 + 114f98b commit 356a7f6
Show file tree
Hide file tree
Showing 9 changed files with 252 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/doc.yml
Expand Up @@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [ 3.7 ]
python-version: [ 3.8 ]

services:
plantuml:
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/test.yml
Expand Up @@ -36,6 +36,8 @@ jobs:
numpy-version: '1.22.0'
- python-version: '3.7'
numpy-version: '1.24.0'
- python-version: '3.9'
numpy-version: '1.18.0'
- python-version: '3.10'
numpy-version: '1.18.0'
- python-version: '3.11'
Expand Down
2 changes: 1 addition & 1 deletion requirements-doc.txt
@@ -1,7 +1,7 @@
Jinja2~=3.0.0
sphinx~=3.2.0
sphinx_rtd_theme~=0.4.3
enum_tools
enum_tools~=0.9.0
sphinx-toolbox
plantumlcli>=0.0.2
packaging
Expand Down
44 changes: 44 additions & 0 deletions test/torch/funcs/test_construct.py
Expand Up @@ -190,6 +190,50 @@ def test_randn_like(self):
}
})

@choose_mark()
def test_rand(self):
_target = ttorch.rand(200, 300)
assert 0.45 <= _target.view(60000).mean().tolist() <= 0.55
assert _target.shape == torch.Size([200, 300])

_target = ttorch.rand({
'a': (2, 3),
'b': (5, 6),
'x': {
'c': (2, 3, 4),
}
})
assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]),
'b': torch.Size([5, 6]),
'x': {
'c': torch.Size([2, 3, 4]),
}
})

@choose_mark()
def test_rand_like(self):
_target = ttorch.rand_like(torch.ones(200, 300))
assert 0.45 <= _target.view(60000).mean().tolist() <= 0.55
assert _target.shape == torch.Size([200, 300])

_target = ttorch.rand_like({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32),
'b': torch.tensor([1, 2, 3, 4], dtype=torch.float),
'x': {
'c': torch.tensor([5, 6, 7], dtype=torch.float64),
'd': torch.tensor([[[8, 9]]], dtype=torch.float32),
}
})
assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]),
'b': torch.Size([4]),
'x': {
'c': torch.Size([3]),
'd': torch.Size([1, 1, 2]),
}
})

@choose_mark()
def test_randint(self):
_target = ttorch.randint(-10, 10, {
Expand Down
106 changes: 106 additions & 0 deletions test/torch/funcs/test_wrapper.py
@@ -0,0 +1,106 @@
from unittest import skipUnless

import pytest
import torch
from hbutils.testing import vpip

import treetensor.torch as ttorch
from treetensor.torch import Size


@pytest.fixture()
def treetensor_x():
return ttorch.randn({
'a': (2, 5, 7),
'b': {
'x': (3, 4, 6),
}
})


@pytest.fixture()
def treetensor_y():
return ttorch.randn({
'a': (2, 5, 7),
'b': {
'x': (3, 4, 6),
}
})


@pytest.mark.unittest
class TestTorchTensorWrapper:
@skipUnless(vpip('torch') >= '2', 'Torch 2 required.')
def test_vmap(self, treetensor_x, treetensor_y):
f = lambda x, y: (x.sum() + y.mean() * 2)
native_vf = torch.vmap(f)
tv_vf = ttorch.vmap(f)
r = tv_vf(treetensor_x, treetensor_y)

assert r.shape == Size({
'a': (2,),
'b': {
'x': (3,)
},
})
assert ttorch.isclose(
r,
ttorch.tensor({
'a': native_vf(treetensor_x.a, treetensor_y.a),
'b': {
'x': native_vf(treetensor_x.b.x, treetensor_y.b.x),
}
})
).all()

@skipUnless(vpip('torch') >= '2', 'Torch 2 required.')
def test_vmap_in_dims(self, treetensor_x, treetensor_y):
f = lambda x, y: (x.sum() + y.mean() * 2)
native_vf = torch.vmap(f, in_dims=1)
tv_vf = ttorch.vmap(f, in_dims=1)
r = tv_vf(treetensor_x, treetensor_y)

assert r.shape == Size({
'a': (5,),
'b': {
'x': (4,)
},
})
assert ttorch.isclose(
r,
ttorch.tensor({
'a': native_vf(treetensor_x.a, treetensor_y.a),
'b': {
'x': native_vf(treetensor_x.b.x, treetensor_y.b.x),
}
})
).all()

@skipUnless(vpip('torch') >= '2', 'Torch 2 required.')
def test_vmap_nested(self, treetensor_x, treetensor_y):
f = lambda x, y: (x.sum() + y.mean() * 2)
native_vf = torch.vmap(torch.vmap(f))
tv_vf = ttorch.vmap(ttorch.vmap(f))
r = tv_vf(treetensor_x, treetensor_y)

assert r.shape == Size({
'a': (2, 5),
'b': {
'x': (3, 4)
},
})
assert ttorch.isclose(
r,
ttorch.tensor({
'a': native_vf(treetensor_x.a, treetensor_y.a),
'b': {
'x': native_vf(treetensor_x.b.x, treetensor_y.b.x),
}
})
).all()

@skipUnless(vpip('torch') < '2', 'Torch 1.x required.')
def test_vmap_torch_1x(self, treetensor_x, treetensor_y):
f = lambda x, y: (x.sum() + y.mean() * 2)
with pytest.raises(NotImplementedError):
_ = ttorch.vmap(f)
3 changes: 3 additions & 0 deletions treetensor/torch/funcs/__init__.py
Expand Up @@ -14,6 +14,8 @@
from .operation import __all__ as _operation_all
from .reduction import *
from .reduction import __all__ as _reduction_all
from .wrapper import *
from .wrapper import __all__ as _wrapper_all
from ...utils import module_autoremove

__all__ = [
Expand All @@ -24,6 +26,7 @@
*_matrix_all,
*_operation_all,
*_reduction_all,
*_wrapper_all,
]

_current_module = sys.modules[__name__]
Expand Down
17 changes: 17 additions & 0 deletions treetensor/torch/funcs/base.py
@@ -1,4 +1,7 @@
from functools import wraps

import torch
from hbutils.testing import vpip
from treevalue import func_treelize as original_func_treelize

from ..tensor import Tensor
Expand All @@ -11,3 +14,17 @@
auto_tensor = replaceable_partial(auto_tree, cls=[(torch.is_tensor, Tensor)])
get_func_from_torch = module_func_loader(torch, Tensor,
[(torch.is_tensor, Tensor)])

_is_torch_2 = vpip('torch') >= '2'


def wrap_for_treelize(*args, **kwargs):
def _decorator(func):
@wraps(func)
def _new_func(*args_, **kwargs_):
retval = func(*args_, **kwargs_)
return func_treelize(*args, **kwargs)(retval)

return _new_func

return _decorator
57 changes: 57 additions & 0 deletions treetensor/torch/funcs/construct.py
Expand Up @@ -10,6 +10,7 @@
'tensor', 'as_tensor', 'clone',
'zeros', 'zeros_like',
'randn', 'randn_like',
'rand', 'rand_like',
'randint', 'randint_like',
'ones', 'ones_like',
'full', 'full_like',
Expand Down Expand Up @@ -216,6 +217,62 @@ def randn_like(input, *args, **kwargs):
return stream_call(torch.randn_like, input, *args, **kwargs)


@doc_from_base()
@args_treelize
@func_treelize()
def rand(*args, **kwargs):
"""
In ``treetensor``, you can use ``rand`` to create a tree of tensors with numbers
obey standard normal distribution.
Example::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.rand(2, 3) # the same as torch.rand(2, 3)
tensor([[-0.8534, -0.5754, -0.2507],
[ 0.0826, -1.4110, 0.9748]])
>>> ttorch.rand({'a': (2, 3), 'b': {'x': (4, )}})
<Tensor 0x7ff363bb6518>
├── a --> tensor([[ 0.5398, 0.7529, -2.0339],
│ [-0.5722, -1.1900, 0.7945]])
└── b --> <Tensor 0x7ff363bb6438>
└── x --> tensor([-0.7181, 0.1670, -1.3587, -1.5129])
"""
return stream_call(torch.rand, *args, **kwargs)


# noinspection PyShadowingBuiltins
@doc_from_base()
@args_treelize
@func_treelize()
def rand_like(input, *args, **kwargs):
"""
In ``treetensor``, you can use ``rand_like`` to create a tree of tensors with numbers
obey standard normal distribution like another tree.
Example::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.rand_like(torch.ones(2, 3)) # the same as torch.rand_like(torch.ones(2, 3))
tensor([[ 1.8436, 0.2601, 0.9687],
[ 1.6430, -0.1765, -1.1732]])
>>> ttorch.rand_like({
... 'a': torch.ones(2, 3),
... 'b': {'x': torch.ones(4, )},
... })
<Tensor 0x7ff3d6f3cb38>
├── a --> tensor([[-0.1532, 1.3965, -1.2956],
│ [-0.0750, 0.6475, 1.1421]])
└── b --> <Tensor 0x7ff3d6f420b8>
└── x --> tensor([ 0.1730, 1.6085, 0.6487, -1.1022])
"""
return stream_call(torch.rand_like, input, *args, **kwargs)


@doc_from_base()
@args_treelize
@func_treelize()
Expand Down
21 changes: 21 additions & 0 deletions treetensor/torch/funcs/wrapper.py
@@ -0,0 +1,21 @@
import torch

from .base import doc_from_base, wrap_for_treelize, _is_torch_2

__all__ = [
'vmap',
]

if _is_torch_2:
@doc_from_base()
@wrap_for_treelize()
def vmap(func, *args, **kwargs):
return torch.vmap(func, *args, **kwargs)

else:
def vmap(func, *args, **kwargs):
"""
.. warning:
:method:`treetensor.torch.vmap` is not supported for torch 1.x.
"""
raise NotImplementedError(f'Function vmap is not supported in torch {torch.__version__}.')

0 comments on commit 356a7f6

Please sign in to comment.